项目场景:
提示:这里简述项目相关背景:
最近在调试一个网络(3D V-Net),将网络的输入 输出形状改为需求的形状大小
问题描述
代码如下:
class InputTransition(nn.Module):
def __init__(self, outChans, elu):
super(InputTransition, self).__init__()
self.conv1 = nn.Conv3d(1, 16, kernel_size=5, padding=2)
self.bn1 = ContBatchNorm3d(16)
self.relu1 = ELUCons(elu, 16)
def forward3(self, x):
# do we want a PRELU here as well?
out = self.bn1(self.conv1(x))
# split input in to 16 channels
x16 = torch.cat((x, x, x, x, x, x, x, x, #把输入重复16次,变换为16通道
x, x, x, x, x, x, x, x), 0)
out = self.relu1(torch.add(out, x16))
return out
报错NotImplementedError
如下:
File "D:\Python3.7.8\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Python3.7.8\lib\site-packages\torch\nn\modules\module.py", line 201, in _forward_unimplemented
raise NotImplementedError
NotImplementedError
Process finished with exit code 1
原因分析:
提示:这里填写问题的分析:
通过阅览其他博主的博客后发现,在python定义类的时候可以预先留出一个接口不实现,而在后续继承的子类中实现,当我们想要提醒自己这个类的子类一定要实现这个接口时,可以调用NotImplementError
针对本个网络来说就是 在nn.Module父类中要求在子类中必须重写forward方法 但是子类中并没有重写这个方法
所以 报这种错的常见原因是:
1.没有重写forward方法
2.重写forward方法时 forward
单词的拼写有误 ,建议仔细检查一下拼写。 经常看到有的博主把 forward
拼写为 forword
或者像我这样的forward3
手欠在后面加了数字[doge]
解决方案:
提示:这里填写该问题的具体解决方案:
1.检查是否重写forward方法
2.检查重写forward方法时 forward
单词的拼写是否有误 ,建议仔细检查一下拼写。
class InputTransition(nn.Module):
def __init__(self, outChans, elu):
super(InputTransition, self).__init__()
self.conv1 = nn.Conv3d(1, 16, kernel_size=5, padding=2)
self.bn1 = ContBatchNorm3d(16)
self.relu1 = ELUCons(elu, 16)
def forward(self, x):
# do we want a PRELU here as well?
out = self.bn1(self.conv1(x))
# split input in to 16 channels
x16 = torch.cat((x, x, x, x, x, x, x, x, #把输入重复16次,变换为16通道
x, x, x, x, x, x, x, x), 0)
out = self.relu1(torch.add(out, x16))
return out
将InputTransition类中方法名称由原来的 forward3
改为 forward
后程序即可顺利运行