使用torch.nn中的Container简化模型代码

1.多种容器

在这里插入图片描述
此处重点关注nn.Sequential()nn.ModuleList()

官方文档torch.nn中的container

2.nn.Sequential()

A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in.

To make it easier to understand, here is a small example

# Example of using Sequential
model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

# Example of using Sequential with OrderedDict
from collections import OrderedDict


model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))

使用nn.Sequential()构建网络模型如下:
包括卷积都可以直接放进去,更简洁的是forward(self, x)的写法
一句话就写完了…

class Model(nn.Module):
    def __init__(self):
        super(Modle, self).__init__()
        self.network=nn.Sequential(
            nn.Linear(1,10),nn.ReLU(),
            nn.Linear(10,100),nn.ReLU(),
            nn.Linear(100,10),nn.ReLU(),
            nn.Linear(10,1)
        )
    def forward(self, x):
        return self.network(x)

可以使用 self.network[0] 获取第一个 Linear子模型,由于每一个子模型没有设置唯一的名称,所以只能使用数字索引来获取。

添加子模型的方法:

也可以在创建之后加入新的

self.network.add_module("linear1",nn.Linear(100,100))

获取该新层的方法

linear=self.network.linear1

要是放在以前,写的老冗长

class Model(nn.Module):
	def __init__(self):
        super(Model,self).__init__()
        
        self.linear1=nn.Linear(1,10)
        self.activation1=nn.ReLU()
        self.linear2=nn.Linear(10,100)
        self.activation2=nn.ReLU()
        self.linear3=nn.Linear(100,10)
        self.activation3=nn.ReLU()
        self.linear4=nn.Linear(10,1)
        
    def forward(self,x):
        out=self.linear1(x)
        out=self.activation1(out)
        out=self.linear2(out)
        out=self.activation2(out)
        out=self.linear3(out)
        out=self.activation3(out)
        out=self.linear4(out)
        return out

3.nn.ModuleList

Holds submodules in a list.

ModuleList can be indexed like a regular Python list, but modules it contains are properly registered, and will be visible by all Module methods.

这里是一个list,因而写forward()和前面不同,基本都是把索引和list中的值取出来

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers=nn.ModuleList([
            nn.Linear(1,10), nn.ReLU(),
            nn.Linear(10,1)])
    def forward(self,x):
        out = x
        for layer in self.layers:
            out = layer(out)
        return out

查看网络结构

model = Model()
print(model)

输出如下:

Model(
  (layers): ModuleList(
    (0): Linear(in_features=1, out_features=10, bias=True)
    (1): ReLU()
    (2): Linear(in_features=10, out_features=1, bias=True)
  )
)

添加单个子模块

append()
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers=nn.ModuleList([
            nn.Linear(1,10), nn.ReLU(),
            nn.Linear(10,1)])
        self.layers.append(nn.Linear(1, 5))
    def forward(self,x):
        out = x
        for layer in self.layers:
            out = layer(out)
        return out

添加多个子模块(另一个list)

extend(),必须也为一个list

self.layers.extend([nn.Linear(size1, size2) for i in range(1, num_layers)])

参考来源

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值