pytorch实现Unet

这篇博客介绍了U-Net模型的详细实现过程,包括卷积层、池化层和上采样层的构建。此外,还展示了如何使用PyTorch实现数据集加载、损失函数设置以及训练流程。博客中提供了完整的代码示例,用于图像去噪任务。
摘要由CSDN通过智能技术生成

http://t.zoukankan.com/wanghui-garcia-p-10719121.html

https://github.com/1024210879/unet-denoising-dirty-documents/blob/master/datasets.py

在这里插入图片描述

Model.py

# sub-parts of the U-Net model
import torch
import torch.nn as nn
import torch.nn.functional as F

# 实现左边的横向卷积
class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''

    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            # 以第一层为例进行讲解
            # 输入通道数in_ch,输出通道数out_ch,卷积核设为kernal_size 3*3,padding为1,stride为1,dilation=1
            # 所以图中H*W能从572*572 变为 570*570,计算为570 = ((572 + 2*padding - dilation*(kernal_size-1) -1) / stride ) +1
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),  # 进行批标准化,在训练时,该层计算每次输入的均值与方差,并进行移动平均
            nn.ReLU(inplace=True),  # 激活函数
            nn.Conv2d(out_ch, out_ch, 3, padding=1),  # 再进行一次卷积,从570*570变为 568*568
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        x = self.conv(x)
        return x

# 实现左边第一行的卷积
class inconv(nn.Module):  #
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)  # 输入通道数in_ch为3, 输出通道数out_ch为64
    def forward(self, x):
        x = self.conv(x)
        return x

# 实现左边的向下池化操作,并完成另一层的卷积
class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )
    def forward(self, x):
        x = self.mpconv(x)
        return x

# 实现右边的向上的采样操作,并完成该层相应的卷积操作
class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()
        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if bilinear:  # 声明使用的上采样方法为bilinear——双线性插值,默认使用这个值,计算方法为 floor(H*scale_factor),所以由28*28变为56*56
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:  # 否则就使用转置卷积来实现上采样,计算式子为 (Height-1)*stride - 2*padding -kernal_size +output_padding
            self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)
    def forward(self, x1, x2):  # x2是左边特征提取传来的值
        # 第一次上采样返回56*56,但是还没结束
        x1 = self.up(x1)

        # input is CHW, [0]是batch_size, [1]是通道数,更改了下,与源码不同
        diffY = x1.size()[2] - x2.size()[2]  # 得到图像x2与x1的H的差值,56-64=-8
        diffX = x1.size()[3] - x2.size()[3]  # 得到图像x2与x1的W差值,56-64=-8

        # 用第一次上采样为例,即当上采样后的结果大小与右边的特征的结果大小不同时,通过填充来使x2的大小与x1相同
        # 对图像进行填充(-4,-4,-4,-4),左右上下都缩小4,所以最后使得64*64变为56*56
        x2 = F.pad(x2, (diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2))

        # for padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

        # 将最后上采样得到的值x1和左边特征提取的值进行拼接,dim=1即在通道数上进行拼接,由512变为1024
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x

# 实现右边的最高层的最右边的卷积
class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)
    def forward(self, x):
        x = self.conv(x)
        return x
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels): #图片的通道数,1为灰度图像,3为彩色图像
        super(UNet, self).__init__()
        self.inc = inconv(in_channels, 64) #假设输入通道数n_channels为3,输出通道数为64
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)
        self.up1 = up(1024, 256)
        self.up2 = up(512, 128)
        self.up3 = up(256, 64)
        self.up4 = up(128, 64)
        self.outc = outconv(64, out_channels)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x
        # return F.sigmoid(x) #进行二分类

dataset.py

import torch
import os
import numpy as np
import transforms as Transforms
from torch.utils.data import Dataset


class UNetDataset(Dataset):
    def __init__(self, dir_train, dir_mask, transform=None):
        self.dirTrain = dir_train
        self.dirMask = dir_mask
        self.transform = transform
        self.dataTrain = [os.path.join(self.dirTrain, filename)
                          for filename in os.listdir(self.dirTrain)]
                          # if filename.endswith('.jpg') or filename.endswith('.png')]
        self.dataMask = [os.path.join(self.dirMask, filename)
                         for filename in os.listdir(self.dirMask)]
                         # if filename.endswith('.jpg') or filename.endswith('.png')]
        self.trainDataSize = len(self.dataTrain)
        self.maskDataSize = len(self.dataMask)

    def __getitem__(self, index):
        assert self.trainDataSize == self.maskDataSize
        image = np.fromfile(self.dataTrain[index], dtype='int16')
        image = np.reshape(image,(512,512))
        label = np.fromfile(self.dataMask[index], dtype='int16')
        label = np.reshape(label, (512,512))
        label = label - image
        # image = cv2.imread(self.dataTrain[index])
        # label = cv2.imread(self.dataMask[index])

        if self.transform:
            for method in self.transform:
                image, label = method(image, label)

        return image[np.newaxis], label[np.newaxis]

    def __len__(self):
        assert self.trainDataSize == self.maskDataSize
        return self.trainDataSize

train.py

损失采用L1 loss

import torch
import torch.nn as nn
from torch import optim
import os
from unet import UNet
from datasets import UNetDataset
import transforms as Transforms
from torch.utils.data import DataLoader

if not os.path.exists('./weight'):
    os.mkdir('./weight')
LR = 1e-3
EPOCH = 250
BATCH_SIZE = 4
weight = './weight/weight.pth'
weight_with_optimizer = './weight/weight_with_optimizer.pth'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def train():

    # dataset
    transforms = [
        # Transforms.ToGray(),
        # Transforms.RondomFlip(),
        # Transforms.RandomRotate(15),
        Transforms.RandomCrop(128,128),
        # Transforms.Log(0.5),
        # Transforms.EqualizeHist(0.5),
        # Transforms.Blur(0.2),
        # Transforms.ToTensor()
    ]
    dataset = UNetDataset(r'D:\DataSet\artifact\artifact_part\input', r'D:\DataSet\artifact\artifact_part\target', transform=None)
    dataLoader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

    # init model
    net = UNet(1, 1).to(device)
    optimizer = optim.Adam(net.parameters(), lr=LR)
    # loss_func = nn.CrossEntropyLoss().to(device)
    loss_func = nn.L1Loss(reduction='mean')
    # L1 LOSS
    # load weight
    if os.path.exists(weight_with_optimizer):
        checkpoint = torch.load(weight_with_optimizer)
        net.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print('load weight')

    # train
    for epoch in range(EPOCH):
        # train
        total_loss = 0
        for step, (batch_x, batch_y) in enumerate(dataLoader):
            # import cv2
            # import numpy as np
            # display = np.concatenate(
            #     (batch_x[0][0].numpy(), batch_y[0][0].numpy().astype(np.float32)),
            #     axis=1
            # )
            # cv2.imshow('display', display)
            # cv2.waitKey()
            nstep = len(dataLoader)
            batch_x = batch_x.to(device).float()
            batch_y = batch_y.to(device).float()
            output = net(batch_x)   # torch.float32
            loss = loss_func(output, batch_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss
            if step % 50 == 0:
                print("epoch: [%3d/%d] Batch:[%5d/%5d] | loss: %.4f"
                      % (epoch, EPOCH, step, nstep, loss.data.cpu()))

        mean_loss = total_loss / nstep

        print('epoch: %d | loss: %.4f' % (epoch, mean_loss.data.cpu()))

        # save weight
        if (epoch + 1) % 1 == 0:
            torch.save({
                'net': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }, weight_with_optimizer)
            torch.save({
                'net': net.state_dict()
            }, weight)
            print('saved')


if __name__ == '__main__':
    train()
  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
Swin-UNet是一种新型图像分割模型,它融合了Swin Transformer和UNet的特性,具有较好的图像分割能力。下面我将简单介绍如何使得Swin-UNet运行。 首先,准备好所需的开发环境,包括Python环境和必要的库。确保安装好PyTorch、Torchvision和其他所需的依赖项。 接下来,下载Swin-Transformer和Swin-UNet的代码。这些代码可以从GitHub上的相关仓库获取,可以使用Git命令将代码克隆到本地。确保克隆了最新的代码版本。 然后,准备好训练数据集。您可以选择一个适合您的应用场景的图像分割数据集,确保该数据集已经按照要求进行标注。将训练和验证数据集划分好,并按照指定的格式准备好。 接着,根据Swin-UNet的文档或示例代码,配置模型的参数和超参数。这些参数包括输入图像大小、批大小、学习率、网络层的尺寸等。根据您的需求和硬件资源,进行相应的调整。 之后,使用准备好的数据集进行训练。使用训练数据集和配置好的参数,运行训练代码,开始训练Swin-UNet模型。根据需要,您可以设定训练的迭代次数或停止条件。 训练完成后,您可以使用训练好的Swin-UNet模型进行图像分割任务的推理。提供一张测试图像,通过加载训练好的模型并对测试图像进行预测,获取图像分割的结果。 最后,根据需要对模型进行评估和调优。使用预留的验证数据集,计算模型在图像分割任务中的精度、召回率、准确率等指标。根据评估结果,进行模型的参数调整或其他优化操作。 总结来说,要使Swin-UNet跑通,您需要准备好开发环境、获取代码和数据集、配置参数、进行训练和推理,并对模型进行评估和调优。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Dataloading...

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值