用于记录初次手搓U-net遇见的一些问题。
(放上最经典的U-net)
import torch
from torch import nn
from torchsummary import summary
import torch.nn.functional as F
## 本U-Net与经典的有所不同,本文的卷积层都添加的一层padding,故输出大小不发生改变。
## 需要分割几块,就在最后的1*1的卷积设置对应的个数(本文设置为2)
#2卷积
class Conv(nn.Module):
def __init__(self,in_channel,out_channel):
super(Conv,self).__init__()
self.Relu=nn.ReLU()
self.con1=nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=3,padding=1)
self.con2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3,padding=1)
def forward(self,x):
x=self.Relu(self.con1(x))
x = self.Relu(self.con2(x))
return x
#下采样
class Down(nn.Module):
def __init__(self,Conv):
super(Down,self).__init__()
self.b1=nn.Sequential(
Conv(1,64)
)
self.b2=nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
Conv(64,128)
)
self.b3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
Conv(128, 256),
)
self.b4 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
Conv(256, 512),
)
self.b5 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
Conv(512, 1024),
)
def forward(self,x):
x1= self.b1(x)
x2=self.b2(x1)
x3 = self.b3(x2)
x4 = self.b4(x3)
x5 = self.b5(x4)
return x1,x2,x3,x4,x5
#上采样
class Up(nn.Module):
def __init__(self,Conv,Down):
super(Up,self).__init__()
self.up_sample = lambda x: F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
self.b1 = Conv(1024,512)
self.b2 = Conv(512, 256)
self.b3 = Conv(256, 128)
self.b4 = Conv(128, 64)
self.con1=nn.Conv2d(in_channels=64,out_channels=2,kernel_size=1)
def forward(self,x):
x1, x2, x3, x4, x5=Down(Conv)(x)
y1=self.up_sample(x5)
y1 = nn.Conv2d(1024, 512, kernel_size=1)(y1)
y1 = self.b1(torch.cat((x4, y1), dim=1))
y2 = self.up_sample(y1)
y2 = nn.Conv2d(512, 256, kernel_size=1)(y2)
y2 = self.b2(torch.cat((x3, y2), dim=1))
y3 = self.up_sample(y2)
y3 = nn.Conv2d(256, 128, kernel_size=1)(y3)
y3 = self.b3(torch.cat((x2, y3), dim=1))
y4 = self.up_sample(y3)
y4 = nn.Conv2d(128, 64, kernel_size=1)(y4)
y4 = self.b4(torch.cat((x1, y4), dim=1))
y5= self.con1(y4)
return y5
if __name__ == "__main__":
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = Up(Conv,Down).to(device)
print(summary(model,(1,224,224))) #这个函数可以通过输入图片大小,展现每一层图片处理时的大小