pytorch系列教程(三)-自定义网络模型

前言

接下来将要实战自定义模型,本篇博客参考了:pytorch教程之nn.Module类详解——使用Module类来自定义模型
  

步骤

在自定义网络模型时,需要继承nn.Module类,并且重新实现__init__和forward这两个方法

一、简单用法

1、把可学习参数的层和不具有学习参数的层都放到构造函数中
先来看一个简单的例子

import torch
 
class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()  # 第一句话,调用父类的构造函数
        self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.relu1=torch.nn.ReLU()
        self.max_pooling1=torch.nn.MaxPool2d(2,1)
 
        self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.relu2=torch.nn.ReLU()
        self.max_pooling2=torch.nn.MaxPool2d(2,1)
 
        self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
        self.dense2 = torch.nn.Linear(128, 10)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.max_pooling1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.max_pooling2(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x
 
model = MyNet()
print(model)
'''运行结果为:
MyNet(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (max_pooling1): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU()
  (max_pooling2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  (dense1): Linear(in_features=288, out_features=128, bias=True)
  (dense2): Linear(in_features=128, out_features=10, bias=True)
)
'''

注意:上面的是将所有的层都放在了构造函数__init__里面,但是只是定义了一系列的层,各个层之间到底是什么连接关系并没有,而是在forward里面实现所有层的连接关系
  

二、高级用法

通过Sequential来包装层,把一些重复的层包装到Sequential中
这次使用UNet3D网络的例子,可以等到图中橙色的箭头代表conv(+BN)+Relu
在这里插入图片描述

可以使用Sequential包装起来,每次调用自己写的包装可以减少代码量

##自己定义的类 要实现两部分功能
##1、init函数中要说明输入和输出in_ch out_ch
##2、在forward函数中把各个部分连接起来
##in_ch out_ch
class DoubleConv2(nn.Module):
    def __init__(self,in_ch, out_ch):
        super(DoubleConv2, self).__init__()
        ###在init中定义各个部分
        self.conv = nn.Sequential(conv3x3(in_ch, out_ch),
                                    nn.BatchNorm3d(out_ch),
                                    nn.ReLU(inplace=True),
                                    conv3x3(out_ch, out_ch),
                                    nn.BatchNorm3d(out_ch),
                                    nn.ReLU(inplace=True)) 

    def forward(self, x):
        ##forward函数把init函数中的各个部分连接起来
        x = self.conv(x)
        return x

完整的UNet3D代码

import torch as torch
import torchvision as tv
import torch.nn as nn


import numpy as np 


def conv3x3(in_ch, out_ch, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class DoubleConv1(nn.Module):
    def __init__(self,in_ch, out_ch):
        super(DoubleConv1, self).__init__()

        self.conv = nn.Sequential(conv3x3(in_ch, in_ch),
                                    nn.BatchNorm3d(in_ch),
                                    nn.ReLU(inplace=True),
                                    conv3x3(in_ch, out_ch),
                                    nn.BatchNorm3d(out_ch),
                                    nn.ReLU(inplace=True)) 

    def forward(self, x):
        
        x = self.conv(x)
        return x

class DoubleConv2(nn.Module):
    def __init__(self,in_ch, out_ch):
        super(DoubleConv2, self).__init__()
        ###在init中定义各个部分
        self.conv = nn.Sequential(conv3x3(in_ch, out_ch),
                                    nn.BatchNorm3d(out_ch),
                                    nn.ReLU(inplace=True),
                                    conv3x3(out_ch, out_ch),
                                    nn.BatchNorm3d(out_ch),
                                    nn.ReLU(inplace=True)) 

    def forward(self, x):
        ##forward函数把init函数中的各个部分连接起来
        x = self.conv(x)
        return x


class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool3d(2),
            ###DoubleConv1是自己定义的类,自己定义的类要放在init
            DoubleConv1(in_ch, out_ch)
        )

    
    def forward(self, x):
        x = self.mpconv(x)
        return x

class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Up, self).__init__()
        ###在init中定义各个部分
        self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        ###DoubleConv2是自己定义的类,自己定义的类要放在init
        self.conv = DoubleConv2(in_ch, out_ch)
   
    def forward(self, x1,x2):
        ##forward函数把init函数中的各个部分连接起来
        ##注意这个forword函数是有两个输入
        x1=self.up(x1)
        ##将两个通道融合在一起
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)

class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = nn.Sequential(conv3x3(in_ch, out_ch//2),
                                    nn.BatchNorm3d(out_ch//2),
                                    nn.ReLU(inplace=True),
                                    conv3x3(out_ch//2, out_ch),
                                    nn.BatchNorm3d(out_ch),
                                    nn.ReLU(inplace=True)) 

    def forward(self, x):
        
        x = self.conv(x)
        return x

class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv3d(in_ch, out_ch, 1)

    def forward(self, x):
        
        x = self.conv(x)
        return x


class UNet3D(nn.Module):
    def __init__(self,in_channels):
        super(UNet3D, self).__init__()

        self.inc=inconv(in_channels,64)

        self.down1=Down(64,128)
        self.down2=Down(128,256)
        self.down3=Down(256,512)

        self.up1 = Up(512, 256)
        self.up2 = Up(256, 128)
        self.up3 = Up(128, 64)

        self.outc = outconv(64, 2)

    def forward(self, x):

        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)

        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)

        x = self.outc(x)

        return x


model=UNet3D(3)
print(model)
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值