【Python】nn.Container基类Module, Sequential, ModuleList, ModuleDict函数详解和示例

前言

在 PyTorch 中,nn.Container 是一个基类,它并不直接提供实例化对象的功能,而是作为其他容器类(如 nn.Module, nn.Sequential, nn.ModuleList, nn.ModuleDict)的基础。这些容器类在构建和组织神经网络时起着至关重要的作用。本文将详细讲解这些容器的原理、原型、运行实例,并最后进行总结。

函数

函数原理

nn.Module

nn.Module 是所有神经网络模块的基类。当你定义一个神经网络模型时,你需要继承这个类。nn.Module 包含了一些基本属性和方法,比如 parameters() 和 buffers(),用于管理模型中的参数和缓冲区(如批量归一化层的均值和方差)。此外,它还定义了 forward() 方法,这是实现前向传播的关键。

nn.Sequential

nn.Sequential 是一个有序容器,可以包含多个子模块(layer),并自动实现这些子模块的前向传播。你只需要按照顺序添加模块,nn.Sequential 会按照添加的顺序自动调用它们的 forward() 方法。这使得模型构建变得非常简单和直观。

nn.ModuleList

nn.ModuleList 是一个持有子模块的列表,它本身也是一个 nn.Module。与普通的 Python 列表不同,nn.ModuleList 中的模块会被正确地注册,并能在整个模块树中可见。这使得在模型中包含多个相同或不同的模块时更加方便。

nn.ModuleDict

nn.ModuleDict 是一个持有子模块的字典,类似于 nn.ModuleList,但提供了通过键来访问子模块的功能。这对于需要参数化模型某些部分的场景非常有用,比如动态选择激活函数。

原型

nn.Module

class torch.nn.Module(object):
    def __init__(self):
        pass

    def forward(self, *input):
        raise NotImplementedError

nn.Sequential

class torch.nn.Sequential(Module):
    def __init__(self, *args):
        super(Sequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

    def forward(self, input):
        for module in self._modules.values():
            input = module(input)
        return input

nn.ModuleList

class torch.nn.ModuleList(Module):
    def __init__(self, modules=None):
        super(ModuleList, self).__init__()
        if modules is not None:
            self.extend(modules)

    def extend(self, modules):
        for module in modules:
            self.append(module)

    def append(self, module):
        self.add_module(str(len(self)), module)

nn.ModuleDict

class torch.nn.ModuleDict(Module):
    def __init__(self, modules=None):
        super(ModuleDict, self).__init__()
        if modules is not None:
            self.update(modules)

    def update(self, modules):
        for key, module in modules.items():
            self.add_module(key, module)

    def __getitem__(self, key):
        return self._modules[key]

运行实例

使用 nn.Sequential

import torch.nn as nn

model = nn.Sequential(
    nn.Conv2d(1, 20, 5),
    nn.ReLU(),
    nn.Conv2d(20, 64, 5),
    nn.ReLU()
)

print(model)

在这里插入图片描述

使用 nn.ModuleList

import torch
import torch.nn as nn
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 创建一个包含5个线性层的列表
        self.linears = nn.ModuleList([nn.Linear(10, 10) for _ in range(5)])

    def forward(self, x):
        # 遍历每个线性层并应用
        for linear in self.linears:
            x = linear(x)
            # 注意:在实际应用中,可能需要加入激活函数或改变输入数据的维度
            # 这里为了简单,我们直接传递x
        return x

# 实例化模型
model = MyModel()
# 假设输入数据
input_data = torch.randn(1, 10)
# 前向传播
output = model(input_data)
print(output.shape)  # 输出维度取决于最后的线性层,这里是[1, 10]

在这里插入图片描述

使用 nn.ModuleDict


import torch
import torch.nn as nn
class MyModelWithDict(nn.Module):
    def __init__(self):
        super(MyModelWithDict, self).__init__()
        # 创建一个包含不同类型层的字典
        self.layers = nn.ModuleDict({
            'conv1': nn.Conv2d(1, 20, 5),
            'relu': nn.ReLU(),
            'fc': nn.Linear(11520, 10)  # 假设输入特征需根据实际情况调整
        })

    def forward(self, x):
        x = self.layers['conv1'](x)
        x = self.layers['relu'](x)
        # 假设我们进行了适当的reshape或flatten操作
        x = x.view(x.size(0), -1)
        x = self.layers['fc'](x)
        return x

# 实例化模型
model_dict = MyModelWithDict()
# 假设输入数据为图像数据
input_image = torch.randn(1, 1, 28, 28)  # 假设是28x28的灰度图像
# 前向传播
output_dict = model_dict(input_image)
print(output_dict.shape)  # 输出维度取决于最后的线性层,这里是[1, 10]

在这里插入图片描述

小结

nn.Container 虽然不是一个直接用于实例化的类,但它作为 PyTorch 中所有神经网络模块的基类,为模型构建提供了基础框架。nn.Sequential, nn.ModuleList, 和 nn.ModuleDict 是基于 nn.Module 的容器类,它们各自在模型构建中扮演着不同的角色:

nn.Sequential 提供了一种快速串联多个层的方式,使得模型结构更加清晰。
nn.ModuleList 类似于 Python 的列表,但专门用于存储 nn.Module 对象,并支持自动注册到网络中。
nn.ModuleDict 类似于 Python 的字典,除了存储 nn.Module 对象外,还支持通过键来访问模块,这在需要动态选择或操作模块时非常有用。
通过这些容器类,PyTorch 提供了一种灵活且强大的方式来构建和组织复杂的神经网络模型。


笔者主要从事计算机视觉方面研究和开发,包括实例分割、目标检测、追踪等方向,进行算法优化和嵌入式平台开发部署。欢迎大家沟通交流、互帮互助、共同进步。
`nn.Module` 是 PyTorch 中所有模型的基类。它是所有神经网络模块的父类,提供了许多有用的方法,包括参数管理、子模块管理和前向传播函数等。 `nn.Module` 类的主要作用是提供一个统一的接口来管理神经网络模块的参数和子模块。每个 `nn.Module` 都可以包含一些可学习的参数(例如权重和偏置),并且可以包含其他的子模块(例如卷积层、池化层、全连接层等)。这些参数和子模块都可以通过 `nn.Module` 类的方法来管理和访问。 下面是一些 `nn.Module` 类的常用方法: - `__init__`: 构造函数,用于定义模型的结构、初始化参数等。 - `forward`: 前向传播函数,用于定义模型的计算流程,将输入转换为输出。 - `parameters`: 返回一个可迭代的参数列表,包含所有的可学习参数。 - `named_parameters`: 返回一个可迭代的参数列表,包含所有的可学习参数及其名称。 - `children`: 返回一个可迭代的模块列表,包含所有的子模块。 - `named_children`: 返回一个可迭代的模块列表,包含所有的子模块及其名称。 - `to`: 将模型移动到指定的设备上,例如 CPU、GPU 等。 - `train`: 将模型设置为训练模式。 - `eval`: 将模型设置为评估模式。 使用 `nn.Module` 类可以方便地管理神经网络模型的结构和参数,并且可以支持模型的保存和加载等操作。通过继承 `nn.Module` 类,我们可以快速构建自己的神经网络模型,从而解决各种复杂的机器学习问题。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

木彳

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值