print("auxi=",nn.Sequential(*list(self.auxi_network.children())))
结果如下:
打印梯度:
print("auxi=",nn.Sequential(*list(self.auxi_network.children()))[0][4][1].conv1.weight.grad)
注意:
- 直接打印self.auxi_network.children()是不会显示网络结构的。只有用上面所示的nn.Sequential封装之后才能支持对网络结构的展示。
- 关于网络的类型,既可以用常规的定义一个网络类并实现forward方法,也可以使用nn.Squential对nn.Conv2d、nn.ReLU、nn.Linear等函数进行封装得到,如:
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
值得注意的是,nn.Squential里还可以放入任何继承自nn.Module的自定义类,只要它实现了forward方法。例子:
其中Featurizer是:self.featurizer = networks.Featurizer(input_shape, self.hparams)
self.classifier = networks.Classifier(self.featurizer.n_outputs, num_classes, self.hparams[‘nonlinear_classifier’])
它所返回的ResNet类是自己封装的:
class ResNet(torch.nn.Module):
"""ResNet with the softmax chopped off and the batchnorm frozen"""
def __init__(self, input_shape, hparams):
super(ResNet, self).__init__()
if hparams['resnet18']:
self.network = torchvision.models.resnet18(pretrained=True)
self.n_outputs = 512
else:
self.network = torchvision.models.resnet50(pretrained=True)
self.n_outputs = 2048
nc = input_shape[0]
if nc != 3:
tmp = self.network.conv1.weight.data.clone()
self.network.conv1 = nn.Conv2d(
nc, 64, kernel_size=(7, 7),
stride=(2, 2), padding=(3, 3), bias=False)
for i in range(nc):
self.network.conv1.weight.data[:,
i, :, :] = tmp[:, i % 3, :, :]
# save memory
del self.network.fc
self.network.fc = Identity()
self.freeze_bn()
self.hparams = hparams
self.dropout = nn.Dropout(hparams['resnet_dropout'])
def forward(self, x):
"""Encode x into a feature vector of size n_outputs."""
return self.dropout(self.network(x))
def train(self, mode=True):
"""
Override the default train() to freeze the BN parameters
"""
super().train(mode)
self.freeze_bn()
def freeze_bn(self):
for m in self.network.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
因为本质上来说,这种继承自nn.Module的自定义的类和pytorch的nn.Conv2d等类本质是一样的,都可以理解为是个Module,接受一个输入并产生一个输出。查看nn.Sequential的源码就可以知道它支持这两种东西的输入。
nn.Sequential源码:
def __init__(self, *args):
super(Sequential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):#接收nn.Conv2d、nn.Linear这类参数,因为他们本质是OrderedDict
for key, module in args[0].items():
self.add_module(key, module)
else:#接收自定义网络类
for idx, module in enumerate(args):
self.add_module(str(idx), module)
关于nn.Conv2d、nn.Linear为什么是OrderedDict,可以参考这篇文章,在此不做展开。
3. 对每一层的访问方式。模型打印出来前面带标号的,如(0),可以直接用下标[0]访问。如果前面是带名字的,如(con1),则需要用属性名.conv1访问。