class FeatDataTrain(Dataset):
def __init__(self, file_list , batch_size):
# image_filenames - 音频路径集合
self.filenames = self.load_files( file_list)
self.batch_size = batch_size
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
file_name = self.filenames[idx]
dir = file_name[0]
x = cPickle.load(open(dir, 'rb'))
x = x.astype(np.float32)
y = int(dir.split('/')[-1].split('.')[0].split('_')[0][0])
y= y - 1 # 若num_classes=3,则y应为[0,1,2]
return x ,y
def load_files(self, csv_path):
filpath = []
if csv_path != "":
f = open(csv_path)
data = csv.reader(f) #
for line in data:
tmp = line
filpath.append(tmp)
else:
print('没有特征文件!')
return filpath
最后一个Batch长度不足,会导致输出维度发生问题.drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了
通过读取文件夹路径分配文件。
(2)
loader = DataLoader(dataset, num_workers=opt.num_workers, drop_last=True, shuffle=True, batch_size= opt.batch_size)
在模型测试阶段,如果shuffle=False,会出现意想不到的情况。
参考:
(2)model.fit以及model.fit_generator区别及用法_猫爱吃鱼the的博客-CSDN博客_model.fit_generator参数
(3)DataLoader详解_sereasuesue的博客-CSDN博客_dataloader
(4)PyTorch 1.x 常用知识_Arrow的博客-CSDN博客
(5)pytorch官方教程中文版(一)PyTorch介绍_努力永远不会欺骗人的博客-CSDN博客_pytorch官方教程中文版