一、nn.Module的children()方法与modules()方法的区别
children()与modules()都是返回网络模型里的组成元素,但是children()返回的是最外层的元素,modules()返回的是所有的元素,包括不同级别的子元素。
首先定义以下全连接网络:
import torch
from torch import nn
class SimpleNet(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim ):
super().__init__()
self.layer1 = nn.Sequential(
nn.Linear(in_dim, n_hidden_1),
nn.ReLU(),
)
self.layer2 = nn.Sequential(
nn.Linear(n_hidden_1, n_hidden_2),
nn.ReLU(),
)
self.layer3 = nn.Linear(n_hidden_2, out_dim)
def forward(self, x):
x = self.layer1(x),
x = self.layer2(x),
x = self.layer3(x)
return x
if __name__ == "__main__":
net = SimpleNet(2, 3, 3, 2)
print(net)
测试运行,结果如下:
可以看到这个网络的结构如下:
1.1 Module类的children()方法
children()方法返回的是最外层,也就是1,2,3这三个。
Module.children()是一个生成器,生成器是一种迭代器。迭代器实现了__iter__() 和__next__()方法。迭代器肯定是可迭代对象,可迭代对象就能放在for x in ...后面进行遍历。
例:
import torch
from torch import nn
class SimpleNet(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim ):
super().__init__()
self.layer1 = nn.Sequential(
nn.Linear(in_dim, n_hidden_1),
nn.ReLU(),
)
self.layer2 = nn.Sequential(
nn.Linear(n_hidden_1, n_hidden_2),
nn.ReLU(),
)
self.layer3 = nn.Linear(n_hidden_2, out_dim)
def forward(self, x):
x = self.layer1(x),
x = self.layer2(x),
x = self.layer3(x)
return x
if __name__ == "__main__":
net = SimpleNet(2, 3, 3, 2)
print(net.children()) #net.children()是一个生成器,生成器是一种迭代器
for i, e in enumerate(net.children()):
print("第{}个元素为:\n {}".format(i, e))
结果:
也就是输入了第一层的元素1,2,3。
1.2 Module类的modules()方法
modules()方法类似与深度优先遍历,不光返回的是最外层。
Module.modules()也是一个生成器。
import torch
from torch import nn
class SimpleNet(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim ):
super().__init__()
self.layer1 = nn.Sequential(
nn.Linear(in_dim, n_hidden_1),
nn.ReLU(),
)
self.layer2 = nn.Sequential(
nn.Linear(n_hidden_1, n_hidden_2),
nn.ReLU(),
)
self.layer3 = nn.Linear(n_hidden_2, out_dim)
def forward(self, x):
x = self.layer1(x),
x = self.layer2(x),
x = self.layer3(x)
return x
if __name__ == "__main__":
net = SimpleNet(2, 3, 3, 2)
print(net.modules()) #net.modules()是一个生成器,生成器是一种迭代器
for i, e in enumerate(net.modules()):
print("第{}个元素为:\n {}".format(i, e))
结果:
即,按照以下顺序进行返回的。
二、如何获取网络的某些层
可以借助children()方法来获取网络的某些层,比如只要经典网络的前几层,后面的层不要了。
比如,resnet18:
import torchvision.models as models
Resnet = models.resnet18(pretrained=False)
print(Resnet)
结果:
D:\Anaconda3\envs\pytorch_env\python.exe D:/pythonCodes/深度学习实验/行人重识别实验1:IDENet/aaa.py
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentu