unet是非常经典的图像分割的网络,因网络结构形似字母U而著称
实现起来不是很复杂,代码如下:
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
class unet(nn.Module):
def __init__(self):
super().__init__()
#conv1
self.conv1=nn.Sequential(
nn.Conv2d(1,64,3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64,64,3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True))
#conv2
self.conv2=nn.Sequential(
nn.Conv2d(64,128,3),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128,128,3),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True))
#conv3
self.conv3=nn.Sequential(
nn.Conv2d(128,256,3),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256,256,3),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True))
#conv4
self.conv4=nn.Sequential(
nn.Conv2d(256,512,3),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512,512,3),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True))
#conv5
self.conv5=nn.Sequential(
nn.Conv2d(1024,512,3),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512,256,3),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True))
#conv6
self.conv6=nn.Sequential(
nn.Conv2d(512,256,3),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256,128,3),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True))
#conv7
self.conv7=nn.Sequential(
nn.Conv2d(256,128,3),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128,64,3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True))
self.trans=nn.Sequential(
nn.Conv2d(512,1024,kernel_size=3),
nn.BatchNorm2d(1024),
nn.ReLU(inplace=True),
nn.Conv2d(1024,512,3),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True))
self.end=nn.Sequential(
nn.Conv2d(128,64,kernel_size=3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64,64,3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64,2,3),
nn.BatchNorm2d(2),
nn.ReLU(inplace=True)
)
self.unSample=nn.Upsample(mode='bilinear',scale_factor=2)
self.pool=nn.MaxPool2d(kernel_size=2,stride=2)
def forward(self,x):
out_conv1=self.conv1(x)
out=self.pool(out_conv1)
out_conv2=self.conv2(out)
out=self.pool(out_conv2)
out_conv3=self.conv3(out)
out=self.pool(out_conv3)
out_conv4=self.conv4(out)
out=self.pool(out_conv4)
out=self.trans(out)
out=self.unSample(out)
out=torch.cat((out,out_conv4[:,:,:out.shape[2],:out.shape[3]]),1)
out=self.conv5(out)
out=self.unSample(out)
out=torch.cat((out,out_conv3[:,:,:out.shape[2],:out.shape[3]]),1)
out=self.conv6(out)
out=self.unSample(out)
out=torch.cat((out,out_conv2[:,:,:out.shape[2],:out.shape[3]]),1)
out=self.conv7(out)
out=self.unSample(out)
out=torch.cat((out,out_conv1[:,:,:out.shape[2],:out.shape[3]]),1)
return self.end(out)
if __name__=='__main__':
input = torch.randn(1,1,572,572)
net=unet()
output = net(input)
print(output)
torch.save(net,'unet.pth')
torch.onnx.export(net, input, "unet.onnx", export_params=True, opset_version=10,
do_constant_folding=True, input_names = ['input'], output_names = ['output'],
)
运行结果如下:
此代码没有进行训练,是直接使用初始化权重的结果