源代码:
class DatasetFromFolderEval(data.Dataset):
def __init__(self, lr_dir, upscale_factor, transform=None):
super(DatasetFromFolderEval, self).__init__()
self.image_filenames = [join(lr_dir, x) for x in listdir(lr_dir) if is_image_file(x)]
self.upscale_factor = upscale_factor
self.transform = transform
def __getitem__(self, index):
input = load_img(self.image_filenames[index])
_, file = os.path.split(self.image_filenames[index])
bicubic = rescale_img(input, self.upscale_factor)
# print('type(images) = ',type(input))
# print('type(labels) = ',type(bicubic))
if self.transform:
input = self.transform(input)
bicubic = self.transform(bicubic)
# print('type(images) = ',type(input))
# print('type(labels) = ',type(bicubic))
return input, bicubic, file
def __len__(self):
return len(self.image_filenames)
test_set = get_eval_set(test_dataset,2)
training_data_loader = DataLoader(test_set,1)
出错地方:
for iteration, batch in enumerate(training_data_loader, 1):
input, bicubic,_ = Variable(batch[0]), Variable(batch[1]), Variable(batch[2])
原因:
因为class eval类返回的是return input, bicubic, file
file是元组
解决办法:去掉variable
input, bicubic,_ = Variable(batch[0]), Variable(batch[1]), (batch[2])