生成数据集train/val/test.txt文件列表(前提是train/val/test数据图片是已经分好的)
import numpy as np
import os
base = './dataset/road/'
with open('dataset/road/labels.txt', 'w') as f:
for i in range(20):
f.write(str(i)+'\n')
imgs = os.listdir(base+'train_pic/')
np.random.seed(42)
np.random.shuffle(imgs)
val_num = int(0.1 * len(imgs))
with open(os.path.join('dataset/road/train_list.txt'), 'w') as f:
for pt in imgs[:-val_num]:
img = 'train_pic/'+pt
ann = 'train_tag/'+pt.replace('.jpg', '.png')
info = img + ' ' + ann + '\n'
crop_data(os.path.join(base,img), os.path.join(base,img))
crop_data(os.path.join(base,ann), os.path.join(base,ann))
f.write(info)
with open(os.path.join('dataset/road/val_list.txt'), 'w') as f:
for pt in imgs[-val_num:]:
img = 'train_pic/'+pt
ann = 'train_tag/'+pt.replace('.jpg', '.png')
info = img + ' ' + ann + '\n'
crop_data(os.path.join(base,img), os.path.join(base,img))
crop_data(os.path.join(base,ann), os.path.join(base,ann))
f.write(info)
with open(os.path.join('dataset/road/test_list.txt'), 'w') as f:
for pt in os.listdir(base+'test_pic/'):
img = 'test_pic/'+pt
ann = 'test_tag/'+pt.replace('.jpg', '.png')
info = img + ' ' + ann + '\n'
f.write(info)