UNet - 代码的深入理解 - 入门2

目录

前言 - 基础知识

代码讲解

UNet模型

1. unet_parts.py

2. unet_model.py

模型训练

train.py

模型预测

predict.py

调参优化


前言 - 基础知识

这一部分,放一些基础知识,防止看到后面,会有术语看不懂或者因果不理解的情况(标红的字)

CNN基础知识——卷积(Convolution)、填充(Padding)、步长(Stride)

感受野是卷积神经网络(CNN)每一层输出的特征图(feature map)上的像素点在原始输入图像上映射的区域大小

批归一化的目的是通过对每个批次的数据进行归一化,使得数据的分布更加稳定,有助于提高模型的训练稳定性和收敛速度

激活函数的作用是对神经网络的输入进行转换,生成神经元的输出。它通过对输入进行非线性操作,引入非线性关系,使得神经网络能够学习和表示更加复杂的函数。

激活函数通常应用于神经网络的隐藏层和输出层。常见的激活函数包括:

  1. Sigmoid 函数(Logistic 函数):将输入映射到 (0, 1) 区间,具有平滑的 S 形曲线。
  2. ReLU 函数(Rectified Linear Unit):在输入大于零时,输出与输入相等;在输入小于等于零时,输出为零。
  3. Tanh 函数(双曲正切函数):将输入映射到 (-1, 1) 区间,具有平滑的 S 形曲线,输出范围比 Sigmoid 函数更广。
  4. Leaky ReLU 函数:在输入小于零时,输出为输入的一个小的斜率乘积;在输入大于零时,输出与输入相等。
  5. Softmax 函数:用于多分类问题,将输入向量转换为概率分布,使得各个输出的和为1。

双线性插值是一种图像处理技术,用于改变图像的大小,可以帮助我们在改变图像大小时保持图像的质量和细节,并且生成平滑的结果。

具体来说,双线性插值会找到新像素点周围的四个最近邻像素点,然后根据它们之间的距离和权重,计算出新像素点的值。这样可以保持图像的平滑性和连续性,避免出现锯齿状或失真的效果。

空间维度是指图像的宽度、高度和通道数。图像可以被看作是一个二维矩阵,其中每个元素代表一个像素值。图像的宽度和高度描述了图像的大小,而通道数表示图像中的颜色通道或特征通道的数量。

例如,在一个RGB彩色图像中,宽度和高度确定了图像的尺寸,而通道数为3,分别对应红色、绿色和蓝色通道。空间维度在卷积和转置卷积等操作中起着重要的作用,它们决定了卷积核的大小和移动步长,以及输出特征图的尺寸。

法线卷积是常规的卷积操作,通过滑动一个卷积核在输入上进行加权求和来生成输出特征图,常用于提取特征

转置卷积用于图像的上采样或还原操作,通过反向操作,使用卷积核对输入进行卷积,以扩大图像的尺寸或增加分辨率,常用于上采样图像

代码讲解

上一篇讲,如何先将UNet跑起来,接下来,还需要学习它的代码

UNet模型

1. unet_parts.py

        首先,找到项目目录下的unet/unet_parts.py,这是构成整个UNet网络的各个小组成部分。

        这些自定义模块的作用是构建UNet模型中的不同部分,如双卷积模块、下采样模块、上采样模块和输出卷积模块。它们在UNet模型的整体构建中起到组织和连接不同层的作用,帮助实现图像分割的功能。

        完整代码这里写的很好https://github.com/milesial/Pytorch-UNet/tree/master,后文只展示部分代码

        代码中有详细注释序号,和 tips 对应来看,要仔细看!!!!!!

  • class DoubleConv(nn.Module):

        双卷积模块(特征提取),用于执行两个连续的卷积操作。这个模块采用了一个3x3的卷积层,然后使用批归一化和ReLU激活函数来增加非线性性。接下来,又使用另一个3x3的卷积层进行进一步的特征提取

class DoubleConv(nn.Module):
    # (卷积 => [BN] => ReLU) * 2
    # 1
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        # 2
        if not mid_channels:
            mid_channels = out_channels
        # 3
        self.double_conv = nn.Sequential(
            # 4
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            # 5
            nn.BatchNorm2d(mid_channels),
            # 6
            nn.ReLU(inplace=True),
            # 同上
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    # 7
    def forward(self, x):
        return self.double_conv(x)

tips:

  1. __init__ 方法是该模块的初始化函数。包含三个参数:三个参数:in_channels 输入通道数,out_channels 输出通道数,mid_channels 中间通道数
  2. 在初始化过程中,如果 mid_channels 未指定,则将其设置为 out_channels。
  3. self.double_conv 是一个 nn.Sequential 对象,它按顺序包含了一系列的卷积、批归一化和ReLU操作。这些操作被组合在一起形成双卷积结构。
  4. nn.Conv2d:二维卷积层,使用 in_channels 和 mid_channels 进行卷积操作。采用 3x3 的卷积核,padding 为 1,表示在输入边界进行零填充,使得输出的特征图尺寸与输入相同。bias 参数设置为 False,表示不使用偏置项。
  5. nn.BatchNorm2d:二维批归一化层,对卷积层的输出进行批归一化操作
  6. nn.ReLU:ReLU 激活函数,引入非线性性质,增加模型的表达能力。
  7. forward 方法定义了模块的前向传播过程。它接受输入 x,并将其传递给 self.double_conv,然后返回双卷积操作的结果。
  • class Down(nn.Module):

        下采样模块(减小图像尺寸和增加感受野),用于对输入图像进行下采样。它首先通过最大池化操作将图像的尺寸减半,然后使用 DoubleConv 模块进行特征提取,这有助于网络捕捉不同尺度的特征信息,并在后续的上采样过程中进行恢复和重建。

class Down(nn.Module):
    # 使用最大池缩小规模,然后double conv
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # 1
        self.maxpool_conv = nn.Sequential(
            # 2
            nn.MaxPool2d(2),
            # 3
            DoubleConv(in_channels, out_channels)
        )
    # 4
    def forward(self, x):
        return self.maxpool_conv(x)

tips:

  1. __init__ 方法中,我们首先定义了一个 maxpool_conv 属性,它是一个包含两个操作的序列。
  2.  nn.MaxPool2d(2),最大池化操作,将输入的特征图的尺寸缩小一半。
  3.  DoubleConv(in_channels, out_channels),它是在前面定义的 DoubleConv 自定义模块,用于进行两次卷积操作。
  4. forward 方法中,我们将输入 x 传递给 maxpool_conv,即先执行最大池化操作,然后通过 DoubleConv 进行两次卷积操作。最后,返回经过下采样后的特征图。
  • class Up(nn.Module):

        上采样模块(恢复图像尺寸和细化特征),用于对特征图进行上采样。如果使用双线性插值进行上采样(bilinear=True),则会使用 nn.Upsample 来将特征图的尺寸放大两倍。然后,使用 DoubleConv 模块将上采样后的特征图与相应的下采样层的特征图进行连接,并进行特征提取。

class Up(nn.Module):
    # 升级然后double conv
    # 1
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            # 2
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            # 3
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            # 4
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            # 5
            self.conv = DoubleConv(in_channels, out_channels)
    # 6
    def forward(self, x1, x2):
        # 7
        x1 = self.up(x1)
        # 8
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        # 9
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # 10
        x = torch.cat([x2, x1], dim=1)
        # 11
        return self.conv(x)

 tips:

  1. 初始化函数,接收输入通道数 in_channels、输出通道数 out_channels 和是否使用双线性插值的标志 bilinear。
  2. 使用双线性插值的上采样层,将特征图的大小沿着空间维度放大两倍。
  3. 使用法线卷积来减少通道数,将输入特征图的通道数减半,并进行两次卷积操作。
  4. 使用转置卷积进行上采样,将特征图的大小沿着空间维度放大两倍。
  5. 进行两次卷积操作。
  6. 前向传播函数,接收两个输入特征图 x1 和 x2。
  7. 对输入特征图 x1 进行上采样。
  8. 计算 x1 和 x2 在空间维度上的差异。
  9. 对 x1 进行边界填充,使其与 x2 的大小一致。
  10. 将 x2 和 x1 沿着通道维度拼接起来,得到融合后的特征图 x。
  11. 对融合后的特征图 x 进行两次卷积操作,输出最终的特征图结果。
  • class OutConv(nn.Module):

        输出卷积模块(生成最终分割结果),它是UNet模型中最后一层的卷积操作,用于生成最终的分割结果。它使用一个1x1的卷积层将上采样后的特征图映射到所需的输出通道数(即类别数)。

# 1
class OutConv(nn.Module):
    # 2
    def __init__(self, in_channels, out_channels):
        # 3
        super(OutConv, self).__init__()
        # 4
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    # 5
    def forward(self, x):
        # 6
        return self.conv(x)

 tips:

  1. 定义了一个名为 OutConv 的类,它继承自 nn.Module
  2. 初始化方法,接收输入通道数 in_channels 和输出通道数 out_channels
  3. 调用父类的初始化方法,确保正确地初始化父类的属性。
  4. 创建一个二维卷积层 nn.Conv2d,输入通道数为 in_channels,输出通道数为 out_channels,卷积核大小为 1x1,也就是对每个像素进行一个线性变换。
  5. 前向传播方法,定义了在输入 x 上执行的计算。
  6. 通过调用卷积层 self.conv 对输入 x 进行卷积操作,并将结果作为输出返回。

这一部分代码学完啦,接下来看 unet/unet_model.py(实现完整的UNet网络模型)

2. unet_model.py

        首先,导入刚刚写好的 unet_parts 模块中的所有内容,这个模块包含了定义 U-Net 模型的各个部分

from .unet_parts import *


class UNet(nn.Module):
    # 1
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        # 2
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        # 3
        self.inc = (DoubleConv(n_channels, 64))
        # 4
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        # 5
        factor = 2 if bilinear else 1
        # 6
        self.down4 = (Down(512, 1024 // factor))
        # 7
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        # 8
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        # 9
        x1 = self.inc(x)
        # 10
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        # 11
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        # 12
        logits = self.outc(x)
        return logits

    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)

 tips:

  1. 输入通道数 n_channels、输出类别数 n_classes 
  2. 将输入通道数、输出类别、是否使用双线性插值,分别存储在实例变量 n_channels、n_classes、bilinear
  3. 创建一个 DoubleConv 实例作为模型的输入层,并将其存储在实例变量 inc 中。输入通道数为 n_channels,输出通道数为 64
  4. 创建一个 Down 实例作为模型的下采样层1,并将其存储在实例变量 down1 中。输入通道数为 64,输出通道数为 128,.........
  5. 如果使用双线性插值,则缩放因子为 2,否则为 1
  6. 创建一个 Down 实例作为模型的下采样层4,并将其存储在实例变量 down4 中。输入通道数为 512,输出通道数为 1024//factor
  7. 创建一个 Up 实例作为模型的上采样层1,并将其存储在实例变量 up1 中。输入通道数为 1024,输出通道数为 512//factor,..........
  8. 创建一个 OutConv 实例作为模型的输出层,并将其存储在实例变量 outc 中。输入通道数为 64,输出通道数为 n_classes
  9. 将输入 x 传递给输入层 inc,得到输出 x1
  10. x1 传递给下采样层1 down1,得到输出 x2
  11. x5x4 传递给上采样层1 up1,得到输出 x
  12. x 传递给输出层 outc,得到模型的最终的输出 logits
  13. 每个组件(incdown1down2....)都通过 torch.utils.checkpoint 方法进行了修改。使得模型可以在需要时生成梯度,无需在整个前向传播过程中保留中间结果,从而降低内存占用并提高计算效率。

模型训练

train.py

首先是导入的包和模块,以及定义的文件路径

接下来是模型训练主函数,传了一些参数(模型本身、设备类型、训练的轮数等),如下:

def train_model(
        model,
        device,
        epochs: int = 5,
        batch_size: int = 1,
        learning_rate: float = 1e-5,
        val_percent: float = 0.1,
        save_checkpoint: bool = True,
        img_scale: float = 0.5,
        amp: bool = False,
        weight_decay: float = 1e-8,
        momentum: float = 0.999,
        gradient_clipping: float = 1.0,
):

        使用 CarvanaDataset 类加载数据集,如果出现异常,则使用 BasicDataset 类作为替代。这些类用于加载图像和目标遮罩数据

try:
        dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
    except (AssertionError, RuntimeError, IndexError):
        dataset = BasicDataset(dir_img, dir_mask, img_scale)

根据指定的验证集比例,将数据集拆分为训练集和验证集

n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

使用 DataLoader 类创建了训练集和验证集的数据加载器,以便在训练过程中按批次加载数据

loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

 使用 wandb.init 函数初始化了一个 Weights & Biases 实验,并更新了一些实验配置参数,如下:

experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
    experiment.config.update(
        dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
             val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp)
    )

定义了优化器(使用 RMSprop)、学习率调度器、梯度缩放器(用于混合精度训练)和损失函数

optimizer = optim.RMSprop(model.parameters(),
                              lr=learning_rate, weight_decay=weight_decay, momentum=momentum, foreach=True)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()
    global_step = 0

要开始训练啦!进入循环,对于每个时期,我们将模型设置为训练模式,并初始化时期的损失为0。我们使用 tqdm 进度条来可视化训练进度

循环训练过程包括:将批次数据移动到设备、计算模型输出和损失、进行梯度优化,以及记录训练过程中的指标和损失

for epoch in range(1, epochs + 1):
    model.train()
    epoch_loss = 0
    with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
        for batch in train_loader:
            images, true_masks = batch['image'], batch['mask']

进行断言检查以确保加载的图像通道数与模型定义的输入通道数相匹配。这有助于捕捉数据加载错误

assert images.shape[1] == model.n_channels, \
    f'Network has been defined with {model.n_channels} input channels, ' \
    f'but loaded images have {images.shape[1]} channels. Please check that ' \
    'the images are loaded correctly.'

将图像和目标遮罩移动到所选设备上,并指定了所需的数据类型

images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
true_masks = true_masks.to(device=device, dtype=torch.long)

在此部分,我们进入了自动混合精度计算损失的上下文管理器。在此上下文中,我们计算模型的输出 masks_pred 并计算损失。根据模型的类别数,我们使用不同的损失函数和评估指标

with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
    masks_pred = model(images)
    if model.n_classes == 1:
        loss = criterion(masks_pred.squeeze(1), true_masks.float())
        loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
    else:
        loss = criterion(masks_pred, true_masks)
        loss += dice_loss(
            F.softmax(masks_pred, dim=1).float(),
            F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
            multiclass=True
        )

执行梯度清零、梯度缩放、梯度裁剪和优化器的更新操作。这些步骤用于反向传播和参数优化

optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
grad_scaler.step(optimizer)
grad_scaler.update()

更新进度条、全局步骤数、时期损失,并记录训练损失。同时,我们更新进度条的后缀以显示当前批次的损失,如下:

pbar.update(images.shape[0])
global_step += 1
epoch_loss += loss.item()
experiment.log({
    'train loss': loss.item(),
    'step': global_step,
    'epoch': epoch
})
pbar.set_postfix(**{'loss (batch)': loss.item()})

这部分代码计算了每个时期内进行验证评估的步骤数。它通过将训练数据集的大小除以(5 * 批次大小)来计算,并确保大于0。然后,我们检查当前全局步骤数是否是这个步骤数的倍数,以确定是否进行验证评估

division_step = (n_train // (5 * batch_size))
if division_step > 0:
    if global_step % division_step == 0:

创建了一个字典 histograms,用于存储模型参数的权重和梯度的直方图。我们遍历模型的命名参数,并检查参数值和梯度是否包含无穷大或NaN值。如果不包含,则将直方图添加到 histograms 字典中

histograms = {}
for tag, value in model.named_parameters():
    tag = tag.replace('/', '.')
    if not (torch.isinf(value) | torch.isnan(value)).any():
        histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
    if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any():
        histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())

调用 evaluate 函数对验证集进行评估,并获得验证分数 val_score。然后,我们使用 scheduler 对学习率进行调度,以便根据验证分数的表现调整学习率

val_score = evaluate(model, val_loader, device, amp)
scheduler.step(val_score)

记录验证评估的结果和相关信息。我们使用日志记录器记录验证分数和学习率,并尝试记录一些图像和直方图数据,以供后续分析和可视化。这些信息被记录到实验中(例如使用 wandb 进行记录)

logging.info('Validation Dice score: {}'.format(val_score))
try:
    experiment.log({
        'learning rate': optimizer.param_groups[0]['lr'],
        'validation Dice': val_score,
        'images': wandb.Image(images[0].cpu()),
        'masks': {
            'true': wandb.Image(true_masks[0].float().cpu()),
            'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()),
        },
        'step': global_step,
        'epoch': epoch,
        **histograms
    })
except:
    pass

如果设置了保存检查点的标志,我们将保存模型的当前状态字典到文件中。同时,我们还保存了数据集的遮罩值,以便在以后使用检查点时能够正确解码遮罩。检查点文件以 checkpoint_epoch{}.pth 的格式命名,其中 {} 是当前时期的数字,如下:

if save_checkpoint:
    Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
    state_dict = model.state_dict()
    state_dict['mask_values'] = dataset.mask_values
    torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
    logging.info(f'Checkpoint {epoch} saved!')

通过 get_args() 函数获取命令行参数,并将其存储在 args 变量中。这些参数包括训练的总轮数、批大小、学习率等

def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
    parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
    parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
                        help='Learning rate', dest='lr')
    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
    parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
    parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')
    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
    parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')

    return parser.parse_args()

如下这段代码的作用是将前面定义的模型、数据加载器、优化器等组件整合在一起,通过命令行参数来配置训练过程,并执行相应的训练任务。它提供了一个方便的入口,使得可以通过命令行来灵活地调整训练的各个参数

if __name__ == '__main__':
    args = get_args()
    # 1
    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    # 2
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')
    # 3
    model = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
    model = model.to(memory_format=torch.channels_last)

    logging.info(f'Network:\n'
                 f'\t{model.n_channels} input channels\n'
                 f'\t{model.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')
    # 4
    if args.load:
        state_dict = torch.load(args.load, map_location=device)
        del state_dict['mask_values']
        model.load_state_dict(state_dict)
        logging.info(f'Model loaded from {args.load}')
    # 5
    model.to(device=device)
    try:
        # 6
        train_model(
            model=model,
            epochs=args.epochs,
            batch_size=args.batch_size,
            learning_rate=args.lr,
            device=device,
            img_scale=args.scale,
            val_percent=args.val / 100,
            amp=args.amp
        )
    # 7
    except torch.cuda.OutOfMemoryError:
        logging.error('Detected OutOfMemoryError! '
                      'Enabling checkpointing to reduce memory usage, but this slows down training. '
                      'Consider enabling AMP (--amp) for fast and memory efficient training')
        torch.cuda.empty_cache()
        model.use_checkpointing()
        train_model(
            model=model,
            epochs=args.epochs,
            batch_size=args.batch_size,
            learning_rate=args.lr,
            device=device,
            img_scale=args.scale,
            val_percent=args.val / 100,
            amp=args.amp
        )

 tips:

  1. 通过 get_args() 函数获取命令行参数,并将其存储在 args 变量中。这些参数包括训练的总轮数、批大小、学习率等
  2. 设置日志记录器的配置,将日志级别设置为 INFO,这样在程序执行过程中会输出一些有用的日志信息
  3. 通过判断是否有可用的 CUDA 设备,选择将模型运行在 CUDA 设备上,否则使用 CPU 设备
  4. 创建 UNet 模型对象,并根据命令行参数设置输入通道数、输出通道数以及是否使用双线性上采样
  5. 将模型移动到选择的设备上
  6. 调用 train_model 函数来训练模型。训练过程会根据命令行参数设置的训练轮数、批大小、学习率等进行训练。训练过程中会输出一些训练过程的日志信息,并在每个训练轮结束后进行验证,并记录验证结果
  7. 如果在训练过程中遇到了内存溢出的错误,程序会输出一条错误消息,并尝试使用检查点机制来减少内存使用。然后重新调用 train_model 函数来继续训练模型

模型预测

predict.py

这个代码就不放在这里展示了,主要讲一下大概做了什么

这段代码用于对输入的图像进行预测,并生成对应的预测掩码。它执行以下步骤:

  1. 导入所需的模块和类,包括 argparse、logging、os、numpy、torch、PIL、transforms 等。

  2. 定义了一个函数 predict_img,用于对输入图像进行预测并生成预测掩码。它接受以下参数:

    • net:UNet 模型对象。
    • full_img:完整的输入图像。
    • device:模型所在的设备。
    • scale_factor:图像的缩放因子。
    • out_threshold:输出掩码的阈值。 函数首先将输入图像进行预处理,并将其转换为张量格式,然后将其传递给模型进行预测。根据模型输出的通道数,将预测结果转换为掩码图像。最后,将掩码图像转换为 NumPy 数组并返回。
  3. 定义了一个函数 get_args,用于获取命令行参数。它使用 argparse 模块来解析命令行参数,并返回解析后的结果。

  4. 定义了一个函数 get_output_filenames,用于生成输出文件名列表。它根据输入文件名生成对应的输出文件名。

  5. 定义了一个函数 mask_to_image,用于将掩码转换为图像。根据掩码的维度和像素值的类型,创建一个与输入掩码相同大小的输出图像,并根据掩码的像素值对输出图像进行赋值。最后,将输出图像转换为 PIL.Image 对象并返回。

  6. if __name__ == '__main__': 块中,首先获取命令行参数并设置日志记录器的配置。

  7. 获取输入文件名列表和输出文件名列表。

  8. 创建 UNet 模型对象,并根据命令行参数设置输入通道数、输出通道数以及是否使用双线性上采样。

  9. 判断是否有可用的 CUDA 设备,选择将模型运行在 CUDA 设备上,否则使用 CPU 设备。

  10. 将模型加载到选择的设备上,并加载预训练模型的状态字典。

  11. 遍历输入文件名列表,对每个输入图像进行预测。

  12. 如果不禁止保存输出掩码,则将预测的掩码保存为图像文件。

  13. 如果启用可视化选项,则显示输入图像和预测掩码的可视化结果。

这段代码的作用是使用预训练的 UNet 模型对输入的图像进行预测,并将预测的掩码保存为图像文件。它提供了命令行接口,可以方便地指定输入图像、模型文件、输出文件等参数。预测过程可以在 GPU 上加速,同时支持可视化预测结果。

调参优化

这个放到下一篇文章讲,写不动了

白白~

  • 7
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值