前面的数据是采用PyTorch的ImageFolder读取的,读取后的数据进行了训练集数据集的划分,分别得到了image路径列表和label标签列表,再传入自己写的label_shuffling进行数据增强,返回后通过PyTorch的DataLoader加载到网络里。
def label_shuffling(data, labels):
#得到样本数最多的类别的样本数
maxNum_sample = Counter(labels).most_common(1)
#得到所有label,分别对每个label进行扩增
no_repeat_labels = len(set(labels))
data_list = []
label_list = []
#先对label 0 扩增,再对label 1,再对label 2
for i in range(no_repeat_labels):
current_data = []
current_label = []
for j in range(len(labels)):
if i == labels[j]:
current_label.append(labels[j])
current_data.append(data[j])
#根据最多的样本数,对每类都产生一个随机排列的列表
one_class_list = list(range(maxNum_sample[0][1]))
random.shuffle(one