介绍
self.modules()是继承torch.nn.Modules()的类拥有的方法,以迭代器形式返回此前声明的所有layers
实验
代码
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.lls = nn.Sequential(self.relu,nn.Linear(10,10))
self.lls1 = nn.Sequential(nn.Linear(10,19),nn.Linear(10,10))
for i,j in enumerate(self.modules()):
print(i,"->:",j)
结果
- 这是一个深度优先遍历,遇到Sequential会继续深入
- 如果有layer在之前已经声明过,则不再添加,就像上图第六个,Sequential里的relu,在第三个已经打印,在之后不会打印,也就是说这是一个集合,没有重复元素。