pytorch实现unet

unet是非常经典的图像分割的网络,因网络结构形似字母U而著称

实现起来不是很复杂,代码如下:

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn

class unet(nn.Module):
    def __init__(self):
        super().__init__()
        #conv1
        self.conv1=nn.Sequential(
        nn.Conv2d(1,64,3),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.Conv2d(64,64,3),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True))
        #conv2
        self.conv2=nn.Sequential(
        nn.Conv2d(64,128,3),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),
        nn.Conv2d(128,128,3),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True))
        #conv3
        self.conv3=nn.Sequential(
        nn.Conv2d(128,256,3),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        nn.Conv2d(256,256,3),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True))
        #conv4
        self.conv4=nn.Sequential(
        nn.Conv2d(256,512,3),
        nn.BatchNorm2d(512),
        nn.ReLU(inplace=True),
        nn.Conv2d(512,512,3),
        nn.BatchNorm2d(512),
        nn.ReLU(inplace=True))
        #conv5
        self.conv5=nn.Sequential(
        nn.Conv2d(1024,512,3),
        nn.BatchNorm2d(512),
        nn.ReLU(inplace=True),
        nn.Conv2d(512,256,3),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True))
        #conv6
        self.conv6=nn.Sequential(
        nn.Conv2d(512,256,3),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        nn.Conv2d(256,128,3),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True))
        #conv7
        self.conv7=nn.Sequential(
        nn.Conv2d(256,128,3),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),
        nn.Conv2d(128,64,3),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True))
        
        self.trans=nn.Sequential(
        nn.Conv2d(512,1024,kernel_size=3),
        nn.BatchNorm2d(1024),
        nn.ReLU(inplace=True),
        nn.Conv2d(1024,512,3),
        nn.BatchNorm2d(512),
        nn.ReLU(inplace=True))
        
        self.end=nn.Sequential(
        nn.Conv2d(128,64,kernel_size=3),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.Conv2d(64,64,3),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.Conv2d(64,2,3),
        nn.BatchNorm2d(2),
        nn.ReLU(inplace=True)
        )
        
        self.unSample=nn.Upsample(mode='bilinear',scale_factor=2)
        self.pool=nn.MaxPool2d(kernel_size=2,stride=2)
        
    def forward(self,x):
        out_conv1=self.conv1(x)
        out=self.pool(out_conv1)
        out_conv2=self.conv2(out)
        out=self.pool(out_conv2)
        out_conv3=self.conv3(out)
        out=self.pool(out_conv3)
        out_conv4=self.conv4(out)
        out=self.pool(out_conv4)
        out=self.trans(out)
        out=self.unSample(out)
        out=torch.cat((out,out_conv4[:,:,:out.shape[2],:out.shape[3]]),1)
        out=self.conv5(out)
        out=self.unSample(out)
        out=torch.cat((out,out_conv3[:,:,:out.shape[2],:out.shape[3]]),1)
        out=self.conv6(out)
        out=self.unSample(out)
        out=torch.cat((out,out_conv2[:,:,:out.shape[2],:out.shape[3]]),1)        
        out=self.conv7(out)
        out=self.unSample(out)
        out=torch.cat((out,out_conv1[:,:,:out.shape[2],:out.shape[3]]),1)
        return self.end(out)
    
if __name__=='__main__':
    input = torch.randn(1,1,572,572)
    net=unet()
    output = net(input)
    print(output)
    torch.save(net,'unet.pth')    
    torch.onnx.export(net, input,  "unet.onnx", export_params=True, opset_version=10,          
                  do_constant_folding=True,   input_names = ['input'],   output_names = ['output'], 
)
    

运行结果如下:

 此代码没有进行训练,是直接使用初始化权重的结果

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值