U_net 网络(pytorch学习)
U_net是一个经典的图像分割网络,可以完成许多功能,在学习U_net网络后结合B站的视频尝试编写U_net代码,锻炼编程能力
一 U_net网络结构
U_net网络的网络结构如下图所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-alsUIo1Z-1662831702017)(“C:\Users\23577\Desktop[2H[T27D8~{THTHD]]VPJD6.png”)]
网络模型代码
步骤:
-
先定义下采样网络既两个卷积
import torch import torchvision from torch import nn class Double_conv(nn.Module): def __init__(self,in_channel, out_channel): super(Double_conv, self).__init__() """ 在这里使用卷积,保持图像尺寸不变,以便更好计算 """ self.layer = nn.Sequential( nn.Conv2d(in_channel,out_channel,kernel_size=3,padding=1,bias=False), nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True), nn.Conv2d(out_channel,out_channel,kernel_size=3,padding=1,bias=False), nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True), ) def forward(self,x): return self.layer(x) urn self.layer(x)
-
定义U_net整个网络模型
class U_NET(nn.Module): def __init__(self, in_channel,out_channel,features=[64,128,256,512]): super(U_NET, self).__init__() self.DOWN = nn.ModuleList() self.UP = nn.ModuleList() self.maxpool = nn.MaxPool2d(2) for feature in features: self.DOWN.append(Double_conv(in_channel,feature)) in_channel = feature for feature in reversed(features): self.UP.append(nn.ConvTranspose2d(feature*2,feature,kernel_size=2,stride=2)) self.UP.append(Double_conv(feature*2,feature)) self.botten = Double_conv(features[-1],features[-1]*2) self.final_conv = nn.Conv2d(features[0],out_channel,kernel_size=1,padding=0) def forward(self,x): skip_connect=[] for idx in self.DOWN: x = idx(x) skip_connect.append(x) x = self.maxpool(x) x = self.botten(x) skip_connect = skip_connect[::-1] for idx in range(0,len(self.UP),2): x = self.UP[idx](x) """ 为了适用,任意尺寸的图片特征融合时为保证尺寸相同Resize一下 """ if x.shape != skip_connect[idx // 2].shape: x = torchvision.transforms.Resize( skip_connect[idx // 2].shape[2:])(x) x = torch.cat((x,skip_connect[idx//2]),dim=1) x = self.UP[idx+1](x) return self.final_conv(x)
-
测试结果
if __name__ == "__main__": x = torch.randn(1,1,161,161) model = U_NET(in_channel=1,out_channel=1) y = model(x) print(y.shape) # -> torch.Size([1, 1, 161, 161])