总结:nn.Module的children()与modules()方法、如何获取网络的某些层

本文详细介绍了PyTorch中nn.Module的children()与modules()方法的区别,children()返回最外层模块,而modules()返回所有模块,包括子模块。通过实例展示了如何利用这些方法获取网络的特定层,如ResNet18模型中除最后两层外的所有层。这对于理解和操作复杂网络结构十分有用。
摘要由CSDN通过智能技术生成

一、nn.Module的children()方法与modules()方法的区别

children()与modules()都是返回网络模型里的组成元素,但是children()返回的是最外层的元素,modules()返回的是所有的元素,包括不同级别的子元素。

首先定义以下全连接网络:

import torch
from torch import nn


class SimpleNet(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim ):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(in_dim, n_hidden_1),
            nn.ReLU(),
        )
        self.layer2 = nn.Sequential(
            nn.Linear(n_hidden_1, n_hidden_2),
            nn.ReLU(),
        )
        self.layer3 = nn.Linear(n_hidden_2, out_dim)

    def forward(self, x):
        x = self.layer1(x),
        x = self.layer2(x),
        x = self.layer3(x)

        return x

if __name__ == "__main__":
    net = SimpleNet(2, 3, 3, 2)
    print(net)

测试运行,结果如下:

可以看到这个网络的结构如下:

1.1 Module类的children()方法

children()方法返回的是最外层,也就是1,2,3这三个。

Module.children()是一个生成器,生成器是一种迭代器。迭代器实现了__iter__() 和__next__()方法。迭代器肯定是可迭代对象,可迭代对象就能放在for x in ...后面进行遍历。

例:

import torch
from torch import nn


class SimpleNet(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim ):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(in_dim, n_hidden_1),
            nn.ReLU(),
        )
        self.layer2 = nn.Sequential(
            nn.Linear(n_hidden_1, n_hidden_2),
            nn.ReLU(),
        )
        self.layer3 = nn.Linear(n_hidden_2, out_dim)

    def forward(self, x):
        x = self.layer1(x),
        x = self.layer2(x),
        x = self.layer3(x)

        return x

if __name__ == "__main__":
    net = SimpleNet(2, 3, 3, 2)
    print(net.children())          #net.children()是一个生成器,生成器是一种迭代器
    for i, e in enumerate(net.children()):
        print("第{}个元素为:\n {}".format(i, e))

结果:

也就是输入了第一层的元素1,2,3。

1.2 Module类的modules()方法

modules()方法类似与深度优先遍历,不光返回的是最外层。

Module.modules()也是一个生成器。

import torch
from torch import nn


class SimpleNet(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim ):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(in_dim, n_hidden_1),
            nn.ReLU(),
        )
        self.layer2 = nn.Sequential(
            nn.Linear(n_hidden_1, n_hidden_2),
            nn.ReLU(),
        )
        self.layer3 = nn.Linear(n_hidden_2, out_dim)

    def forward(self, x):
        x = self.layer1(x),
        x = self.layer2(x),
        x = self.layer3(x)

        return x

if __name__ == "__main__":
    net = SimpleNet(2, 3, 3, 2)
    print(net.modules())          #net.modules()是一个生成器,生成器是一种迭代器
    for i, e in enumerate(net.modules()):
        print("第{}个元素为:\n {}".format(i, e))


结果:

即,按照以下顺序进行返回的。

二、如何获取网络的某些层

可以借助children()方法来获取网络的某些层,比如只要经典网络的前几层,后面的层不要了。

比如,resnet18:

import torchvision.models as models

Resnet = models.resnet18(pretrained=False)

print(Resnet)

结果:

D:\Anaconda3\envs\pytorch_env\python.exe D:/pythonCodes/深度学习实验/行人重识别实验1:IDENet/aaa.py
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentu
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值