构建一个简单的U-Net网络
注:本文仅搭建网络结构,并未实现训练及预测
注:以下网络仅供学习交流,实际在投入训练过程中loss不能正确收敛,请不要直接使用,如有大佬知道什么情况,烦请指点一二
文章目录
一、导包
import torch
from torch import nn
二、下采样
'''
下采样,即U-Net的左半部分
'''
class block_down(nn.Module):
def __init__(self,inp_channel,out_channel):
"""
:param inp_channel: 输入通道数
:param out_channel: 输出通道数
"""
# 调用父类方法,传block_down是父类名字
super(block_down,self).__init__()
# 所有的卷积层添加padding=1会填充1个像素点,实现输入和输出的维度相同,也可以不选
# 注:本文中所有关于维度的数据均未添加padding,如果需要添加padding,需要自己逐步调试计算维度
# self.conv1 = nn.Conv2d(inp_channel, out_channel, 3, 1,padding=1)
# self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1,padding=1)
# 在这里面定义卷积层、标准化层和激活层,方便在forward方法中调用
self.conv1=nn.Conv2d(inp_channel,out_channel,3,1)
self.conv2=nn.Conv2d(out_channel,out_channel,3,1)
# BatchNorm2d层是批量标准化层,可以加快收敛速度
self.bn=nn.BatchNorm2d(out_channel)
# 激活函数
self.relu=nn.ReLU6(inplace=True)
'''
这个里面就是卷积两次,就是U-Net网络左半部分的每一行
注意每次卷积之后都要标准化和激活
'''
def forward(self,x):
x=self.conv1(x)
x=self.bn(x)
x=self.relu(x)
x=self.conv2(x)
x=self.bn(x)
x=self.relu(x)
return x
三、上采样
'''
上采样模块,结构的右半部分
'''
class block_up(nn.Module):
def __init__(self,inp_channel,out_channel,y):
super(block_up,self).__init__()
# 使用卷积转置实现上采样,增加输入特征图的尺寸
self.up=nn.ConvTranspose2d(inp_channel,out_channel,2,2)
# 所有的卷积层添加padding=1会填充1个像素点,实现输入和输出的维度相同,也可以不选
self.conv1=nn.Conv2d(inp_channel,out_channel,3,1)
self.conv2=nn.Conv2d(out_channel,out_channel,3,1)
self.bn=nn.BatchNorm2d(out_channel)
self.relu=nn.ReLU6(inplace=True)
self.y=y
def forward(self,x):
x = self.up(x) # 上采样
'''
需要对y进行处理,因为传入的y和上采样得到的维度不同,无法直接进行拼接,需要进行裁剪
首先以第一层上采样为例
block6=block_up(1024,512,x4_use)
x6=block6(x5)
传入的x4_use的size是 ([1, 512, 52, 72]),在这里是y
要与其拼接的x5的size是 ([1, 1024,22, 32]),在这里是x
x经过一层上采样后的size是 ([1, 512, 44, 64])
显然第三个维度不一致,无法进行拼接,y的第三个维度较大,需要进行裁剪
在这里采用头尾裁切,取中间的方法
所以有了关于delta的计算
计算思路为52-44=8,两个数据差8,头尾各去掉4,即可实现拼接
第四个维度同理
即self.y=self.y[:,:,delta:self.y.shape[2]-delta,delta:self.y.shape[3]-delta]
'''
if self.y.shape[2]!=x.shape[2]:
delta1=self.y.shape[2]-x.shape[2]
delta=delta1//2
'''
以第二层
block7=block_up(512,256,x3_use)
x7=block7(x6)
这层为例
进入该方法时x3_use的Size是 ([1, 256, 113, 153]),在这个方法里是y
要与其拼接的x6的Size是 ([1, 512, 40, 60]),在这个方法里是x
经过一层上采样后,x的Size变为([1, 256, 80, 120]),显然也无法进行拼接
经过计算delta1=33,delta=16
这里存在一个问题,delta是奇数,经过整除运算余1,两边同时去掉delta还差1,维度同样不匹配
所以需要在一边减去这个多余1,所以有了如下关于delta1奇偶的判断
如果为奇数,则上限多-1
'''
if delta1%2==0:
self.y=self.y[:,:,delta:self.y.shape[2]-delta,delta:self.y.shape[3]-delta]
else:
self.y=self.y[:,:,delta:self.y.shape[2]-delta-1,delta:self.y.shape[3]-delta-1]
# 将x和y在第二个维度上进行拼接,就是第二个维度的加法操作
x=torch.cat([x,self.y],dim=1)
# 正常的每一行卷两下
x=self.conv1(x)
x=self.bn(x)
x=self.relu(x)
x=self.conv2(x)
x=self.bn(x)
x=self.relu(x)
return x
四、构建网络
class U_net(nn.Module):
def __init__(self,out_channel):
super(U_net,self).__init__()
self.out=nn.Conv2d(64,out_channel,1)
# 使用最大池化实现下采样
self.maxpool=nn.MaxPool2d(2)
'''
U-Net网络的架构
'''
def forward(self,x):
# 下采样层
# 一个block就是左半或者右半的一横行,就是卷积两下
# use的作用在于crop and copy,裁剪和复制
# use必须留在这里用于后续和上采样的对应层进行拼接,所以没有写进方法
# 最大池化实现下采样
block1=block_down(3,64)
x1_use=block1(x) # torch.Size([1, 64, 476, 636])
x1=self.maxpool(x1_use) #torch.Size([1, 64, 238, 318])
block2=block_down(64,128) # torch.Size([1, 128, 119, 158])
x2_use=block2(x1) # torch.Size([1, 128, 234, 314])
x2=self.maxpool(x2_use) #torch.Size([1, 256, 117, 157])
block3=block_down(128,256)
x3_use=block3(x2) # torch.Size([1, 256, 113, 153])
x3=self.maxpool(x3_use) # torch.Size([1, 256, 56, 76])
block4=block_down(256,512)
x4_use=block4(x3) # torch.Size([1, 512, 52, 72])
x4=self.maxpool(x4_use) # torch.Size([1, 512, 26, 36])
# 这层不需要池化了,到底层了
block5=block_down(512,1024)
x5=block5(x4) # torch.Size([1, 1024, 22, 32])
# 上采样层
# 上采样时我们把转置卷积写在方法里面了,所以没有跨行操作
block6=block_up(1024,512,x4_use)
x6=block6(x5)
block7=block_up(512,256,x3_use)
x7=block7(x6)
block8=block_up(256,128,x2_use)
x8=block8(x7)
block9=block_up(128,64,x1_use)
x9=block9(x8)
x10=self.out(x9)
out=nn.Softmax2d()(x10)
return out
五、测试
'''
main只是在测试输入输出的size是否一致,并没有执行训练,本文也并没有写训练相关代码
'''
if __name__=="__main__":
# 创建一个测试输入,是一个随机图像
test_input=torch.rand(1,3,480,640)
# 输出测试输入的size
print("test_input:",test_input.size())
# 创建模型,输出通道为3
model=U_net(out_channel=3)
# 进行前向传播
output=model(test_input)
# 输出测试size
print("output size:",output.size())