第12章 PyTorch图像分割代码框架-2

模型模块

本书的第5-9章重点介绍了各种2D3D的语义分割和实例分割网络模型,所以在模型模块中,我们需要做的事情就是将要实验的分割网络写在该目录下。有时候我们可能想尝试不同的分割网络结构,所以在该目录下可以存在多个想要实验的网络模型定义文件。对于PASCAL VOC这样的自然数据集,我们可能想实验Deeplab v3+PSPNetRefineNet等网络的训练效果。代码11-3给出了Deeplab v3+网络封装后的主体部分,完整网络搭建代码可参考本书配套代码对应章节。

代码11-3 Deeplab v3+网络的主体部分

# 定义Deeplab V3+类
class DeepLabHeadV3Plus(nn.Module):
    def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
        super(DeepLabHeadV3Plus, self).__init__()


        self.project = nn.Sequential(
            nn.Conv2d(low_level_channels, 48, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True),
        )
    # ASPP
        self.aspp = ASPP(in_channels, aspp_dilate)
    # classifier head
        self.classifier = nn.Sequential(
            nn.Conv2d(304, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, 1)
        )


        self._init_weight()
  # forward method
    def forward(self, feature):
        # print(feature['low_level'].shape)
        # print(feature['out'].shape)
        low_level_feature = self.project(feature['low_level'])
        output_feature = self.aspp(feature['out'])
        output_feature = F.interpolate(
            output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False)
        return self.classifier(torch.cat([low_level_feature, output_feature], dim=1))
  # weight initilize
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

对于复杂网络搭建,一般都是采用自下而上的搭建方法,先搭建底层组件,再逐步向上封装,对于本例中的Deeplab v3+,可以先分别搭建backbone骨干网络、ASPP和编解码结构,最后再进行封装。

工具函数模块

工具函数是为项目完成各项功能所自定义的辅助函数,可以统一定义在utils文件夹下,根据实际项目的不同,工具函数也各不相同。常用的工具函数包括各种损失函数的定义loss.py、训练可视化函数的定义visualize.py、用于记录训练日志的log.py等。代码11-4给出了一个关于Focal loss损失函数的定义,该损失函数作为工具函数可放在loss.py文件中。

代码11-4 工具函数示例:定义一个Focal loss

# 导入相关库
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义一个Focal loss类
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma


    def forward(self, inputs, targets):
        # Compute cross-entropy loss
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')


        # Compute the focal loss
        pt = torch.exp(-ce_loss)  
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal_loss.mean()

配置模块

配置模块是为项目模型训练传入各种参数而进行设置的模块,比如训练数据所在目录、训练所需要的各种参数、训练过程是否需要可视化等。一般来说,我们有两种方式来对项目执行参数进行配置管理,一种是直接在主函数main.py中使用argparse库对参数进行配置,然后再命令行中进行传入;另一种则是单独定义一个config.py或者config.yaml文件来对所有参数进行统一配置。基于argparse库的参数配置管理简单示例如代码11-5所示。

代码11-5 argparser参数配置管理

# 导入argparse库
import argparse
# 创建参数管理器
parser = argparse.ArgumentParser()
# 涉及数据相关的参数管理
parser.add_argument("--data_root", type=str, default='./dataset',
                     help="path to Dataset")
parser.add_argument("--save_root", type=str, default='./',
                     help="path to save result")
parser.add_argument("--dataset", type=str, default='voc',
                     choices=['voc', 'cityscapes', 'ade'], help='Name of dataset')
parser.add_argument("--num_classes", type=int, default=None,
                     help="num classes (default: None)")

在上述代码中,我们基于argparse给出了一小部分参数配置管理代码,涉及训练数据相关的部分参数,包括数据读取路径、存放路径、训练所用数据集、分割类别数量等。

主函数模块

主函数模块main.py是项目的启动模块,该模块将定义好的数据和模型模块进行组装,并结合损失函数、优化器、评估方法和可视化等组件,将config.py中配置好的项目参数传入,根据训练-验证的模式,执行图像分割项目模型训练和验证。代码11-6VOC数据集训练验证部分代码。

代码11-6 主函数模块中的训练迭代部分

# 初始化区间损失
interval_loss = 0
while True:  
  # 执行训练
  model.train()
  cur_epochs += 1
  for (images, labels) in train_loader:
    cur_itrs += 1
    images = images.to(device, dtype=torch.float32)
    labels = labels.to(device, dtype=torch.long)
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()


    np_loss = loss.detach().cpu().numpy()
    interval_loss += np_loss


    if vis is not None:
      vis.vis_scalar('Loss', cur_itrs, np_loss)
    # 打印训练信息
    if (cur_itrs) % opts.print_interval == 0:
      pass
    # 保存模型
    if (cur_itrs) % opts.val_interval == 0:
      pass
      # 日志记录
      logger.info("Save the latest model to %s" % save_path_checkpoints)
      # 模型验证
      print("validation...")
      model.eval()
      val_score, ret_samples = validate(
        opts=opts, model=model, loader=val_loader, device=device, metrics=metrics,
        ret_samples_ids=vis_sample_id)
      logger.info("Validation performance: %s", val_score)
      
      # 保存最优模型
      if val_score['mean_dice'] > best_score:  
        best_score = val_score['mean_dice']
        save_ckpt(os.path.join(save_path_checkpoints, 'best_%s_%s_os%d.pth' %
                     (opts.model, opts.dataset, opts.output_stride)))
        logger.info("Save best-performance model so far to %s" % save_path_checkpoints)


      # 训练过程可视化
      if vis is not None:  
        vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])
        vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])
        vis.vis_table("[Val] Class IoU", val_score['Class IoU'])


        for k, (img, target, lbl) in enumerate(ret_samples):
          img = (denorm(img) * 255).astype(np.uint8)
          target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)
          lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)
          concat_img = np.concatenate((img, target, lbl), axis=2)  
          vis.vis_image('Sample %d' % k, concat_img)
          
    scheduler.step()

在代码11-6中,我们展示了一个图像分割项目主函数模块中最核心的训练和验证部分。在训练时,按照指定迭代次数保存模型和对训练过程进行可视化展示。图11-2为训练打印的部分信息。

4b3c438a10649b752070ff5d0c7ce8fa.png

11-2 VOC训练过程信息

11-3为基于visdom的训练过程可视化展示,包括当前训练配置参数信息,训练损失函数变化曲线、验证集全局准确率、mIoU和类别IoU等指标变化曲线图。

9907f54a46275eafc2e600b66157fc84.png

11-3 Deeplab v3+训练过程可视化

11-4展示了两组训练过程中验证集的输入图像、标签图像和模型预测图像的对比图。可以看到,基于Deeplab v3+的分割模型在PASCAL VOC 2012上表现还不错。

dcb6cef8c0b0abdf98f2c64c97552bb0.png

11-4 验证集模型效果图

后续全书内容和代码将在github上开源,请关注仓库:

https://github.com/luwill/Deep-Learning-Image-Segmentation

(未完待续)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是一个基于Pytorch图像分割代码示例: ```python import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.datasets import Cityscapes from model import UNet # 设置超参数 num_epochs = 10 batch_size = 4 learning_rate = 0.001 # 加载Cityscapes数据集 train_dataset = Cityscapes(root='./data', split='train', mode='fine', target_type=['semantic'], transform=transforms.ToTensor()) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # 定义UNet模型和损失函数 model = UNet(n_channels=3, n_classes=19) criterion = nn.CrossEntropyLoss() # 定义优化器 optimizer = optim.Adam(model.parameters(), lr=learning_rate) # 训练模型 for epoch in range(num_epochs): for i, (images, labels) in enumerate(train_loader): # 前向传播 outputs = model(images) loss = criterion(outputs, labels) # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() # 每100个iteration打印一次损失 if (i+1) % 100 == 0: print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, len(train_loader), loss.item())) # 保存模型 torch.save(model.state_dict(), 'unet.ckpt') ``` 这里使用了一个名为`UNet`的卷积神经网络模型进行图像分割,该模型的实现可以参考这个[代码库](https://github.com/milesial/Pytorch-UNet)。在训练过程中,我们使用了`Cityscapes`数据集,它是一个用于城市场景图像分割的数据集。在每个iteration结束后,我们打印了损失值以便于观察训练过程。最后,我们将训练好的模型保存到了`unet.ckpt`文件中。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值