Pytorch学习笔记之模型的创建

Pytorch中所有模型都是基于Module这个类,也就是说无论是自定义的模型,还是Pytorch中已有的模型,都是这个类的子类,并重写了forward方法。Pytorch中创建模型有几种方法。

继承Module

这是最直接的方法,自己写一个模型继承Module,并重写forward方法。

from torch.nn import Module
import torch.nn as nn
import torch.nn.functional as F

class LinearMoudule(Module):
    def __init__(self):
        super(LinearMoudule,self).__init__()
        self.linear_1 = nn.Linear(10,30)
        self.linear_2 = nn.Linear(30,5)
    
    def forward(self,x):
        x = self.linear_1(x)
        x = F.tanh(x)
        x = self.linear_2(x)
        x = F.sigmoid(x)
        return x

使用Sequential

使用Sequential是一种快速构建模型的方法,只需将需要添加的模型放入其构造函数即可。

from torch.nn import Module
import torch.nn as nn
import torch.nn.functional as F

module = nn.Sequential(nn.Linear(10,30),
                      nn.Tanh(),
                      nn.Linear(30,5),
                      nn.Sigmoid())

另外还可以用OrderedDict来对每一层模型进行命名。

from torch.nn import Module
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict

module = nn.Sequential(OrderedDict(
							{'linear_1':nn.Linear(10,30),
                            'tanh':nn.Tanh(),
                            'linear_2': nn.Linear(30,5),
                            'sigmod': nn.Sigmoid()}
                            ))

module = nn.Sequential(OrderedDict(
							[('linear_1', nn.Linear(10,30)),
                            ('tanh',nn.Tanh()),
                            ('linear_2', nn.Linear(30,5)),
                            ('sigmod', nn.Sigmoid())
                            ))
                            #两种写法都可以

ModuleList和ModuleDict

这两个类顾名思义,是分别通过List和Dict两种容器将模块进行包装来创建新模型的。并且这两个类可以通过迭代来访问。

from torch.nn import Module
import torch.nn as nn
import torch.nn.functional as F、

class Module_List(nn.Module):
    def __init__(self):
        super(Module_List, self).__init__()
        self.modules = nn.ModuleList([nn.Linear(10, 30),nn.Tanh(),nn.Linear(30,5),nn.Sigmoid()])

    def forward(self, x):
        for layer in self.modules:
            x = layer(x)
        return x


class Module_Dict(nn.Module):
    def __init__(self):
        super(Module_Dict, self).__init__()
        self.modules = nn.ModuleDict({'linear_1' : nn.Linear(10,30),
                            'tanh':nn.Tanh(),
                            'linear_2': nn.Linear(30,5),
                            'sigmod': nn.Sigmoid()})

    def forward(self, x):
        for layer in self.modules:
            x = layer(x)
        return x


混合使用

前面几种创建模型的方法可以混合使用,来创建更为复杂的模型。

class Bottle(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size):
        super(Bottle, self).__init__()
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self,x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.pool(x)
        return x

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()

        self.bottle_1 = Bottle(3,6,5)
        self.bottle_2 = Bottle(6,16,5)

        self.fc = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                nn.ReLU(),
                                nn.Linear(120, 84),
                                nn.ReLU())
        self.last_fc = nn.Linear(84, 10)

    def forward(self,x):
        x = self.bottle_1(x)
        x = self.bottle_2(x)
        x = x.view(-1, 16 * 5 * 5)
        x = self.fc(x)
        x = self.last_fc(x)
        return x
已标记关键词 清除标记
©️2020 CSDN 皮肤主题: 技术黑板 设计师:CSDN官方博客 返回首页