深度学习 — VOC数据集 & 处理工具类

时间:2025-02-11 14:19:22
#!/usr/bin/env python # encoding: utf-8 ''' @Author : pentiumCM @Email : 842679178@ @Software: PyCharm @File : voc_data_util.py @Time : 2021/6/30 21:31 @desc : VOC 数据集工具类 ''' import os import random import xml.dom.minidom as xmldom import xml.etree.ElementTree as ET class VOCDataUtil: """ VOC 数据集工具类 """ def __init__(self, dataset_path, JPEGImages_dir='JPEGImages', annotation_dir='Annotations', ImageSets_dir='ImageSets', label_type='xyxy'): """ VOC 格式数据集初始化 :param dataset_path: 数据集文件夹根路径 :param JPEGImages_dir: 图片文件夹 :param annotation_dir: 标注文件夹 :param ImageSets_dir: 数据集划分文件夹 :param label_type: 解析标注方式 xyxy(ssd)/ xywh(yolo) """ self.dataset_path = dataset_path self.JPEGImages = os.path.join(self.dataset_path, JPEGImages_dir) self.Annotations = os.path.join(self.dataset_path, annotation_dir) self.ImageSets = os.path.join(self.dataset_path, ImageSets_dir) self.label_type = label_type assert os.path.exists(self.JPEGImages) assert os.path.exists(self.Annotations) def fill_blanklabel(self): """ 填充 空白(背景)标注文件 :return: """ JPEGImages_names = [i.split('.')[0] for i in os.listdir(self.JPEGImages)] Annotations_names = [i.split('.')[0] for i in os.listdir(self.Annotations)] background_list = list(set(JPEGImages_names).difference(set(Annotations_names))) # 前面有而后面中没有的 for item in background_list: file_name = item + '.xml' file_path = os.path.join(self.Annotations, file_name) open(file_path, "w") print('process:', file_path) def remove_blanklabel(self): """ 删除 空白(背景)标注文件 :return: """ Annotations_names = [os.path.join(self.Annotations, i) for i in os.listdir(self.Annotations)] for item in Annotations_names: file_size = os.path.getsize(item) if file_size == 0: os.remove(item) print('process:', item) def voc_label_statistics(self): ''' voc数据集类别统计 :return: {'class1':'count',...} ''' count = 0 annotation_names = [os.path.join(self.Annotations, i) for i in os.listdir(self.Annotations)] labels = dict() for names in annotation_names: names_arr = names.split('.') file_type = names_arr[-1] if file_type != 'xml': continue file_size = os.path.getsize(names) if file_size == 0: continue count = count + 1 print('process:', names) xmlfilepath = names domobj = xmldom.parse(xmlfilepath) # 得到元素对象 elementobj = domobj.documentElement # 获得子标签 subElementObj = elementobj.getElementsByTagName("object") for s in subElementObj: label = s.getElementsByTagName("name")[0].firstChild.data label_count = labels.get(label, 0) labels[label] = label_count + 1 print('文件标注个数:', count) return labels def dataset_divide(self, train_percent=0.8, trainval_percent=1): """ 1. 划分数据集:【(训练集,验证集),测试集】 在ImageSets/Main/生成 ,,, :param train_percent: train_percent = 训练集 / (训练集 + 验证集) :param trainval_percent: trainval_percent = (训练集 + 验证集) / 总数据集。测试集所占比例为:1-trainval_percent :return: """ # 数据集划分保存的路径 ImageSets/Main/ ds_divide_save__path = os.path.join(self.ImageSets, 'Main') if not os.path.exists(ds_divide_save__path): os.makedirs(ds_divide_save__path) # train_percent:训练集占(训练集+验证集)的比例 # train_percent = 0.8 # trainval_percent:(训练集+验证集)占总数据集的比例。测试集所占比例为:1-trainval_percent # trainval_percent = 1 temp_xml = os.listdir(self.Annotations) total_xml = [] for xml in temp_xml: if xml.endswith(".xml"): total_xml.append(xml) num = len(total_xml) list = range(num) tv_len = int(num * trainval_percent) tr_len = int(tv_len * train_percent) test_len = int(num - tv_len) trainval = random.sample(list, tv_len) train = random.sample(trainval, tr_len) print("train and val size:", tv_len) print("train size:", tr_len) print("test size:", test_len) ftrainval = open(os.path.join(ds_divide_save__path, ''), 'w') ftest = open(os.path.join(ds_divide_save__path, ''), 'w') ftrain = open(os.path.join(ds_divide_save__path, ''), 'w') fval = open(os.path.join(ds_divide_save__path, ''), 'w') for i in list: name = total_xml[i][:-4] + '\n' if i in trainval: ftrainval.write(name) if i in train: ftrain.write(name) else: fval.write(name) else: ftest.write(name) ftrainval.close() ftrain.close() fval.close() ftest.close() def convert(self, size, box): """ VOC 的标注转为 yolo的标注,即 xyxy -> xywh :param size: 图片尺寸 (w, h) :param box: 标注框(xyxy) :return: """ dw = 1. / size[0] dh = 1. / size[1] # 中心点坐标 x = (box[0] + box[2]) / 2.0 y = (box[1] + box[3]) / 2.0 # 宽高 w = box[2] - box[0] h = box[3] - box[1] # 归一化 x = x * dw w = w * dw y = y * dh h = h * dh return (x, y, w, h) def xml_annotation(self, image_id, out_file, classes): """ 解析 voc 的 xml 标注文件 :param image_id: 对应于每张图片标注的索引 :param out_file: 汇总数据集标注 的输出文件 :param classes: 数据集类别 :return: """ # 做该校验,为了方便对背景进行训练 annotation_filepath = os.path.join(self.Annotations, image_id + '.xml') file_size = os.path.getsize(annotation_filepath) if file_size > 0: in_file = open(os.path.join(self.Annotations, '%' % image_id), encoding='utf-8') tree = ET.parse(in_file) root = tree.getroot() size = root.find('size') w = int(size.find('width').text) h = int(size.find('height').text) for obj in root.iter('object'): difficult = obj.find('difficult').text cls = obj.find('name').text if cls not in classes or int(difficult) == 1: continue cls_id = classes.index(cls) xmlbox = obj.find('bndbox') b = ( int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text))) b_box = b if self.label_type == 'xywh': # 标注格式:x,y,w,h,label b_box = self.convert((w, h), b) elif self.label_type == 'xyxy': # 标注格式:x,y,x,y,label pass # yolo: out_file.write(str(cls_id) + " " + " ".join([str(a) for a in b_box]) + '\n') out_file.write(" " + ",".join([str(a) for a in b_box]) + ',' + str(cls_id)) def convert_annotation(self, classes): """ 2. 处理标注结果:汇总数据集标注到 txt :param classes: 数据集类别 :return: """ sets = ['train', 'test', 'val'] for image_set in sets: # 数据集划分文件:ImageSets/Main/ ds_divide_file = os.path.join(self.ImageSets, 'Main', '%' % (image_set)) # (训练,验证,测试集)数据集索引 image_ids = open(ds_divide_file).read().strip().split() list_file = open(os.path.join(self.dataset_path, '%' % (image_set)), 'w', encoding='utf-8') for image_id in image_ids: list_file.write(os.path.join(self.JPEGImages, '%' % (image_id))) self.xml_annotation(image_id, list_file, classes) list_file.write('\n') list_file.close() if __name__ == '__main__': # VOC_CLASS = [ # "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", # "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor" # ] vocutil = VOCDataUtil(dataset_path='E:/project/nongda/code/ssd-keras-nd/VOCdevkit/VOC2007', label_type='xyxy') # 1. 汇总数据集标注类别 label_dict = vocutil.voc_label_statistics() cus_classes = [] for key in label_dict.keys(): cus_classes.append(key) # 按照字符串进行正序排序 cus_classes.sort() # 2. 划分数据集 vocutil.dataset_divide() # 3. 处理标注结果 vocutil.convert_annotation(cus_classes)