一、报错:‘Res_rec’ object has no attribute ‘_parameters’
二、纠错前代码
class Res_com(nn.Module):
def __init__(self, n_com=3, b_com=6, d_com=2, com_disable=False):
def make(n, b, d):
f1 = [2 ** (b + i) for i in range(n)] + [2 ** (b + n - 1 - i) for i in range(n)]
f2 = [i * (2 ** d) for i in f1]
del f2[len(f2) // 2]
f2_last = 32 if f1[0] > 32 else 16
f2.append(f2_last)
return f1, f2
self.f1_com, self.f2_com = make(n_com, b_com, d_com)
self.n_com = n_com
self.com_disable = com_disable
def forward(self,x):
if self.com_disable:
print("No implementation for com_disable")
raise
else:
print("in_channels for class Res_com is :{}".format(x[1]))
for i in range(self.n_com * 2):
x = Res_block(x[1], [self.f1_com[i], self.f1_com[i], self.f2_com[i]])(x)
out = nn.Conv2d(in_channels=x[1],out_channels=12,kernel_size=1,stride=1,padding=0)(x)
return out
出错分析
网络中涉及到参数的nn.Conv2d,BatchNorm2d等都被放到了forward,致使程序运行中会认为该网络没有可以更新的权重参数从而出错
解决方案
正确的做法是:将所有涉及到参数更新的层全部以类的变量成员的形式定义在类的初始化方法__init__方法中。
三、纠错后代码
class Res_com(nn.Module):
def __init__(self, n_com=3, b_com=6, d_com=2, com_disable=False,in_channels=3):
def make(n, b, d):
f1 = [2 ** (b + i) for i in range(n)] + [2 ** (b + n - 1 - i) for i in range(n)]
f2 = [i * (2 ** d) for i in f1]
del f2[len(f2) // 2]
f2_last = 32 if f1[0] > 32 else 16
f2.append(f2_last)
return f1, f2
self.f1_com, self.f2_com = make(n_com, b_com, d_com)
self.n_com = n_com
self.com_disable = com_disable
layers = []
if self.com_disable:
print("No implementation for com_disable")
raise
else:
for i in range(self.n_com * 2):
if i == 0 :
layers.append(Res_block(in_channels, [self.f1_com[i], self.f1_com[i], self.f2_com[i]]))
else:
layers.append(Res_block(self.f2_com[i-1], [self.f1_com[i], self.f1_com[i], self.f2_com[i]]))
self.multi_res_block = nn.sequential(*layers)
self.conv1 = nn.Conv2d(in_channels=x[1],out_channels=12,kernel_size=1,stride=1,padding=0)(x)
def forward(self,x):
print("in_channels for class Res_com is :{}".format(x[1]))
out = self.multi_res_block(x)
out = self.conv1(out)
return out