U-Net代码实现与说明

概要

系列文章为《Pytorch神经网络实战》学习笔记

Pytorch 2.2.2 (CUDA11.2) Python 3.9.7|远程连接 Anaconda 环境


一、概述

代码仓库:

U-Net|Basic Module of Semantic Segmentation with PyTorch

代码文件使用说明

二、架构

各文件的作用:

datasets.py - 数据集相关的代码;

loss_utils.py - 自定义的损失函数及其相关的辅助函数;

human_dataset - 存储数据集的文件夹,用于训练深度学习模型;

augument.py - 用于数据增强;

utils.py - 辅助函数和工具函数;

UNet.py - 实现 U-Net 模型;

old_unet.py - 旧版本的 U-Net 模型的代码;

train.py - 模型训练;

infer.py - 在训练完成后使用模型进行预测或推断;

三、使用说明

augument.py

代码剖析:

import torchvision
import torchvision.transforms.functional as TF
import numpy as np
from PIL import Image
  • torchvision是PyTorch的一个视觉库,提供了一些常用的图像数据集、模型和图像变换方法。
  • torchvision.transforms.functional模块被导入并重命名为TF,用于进行图像变换。
  • numpy库被导入并重命名为np,用于处理数组和矩阵运算。
  • PIL库中导入Image模块,用于图像处理。
def horizontal_flip(img , ref):
    return TF.hflip(img), TF.hflip(ref)
  • 定义了一个名为horizontal_flip的函数,接受两个参数:imgref,分别表示原始图像和参考图像。
  • 函数内部调用了TF.hflip函数,将原始图像和参考图像水平翻转,并将结果作为元组返回。
def vertical_flip(img, ref):
    return TF.vflip(img), TF.vflip(ref)
  • 定义了一个名为vertical_flip的函数,接受两个参数:imgref,分别表示原始图像和参考图像。
  • 函数内部调用了TF.vflip函数,将原始图像和参考图像垂直翻转,并将结果作为元组返回。
def random_crop(img, ref, crop_size):
    assert crop_size <= img.width and crop_size <= img.height

    max_left = img.height - crop_size
    max_top = img.width - crop_size

    left = 0
    top = 0
    if max_left > 0:
        left = np.random.randint(low=0, high=max_left, size=1)[0]
    if max_top > 0:
        top = np.random.randint(low=0, high=max_top, size=1)[0]
    img_crop = TF.crop(img, left, top, crop_size, crop_size)
    ref_crop = TF.crop(ref, left, top, crop_size, crop_size)
    return img_crop, ref_crop
  • 定义了一个名为random_crop的函数,接受三个参数:imgrefcrop_size,分别表示原始图像、参考图像和裁剪大小。
  • 函数内部首先确保裁剪大小不超过原始图像的宽度和高度。
  • 然后计算出裁剪的最大左上角坐标。
  • 根据最大坐标随机生成裁剪的左上角坐标。
  • 使用TF.crop函数对原始图像和参考图像进行裁剪,并将裁剪结果作为元组返回。
if __name__ == "__main__":
    img = Image.open("OHAZE/train/hazy/03_outdoor_hazy.jpg")
    ref = Image.open("OHAZE/train/GT/03_outdoor_GT.jpg")
    random_crop(img, ref, img.height - 1)

    print(help(TF.crop))
  • 如果当前脚本作为主程序运行,则执行以下操作。
  • 使用Image.open函数分别打开了两张图像,分别作为原始图像img和参考图像ref
  • 调用了random_crop函数,但没有将其返回的裁剪结果保存到变量中。
  • 打印了TF.crop函数的帮助信息。

datasets.py

代码剖析:

from torch.utils.data import Dataset
from PIL import Image
import os
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from augument import horizontal_flip, vertical_flip
import numpy as np
  • 这些是导入所需的库和模块。torch.utils.data.Dataset 用于创建自定义数据集类,PIL.Image 用于处理图像,os 用于文件路径操作,torchvision.transforms.functional 中的 TF 包含了图像变换功能,matplotlib.pyplot 用于图像展示,augument 中应该包含了水平翻转和垂直翻转的函数,numpy 用于数值计算。
class SegDatasets(Dataset):
    def __init__(self, image_size, data_root, input_dir_name,
                 label_dir_name, h_flip, v_flip, train=True):
        # 定义属性
        self.image_size = image_size
        self.data_dir = data_root
        self.input_dir_name = input_dir_name
        self.label_dir_name = label_dir_name
        self.h_flip = h_flip
        self.v_flip = v_flip
        self.train = train
        if self.train:
            self.prefix = "train_"
        else:
            self.prefix = "val_"

        # 检查目录是否存在
        if not os.path.exists(self.data_dir):
            raise Exception(r"[!] data set does not exist!")

        # 获取所有训练数据的名称,存储到self.files列表中
        self.files = sorted(os.listdir(os.path.join(self.data_dir,
                                                    self.prefix + self.input_dir_name)))
  • 定义 SegDatasets 的类,它继承自 torch.utils.data.Dataset。在 __init__ 方法中,初始化了数据集的属性,包括图像大小、数据根目录、输入图像目录名称、标签目录名称、是否水平翻转、是否垂直翻转以及是否为训练数据集。然后根据训练标志位决定文件名前缀是 "train_" 还是 "val_",并检查数据目录是否存在。最后,获取所有训练数据的文件名,并存储在 self.files 列表中。
def __getitem__(self, item):
    file_name = self.files[item]
    # 打开img和对应的mask,file_name[:-4] + "_matte.png"代表原始数据集中mask的命名方式
    img = Image.open(os.path.join(self.data_dir,
                                  self.prefix + self.input_dir_name,
                                  file_name)).convert('RGB')
    mask = Image.open(os.path.join(self.data_dir,
                                   self.prefix + self.label_dir_name,
                                   file_name[:-4] + "_matte.png")).convert('L')

    # 将img和mask进行resize,统一为相同尺寸
    img = TF.resize(img, (self.image_size, self.image_size))
    mask = TF.resize(mask, (self.image_size, self.image_size))

    if self.train:
        # 以0.5的概率进行数据增强,增强方式必须保证img和mask的变换是完全对应的
        if self.h_flip and np.random.random() > 0.5:
           img, mask = horizontal_flip(img, mask)

        if self.v_flip and np.random.random() > 0.5:
            img, mask = vertical_flip(img, mask)

    # 将图像转为tensor类型
    img = TF.to_tensor(img)
    mask = TF.to_tensor(mask)

    # 以字典形式返回img、mask和img的名字
    out = {'human': img, 'mask': mask, "img_name": file_name}

    return out
  • __getitem__ 方法用于获取数据集中的一个样本。首先根据索引获取对应的文件名,然后打开对应的图像和标签文件,并进行大小调整。如果是训练数据集,就以一定的概率进行数据增强,包括水平翻转和垂直翻转。最后将图像和标签转换为 Tensor 类型,并以字典形式返回,包括图像、掩码和图像名称。
def __len__(self):
    return len(self.files)
  • __len__ 方法返回数据集的长度,即文件列表的长度。
if __name__ == "__main__":
    # 一些训练参数和常量
    # 构建训练Dataset
    train_set = SegDatasets(IMAGE_SIZE, DATA_ROOT, INPUT_DIR_NAME, LABEL_DIR_NAME, H_FLIP, V_FLIP, train=True)

    # 数据集中图像数量
    print("num of Train set {}".format(len(train_set)))

    # 获取数据集中的第3条数据的原始图像img、掩码mask和图像名称
    img = train_set[3]["human"]
    mask = train_set[3]["mask"]
    name = train_set[3]["img_name"]

    # 展示原始图像img和掩码mask
    plt.subplot(1, 2, 1)
    plt.imshow(img.numpy().transpose(1, 2, 0))
    plt.subplot(1, 2, 2)
    plt.imshow(mask.numpy().squeeze())
    plt.show()
  • 在主函数中设置了一些训练参数和常量,并构建了训练数据集对象 train_set。然后输出了训练数据集的样本数量,并展示了第 3 个样本的原始图像和掩码。

loss_utils.py

代码剖析:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
  • 这一部分导入了所需的Python库,包括操作系统库(os)、PyTorch库(torch)、PyTorch的神经网络模块(torch.nn)、PyTorch的函数库(torch.nn.functional)以及用于绘图的matplotlib库。
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, pred, label, smooth=1):
        # 将预测结果pred和真实标签label进行展开操作
        # 展开后的pred和label均为一维的向量,便于计算交并比
        pred = F.sigmoid(pred)
        pred = pred.view(-1)
        label = label.view(-1)

        # 计算pred和label的交集
        intersection = (pred * label).sum()
        # 计算pred和label的并集
        union = pred.sum() + label.sum()
        # 根据前述的dice损失计算公式,加入平滑因子计算交并比
        dice = (2. * intersection + smooth) / (union + smooth)

        return 1 - dice
  • 这部分定义了DiceLoss类,用于计算Dice损失函数。在__init__方法中,调用了父类的构造函数。在forward方法中,计算了预测结果和真实标签的Dice损失。
class DiceBCELoss(nn.Module):
    def __init__(self):
        super(DiceBCELoss, self).__init__()

    def forward(self,
                inputs: torch.Tensor,
                targets: torch.Tensor,
                smooth: int = 1
                ) -> torch.Tensor:
        inputs = F.sigmoid(inputs)

        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss

        return Dice_BCE
  • 这部分定义了DiceBCELoss类,结合了Dice损失和二元交叉熵(BCE)损失。在forward方法中,首先对输入进行sigmoid激活,然后计算Dice损失和BCE损失,并将它们结合起来作为最终的损失值。
class FocalLoss(nn.Module):
    def __init__(self):
        super(FocalLoss, self).__init__()

    def forward(self, pred, label, alpha=0.8, gamma=2):
        pred = F.sigmoid(pred)
        # 将pred和label展开为一维
        pred = pred.view(-1)
        label = label.view(-1)

        # 计算BCE损失
        bce = F.binary_cross_entropy(pred, label, reduction='mean')
        bce = torch.exp(-bce)

        # Focal Loss的核心计算公式
        focal_loss = alpha * (1 - bce) ** gamma * bce
        return focal_loss
  • 这部分定义了FocalLoss类,用于计算Focal损失函数。在forward方法中,首先对预测结果进行sigmoid激活,然后计算二元交叉熵(BCE)损失,并根据Focal Loss的计算公式计算最终的损失值。
class LossWriter():
    def __init__(self, save_dir):
        self.save_dir = save_dir

    def add(self, loss_name, loss, i):
        with open(os.path.join(self.save_dir, loss_name + ".txt"), mode="a") as f:
            term = str(i) + " " + str(loss) + "\n"
            f.write(term)
            f.close()
  • 这部分定义了LossWriter类,用于将损失值写入文本文件。初始化方法__init__接收保存目录参数save_dir,add方法用于将损失值写入以损失函数命名的文本文件中。
def plot_loss(txt_name, x_label, y_label, title, legend, font_size, fig_size, save_name):
    """
    损失函数绘图代码
    """
    all_i = []
    all_val = []
    with open(txt_name, "r") as f:
        # 读取txt文件中的所有行
        all_lines = f.readlines()
        # 遍历每一行
        for line in all_lines:
            # 每行的第一个元素和第二个元素以空格分隔
            sp= line.split(" ")
            i = int(sp[0])
            val = float(sp[1])
            all_i.append(i)
            all_val.append(val)
    # 绘图以及参数指定
    plt.figure(figsize=(6, 4))
    plt.plot(all_i, all_val)
    plt.xlabel(x_label, fontsize=font_size)
    plt.ylabel(y_label, fontsize=font_size)
    if legend:
        plt.legend(legend, fontsize=font_size)
    plt.title(title, fontsize=font_size)
    plt.tick_params(labelsize=font_size)
    plt.savefig(save_name, dpi=200, bbox_inches = "tight")
    plt.show()
  • 这部分定义了plot_loss函数,用于绘制损失函数的曲线图。该函数从文本文件中读取损失值,并根据指定的参数绘制曲线图,包括x轴和y轴的标签、图表标题、图例、字体大小等。
if __name__ == "__main__":
    plot_loss(txt_name="results_unet/loss/bce_loss.txt", x_label="iteration",
              y_label="loss value", title="Loss of BCE on UNet",
              legend=None, font_size=15, fig_size=(10, 10),
              save_name="unet_BCE_loss.png")
  • 这部分是主函数,调用了plot_loss函数来绘制BCE损失函数的曲线图。你可以修改参数来绘制其他损失函数的曲线图,比如DiceLoss、DiceBCELoss和FocalLoss。

utils.py

代码剖析:

import shutil
import os
from PIL import Image
import numpy as np
  • 这部分是导入所需的库。shutil库用于高级文件操作,os库用于操作系统相关的功能,PIL库用于图像处理,numpy库用于数值计算。
def save_image(image_tensor, out_name):
    """
    save a single image
    :param image_tensor: torch tensor with size=(3, h, w)
    :param out_name: path+name+".jpg"
    :return: None
    """
  • 这是一个保存单张图片的函数。它接受一个PyTorch张量 image_tensor 和输出路径 out_name 作为参数。
    if len(image_tensor.size()) == 3:
  • 这里检查输入的张量是否是三维的,即是否为三通道图像。
        image_numpy = image_tensor.cpu().detach().numpy().squeeze(0)
  • 将PyTorch张量转换为NumPy数组,并去除第一个维度(批量维度),因为此处只处理单张图片。
        image_numpy[image_numpy > 0.5] = 255
        image_numpy[image_numpy < 0.5] = 0
  • 将像素值大于0.5的设置为255,小于0.5的设置为0,将像素值二值化。
        image_numpy = image_numpy.astype(np.uint8)
  • 将数组元素类型转换为无符号8位整数类型,以便保存为图像。
        image = Image.fromarray(image_numpy)
  • 使用PIL库从NumPy数组创建图像对象。
        image.save(out_name)
  • 保存图像到指定路径。
    else:
        raise ValueError("input tensor not with size (3, h, w)")
  • 如果输入张量不是三维的,抛出值错误异常。
return None
  • 返回空值。
def check_mk_dir(dir_name):
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
  • 检查目录是否存在,如果不存在则创建目录。
def split(ori_dir, tar_dir):
    files = os.listdir(ori_dir)
    for file in files:
        if len(file) > 10:
            shutil.copy(os.path.join(ori_dir, file), os.path.join(tar_dir, file))
  • 将原始目录中文件名长度大于10的文件复制到目标目录。
def make_sub_dirs(base_dir, sub_dirs):
    check_mk_dir(base_dir)
    for sub_dir in sub_dirs:
        check_mk_dir(os.path.join(base_dir, sub_dir))
  • 在基础目录下创建多个子目录。
def make_project_dir(train_dir, val_dir):
    make_sub_dirs(train_dir, ["train_images", "pth", "loss"])
    make_sub_dirs(val_dir, ["val_images"])
  • 在训练和验证目录下创建子目录。
if __name__ == "__main__":
    pass
  • 这是主函数的开始,但是主函数部分被注释掉了,没有具体实现。

loss_utils.py

代码剖析:

  1. 导入必要的库

    import os
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import matplotlib.pyplot as plt

  2. 定义Dice损失函数

    class DiceLoss(nn.Module):
        def __init__(self):
            super(DiceLoss, self).__init__()
    
        def forward(self, pred, label, smooth=1):
            pred = F.sigmoid(pred)
            pred = pred.view(-1)
            label = label.view(-1)
    
            intersection = (pred * label).sum()
            union = pred.sum() + label.sum()
            dice = (2. * intersection + smooth) / (union + smooth)
    
            return 1 - dice

  3. 定义结合二元交叉熵(BCE)的Dice损失函数

    class DiceBCELoss(nn.Module):
        def __init__(self):
            super(DiceBCELoss, self).__init__()
    
        def forward(self, inputs, targets, smooth=1):
            inputs = F.sigmoid(inputs)
    
            inputs = inputs.view(-1)
            targets = targets.view(-1)
    
            intersection = (inputs * targets).sum()
            dice_loss = 1 - (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
            BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
            Dice_BCE = BCE + dice_loss
    
            return Dice_BCE

  4. 定义Focal Loss

    class FocalLoss(nn.Module):
        def __init__(self):
            super(FocalLoss, self).__init__()
    
        def forward(self, pred, label, alpha=0.8, gamma=2):
            pred = F.sigmoid(pred)
            pred = pred.view(-1)
            label = label.view(-1)
    
            bce = F.binary_cross_entropy(pred, label, reduction='mean')
            bce = torch.exp(-bce)
    
            focal_loss = alpha * (1 - bce) ** gamma * bce
            return focal_loss

  5. 定义LossWriter类,用于将损失写入文件:

    class LossWriter():
        def __init__(self, save_dir):
            self.save_dir = save_dir
    
        def add(self, loss_name, loss, i):
            with open(os.path.join(self.save_dir, loss_name + ".txt"), mode="a") as f:
                term = str(i) + " " + str(loss) + "\n"
                f.write(term)
                f.close()

  6. 定义绘制损失曲线的函数

    def plot_loss(txt_name, x_label, y_label, title, legend, font_size, fig_size, save_name):
        all_i = []
        all_val = []
        with open(txt_name, "r") as f:
            all_lines = f.readlines()
            for line in all_lines:
                sp= line.split(" ")
                i = int(sp[0])
                val = float(sp[1])
                all_i.append(i)
                all_val.append(val)
        plt.figure(figsize=(6, 4))
        plt.plot(all_i, all_val)
        plt.xlabel(x_label, fontsize=font_size)
        plt.ylabel(y_label, fontsize=font_size)
        if legend:
            plt.legend(legend, fontsize=font_size)
        plt.title(title, fontsize=font_size)
        plt.tick_params(labelsize=font_size)
        plt.savefig(save_name, dpi=200, bbox_inches = "tight")
        plt.show()

  7. 主函数,用于调用绘制损失曲线的函数:

    if __name__ == "__main__":
        plot_loss(txt_name="results_unet/loss/bce_loss.txt", x_label="iteration",
                  y_label="loss value", title="Loss of BCE on UNet",
                  legend=None, font_size=15, fig_size=(10, 10),
                  save_name="unet_BCE_loss.png")

这段代码实现了定义不同的损失函数,以及绘制指定文件中损失值随迭代次数的变化曲线。

UNet.py

代码剖析:

  1. ConvBlock 类

    class ConvBlock(nn.Module):
        def __init__(self, ch_in, ch_out):
            super(ConvBlock, self).__init__()
            self.conv = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
                nn.BatchNorm2d(ch_out),
                nn.ReLU(inplace=True),
                nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
                nn.BatchNorm2d(ch_out),
                nn.ReLU(inplace=True))
        def forward(self, x):
            x = self.conv(x)
            return x
    • ConvBlock 类定义了一个卷积块,包括两个连续的卷积层(每层都包含卷积、批归一化和ReLU激活函数)。
    • ch_in 是输入通道数,ch_out 是输出通道数。
    • forward 方法定义了前向传播过程,将输入 x 经过卷积块的处理后返回结果。
  2. UpConvBlock 类

    class UpConvBlock(nn.Module):
        def __init__(self, ch_in, ch_out):
            super().__init__()
            self.up = nn.Sequential(
                nn.Upsample(scale_factor=2),
                nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
                nn.BatchNorm2d(ch_out),
                nn.ReLU(inplace=True)
            )
        def forward(self, x):
            x = self.up(x)
            return x
    • UpConvBlock 类定义了一个上采样块,包括上采样、卷积、批归一化和ReLU激活函数。
    • ch_in 是输入通道数,ch_out 是输出通道数。
    • forward 方法定义了前向传播过程,将输入 x 经过上采样块处理后返回结果。
  3. UNet 类

    class UNet(nn.Module):
        def __init__(self, ch_in=3, ch_out=1):
            super().__init__()
            feature_channels = [8, 16, 32, 64, 128]
            # 定义四个池化层
            self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
            self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
            # 定义卷积层和上采样层
            self.conv1 = ConvBlock(ch_in, feature_channels[0])
            self.conv2 = ConvBlock(feature_channels[0], feature_channels[1])
            self.conv3 = ConvBlock(feature_channels[1], feature_channels[2])
            self.conv4 = ConvBlock(feature_channels[2], feature_channels[3])
            self.conv5 = ConvBlock(feature_channels[3], feature_channels[4])
            self.up5 = UpConvBlock(feature_channels[4], feature_channels[3])
            self.up_conv5 = ConvBlock(feature_channels[4], feature_channels[3])
            self.up4 = UpConvBlock(feature_channels[3], feature_channels[2])
            self.up_conv4 = ConvBlock(feature_channels[3], feature_channels[2])
            self.up3 = UpConvBlock(feature_channels[2], feature_channels[1])
            self.up_conv3 = ConvBlock(feature_channels[2], feature_channels[1])
            self.up2 = UpConvBlock(feature_channels[1], feature_channels[0])
            self.up_conv2 = ConvBlock(feature_channels[1], feature_channels[0])
            # 最后一层卷积和sigmoid激活
            self.conv_last = nn.Conv2d(feature_channels[0], ch_out, kernel_size=1, stride=1, padding=0)
            self.sigmoid = nn.Sigmoid()
        def forward(self, x):
            # 编码器过程
            f1 = self.conv1(x)
            f2 = self.pool1(f1)
            f2 = self.conv2(f2)
            f3 = self.pool2(f2)
            f3 = self.conv3(f3)
            f4 = self.pool3(f3)
            f4 = self.conv4(f4)
            f5 = self.pool4(f4)
            f5 = self.conv5(f5)
            # 解码器过程
            up_f5 = self.up5(f5)
            up_f5 = torch.cat((f4, up_f5), dim=1)
            up_f5 = self.up_conv5(up_f5)
            up_f4 = self.up4(up_f5)
            up_f4 = torch.cat((f3, up_f4), dim=1)
            up_f4 = self.up_conv4(up_f4)
            up_f3 = self.up3(up_f4)
            up_f3 = torch.cat((f2, up_f3), dim=1)
            up_f3 = self.up_conv3(up_f3)
            up_f2 = self.up2(up_f3)
            up_f2 = torch.cat((f1, up_f2), dim=1)
            up_f2 = self.up_conv2(up_f2)
            mask = self.conv_last(up_f2)
            mask = self.sigmoid(mask)
            return mask
    • UNet 类定义了整个 UNet 模型的结构,包括编码器部分(卷积层和池化层)和解码器部分(上采样层和卷积层)。
    • 编码器部分将输入图像经过一系列卷积和池化操作,得到中间特征图。
    • 解码器部分将中间特征图经过上采样和卷积操作,与编码器对应层的特征图进行特征融合,最终输出预测的分割结果。
  4. 主函数

    if __name__ == "__main__":
        unet = UNet()
        x = torch.randn(size=(1, 3, 256, 256))
        print(unet(x).size())
    • 在主函数中创建 UNet 模型的实例,并生成一个随机输入张量 x。
    • 调用 UNet 模型进行前向传播,打印输出的预测 mask 的大小

train.py

用于训练图像分割网络的Python脚本。代码剖析:

  1. 导入所需的库和模块。

  2. 设置超参数,包括学习率、批量大小、训练时是否进行水平和垂直翻转等。

  3. 创建训练和验证数据集的DataLoader。

  4. 定义交叉熵损失函数(BCE损失)。

  5. 创建UNET模型,并将模型参数绑定到Adam优化器。

  6. 定义训练函数,其中遍历数据加载器中的所有数据,在每个迭代中执行以下操作:

    1. 将输入图像传递给UNET模型以获取预测的掩码。
    2. 计算预测掩码和真实掩码之间的BCE损失。
    3. 清除优化器的梯度,执行反向传播,并更新模型参数。
    4. 记录损失值并输出到控制台。
    5. 如果满足条件,保存训练过程中的分割掩码图像。
    6. 每隔一定的epoch保存模型的权重。
    7. 每隔一定的epoch计算验证集的分割效果,并保存结果用于观察模型表现。
  7. 最后,在主程序中调用train函数开始训练。

infer.py

代码剖析:

  1. 导入所需的库:

    import torch
    from old_unet import UNet  # 从旧的 UNet 模型中导入 UNet 类
    import torchvision.transforms.functional as TF
    from PIL import Image
    import matplotlib.pyplot as plt
    import numpy as np

  2. 加载预训练模型:

    device = torch.device("cpu")  # 指定设备为 CPU
    seg_net = UNet().to(device)  # 创建 UNet 模型实例并将其移动到指定设备
    seg_net.load_state_dict(torch.load("results/pth/30.pth", map_location="cpu"))  # 加载预训练权重
    seg_net.eval()  # 设置模型为评估模式

  3. 处理输入图像:

    image_path = "human_dataset/val_human/00031.png"  # 输入图像路径
    human_image = Image.open(image_path)  # 打开图像文件
    human_image = human_image.resize((256, 256))  # 调整图像大小为模型输入大小
    human_image = TF.to_tensor(human_image).to(device).unsqueeze(0)  # 图像转换为张量,并移动到指定设备

  4. 获取分割掩码:

    with torch.no_grad():
        predict_mask = seg_net(human_image)  # 使用 UNet 模型进行前向传播得到分割掩码
        predict_mask[predict_mask > 0.5] = 1  # 对掩码进行阈值处理
        predict_mask[predict_mask <= 0.5] = 0

  5. 可视化输出结果:

    human_image = human_image * 255  # 将图像张量转换为 0-255 范围内的整数
    predict_mask = predict_mask * 255  # 将预测掩码转换为 0-255 范围内的整数
    predict_mask = torch.cat((predict_mask, predict_mask, predict_mask), dim=1)  # 将预测掩码复制为3通道
    result = torch.cat((human_image, predict_mask), dim=3)[0]  # 将人像图像和预测掩码拼接在一起
    result = result.cpu().detach().numpy().transpose(1, 2, 0).astype(np.uint8)  # 将结果转换为 numpy 数组并转置通道顺序
    plt.imshow(result)  # 显示结果图像
    plt.savefig("pictures/demo.png", dpi=500, bbox_inches='tight')  # 将结果保存为图片文件
    plt.show()  # 显示图像

这段代码通过 UNet 模型对输入的人像进行分割,并将分割结果与原始图像拼接并可视化输出。


总结

U-Net简单应用的使用说明手册,对照仓库中的代码文件使用

  • 16
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值