1 从源码安装
git clone https://github.com/open-mmlab/mmpretrain.git
cd mmpretrain
pip install -U openmim && mim install -e .
下面是我使用的版本
/media/xp/data/pydoc/mmlab/mmpretrain$ pip show mmcv mmpretrain mmengine
Name: mmcv
Version: 2.1.0
Summary: OpenMMLab Computer Vision Foundation
Home-page: https://github.com/open-mmlab/mmcv
Author: MMCV Contributors
Author-email: openmmlab@gmail.com
License: UNKNOWN
Location: /home/xp/anaconda3/envs/py3/lib/python3.8/site-packages
Requires: addict, mmengine, numpy, packaging, Pillow, pyyaml, yapf
Required-by:
---
Name: mmpretrain
Version: 1.2.0
Summary: OpenMMLab Model Pretraining Toolbox and Benchmark
Home-page: https://github.com/open-mmlab/mmpretrain
Author: MMPretrain Contributors
Author-email: openmmlab@gmail.com
License: Apache License 2.0
Location: /media/xp/data/pydoc/mmlab/mmpretrain
Editable project location: /media/xp/data/pydoc/mmlab/mmpretrain
Requires: einops, importlib-metadata, mat4py, matplotlib, modelindex, numpy, rich
Required-by:
---
Name: mmengine
Version: 0.10.3
Summary: Engine of OpenMMLab projects
Home-page: https://github.com/open-mmlab/mmengine
Author: MMEngine Authors
Author-email: openmmlab@gmail.com
License: UNKNOWN
Location: /home/xp/anaconda3/envs/py3/lib/python3.8/site-packages
Requires: addict, matplotlib, numpy, opencv-python, pyyaml, rich, termcolor, yapf
Required-by: mmcv
2 数据集准备
我以cat and dog分类数据集为例,我的训练集如下
/media/xp/data/image/deep_image/mini_cat_and_dog$ tree -L 2
.
├── train
│ ├── cat
│ └── dog
└── val
├── cat
└── dog
注意
:我训练的时候有些图好像是坏的,mmcv以opencv为后端来获取图片,这里最好先把坏图过滤掉,不然训练的时候会报cv imencode失败或者找不到图像。用下面的代码可以去除掉opencv打不开的图。
import cv2 as cv
import os
def find_all_image_files(root_dir):
image_files = []
for root, dirs, files in os.walk(root_dir):
for file in files:
if file.endswith('.jpg') or file.endswith('.png'):
image_files.append(os.path.join(root, file))
return image_files
def is_bad_image(image_file):
try:
img = cv.imread(image_file)
if img is None:
return True
return False
except:
return True
def remove_bad_images(root_dir):
image_files = find_all_image_files(root_dir)
for image_file in image_files:
if is_bad_image(image_file):
os.remove(image_file)
print(f"Removed bad image: {image_file}")
remove_bad_images("/media/xp/data/image/deep_image/mini_cat_and_dog")
3 config文件
mmlab系列的训练测试转化都是以config来配置的,三个基础块,一个是数据集,一个是模型,一个是runtime,有很多模型都是从_base_目录中继承这三个组件,然后修改其中的一些选项来训练不同的模型和数据集。
在训练的时候mm会保存一个训练的配置到work_dir目录下,后面也可以直接复制这个config去修改,把所有内容整合到一个config中,方便管理。如果你也喜欢这样的方式可以直接copy附录中的config修改去训练。
下面是我训练mobilenet v3时修改的config。
- 在config/mobilenet_v3 目录下添加一个文件my_mobilenetv3.py
configs/mobilenet_v3/my_mobilenetv3.py
_base_ = [
# '../_base_/models/mobilenet_v3/mobilenet_v3_small_075_imagenet.py',
'../_base_/datasets/my_custom.py',
'../_base_/default_runtime.py',
]
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='MobileNetV3', arch='small_075'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='StackedLinearClsHead',
num_classes=2,
in_channels=432,
mid_channels=[1024],
dropout_rate=0.2,
act_cfg=dict(type='HSwish'),
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
init_cfg=dict(
type='Normal', layer='Linear', mean=0., std=0.01, bias=0.),
topk=(1, 1)))
# model = dict(backbone=dict(norm_cfg=dict(type='BN', eps=1e-5, momentum=0.1)))
my_image_size = 128
my_max_epochs = 300
my_batch_size = 128
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=my_image_size,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='AutoAugment',
policies='imagenet',
hparams=dict(pad_val=[round(x) for x in [128,128,128]])),
dict(
type='RandomErasing',
erase_prob=0.2,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=[128,128,128],
fill_std=[50,50,50]),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=my_image_size,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=my_image_size),
dict(type='PackInputs'),
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader
# schedule settings
optim_wrapper = dict(
optimizer=dict(
type='RMSprop',
lr=0.064,
alpha=0.9,
momentum=0.9,
eps=0.0316,
weight_decay=1e-5))
param_scheduler = dict(type='StepLR', by_epoch=True, step_size=2, gamma=0.973)
train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=10)
val_cfg = dict()
test_cfg = dict()
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (8 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=my_batch_size)
- 在configs/base/datasets/下面创建 my_custom.py
# dataset settings
dataset_type = 'CustomDataset'
data_preprocessor = dict(
num_classes=2,
# RGB format normalization parameters
mean=[128,128,128],
std=[50,50,50],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=128, edge='short'),
dict(type='CenterCrop', crop_size=128),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=128, edge='short'),
dict(type='CenterCrop', crop_size=128),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=32,
num_workers=1,
dataset=dict(
type=dataset_type,
data_root='/media/xp/data/image/deep_image/mini_cat_and_dog',
data_prefix='train',
with_label=True,
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=32,
num_workers=1,
dataset=dict(
type=dataset_type,
data_root='/media/xp/data/image/deep_image/mini_cat_and_dog',
data_prefix='val',
with_label=True,
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 1))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator
4 训练
$ python tools/train.py configs/mobilenet_v3/my_mobilenetv3.py
输出
04/22 10:08:18 - mmengine - INFO -
------------------------------------------------------------
System environment:
sys.platform: linux
Python: 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0]
CUDA available: True
MUSA available: False
numpy_random_seed: 769660308
GPU 0: Quadro M6000 24GB
CUDA_HOME: /usr/local/cuda-11.8
NVCC: Cuda compilation tools, release 11.8, V11.8.89
GCC: gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
PyTorch: 2.2.2+cu121
PyTorch compiling details: PyTorch built with:
- GCC 9.3
- C++ Version: 201703
- Intel(R) oneAPI Math Kernel Library Version 2023.1-Product Build 20230303 for Intel(R) 64 architecture applications
- Intel(R) MKL-DNN v3.3.2 (Git Hash 2dc95a2ad0841e29db8b22fbccaf3e5da7992b01)
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: AVX2
- CUDA Runtime 12.1
- NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
- CuDNN 8.9.2
- Magma 2.6.1
- Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=8.9.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.2.2, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,
TorchVision: 0.17.2+cu121
OpenCV: 4.9.0
MMEngine: 0.10.4
Runtime environment:
cudnn_benchmark: False
mp_cfg: {'mp_start_method': 'fork', 'opencv_num_threads': 0}
dist_cfg: {'backend': 'nccl'}
seed: 1921958984
deterministic: False
Distributed launcher: none
Distributed training: False
GPU number: 1
--------------------------------------
04/22 10:09:08 - mmengine - WARNING - "FileClient" will be deprecated in future. Please use io functions in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io
04/22 10:09:08 - mmengine - WARNING - "HardDiskBackend" is the alias of "LocalBackend" and the former will be deprecated in future.
04/22 10:09:08 - mmengine - INFO - Checkpoints will be saved to /media/xp/data/pydoc/mmlab/mmpretrain/work_dirs/my_mobilenetv3.
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:09:17 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:09:17 - mmengine - INFO - Epoch(train) [1][98/98] lr: 6.4000e-02 eta: 1:31:37 time: 0.0913 data_time: 0.0129 loss: 11.2596
04/22 10:09:17 - mmengine - INFO - Saving checkpoint at 1 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:09:26 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:09:26 - mmengine - INFO - Epoch(train) [2][98/98] lr: 6.4000e-02 eta: 1:30:36 time: 0.0905 data_time: 0.0129 loss: 0.7452
04/22 10:09:26 - mmengine - INFO - Saving checkpoint at 2 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:09:35 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:09:35 - mmengine - INFO - Epoch(train) [3][98/98] lr: 6.2272e-02 eta: 1:29:30 time: 0.0841 data_time: 0.0059 loss: 0.7198
04/22 10:09:35 - mmengine - INFO - Saving checkpoint at 3 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:09:44 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:09:44 - mmengine - INFO - Epoch(train) [4][98/98] lr: 6.2272e-02 eta: 1:29:02 time: 0.0856 data_time: 0.0047 loss: 0.6938
04/22 10:09:44 - mmengine - INFO - Saving checkpoint at 4 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:09:53 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:09:53 - mmengine - INFO - Epoch(train) [5][98/98] lr: 6.0591e-02 eta: 1:28:42 time: 0.0877 data_time: 0.0100 loss: 0.7128
04/22 10:09:53 - mmengine - INFO - Saving checkpoint at 5 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:10:02 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:10:02 - mmengine - INFO - Epoch(train) [6][98/98] lr: 6.0591e-02 eta: 1:28:32 time: 0.0857 data_time: 0.0069 loss: 0.7214
04/22 10:10:02 - mmengine - INFO - Saving checkpoint at 6 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:10:11 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:10:11 - mmengine - INFO - Epoch(train) [7][98/98] lr: 5.8955e-02 eta: 1:28:11 time: 0.0860 data_time: 0.0063 loss: 0.7113
04/22 10:10:11 - mmengine - INFO - Saving checkpoint at 7 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:10:20 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:10:20 - mmengine - INFO - Epoch(train) [8][98/98] lr: 5.8955e-02 eta: 1:28:05 time: 0.0881 data_time: 0.0083 loss: 0.6989
04/22 10:10:20 - mmengine - INFO - Saving checkpoint at 8 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:10:29 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:10:29 - mmengine - INFO - Epoch(train) [9][98/98] lr: 5.7363e-02 eta: 1:28:23 time: 0.0883 data_time: 0.0077 loss: 0.6874
04/22 10:10:29 - mmengine - INFO - Saving checkpoint at 9 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:10:39 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:10:39 - mmengine - INFO - Epoch(train) [10][98/98] lr: 5.7363e-02 eta: 1:28:28 time: 0.0894 data_time: 0.0068 loss: 0.7028
04/22 10:10:39 - mmengine - INFO - Saving checkpoint at 10 epochs
04/22 10:10:39 - mmengine - INFO - Epoch(val) [10][3/3] accuracy/top1: 60.8696 data_time: 0.0411 time: 0.0650
微调模型
参考
在进行模型微调时,我们通常希望在主干网络(backbone)加载预训练模型,再用我们的数据集训练一个新的分类头(head)。
为了在主干网络加载预训练模型,我们需要修改主干网络的初始化设置,使用 Pretrained 类型的初始化函数。另外,在初始化设置中,我们使用 prefix=‘backbone’ 来告诉初始化函数需要加载的子模块的前缀,backbone即指加载模型中的主干网络。
按照新数据集的类别数目来修改分类头的配置。只需要修改分 类头中的 num_classes 设置即可。
frozen_stages 参数设置可以冻结主干网络前面几层的参数,只训练后面层以及分类头的参数。
# >>>>>>>>>>>>>>> 在这里重载模型相关配置 >>>>>>>>>>>>>>>>>>>
model = dict(
backbone=dict(
frozen_stages=2,
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
prefix='backbone',
)),
head=dict(num_classes=10),
)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
- 下面这个是mobilenet-v3-small-0.75的model的完整配置
model = dict(
backbone=dict(
arch='small_075',
type='MobileNetV3',
############################使用预训练模型
frozen_stages=2,
init_cfg=dict(
type='Pretrained',
checkpoint='model/mobilenet-v3-small-075_3rdparty_in1k_20221114-2011fa76.pth',
prefix='backbone',
)
############################
),
head=dict(
act_cfg=dict(type='HSwish'),
dropout_rate=0.2,
in_channels=432,
init_cfg=dict(
bias=0.0, layer='Linear', mean=0.0, std=0.01, type='Normal'),
loss=dict(loss_weight=1.0, type='CrossEntropyLoss'),
mid_channels=[
1024,
],
num_classes=len(my_class_names),
topk=(
1,
1,
),
type='StackedLinearClsHead'),
neck=dict(type='GlobalAveragePooling'),
type='ImageClassifier')
5 test
python tools/test.py configs/mobilenet_v3/my_mobilenetv3.py work_dirs/my_mobilenetv3/epoch_150.pth --work-dir work_dirs/my_mobilenetv3_test/ --show-dir work_dirs/my_mobilenetv3_test/show
6 部分安装问题【更新】
python版本问题
在另外一台机器上训练的时候出现下面的错误
File "/home/xp/anaconda3/envs/mmlab/lib/python3.12/site-packages/mmengine/utils/dl_utils/collect_env.py", line 54, in collect_env
from distutils import errors
ModuleNotFoundError: No module named 'distutils'
这是因为在python 3.12中把distutils库删掉了,需要换python版本,换python3.11.5可用。
参考链接:https://stackoverflow.com/questions/69919970/no-module-named-distutils-but-distutils-installed
在conda中你没办法直接conda install python=3.11.5
,因为尝试安装就会发现基本所有的库都会冲突,没法修改python的版本,但是换python版本意味着你需要重新安装pytorch,conda不能从其他的env中安装包,要么有时间的时候安装几个版本的pytorch,后面要用的时候用conda create --name new_name --clone old_name
来克隆一个环境,要么就等着安装吧,conda和pip的源也可以换,但是几个G的东西还是要一会儿的。
另外torch2.2.2似乎在python3.8的环境下面不能使用GPU,同样安装的包在python3.11.x中torch.cuda.is_available()
为True,在3.8.X中为False,原因未知。
附录
- 数据集准备
官方文档 - 训练完整config,可以直接修改了拿去训练用的,三个模块整合一起的。
my_train_batch_size = 64
my_val_batch_size = 16
my_image_size = 128
my_max_epochs = 300
my_checkpoints_interval = 10 # 10 epochs to save a checkpoint
my_train_dataset_root = '/media/xp/data/image/deep_image/mini_cat_and_dog'
my_train_data_prefix = 'train'
my_val_dataset_root = '/media/xp/data/image/deep_image/mini_cat_and_dog'
my_val_data_prefix = 'val'
my_test_dataset_root = '/media/xp/data/image/deep_image/mini_cat_and_dog'
my_test_data_prefix = 'test'
work_dir = './work_dirs/my_mobilenetv3'
my_class_names = ['cat', 'dog']
auto_scale_lr = dict(base_batch_size=128)
data_preprocessor = dict(
mean=[
128,
128,
128,
], num_classes=2, std=[
50,
50,
50,
], to_rgb=True)
dataset_type = 'CustomDataset'
default_hooks = dict(
checkpoint=dict(interval=my_checkpoints_interval, type='CheckpointHook'),
logger=dict(interval=100, type='LoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook'),
sampler_seed=dict(type='DistSamplerSeedHook'),
timer=dict(type='IterTimerHook'),
visualization=dict(enable=False, type='VisualizationHook'))
default_scope = 'mmpretrain'
env_cfg = dict(
cudnn_benchmark=False,
dist_cfg=dict(backend='nccl'),
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
launcher = 'none'
load_from = None
log_level = 'INFO'
model = dict(
backbone=dict(arch='small_075', type='MobileNetV3'),
head=dict(
act_cfg=dict(type='HSwish'),
dropout_rate=0.2,
in_channels=432,
init_cfg=dict(
bias=0.0, layer='Linear', mean=0.0, std=0.01, type='Normal'),
loss=dict(loss_weight=1.0, type='CrossEntropyLoss'),
mid_channels=[
1024,
],
num_classes=len(my_class_names),
topk=(
1,
1,
),
type='StackedLinearClsHead'),
neck=dict(type='GlobalAveragePooling'),
type='ImageClassifier')
optim_wrapper = dict(
optimizer=dict(
alpha=0.9,
eps=0.0316,
lr=0.064,
momentum=0.9,
type='RMSprop',
weight_decay=1e-05))
param_scheduler = dict(by_epoch=True, gamma=0.973, step_size=2, type='StepLR')
randomness = dict(deterministic=False, seed=None)
resume = False
test_cfg = dict()
test_dataloader = dict(
batch_size=my_val_batch_size,
collate_fn=dict(type='default_collate'),
dataset=dict(
data_prefix='val',
data_root=my_val_dataset_root,
pipeline=[
dict(type='LoadImageFromFile'),
dict(
backend='pillow',
edge='short',
interpolation='bicubic',
scale=my_image_size,
type='ResizeEdge'),
dict(crop_size=my_image_size, type='CenterCrop'),
dict(type='PackInputs'),
],
type='CustomDataset',
with_label=True),
num_workers=1,
persistent_workers=True,
pin_memory=True,
sampler=dict(shuffle=False, type='DefaultSampler'))
test_evaluator = dict(
topk=(
1,
1,
), type='Accuracy')
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
backend='pillow',
edge='short',
interpolation='bicubic',
scale=my_image_size,
type='ResizeEdge'),
dict(crop_size=my_image_size, type='CenterCrop'),
dict(type='PackInputs'),
]
train_cfg = dict(by_epoch=True, max_epochs=my_max_epochs, val_interval=10)
train_dataloader = dict(
batch_size=my_train_batch_size,
collate_fn=dict(type='default_collate'),
dataset=dict(
data_prefix=my_train_data_prefix,
data_root=my_train_dataset_root,
pipeline=[
dict(type='LoadImageFromFile'),
dict(
backend='pillow',
interpolation='bicubic',
scale=my_image_size,
type='RandomResizedCrop'),
dict(direction='horizontal', prob=0.5, type='RandomFlip'),
dict(
hparams=dict(pad_val=[
128,
128,
128,
]),
policies='imagenet',
type='AutoAugment'),
dict(
erase_prob=0.2,
fill_color=[
128,
128,
128,
],
fill_std=[
50,
50,
50,
],
max_area_ratio=0.3333333333333333,
min_area_ratio=0.02,
mode='rand',
type='RandomErasing'),
dict(type='PackInputs'),
],
type='CustomDataset',
with_label=True),
num_workers=1,
persistent_workers=True,
pin_memory=True,
sampler=dict(shuffle=True, type='DefaultSampler'))
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
backend='pillow',
interpolation='bicubic',
scale=my_image_size,
type='RandomResizedCrop'),
dict(direction='horizontal', prob=0.5, type='RandomFlip'),
dict(
hparams=dict(pad_val=[
128,
128,
128,
]),
policies='imagenet',
type='AutoAugment'),
dict(
erase_prob=0.2,
fill_color=[
128,
128,
128,
],
fill_std=[
50,
50,
50,
],
max_area_ratio=0.3333333333333333,
min_area_ratio=0.02,
mode='rand',
type='RandomErasing'),
dict(type='PackInputs'),
]
val_cfg = dict()
val_dataloader = dict(
batch_size=my_val_batch_size,
collate_fn=dict(type='default_collate'),
dataset=dict(
data_prefix=my_val_data_prefix,
data_root=my_val_dataset_root,
pipeline=[
dict(type='LoadImageFromFile'),
dict(
backend='pillow',
edge='short',
interpolation='bicubic',
scale=my_image_size,
type='ResizeEdge'),
dict(crop_size=my_image_size, type='CenterCrop'),
dict(type='PackInputs'),
],
type='CustomDataset',
with_label=True),
num_workers=1,
persistent_workers=True,
pin_memory=True,
sampler=dict(shuffle=False, type='DefaultSampler'))
val_evaluator = dict(
topk=(
1,
1,
), type='Accuracy')
vis_backends = [
dict(type='LocalVisBackend'),
]
visualizer = dict(
type='UniversalVisualizer', vis_backends=[
dict(type='LocalVisBackend'),
])