前言
接下来将要实战自定义模型,本篇博客参考了:pytorch教程之nn.Module类详解——使用Module类来自定义模型
步骤
在自定义网络模型时,需要继承nn.Module类,并且重新实现__init__和forward这两个方法
一、简单用法
1、把可学习参数的层和不具有学习参数的层都放到构造函数中
先来看一个简单的例子
import torch
class MyNet(torch.nn.Module):
def __init__(self):
super(MyNet, self).__init__() # 第一句话,调用父类的构造函数
self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
self.relu1=torch.nn.ReLU()
self.max_pooling1=torch.nn.MaxPool2d(2,1)
self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)
self.relu2=torch.nn.ReLU()
self.max_pooling2=torch.nn.MaxPool2d(2,1)
self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
self.dense2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.max_pooling1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.max_pooling2(x)
x = self.dense1(x)
x = self.dense2(x)
return x
model = MyNet()
print(model)
'''运行结果为:
MyNet(
(conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu1): ReLU()
(max_pooling1): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu2): ReLU()
(max_pooling2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
(dense1): Linear(in_features=288, out_features=128, bias=True)
(dense2): Linear(in_features=128, out_features=10, bias=True)
)
'''
注意:上面的是将所有的层都放在了构造函数__init__里面,但是只是定义了一系列的层,各个层之间到底是什么连接关系并没有,而是在forward里面实现所有层的连接关系
二、高级用法
通过Sequential来包装层,把一些重复的层包装到Sequential中
这次使用UNet3D网络的例子,可以等到图中橙色的箭头代表conv(+BN)+Relu
可以使用Sequential包装起来,每次调用自己写的包装可以减少代码量
##自己定义的类 要实现两部分功能
##1、init函数中要说明输入和输出in_ch out_ch
##2、在forward函数中把各个部分连接起来
##in_ch out_ch
class DoubleConv2(nn.Module):
def __init__(self,in_ch, out_ch):
super(DoubleConv2, self).__init__()
###在init中定义各个部分
self.conv = nn.Sequential(conv3x3(in_ch, out_ch),
nn.BatchNorm3d(out_ch),
nn.ReLU(inplace=True),
conv3x3(out_ch, out_ch),
nn.BatchNorm3d(out_ch),
nn.ReLU(inplace=True))
def forward(self, x):
##forward函数把init函数中的各个部分连接起来
x = self.conv(x)
return x
完整的UNet3D代码
import torch as torch
import torchvision as tv
import torch.nn as nn
import numpy as np
def conv3x3(in_ch, out_ch, stride=1):
"""3x3 convolution with padding"""
return nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=stride,
padding=1, bias=False)
class DoubleConv1(nn.Module):
def __init__(self,in_ch, out_ch):
super(DoubleConv1, self).__init__()
self.conv = nn.Sequential(conv3x3(in_ch, in_ch),
nn.BatchNorm3d(in_ch),
nn.ReLU(inplace=True),
conv3x3(in_ch, out_ch),
nn.BatchNorm3d(out_ch),
nn.ReLU(inplace=True))
def forward(self, x):
x = self.conv(x)
return x
class DoubleConv2(nn.Module):
def __init__(self,in_ch, out_ch):
super(DoubleConv2, self).__init__()
###在init中定义各个部分
self.conv = nn.Sequential(conv3x3(in_ch, out_ch),
nn.BatchNorm3d(out_ch),
nn.ReLU(inplace=True),
conv3x3(out_ch, out_ch),
nn.BatchNorm3d(out_ch),
nn.ReLU(inplace=True))
def forward(self, x):
##forward函数把init函数中的各个部分连接起来
x = self.conv(x)
return x
class Down(nn.Module):
def __init__(self, in_ch, out_ch):
super(Down, self).__init__()
self.mpconv = nn.Sequential(
nn.MaxPool3d(2),
###DoubleConv1是自己定义的类,自己定义的类要放在init
DoubleConv1(in_ch, out_ch)
)
def forward(self, x):
x = self.mpconv(x)
return x
class Up(nn.Module):
def __init__(self, in_ch, out_ch):
super(Up, self).__init__()
###在init中定义各个部分
self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
###DoubleConv2是自己定义的类,自己定义的类要放在init
self.conv = DoubleConv2(in_ch, out_ch)
def forward(self, x1,x2):
##forward函数把init函数中的各个部分连接起来
##注意这个forword函数是有两个输入
x1=self.up(x1)
##将两个通道融合在一起
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
class inconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
self.conv = nn.Sequential(conv3x3(in_ch, out_ch//2),
nn.BatchNorm3d(out_ch//2),
nn.ReLU(inplace=True),
conv3x3(out_ch//2, out_ch),
nn.BatchNorm3d(out_ch),
nn.ReLU(inplace=True))
def forward(self, x):
x = self.conv(x)
return x
class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv3d(in_ch, out_ch, 1)
def forward(self, x):
x = self.conv(x)
return x
class UNet3D(nn.Module):
def __init__(self,in_channels):
super(UNet3D, self).__init__()
self.inc=inconv(in_channels,64)
self.down1=Down(64,128)
self.down2=Down(128,256)
self.down3=Down(256,512)
self.up1 = Up(512, 256)
self.up2 = Up(256, 128)
self.up3 = Up(128, 64)
self.outc = outconv(64, 2)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x = self.up1(x4, x3)
x = self.up2(x, x2)
x = self.up3(x, x1)
x = self.outc(x)
return x
model=UNet3D(3)
print(model)