最近项目拿到了一个别人标注但没有划分的数据集,有13类,不过经过统计发现各类别的数目差距较大,最多的一类有五万多张图片,最少的一类只有两千多张,如果使用传统的划分方法,对所有的数据进行随机划分,将会导致样本严重不均衡的问题,甚至可能出现训练集中不存在某一类图片,因此考虑以最少的一类图片数目为基准,对每一类都选择两千张左右的图片,并且使用蓄水池算法保证选取的随机性,考虑到同一张图片中可能存在多个目标,并且目标也不一定是同类,因此对每一张图片的标注文件只参考其第一个标注的目标类别(如果标注文件中有没有标注的目标,需要先判断),最后对每一类图片按照数据集划分的比例随机划分到训练集、验证集、测试集中,虽然无法保证最终划分的数据集每一类图片数目非常相近,但大致差别不会太大,并且保证了训练集、验证集、测试集中每一类都会存在一定数目的图片。
import os
import xml.dom.minidom
import random
master_root = os.path.abspath(os.path.join(os.getcwd(), "../../"))
data_root = os.path.join(master_root, "name of your dataset") # data_root = os.path.join(master_root, "coco")
ImageSets_path = os.path.join(data_root, "ImageSets/Main")
train_txt_path = os.path.join(ImageSets_path, "train.txt")
val_txt_path = os.path.join(ImageSets_path, "val.txt")
test_txt_path = os.path.join(ImageSets_path, "test.txt")
none_tag_path = os.path.join(ImageSets_path, "none_tag.txt")
xml_path = os.path.join(data_root, "Annotations/")
classes = ['classes of your dataset']
# classes = [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat']
files = os.listdir(xml_path)
# 蓄水池抽样算法
def add(list, size, len, file):
if (len < size):
list.append(file)
else:
i = random.randint(0, len)
if (i < size):
list[i] = file
def create_imagesets_train_val_test(lists, traintxt_full_path, valtxt_full_path, testtxt_full_path):
# 训练集比例
train_percent = 0.6
# 验证集比例
val_percent = 0.2
# 测试集比例
test_percent = 0.2
ftrain = open(traintxt_full_path, 'w')
fval = open(valtxt_full_path, 'w')
ftest = open(testtxt_full_path, 'w')
trainList = []
valList = []
testList = []
for list in lists:
num = len(list)
num_train = int(num * train_percent) # 训练集个数
num_val = int(num * val_percent) # 验证集个数
# 随机选num_train个train文件
train_list = random.sample(list, num_train)
for i in train_list:
trainList.append(i)
list.remove(i)
val_list = random.sample(list, num_val)
for j in val_list:
valList.append(j)
list.remove(j)
test_list = list
for k in test_list:
testList.append(k)
trainList.sort()
valList.sort()
testList.sort()
for i in trainList:
ftrain.write(i) # train.txt文件写入
for j in valList:
fval.write(j) # val.txt文件写入
for k in testList:
ftest.write(k) # test.txt文件写入
ftrain.close() # 关闭train.txt
fval.close() # 关闭val.txt
ftest.close() # 关闭test.txt
if __name__ == '__main__':
lists = [[] for i in range(len(classes))]
sizes = []
length = [0 for i in range(len(classes))]
for i in range(len(classes)):
sizes.append(random.randint(1950, 2250)) # 大概数目
# 记录没标注的图片
none_tag = []
none = open(none_tag_path, 'w')
# 遍历所有标注文件
for file in files:
xmlfile = xml_path + file
dom = xml.dom.minidom.parse(xmlfile) # 读取xml文档
root = dom.documentElement # 得到文档元素对象
objectlist = root.getElementsByTagName("object")
if len(objectlist) == 0:
none_tag.append(os.path.splitext(file)[0] + '\n')
else:
# 如果有标注就按第一个标注的对象分类
object = objectlist[0]
namelist = object.getElementsByTagName("name")
objectname = namelist[0].childNodes[0].data
if objectname in classes:
cls_id = classes.index(objectname)
add(lists[cls_id], sizes[cls_id], length[cls_id], os.path.splitext(file)[0] + '\n') # 使用蓄水池算法实现随机选取样本
length[cls_id] += 1
for n in none_tag:
none.write(n) # none_tag.txt文件写入
none.close() # 关闭none_tag.txt
create_imagesets_train_val_test(lists, train_txt_path, val_txt_path, test_txt_path)