最终的代码如下:
def main():
# parse arguments
args = parse_args()
if args is None:
exit()
if args.gpu_mode and not torch.cuda.is_available():
raise Exception("No GPU found, please run without --gpu_mode=False")
# print 'scale factor = ', scale_factor, \
# '\ntest_dir =', args.test_dataset,\
from network import Net_new4 as net
model = net(num_channels=1, scale_factor=4, d=32, s=5, m=1)
model.load_state_dict(torch.load(pretrained_model, map_location = torch.device('cpu')))
image_dir = args.test_dataset
image_filenames = [join(image_dir, x) for x in sorted(listdir(image_dir)) if is_image_file(x)]
file_num = len(image_filenames)
for idx in range(file_num):
img_ycbcr = Image.open(image_filenames[idx]).convert('YCbCr')
img_y, img_cb, img_cr = img_ycbcr.split()
input_x_t = torch.from_numpy(numpy.zeros((1, 1, np.array(img_y).shape[0], np.array(img_y).shape[1]), dtype='f'))
#temp = torch.from_numpy(np.array(img_y))
input_x_t[0, 0, :, :] = torch.from_numpy(np.array(img_y))
#print(input_x.size())
recon_y = model(input_x_t).detach()
print(recon_y.size())
#print(recon_y[:,:,:5,:5])
#temp = recon_y.data.squeeze(0).