pytorch实现Unet

http://t.zoukankan.com/wanghui-garcia-p-10719121.html

https://github.com/1024210879/unet-denoising-dirty-documents/blob/master/datasets.py

在这里插入图片描述

Model.py

# sub-parts of the U-Net model
import torch
import torch.nn as nn
import torch.nn.functional as F

# 实现左边的横向卷积
class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''

    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            # 以第一层为例进行讲解
            # 输入通道数in_ch,输出通道数out_ch,卷积核设为kernal_size 3*3,padding为1,stride为1,dilation=1
            # 所以图中H*W能从572*572 变为 570*570,计算为570 = ((572 + 2*padding - dilation*(kernal_size-1) -1) / stride ) +1
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),  # 进行批标准化,在训练时,该层计算每次输入的均值与方差,并进行移动平均
            nn.ReLU(inplace=True),  # 激活函数
            nn.Conv2d(out_ch, out_ch, 3, padding=1),  # 再进行一次卷积,从570*570变为 568*568
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        x = self.conv(x)
        return x

# 实现左边第一行的卷积
class inconv(nn.Module):  #
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)  # 输入通道数in_ch为3, 输出通道数out_ch为64
    def forward(self, x):
        x = self.conv(x)
        return x

# 实现左边的向下池化操作,并完成另一层的卷积
class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )
    def forward(self, x):
        x = self.mpconv(x)
        return x

# 实现右边的向上的采样操作,并完成该层相应的卷积操作
class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()
        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if bilinear:  # 声明使用的上采样方法为bilinear——双线性插值,默认使用这个值,计算方法为 floor(H*scale_factor),所以由28*28变为56*56
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:  # 否则就使用转置卷积来实现上采样,计算式子为 (Height-1)*stride - 2*padding -kernal_size +output_padding
            self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)
    def forward(self, x1, x2):  # x2是左边特征提取传来的值
        # 第一次上采样返回56*56,但是还没结束
        x1 = self.up(x1)

        # input is CHW, [0]是batch_size, [1]是通道数,更改了下,与源码不同
        diffY = x1.size()[2] - x2.size()[2]  # 得到图像x2与x1的H的差值,56-64=-8
        diffX = x1.size()[3] - x2.size()[3]  # 得到图像x2与x1的W差值,56-64=-8

        # 用第一次上采样为例,即当上采样后的结果大小与右边的特征的结果大小不同时,通过填充来使x2的大小与x1相同
        # 对图像进行填充(-4,-4,-4,-4),左右上下都缩小4,所以最后使得64*64变为56*56
        x2 = F.pad(x2, (diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2))

        # for padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

        # 将最后上采样得到的值x1和左边特征提取的值进行拼接,dim=1即在通道数上进行拼接,由512变为1024
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x

# 实现右边的最高层的最右边的卷积
class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)
    def forward(self, x):
        x = self.conv(x)
        return x
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels): #图片的通道数,1为灰度图像,3为彩色图像
        super(UNet, self).__init__()
        self.inc = inconv(in_channels, 64) #假设输入通道数n_channels为3,输出通道数为64
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)
        self.up1 = up(1024, 256)
        self.up2 = up(512, 128)
        self.up3 = up(256, 64)
        self.up4 = up(128, 64)
        self.outc = outconv(64, out_channels)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x
        # return F.sigmoid(x) #进行二分类

dataset.py

import torch
import os
import numpy as np
import transforms as Transforms
from torch.utils.data import Dataset


class UNetDataset(Dataset):
    def __init__(self, dir_train, dir_mask, transform=None):
        self.dirTrain = dir_train
        self.dirMask = dir_mask
        self.transform = transform
        self.dataTrain = [os.path.join(self.dirTrain, filename)
                          for filename in os.listdir(self.dirTrain)]
                          # if filename.endswith('.jpg') or filename.endswith('.png')]
        self.dataMask = [os.path.join(self.dirMask, filename)
                         for filename in os.listdir(self.dirMask)]
                         # if filename.endswith('.jpg') or filename.endswith('.png')]
        self.trainDataSize = len(self.dataTrain)
        self.maskDataSize = len(self.dataMask)

    def __getitem__(self, index):
        assert self.trainDataSize == self.maskDataSize
        image = np.fromfile(self.dataTrain[index], dtype='int16')
        image = np.reshape(image,(512,512))
        label = np.fromfile(self.dataMask[index], dtype='int16')
        label = np.reshape(label, (512,512))
        label = label - image
        # image = cv2.imread(self.dataTrain[index])
        # label = cv2.imread(self.dataMask[index])

        if self.transform:
            for method in self.transform:
                image, label = method(image, label)

        return image[np.newaxis], label[np.newaxis]

    def __len__(self):
        assert self.trainDataSize == self.maskDataSize
        return self.trainDataSize

train.py

损失采用L1 loss

import torch
import torch.nn as nn
from torch import optim
import os
from unet import UNet
from datasets import UNetDataset
import transforms as Transforms
from torch.utils.data import DataLoader

if not os.path.exists('./weight'):
    os.mkdir('./weight')
LR = 1e-3
EPOCH = 250
BATCH_SIZE = 4
weight = './weight/weight.pth'
weight_with_optimizer = './weight/weight_with_optimizer.pth'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def train():

    # dataset
    transforms = [
        # Transforms.ToGray(),
        # Transforms.RondomFlip(),
        # Transforms.RandomRotate(15),
        Transforms.RandomCrop(128,128),
        # Transforms.Log(0.5),
        # Transforms.EqualizeHist(0.5),
        # Transforms.Blur(0.2),
        # Transforms.ToTensor()
    ]
    dataset = UNetDataset(r'D:\DataSet\artifact\artifact_part\input', r'D:\DataSet\artifact\artifact_part\target', transform=None)
    dataLoader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

    # init model
    net = UNet(1, 1).to(device)
    optimizer = optim.Adam(net.parameters(), lr=LR)
    # loss_func = nn.CrossEntropyLoss().to(device)
    loss_func = nn.L1Loss(reduction='mean')
    # L1 LOSS
    # load weight
    if os.path.exists(weight_with_optimizer):
        checkpoint = torch.load(weight_with_optimizer)
        net.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print('load weight')

    # train
    for epoch in range(EPOCH):
        # train
        total_loss = 0
        for step, (batch_x, batch_y) in enumerate(dataLoader):
            # import cv2
            # import numpy as np
            # display = np.concatenate(
            #     (batch_x[0][0].numpy(), batch_y[0][0].numpy().astype(np.float32)),
            #     axis=1
            # )
            # cv2.imshow('display', display)
            # cv2.waitKey()
            nstep = len(dataLoader)
            batch_x = batch_x.to(device).float()
            batch_y = batch_y.to(device).float()
            output = net(batch_x)   # torch.float32
            loss = loss_func(output, batch_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss
            if step % 50 == 0:
                print("epoch: [%3d/%d] Batch:[%5d/%5d] | loss: %.4f"
                      % (epoch, EPOCH, step, nstep, loss.data.cpu()))

        mean_loss = total_loss / nstep

        print('epoch: %d | loss: %.4f' % (epoch, mean_loss.data.cpu()))

        # save weight
        if (epoch + 1) % 1 == 0:
            torch.save({
                'net': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }, weight_with_optimizer)
            torch.save({
                'net': net.state_dict()
            }, weight)
            print('saved')


if __name__ == '__main__':
    train()
  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Dataloading...

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值