U-Net和U-Net++代码,Pytorch实现,数据集2018 Data Science Bowl

参考代码

Python Unet ++ :医学图像分割,医学细胞分割,Unet医学图像处理,语义分割

数据集

100+医学影像数据集集锦

改动

保证已有pytorch环境。

  1. 上述链接代码导入pycharm

  2. 安装 pip install -U albumentations 需要用命令,其他的就用pycharm自动安装即可

  3. 将数据dsb-18放到inputs文件夹下,并修改preprocess_dsb2018.py中paths的路径
    在这里插入图片描述

  4. 运行preprocess_dsb2018.py文件,得到dsb2018_96文件夹,以及下面的已经处理好的数据

  5. 将train.py文件中的cuda换成有cuda就cuda,没有就cpu的形式:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)
  1. 将train.py的main方法最后一行清空显存的cuda的语句(torch.cuda.empty_cache())加上有cuda条件。

cpu运行的话,如果换成以下cpu清内存,会发生错误就是内存中的被清空,反向传播计算的梯度无法保存,效果不提升。但是不清空速度会慢,但没有办法

    # 删除所有全局变量
    for var in globals():
        del globals()[var]

    # 删除所有局部变量
    for var in locals():
        del locals()[var]

    # 触发垃圾回收器对不再使用的对象进行清理
    gc.collect()
  1. 接着将unet++模型定义中的输出语句注释

  2. 将epoch从100改为10,轮数太多运行太慢

  3. 改显示的epoch

   epoch = epoch + 1
   print('Epoch [%d/%d]' % (epoch, config['epochs']))
  1. 将train.py的main方法中的cudnn.benchmark = True这句代码加上cuda可用条件

  2. 将使用命令行参数的代码中的默认值使用ini配置文件中的值配置。这样可以使用ini配置文件改配置值,也可以使用命令行改配置。优先级是命令行大于ini。

  3. 将保存的模型,在模型名称后加上时间,防止多次运行覆盖之前运行结果

import os
import datetime

timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
folder_name = f"models/{config['name']}_{timestamp}"

os.makedirs(folder_name, exist_ok=True)
  1. 加入tensorboard:writer = SummaryWriter(f"{folder_name}/SummaryWriter"),可以可视化展示运行过程中iou、loss变化
   writer.add_scalars("loss",
                      {"train_loss": train_log['loss'],
                       "test_loss": val_log['loss']},
                      epoch)
   writer.add_scalars("iou",
                      {"train_iou": train_log['iou'],
                       "test_iou": val_log['iou']},
                      epoch)
  1. 加入时间控制:

    start_time = time.time()
    end_time = time.time()
    print("本轮截至运行时间:",end_time - start_time)
    
  2. 开始运行train.py

改动后代码及代码理解

  1. 目录结构
    在这里插入图片描述
  2. defaultValue.ini
[training]
epochs = 10
batch_size = 8
arch = NestedUNet
input_channels = 3
num_classes = 1
input_w = 96
input_h = 96
loss = BCEDiceLoss

[data]
dataset = dsb2018_96
img_ext = .png
mask_ext = .png

[optimizer]
optimizer = Adam
lr = 1e-3
momentum = 0.9
weight_decay = 1e-4
nesterov = False

[scheduler]
scheduler = CosineAnnealingLR
min_lr = 1e-5
factor = 0.1
patience = 2
milestones = 1,2
gamma = 0.666666

[early_stopping]
early_stopping = -1

[other]
num_workers = 0
  1. archs.py
import torch
from torch import nn

__all__ = ['UNet', 'NestedUNet']

"""
- 这是一个 VGG 网络中常用的基本块类 VGGBlock 的定义,它包含两个卷积层和两个批归一化层,其中 ReLU 激活函数被应用在每个卷积层之后。
- 这个类的构造函数 __init__ 接受三个参数:输入通道数 in_channels,中间通道数 middle_channels 和输出通道数 out_channels。在构造函数中,首先调用父类的构造函数 super().__init__() 进行初始化。然后,定义了卷积层 self.conv1 和 self.conv2,分别将输入通道数变换为中间通道数和输出通道数。同时,两个卷积层的卷积核大小都是 3x3,**填充大小为 1,以保持特征图的尺寸不变**。接着,定义了批归一化层 self.bn1 和 self.bn2,用于规范化卷积层的输出。最后,定义了 ReLU 激活函数 self.relu。
- forward 方法是前向传播函数,在该方法中实现了网络层的前向计算逻辑。首先通过第一个卷积层和批归一化层处理输入 x,然后应用 ReLU 激活函数。接着,将结果传递给第二个卷积层和批归一化层,再次应用 ReLU 激活函数。最后,将输出结果返回。
- 这个类定义了一个 VGG 网络中常用的基本块,可以在 UNet 网络的不同层中多次调用这个块来构建整个网络。
"""
class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        return out

"""
- 这段代码定义了一个名为 `UNet` 的类,继承自 `nn.Module` 类。该类是一个 U-Net 网络,用于图像分割任务。它的输入是一个大小为 `input_channels` 的图像,输出是一个大小为 `num_classes` 的分割图像。
- 在类的初始化函数中,定义了一个包含 5 个元素的列表 `nb_filter`,表示每个卷积层的输出通道数。然后创建了 `MaxPool2d` 和 `Upsample` 对象,用于网络中的最大池化和上采样操作。接下来创建了 5 个卷积块对象 `conv0_0` 到 `conv4_0`,每个卷积块包含两个卷积层和一个 ReLU 层。其中,第一个卷积层的输入通道数为 `input_channels` 或前一层卷积层的输出通道数,输出通道数为 `nb_filter[0]` 到 `nb_filter[4]`,依次递增。第二个卷积层的输入通道数等于第一个卷积层的输出通道数,输出通道数仍然为 `nb_filter[0]` 到 `nb_filter[4]`。这些卷积块的作用是提取不同尺度的特征信息。
- 接下来创建了 4 个卷积块对象 `conv3_1` 到 `conv0_4`。这些卷积块的输入通道数分别为 `nb_filter[3]+nb_filter[4]`、`nb_filter[2]+nb_filter[3]`、`nb_filter[1]+nb_filter[2]` 和 `nb_filter[0]+nb_filter[1]`,输出通道数仍然为 `nb_filter[0]`。这些卷积块的作用是将不同尺度的特征信息进行融合,以便更好地进行分割。
- 最后定义了一个卷积层 `final`,用于生成分割结果。它的输入通道数等于最后一个卷积块的输出通道数,输出通道数为 `num_classes`。在 `forward()` 函数中,先通过各个卷积块提取特征信息,然后将不同尺度的特征信息进行融合,最终通过卷积层 `final` 生成分割结果。
- 在这里,`1` 是 `torch.cat()` 函数的第二个参数 `dim` 的值。`torch.cat()` 函数用于沿着指定的维度将张量进行拼接。
- 在这个特定的语句中,`torch.cat([x3_0, self.up(x4_0)], 1)` 表示将 `x3_0` 和 `self.up(x4_0)` 沿着第一个维度(维度索引从0开始)进行拼接。通过使用 `1` 作为 `dim` 参数的值,表示沿着第一个维度进行拼接。
- 这里的目的是将 `x3_0` 和上采样后的 `x4_0` 进行拼接,以便在 U-Net 网络中进行特征融合操作。拼接后的结果将作为输入传递给 `self.conv3_1` 卷积块进行处理。
"""
"""
测试网络正确性代码
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 测试网络的正确性
input = torch.ones((2,3,400,400))
unet =UNet(1).to(device)  
output = unet(input)
print(output.shape)

- 这里出现了问题:
    - 当测试图片大小为572,572时:
        1. torch.Size([2, 32, 572, 572])
        2. torch.Size([2, 64, 286, 286])
        3. torch.Size([2, 128, 143, 143])
        4. torch.Size([2, 256, 71, 71])
        5. torch.Size([2, 512, 35, 35])
        6. torch.Size([2, 512, 70, 70])
- 这是因为:
    - 下采样71是单数,变到了35,而上采样35乘2是70
    - x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))  这个加法没办法加
"""
"""
torch.Size([2, 32, 400, 400])
torch.Size([2, 64, 200, 200])
torch.Size([2, 128, 100, 100])
torch.Size([2, 256, 50, 50])
torch.Size([2, 512, 25, 25])
torch.Size([2, 256, 50, 50])
torch.Size([2, 128, 100, 100])
torch.Size([2, 64, 200, 200])
torch.Size([2, 32, 400, 400])
torch.Size([2, 1, 400, 400])
"""
class UNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, **kwargs):
        super().__init__()

        nb_filter = [32, 64, 128, 256, 512]

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # scale_factor:放大的倍数  插值

        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        self.conv3_1 = VGGBlock(nb_filter[3] + nb_filter[4], nb_filter[3], nb_filter[3])
        self.conv2_2 = VGGBlock(nb_filter[2] + nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv1_3 = VGGBlock(nb_filter[1] + nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv0_4 = VGGBlock(nb_filter[0] + nb_filter[1], nb_filter[0], nb_filter[0])

        self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)

    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x2_0 = self.conv2_0(self.pool(x1_0))
        x3_0 = self.conv3_0(self.pool(x2_0))
        x4_0 = self.conv4_0(self.pool(x3_0))

        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))

        output = self.final(x0_4)
        return output

"""
- conv0_0: 前面的0代表第0层。后面的0代表第0列,斜着的列
- unet++中的通道数: 每一层的输出通道数都一致,输入通道数为前面密集连接拼起来的和
"""
"""
# 测试网络的正确性
input = torch.ones((2,3,400,400))
nest_unet =NestedUNet(1, deep_supervision=True).to(device)  
output = nest_unet(input)
for result in output:
    print("out: ",result.shape)
"""
"""
input: torch.Size([2, 3, 400, 400])
x0_0: torch.Size([2, 32, 400, 400])
x1_0: torch.Size([2, 64, 200, 200])
x0_1: torch.Size([2, 32, 400, 400])
x2_0: torch.Size([2, 128, 100, 100])
x1_1: torch.Size([2, 64, 200, 200])
x0_2: torch.Size([2, 32, 400, 400])
x3_0: torch.Size([2, 256, 50, 50])
x2_1: torch.Size([2, 128, 100, 100])
x1_2: torch.Size([2, 64, 200, 200])
x0_3: torch.Size([2, 32, 400, 400])
x4_0: torch.Size([2, 512, 25, 25])
x3_1: torch.Size([2, 256, 50, 50])
x2_2: torch.Size([2, 128, 100, 100])
x1_3: torch.Size([2, 64, 200, 200])
x0_4: torch.Size([2, 32, 400, 400])
out:  torch.Size([2, 1, 400, 400])
out:  torch.Size([2, 1, 400, 400])
out:  torch.Size([2, 1, 400, 400])
out:  torch.Size([2, 1, 400, 400])
"""
class NestedUNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
        super().__init__()

        nb_filter = [32, 64, 128, 256, 512]

        self.deep_supervision = deep_supervision

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        self.conv0_1 = VGGBlock(nb_filter[0] + nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_1 = VGGBlock(nb_filter[1] + nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_1 = VGGBlock(nb_filter[2] + nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv3_1 = VGGBlock(nb_filter[3] + nb_filter[4], nb_filter[3], nb_filter[3])

        self.conv0_2 = VGGBlock(nb_filter[0] * 2 + nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_2 = VGGBlock(nb_filter[1] * 2 + nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_2 = VGGBlock(nb_filter[2] * 2 + nb_filter[3], nb_filter[2], nb_filter[2])

        self.conv0_3 = VGGBlock(nb_filter[0] * 3 + nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_3 = VGGBlock(nb_filter[1] * 3 + nb_filter[2], nb_filter[1], nb_filter[1])

        self.conv0_4 = VGGBlock(nb_filter[0] * 4 + nb_filter[1], nb_filter[0], nb_filter[0])

        if self.deep_supervision:
            self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
        else:
            self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)

    def forward(self, input):
        # print('input:', input.shape)
        x0_0 = self.conv0_0(input)
        # print('x0_0:', x0_0.shape)
        x1_0 = self.conv1_0(self.pool(x0_0))
        # print('x1_0:', x1_0.shape)
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
        # print('x0_1:', x0_1.shape)

        x2_0 = self.conv2_0(self.pool(x1_0))
        # print('x2_0:', x2_0.shape)
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        # print('x1_1:', x1_1.shape)
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
        # print('x0_2:', x0_2.shape)

        x3_0 = self.conv3_0(self.pool(x2_0))
        # print('x3_0:', x3_0.shape)
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        # print('x2_1:', x2_1.shape)
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        # print('x1_2:', x1_2.shape)
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
        # print('x0_3:', x0_3.shape)
        x4_0 = self.conv4_0(self.pool(x3_0))
        # print('x4_0:', x4_0.shape)
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        # print('x3_1:', x3_1.shape)
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        # print('x2_2:', x2_2.shape)
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        # print('x1_3:', x1_3.shape)
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
        # print('x0_4:', x0_4.shape)

        if self.deep_supervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4]

        else:
            output = self.final(x0_4)
            return output
  1. dataset.py
import os

import cv2
import numpy as np
import torch
import torch.utils.data

# 数据集处理类
class Dataset(torch.utils.data.Dataset):
    def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None):
        """
        Args:
            img_ids (list): Image ids. 存着所有图像名称的列表
            img_dir: Image file directory.
            mask_dir: Mask file directory.
            img_ext (str): Image file extension. image图像的类型
            mask_ext (str): Mask file extension. mask图像的类型
            num_classes (int): Number of classes.
            transform (Compose, optional): Compose transforms of albumentations. Defaults to None.

        Note:
            Make sure to put the files as the following structure:
            <dataset name>
            ├── images
            |   ├── 0a7e06.jpg
            │   ├── 0aab0a.jpg
            │   ├── 0b1761.jpg
            │   ├── ...
            |
            └── masks
                ├── 0
                |   ├── 0a7e06.png
                |   ├── 0aab0a.png
                |   ├── 0b1761.png
                |   ├── ...
                |
                ├── 1
                |   ├── 0a7e06.png
                |   ├── 0aab0a.png
                |   ├── 0b1761.png
                |   ├── ...
                ...
        """
        self.img_ids = img_ids
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_ext = img_ext
        self.mask_ext = mask_ext
        self.num_classes = num_classes
        self.transform = transform

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]

        # 读入image图像
        img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext))

        # 这里mask为列表,分几类,就有几-1个值.也就是每个类别对背景二分类。
        mask = []
        for i in range(self.num_classes):
            # 由于mask是灰度图,与使用灰度图的方式读入,但是使用[..., None]强制将灰度图像转换为具有额外维度的三维数组。
            # 这样做通常是为了与 RGB 图像具有相同的形状,以方便进行处理。
            mask.append(cv2.imread(os.path.join(self.mask_dir, str(i),
                                                img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None])

        # np.dstack(mask)是将多个mask图像沿着深度维度进行堆叠,生成一个新的三维数组。
        # mask变量是一个二维列表,每个元素表示一类掩码图像,经过np.dstack()后,数组的深度维度就表示不同的掩码类别。
        # 也就是:转置卷积和1*1卷积综合作用后生成大小和原图一致,但是通道数为分类类别数。每个通道中每个像素的值是一个概率,代表这个像素属于该通道的概率。
        # 这里就是生成了类别数个通道,每个通道代表每个像素分到这个通道的概率,但是这里是255
        mask = np.dstack(mask)

        """
             这段代码是使用`albumentations`库对图像进行数据增强。
               `albumentations`库是一个开源的图像增强库,能够在训练深度学习模型时为图像增加随机变换和扰动,从而提高模型的泛化能力。
                在这段代码中,如果定义了`transform`参数(即数据增强的变换),则会使用该参数对图像和掩模进行增强。
            
             如果要使用数据增强,`transform`参数应该传入一个`albumentations`库中定义的变换函数或变换组合。
                `albumentations`库提供了多种各具特色的图像增强操作,例如随机裁剪、旋转、缩放、翻转、色彩调整等。
                 这些增强操作可以单独使用,也可以通过`Compose`函数组合在一起形成一个变换组合。
    
            下面是一个示例,展示如何使用`Compose`函数将多个增强操作组合在一起:
    
            ```python
            import albumentations as A
    
            transform = A.Compose([
                A.RandomCrop(width=256, height=256),
                A.HorizontalFlip(p=0.5),
                A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            ])
            ```
    
            在上述示例中,`RandomCrop`表示随机裁剪,`HorizontalFlip`表示水平翻转,`ColorJitter`表示颜色调整。
            这些操作会按照指定的概率(例如`p=0.5`)对输入图像进行随机变换。
            然后,可以将`transform`作为参数传递给图像增强代码中的`self.transform`,以应用相应的数据增强操作。
            
            具体地,`self.transform(image=img, mask=mask)`会将输入的原始图像和掩模传递给变换函数进行增强,返回增强后的图像和掩模。
            `augmented`是一个字典,包含了增强后的图像和掩模。
            `img = augmented['image']`和`mask = augmented['mask']`用于分别获取增强后的图像和掩模,并将其赋值回原来的变量。
    
            需要注意的是,`albumentations`库处理的图像和掩模都是`numpy`数组格式,且通道顺序为RGB。PyTorch以及OpenCV中通道都是RGB
        """
        if self.transform is not None:
            augmented = self.transform(image=img, mask=mask)  # 这个包比较方便,能把mask也一并做掉
            img = augmented['image']  # 参考https://github.com/albumentations-team/albumentations
            mask = augmented['mask']

        # 这段代码对图像和掩模进行了一些预处理操作。
        # 将图像的数据类型转换为float32,并将像素值归一化到[0, 1]的范围。
        # 改变图像的维度顺序,将通道维度放在最前面,以适应深度学习框架的输入要求。
        # 同样地,对掩模也进行了类似的处理,和神经网络预测的结果一样,要转换为0-1的范围内,也就是概率值,也就是最重要预测的值。
        # 对图像归一化的作用:
        # 1. 对图像进行归一化可以使得图像的像素值在[0,1]范围内,这有助于优化模型的训练和收敛速度。
        #     具体来说,归一化可以使得数据的均值接近于0,方差接近于1,这对于许多基于梯度的优化算法(如随机梯度下降)是非常有利的。
        # 2. 此外,由于神经网络的权重通常会被初始化为较小的随机值,因此在训练过程中,输入的像素值可能会因为网络权重较小而产生较小的响应,
        #     这可能导致网络学习缓慢或停滞不前。通过对输入图像进行归一化,可以缓解这个问题,使网络更容易学习到有用的特征。
        # 3. 另外,还可以防止图像数据因为数值范围过大而在计算中出现溢出或浮点数精度丢失等问题。
        img = img.astype('float32') / 255

        # 在深度学习中,通常要求输入的图像或掩模的维度顺序为通道维度在最前面,即`(channels, height, width)`。
        # 这种维度顺序可以方便地与卷积操作等深度学习操作相匹配。
        # 在原始的`mask`中,假设维度顺序为`(height, width, channels)`,也确实是这样。通过进行维度变换 `mask = mask.transpose(2, 0, 1)`,
        # 就将通道维度放在了最前面,符合深度学习框架对输入数据的要求。
        img = img.transpose(2, 0, 1)
        mask = mask.astype('float32') / 255
        mask = mask.transpose(2, 0, 1)

        # 返回三个数据:图像、掩膜、图像名称。最后调用该类会将所有的组合成一个列表返回。
        return img, mask, {'img_id': img_id}


    """
    fcn的用的是argmax,将所有通道的这个像素值中最大的索引取出来,代表值,于是fcn就用的标签数组。
    而这里的mask就是所有的通道,于是就不能用argmax了。
    """

    """
    # 展示多张图片
    def show_images(imgs, num_rows, num_cols, scale=2):
        figsize = (num_cols * scale, num_rows * scale)
        _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
        for i in range(num_rows):
            for j in range(num_cols):
                axes[i][j].imshow(imgs[i * num_cols + j])
                axes[i][j].axes.get_xaxis().set_visible(False)
                axes[i][j].axes.get_yaxis().set_visible(False)
        plt.show()
        return axes
    """

    """
    掩膜图像展示法
    img = cv2.bitwise_and(train[0][2],cv2.cvtColor(train[1][2], cv2.COLOR_GRAY2BGR))
    plt.title('Masks over image')
    plt.imshow(img)
    plt.show()
    """
  1. preprocess_dsb2018.py
import os
from glob import glob

import cv2
import numpy as np
from tqdm import tqdm

# 数据集预处理
def main():
    img_size = 96

    # glob这个函数是用来获取路径下所有文件的绝对路径,并组装成列表,以供后续使用
    paths = glob('inputs/dsb-18/stage1_train/*')

    """
        给这个文件夹生成目录树:
            tree:
                stage1_train
                ├── 0a7e06
                |    ├── images
                |       ├── 0a7e06.png
                |    └── masks
                |       ├── 0a7e06.png
                |       ├── 0aab0a.png
                |       ├── 0b1761.png
                ├── 0aab0a
                ├── 0b1761
                ├── ...
    """

    # 这里使用os库中的函数makedirs函数在文件夹不存在的情况下创建所给路径的文件夹
    os.makedirs('inputs/dsb2018_%d/images' % img_size, exist_ok=True)
    #  这里的 0 的意思是,本数据集只分一类,也就是只有背景和细胞两个分类,二分类。
    #  如果还有别的待分类的内容,就需要再建一个1的文件夹,这里面存的是二值图像关于这个类和背景的分割
    os.makedirs('inputs/dsb2018_%d/masks/0' % img_size, exist_ok=True)

    # 这里循环 每一个paths中的文件夹
    for i in tqdm(range(len(paths))):
        path = paths[i]
        # 取到路径中images文件夹中的图片
        # os.path.basename(path) :返回一个文件路径的基名(即文件名)
        img = cv2.imread(os.path.join(path, 'images',
                                      os.path.basename(path) + '.png'))
        #  根据原始图像的形状创建全零图像,用于存储拼接后的masks图像
        mask = np.zeros((img.shape[0], img.shape[1]))
        for mask_path in glob(os.path.join(path, 'masks', '*')):
            # 将下属所有图像中值大于127,也就是二值图像中的255的白色部分,全部集合到mask图像中
            mask_ = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) > 127
            mask[mask_] = 1

        # 这段代码的作用是处理图像的通道数:
        #       第一个条件len(img.shape) == 2判断图像是否是灰度图(通道数为1)。如果是灰度图,则使用np.tile()函数将图像在通道维度进行复制,使其变为RGB图像(通道数为3)。
        #       第二个条件img.shape[2] == 4判断图像是否具有4个通道(RGBA图像)。如果是RGBA图像,则取前3个通道,将其转换为RGB图像。
        # 这段代码的目的是确保图像的通道数为3,以便后续处理。如果图像是灰度图或RGBA图像,会根据需要进行相应的转换。
        if len(img.shape) == 2:
            img = np.tile(img[..., None], (1, 1, 3))
        if img.shape[2] == 4:
            img = img[..., :3]

        # 为方便后续训练,将图像与标签图像的大小resize为所要的大小。
        # 但是真实的UNet不是这样子的,真实的UNet用的是patch图像分块,并且还给图像进行了镜像padding
        img = cv2.resize(img, (img_size, img_size))
        mask = cv2.resize(mask, (img_size, img_size))

        # 将处理好的图像存入指定文件夹中
        cv2.imwrite(os.path.join('inputs/dsb2018_%d/images' % img_size,
                                 os.path.basename(path) + '.png'), img)
        # 这里要将mask*255,是因为mask中细胞位置是1,但是要显示为白色255,就需要都乘以255
        # 这种方式要记住
        cv2.imwrite(os.path.join('inputs/dsb2018_%d/masks/0' % img_size,
                                 os.path.basename(path) + '.png'), (mask * 255).astype('uint8'))


if __name__ == '__main__':
    main()
  1. losses.py
import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    from LovaszSoftmax.pytorch.lovasz_losses import lovasz_hinge
except ImportError:
    pass

__all__ = ['BCEDiceLoss', 'LovaszHingeLoss']

"""
这是一个 PyTorch 实现的损失函数模块,包括 `BCEDiceLoss` 和 `LovaszHingeLoss` 两个损失函数。

`BCEDiceLoss` 是一种结合了二元交叉熵 (Binary Cross Entropy, BCE) 和 Dice Loss 的损失函数。
BCE 用于度量预测值与真实值之间的差异,而 Dice Loss 则用于衡量两个集合之间的相似度。
相比于单一的 BCE 或 Dice Loss,BCEDiceLoss 可以更好地平衡准确率和召回率之间的权衡,从而提高模型的性能。

`LovaszHingeLoss` 则是一种基于 Lovasz 损失函数的损失函数。
Lovasz 损失函数是一种非常适合处理不平衡数据的损失函数,它可以在训练中强制模型对较难的样本进行更多的关注,从而提高模型的泛化能力。
LovaszHingeLoss 在二分类问题中表现得很好,特别是在处理分割问题时,可以取得很好的效果。

需要注意的是,`LovaszHingeLoss` 中的 `lovasz_hinge` 函数是在外部定义的,可能需要单独导入才能使用。
"""
class BCEDiceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input, target):
        bce = F.binary_cross_entropy_with_logits(input, target)
        smooth = 1e-5
        input = torch.sigmoid(input)
        num = target.size(0)
        input = input.view(num, -1)
        target = target.view(num, -1)
        intersection = (input * target)
        dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth)
        dice = 1 - dice.sum() / num
        return 0.5 * bce + dice


class LovaszHingeLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input, target):
        input = input.squeeze(1)
        target = target.squeeze(1)
        loss = lovasz_hinge(input, target, per_image=True)

        return loss
  1. metrics.py
import numpy as np
import torch
import torch.nn.functional as F

"""
这段代码定义了两个评估指标函数:`iou_score`和`dice_coef`。这些指标通常用于评估图像分割模型的性能。

`iou_score`计算了预测输出(output)和目标标签(target)之间的交并比(Intersection over Union,IoU)。
    首先,通过对预测输出应用sigmoid函数将其转换为概率值,并使用`torch.sigmoid()`函数将其转换为[0,1]范围内的值。
    然后,将预测输出和目标标签转换为numpy数组。
    接下来,根据阈值0.5将预测输出和目标标签二值化为True/False表示。
    然后,计算交集和并集的元素数量,并将其相加得到交并比。
    为了避免除以0的情况,添加了一个平滑项smooth。

`dice_coef`计算了预测输出和目标标签之间的Dice系数。
    首先,对预测输出应用sigmoid函数并将其视为一维数组。
    然后,将预测输出和目标标签转换为numpy数组。
    接下来,计算预测输出和目标标签之间的交集元素的和,并将其乘以2。
    然后,计算预测输出和目标标签中各自元素的和,并将其相加。
    最后,将交集和平滑项添加到分子中,将预测输出和目标标签的和以及平滑项添加到分母中,得到Dice系数。

这段代码使用了PyTorch库的一些函数,如`torch.sigmoid`、`torch.is_tensor`和`torch.Tensor.view`,
并使用了NumPy库的函数,如`numpy.sum`和`numpy.ndarray`。
如果要完整运行该段代码,需要确保已正确导入相关库,并提供合适的输出和目标标签张量作为输入。
"""
def iou_score(output, target):
    smooth = 1e-5

    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()
    output_ = output > 0.5
    target_ = target > 0.5
    intersection = (output_ & target_).sum()
    union = (output_ | target_).sum()

    return (intersection + smooth) / (union + smooth)


def dice_coef(output, target):
    smooth = 1e-5

    output = torch.sigmoid(output).view(-1).data.cpu().numpy()
    target = target.view(-1).data.cpu().numpy()
    intersection = (output * target).sum()

    return (2. * intersection + smooth) / \
           (output.sum() + target.sum() + smooth)
  1. utils.py
import argparse


"""
这段代码包含了一些常用的辅助函数和类。

`str2bool`是一个用于解析命令行参数的函数,将字符串表示的布尔值转换为对应的Python布尔值。
    如果输入的字符串是"true"(不区分大小写)或1,则返回True;
    如果输入的字符串是"false"(不区分大小写)或0,则返回False;
    否则,抛出一个类型错误的异常。

`count_params`是一个用于计算模型参数数量的函数。
    它接受一个模型对象作为输入,并返回其中所有需要进行梯度计算的参数的总数量。
    通过遍历模型的parameters属性,并使用requires_grad属性来过滤出需要进行梯度计算的参数,
    然后使用numel方法获取每个参数的元素数量,并将其求和得到总数量。

`AverageMeter`是一个用于计算和存储平均值的类。
    它包含了val、avg、sum和count四个属性,分别表示当前值、平均值、值的累加和以及值的累加次数。
    reset方法用于重置所有属性的值为0,
    update方法用于更新当前值、累加和和累加次数,并计算新的平均值。

这些辅助函数和类可以在训练和评估深度学习模型时非常有用。
    例如,可以使用str2bool函数解析命令行参数中的布尔值选项,
    使用count_params函数统计模型的参数数量,
    使用AverageMeter类跟踪训练过程中的损失值或指标值的平均值。

"""
def str2bool(v):
    if v.lower() in ['true', 1]:
        return True
    elif v.lower() in ['false', 0]:
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
  1. train.py
import argparse
import configparser
import datetime
import os
import time
from collections import OrderedDict
from glob import glob

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
import albumentations as albu
from albumentations.augmentations import transforms
from albumentations.core.composition import Compose, OneOf
from sklearn.model_selection import train_test_split
from torch.optim import lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

import archs
import losses
from dataset import Dataset
from metrics import iou_score
from utils import AverageMeter, str2bool
"""
这段代码导入了一些必要的库和模块,下面我将逐一介绍每个包的作用:
1. argparse:用于解析命令行参数,可以方便地从命令行中获取用户输入的参数。
2. os:提供了许多与操作系统交互的函数,用于处理文件和目录。
3. collections 中的 OrderedDict:是一个有序字典,用于创建有序的键值对。
4. glob:根据指定的路径模式匹配文件,可以用来获取符合条件的文件列表。
5. pandas:用于数据分析和处理,提供了高性能、易用的数据结构和数据分析工具。
6. torch:PyTorch深度学习框架的核心库,提供了张量操作、自动求导等功能。
7. torch.backends.cudnn:用于设置一些与CUDA加速相关的选项,提供了对cuDNN库的接口。
8. torch.nn:PyTorch中用于定义神经网络层的模块,包括各种不同类型的层和损失函数。
9. torch.optim:提供了各种优化器的实现,用于更新模型的参数。
10. yaml:用于读取和写入YAML格式的配置文件。
11. albumentations:一个图像增强库,提供了各种图像增强的方法,如旋转、缩放、裁剪等。
12. transforms:albumentations中的模块,提供了各种图像增强的操作。
13. Compose:albumentations中的模块,用于将多个图像增强操作组合在一起。
14. train_test_split:用于将数据集划分为训练集和验证集的模块。
15. torch.optim.lr_scheduler:PyTorch中的学习率调度器,用于动态调整学习率。
16. tqdm:一个Python进度条库,用于在循环中显示进度条。
17. archs:自定义的模型架构,用于定义神经网络模型。
18. losses:自定义的损失函数,用于定义模型的损失计算方法。
19. dataset:自定义的数据集类,用于加载和处理数据。
20. metrics:自定义的评估指标,用于评估模型性能。
21. utils:自定义的工具函数,用于辅助操作和处理。
"""



"""
ARCH_NAMES = archs.__all__:从 archs 模块中导入所有可用的模型架构的名称,并将它们存储在 ARCH_NAMES 列表中。
LOSS_NAMES = losses.__all__:从 losses 模块中导入所有可用的损失函数的名称,并将它们存储在 LOSS_NAMES 列表中。
LOSS_NAMES.append('BCEWithLogitsLoss'):将一个名为 'BCEWithLogitsLoss' 的损失函数名称添加到 LOSS_NAMES 列表中。
"""
ARCH_NAMES = archs.__all__
LOSS_NAMES = losses.__all__
LOSS_NAMES.append('BCEWithLogitsLoss')

"""
命令行参数: 指定参数:
--dataset dsb2018_96 
--arch NestedUNet
"""

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


"""
这段代码是用来解析命令行参数的函数。下面是每一句代码的作用:
1. `parser = argparse.ArgumentParser()`:创建一个ArgumentParser对象,用于解析命令行参数。
2. `parser.add_argument('--name', default=None, help='model name: (default: arch+timestamp)')`:添加一个名为'name'的参数,用于指定模型的名称,默认值为None,帮助信息为'model name: (default: arch+timestamp)'。
3. `parser.add_argument('--epochs', default=10, type=int, metavar='N', help='number of total epochs to run')`:添加一个名为'epochs'的参数,用于指定总共运行的训练轮数,默认值为10,类型为整数,帮助信息为'number of total epochs to run'。
4. `parser.add_argument('-b', '--batch_size', default=8, type=int, metavar='N', help='mini-batch size (default: 16)')`:添加一个名为'batch_size'的参数,用于指定每个mini-batch的样本数量,默认值为8,类型为整数,帮助信息为'mini-batch size (default: 16)'。
5. `parser.add_argument('--arch', '-a', metavar='ARCH', default='NestedUNet', choices=ARCH_NAMES, help='model architecture: ' + ' | '.join(ARCH_NAMES) + ' (default: NestedUNet)')`:添加一个名为'arch'或'a'的参数,用于指定模型的架构,默认值为'NestedUNet',可选值为ARCH_NAMES中的架构名称,帮助信息为'model architecture: ' + ' | '.join(ARCH_NAMES) + ' (default: NestedUNet)'。
6. `parser.add_argument('--deep_supervision', default=False, type=str2bool)`:添加一个名为'deep_supervision'的参数,用于指定是否使用深度监督,默认值为False。
7. `parser.add_argument('--input_channels', default=3, type=int, help='input channels')`:添加一个名为'input_channels'的参数,用于指定输入图像的通道数,默认值为3,帮助信息为'input channels'。
8. `parser.add_argument('--num_classes', default=1, type=int, help='number of classes')`:添加一个名为'num_classes'的参数,用于指定分类的类别数量,默认值为1,帮助信息为'number of classes'。
9. `parser.add_argument('--input_w', default=96, type=int, help='image width')`:添加一个名为'input_w'的参数,用于指定输入图像的宽度,默认值为96,帮助信息为'image width'。
10. `parser.add_argument('--input_h', default=96, type=int, help='image height')`:添加一个名为'input_h'的参数,用于指定输入图像的高度,默认值为96,帮助信息为'image height'。
11. `parser.add_argument('--loss', default='BCEDiceLoss', choices=LOSS_NAMES, help='loss: ' + ' | '.join(LOSS_NAMES) + ' (default: BCEDiceLoss)')`:添加一个名为'loss'的参数,用于指定损失函数,默认值为'BCEDiceLoss',可选值为LOSS_NAMES中的损失函数名称,帮助信息为'loss: ' + ' | '.join(LOSS_NAMES) + ' (default: BCEDiceLoss)'。
12. `parser.add_argument('--dataset', default='dsb2018_96', help='dataset name')`:添加一个名为'dataset'的参数,用于指定数据集的名称,默认值为'dsb2018_96',帮助信息为'dataset name'。
13. `parser.add_argument('--img_ext', default='.png', help='image file extension')`:添加一个名为'img_ext'的参数,用于指定图像文件的扩展名,默认值为'.png',帮助信息为'image file extension'。
14. `parser.add_argument('--mask_ext', default='.png', help='mask file extension')`:添加一个名为'mask_ext'的参数,用于指定掩码文件的扩展名,默认值为'.png',帮助信息为'mask file extension'。
15. `parser.add_argument('--optimizer', default='SGD', choices=['Adam', 'SGD'], help='loss: ' + ' | '.join(['Adam', 'SGD']) + ' (default: Adam)')`:添加一个名为'optimizer'的参数,用于指定优化器,默认值为'SGD',可选值为['Adam', 'SGD'],帮助信息为'loss: ' + ' | '.join(['Adam', 'SGD']) + ' (default: Adam)'。
16. `parser.add_argument('--lr', '--learning_rate', default=1e-3, type=float, metavar='LR', help='initial learning rate')`:添加一个名为'lr'或'learning_rate'的参数,用于指定初始学习率,默认值为1e-3,类型为浮点数,帮助信息为'initial learning rate'。
17. `parser.add_argument('--momentum', default=0.9, type=float, help='momentum')`:添加一个名为'momentum'的参数,用于指定动量,默认值为0.9。
18. `parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay')`:添加一个名为'weight_decay'的参数,用于指定权重衰减,默认值为1e-4。
19. `parser.add_argument('--nesterov', default=False, type=str2bool, help='nesterov')`:添加一个名为'nesterov'的参数,用于指定是否使用Nesterov动量,默认值为False。
20. `parser.add_argument('--scheduler', default='CosineAnnealingLR', choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'])`:添加一个名为'scheduler'的参数,用于指定学习率调度器,默认值为'CosineAnnealingLR',可选值为['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR']。
21. `parser.add_argument('--min_lr', default=1e-5, type=float, help='minimum learning rate')`:添加一个名为'min_lr'的参数,用于指定最小学习率,默认值为1e-5。
22. `parser.add_argument('--factor', default=0.1, type=float)`:添加一个名为'factor'的参数,用于指定学习率调度器中的因子,默认值为0.1。
23. `parser.add_argument('--patience', default=2, type=int)`:添加一个名为'patience'的参数,用于指定学习率调度器中的耐心值,默认值为2。
24. `parser.add_argument('--milestones', default='1,2', type=str)`:添加一个名为'milestones'的参数,用于指定学习率调度器中的里程碑,默认值为'1,2'。
25. `parser.add_argument('--gamma', default=2 / 3, type=float)`:添加一个名为'gamma'的参数,用于指定学习率调度器中的γ值,默认值为2/3。
26. `parser.add_argument('--early_stopping', default=-1, type=int, metavar='N', help='early stopping (default: -1)')`:添加一个名为'early_stopping'的参数,用于指定早停的轮数,默认值为-1,帮助信息为'early stopping (default: -1)'。
27. `parser.add_argument('--num_workers', default=0, type=int)`:添加一个名为'num_workers'的参数,用于指定数据加载时的并行工作进程数,默认值为0。
28. `config = parser.parse_args()`:解析命令行参数,并将结果存储在config变量中。
29. `return config`:返回解析后的命令行参数配置。
"""
def original_parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--name', default=None,
                        help='model name: (default: arch+timestamp)')
    parser.add_argument('--epochs', default=10, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-b', '--batch_size', default=8, type=int,
                        metavar='N', help='mini-batch size (default: 16)')

    # model
    parser.add_argument('--arch', '-a', metavar='ARCH', default='NestedUNet',
                        choices=ARCH_NAMES,
                        help='model architecture: ' +
                             ' | '.join(ARCH_NAMES) +
                             ' (default: NestedUNet)')
    parser.add_argument('--deep_supervision', default=False, type=str2bool)
    parser.add_argument('--input_channels', default=3, type=int,
                        help='input channels')
    parser.add_argument('--num_classes', default=1, type=int,
                        help='number of classes')
    parser.add_argument('--input_w', default=96, type=int,
                        help='image width')
    parser.add_argument('--input_h', default=96, type=int,
                        help='image height')

    # loss
    parser.add_argument('--loss', default='BCEDiceLoss',
                        choices=LOSS_NAMES,
                        help='loss: ' +
                             ' | '.join(LOSS_NAMES) +
                             ' (default: BCEDiceLoss)')

    # dataset
    parser.add_argument('--dataset', default='dsb2018_96',
                        help='dataset name')
    parser.add_argument('--img_ext', default='.png',
                        help='image file extension')
    parser.add_argument('--mask_ext', default='.png',
                        help='mask file extension')

    # optimizer
    parser.add_argument('--optimizer', default='SGD',
                        choices=['Adam', 'SGD'],
                        help='loss: ' +
                             ' | '.join(['Adam', 'SGD']) +
                             ' (default: Adam)')
    parser.add_argument('--lr', '--learning_rate', default=1e-3, type=float,
                        metavar='LR', help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float,
                        help='momentum')
    parser.add_argument('--weight_decay', default=1e-4, type=float,
                        help='weight decay')
    parser.add_argument('--nesterov', default=False, type=str2bool,
                        help='nesterov')

    # scheduler
    parser.add_argument('--scheduler', default='CosineAnnealingLR',
                        choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'])
    parser.add_argument('--min_lr', default=1e-5, type=float,
                        help='minimum learning rate')
    parser.add_argument('--factor', default=0.1, type=float)
    parser.add_argument('--patience', default=2, type=int)
    parser.add_argument('--milestones', default='1,2', type=str)
    parser.add_argument('--gamma', default=2 / 3, type=float)
    parser.add_argument('--early_stopping', default=-1, type=int,
                        metavar='N', help='early stopping (default: -1)')

    parser.add_argument('--num_workers', default=0, type=int)

    config = parser.parse_args()

    return config

def str2bool(v):
    return v.lower() in ('yes', 'true', 't', '1')

# 使用ini配置文件添入默认值
def parse_args():
    parser = argparse.ArgumentParser()

    config = configparser.ConfigParser()
    config.read('resources/defaultValue.ini')

    parser.add_argument('--name', default=None,
                        help='model name: (default: arch+timestamp)')
    parser.add_argument('--epochs', default=config.getint('training', 'epochs'), type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-b', '--batch_size', default=config.getint('training', 'batch_size'), type=int,
                        metavar='N', help='mini-batch size (default: 16)')

    # model
    parser.add_argument('--arch', '-a', metavar='ARCH', default=config.get('training', 'arch'),
                        choices=ARCH_NAMES,
                        help='model architecture: ' +
                             ' | '.join(ARCH_NAMES) +
                             ' (default: NestedUNet)')
    parser.add_argument('--deep_supervision', default=False, type=str2bool)
    parser.add_argument('--input_channels', default=config.getint('training', 'input_channels'), type=int,
                        help='input channels')
    parser.add_argument('--num_classes', default=config.getint('training', 'num_classes'), type=int,
                        help='number of classes')
    parser.add_argument('--input_w', default=config.getint('training', 'input_w'), type=int,
                        help='image width')
    parser.add_argument('--input_h', default=config.getint('training', 'input_h'), type=int,
                        help='image height')

    # loss
    parser.add_argument('--loss', default=config.get('training', 'loss'),
                        choices=LOSS_NAMES,
                        help='loss: ' +
                             ' | '.join(LOSS_NAMES) +
                             ' (default: BCEDiceLoss)')

    # dataset
    parser.add_argument('--dataset', default=config.get('data', 'dataset'),
                        help='dataset name')
    parser.add_argument('--img_ext', default=config.get('data', 'img_ext'),
                        help='image file extension')
    parser.add_argument('--mask_ext', default=config.get('data', 'mask_ext'),
                        help='mask file extension')

    # optimizer
    parser.add_argument('--optimizer', default=config.get('optimizer', 'optimizer'),
                        choices=['Adam', 'SGD'],
                        help='loss: ' +
                             ' | '.join(['Adam', 'SGD']) +
                             ' (default: Adam)')
    parser.add_argument('--lr', '--learning_rate', default=config.getfloat('optimizer', 'lr'), type=float,
                        metavar='LR', help='initial learning rate')
    parser.add_argument('--momentum', default=config.getfloat('optimizer', 'momentum'), type=float,
                        help='momentum')
    parser.add_argument('--weight_decay', default=config.getfloat('optimizer', 'weight_decay'), type=float,
                        help='weight decay')
    parser.add_argument('--nesterov', default=config.getboolean('optimizer', 'nesterov'), type=str2bool,
                        help='nesterov')

    # scheduler
    parser.add_argument('--scheduler', default=config.get('scheduler', 'scheduler'),
                        choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'])
    parser.add_argument('--min_lr', default=config.getfloat('scheduler', 'min_lr'), type=float,
                        help='minimum learning rate')
    parser.add_argument('--factor', default=config.getfloat('scheduler', 'factor'), type=float)
    parser.add_argument('--patience', default=config.getint('scheduler', 'patience'), type=int)
    parser.add_argument('--milestones', default=config.get('scheduler', 'milestones'), type=str)
    parser.add_argument('--gamma', default=config.getfloat('scheduler', 'gamma'), type=float)
    parser.add_argument('--early_stopping', default=config.getint('early_stopping', 'early_stopping'), type=int,
                        metavar='N', help='early stopping (default: -1)')

    parser.add_argument('--num_workers', default=config.getint('other', 'num_workers'), type=int)

    config = parser.parse_args()

    return config

"""
这个函数接受以下参数:
- `model`: 要训练的模型。
- `train_loader`: 训练数据集的数据加载器。
- `criterion`: 损失函数。
- `optimizer`: 优化器。
- `device`: 设备,可以是 `'cuda'` 或 `'cpu'`。
- `config`: 包含训练配置的字典,包括是否使用深度监督、学习率等超参数。

函数的主要步骤:
1. 将模型切换到训练模式:`model.train()`
2. 创建进度条对象,用于显示训练进度:`pbar = tqdm(total=len(train_loader))`
3. 遍历训练数据加载器中的每个样本:
   - 将输入数据和目标数据移动到设备上:`input = input.to(device)` 和 `target = target.to(device)`
   - 如果使用深度监督,通过模型进行前向传播,获取多个输出结果(列表),并计算损失和性能指标。
     - 遍历每个输出结果并计算损失:`for output in outputs: loss += criterion(output, target)`
     - 计算平均损失:`loss /= len(outputs)`
     - 计算最后一个输出结果和目标数据之间的 IoU 指标:`iou = iou_score(outputs[-1], target)`
   - 如果不使用深度监督,通过模型进行前向传播,获取单个输出结果,并计算损失:`output = model(input)` 和 `loss = criterion(output, target)`
   - 清零优化器的梯度缓冲区:`optimizer.zero_grad()`
   - 执行反向传播计算梯度:`loss.backward()`
   - 根据梯度更新模型的参数:`optimizer.step()`
   - 使用 `AverageMeter` 对象更新平均损失和平均 IoU:`avg_meters['loss'].update(loss.item(), input.size(0))` 和 `avg_meters['iou'].update(iou, input.size(0))`
   - 更新进度条的后缀信息:`postfix = OrderedDict([...])` 和 `pbar.set_postfix(postfix)`
   - 更新进度条的计数器:`pbar.update(1)`
4. 关闭进度条:`pbar.close()`
5. 返回包含平均损失和性能指标的有序字典:`return OrderedDict([...])`

这些步骤涵盖了在训练数据集上对模型进行一次迭代的训练过程。
"""
def train(config, train_loader, model, criterion, optimizer):
    avg_meters = {'loss': AverageMeter(),
                  'iou': AverageMeter()}

    model.train()

    pbar = tqdm(total=len(train_loader))
    for input, target, _ in train_loader:
        input = input.to(device)
        target = target.to(device)

        # compute output
        if config['deep_supervision']:
            outputs = model(input)
            loss = 0
            for output in outputs:
                loss += criterion(output, target)
            loss /= len(outputs)
            iou = iou_score(outputs[-1], target)
        else:
            output = model(input)
            loss = criterion(output, target)
            iou = iou_score(output, target)

        # compute gradient and do optimizing step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_meters['loss'].update(loss.item(), input.size(0))
        avg_meters['iou'].update(iou, input.size(0))

        postfix = OrderedDict([
            ('loss', avg_meters['loss'].avg),
            ('iou', avg_meters['iou'].avg),
        ])
        pbar.set_postfix(postfix)
        pbar.update(1)
    pbar.close()

    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg)])


"""
这段代码定义了一个用于验证模型的函数 `validate()`。它接受以下参数:
- `config`:包含配置信息的字典。
- `val_loader`:验证数据集的数据加载器。
- `model`:要验证的模型。
- `criterion`:损失函数。

函数中的主要步骤如下:
1. 创建一个包含两个 `AverageMeter()` 对象的字典 `avg_meters`,用于计算并保存平均损失和平均 IoU(Intersection over Union)。
2. 将模型切换到评估模式,通过调用 `model.eval()` 来设置。
3. 使用 `torch.no_grad()` 上下文管理器,禁用梯度计算,以便在验证过程中不进行参数更新。
4. 使用 `tqdm` 进度条迭代 `val_loader` 中的每个验证样本。
5. 将输入数据和目标数据移动到设备(GPU 或 CPU)上。
6. 根据配置中的 `deep_supervision` 值,选择不同的计算输出和损失的方式:
   - 如果 `deep_supervision` 为真,则通过模型的多个输出计算损失,并取平均损失。
   - 如果 `deep_supervision` 为假,则使用模型的单个输出计算损失。
7. 使用输出和目标计算 IoU 值。
8. 更新平均损失和平均 IoU 的 `AverageMeter` 对象。
9. 设置进度条的后缀信息,包括平均损失和平均 IoU。
10. 更新并关闭进度条。
11. 返回一个有序字典,包含平均损失和平均 IoU 的键值对。

该函数的作用是计算模型在验证数据集上的损失和性能指标(IoU),并返回这些指标的平均值。
"""
def validate(config, val_loader, model, criterion):
    avg_meters = {'loss': AverageMeter(),
                  'iou': AverageMeter()}

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        pbar = tqdm(total=len(val_loader))
        for input, target, _ in val_loader:
            input = input.to(device)
            target = target.to(device)

            # compute output
            if config['deep_supervision']:
                outputs = model(input)
                loss = 0
                for output in outputs:
                    loss += criterion(output, target)
                loss /= len(outputs)
                iou = iou_score(outputs[-1], target)
            else:
                output = model(input)
                loss = criterion(output, target)
                iou = iou_score(output, target)

            avg_meters['loss'].update(loss.item(), input.size(0))
            avg_meters['iou'].update(iou, input.size(0))

            postfix = OrderedDict([
                ('loss', avg_meters['loss'].avg),
                ('iou', avg_meters['iou'].avg),
            ])
            pbar.set_postfix(postfix)
            pbar.update(1)
        pbar.close()

    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg)])


def main():

    """
    这是一个将命令行参数解析为字典的操作,其中parse_args()是一个用于解析命令行参数的函数。
    vars()函数将解析出来的命令行参数对象转换为字典类型。
    这个操作的结果是将命令行参数以字典形式存储在config中。
    """
    config = vars(parse_args())

    """
    这段代码是为了创建用于保存模型的文件夹。首先判断配置中是否指定了模型名称,如果没有指定,则根据数据集和网络结构来生成默认的模型名称。
    如果配置中开启了深度监督(deep_supervision),则模型名称为"数据集_网络结构_wDS";如果没有开启深度监督,则模型名称为"数据集_网络结构_woDS"。
    然后使用`os.makedirs`函数创建以模型名称和时间命名的文件夹,如果文件夹已存在则不会重复创建。
    """
    if config['name'] is None:
        if config['deep_supervision']:
            config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch'])
        else:
            config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch'])

    timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    folder_name = f"models/{config['name']}_{timestamp}"
    os.makedirs(folder_name, exist_ok=True)

    # 添加tensorBoard
    writer = SummaryWriter(f"{folder_name}/SummaryWriter")

    """
    这段代码会打印模型训练的配置信息。首先输出一条20个连字符"-",作为分割线。
    然后遍历配置字典中的所有键值对,逐个输出键和对应的值,格式为"键: 值"。
    输出完后再次输出一条20个连字符"-"的分割线。
    这样可以方便用户在控制台上查看和核对模型训练的配置信息。
    """
    print('-' * 20)
    for key in config:
        print('%s: %s' % (key, config[key]))
    print('-' * 20)

    """
    这段代码将模型的配置信息保存为YAML格式的文件。
    使用`open`函数打开一个文件,文件路径是根据模型名称组合而成的,路径为"models/模型名称/config.yml"。
    然后使用`yaml.dump`函数将配置信息写入文件中,以YAML格式保存。
    存入config.yml文件中
    这样可以方便后续查看和恢复模型训练的配置信息。
    """
    with open('%s/config.yml' % folder_name, 'w') as f:
        yaml.dump(config, f)

    # define loss function (criterion)
    """
    这段代码根据配置中的损失函数类型来创建相应的损失函数实例。
        首先判断配置中的损失函数类型是否为'BCEWithLogitsLoss',如果是,则使用`nn.BCEWithLogitsLoss()`来创建二分类交叉熵损失函数实例,并将其移动到指定的设备上。
        如果配置中的损失函数类型不是'BCEWithLogitsLoss',则通过`losses.__dict__[config['loss']]()`来动态获取指定名称的损失函数类,并使用该类创建损失函数实例。然后将损失函数实例移动到指定的设备上。
    这样可以根据配置中的损失函数类型灵活地选择和使用不同的损失函数。
    """
    if config['loss'] == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss().to(device)  # WithLogits 就是先将输出结果经过sigmoid再交叉熵
    else:
        # `losses.__dict__` 是一个字典,包含了当前上下文中所有可用的损失函数类。
        # `config['loss']` 是配置中指定的损失函数类型的字符串。
        # 通过 `losses.__dict__[config['loss']]` 可以获取到对应的损失函数类。
        # 然后使用 `()` 运算符创建该损失函数类的实例,并将其移动到指定的设备上(`to(device)`)。
        criterion = losses.__dict__[config['loss']]().to(device)

    """
    这行代码是用于设置cuDNN的benchmark模式为True。
    cuDNN(CUDA Deep Neural Network library)是一个针对深度神经网络计算的GPU加速库。
    cuDNN的benchmark模式可以自动寻找最适合当前硬件配置的卷积实现算法,并在训练过程中进行优化,提高运行效率。
    通过将`cudnn.benchmark`设置为True,可以启用cuDNN的benchmark模式,从而使得每次运行时都会重新评估和选择最佳的卷积实现算法,进而提高训练速度。
    需要注意的是,在某些情况下,benchmark模式可能会导致不确定性,因为每次运行时都会选择不同的实现算法。
        因此,如果要求结果的一致性比速度更重要,则可以将benchmark模式设置为False。
    而且:
        如果没有使用GPU加速,那么设置`cudnn.benchmark = True`不会产生任何效果,因为cuDNN库只能在GPU上使用。
        该代码的作用是启用cuDNN的benchmark模式来优化卷积计算,在使用GPU进行深度学习训练时可以提高训练速度。
        如果没有使用GPU,则无法享受到cuDNN带来的加速效果,即使设置了`cudnn.benchmark = True`也不会有任何效果。
    """
    if torch.cuda.is_available():
        cudnn.benchmark = True


    # create model
    print("=> creating model %s" % config['arch'])
    # 通过 archs.__dict__[config['arch']] 可以获取到指定名称的模型结构类。
    # 然后调用该模型结构类的构造函数,并传入相应的参数:
    #   config['num_classes'] 表示分类任务的类别数,
    #   config['input_channels'] 表示输入数据的通道数,
    #   config['deep_supervision'] 表示是否使用深度监督(即是否在网络中添加多个分支用于不同层次的特征提取),
    # 并将创建好的模型实例赋值给变量 model。
    model = archs.__dict__[config['arch']](config['num_classes'],
                                           config['input_channels'],
                                           config['deep_supervision'])

    model = model.to(device)

    # 这段代码是根据配置文件中指定的优化器和学习率调度器类型,创建对应的优化器和学习率调度器。
    params = filter(lambda p: p.requires_grad, model.parameters())# 对模型的参数进行过滤,只选择需要梯度更新的参数,并将其赋值给变量 params。
    """
    根据配置文件中的 config['optimizer'] 来选择使用哪种优化器。
        如果是 "Adam",则创建一个 Adam 优化器,使用 optim.Adam() 函数,
            并传入相应的参数:params、lr=config['lr'](学习率)和 weight_decay=config['weight_decay'](权重衰减)。
        如果是 "SGD",则创建一个 SGD 优化器,使用 optim.SGD() 函数,
            并传入相应的参数:params、lr=config['lr']、momentum=config['momentum'](动量)和 weight_decay=config['weight_decay']。
        如果既不是 "Adam" 也不是 "SGD",则抛出一个 NotImplementedError 异常。
    """
    if config['optimizer'] == 'Adam':
        optimizer = optim.Adam(
            params, lr=config['lr'], weight_decay=config['weight_decay'])
    elif config['optimizer'] == 'SGD':
        optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'],
                              nesterov=config['nesterov'], weight_decay=config['weight_decay'])
    else:
        raise NotImplementedError

    """
    根据配置文件中的 config['scheduler'] 来选择使用哪种学习率调度器。
        如果是 "CosineAnnealingLR",则创建一个 CosineAnnealingLR 调度器,使用 lr_scheduler.CosineAnnealingLR() 函数,
            并传入相应的参数:optimizer、T_max=config['epochs'](总的训练轮数)和 eta_min=config['min_lr'](学习率的最小值)。
        如果是 "ReduceLROnPlateau",则创建一个 ReduceLROnPlateau 调度器,使用 lr_scheduler.ReduceLROnPlateau() 函数,
            并传入相应的参数:optimizer、factor=config['factor'](学习率缩放因子)、
                patience=config['patience'](在验证集上等待多少个epoch之后,学习率开始下降)、
                verbose=1(是否打印日志信息)和 min_lr=config['min_lr'](学习率的最小值)。
        如果是 "MultiStepLR",则创建一个 MultiStepLR 调度器,使用 lr_scheduler.MultiStepLR() 函数,
            并传入相应的参数:optimizer、milestones=[int(e) for e in config['milestones'].split(',')](在哪些epoch时学习率进行调整)
                和 gamma=config['gamma'](学习率缩放因子)。如果是 "ConstantLR",则不使用学习率调度器,将 scheduler 设置为 None。
        如果既不是 "CosineAnnealingLR" 也不是 "ReduceLROnPlateau" 也不是 "MultiStepLR" 也不是 "ConstantLR",则抛出一个 NotImplementedError 异常。
    """
    if config['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
    elif config['scheduler'] == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'],
                                                   verbose=1, min_lr=config['min_lr'])
    elif config['scheduler'] == 'MultiStepLR':
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')],
                                             gamma=config['gamma'])
    elif config['scheduler'] == 'ConstantLR':
        scheduler = None
    else:
        raise NotImplementedError

    # Data loading code
    """
    这行代码使用glob函数获取图像文件的路径。
        os.path.join('inputs', config['dataset'], 'images')用于生成图像文件的目录路径,
            其中'inputs'是根目录,config['dataset']是数据集名称,'images'是图像文件所在的子目录。
        '*' + config['img_ext']表示以任意字符开头,后面跟着图像文件扩展名。config['img_ext']是配置参数中指定的图像文件扩展名。
    最终,glob函数返回匹配指定模式的所有文件路径,并保存在img_ids列表中。每个文件的路径是相对于当前工作目录的路径。
    """
    img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))
    """
    这行代码对图像文件的路径进行处理,提取出图像文件的名称(去掉扩展名),并将名称保存在img_ids列表中。
    具体来说,代码通过遍历img_ids列表中的每个图像文件路径,
        使用os.path.basename(p)函数获取文件的基本名称(包括扩展名),
        然后使用os.path.splitext()函数将基本名称分割成文件名和扩展名的元组。
        由于我们只需要文件名部分,因此使用[0]索引获取文件名,并将文件名添加到img_ids列表中。
    最终,img_ids列表中保存的是图像文件的名称(不包含扩展名),用于后续的数据集划分或其他操作。
    """
    img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

    # 将img_ids列表按照指定的比例划分为训练集和验证集,并将结果保存在train_img_ids和val_img_ids列表中。
    train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)
    # 数据增强:需要安装albumentations包
    train_transform = Compose([
        # 角度旋转
        albu.RandomRotate90(),
        # 图像翻转
        albu.Flip(),
        OneOf([
            transforms.HueSaturationValue(),
            transforms.RandomBrightness(),
            transforms.RandomContrast(),
        ], p=1),  # 按照归一化的概率选择执行哪一个
        # 将图像大小调整为模型可接受的输入尺寸。
        albu.Resize(config['input_h'], config['input_w']),
        # 对图像进行归一化处理,将图像像素值缩放到0到1之间。
        albu.Normalize(),
    ])

    val_transform = Compose([
        albu.Resize(config['input_h'], config['input_w']),
        albu.Normalize(),
    ])

    """
    通过Dataset类创建了训练集对象train_dataset。
        其中,img_ids参数是训练图像的名称列表(不包含扩展名),
        img_dir参数是训练图像文件存储的目录路径,
        mask_dir参数是训练标签(掩膜)文件存储的目录路径,
        img_ext参数是图像文件的扩展名,
        mask_ext参数是标签文件的扩展名,
        num_classes参数是数据集的类别数量,这里的分类数是不包括背景在内的。
        transform参数是对训练图像进行的数据增强变换操作流水线。
    """
    train_dataset = Dataset(
        img_ids=train_img_ids,
        img_dir=os.path.join('inputs', config['dataset'], 'images'),
        mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        transform=train_transform)
    val_dataset = Dataset(
        img_ids=val_img_ids,
        img_dir=os.path.join('inputs', config['dataset'], 'images'),
        mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        transform=val_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,# 表示在每个 epoch 开始时是否对数据进行洗牌,即打乱顺序。
        num_workers=config['num_workers'],
        drop_last=True)  # 不能整除的batch是否就不要了
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        drop_last=False)

    """
    这段代码定义了一个有序字典 log,用于保存训练和验证过程中的一些指标值。
    log 字典中包含了以下键值对:
        'epoch':用于保存每个 epoch 的编号。
        'lr':用于保存每个 epoch 的学习率。
        'loss':用于保存每个 epoch 的训练集损失。
        'iou':用于保存每个 epoch 的训练集 Intersection over Union (IoU) 指标值。
        'val_loss':用于保存每个 epoch 的验证集损失。
        'val_iou':用于保存每个 epoch 的验证集 IoU 指标值。
    最后存入log.csv文件中
    """
    log = OrderedDict([
        ('epoch', []),
        ('lr', []),
        ('loss', []),
        ('iou', []),
        ('val_loss', []),
        ('val_iou', []),
    ])

    start_time = time.time()
    best_iou = 0
    """
    trigger 的作用是用来判断是否触发早期停止(early stopping)。
        早期停止是一种常用的训练策略,它可以在验证集的性能不再提升时停止训练,以防止过拟合并节省计算资源。
    在这段代码中,
        trigger 的初始值为0。
        每当验证集的 IoU 指标不再提升时,trigger 的值就会增加1。
        如果 trigger 的值大于等于设定的早期停止阈值 config['early_stopping'],则会触发早期停止,训练循环会被中断,停止训练过程。
    通过使用 trigger 变量,可以在验证集性能不再提升时自动结束训练,避免过拟合并提高训练效率。
    """
    trigger = 0
    for epoch in range(config['epochs']):
        epoch = epoch + 1
        print('Epoch [%d/%d]' % (epoch, config['epochs']))

        # train for one epoch
        train_log = train(config, train_loader, model, criterion, optimizer)
        # evaluate on validation set
        val_log = validate(config, val_loader, model, criterion)

        if config['scheduler'] == 'CosineAnnealingLR':
            scheduler.step()
        elif config['scheduler'] == 'ReduceLROnPlateau':
            scheduler.step(val_log['loss'])

        print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f'
              % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))
        writer.add_scalars("loss",
                           {"train_loss": train_log['loss'],
                            "test_loss": val_log['loss']},
                           epoch)
        writer.add_scalars("iou",
                           {"train_iou": train_log['iou'],
                            "test_iou": val_log['iou']},
                           epoch)

        log['epoch'].append(epoch)
        log['lr'].append(config['lr'])
        log['loss'].append(train_log['loss'])
        log['iou'].append(train_log['iou'])
        log['val_loss'].append(val_log['loss'])
        log['val_iou'].append(val_log['iou'])

        """
        这行代码将一个名为 `log` 的字典转换为 DataFrame,并将其保存为 CSV 文件。具体来说,它执行以下操作:
            1. 将 `log` 字典传递给 `pd.DataFrame()` 函数,将其转换为一个 DataFrame 对象。
            2. 使用 `to_csv()` 方法将 DataFrame 对象保存为 CSV 文件。
                - 参数 `'models/%s/log.csv' % config['name']` 是文件路径和名称的格式字符串,
                    其中 `%s` 会被 `config['name']` 的值替代。
                    例如,如果 `config['name']` 的值是 `"model1"`,那么保存的文件路径就是 `'models/model1/log.csv'`。
                - 参数 `index=False` 表示不将 DataFrame 的索引写入 CSV 文件。
        这段代码的目的是将训练过程中的日志信息保存为 CSV 文件,以便后续分析和可视化。
        """
        pd.DataFrame(log).to_csv('%s/log.csv' %
                                 folder_name, index=False)

        trigger += 1

        if val_log['iou'] > best_iou:
            torch.save(model.state_dict(), '%s/model.pth' %
                       folder_name)
            best_iou = val_log['iou']
            print("=> saved best model")
            trigger = 0

        # early stopping
        """
        config['early_stopping']的默认值是-1,不满足第一个if条件,不触发早停策略。
        通常情况下,早期停止的阈值应该是一个非负整数,表示在多少个连续的验证集性能不再提升时触发停止训练。
            当将早停阈值设为-1时,意味着不使用早期停止策略,训练过程将一直进行下去,直到达到指定的训练轮数(config['epochs'])为止,或者手动停止训练。
        禁用早期停止可能会导致模型在训练过程中过拟合,因为它没有根据验证集的性能动态调整训练的停止时机。
        因此,建议在实际训练中根据需要设置合适的早停阈值,以避免过拟合和节省计算资源。
        """
        if config['early_stopping'] >= 0 and trigger >= config['early_stopping']:
            print("=> early stopping")
            break

        end_time = time.time()
        print("本轮截至运行时间:",end_time - start_time)

        # 清显存,保证运行速率
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

if __name__ == '__main__':
    main()

接下来

  • 对真正的测试集,也就是没有标签数据的测试集做预测,并生成掩膜,放到val.py
  • 10
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值