使用 MMSegmentation 构建训练手写字迹分割模型

【语义分割】使用 MMSegmentation 构建训练手写字迹分割模型


仓库地址:

GitHub - open-mmlab/mmsegmentation: OpenMMLab Semantic Segmentation Toolbox and Benchmark.

MMSegmentation介绍

MMSegmentation 是一个用于语义分割的开源工具箱,它基于 PyTorch 实现。它是 OpenMMLab 项目的一部分。语义分割是将图像分割成属于同一对象类别的部分的任务,它是一个像素级别的预测形式,旨在为图像中的每个像素分配一个类别标签。

MMSegmentation 提供了一个统一的接口以及大量的预训练模型和配置文件,为研究者和开发者在各种标准数据集上进行语义分割实验提供便利。用户可以基于这些模型和配置文件来进行自己的训练和测试,或者对它们进行定制化的修改以适应特定需求。

通常而言,如果仅仅需要将手写内容与印刷内容分开,并不需要识别单独的手写笔迹,那么使用语义分割可能更简单直接。但如果要详细分析或进一步处理图像中每一个手写笔划(例如,提取个别签名或注释),实例分割可能更合适。选择哪种方法取决于具体的应用场景和需求。

如果目的是将手写内容与印刷内容分开,而不需要区分不同的手写实例,那么语义分割是更合适的选择。通过使用语义分割算法,可以为图像中的每个像素分配一个类别标签,从而创建一个二进制掩码,其中一个类别代表手写文本,另一个类别代表印刷文本。

语义分割算法将整张图像分为两部分:一部分包含所有手写字迹的像素,另一部分包含所有印刷体字迹的像素。这样,就可以获得一个清晰的区分,将手写内容从印刷内容中分割出来。这对于自动化文档处理、手写笔记识别以及历史文档的数字化都是有用的。在选择具体的模型和工具时,应该寻找那些专门为类似任务设计的语义分割算法。

**deeplabv3_unet_s5-d16**这个模型名中包含了几个部分,每个部分代表了不同的含义:

  • deeplabv3:这部分指的是网络的主体架构,DeepLabV3是一个用于图像语义分割的深度神经网络架构,它属于DeepLab系列模型中的第三代,特点是使用了Atrous Spatial Pyramid Pooling(ASPP)来捕捉多尺度的图像信息,并且提高了分割的精确度。
  • unet:这部分表示该模型结合了UNet架构的特点。UNet是一种常用于医学图像分割的网络结构,其特点是有一个编码器-解码器结构,能够在图像分割中提供很好的性能,并且对小量的数据有很好的泛化能力。
  • s5:这个参数一般表示模型中某个特定的设置或者是版本,可能代表着空洞卷积层的一个特定设置或者是网络的一个特殊的设计。不过这个参数并没有一个标准的定义,具体意义可能需要查阅相关的文档资料。
  • d16:通常这一部分代表的是网络中某个层使用的卷积核的密度或者步长,例如d16可能表示步长为16或者是某种方式下卷积核的密集度。

由于不同的实现细节,具体参数的意义可能会略有差异,详细的参数含义可能需要核查特定文档或该模型的源代码注释来获得精确解释。如果需要更详细的解释,可以查看MMsegmentation项目在GitHub上的文档。

数据预处理及训练过程

step 1 找到官方类似的训练数据集并修改自己数据集格式

由于我的数据集和官方数据集 CHASE DB1 都是由 Image 和 Mask 组成

故依照 CHASE DB1 数据集配置文件和参数进行更改和适配,对自己的数据集进行训练

step 1.1 将自己的数据集修改成与 CHASE DB1 同样的文件目录

— CHASE_DB1 
   |     |— annotations     # Mask 图片                   
   |     |     |— training  
   |     |     |— validation  
   |     |— images          # 图片         
   |     |     |— training           
   └──   └──   └── validation  

修改后自己的数据集如下:

handwriting
├── annotations
│   ├── rgb2bw.py     # 用于将非二值图片的mask转换为二值图片
│   ├── training
│   └── validation
└── images
    ├── training
    └── validation

!! image 应和其对应的每一个 annotation 的文件名保持一致

step 1.2 annotations 下的图片必须为二值图片,将非二值图片转换为二值图片的代码如下:

import os
from PIL import Image
from tqdm import tqdm

def convert_to_binary(image_path, threshold):
    image = Image.open(image_path).convert('L')
    binary_image = image.point(lambda x: 255 if x > threshold else 0, '1')
    return binary_image

def batch_convert_images(folder_path, output_folder, threshold):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    image_files = os.listdir(folder_path)
    for image_file in tqdm(image_files, desc="Processing Images"):
        image_path = os.path.join(folder_path, image_file)
        if os.path.isfile(image_path):
            binary_image = convert_to_binary(image_path, threshold)
            output_path = os.path.join(output_folder, image_file)
            binary_image.save(output_path)

    print("图片批量处理完成!")

folder_path = 'input_folder'
output_folder = 'output_folder'
threshold = 128
batch_convert_images(folder_path, output_folder, threshold)

step 2 修改配置文件【重点】

step 2.1 创建一个新文件 mmseg/datasets/handwriting.py

注:classes 和 mask 中的类别对应(虽然只有 handwriting 这1个类别,但是background也必须算作1个类别);palette 和 mask 中类别的色彩对应(可以用ps查看每个类别的RGB色彩)。

  • mmseg/datasets/handwriting.py

    import mmengine.fileio as fileio
    
    from mmseg.registry import DATASETS
    from .basesegdataset import BaseSegDataset
    
    @DATASETS.register_module()
    class HandwritingDataset(BaseSegDataset):
        METAINFO = dict(
            classes=('background', 'handwriting'),
            palette=[[0, 0, 0], [255, 255, 255]])
    
        def __init__(self,
                     img_suffix='.png',
                     seg_map_suffix='.png',
                     reduce_zero_label=False,
                     **kwargs) -> None:
            super().__init__(
                img_suffix=img_suffix,
                seg_map_suffix=seg_map_suffix,
                reduce_zero_label=reduce_zero_label,
                **kwargs)
            assert fileio.exists(
                self.data_prefix['img_path'], backend_args=self.backend_args)
    
    

step 2.2mmseg/datasets/__init__.py 中添加语句

# 在开头加入
from . handwriting import HandwritingDataset

# __all__中添加
__all__ = ['HandwritingDataset']

step 2.3 创建一个新的数据集配置文件 configs/__base__/datasets/handwriting.py

  • configs/__base__/datasets/handwriting.py

    # dataset settings
    dataset_type = 'HandwritingDataset' # 与 mmseg/datasets/handwriting.py中的类名是对应的
    data_root = 'data/handwriting/' # 自己的数据集所在的位置
    img_scale = (320, 640)
    crop_size = (160, 320)
    train_pipeline = [
        dict(type='LoadImageFromFile'),
        dict(type='LoadAnnotations'),
        dict(
            type='RandomResize',
            scale=img_scale,
            ratio_range=(0.5, 2.0),
            keep_ratio=True),
        dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
        dict(type='RandomFlip', prob=0.5),
        dict(type='PhotoMetricDistortion'),
        dict(type='PackSegInputs')
    ]
    test_pipeline = [
        dict(type='LoadImageFromFile'),
        dict(type='Resize', scale=img_scale, keep_ratio=True),
        # add loading annotation after ``Resize`` because ground truth
        # does not need to do resize data transform
        dict(type='LoadAnnotations'),
        dict(type='PackSegInputs')
    ]
    img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
    tta_pipeline = [
        dict(type='LoadImageFromFile', backend_args=None),
        dict(
            type='TestTimeAug',
            transforms=[
                [
                    dict(type='Resize', scale_factor=r, keep_ratio=True)
                    for r in img_ratios
                ],
                [
                    dict(type='RandomFlip', prob=0., direction='horizontal'),
                    dict(type='RandomFlip', prob=1., direction='horizontal')
                ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
            ])
    ]
    
    train_dataloader = dict(
        batch_size=4,
        num_workers=2,
        persistent_workers=True,
        sampler=dict(type='InfiniteSampler', shuffle=True),
        dataset=dict(
            type='RepeatDataset',
            times=40000,
            dataset=dict(
                type=dataset_type,
                data_root=data_root,
                data_prefix=dict(
                    img_path='images/training',
                    seg_map_path='annotations/training'),
                pipeline=train_pipeline)))
    
    val_dataloader = dict(
        batch_size=1,
        num_workers=4,
        persistent_workers=True,
        sampler=dict(type='DefaultSampler', shuffle=False),
        dataset=dict(
            type=dataset_type,
            data_root=data_root,
            data_prefix=dict(
                img_path='images/validation',
                seg_map_path='annotations/validation'),
            pipeline=test_pipeline))
    test_dataloader = val_dataloader
    
    val_evaluator = dict(type='IoUMetric', iou_metrics=['mDice'])
    test_evaluator = val_evaluator
    

step 2.4mmseg/utils/class_names 中补充数据集元信息

# 在 dataset_aliases 字典中添加!!!
'handwriting':['handwriting']

# 添加到最后一行即可
def handwriting_classes():
    return [
        'background','handwriting'
    ]

def handwriting_palette():
    return [
        [0,0,0],[255,255,255]
    ]

step 2.5 创建一个总配置文件 configs/unet/unet_s5-d16_deeplabv3_4xb4-40k_handwriting-320×640.py

_base_ = [
    '../_base_/models/deeplabv3_unet_s5-d16.py',
    '../_base_/datasets/handwriting.py', '../_base_/default_runtime.py',
    '../_base_/schedules/schedule_40k.py'
]
crop_size = (160, 320)
data_preprocessor = dict(size=crop_size)
model = dict(
    data_preprocessor=data_preprocessor,
    test_cfg=dict(crop_size=(160, 320), stride=(85, 85)))

step 2.6 【可选】 模型配置文件 _base_/models/deeplabv3_unet_s5-d16.py

  • _base_/models/deeplabv3_unet_s5-d16.py

    # model settings
    norm_cfg = dict(type='BN', requires_grad=True) # 单卡训练为BN,多卡训练为SyncBN
    data_preprocessor = dict(
        type='SegDataPreProcessor',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        bgr_to_rgb=True,
        pad_val=0,
        seg_pad_val=255)
    model = dict(
        type='EncoderDecoder',
        data_preprocessor=data_preprocessor,
        pretrained=None,
        backbone=dict(
            type='UNet',
            in_channels=3,
            base_channels=64,
            num_stages=5,
            strides=(1, 1, 1, 1, 1),
            enc_num_convs=(2, 2, 2, 2, 2),
            dec_num_convs=(2, 2, 2, 2),
            downsamples=(True, True, True, True),
            enc_dilations=(1, 1, 1, 1, 1),
            dec_dilations=(1, 1, 1, 1),
            with_cp=False,
            conv_cfg=None,
            norm_cfg=norm_cfg,
            act_cfg=dict(type='ReLU'),
            upsample_cfg=dict(type='InterpConv'),
            norm_eval=False),
        decode_head=dict(
            type='ASPPHead',
            in_channels=64,
            in_index=4,
            channels=16,
            dilations=(1, 12, 24, 36),
            dropout_ratio=0.1,
            num_classes=2, # 类别记得符合Mask的类别数
            norm_cfg=norm_cfg,
            align_corners=False,
            loss_decode=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
        auxiliary_head=dict(
            type='FCNHead',
            in_channels=128,
            in_index=3,
            channels=64,
            num_convs=1,
            concat_input=False,
            dropout_ratio=0.1,
            num_classes=2, # 类别记得符合Mask的类别数
            norm_cfg=norm_cfg,
            align_corners=False,
            loss_decode=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
        # model training and testing settings
        train_cfg=dict(),
        test_cfg=dict(mode='slide', crop_size=128, stride=85))
    

step 3 重新启动

python setup.py install

pip install -v -e .

step 4 启动训练

python tools/train.py ./configs/unet/unet_s5-d16_deeplabv3_4xb4-40k_handwriting-320×640.py --work-dir ./mmseg_log/

其中:

tools/train.py 为训练启动脚本

./configs/unet/unet_s5-d16_deeplabv3_4xb4-40k_handwriting-320×640.py 为参数文件

--work-dir ./mmseg_log/ 指定日志文件和模型文件的保存位置

!!训练环境说明:官方Dockerfile + 1 * A100

训练出来的模型文件为 228M (iter_40000.pth)

推理自己的数据集

python mm_seg_rgb.py ./data/eval/ ./mmseg_log/unet_s5-d16_deeplabv3_4xb4-40k_handwriting-320×640.py ./mmseg_log/iter_40000.pth ./seg_output/

其中:

./data/eval/ 为测试图片所在的文件夹

./mmseg_log/unet_s5-d16_deeplabv3_4xb4-40k_handwriting-320×640.py 为参数文件

./mmseg_log/iter_40000.pth 为模型文件

./seg_output/ 为输出 mask 存储的文件夹

!!推理文件编写

  • mm_seg_rgb.py

    import os
    from argparse import ArgumentParser
    from pathlib import Path
    
    from mmengine.model import revert_sync_batchnorm
    from mmseg.apis import inference_model, init_model, show_result_pyplot
    
    def main():
        parser = ArgumentParser()
        parser.add_argument('img_dir', help='Directory where the images are stored')
        parser.add_argument('config', help='Config file')
        parser.add_argument('checkpoint', help='Checkpoint file')
        parser.add_argument('out_dir', help='Directory to save output mask images')
        parser.add_argument(
            '--device', default='cuda:0', help='Device used for inference')
        parser.add_argument(
            '--opacity',
            type=float,
            default=1,     # 调节背景透明度,0为全白1为全黑,默认为0.5
            help='Opacity of painted segmentation map. In (0, 1] range.')
        parser.add_argument(
            '--with-labels',
            action='store_true',
            default=False,
            help='Whether to display the class labels.')
        parser.add_argument(
            '--title', default='result', help='The image identifier.')
        args = parser.parse_args()
    
        # Create output directory if it doesn't exist
        os.makedirs(args.out_dir, exist_ok=True)
    
        # Build the model from a config file and a checkpoint file
        model = init_model(args.config, args.checkpoint, device=args.device)
        if args.device == 'cpu':
            model = revert_sync_batchnorm(model)
        
        # Process each image in the input directory
        for img_file in os.listdir(args.img_dir):
            img_path = os.path.join(args.img_dir, img_file)
            if os.path.isfile(img_path):
                # Test a single image
                result = inference_model(model, img_path)
                # Define the output file path
                out_file = os.path.join(args.out_dir, os.path.splitext(img_file)[0] + '_mask.png')
                # Show the results
                show_result_pyplot(
                    model,
                    img_path,
                    result,
                    title=args.title,
                    opacity=args.opacity,
                    with_labels=args.with_labels,
                    draw_gt=False,
                    show=False,
                    out_file=out_file)
    
    if __name__ == '__main__':
        main()
    
/usr/bin/time -v python mm_seg_predict.py ./test_time/source ./mmseg_log/unet_s5-d16_deeplabv3_4xb4-40k_handwriting-320×640.py ./mmseg_log/iter_40000.pth ./seg_output

测试效果

在这里插入图片描述 在这里插入图片描述 在这里插入图片描述

参考文章:

https://juejin.cn/post/7259624494960869435

  • 16
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值