# ------------------------------------------------#
# 进行训练前需要利用这个文件生成cls_train.txt
# ------------------------------------------------#
import os
import random
def txt_annotation(datasets_path, select_num):
'''
:param datasets_path: 数据集根目录
:param select_num: 每一个folder选择的文件数目
:return: cls_train.txt
'''
types_name = os.listdir (datasets_path)
types_name = sorted (types_name)
# print (types_name)
list_file = open ('cls_train.txt', 'w')
for cls_id, type_name in enumerate (types_name):
photos_path = os.path.join (datasets_path, type_name)
if not os.path.isdir (photos_path) or select_num > len (os.listdir (photos_path)):
print (
'pass folder {} with {} numbers , required {} '.format (type_name, len (os.listdir (photos_path)),
select_num))
continue
photos_name = os.listdir (photos_path)
photos_name = random.sample (photos_name, select_num)
# print(photos_name)
for photo_name in photos_name:
# print (cls_id)
list_file.write (
str (cls_id) + ";" + '%s' % (os.path.join (os.path.abspath (datasets_path), type_name, photo_name)))
list_file.write ('\n')
list_file.close ()
if __name__ == "__main__":
dataset_path = r'test_dataset'
select_num = 32
txt_annotation (dataset_path, select_num)
效果如下: