unet模型学习笔记
前言
提示:这里可以添加本文要记录的大概内容:
例如:随着人工智能的不断发展,机器学习这门技术也越来越重要,很多人都开启了学习机器学习,本文就介绍了机器学习的基础内容。
提示:以下是本篇文章正文内容,下面案例可供参考
一、模型图
二、依据模型图搭建模型
1.图中可以看出每层有两次卷积核,将图像输入后,进行两次卷积,每次卷积跟一个批归一化和激活函数激活
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),#padding为1,卷积后图片尺寸大小不变
nn.BatchNorm2d(out_channels),#批归一化
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)##经过两次卷积,每次卷积后跟着批归一化和relu激活函数
def forward(self, x):
return self.double_conv(x)
下采样:
下采样用最大池化,同时改变图像通道数
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
## Down:经过一次最大池化(图像缩小一倍)后再经过两次卷积
上采样:
上采样有反卷积操作、双线性差值、最邻近插值等等,pytorch有封装好的API
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
双线性插值很好理解,pytorch中反卷积操作可以参考:反卷积和上采样
最后是结果输出,输出分类结果的特征图
class OutConv(nn.Module):
def init(self, in_channels, out_channels):
super(OutConv, self).init()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
最后再申明一个类,把上述模块按照模型图搭积木搭起来就可以了。