最近看的paper里的pytorch代码太复杂,我之前也没接触过pytorch,遂决定先自己实现一个基础的裸代码,这样走一遍,对跑网络的基本流程和一些常用的基础函数的印象会更深刻。
本文的代码和数据主要来自pytorch笔记:05)UNet网络简单实现_Javis486的专栏-CSDN博客,
附上该博主的github地址:https://github.com/JavisPeng/u_net_liver
并在自己的理解的基础上做了一些改动,以及加了大量注释。
如有错误,欢迎指出。
unet.py(实现unet网络)
import torch.nn as nn
import torch
class DoubleConv(nn.Module):
def __init__(self,in_ch,out_ch):
super(DoubleConv,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch,out_ch,3,padding=1),#in_ch、out_ch是通道数
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace = True),
nn.Conv2d(out_ch,out_ch,3,padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace = True)
)
def forward(self,x):
return self.conv(x)
class UNet(nn.Module):
def __init__(self,in_ch,out_ch):
super(UNet,self).__init__()
self.conv1 = DoubleConv(in_ch,64)
self.pool1 = nn.MaxPool2d(2)#每次把图像尺寸缩小一半
self.conv2 = DoubleConv(64,128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(128,256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(256,512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(512,1024)