PyTorch nn容器

在PyTorch中,nn容器(nn containers)是用于组织和管理神经网络层的工具,它们使得构建和管理复杂的神经网络结构变得更加简单和高效。以下是几个常用的nn容器及其特点:

1. nn.Sequential

  • 定义nn.Sequential是一个有序的容器,它按照传入构造器的顺序,依次创建相应的网络层,并将它们封装成一个整体。
  • 特点
    • 有序性:内部的层按照传入的顺序进行排列,前一个层的输出自动作为后一个层的输入。
    • 简化代码:使用nn.Sequential可以简化代码,避免显式地编写前向传播逻辑。
    • 灵活性受限:由于内部层的顺序是固定的,因此在需要复杂前向传播逻辑时可能不够灵活。
  • 示例
  • import torch.nn as nn
    
    model = nn.Sequential(
        nn.Linear(784, 128),
        nn.ReLU(),
        nn.Linear(128, 10),
        nn.LogSoftmax(dim=1)
    )
    
    print(model)
    
    输出结果:
    Sequential(
      (0): Linear(in_features=784, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=10, bias=True)
      (3): LogSoftmax(dim=1)
    )

    2. nn.ModuleList

  • 定义nn.ModuleList是一个持有多个子模块的列表,它继承自nn.Module,因此可以被视为一个特殊的列表容器。
  • 特点
    • 迭代性:可以像普通的Python列表一样迭代nn.ModuleList中的模块。
    • 自动注册:添加到nn.ModuleList中的模块会被自动注册到整个网络中,这意味着它们的参数会被包含在net.parameters()中。
    • 无前向传播:与nn.Sequential不同,nn.ModuleList本身不定义前向传播逻辑,需要在自定义的forward方法中实现。
  • 示例
  • import torch.nn as nn
    
    class MyModule(nn.Module):
        def __init__(self):
            super(MyModule, self).__init__()
            self.linears = nn.ModuleList([nn.Linear(10, 10) for _ in range(4)])
    
        def forward(self, x):
            for linear in self.linears:
                x = linear(x)
            return x
    
    model = MyModule()
    print(model)
    
    输出结果:
    MyModule(
      (linears): ModuleList(
        (0): Linear(in_features=10, out_features=10, bias=True)
        (1): Linear(in_features=10, out_features=10, bias=True)
        (2): Linear(in_features=10, out_features=10, bias=True)
        (3): Linear(in_features=10, out_features=10, bias=True)
      )
    )

    3. nn.ModuleDict

  • 定义nn.ModuleDict是一个持有多个子模块的字典,它同样继承自nn.Module
  • 特点
    • 命名索引:通过键值对的方式存储模块,可以通过键名来索引和访问模块。
    • 自动注册:与nn.ModuleList类似,添加到nn.ModuleDict中的模块也会被自动注册到整个网络中。
    • 灵活性:允许根据需求动态地添加、删除或修改模块。
  • 示例
    import torch.nn as nn
    
    class MyModule(nn.Module):
        def __init__(self):
            super(MyModule, self).__init__()
            self.layers = nn.ModuleDict({
                'conv1': nn.Conv2d(1, 20, 5),
                'pool': nn.MaxPool2d(2, 2),
                'conv2': nn.Conv2d(20, 50, 5)
            })
    
        def forward(self, x):
            x = self.layers['conv1'](x)  
            x = self.layers['pool'](x)
            x = self.layers['conv2'](x)
            return x
    
    model = MyModule()
    print(model)
    
    输出结果:
    MyModule(
      (layers): ModuleDict(
        (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
      )
    )

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值