UNet语义分割模型的使用-Pytorch

1.概述

最近有时间,跑了一下UNet模型,因为自己的深度学习基础不扎实,导致用了一些时间。目前只停留在使用和理解别人模型的基础上,对于优化模型的相关方法还有待学习。
众所周知,UNent是进行语义分割的知名模型,它的U形结构很多人也都见过,但是如果自己没有亲自试过的话,也就只知道它的U形结构,其实里面还是有很多学问的,下面就把自己学习时候的一些理解写一下。
最后会拿个完整代码作为例子(实际上自己练习了两个比较成功的例子)

2.UNet模型理解

先放UNet模型的图,然后介绍再Pytorch相关实现的函数。
UNet模型
一般看到这个图,都会看到它从左边逐渐编码,到最底端,之后从底端不断解码,恢复为一张图像。但是很多人会忽略中间的从左往右的三条灰色的直线,这是它把一个图像处理为目标图像的一个关键。
从图中理解,Unet模型可以分为几个关键的部分
①ConvBlock():也就是U形结构左边在这里插入图片描述的这一部分,当然也包括下面和它很相似的三个,他们都是(卷积-激活-卷积-激活),这也就是图像从572×572变成568×568的原因,实际上就是卷积过程中边界设定的问题。
②MaxPool2d():然后紧接者就是红色箭头了在这里插入图片描述,它就是池化,把卷积后的特征筛选出来,这是图像尺寸急剧下降的原因。
接着①和②步骤重复了4次。得到了一个512维,32×32的图片。当然有的模型为了简化,这个步骤只进行了3次,实际上也能取得不错的效果。
最后它又进行了一次①,成功把图像变成1024维度,30×30,图像特征被压缩在了这样的一个信息里面了。
下面就是U形结构的右侧上升部分了,总之就是把刚才被压缩的信息展开,结合一步步压缩过程的中间图像(灰色箭头),把图像还原成想要的样子。
这里主要用到了下面的函数
③ConvBlock():和①一模一样,只不过,①是把维度逐渐变大也就是从3变成1024,而在经过4次③之后,1024变成了1
在这里插入图片描述ConvTranspose2d():与池化对应,这个是上采样(不确定叫法是否得当)过程,它拓展了数据的尺寸,并减少了维度。
⑤copyAndCrop和cat,这个就是灰色箭头的实现过程,首先把两个输入数据(也就是原始数据及编码过程数据上采样结果裁剪为一致的尺寸,之后进行连接)
在这里插入图片描述在最后一层,输出的维度为一,也就是灰度图像,不过也可以定义为其他维度,例如输出彩色,这跟自己实际的需求有关。

3.数据集加载

为了方便下面展示代码,先导入必要的模块

import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.utils as vutils
from torchsummary import summary
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR, StepLR, MultiStepLR, CyclicLR
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T, datasets as dset
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
from zipfile import ZipFile
from tqdm import tqdm
from glob import glob
from PIL import Image
import cv2
from torch.utils.tensorboard import SummaryWriter

数据加载过程跟普通的卷积神经网络没什么区别,无非就是构建Dataset从文件夹读取数据,之后构建dataset_loader,用来为训练数据做准备。下面直接放一段代码。
不过还是要介绍一下原数据的目录结构,下载数据请点这里Carvana Image Masking Challenge.,下载下面这两个压缩包解压就可以
在这里插入图片描述
解压后这样就行:
在这里插入图片描述
在下面的代码中,数据被分为了两组,0.7的数据被作为训练组,0.3的数据被用来验证。
多说两句,这个数据集的构建继承了Dataset类,实现了__getitem__(self, index: int)和__len__(self)两个函数。

class MyDataset(Dataset):
    def __init__(self, root_dir: str, train=True, transforms=None):
        super(MyDataset, self).__init__()
        self.train = train
        self.transforms = transforms

        file_path = root_dir + 'imgs/*.jpg'
        file_mask_path = root_dir + 'masks/*.gif'

        self.images = sorted(glob(file_path))
        self.image_mask = sorted(glob(file_mask_path))

        # manually split the train/valid data
        split_ratio = int(len(self.images) * 0.7)
        if train:
            self.images = self.images[:split_ratio]
            self.image_mask = self.image_mask[:split_ratio]
        else:
            self.images = self.images[split_ratio:]
            self.image_mask = self.image_mask[split_ratio:]

    def __getitem__(self, index: int):
        image = Image.open(self.images[index]).convert('RGB')
        image_mask = Image.open(self.image_mask[index]).convert('L')

        if self.transforms:
            image = self.transforms(image)
            image_mask = self.transforms(image_mask)

        return {'img': image, 'mask': image_mask}

    def __len__(self):
        return len(self.images)

4.构建模型

下面模型的代码跟**2.**的介绍是对应的,建议对照看一下,就会有所理解

class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super(ConvBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x: torch.Tensor):
        return self.block(x)


class CopyAndCrop(nn.Module):
    def forward(self, x: torch.Tensor, encoded: torch.Tensor):
        _, _, h, w = encoded.shape
        crop = T.CenterCrop((h, w))(x)
        output = torch.cat((x, crop), 1)

        return output


class UNet(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super(UNet, self).__init__()

        self.encoders = nn.ModuleList([
            ConvBlock(in_channels, 64),
            ConvBlock(64, 128),
            ConvBlock(128, 256),
            ConvBlock(256, 512),
        ])
        self.down_sample = nn.MaxPool2d(2)
        self.copyAndCrop = CopyAndCrop()
        self.decoders = nn.ModuleList([
            ConvBlock(1024, 512),
            ConvBlock(512, 256),
            ConvBlock(256, 128),
            ConvBlock(128, 64),
        ])

        # PixelShuffle, UpSample will modify the output channel (you can add extra operation to update the channel, e.g.conv2d)
        # preffer use convTranspose2d, it won't modify the output channel
        self.up_samples = nn.ModuleList([
            nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        ])

        self.bottleneck = ConvBlock(512, 1024)
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1, stride=1)

    def forward(self, x: torch.Tensor):
        # encod
        encoded_features = []
        for enc in self.encoders:
            x = enc(x)
            encoded_features.append(x)
            x = self.down_sample(x)

        x = self.bottleneck(x)

        # decode
        for idx, denc in enumerate(self.decoders):
            x = self.up_samples[idx](x)
            encoded = encoded_features.pop()
            x = self.copyAndCrop(x, encoded)
            x = denc(x)

        output = self.final_conv(x)
        return output

5.模型训练

模型训练大约包含下面几个步骤,首先定义了几个必要的参数,例如图像大小,batch_size,device 等等。流程如下。没有介绍优化器和损失函数之类的,因为笔者自己理解还不够,但是代码里面是有的。代码里面有些绘图的内容,方便了可视化,感觉麻烦可以删掉。

Created with Raphaël 2.3.0 定义参数 加载数据(MyDataset) 创建dataset_loader 开始训练 训练集训练 验证 训练集训练 保存模型 达到epochs数量? 结束 yes no
batch_size = 4
n_iters = 10000
epochs = 10
learning_rate = 0.0002
n_workers = 2
width = 256
height = 256
channels = 3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = 44
random.seed(seed)
torch.manual_seed(seed)

if __name__ == '__main__':
    transforms = T.Compose([
        T.Resize((width, height)),
        T.ToTensor(),
        #     T.Normalize(mean=[0.485, 0.456, 0.406],
        #                 std=[0.229, 0.224, 0.225]),
        #     T.RandomHorizontalFlip()
    ])
    train_dataset = MyDataset(root_dir='./data/',
                              train=True,
                              transforms=transforms)
    val_dataset = MyDataset(root_dir='./data/',
                            train=False,
                            transforms=transforms)

    train_dataset_loader = DataLoader(dataset=train_dataset,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=n_workers)
    val_dataset_loader = DataLoader(dataset=val_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=n_workers)
    samples = next(iter(train_dataset_loader))

    fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(12, 4))
    fig.tight_layout()


    ax1.axis('off')
    ax1.set_title('input image')
    ax1.imshow(np.transpose(vutils.make_grid(samples['img'], padding=2).numpy(),
                           (1, 2, 0)))

    ax2.axis('off')
    ax2.set_title('input mask')
    ax2.imshow(np.transpose(vutils.make_grid(samples['mask'], padding=2).numpy(),
                           (1, 2, 0)), cmap='gray')

    plt.show()

    def dice_score(pred: torch.Tensor, mask: torch.Tensor):
        dice = (2 * (pred * mask).sum()) / (pred + mask).sum()
        return np.mean(dice.cpu().numpy())

    def iou_score(pred: torch.Tensor, mask: torch.Tensor):
        pass


    def plot_pred_img(samples, pred):
        fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, ncols=1, figsize=(12, 6))
        fig.tight_layout()

        ax1.axis('off')
        ax1.set_title('input image')
        ax1.imshow(np.transpose(vutils.make_grid(samples['img'], padding=2).numpy(),
                                (1, 2, 0)))

        ax2.axis('off')
        ax2.set_title('input mask')
        ax2.imshow(np.transpose(vutils.make_grid(samples['mask'], padding=2).numpy(),
                                (1, 2, 0)), cmap='gray')

        ax3.axis('off')
        ax3.set_title('predicted mask')
        ax3.imshow(np.transpose(vutils.make_grid(pred, padding=2).cpu().numpy(),
                                (1, 2, 0)), cmap='gray')

        plt.show()


    def plot_train_progress(model):
        #     model.eval()

        #     with torch.no_grad():
        samples = next(iter(val_dataset_loader))
        val_img = samples['img'].to(device)
        val_mask = samples['mask'].to(device)

        pred = model(val_img)

        plot_pred_img(samples, pred.detach())


    def train(model, optimizer, criteration, scheduler=None):
        train_losses = []
        val_lossess = []
        lr_rates = []

        # calculate train epochs
        epochs = int(n_iters / (len(train_dataset) / batch_size))

        for epoch in range(epochs):
            model.train()
            train_total_loss = 0
            train_iterations = 0

            for idx, data in enumerate(tqdm(train_dataset_loader)):
                train_iterations += 1
                train_img = data['img'].to(device)
                train_mask = data['mask'].to(device)

                optimizer.zero_grad()
                # speed up the training
                with torch.set_grad_enabled(True):
                    train_output_mask = model(train_img)
                    train_loss = criterion(train_output_mask, train_mask)
                    train_total_loss += train_loss.item()

                train_loss.backward()
                optimizer.step()

            train_epoch_loss = train_total_loss / train_iterations
            train_losses.append(train_epoch_loss)

            # evaluate mode
            model.eval()
            with torch.no_grad():
                val_total_loss = 0
                val_iterations = 0
                scores = 0

                for vidx, val_data in enumerate(tqdm(val_dataset_loader)):
                    val_iterations += 1
                    val_img = val_data['img'].to(device)
                    val_mask = val_data['mask'].to(device)

                    with torch.set_grad_enabled(False):
                        pred = model(val_img)

                        val_loss = criterion(pred, val_mask)
                        val_total_loss += val_loss.item()
                        scores += dice_score(pred, val_mask)

                val_epoch_loss = val_total_loss / val_iterations
                dice_coef_scroe = scores / val_iterations

                val_lossess.append(val_epoch_loss)

                plot_train_progress(model)
                print('epochs - {}/{} [{}/{}], dice score: {}, train loss: {}, val loss: {}'.format(
                    epoch + 1, epochs,
                    idx + 1, len(train_dataset_loader),
                    dice_coef_scroe, train_epoch_loss, val_epoch_loss
                ))
                torch.save(model, 'modelCar' + str(epoch) + '.pkl')
            lr_rates.append(optimizer.param_groups[0]['lr'])
            if scheduler:
                scheduler.step()  # decay learning rate
                print('LR rate:', scheduler.get_last_lr())


        return {
            'lr': lr_rates,
            'train_loss': train_losses,
            'valid_loss': val_lossess
        }
    model = UNet(in_channels=3, out_channels=1).to(device)
    criterion = nn.BCEWithLogitsLoss()
    # criterion = smp.losses.DiceLoss(mode='binary')
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    history = train(model, optimizer, criterion)

6.训练结果

经过一段时间的训练,在这个数据集中的效果还不错,下面放几张图来看一下,展示3个epochs的吧。(dic score的计算可能并不准确)
100%|██████████| 891/891 [07:31<00:00, 1.97it/s]
100%|██████████| 382/382 [01:20<00:00, 4.76it/s]
Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).
epochs - 1/11 [891/891], dice score: -1.732437749183615, train loss: 0.10204851534531173, val loss: 0.037313264545969935
在这里插入图片描述
100%|██████████| 891/891 [07:29<00:00, 1.98it/s]
100%|██████████| 382/382 [01:15<00:00, 5.07it/s]
Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).
epochs - 2/11 [891/891], dice score: -1.2491931439382244, train loss: 0.027105745536460217, val loss: 0.021171443983522387
在这里插入图片描述
100%|██████████| 891/891 [07:30<00:00, 1.98it/s]
100%|██████████| 382/382 [01:15<00:00, 5.08it/s]
Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).
epochs - 3/11 [891/891], dice score: -1.2030698774060653, train loss: 0.01996645850143014, val loss: 0.02303329437815082
在这里插入图片描述

  • 5
    点赞
  • 43
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Matrix_CS

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值