task03 Pytorch模型定义

task03 Pytorch模型定义

2022/6/19 雾切凉宫

先引入必要的包

import os
import numpy as np
import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

1. 模型定义方式

定义方法特点使用方式
sequential:direct list直接按顺序定义模型直接传入变量
sequential:ordered dict类似字典构建,但是有序定义模型直接传入变量
ModuleList可以逐个定义模型层,但需要实例化先实例化再使用
ModuleDict重写前馈函数时需要用层名遍历先实例化再使用

P.S.实例化指的是需要继承nn.Module类,并重写构造函数(init)和前馈函数(forward)

1.1 sequential方法

1.1.1 sequential:direct list

直接按模型层序一一排列,定义好直接就可以前向传播运算。

优点是方便快捷

缺点是不方便定义复杂模型

net1 = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
)
print(net1)
Sequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)

输入纬度:784->256->10
总共三层:两层全连接一层ReLU激活函数

1.1.2 sequential:ordered dict

ordered dict形式定义模型:
每一层的定义都在一个tuple中,tuple包含两个元素:第一个是该层的名字(自定义),第二个是层结点的定义

net2 = nn.Sequential(collections.OrderedDict([
    ("fcl", nn.Linear(784, 256)),
    ("relu1", nn.ReLU()),
    ("fc2", nn.Linear(256, 10))
]))
print(net2)

以上定义了一个同1.1.1方式的模型,区别在于每一层有了自定义的名字。

P.S.虽然是字典形式,但是却有顺序,我觉得可以理解为带一个注释参数的list,模型的层序严格按照定义的层序

下面是模型的前馈运算,可见模型一定义就可以运算,不需要实例化

a = torch.rand(4,784)
out1 = net1(a)
out2 = net2(a)
print(out1.shape,out2.shape)
torch.Size([4, 10]) torch.Size([4, 10])

1.2 ModuleList方法

ModuleList方法与sequential方法最大的不同在于模型定义后需要实例化。自己重写构造函数(init)和前馈函数(forward)

net3 = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net3.append(nn.Linear(256, 10))
print(net3)
ModuleList(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)

moduleList并不会实际定义一个网络,只是将不同模块存储。
所以通常moduleList的使用是这样的:

先进行类的继承与方法重写:

class Net3(nn.Module):
    def __init__(self):
        super().__init__()
        self.modulelist = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
        self.modulelist.append(nn.Linear(256, 10))
    
    def forward(self, x):
        for layer in self.modulelist:
            x = layer(x)
        return x
    

以上定义了一个Net3的网络继承nn.Module,包括了一个构造函数,完成了对ModuleList的定义
一个前向传播函数,并定义了参数如何在各层之间传播

net = Net3()
out = net(a)
print(out.shape)
torch.Size([4, 10])

以上实现了模型的计算,先实例化网络类,再传入数据,就能够完成前向计算

1.3 ModuleDict方法

在定义上与ModuleList方法大体相似,不同的是可以为每一层命名

net4 = nn.ModuleDict(
    {
        "linear":nn.Linear(784,256),
        "act":nn.ReLU(),
    })
net4["output"] = nn.Linear(256, 10)
print(net4)

使用上与ModuleList有一定区别

在forward上,因为字典没法直接遍历,要自己写每一层前向传播的过程,个人感觉区别就是可以给每一层设个名字然后用名字编写forward前馈函数。。。

class Net4(nn.Module):
    def __init__(self):
        super().__init__()
        self.ModuleDict= nn.ModuleDict(
        {
            "linear":nn.Linear(784,256),
            "act":nn.ReLU(),
        })
        self.ModuleDict["output"] = nn.Linear(256, 10)
    
    def forward(self, x):
        x = self.ModuleDict["linear"](x)
        x = self.ModuleDict["act"](x)
        x = self.ModuleDict["output"](x)
        return x
net4 = Net4()
out = net4(a)
print(out.shape)
torch.Size([4, 10])

1.4 总结

  • sequential 简单直白,模型定义完了,直接就可以运算,不用定义forward
  • ModuleList 不能够直接用来运算,需要实例化,自定义forward。可以批量定义一个复杂的多层网络
  • ModulaDict 与ModuleList类似,需要实例化、定义forward,不过可以给每层网络自定义名字,并用名字来调用各层的方法。

2. 利用模型块快速搭建网络

以U-Net为例,学习如何构建模型块,以及如何利用模型块快速搭建复杂模型。

2.1 U-Net简介

U-Net是分割 (Segmentation) 模型的杰作,在以医学影像为代表的诸多领域有着广泛的应用。U-Net模型结构如下图所示,通过残差连接结构解决了模型学习中的退化问题,使得神经网络的深度能够不断扩展。

在这里插入图片描述

2.2 U-Net模型块分析

结合上图,不难发现U-Net模型具有非常好的对称性。模型从上到下分为若干层,每层由左侧和右侧两个模型块组成,每侧的模型块与其上下模型块之间有连接;同时位于同一层左右两侧的模型块之间也有连接,称为“Skip-connection”。此外还有输入和输出处理等其他组成部分。由于模型的形状非常像英文字母的“U”,因此被命名为“U-Net”。

组成U-Net的模型块主要有如下几个部分:

1)每个子块内部的两次卷积(Double Convolution)

2)左侧模型块之间的下采样连接,即最大池化(Max pooling)

3)右侧模型块之间的上采样连接(Up sampling)

4)输出层的处理

除模型块外,还有模型块之间的横向连接,输入和U-Net底部的连接等计算,这些单独的操作可以通过forward函数来实现。

下面我们用PyTorch先实现上述的模型块,然后再利用定义好的模型块构建U-Net模型。

2.3 模型块代码实现

2.3.1 双层卷积模块

in_channels**–(卷积conv2d)–>mid_channels–(batchNorm2/ReLU)–>mid_channelss–(卷积conv2d)–>out_channels–(batchNorm2/ReLU)–>**out_channels

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


2.3.2 降采样模块

in_channels**–(MaxPools2d池化降维)–>**复用双层卷积

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)
2.3.3 上采样模块

forward中实现了图中copy and crop的功能

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


2.3.4 输出模块

不会被复用
最后进行一次两次卷积将维数调整到需要的维数

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

2.4 模型块组装U-Net

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
unet = UNet(3, 1)

3. 模型修改

3.1 修改模型层

比如说我们想要一个5分类的任务,需要最后一层输出5个变量。
只需要取出原模型的最后一层重新实例化最后一层即可。

import copy
unet1 = copy.deepcopy(unet)
unet1.outc = OutConv(64,5)
b = torch.rand(1,3,224,224)
out_unet1 = unet1(b)
print(out_unet1.shape)
torch.Size([1, 5, 224, 224])

可以看见batchsize与图片大小均没有变化,channel数发生了变化,变成了5

3.2 添加额外输入

class UNet_more_input(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet_more_input, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x, add_variable):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = x + add_variable #修改点
        logits = self.outc(x)
        return logits
unet2 = UNet_more_input(3,1)
c = torch.rand(1, 224, 224)
out = unet2(b,c)
print(out.shape)
torch.Size([1, 1, 224, 224])

总结:在forward处添加变量,并加入各层运算中

3.3 添加额外输出

class UNet_more_output(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet_more_output, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits, x5   #修改点
unet3 = UNet_more_output(3,1)
out ,more_out= unet3(b)
print(out.shape,more_out.shape)
torch.Size([1, 1, 224, 224]) torch.Size([1, 512, 14, 14])

总结:在return处添加更多输出

4. 模型保存与读取

unet(b)
unet.state_dict()
#保存&读取整个模型
torch.save(unet, "./unet.pth")
torch.load("./unet.pth")
#保存&读取模型权重
torch.save(unet.state_dict(),"./unet_weight.pth")
loaded_weight = torch.load("./unet_weight.pth")
unet.load_state_dict(loaded_weight)
<All keys matched successfully>
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值