如果只是单纯的依靠 next(trainloader_iter) 获取数据,则有可能获取数据时候出现不能获取的情况数据为non,此时我们需要再次回访该数据集,再次读取。
代码如下:
#读取数据代码部分:
train_dataset = Dataset(args.data_dir, args.data_list, mean=IMG_MEAN)
train_dataset_size = len(train_dataset)
print ('dataset size: ', train_dataset_size)
trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
trainloader_iter = iter(trainloader)
....
# 运行 CNN training 部分
for i_iter in range(args.num_steps+1):
.....
##### 原始出错方案 #####
batch = next(trainloader_iter)
##### 改进方案 #####
try:
batch = next(trainloader_iter)
except:
trainloader_iter = iter(trainloader) # 再次读取,获取数据
batch = next(trainloader_iter)
......
images, labels, image_id = batch
images = Variable(images).cuda()
labels = labels.cuda()
......