深度学习 — VOC数据集 & 处理工具类
#!/usr/bin/env python
# encoding: utf-8
'''
@Author : pentiumCM
@Email : 842679178@
@Software: PyCharm
@File : voc_data_util.py
@Time : 2021/6/30 21:31
@desc : VOC 数据集工具类
'''
import os
import random
import xml.dom.minidom as xmldom
import xml.etree.ElementTree as ET
class VOCDataUtil:
"""
VOC 数据集工具类
"""
def __init__(self, dataset_path,
JPEGImages_dir='JPEGImages', annotation_dir='Annotations', ImageSets_dir='ImageSets',
label_type='xyxy'):
"""
VOC 格式数据集初始化
:param dataset_path: 数据集文件夹根路径
:param JPEGImages_dir: 图片文件夹
:param annotation_dir: 标注文件夹
:param ImageSets_dir: 数据集划分文件夹
:param label_type: 解析标注方式 xyxy(ssd)/ xywh(yolo)
"""
self.dataset_path = dataset_path
self.JPEGImages = os.path.join(self.dataset_path, JPEGImages_dir)
self.Annotations = os.path.join(self.dataset_path, annotation_dir)
self.ImageSets = os.path.join(self.dataset_path, ImageSets_dir)
self.label_type = label_type
assert os.path.exists(self.JPEGImages)
assert os.path.exists(self.Annotations)
def fill_blanklabel(self):
"""
填充 空白(背景)标注文件
:return:
"""
JPEGImages_names = [i.split('.')[0] for i in os.listdir(self.JPEGImages)]
Annotations_names = [i.split('.')[0] for i in os.listdir(self.Annotations)]
background_list = list(set(JPEGImages_names).difference(set(Annotations_names))) # 前面有而后面中没有的
for item in background_list:
file_name = item + '.xml'
file_path = os.path.join(self.Annotations, file_name)
open(file_path, "w")
print('process:', file_path)
def remove_blanklabel(self):
"""
删除 空白(背景)标注文件
:return:
"""
Annotations_names = [os.path.join(self.Annotations, i) for i in os.listdir(self.Annotations)]
for item in Annotations_names:
file_size = os.path.getsize(item)
if file_size == 0:
os.remove(item)
print('process:', item)
def voc_label_statistics(self):
'''
voc数据集类别统计
:return: {'class1':'count',...}
'''
count = 0
annotation_names = [os.path.join(self.Annotations, i) for i in os.listdir(self.Annotations)]
labels = dict()
for names in annotation_names:
names_arr = names.split('.')
file_type = names_arr[-1]
if file_type != 'xml':
continue
file_size = os.path.getsize(names)
if file_size == 0:
continue
count = count + 1
print('process:', names)
xmlfilepath = names
domobj = xmldom.parse(xmlfilepath)
# 得到元素对象
elementobj = domobj.documentElement
# 获得子标签
subElementObj = elementobj.getElementsByTagName("object")
for s in subElementObj:
label = s.getElementsByTagName("name")[0].firstChild.data
label_count = labels.get(label, 0)
labels[label] = label_count + 1
print('文件标注个数:', count)
return labels
def dataset_divide(self, train_percent=0.8, trainval_percent=1):
"""
1. 划分数据集:【(训练集,验证集),测试集】
在ImageSets/Main/生成 ,,,
:param train_percent: train_percent = 训练集 / (训练集 + 验证集)
:param trainval_percent: trainval_percent = (训练集 + 验证集) / 总数据集。测试集所占比例为:1-trainval_percent
:return:
"""
# 数据集划分保存的路径 ImageSets/Main/
ds_divide_save__path = os.path.join(self.ImageSets, 'Main')
if not os.path.exists(ds_divide_save__path):
os.makedirs(ds_divide_save__path)
# train_percent:训练集占(训练集+验证集)的比例
# train_percent = 0.8
# trainval_percent:(训练集+验证集)占总数据集的比例。测试集所占比例为:1-trainval_percent
# trainval_percent = 1
temp_xml = os.listdir(self.Annotations)
total_xml = []
for xml in temp_xml:
if xml.endswith(".xml"):
total_xml.append(xml)
num = len(total_xml)
list = range(num)
tv_len = int(num * trainval_percent)
tr_len = int(tv_len * train_percent)
test_len = int(num - tv_len)
trainval = random.sample(list, tv_len)
train = random.sample(trainval, tr_len)
print("train and val size:", tv_len)
print("train size:", tr_len)
print("test size:", test_len)
ftrainval = open(os.path.join(ds_divide_save__path, ''), 'w')
ftest = open(os.path.join(ds_divide_save__path, ''), 'w')
ftrain = open(os.path.join(ds_divide_save__path, ''), 'w')
fval = open(os.path.join(ds_divide_save__path, ''), 'w')
for i in list:
name = total_xml[i][:-4] + '\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest.close()
def convert(self, size, box):
"""
VOC 的标注转为 yolo的标注,即 xyxy -> xywh
:param size: 图片尺寸 (w, h)
:param box: 标注框(xyxy)
:return:
"""
dw = 1. / size[0]
dh = 1. / size[1]
# 中心点坐标
x = (box[0] + box[2]) / 2.0
y = (box[1] + box[3]) / 2.0
# 宽高
w = box[2] - box[0]
h = box[3] - box[1]
# 归一化
x = x * dw
w = w * dw
y = y * dh
h = h * dh
return (x, y, w, h)
def xml_annotation(self, image_id, out_file, classes):
"""
解析 voc 的 xml 标注文件
:param image_id: 对应于每张图片标注的索引
:param out_file: 汇总数据集标注 的输出文件
:param classes: 数据集类别
:return:
"""
# 做该校验,为了方便对背景进行训练
annotation_filepath = os.path.join(self.Annotations, image_id + '.xml')
file_size = os.path.getsize(annotation_filepath)
if file_size > 0:
in_file = open(os.path.join(self.Annotations, '%' % image_id), encoding='utf-8')
tree = ET.parse(in_file)
root = tree.getroot()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
for obj in root.iter('object'):
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult) == 1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (
int(float(xmlbox.find('xmin').text)),
int(float(xmlbox.find('ymin').text)),
int(float(xmlbox.find('xmax').text)),
int(float(xmlbox.find('ymax').text)))
b_box = b
if self.label_type == 'xywh':
# 标注格式:x,y,w,h,label
b_box = self.convert((w, h), b)
elif self.label_type == 'xyxy':
# 标注格式:x,y,x,y,label
pass
# yolo: out_file.write(str(cls_id) + " " + " ".join([str(a) for a in b_box]) + '\n')
out_file.write(" " + ",".join([str(a) for a in b_box]) + ',' + str(cls_id))
def convert_annotation(self, classes):
"""
2. 处理标注结果:汇总数据集标注到 txt
:param classes: 数据集类别
:return:
"""
sets = ['train', 'test', 'val']
for image_set in sets:
# 数据集划分文件:ImageSets/Main/
ds_divide_file = os.path.join(self.ImageSets, 'Main', '%' % (image_set))
# (训练,验证,测试集)数据集索引
image_ids = open(ds_divide_file).read().strip().split()
list_file = open(os.path.join(self.dataset_path, '%' % (image_set)), 'w', encoding='utf-8')
for image_id in image_ids:
list_file.write(os.path.join(self.JPEGImages, '%' % (image_id)))
self.xml_annotation(image_id, list_file, classes)
list_file.write('\n')
list_file.close()
if __name__ == '__main__':
# VOC_CLASS = [
# "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog",
# "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"
# ]
vocutil = VOCDataUtil(dataset_path='E:/project/nongda/code/ssd-keras-nd/VOCdevkit/VOC2007', label_type='xyxy')
# 1. 汇总数据集标注类别
label_dict = vocutil.voc_label_statistics()
cus_classes = []
for key in label_dict.keys():
cus_classes.append(key)
# 按照字符串进行正序排序
cus_classes.sort()
# 2. 划分数据集
vocutil.dataset_divide()
# 3. 处理标注结果
vocutil.convert_annotation(cus_classes)