MMclassification自定义数据集训练验证&&格式转换脚本

MMclassification_训练与验证

  • 数据集准备

数据集格式

文件夹格式:

DATA

├──train.txt #存放训练集图片路径及类名的txt文件,每一行为图片路径和类名

├──val.txt#存放验证集图片路径及类名的txt文件,每一行为图片路径和类名

├── train存放训练集图片

│   ├── class1

│   │   ├── 026.JPEG

│   │   ├── ...

│   ├── class2

│   │   ├── 999.JPEG

│   │   ├── ...

│   ├── ...

├── val存放验证集图片

│   ├── class1

│   │   ├── 0027.JPEG

│   │   ├── ...

│   ├── class2

│   │   ├── 993.JPEG

│   │   ├── ...

│   ├── ...

Txt文件可以通过格式转换脚本IMGDataset_txt.py生成。

其中--data_dir为数据集DATA的路径。

--image_type为数据集图片的格式。

--output_dir为转换后的txt文件存放路径。

  • 修改

以resnet101为例

/mmclassification/configs/resnet/resnet101_8xb32_in1k_IMGDataset.py

 

三、训练

训练命令格式:

# 单 GPU 训练

python tools/train.py ${CONFIG_FILE} [optional arguments]

# 多 GPU 训练

bash tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments]

说明:

config_file:模型配置文件的路径

gpu_num:使用 GPU 的数量

--work-dir:设置存放训练生成文件的路径

--resume-from:设置恢复训练的模型检查点文件的路径

--no-validate(不建议):设置训练时不验证模型

--seed:设置随机种子,便于复现结果

这里以resnet101为例,cd 到yuml_web目录下,运行命令:

python  mmclassification/tools/train.py  mmclassification/configs/resnet/resnet101_8xb32_in1k_IMGDataset.py

即可开始训练模型。其中训练产生的所有日志文件都保存在work_dir中。

四、验证

# 单 GPU 测试

python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} \

    [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] [--show]

# 多 GPU 测试

bash tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} \

[--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}]

config_file:模型配置文件的路径

checkpoint_file:模型检查点文件的路径

gpu_num:使用的 GPU 数量

--out:设置输出 pkl 测试结果文件的路径

--work-dir:设置存放 json 日志文件的路径

--eval:设置度量指标(voc:mAP, recall | coco:bbox, segm, proposal)

--show:设置显示有预测框的测试集图像

--show-dir:设置存放有预测框的测试集图像的路径

--show-score-thr:设置显示预测框的阈值,默认值为 0.3

--fuse-conv-bn: 设置融合卷积层和批归一化层,能够稍微提升推理速度

这里以resnet101为例,建议在work_dir中需要验证的pth模型文件复制到yuml_web/checkpoints/下,cd 到yuml_web目录下,运行命令:

python  mmclassification/tools/test.py  mmclassification/configs/resnet/resnet101_8xb32_in1k_IMGDataset.py

checkpoints/pth模型文件名 --show

转换脚本:

from pathlib import Path
import argparse
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str,default='.\products10k',
                        help='Directory of images and xml.')
    parser.add_argument('--image_type', type=str,default='jpg',
                        help='Directory of images and xml.')
    parser.add_argument('--output_dir', type=str,default='.\products10k',
                        help='Directory of output.')
    a=parser.parse_args()
    root = Path(a.data_dir)
    train_paths = (root / "train").rglob("*")
    val_paths = (root / "val").rglob("*")
    img_type = a.image_type
    train_txt = []
    for i in train_paths:
        for j in i.glob("*."+img_type):
            train_txt.append(str(j) + " " + str(i.stem) + "\n")
    val_txt = []
    for i in val_paths:
        for j in i.glob("*."+img_type):
            val_txt.append(str(j) + " " + str(i.stem) + "\n")
    train_txt[-1] = train_txt[-1].strip()
    train_dir = a.output_dir + "/train.txt"
    val_dir = a.output_dir + "/val.txt"
    with open(train_dir, "w", encoding="utf-8") as f:
        f.writelines(train_txt)
    val_txt[-1] = val_txt[-1].strip()
    with open(val_dir, "w", encoding="utf-8") as f:
        f.writelines(val_txt)
    print("success")

替换:

"/mmclassification/mmcls/datasets/__init__.py"

# Copyright (c) OpenMMLab. All rights reserved.
from .base_dataset import BaseDataset
from .builder import (DATASETS, PIPELINES, SAMPLERS, build_dataloader,
                      build_dataset, build_sampler)
from .cifar import CIFAR10, CIFAR100
from .cub import CUB
from .custom import CustomDataset
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
                               KFoldDataset, RepeatDataset)
from .imagenet import ImageNet
from .imagenet21k import ImageNet21k
from .mnist import MNIST, FashionMNIST
from .multi_label import MultiLabelDataset
from .samplers import DistributedSampler, RepeatAugSampler
from .voc import VOC
from .img import IMGDataset

__all__ = [
    'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
    'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
    'DistributedSampler', 'ConcatDataset', 'RepeatDataset', 'IMGDataset'
    'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k', 'SAMPLERS',
    'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB', 'CustomDataset'
]
 

新增"/mmclassification/mmcls/datasets/img.py"

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Union

from .builder import DATASETS
from .custom import CustomDataset


@DATASETS.register_module()
class IMGDataset(CustomDataset):
    """`ImageNet <http://www.image-net.org>`_ Dataset.

    This implementation is modified from
    https://github.com/pytorch/vision/blob/master/torchvision/datasets/imagenet.py
    """  # noqa: E501

    IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
    class_dir = '/data/class.txt'
    with open(class_dir, "r", encoding="utf-8") as f:
        a = [i.strip() for i in f.readlines()]
    CLASSES = a
    def __init__(self,
                 data_prefix: str,
                 pipeline: Sequence = (),
                 classes: Union[str, Sequence[str], None] = None,
                 ann_file: Optional[str] = None,
                 test_mode: bool = False,
                 file_client_args: Optional[dict] = None):
        super().__init__(
            data_prefix=data_prefix,
            pipeline=pipeline,
            classes=classes,
            ann_file=ann_file,
            extensions=self.IMG_EXTENSIONS,
            test_mode=test_mode,
            file_client_args=file_client_args)
 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值