使用UNet+mmsegmentation完成全景X射线图像牙齿分割任务(阿里天池题目)

在这里插入图片描述

最近在寻找分割练手数据集的时候正好在阿里天池上看到了牙齿分割任务的数据集和代码。

于是当作练习mmsegmentation的练手项目,并且写下完整的文档来记录学习过程。

为什么要用框架

框架的作用相当于是别人写好了一堆的模型,你只需要用命令调用他们完成任务就行。当然你也可以直接去写自己的模型放到框架里跑,还可以使用框架中写好的一些图像处理算法。

其实是可以直接从python写分割任务的代码的,但是使用框架的好处是能够很方便的测试多个主流的分割模型在任务上的表现。

这里我使用的是mmsegmentation,官方github代码:https://github.com/SwanHubX/mmsegmentation

环境安装

参考我另一篇博客从CUDA开始的mmsegmentation安装教程(尽量不受版本影响)

当然更推荐大家直接跟着官方教程走mmsegmentation官方教程

准备天池数据

天池的官方提供了下载链接

在这里插入图片描述

一共四个压缩包,这里我们仅用初赛的压缩包train.zip作为案例教程。

下载完成后接下来就是解压,转换为mmsegmentation的格式,官方格式的样式可以参考官方文档

这边提供了自动解压+转换为mmsegmentation通用格式的脚本:

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import tempfile
import zipfile

import mmcv
from mmengine.utils import mkdir_or_exist

from tqdm import tqdm


def parse_args():
    parser = argparse.ArgumentParser(
        description="Convert MICCAI dataset to mmsegmentation format"
    )
    parser.add_argument("dataset_path", help="path of MICCAI train.zip)")
    parser.add_argument("--tmp_dir", help="path of the temporary directory")
    parser.add_argument("-o", "--out_dir", default=None, help="output path")
    parser.add_argument(
        "-s", "--split_rate", type=float, help="splite rate for train val"
    )
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    dataset_path = args.dataset_path
    if args.out_dir is None:
        out_dir = osp.join("data", "miccai")
    else:
        out_dir = args.out_dir

    print("Making directories...")
    mkdir_or_exist(out_dir)
    mkdir_or_exist(osp.join(out_dir, "images"))
    mkdir_or_exist(osp.join(out_dir, "images", "training"))
    mkdir_or_exist(osp.join(out_dir, "images", "validation"))
    mkdir_or_exist(osp.join(out_dir, "annotations"))
    mkdir_or_exist(osp.join(out_dir, "annotations", "training"))
    mkdir_or_exist(osp.join(out_dir, "annotations", "validation"))

    with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
        print("Extracting train.zip...")
        zip_file = zipfile.ZipFile(dataset_path)
        zip_file.extractall(tmp_dir)

        # split
        train_img_path = osp.join(tmp_dir, "train/image")
        train_anno_path = osp.join(tmp_dir, "train/mask")
        image_num = len(os.listdir(train_img_path))
        training_len = int(image_num * args.split_rate)

        print("Processing train data...")

        for img_name in tqdm(sorted(os.listdir(train_img_path))[:training_len]):
            img = mmcv.imread(osp.join(train_img_path, img_name))
            mmcv.imwrite(
                img,
                osp.join(
                    out_dir,
                    "images",
                    "training",
                    osp.splitext(img_name)[0] + ".png",
                ),
            )
        print("Processing train anno data...")
        for img_name in tqdm(sorted(os.listdir(train_anno_path))[:training_len]):
            # The annotation img should be divided by 128, because some of
            # the annotation imgs are not standard. We should set a
            # threshold to convert the nonstandard annotation imgs. The
            # value divided by 128 is equivalent to '1 if value >= 128
            # else 0'
            img = mmcv.imread(osp.join(train_anno_path, img_name))
            mmcv.imwrite(
                img[:, :, 0] // 128,
                osp.join(
                    out_dir,
                    "annotations",
                    "training",
                    osp.splitext(img_name)[0] + ".png",
                ),
            )
        print("Processing val data...")
        for img_name in tqdm(sorted(os.listdir(train_img_path))[training_len:]):
            img = mmcv.imread(osp.join(train_img_path, img_name))
            mmcv.imwrite(
                img,
                osp.join(
                    out_dir,
                    "images",
                    "validation",
                    osp.splitext(img_name)[0] + ".png",
                ),
            )
        print("Processing val anno data...")
        for img_name in tqdm(sorted(os.listdir(train_anno_path))[training_len:]):
            img = mmcv.imread(osp.join(train_anno_path, img_name))
            mmcv.imwrite(
                img[:, :, 0] // 128,
                osp.join(
                    out_dir,
                    "annotations",
                    "validation",
                    osp.splitext(img_name)[0] + ".png",
                ),
            )
        print("Removing the temporary files...")
    print("Done!")


if __name__ == "__main__":
    main()

按照mmsegmentation的官方建议,建议将上述代码命名并保存在<mmsegmentation root>/tools/dataset_converters/miccai.py

使用如下命令完成转换

python tools/dataset_converters/miccai.py <train.zip的路径> -s <训练集和测试集的切分比例,推荐为0.7>

完成后即可在<mmsegmentation root>/data/miccai/文件夹中看到转换的结果

在这里插入图片描述

实现mmsegmentation数据读取代码

这里可以安装mmsegmentation的官方教程实现,此处不多赘述,就把我实现的代码和路径粘贴出来,路径我统一使用相对于mmsegmentation项目的相对路径。

实现一个简单mmsegmentation数据读取类的代码,代码位置:./mmseg/datasets/miccai.py

from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset


@DATASETS.register_module()
class MICCAIDataset(BaseSegDataset):

    METAINFO = dict(
        classes=("unlabelled", "tooth"), palette=[[120, 120, 120], [6, 230, 230]]
    )

    def __init__(
        self,
        img_suffix=".png",
        seg_map_suffix=".png",
        reduce_zero_label=False,
        **kwargs
    ) -> None:
        super().__init__(
            img_suffix=img_suffix,
            seg_map_suffix=seg_map_suffix,
            reduce_zero_label=reduce_zero_label,
            **kwargs
        )

注册该数据读取类,在./mmseg/datasets/__init__.py 最后一行添加代码

from .miccai import MICCAIDataset
__all__.append("MICCAIDataset")

以UNet为例对牙齿数据进行分割

在这里插入图片描述

这里境以Unet为例,创建一个文件在configs/unet/unet-s5-d16_fcn_4xb4-20k_miccai-128x128.py

代码如下:

_base_ = [
    "../_base_/datasets/miccai_128x128.py",
    "../_base_/default_runtime.py",
    "../_base_/schedules/schedule_20k.py",
]
# 数据集配置
# dataset settings
dataset_type = "MICCAIDataset"
data_root = "data/miccai"
img_scale = (640, 320)
crop_size = (128, 128)
train_pipeline = [
    dict(type="LoadImageFromFile"),
    dict(type="LoadAnnotations"),
    dict(type="RandomResize", scale=img_scale, ratio_range=(0.5, 2.0), keep_ratio=True),
    dict(type="RandomCrop", crop_size=crop_size, cat_max_ratio=0.75),
    dict(type="RandomFlip", prob=0.5),
    dict(type="PhotoMetricDistortion"),
    dict(type="PackSegInputs"),
]
test_pipeline = [
    dict(type="LoadImageFromFile"),
    dict(type="Resize", scale=img_scale, keep_ratio=True),
    # add loading annotation after ``Resize`` because ground truth
    # does not need to do resize data transform
    dict(type="LoadAnnotations"),
    dict(type="PackSegInputs"),
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
    dict(type="LoadImageFromFile", backend_args=None),
    dict(
        type="TestTimeAug",
        transforms=[
            [dict(type="Resize", scale_factor=r, keep_ratio=True) for r in img_ratios],
            [
                dict(type="RandomFlip", prob=0.0, direction="horizontal"),
                dict(type="RandomFlip", prob=1.0, direction="horizontal"),
            ],
            [dict(type="LoadAnnotations")],
            [dict(type="PackSegInputs")],
        ],
    ),
]

train_dataloader = dict(
    batch_size=4,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type="InfiniteSampler", shuffle=True),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path="images/training", seg_map_path="annotations/training"
        ),
        pipeline=train_pipeline,
    ),
)
val_dataloader = dict(
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type="DefaultSampler", shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        data_prefix=dict(
            img_path="images/validation", seg_map_path="annotations/validation"
        ),
        pipeline=test_pipeline,
    ),
)
test_dataloader = val_dataloader

val_evaluator = dict(type="IoUMetric", iou_metrics=["mIoU", "mDice"])
test_evaluator = val_evaluator

# 模型设置
crop_size = (128, 128)
data_preprocessor = dict(size=crop_size)
model = dict(
    data_preprocessor=data_preprocessor,
    test_cfg=dict(crop_size=(128, 128), stride=(85, 85)),
)

# SwanLab
custom_imports = dict(
    imports=["swanlab.integration.mmengine"], allow_failed_imports=False
)
import sys

vis_backends = [
    dict(
        type="SwanlabVisBackend",
        save_dir="runs/swanlab",
        init_kwargs={
            "project": "MICCAI",
            "experiment_name": "unet baseline",
            "description": "unet-s5-d16_fcn_4xb4-20k_miccai-128x128.py"
        },
    ),
]

visualizer = dict(
    type="Visualizer",
    vis_backends=vis_backends,
)

这里分享我实现的config文件,功能包括:

  1. 读取牙齿图片,随机crop成128x128分辨率,使用一些随机翻转,增强等。
  2. 使用U2Net网络进行训练,训练20k个epoch,batch数是4
  3. 计算测试指标:mIoU和Dice
  4. 配置文件使用了swanlab作为日志跟踪工具,一个开源的在线日志记录工具:https://github.com/SwanHubX/SwanLab

SwanLab安装+配置

在这里插入图片描述

SwanLab是一个开源的在线日志记录软件,可以参考SwanLab快速上手(Wandb国产平替)安装,这里附上安装命令,方便我们等等查看训练进度:

pip install -U swanlab

github地址:https://github.com/SwanHubX/SwanLab

官方网站:https://swanlab.cn/

训练以及结果

使用如下命令开启训练

单卡训练:

python tools/train.py configs/unet/unet-s5-d16_fcn_4xb4-20k_miccai-128x128.py

4卡训练:

CUDA_VISIBLE_DEVICES=0,1,2,3 bash tools/dist_train.sh configs/unet/unet-s5-d16_fcn_4xb4-20k_miccai-128x128.py 4

如果你选择了SwanLab作为可视化工具,会出现下面的提示你粘贴Key的情况

在这里插入图片描述

可以在这里获得Key来让训练日志可以在线观看:https://swanlab.cn/settings

在这里插入图片描述

填写完成后便可以成功训练

在这里插入图片描述

训练日志可以在swanlab.cn上看到:

请添加图片描述

  • 10
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值