task03 Pytorch模型定义

task03 Pytorch模型定义

2022/6/19 雾切凉宫

先引入必要的包

import os
import numpy as np
import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

1. 模型定义方式

定义方法 特点 使用方式
sequential:direct list 直接按顺序定义模型 直接传入变量
sequential:ordered dict 类似字典构建,但是有序定义模型 直接传入变量
ModuleList 可以逐个定义模型层,但需要实例化 先实例化再使用
ModuleDict 重写前馈函数时需要用层名遍历 先实例化再使用

P.S.实例化指的是需要继承nn.Module类,并重写构造函数(init)和前馈函数(forward)

1.1 sequential方法

1.1.1 sequential:direct list

直接按模型层序一一排列,定义好直接就可以前向传播运算。

优点是方便快捷

缺点是不方便定义复杂模型

net1 = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
)
print(net1)
Sequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)

输入纬度:784->256->10
总共三层:两层全连接一层ReLU激活函数

1.1.2 sequential:ordered dict

ordered dict形式定义模型:
每一层的定义都在一个tuple中,tuple包含两个元素:第一个是该层的名字(自定义),第二个是层结点的定义

net2 = nn.Sequential(collections.OrderedDict([
    ("fcl", nn.Linear(784, 256)),
    ("relu1", nn.ReLU()),
    ("fc2", nn.Linear(256, 10))
]))
print(net2)

以上定义了一个同1.1.1方式的模型,区别在于每一层有了自定义的名字。

P.S.虽然是字典形式,但是却有顺序,我觉得可以理解为带一个注释参数的list,模型的层序严格按照定义的层序

下面是模型的前馈运算,可见模型一定义就可以运算,不需要实例化

a = torch.rand(4,784)
out1 = net1(a)
out2 = net2(a)
print(out1.shape,out2.shape)
torch.Size([4, 10]) torch.Size([4, 10])

1.2 ModuleList方法

ModuleList方法与sequential方法最大的不同在于模型定义后需要实例化。自己重写构造函数(init)和前馈函数(forward)

net3 = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net3.append(nn.Linear(256, 10))
print(net3)
ModuleList(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)

moduleList并不会实际定义一个网络,只是将不同模块存储。
所以通常moduleList的使用是这样的:

先进行类的继承与方法重写:

class Net3(nn.Module):
    def __init__(self):
        super().__init__()
        self.modulelist = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
        self.modulelist.append(nn.Linear(256, 10))
    
    def forward(self, x):
        for layer in self.modulelist:
            x = layer(x)
        return x
    

以上定义了一个Net3的网络继承nn.Module,包括了一个构造函数,完成了对ModuleList的定义
一个前向传播函数,并定义了参数如何在各层之间传播

net = Net3()
out = net(a)
print(out.shape)
torch.Size([4, 10])

以上实现了模型的计算,先实例化网络类,再传入数据,就能够完成前向计算

1.3 ModuleDict方法

在定义上与ModuleList方法大体相似,不同的是可以为每一层命名

net4 = nn.ModuleDict(
    {
   
   
        "linear":nn.Linear(784,256),
        "act":nn.ReLU(),
    })
net4[
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值