【1】网络结构
Unet包括两部分:
1 特征提取部分,每经过一个池化层就一个尺度,包括原图尺度一共有5个尺度。
2 上采样部分,每上采样一次,就和特征提取部分对应的通道数相同尺度融合,但是融合之前要将其crop。这里的融合也是拼接。
该网络由收缩路径(contracting path)和扩张路径(expanding path)组成。其中,收缩路径用于获取上下文信
【1.1】网络优点
(1) overlap-tile策略
(2)数据增强(data augmentation)
(3)加权loss
【1.2】网络缺点
U-Net++作者分析U-Net不足并如何做改进:https://zhuanlan.zhihu.com/p/44958351
参考文献:https://zhuanlan.zhihu.com/p/118540575
【2】网络训练
代码以及权重下载地址:https://github.com/JavisPeng/u_net_liver
data and trained weight link: https://pan.baidu.com/s/1dgGnsfoSmL1lbOUwyItp6w code: 17yr
all dataset you can access from: https://competitions.codalab.org/competitions/15595
【2.1】代码展示
文件夹介绍
(1)data文件夹中放的是训练和测试的图片
(2)dowoload是下载的权重文件
Unetmodel.py
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File : Unetmodel.py
@Time : 2021/03/23 20:09:25
@Author : Jian Song
@Contact : 1248975661@qq.com
@Desc : None
'''
# here put the import lib
import torch.nn as nn
import torch
from torch import autograd
'''
文件介绍:定义了unet网络模型,
******pytorch定义网络只需要定义模型的具体参数,不需要将数据作为输入定义到网络中。
仅需要在使用时实例化这个网络,然后将数据输入。
******tensorflow定义网络时则需要将输入张量输入到模型中,即用占位符完成输入数据的输入。
'''
#把常用的2个卷积操作简单封装下
class DoubleConv(nn.Module):
#通过此处卷积,特征图的大小减4,但是通道数保持不变;
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
#添加了BN层
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, input):
return self.conv(input)
class Unet(nn.Module):
def __init__(self, in_ch, out_ch):
super(Unet, self).__init__()
#定义网络模型
#下采样-》编码
self.conv1 = DoubleConv(in_ch, 64)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(512, 1024)
# 逆卷积,也可以使用上采样(保证k=stride,stride即上采样倍数)
self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)#反卷积
self.conv6 = DoubleConv(1024, 512)
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = DoubleConv(512, 256)
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv8 = DoubleConv(256, 128)
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv9 = DoubleConv(128, 64)
self.conv10 = nn.Conv2d(64, out_ch, 1)
#定义网络前向传播过程
def forward(self, x):
c1 = self.conv1(x)
p1 = self.pool1(c1)
c2 = self.conv2(p1)
p2 = self.pool2(c2)
c3 = self.conv3(p2)
p3 = self.pool3(c3)
c4 = self.conv4(p3)
p4 = self.pool4(c4)
c5 = self.conv5(p4)
#上采样
up_6 = self.up6(c5)
#cat函数讲解:https://www.cnblogs.com/JeasonIsCoding/p/10162356.html
merge6 = torch.cat([up_6, c4], dim=1)#此处横着拼接,dim=1表示在行的后面添加上原有矩阵
c6 = self.conv6(merge6)
up_7 = self.up7(c6)
merge7 = torch.cat([up_7, c3], dim=1)
c7 = self.conv7(merge7)
up_8 = self.up8(c7)
merge8 = torch.cat([up_8, c2], dim=1)
c8 = self.conv8(merge8)
up_9 = self.up9(c8)
merge9 = torch.cat([up_9, c1], dim=1)
c9 = self.conv9(merge9)
c10 = self.conv10(c9)
out = nn.Sigmoid()(c10)
return out
if __name__ == '__main__':
myUnet=Unet(1,1)
print(myUnet)
main.py
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
(1)参考文献:UNet网络简单实现
https://blog.csdn.net/jiangpeng59/article/details/80189889
(2)FCN和unet的区别
https://zhuanlan.zhihu.com/p/118540575
'''
import torch
import argparse
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision.transforms import transforms
from Unetmodel import Unet
from setdata import LiverDataset
from setdata import *
# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#定义输入数据的预处理模式,因为分为原始图片和研磨图像,所以也分为两种
#image转换为0~1的数据类型
x_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# mask只需要转换为tensor
y_transforms = transforms.ToTensor()
def train_model(model, criterion, optimizer, dataload, num_epochs=5):
for epoch in range(num_epochs):
#.format参考,https://blog.csdn.net/u012149181/article/details/78965472
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
dt_size = len(dataload.dataset)
epoch_loss = 0
step = 0
for x, y in dataload:
step += 1
#判断是否调用GPU
inputs = x.to(device)
labels = y.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
outputs = model(inputs)
loss = criterion(outputs, labels) #计算损失值
loss.backward()
optimizer.step()
#item()是得到一个元素张量里面的元素值
epoch_loss += loss.item()
print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
print("epoch %d loss:%0.3f" % (epoch, epoch_loss/step))
#保存模型
torch.save(model.state_dict(), 'weights_%d.pth' % epoch)
return model
#训练模型
def train(batch_size):
#模型初始化
model = Unet(3, 1).to(device)
batch_size = batch_size
#定义损失函数
criterion = nn.BCEWithLogitsLoss()
#定义优化器
optimizer = optim.Adam(model.parameters())
#加载训练数据
liver_dataset = LiverDataset("data/train",transform=x_transforms,target_transform=y_transforms)
dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
train_model(model, criterion, optimizer, dataloaders)
#模型的测试结果
def test(ckptpath):
model = Unet(3, 1)
model.load_state_dict(torch.load(ckptpath,map_location='cpu'))
liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms)
#一次加载一张图像
dataloaders = DataLoader(liver_dataset, batch_size=1)
#eval函数是将字符串转化为list、dict、tuple,但是字符串里的字符必须是标准的格式,不然会出错
model.eval()
import matplotlib.pyplot as plt
plt.ion()# 打开交互模式
with torch.no_grad():
for x, _ in dataloaders:
y=model(x).sigmoid()
#a.squeeze(i) 压缩第i维,如果这一维维数是1,则这一维可有可无,便可以压缩
img_y=torch.squeeze(y).numpy()
plt.imshow(img_y)
plt.pause(0.01)
plt.show()
# def trainmodel(batchsize):
# train(batchsize)
# def testmodel(ckptpath):
# test(ckptpath)
if __name__ == '__main__':
#参数解析
# parse=argparse.ArgumentParser()
# parse = argparse.ArgumentParser()
# parse.add_argument("action", type=str, help="train or test")
# parse.add_argument("--batch_size", type=int, default=8)
# parse.add_argument("--ckpt", type=str, help="the path of model weight file")
# args = parse.parse_args()
# if args.action=="train":
# train(args)
# elif args.action=="test":
# test(args)
batchsize=10
train(batchsize)
ckptpath='./dowoload/weights_19.pth'
test(ckptpath)
setdata.py
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from torch.utils.data import Dataset
import PIL.Image as Image
import os
#创建一个列表,存放图像和研磨图像的图像路径
def make_dataset(root):
imgs=[]
n=len(os.listdir(root))//2
for i in range(n):
'''
%3d--可以指定宽度,不足的左边补空格
%-3d--左对齐
%03d---一种左边补0 的等宽格式,比如数字12,%03d出来就是: 012
'''
#img=root+00i.png
#mask=root+00i_mask.png
img=os.path.join(root,"%03d.png"%i)
mask=os.path.join(root,"%03d_mask.png"%i)
imgs.append((img,mask))
return imgs
class LiverDataset(Dataset):
def __init__(self, root, transform=None, target_transform=None):
imgs = make_dataset(root)
self.imgs = imgs
self.transform = transform #原始图像的预处理
self.target_transform = target_transform #研磨图像的预处理
def __getitem__(self, index):
x_path, y_path = self.imgs[index]
img_x = Image.open(x_path)
img_y = Image.open(y_path)
if self.transform is not None: #若设置了预处理
img_x = self.transform(img_x)
if self.target_transform is not None: #若设置了预处理
img_y = self.target_transform(img_y)
return img_x, img_y
def __len__(self):
return len(self.imgs)
【3】UNet模型参数展示
Unet(
(conv1): DoubleConv(
(conv): Sequential(
(0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): DoubleConv(
(conv): Sequential(
(0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv3): DoubleConv(
(conv): Sequential(
(0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv4): DoubleConv(
(conv): Sequential(
(0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv5): DoubleConv(
(conv): Sequential(
(0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(up6): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
(conv6): DoubleConv(
(conv): Sequential(
(0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(up7): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
(conv7): DoubleConv(
(conv): Sequential(
(0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(up8): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
(conv8): DoubleConv(
(conv): Sequential(
(0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(up9): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
(conv9): DoubleConv(
(conv): Sequential(
(0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(conv10): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)
PS F:\PytorchTest\torchdeeplearnmodel\Unet> & G:/Anaconda3/envs/tensorflow/python.exe f:/PytorchTest/torchdeeplearnmodel/Unet/Unetmodel.py
Unet(
(conv1): DoubleConv(
(conv): Sequential(
(0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): DoubleConv(
(conv): Sequential(
(0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv3): DoubleConv(
(conv): Sequential(
(0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv4): DoubleConv(
(conv): Sequential(
(0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv5): DoubleConv(
(conv): Sequential(
(0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(up6): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
(conv6): DoubleConv(
(conv): Sequential(
(0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(up7): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
(conv7): DoubleConv(
(conv): Sequential(
(0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(up8): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
(conv8): DoubleConv(
(conv): Sequential(
(0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(up9): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
(conv9): DoubleConv(
(conv): Sequential(
(0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(conv10): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)
【4】参考文献
(2)UNet网络简单实现
(3)FCN和unet的区别