错误如下:
Traceback (most recent call last):
File "main.py", line 23, in <module>
t.train()
File "c:\Paper Code\RCAN-master-Real\RCAN_TrainCode\code\trainer.py", line 51, in train
sr = self.model(lr, idx_scale)
File "C:\Anaconda3\envs\pytorch0.4.0\lib\site-packages\torch\nn\modules\module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "c:\Paper Code\RCAN-master-Real\RCAN_TrainCode\code\model\__init__.py", line 54, in forward
return self.model(x)
File "C:\Anaconda3\envs\pytorch0.4.0\lib\site-packages\torch\nn\modules\module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "c:\Paper Code\RCAN-master-Real\RCAN_TrainCode\code\model\rcan.py", line 107, in forward
x = self.sub_mean(x)
File "C:\Anaconda3\envs\pytorch0.4.0\lib\site-packages\torch\nn\modules\module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "C:\Anaconda3\envs\pytorch0.4.0\lib\site-packages\torch\nn\modules\conv.py", line 301, in forward
self.padding, self.dilation, self.groups)
RuntimeError: CUDNN_STATUS_EXECUTION_FAILED
代码修改前:
if __name__ == '__main__':
torch.manual_seed(args.seed)
checkpoint = utility.checkpoint(args)
if checkpoint.ok:
loader = data.Data(args)
model = model.Model(args, checkpoint)
loss = loss.Loss(args, checkpoint) if not args.test_only else None
t = Trainer(args, loader, model, loss, checkpoint)
while not t.terminate():
t.train()
t.test()
checkpoint.done()
代码修改后
if __name__ == '__main__':
torch.backends.cudnn.enabled = False
torch.manual_seed(args.seed)
checkpoint = utility.checkpoint(args)
if checkpoint.ok:
loader = data.Data(args)
model = model.Model(args, checkpoint)
loss = loss.Loss(args, checkpoint) if not args.test_only else None
t = Trainer(args, loader, model, loss, checkpoint)
while not t.terminate():
t.train()
t.test()
checkpoint.done()
跑动了~~
学习笔记记录。。