精读论文U-Net: Convolutional Networks for Biomedical Image Segmentation(附翻译和代码)

U-Net: Convolutional Networks for Biomedical Image Segmentation

原文链接
写于2020年7月21日,本学期精读的第一篇论文,加油菜菜的小孙同学!(代码暂未复现)

一、总结

1. motivation

  • 经典卷积网络大部分都是针对图像分类任务的,但是在一些特定场景,如医疗图像处理领域,应是pixel-wise像素级的处理,也就是图像分割。图像分割指进行端对端操作,输入和输出应保持同一维度,输入一张待分割图像,输出应该是分割好的图像。
  • 已有网络需要大量标注训练样本,生物医学任务中没有数千个标注的数据集,所以需要对数据进行数据扩张。
  • 为了解决前两个问题,Ciresan用滑窗法来解决该问题。sliding-window 算法提出预测patch的类别(patch指像素周围的局部区域)。 Deep neural networks segment neuronal membranes in electron microscopy images. 该算法有两个主要问题:一、由于每个patch都需要训练导致这个算法很慢。二、分割准确率和上下文关系需要平衡的,平衡取决于patch的大小,patch越大需要pooling越多准确率越低,patch越小则不具备上下文联系。也就是说滑窗法的感受野(有效视野)大小和分割精度呈负相关关系。
  • 分割速度应更快,分割准确率也应提高。

2. unet

网络:
  作者以FCN全卷积神经网络(Fully Convolution Networks for Semantic Segmentation)为基础设计了Unet,其中包含两条串联的路径:contracting path用来提取图像特征,捕捉context,将图像压缩为由特征组成的feature maps;expanding path用来精准定位,precise localization,将提取的特征解码为与原始图像尺寸一样的分割后的预测图像。
  与FCN不同的是,在上采样过程中保留了大量的特征通道(feature channels),从而使更多的信息能流入最终复原的分割图像中。另外,为了降低在压缩路径上损失的图像信息,还将contracting path和expanding path同尺寸的feature map进行叠加,再继续进行卷积和上采样工作,以此整合更多信息进行图像分割。

  • 蓝色箭头代表3x3的卷积操作,channel的大小乘2,stride是1,padding为0,因此,每个该操作以后,featuremap的大小会减2。输入是572x572的,但是输出变成了388x388,灰色箭头表示复制和剪切操作,在同一层左边的最后一层要比右边的第一层要大一些,这就导致了,想要利用浅层的feature,就要进行一些剪切,也导致了最终的输出是输入的中心某个区域。
  • 红色箭头代表2x2的maxpooling操作,需要注意的是,如果pooling之前featuremap的大小是奇数,那么就会损失一些信息。
  • 绿色箭头代表2x2的反卷积操作,操作会将featuremap的大小乘2,channel的大小除以2。
  • 输出的最后一层,使用了1x1的卷积层做了分类,把64个特征向量分成了2类(细胞类、背景类)。

3. unet和FCN的区别

参考unet++的作者:
  U-Net和FCN非常的相似,U-Net比FCN稍晚提出来,但都发表在2015年,和FCN相比,U-Net的第一个特点是完全对称,也就是左边和右边是很类似的,而FCN的decoder相对简单,只用了一个deconvolution的操作,之后并没有跟上卷积结构。第二个区别就是skip connection,FCN用的是加操作(summation),U-Net用的是叠操作(concatenation)。这些都是细节,重点是它们的结构用了一个比较经典的思路,也就是编码和解码(encoder-decoder),早在2006年就被Hinton大神提出来发表在了nature上。
  降采样可以增加对输入图像的一些小扰动的鲁棒性,比如图像平移,旋转等,减少过拟合的风险,降低运算量,和增加感受野的大小。升采样的最大的作用其实就是把抽象的特征再还原解码到原图的尺寸,最终得到分割结果。

4. 数据增强

增强方式:弹性形变(elastic deformations)
  作者采用了弹性变形的图像增广,以此让网络学习更稳定的图像特征。因为数据集是细胞组织的图像,细胞组织的边界每时每刻都会发生不规则的畸变,所以这种弹性变形的增广是非常有效的。
首先创建随机位移场来使图像变形,即
  Δ x ( x , y ) = r a n d ( − 1 , + 1 ) \ Δx(x,y) = rand(-1,+1)  Δx(x,y)=rand(1,+1)
  Δ y ( x , y ) = r a n d ( − 1 , + 1 ) \ Δy(x,y)=rand(-1,+1)  Δy(x,y)=rand(1,+1)
  其中rand(-1,+1)是生成一个在(-1,1)之间均匀分布的随机数,然后用标准差为σ的高斯函数对Δx和Δy进行卷积。高斯函数如下:
  G ( x , y ) = 1 2 π σ 2 e − ( x 2 + y 2 ) / 2 σ 2 \ G(x,y) = \frac{1}{2\pi\sigma^2}e^{-(x^2+y^2)/2\sigma^2}  G(x,y)=2πσ21e(x2+y2)/2σ2
  如果   σ \ \sigma  σ值很大,则结果值很小,因为随机值平均为0;如果σ很小,则归一化后该字段看起来像一个完全随机的字段。
  对于中间σ值,位移场看起来像弹性变形,其中σ是弹性系数。然后将位移场乘以控制变形强度的比例因子   α \ \alpha  α。 将经过高斯卷积的位移场乘以控制变形强度的比例因子   α \ \alpha  α,得到一个弹性形变的位移场,最后将这个位移场作用在仿射变换之后的图像上,得到最终弹性形变增强的数据。作用的过程相当于在仿射图像上插值的过程,最后返回插值之后的结果。

5. 损失函数设置权重

loss:
  细胞组织图像的一大特点是,多个同类的细胞会紧紧贴合在一起,其中只有细胞壁或膜组织分割,因此,作者在计算损失的过程中,给两个细胞的边缘部分及细胞间的背景部分增加了损失的权重,以此让网络更加注重这类重合的边缘信息。
  E = ∑ i ∈ Ω w ( x ) log ⁡ ( p ι ( x ) ( x ) \ E= \sum\limits_{i\in\Omega}w(x)\log(p_{\iota (x)}(x)  E=iΩw(x)log(pι(x)(x)

6. overlap-tile重叠平铺策略

overlap-tile策略
  该策略的思想是:对图像的某一块像素点(黄框内部分)进行预测时,需要该图像块周围的像素点(蓝色框内)提供上下文信息(context),以获得更准确的预测。简单地说,就是在预处理中,对输入图像进行padding,通过padding扩大输入图像的尺寸,使得最后输出的结果正好是原始图像的尺寸, 同时, 输入图像块(黄框)的边界也获得了上下文信息从而提高预测的精度,本文用的是mirror padding
用于任意大图像的无缝分割的重叠拼贴策略。 对黄色区域中的分割的预测需要蓝色区域内的图像数据作为输入。 丢失的输入数据通过镜像推断
  医学图像是一般相当大,但是分割时候不可能将原图太小输入网络,所以必须切成一张一张的小patch,在切成小patch的时候,Unet由于网络结构的原因,因此适合用overlap的切图 overlap部分可以为分割区域边缘部分提供文理等信息, 并且分割结果并没有受到切成小patch而造成分割情况不好。

二、翻译

0. 摘要

abstract:
  人们普遍同意,成功地训练深度网络需要大量已标注的训练样本。在本文中,我们提出了一种网络和训练策略。为了更有效的利用标注数据,我们使用数据扩张的方法(data augmentation)。该体系结构包括两部分:用于捕获上下文的收缩路径(a contracting path)用于进行精确定位的对称扩展路径(a symmetric expanding path)。我们展示了这样的网络可以在很少的图像上进行端到端训练,并且在ISBI挑战方面优于现有的最佳方法。我们使用这个网络获得了赢得了ISBI cell tracking challenge 2015。不仅如此,这个网络非常的快,对一个512*512的图像,使用一块GPU只需要不到一秒的时间。
  网络获取地址(基于caffe): http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net

1. 介绍

introduction:
  在过去的两年中,深度卷积网络在许多视觉识别任务中的表现超越了现有技术。 虽然卷积网络已经存在很长时间了,但是由于可用训练集的大小和网络结构的大小,它们的成功受到限制。 Krizhevsky等人创造了很大的突破,他们对具有8层的大型网络和具有100万个训练图像的ImageNet数据集上的数百万个参数进行了监督训练。从那时起,甚至更大更深的网络也得到了训练。
  卷积网络的典型用途是用于分类任务,其中图像的输出是单个类别。然而,在许多视觉任务中,特别是在生物医学图像处理中,期望的输出应包括位置,所以应该给每个像素都进行标注。然而在生物医学任务中通常无法获得数千个训练图像。因此,Ciresan等人提出用滑动窗口训练网络,通过提供像素周围的局部区域(patch——每个patch包含很多pixel)作为输入,预测每个像素的类标签。首先,这个网络可以完成定位工作。第二,patch的训练数据远大于训练图像的数量。最终网络赢得了ISBI 2012的胜利。
  显然,Ciresan等人的策略有两个缺点。首先要分别预测每一个patch的类别,patch之间的重叠会导致大量的冗余。其次这个网络需要在局部准确性和获取整体上下文信息之间平衡,更大的patches需多的最大池化层,但是会降低准确率,小的patches仅允许网络查看很少的上下文信息。最近有方法提出了一种分类器输出,该输出考虑了来自多层的特征,既有准确地定位,又包含了上下文信息。
  在这篇文章中,我们建立了一个更加优雅的框架,通常被称为“全卷积网络”(fully convolutional network)。我们修改并拓展了这个框架,使其可以仅使用少量训练图片就可以工作,获得更高的分割准确率。网络如下图所示:
图一
  “全卷积网络”(fully convolutional network)的核心思想是修改一个普通的逐层收缩的网络,用上采样 (up sampling)操作代替网络后部的池化(pooling)操作。因此,这些层增加了输出的分辨率。为了精准定位,在网络收缩过程(路径)中产生的高分辨率特征(high resolution features) ,被连接到了修改后网络的上采样的结果上。在此之后,连续的卷积层基于这些综合信息得到更精确的结果。
  我们架构的重大改进是,在上采样部分中,我们还拥有大量特征通道,这些特征通道使网络可以将上下文信息传播到更高分辨率的层。结果,扩展路径或多或少地相对于收缩路径对称,并且产生u形结构。 网络不存在任何全连接层(fully connected layers),并且,只使用每个卷积的有效部分,例如,分割图(segmentation map)只包含这样一些像素点,这些像素点的完整上下文都出现在输入图像中。该策略允许通过重叠-平铺策略对任意大图像进行无缝分割,如下图。(用于任意大图像的无缝分割的重叠拼贴策略。 对黄色区域中的分割的预测需要蓝色区域内的图像数据作为输入。 丢失的输入数据通过镜像推断。)
用于任意大图像的无缝分割的重叠拼贴策略。 对黄色区域中的分割的预测需要蓝色区域内的图像数据作为输入。 丢失的输入数据通过镜像推断
  为了预测图像边界区域中的像素,可通过镜像输入图像来推断缺失的上下文。 这种平铺策略对于将网络应用于大图像非常重要,因为否则分辨率会受到GPU内存的限制。
  至于我们的任务,几乎没有可用的训练数据,我们通过对可用的训练图像应用弹性变形来使用过多的数据增强。这允许网络学习此类变形的不变性,而无需在带注释的图像语料库中查看这些转换。这对医学图像分割是非常重要的,因为组织的形变是非常常见的情况,并且计算机可以很有效的模拟真实的形变。在无监督特征学习的范围内,Dosovitskiy等人已经证明了数据增强对于学习不变性的价值。
  在细胞分割任务中的另一个挑战是分离同一类别的接触物体。 参见图3。为此,我们使用了加权损失,这些位于touching cells之间的背景在损失函数的权重很高。如图所示:
在这里插入图片描述
  (a)原始图像
  (b)标注图像实况分割 不同的颜色表示HeLa细胞的不同情况
  (c)生成分割蒙版(白色:前景,黑色:背景)
  (d)以像素为单位的权重映射,迫使网络学习边界像素
  本文提出的网络适用于各种生物医学分割问题。 在本文中,我们展示了有关EM stacks中神经元结构分割的结果(一场持续的竞争始于ISBI 2012),在此方面我们胜过了Ciresan等人的网络。 此外,我们在2015年ISBI cell tracking challenge的光学显微镜图像数据集中中显示了细胞分割的结果。我们在两个最具挑战性的2D数据集上取得了很好的效果。

2. 网络结构

architecture:
  网络架构如图所示。它由一个contracting path(收缩路径) 和 expansive path(扩展路径)组成。 收缩路径遵循卷积网络的典型架构。它包括了重复单元:2个3 * 3卷积层(unpadding)、ReLU激活函数和一个2 * 2的步长为2的max pooling层。每一次下采样后我们都把特征通道的数量加倍。扩展路径中的每一步都首先使用反卷积(up-convolution),每次使用反卷积都将特征通道数量减半,特征图大小加倍。反卷积过后,将反卷积的结果与收缩路径中对应步骤的特征图拼接起来,跟随2个3 * 3卷积层(unpadding)、ReLU激活函数。由于每次卷积中都会丢失边界像素,因此有必要进行裁剪。 在最后一层,使用1x1卷积将每个64分量特征向量映射到所需的类数。 该网络总共有23个卷积层(18个顺序卷积+4个收缩到扩展的卷积+1个1*1的卷积)。
图一
  为了无缝拼接输出分割图,选择输入图块大小非常重要,以保证所有的Max Pooling操作作用于长宽为偶数的feature map。

3. 训练

training:
  我们采用随机梯度下降法训练,基于caffe框架。为了最大限度的使用GPU显存,比起输入一个大的batch size,我们更倾向于大量输入tiles,因此实验batch size为1。我们使用了很高的momentum(0.99),大量先前的训练样本确定了当前优化步骤中的更新。损失函数就是pixel-wise softmax + cross_entropy
  softmax函数,   a k \ a_k  ak代表   k \ k  k通道,   x \ x  x像素位置,   K \ K  K代表类别的数量
  p k ( x ) = exp ⁡ ( a k ( x ) ) / ( ∑ k ′ = 1 K exp ⁡ ( a k ′ ( x ) ) \ p_k(x) = \exp(a_k(x))/(\sum_{k'=1}^K \exp(a_k'(x))  pk(x)=exp(ak(x))/(k=1Kexp(ak(x))
   交叉熵损失函数,   ι : Ω → ( 1 , 2 , 3... K ) \ \iota:\Omega\rightarrow({1,2,3...K})  ιΩ(1,2,3...K)代表true label,   w \ w  w是权重
  E = ∑ i ∈ Ω w ( x ) log ⁡ ( p ι ( x ) ( x ) \ E= \sum\limits_{i\in\Omega}w(x)\log(p_{\iota (x)}(x)  E=iΩw(x)log(pι(x)(x)
  为了使某些像素点更加重要,我们在公式中引入了   w ( x ) \ w(x)  w(x)。我们对每一张标注图像预计算了一个权重图,来补偿训练集中每类像素的不同频率,使网络更注重学习相互接触的细胞之间的小的分割边界。我们使用形态学操作计算分割边界。权重图计算公式如下:
  w ( x ) = w c ( x ) + w 0 ∗ exp ⁡ ( − ( ( d 1 ( x ) + d 2 ( x ) ) 2 ) 2 σ 2 ) \ w(x)=w_c(x)+w_0*\exp(-\frac{((d_1(x)+d_2(x))^2)}{2\sigma^2})  w(x)=wc(x)+w0exp(2σ2((d1(x)+d2(x))2))
     w c \ w_c  wc是用于平衡类别频率的权重图,   d 1 \ d_1  d1代表到最近细胞的边界的距离,   d 2 \ d_2  d2代表到第二近的细胞的边界的距离。基于经验我们设定   w 0 \ w_0  w0=10, σ = 5 \sigma=5 σ=5像素。
  在具有许多卷积层和通过网络的不同路径的深度网络中,权重的良好初始化非常重要。 否则,网络的某些部分可能会进行过多的激活,而其他部分则永远不会起作用。 理想情况下,应调整初始权重,以使网络中的每个特征图都具有大约单位方差。 对于具有我们(交替卷积和ReLU层)的网络结构,这可以通过从具有标准偏差   2 / N \ \sqrt {2/N}  2/N 的高斯分布中提取初始权重来实现,其中N表示一个神经元的传入节点数。 例如。 对于上一层的3x3卷积和64个特征通道,N = 9·64 = 576。

3.1 数据扩充

data augmentation:
  在只有少量样本的情况况下,要想尽可能的让网络获得不变性和鲁棒性,数据增加是很重要的。因为本论文需要处理显微镜图片,我们需要平移与旋转不变性,并且对形变和灰度变化鲁棒。将训练样本进行随机弹性变形,是训练带有很少标注图像的分割网络的关键。 我们使用随机位移矢量在粗糙的3*3网格上(random displacement vectors on a coarse 3 by 3 grid)产生平滑形变(smooth deformations)。 位移是从10像素标准偏差的高斯分布中采样的。然后使用双三次插值计算每个像素的位移。在contracting path的末尾采用drop-out 层更进一步增加数据。(不懂)

4. 实验

experiment:
  我们演示了unet在三个不同分割任务上的应用,第一项任务是在电子显微镜记录中分割神经元结构,在下图中演示了数据集中的一个例子和我们的分割结果。我们提供了全部结果作为补充材料。数据集是EM分割挑战提供的,这个挑战是从 ISBI 2012开始的,现在依旧开放。训练数据是一组来自果蝇幼虫腹侧腹侧神经索(VNC)的连续切片透射电镜的30张图像(512x512像素)。每个图像都带有一个对应的标注分割图,细胞(白色)和膜(黑色)。测试集是公开可用的,但对应的标注图是保密的,可以通过将预测的膜概率图发送给组织者来获得评估。通过在10个不同级别对结果进行阈值化和计算the “warping error”, the “Rand error” and the “pixel error”(预测的label和实际的label)
  u-net(输入数据的7个旋转版本的平均值)无需进行任何进一步的预处理或后处理即可获得0.0003529的warping error和Rand error为0.0382。
在这里插入图片描述

5. 结论

conclusion:  
  u-net体系结构在截然不同的生物医学细分应用中实现了非常好的性能。 由于具有弹性变形的数据增强,它仅需要很少的带注释的图像,并且在NVidia Titan GPU(6 GB)上只有10小时的非常合理的训练时间。 我们提供完整的基于Caffe的训练网络。 我们确信u-net架构可以轻松地应用于更多任务。

三、纸质版学习材料

在这里插入图片描述
在这里插入图片描述在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

四、代码

在LiTS-肝肿瘤分割挑战数据集训练unet模型
数据集https://github.com/JavisPeng/u_net_liver

import torch.utils.data as data
import os
import PIL.Image as Image
import torch
from torchvision.transforms import transforms as T
import argparse #argparse模块的作用是用于解析命令行参数,例如python parseTest.py input.txt --port=8080
from torch import optim
from torch.utils.data import DataLoader
#data.Dataset:
#所有子类应该override__len__和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)
import torch.nn as nn
import torch
 
class DoubleConv(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(DoubleConv,self).__init__()
        self.conv = nn.Sequential(
                nn.Conv2d(in_ch,out_ch,3,padding=1),#in_ch、out_ch是通道数
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace = True),
                nn.Conv2d(out_ch,out_ch,3,padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace = True)  
            )
    def forward(self,x):
        return self.conv(x)
 
 
class UNet(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(UNet,self).__init__()
        self.conv1 = DoubleConv(in_ch,64)
        self.pool1 = nn.MaxPool2d(2)#每次把图像尺寸缩小一半
        self.conv2 = DoubleConv(64,128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128,256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256,512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512,1024)
        #逆卷积
        self.up6 = nn.ConvTranspose2d(1024,512,2,stride=2)
        self.conv6 = DoubleConv(1024,512)
        self.up7 = nn.ConvTranspose2d(512,256,2,stride=2)
        self.conv7 = DoubleConv(512,256)
        self.up8 = nn.ConvTranspose2d(256,128,2,stride=2)
        self.conv8 = DoubleConv(256,128)
        self.up9 = nn.ConvTranspose2d(128,64,2,stride=2)
        self.conv9 = DoubleConv(128,64)
        
        self.conv10 = nn.Conv2d(64,out_ch,1)
        
    
    def forward(self,x):
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        c3 = self.conv3(p2)
        p3 = self.pool3(c3)
        c4 = self.conv4(p3)
        p4 = self.pool4(c4)
        c5 = self.conv5(p4)
        up_6 = self.up6(c5)
        merge6 = torch.cat([up_6,c4],dim=1)#按维数1(列)拼接,列增加
        c6 = self.conv6(merge6)
        up_7 = self.up7(c6)
        merge7 = torch.cat([up_7,c3],dim=1)
        c7 = self.conv7(merge7)
        up_8 = self.up8(c7)
        merge8 = torch.cat([up_8,c2],dim=1)
        c8 = self.conv8(merge8)
        up_9 = self.up9(c8)
        merge9 = torch.cat([up_9,c1],dim=1)
        c9 = self.conv9(merge9)
        c10 = self.conv10(c9)
        
        out = nn.Sigmoid()(c10)#化成(0~1)区间
        return out
 
class LiverDataset(data.Dataset):
    #创建LiverDataset类的实例时,就是在调用init初始化
    def __init__(self,root,transform = None,target_transform = None):#root表示图片路径
        n = len(os.listdir(root))//2 #os.listdir(path)返回指定路径下的文件和文件夹列表。/是真除法,//对结果取整
        print(len(os.listdir(root)))
        imgs = []
        for i in range(n):
            img = os.path.join(root,"%03d.png"%i)#os.path.join(path1[,path2[,......]]):将多个路径组合后返回
            mask = os.path.join(root,"%03d_mask.png"%i)
            imgs.append([img,mask])#append只能有一个参数,加上[]变成一个list
        
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
    
    
    def __getitem__(self,index):
        x_path,y_path = self.imgs[index]
        img_x = Image.open(x_path)
        img_y = Image.open(y_path)
        if self.transform is not None:
            img_x = self.transform(img_x)
        if self.target_transform is not None:
            img_y = self.target_transform(img_y)
        return img_x,img_y#返回的是图片
    
    
    def __len__(self):
        return len(self.imgs)#400,list[i]有两个元素,[img,mask]
 
 
# 是否使用current cuda device or torch.device('cuda:0')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
x_transform = T.Compose([
    T.ToTensor(),
    # 标准化至[-1,1],规定均值和标准差
    T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])#torchvision.transforms.Normalize(mean, std, inplace=False)
])
# mask只需要转换为tensor
y_transform = T.ToTensor()
 
def train_model(model,criterion,optimizer,dataload,num_epochs=20):
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        dataset_size = len(dataload.dataset)
        epoch_loss = 0
        step = 0 #minibatch数
        rightCount = 0.0
        allCount = 0.0
        for x, y in dataload:# 分100次遍历数据集,每次遍历batch_size=4
            optimizer.zero_grad()#每次minibatch都要将梯度(dw,db,...)清零
            inputs = x.to(device)
            labels = y.to(device)
            print(y.shape[0])
            outputs = model(inputs)#前向传播
            allCount += y.shape[0]*y.shape[2]*y.shape[3]
            loss = criterion(outputs, labels)#计算损失
            loss.backward()#梯度下降,计算出梯度
            optimizer.step()#更新参数一次:所有的优化器Optimizer都实现了step()方法来对所有的参数进行更新
            epoch_loss += loss.item()
            rightCount += (torch.gt(outputs, labels)).sum()
            print("rightCount",rightCount)
            step += 1
            print("%d/%d,train_loss:%0.3f" % (step, dataset_size // dataload.batch_size, loss.item()))
            print("%d/%d,train_acc:%0.3f" % (step, dataset_size // dataload.batch_size, 100. * rightCount/allCount))
        print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
        print("第%d循环,准确率为%d" % (epoch+1,rightCount/allCount))
        torch.save(model.state_dict(),'weights_%d.pth' % epoch)
        # 返回模型的所有内容
    return model

#训练模型
def train():
    model = UNet(3,1).to(device)
    batch_size = args.batch_size
    #损失函数
    criterion = torch.nn.BCELoss()
    #梯度下降
    optimizer = optim.Adam(model.parameters())#model.parameters():Returns an iterator over module parameters
    #加载数据集
    liver_dataset = LiverDataset("./unet/train", transform=x_transform, target_transform=y_transform)
    dataloader = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True,num_workers=4)
    # DataLoader:该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor
    # batch_size:how many samples per minibatch to load,这里为4,数据集大小400,所以一共有100个minibatch
    # shuffle:每个epoch将数据打乱,这里epoch=10。一般在训练数据中会采用
    # num_workers:表示通过多个进程来导入数据,可以加快数据导入速度 
    train_model(model,criterion,optimizer,dataloader)
 
#测试
def test():
    model = unet.UNet(3,1)
    model.load_state_dict(torch.load(args.weight,map_location='cpu'))
    liver_dataset = LiverDataset("./unet/val", transform=x_transform, target_transform=y_transform)
    dataloaders = DataLoader(liver_dataset)#batch_size默认为1
    model.eval()
    import matplotlib.pyplot as plt
    plt.ion()
    with torch.no_grad():
        for x, _ in dataloaders:
            y=model(x)
            img_y=torch.squeeze(y).numpy()
            plt.imshow(img_y)
            plt.pause(0.01)
        plt.show()
 
 
if __name__ == '__main__':
    #参数解析
    # parser = argparse.ArgumentParser() #创建一个ArgumentParser对象
    # parser.add_argument('action', type=str, help='train or test',default = 'train')#添加参数
    # parser.add_argument('--batch_size', type=int, default=4)
    # parser.add_argument('--weight', type=str, help='the path of the mode weight file')
    # #args = parser.parse_args()
    args = parser.parse_known_args()[0]
    args.action = 'train'
    args.batch_size = 4
    #args.weight = "/content/drive/My Drive/Pytorch_try/unet"

    print(args.action=='train')
    if args.action == 'train':
        train()
    elif args.action == 'test':
        test()

感谢大神的指引https://blog.csdn.net/weixin_42135399/article/details/90178673

  • 77
    点赞
  • 260
    收藏
    觉得还不错? 一键收藏
  • 8
    评论
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值