Unet基础代码(修补版)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

1.概述

UNet是医学图像分割领域经典的论文,因其结构像字母U得名,本文的代码是对其他博主代码的细节上的增改,增加了测试代码。

下面是该博主的链接,包含了预训练模型:

UNet的Pytorch实现_Natuski_的博客-CSDN博客_pytorch unet

一、dataset.py

import os
import torchvision
from PIL import Image
from torch.utils.data import Dataset
import torch

class SEGData(Dataset):
    def __init__(self,path1,path2):
        '''
        根据标注文件去取图片
        '''
        self.img_path=path1
        self.label_path=path2
        self.images = sorted(os.listdir(self.img_path))
        self.labels = sorted(os.listdir(self.label_path))
        # self.label_data=os.listdir(self.label_path)
        self.totensor=torchvision.transforms.ToTensor()
        # 一般而言,尺寸越大,训练效果越好,速度越慢
        self.resizer=torchvision.transforms.Resize((512,512))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, i):
        '''
        由于输出的图片的尺寸不同,我们需要转换为相同大小的图片。首先转换为正方形图片,然后缩放的同样尺度(256*256)。
        否则dataloader会报错。
        '''
        # 取出图片路径
        img = Image.open(self.img_path + self.images[i])
        label = Image.open(self.label_path + self.labels[i])
        # img_name = os.path.join(self.label_path, self.label_data[item])
        # img_name = os.path.split(img_name)
        # img_name = img_name[-1]
        # img_name = img_name.split('.')
        # img_name = img_name[0] + '.png'
        # img_data = os.path.join(self.img_path, img_name)
        # label_data = os.path.join(self.label_path, self.label_data[item])
        # 将图片和标签都转为正方形
        # img = Image.open(img_data)
        # label = Image.open(label_data)
        w, h = img.size
        # 以最长边为基准,生成全0正方形矩阵
        slide = max(h, w)
        black_img = torchvision.transforms.ToPILImage()(torch.zeros(3, slide, slide))
        black_label = torchvision.transforms.ToPILImage()(torch.zeros(3, slide, slide))
        black_img.paste(img, (0, 0, int(w), int(h)))  # patse在图中央和在左上角是一样的
        black_label.paste(label, (0, 0, int(w), int(h)))
        # 变为tensor,转换为统一大小512*512
        img = self.resizer(black_img)
        label = self.resizer(black_label)
        img = self.totensor(img)
        label = self.totensor(label)
        return img,label

二、Model.py

from __future__ import print_function, division

import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        out_channels=[2**(i+6) for i in range(5)] #[64, 128, 256, 512, 1024]
        #下采样
        self.d1=DownsampleLayer(3,out_channels[0])#3-64
        self.d2=DownsampleLayer(out_channels[0],out_channels[1])#64-128
        self.d3=DownsampleLayer(out_channels[1],out_channels[2])#128-256
        self.d4=DownsampleLayer(out_channels[2],out_channels[3])#256-512
        #上采样
        self.u1=UpSampleLayer(out_channels[3],out_channels[3])#512-1024-512
        self.u2=UpSampleLayer(out_channels[4],out_channels[2])#1024-512-256
        self.u3=UpSampleLayer(out_channels[3],out_channels[1])#512-256-128
        self.u4=UpSampleLayer(out_channels[2],out_channels[0])#256-128-64
        #输出
        self.o=nn.Sequential(
            nn.Conv2d(out_channels[1],out_channels[0],kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(out_channels[0]),
            nn.ReLU(),
            nn.Conv2d(out_channels[0], out_channels[0], kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels[0]),
            nn.ReLU(),
            nn.Conv2d(out_channels[0],3,3,1,1),
            nn.Sigmoid(),
            # BCELoss
        )
    def forward(self,x):
        out_1,out1=self.d1(x)
        out_2,out2=self.d2(out1)
        out_3,out3=self.d3(out2)
        out_4,out4=self.d4(out3)
        out5=self.u1(out4,out_4)
        out6=self.u2(out5,out_3)
        out7=self.u3(out6,out_2)
        out8=self.u4(out7,out_1)
        out=self.o(out8)
        return out

# 下采样
class DownsampleLayer(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(DownsampleLayer, self).__init__()
        self.Conv_BN_ReLU_2=nn.Sequential(
            nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )
        self.downsample=nn.Sequential(
            nn.Conv2d(in_channels=out_ch,out_channels=out_ch,kernel_size=3,stride=2,padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )

    def forward(self,x):
        """
        :param x:
        :return: out输出到深层,out_2输入到下一层,
        """
        out=self.Conv_BN_ReLU_2(x)
        out_2=self.downsample(out)
        return out,out_2

# 上采样
class UpSampleLayer(nn.Module):
    def __init__(self,in_ch,out_ch):
        # 512-1024-512
        # 1024-512-256
        # 512-256-128
        # 256-128-64
        super(UpSampleLayer, self).__init__()
        self.Conv_BN_ReLU_2 = nn.Sequential(
            nn.Conv2d(in_channels=in_ch, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(out_ch*2),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_ch*2, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(out_ch*2),
            nn.ReLU()
        )
        self.upsample=nn.Sequential(
            nn.ConvTranspose2d(in_channels=out_ch*2,out_channels=out_ch,kernel_size=3,stride=2,padding=1,output_padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )

    def forward(self,x,out):
        '''
        :param x: 输入卷积层
        :param out:与上采样层进行cat
        :return:
        '''
        x_out=self.Conv_BN_ReLU_2(x)
        x_out=self.upsample(x_out)
        cat_out=torch.cat((x_out,out),dim=1)
        return cat_out

三、train.py

import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
import os
from torchvision.utils import save_image
from min_unet.Model import UNet
from min_unet.dataset import SEGData
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

def main(path1,path2,EPOCH,Batch):
    net = UNet().cuda()
    optimizer = torch.optim.Adam(net.parameters())
    loss_func = nn.BCELoss()
    data = SEGData(path1,path2)
    dataloader = DataLoader(data, batch_size=Batch, shuffle=True, num_workers=0, drop_last=True)
    summary = SummaryWriter(r'Log')
    print('load net')
    net.load_state_dict(torch.load('SAVE/Unet.pt'))
    print('load success')
    for epoch in range(EPOCH):
        print('开始第{}轮'.format(epoch))
        net.train()
        for i, (img, label) in enumerate(dataloader):
            img = img.cuda()
            label = label.cuda()
            img_out = net(img)
            loss = loss_func(img_out, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            summary.add_scalar('bceloss', loss, i)

        torch.save(net.state_dict(), r'SAVE/Unet.pt')

        img, label = data[2]
        img = torch.unsqueeze(img, dim=0).cuda()
        net.eval()
        out = net(img)
        if not os.path.exists(r"Log_imgs"):
            os.mkdir(r"Log_imgs")
        if epoch%10==0:
            save_image(out, 'Log_imgs/segimg_{}——.png'.format(epoch, i), nrow=1, scale_each=True)
        print(f"第{epoch}轮train_loss={loss.item()}")
        print('第{}轮结束'.format(epoch))


if __name__=='__main__':
    path1 = r'../data/imgs/'#训练集图像
    path2 = r'../data/masks/'#训练集图像标签
    EPOCH = 11
    Batch = 2
    main(path1,path2,EPOCH,Batch)

四、test.py

import torch
import torchvision
import os
from torchvision.utils import save_image
from min_unet.Model import UNet
from PIL import Image


def test(input_path):
    net = UNet().cuda()
    weight=r'SAVE/Unet.pt'
    if os.path.exists(weight):
        net.load_state_dict(torch.load(weight))
        print("successful")
    else:
        print("no")

    if not os.path.exists(r"Test_imgs"):
        os.mkdir(r"Test_imgs")

    for file in os.listdir(input_path):
        f=file.split('.')[0]
        path = os.path.join(input_path,file)
        img = Image.open(path)
        w, h = img.size
        slide = max(h, w)
        # img=transform(path)
        black_img = torchvision.transforms.ToPILImage()(torch.zeros(3, slide, slide))
        # black_label = torchvision.transforms.ToPILImage()(torch.zeros(3, slide, slide))
        black_img.paste(img, (0, 0, int(w), int(h)))  # patse在图中央和在左上角是一样的
        # black_label.paste(label, (0, 0, int(w), int(h)))
        tensor_test = torchvision.transforms.ToTensor()
        image = tensor_test(black_img)
        img = torch.unsqueeze(image, dim=0).cuda()
        net.eval()
        out = net(img)
        save_image(out, f'Test_imgs/segimg_{f}.png', nrow=1, scale_each=True)

if __name__=="__main__":
    input_path = r"../data/test_imgs"
    test(input_path)

  • 2
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值