mmclassification 训练自己的数据集

1 从源码安装

git clone https://github.com/open-mmlab/mmpretrain.git
cd mmpretrain
pip install -U openmim && mim install -e .

下面是我使用的版本

/media/xp/data/pydoc/mmlab/mmpretrain$ pip show mmcv mmpretrain mmengine
Name: mmcv
Version: 2.1.0
Summary: OpenMMLab Computer Vision Foundation
Home-page: https://github.com/open-mmlab/mmcv
Author: MMCV Contributors
Author-email: openmmlab@gmail.com
License: UNKNOWN
Location: /home/xp/anaconda3/envs/py3/lib/python3.8/site-packages
Requires: addict, mmengine, numpy, packaging, Pillow, pyyaml, yapf
Required-by: 
---
Name: mmpretrain
Version: 1.2.0
Summary: OpenMMLab Model Pretraining Toolbox and Benchmark
Home-page: https://github.com/open-mmlab/mmpretrain
Author: MMPretrain Contributors
Author-email: openmmlab@gmail.com
License: Apache License 2.0
Location: /media/xp/data/pydoc/mmlab/mmpretrain
Editable project location: /media/xp/data/pydoc/mmlab/mmpretrain
Requires: einops, importlib-metadata, mat4py, matplotlib, modelindex, numpy, rich
Required-by: 
---
Name: mmengine
Version: 0.10.3
Summary: Engine of OpenMMLab projects
Home-page: https://github.com/open-mmlab/mmengine
Author: MMEngine Authors
Author-email: openmmlab@gmail.com
License: UNKNOWN
Location: /home/xp/anaconda3/envs/py3/lib/python3.8/site-packages
Requires: addict, matplotlib, numpy, opencv-python, pyyaml, rich, termcolor, yapf
Required-by: mmcv

2 数据集准备

我以cat and dog分类数据集为例,我的训练集如下

/media/xp/data/image/deep_image/mini_cat_and_dog$ tree -L 2
.
├── train
│   ├── cat
│   └── dog
└── val
    ├── cat
    └── dog

在这里插入图片描述
在这里插入图片描述
注意:我训练的时候有些图好像是坏的,mmcv以opencv为后端来获取图片,这里最好先把坏图过滤掉,不然训练的时候会报cv imencode失败或者找不到图像。用下面的代码可以去除掉opencv打不开的图。

import cv2 as cv
import os

def find_all_image_files(root_dir):
    image_files = []
    for root, dirs, files in os.walk(root_dir):
        for file in files:
            if file.endswith('.jpg') or file.endswith('.png'):
                image_files.append(os.path.join(root, file))
    return image_files

def is_bad_image(image_file):
    try:
        img = cv.imread(image_file)
        if img is None:
            return True
        return False
    except:
        return True
    
def remove_bad_images(root_dir):
    image_files = find_all_image_files(root_dir)
    for image_file in image_files:
        if is_bad_image(image_file):
            os.remove(image_file)
            print(f"Removed bad image: {image_file}")

remove_bad_images("/media/xp/data/image/deep_image/mini_cat_and_dog")

3 config文件

mmlab系列的训练测试转化都是以config来配置的,三个基础块,一个是数据集,一个是模型,一个是runtime,有很多模型都是从_base_目录中继承这三个组件,然后修改其中的一些选项来训练不同的模型和数据集。
在训练的时候mm会保存一个训练的配置到work_dir目录下,后面也可以直接复制这个config去修改,把所有内容整合到一个config中,方便管理。如果你也喜欢这样的方式可以直接copy附录中的config修改去训练。
下面是我训练mobilenet v3时修改的config。

  • 在config/mobilenet_v3 目录下添加一个文件my_mobilenetv3.py
    configs/mobilenet_v3/my_mobilenetv3.py
_base_ = [
    # '../_base_/models/mobilenet_v3/mobilenet_v3_small_075_imagenet.py',
    '../_base_/datasets/my_custom.py',
    '../_base_/default_runtime.py',
]

# model settings

model = dict(
    type='ImageClassifier',
    backbone=dict(type='MobileNetV3', arch='small_075'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='StackedLinearClsHead',
        num_classes=2,
        in_channels=432,
        mid_channels=[1024],
        dropout_rate=0.2,
        act_cfg=dict(type='HSwish'),
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        init_cfg=dict(
            type='Normal', layer='Linear', mean=0., std=0.01, bias=0.),
        topk=(1, 1)))
# model = dict(backbone=dict(norm_cfg=dict(type='BN', eps=1e-5, momentum=0.1)))

my_image_size = 128
my_max_epochs = 300
my_batch_size = 128

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='RandomResizedCrop',
        scale=my_image_size,
        backend='pillow',
        interpolation='bicubic'),
    dict(type='RandomFlip', prob=0.5, direction='horizontal'),
    dict(
        type='AutoAugment',
        policies='imagenet',
        hparams=dict(pad_val=[round(x) for x in [128,128,128]])),
    dict(
        type='RandomErasing',
        erase_prob=0.2,
        mode='rand',
        min_area_ratio=0.02,
        max_area_ratio=1 / 3,
        fill_color=[128,128,128],
        fill_std=[50,50,50]),
    dict(type='PackInputs'),
]

test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='ResizeEdge',
        scale=my_image_size,
        edge='short',
        backend='pillow',
        interpolation='bicubic'),
    dict(type='CenterCrop', crop_size=my_image_size),
    dict(type='PackInputs'),
]

train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader

# schedule settings
optim_wrapper = dict(
    optimizer=dict(
        type='RMSprop',
        lr=0.064,
        alpha=0.9,
        momentum=0.9,
        eps=0.0316,
        weight_decay=1e-5))

param_scheduler = dict(type='StepLR', by_epoch=True, step_size=2, gamma=0.973)

train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=10)
val_cfg = dict()
test_cfg = dict()

# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (8 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=my_batch_size)

  • 在configs/base/datasets/下面创建 my_custom.py
# dataset settings
dataset_type = 'CustomDataset'
data_preprocessor = dict(
    num_classes=2,
    # RGB format normalization parameters
    mean=[128,128,128],
    std=[50,50,50],
    # convert image from BGR to RGB
    to_rgb=True,
)

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='ResizeEdge', scale=128, edge='short'),
    dict(type='CenterCrop', crop_size=128),
    dict(type='RandomFlip', prob=0.5, direction='horizontal'),
    dict(type='PackInputs'),
]

test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='ResizeEdge', scale=128, edge='short'),
    dict(type='CenterCrop', crop_size=128),
    dict(type='PackInputs'),
]

train_dataloader = dict(
    batch_size=32,
    num_workers=1,
    dataset=dict(
        type=dataset_type,
        data_root='/media/xp/data/image/deep_image/mini_cat_and_dog',
        data_prefix='train',
        with_label=True,
        pipeline=train_pipeline),
    sampler=dict(type='DefaultSampler', shuffle=True),
)

 
val_dataloader = dict(
    batch_size=32,
    num_workers=1,
    dataset=dict(
        type=dataset_type,
        data_root='/media/xp/data/image/deep_image/mini_cat_and_dog',
        data_prefix='val',
        with_label=True,
        pipeline=test_pipeline),
    sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 1))

# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

4 训练

$ python tools/train.py configs/mobilenet_v3/my_mobilenetv3.py 

输出

04/22 10:08:18 - mmengine - INFO - 
------------------------------------------------------------
System environment:
    sys.platform: linux
    Python: 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0]
    CUDA available: True
    MUSA available: False
    numpy_random_seed: 769660308
    GPU 0: Quadro M6000 24GB
    CUDA_HOME: /usr/local/cuda-11.8
    NVCC: Cuda compilation tools, release 11.8, V11.8.89
    GCC: gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
    PyTorch: 2.2.2+cu121
    PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2023.1-Product Build 20230303 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.3.2 (Git Hash 2dc95a2ad0841e29db8b22fbccaf3e5da7992b01)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 12.1
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
  - CuDNN 8.9.2
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=8.9.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.2.2, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF, 

    TorchVision: 0.17.2+cu121
    OpenCV: 4.9.0
    MMEngine: 0.10.4

Runtime environment:
    cudnn_benchmark: False
    mp_cfg: {'mp_start_method': 'fork', 'opencv_num_threads': 0}
    dist_cfg: {'backend': 'nccl'}
    seed: 1921958984
    deterministic: False
    Distributed launcher: none
    Distributed training: False
    GPU number: 1
--------------------------------------
04/22 10:09:08 - mmengine - WARNING - "FileClient" will be deprecated in future. Please use io functions in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io
04/22 10:09:08 - mmengine - WARNING - "HardDiskBackend" is the alias of "LocalBackend" and the former will be deprecated in future.
04/22 10:09:08 - mmengine - INFO - Checkpoints will be saved to /media/xp/data/pydoc/mmlab/mmpretrain/work_dirs/my_mobilenetv3.
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:09:17 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:09:17 - mmengine - INFO - Epoch(train)   [1][98/98]  lr: 6.4000e-02  eta: 1:31:37  time: 0.0913  data_time: 0.0129  loss: 11.2596
04/22 10:09:17 - mmengine - INFO - Saving checkpoint at 1 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:09:26 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:09:26 - mmengine - INFO - Epoch(train)   [2][98/98]  lr: 6.4000e-02  eta: 1:30:36  time: 0.0905  data_time: 0.0129  loss: 0.7452
04/22 10:09:26 - mmengine - INFO - Saving checkpoint at 2 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:09:35 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:09:35 - mmengine - INFO - Epoch(train)   [3][98/98]  lr: 6.2272e-02  eta: 1:29:30  time: 0.0841  data_time: 0.0059  loss: 0.7198
04/22 10:09:35 - mmengine - INFO - Saving checkpoint at 3 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:09:44 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:09:44 - mmengine - INFO - Epoch(train)   [4][98/98]  lr: 6.2272e-02  eta: 1:29:02  time: 0.0856  data_time: 0.0047  loss: 0.6938
04/22 10:09:44 - mmengine - INFO - Saving checkpoint at 4 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:09:53 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:09:53 - mmengine - INFO - Epoch(train)   [5][98/98]  lr: 6.0591e-02  eta: 1:28:42  time: 0.0877  data_time: 0.0100  loss: 0.7128
04/22 10:09:53 - mmengine - INFO - Saving checkpoint at 5 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:10:02 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:10:02 - mmengine - INFO - Epoch(train)   [6][98/98]  lr: 6.0591e-02  eta: 1:28:32  time: 0.0857  data_time: 0.0069  loss: 0.7214
04/22 10:10:02 - mmengine - INFO - Saving checkpoint at 6 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:10:11 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:10:11 - mmengine - INFO - Epoch(train)   [7][98/98]  lr: 5.8955e-02  eta: 1:28:11  time: 0.0860  data_time: 0.0063  loss: 0.7113
04/22 10:10:11 - mmengine - INFO - Saving checkpoint at 7 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:10:20 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:10:20 - mmengine - INFO - Epoch(train)   [8][98/98]  lr: 5.8955e-02  eta: 1:28:05  time: 0.0881  data_time: 0.0083  loss: 0.6989
04/22 10:10:20 - mmengine - INFO - Saving checkpoint at 8 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:10:29 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:10:29 - mmengine - INFO - Epoch(train)   [9][98/98]  lr: 5.7363e-02  eta: 1:28:23  time: 0.0883  data_time: 0.0077  loss: 0.6874
04/22 10:10:29 - mmengine - INFO - Saving checkpoint at 9 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:10:39 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:10:39 - mmengine - INFO - Epoch(train)  [10][98/98]  lr: 5.7363e-02  eta: 1:28:28  time: 0.0894  data_time: 0.0068  loss: 0.7028
04/22 10:10:39 - mmengine - INFO - Saving checkpoint at 10 epochs
04/22 10:10:39 - mmengine - INFO - Epoch(val) [10][3/3]    accuracy/top1: 60.8696  data_time: 0.0411  time: 0.0650

微调模型

参考
在进行模型微调时,我们通常希望在主干网络(backbone)加载预训练模型,再用我们的数据集训练一个新的分类头(head)。

为了在主干网络加载预训练模型,我们需要修改主干网络的初始化设置,使用 Pretrained 类型的初始化函数。另外,在初始化设置中,我们使用 prefix=‘backbone’ 来告诉初始化函数需要加载的子模块的前缀,backbone即指加载模型中的主干网络。
按照新数据集的类别数目来修改分类头的配置。只需要修改分 类头中的 num_classes 设置即可。

frozen_stages 参数设置可以冻结主干网络前面几层的参数,只训练后面层以及分类头的参数。

# >>>>>>>>>>>>>>> 在这里重载模型相关配置 >>>>>>>>>>>>>>>>>>>
model = dict(
    backbone=dict(
        frozen_stages=2,
        init_cfg=dict(
            type='Pretrained',
            checkpoint='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
            prefix='backbone',
        )),
    head=dict(num_classes=10),
)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
  • 下面这个是mobilenet-v3-small-0.75的model的完整配置
model = dict(
    backbone=dict(
        arch='small_075', 
        type='MobileNetV3',
        ############################使用预训练模型
        frozen_stages=2,
        init_cfg=dict(
            type='Pretrained',
            checkpoint='model/mobilenet-v3-small-075_3rdparty_in1k_20221114-2011fa76.pth',
            prefix='backbone',
        )
        ############################
        ),
    head=dict(
        act_cfg=dict(type='HSwish'),
        dropout_rate=0.2,
        in_channels=432,
        init_cfg=dict(
            bias=0.0, layer='Linear', mean=0.0, std=0.01, type='Normal'),
        loss=dict(loss_weight=1.0, type='CrossEntropyLoss'),
        mid_channels=[
            1024,
        ],
        num_classes=len(my_class_names),
        topk=(
            1,
            1,
        ),
        type='StackedLinearClsHead'),
    neck=dict(type='GlobalAveragePooling'),
    type='ImageClassifier')

5 部分安装问题【更新】

python版本问题

在另外一台机器上训练的时候出现下面的错误

  File "/home/xp/anaconda3/envs/mmlab/lib/python3.12/site-packages/mmengine/utils/dl_utils/collect_env.py", line 54, in collect_env
    from distutils import errors
ModuleNotFoundError: No module named 'distutils'

这是因为在python 3.12中把distutils库删掉了,需要换python版本,换python3.11.5可用。
参考链接:https://stackoverflow.com/questions/69919970/no-module-named-distutils-but-distutils-installed
在conda中你没办法直接conda install python=3.11.5,因为尝试安装就会发现基本所有的库都会冲突,没法修改python的版本,但是换python版本意味着你需要重新安装pytorch,conda不能从其他的env中安装包,要么有时间的时候安装几个版本的pytorch,后面要用的时候用conda create --name new_name --clone old_name来克隆一个环境,要么就等着安装吧,conda和pip的源也可以换,但是几个G的东西还是要一会儿的。
另外torch2.2.2似乎在python3.8的环境下面不能使用GPU,同样安装的包在python3.11.x中torch.cuda.is_available()为True,在3.8.X中为False,原因未知。

附录

  • 数据集准备
    官方文档
  • 训练完整config,可以直接修改了拿去训练用的,三个模块整合一起的。

my_train_batch_size = 64
my_val_batch_size = 16
my_image_size = 128
my_max_epochs = 300

my_checkpoints_interval = 10 # 10 epochs to save a checkpoint

my_train_dataset_root = '/media/xp/data/image/deep_image/mini_cat_and_dog'
my_train_data_prefix = 'train'
my_val_dataset_root = '/media/xp/data/image/deep_image/mini_cat_and_dog'
my_val_data_prefix = 'val'
my_test_dataset_root = '/media/xp/data/image/deep_image/mini_cat_and_dog'
my_test_data_prefix = 'test'

work_dir = './work_dirs/my_mobilenetv3'

my_class_names = ['cat', 'dog']


auto_scale_lr = dict(base_batch_size=128)
data_preprocessor = dict(
    mean=[
        128,
        128,
        128,
    ], num_classes=2, std=[
        50,
        50,
        50,
    ], to_rgb=True)
dataset_type = 'CustomDataset'



default_hooks = dict(
    checkpoint=dict(interval=my_checkpoints_interval, type='CheckpointHook'),
    logger=dict(interval=100, type='LoggerHook'),
    param_scheduler=dict(type='ParamSchedulerHook'),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    timer=dict(type='IterTimerHook'),
    visualization=dict(enable=False, type='VisualizationHook'))
default_scope = 'mmpretrain'
env_cfg = dict(
    cudnn_benchmark=False,
    dist_cfg=dict(backend='nccl'),
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
launcher = 'none'
load_from = None
log_level = 'INFO'
model = dict(
    backbone=dict(arch='small_075', type='MobileNetV3'),
    head=dict(
        act_cfg=dict(type='HSwish'),
        dropout_rate=0.2,
        in_channels=432,
        init_cfg=dict(
            bias=0.0, layer='Linear', mean=0.0, std=0.01, type='Normal'),
        loss=dict(loss_weight=1.0, type='CrossEntropyLoss'),
        mid_channels=[
            1024,
        ],
        num_classes=len(my_class_names),
        topk=(
            1,
            1,
        ),
        type='StackedLinearClsHead'),
    neck=dict(type='GlobalAveragePooling'),
    type='ImageClassifier')

optim_wrapper = dict(
    optimizer=dict(
        alpha=0.9,
        eps=0.0316,
        lr=0.064,
        momentum=0.9,
        type='RMSprop',
        weight_decay=1e-05))
param_scheduler = dict(by_epoch=True, gamma=0.973, step_size=2, type='StepLR')
randomness = dict(deterministic=False, seed=None)
resume = False
test_cfg = dict()
test_dataloader = dict(
    batch_size=my_val_batch_size,
    collate_fn=dict(type='default_collate'),
    dataset=dict(
        data_prefix='val',
        data_root=my_val_dataset_root,
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                backend='pillow',
                edge='short',
                interpolation='bicubic',
                scale=my_image_size,
                type='ResizeEdge'),
            dict(crop_size=my_image_size, type='CenterCrop'),
            dict(type='PackInputs'),
        ],
        type='CustomDataset',
        with_label=True),
    num_workers=1,
    persistent_workers=True,
    pin_memory=True,
    sampler=dict(shuffle=False, type='DefaultSampler'))
test_evaluator = dict(
    topk=(
        1,
        1,
    ), type='Accuracy')
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        backend='pillow',
        edge='short',
        interpolation='bicubic',
        scale=my_image_size,
        type='ResizeEdge'),
    dict(crop_size=my_image_size, type='CenterCrop'),
    dict(type='PackInputs'),
]
train_cfg = dict(by_epoch=True, max_epochs=my_max_epochs, val_interval=10)
train_dataloader = dict(
    batch_size=my_train_batch_size,
    collate_fn=dict(type='default_collate'),
    dataset=dict(
        data_prefix=my_train_data_prefix,
        data_root=my_train_dataset_root,
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                backend='pillow',
                interpolation='bicubic',
                scale=my_image_size,
                type='RandomResizedCrop'),
            dict(direction='horizontal', prob=0.5, type='RandomFlip'),
            dict(
                hparams=dict(pad_val=[
                    128,
                    128,
                    128,
                ]),
                policies='imagenet',
                type='AutoAugment'),
            dict(
                erase_prob=0.2,
                fill_color=[
                    128,
                    128,
                    128,
                ],
                fill_std=[
                    50,
                    50,
                    50,
                ],
                max_area_ratio=0.3333333333333333,
                min_area_ratio=0.02,
                mode='rand',
                type='RandomErasing'),
            dict(type='PackInputs'),
        ],
        type='CustomDataset',
        with_label=True),
    num_workers=1,
    persistent_workers=True,
    pin_memory=True,
    sampler=dict(shuffle=True, type='DefaultSampler'))
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        backend='pillow',
        interpolation='bicubic',
        scale=my_image_size,
        type='RandomResizedCrop'),
    dict(direction='horizontal', prob=0.5, type='RandomFlip'),
    dict(
        hparams=dict(pad_val=[
            128,
            128,
            128,
        ]),
        policies='imagenet',
        type='AutoAugment'),
    dict(
        erase_prob=0.2,
        fill_color=[
            128,
            128,
            128,
        ],
        fill_std=[
            50,
            50,
            50,
        ],
        max_area_ratio=0.3333333333333333,
        min_area_ratio=0.02,
        mode='rand',
        type='RandomErasing'),
    dict(type='PackInputs'),
]
val_cfg = dict()
val_dataloader = dict(
    batch_size=my_val_batch_size,
    collate_fn=dict(type='default_collate'),
    dataset=dict(
        data_prefix=my_val_data_prefix,
        data_root=my_val_dataset_root,
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                backend='pillow',
                edge='short',
                interpolation='bicubic',
                scale=my_image_size,
                type='ResizeEdge'),
            dict(crop_size=my_image_size, type='CenterCrop'),
            dict(type='PackInputs'),
        ],
        type='CustomDataset',
        with_label=True),
    num_workers=1,
    persistent_workers=True,
    pin_memory=True,
    sampler=dict(shuffle=False, type='DefaultSampler'))
val_evaluator = dict(
    topk=(
        1,
        1,
    ), type='Accuracy')
vis_backends = [
    dict(type='LocalVisBackend'),
]
visualizer = dict(
    type='UniversalVisualizer', vis_backends=[
        dict(type='LocalVisBackend'),
    ])


  • 5
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: mmclassification是一个深度学习框架,主要用于图像分类任务。在此框架中,有一组名为猫狗数据集的图像数据集。该数据集包含25,000张猫和狗的图片,分别来自不同的种类,其中12,500张是猫的图片,12,500张是狗的图片。 这个数据集是一个很好的基础数据集,用于训练图像分类模型。它可以用来评估不同的深度学习算法在包含多种类别的图像分类任务中的效果。同时,该数据集的图像质量较高,与实际场景更为接近,因此训练出来的模型具有较高的实用价值。 在使用mmclassification的猫狗数据集进行训练时,可以采用各种深度学习模型进行训练,并通过交叉验证等方式评估不同模型的效果。此外,可以对图像进行预处理以提高训练效果,比如对图像进行剪切、旋转、缩放等操作。在训练过程中,还可以使用分布式训练等技术,加快模型训练的速度。 总之,mmclassification猫狗数据集是一个常用的图像分类数据集,可以用于训练和评估各种深度学习模型,在实际应用中具有广泛的应用和推广价值。 ### 回答2: mmclassification猫狗数据集是一个用于图像分类任务的数据集,其中包含有大量的猫和狗的图像。这个数据集可以被广泛应用于机器学习算法的训练和测试中。 使用mmclassification猫狗数据集,我们可以训练一个分类器来识别一张图片中是猫还是狗。这个任务涉及到图像预处理、特征提取和模型训练等很多方面,需要综合运用图像处理、机器学习深度学习等多个领域的知识和技术。 对于这个数据集,我们需要预处理数据,包括图像的大小和颜色等方面。然后使用现有的深度学习算法或自行设计模型来提取图像特征和训练模型。最后使用测试数据集来评估模型的准确性。 通过使用这个数据集进行训练和测试,我们可以得到一个高准确率的分类器,它可以成功地识别一张图片中是猫还是狗,并且能够适应不同场景和环境的变化。同时,这个数据集也能够促进机器学习深度学习技术的发展和应用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值