介绍
前段时间正好在做一个项目的过程中需要使用mmclassification训练自己的数据集,顺便整理了一下流程供大家参考。本文章默认大家已经安装好mmcv的所有环境。
本文环境
- cuda ->11.1
- torch ->1.9.0
- python ->3.6
- mmcv-full ->1.3.0
- mmcls ->0.12.0
准备数据集
- 将图片划分为训练集,验证集,测试集。文件目录结构如下:
- 生成TXT标签
import pathlib
import random
path='/home/sychen/mmclassification/lp_data/val'
data_path = pathlib.Path(path)
all_images_path = list(data_path.glob('*/*'))
all_images_path = [str(path) for path in all_images_path] # 所有图片路径名存入列表
random.shuffle(all_images_path) # 打散
print(len(all_images_path))
print(all_images_path[:5]) # 打印前五个
# 开始制作标签
label_names = sorted(item.name for item in data_path.glob('*/') if item.is_dir())
print(label_names) # 打印类别名 注:下一步是制作与类别名对应的标签
label_to_index = dict((name, index) for index, name in enumerate(label_names))
all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_images_path]
for image, label in zip(all_images_path[:5], all_image_labels[:5]):
print(image, '-----', label)
filename='/home/sychen/mmclassification/lp_data/val.txt' # ***这里也要记得改***
with open(filename,'w') as f:
for image,label in zip(all_images_path,all_image_labels):
image=image.split("/")[-2]+"/"+image.split("/")[-1]
f.write(image+" "+str(label)+"\n")
print("\nAll images and labels have been written in the txt!\n")
生成后的结果
修改mmclassification代码
mmcls/datasets目录下新建py文件(名字自取)
import os
import numpy as np
from .base_dataset import BaseDataset
from .builder import DATASETS
def has_file_allowed_extension(filename, extensions):
filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in extensions)
def find_folders(root):
folders = [
d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))
]
folders.sort()
folder_to_idx = {folders[i]: i for i in range(len(folders))}
return folder_to_idx
def get_samples(root, folder_to_idx, extensions):
samples = []
root = os.path.expanduser(root)
for folder_name in sorted(os.listdir(root)):
_dir = os.path.join(root, folder_name)
if not os.path.isdir(_dir):
continue
for _, _, fns in sorted(os.walk(_dir)):
for fn in sorted(fns):
if has_file_allowed_extension(fn, extensions):
path = os.path.join(folder_name, fn)
item = (path, folder_to_idx[folder_name])
samples.append(item)
return samples
@DATASETS.register_module()
#类名自取
class MyDataset(BaseDataset):
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
CLASSES = ['0', '8', 'B', 'D']#修改点
def load_annotations(self):
if self.ann_file is None:
folder_to_idx = find_folders(self.data_prefix)
samples = get_samples(
self.data_prefix,
folder_to_idx,
extensions=self.IMG_EXTENSIONS)
if len(samples) == 0:
raise (RuntimeError('Found 0 files in subfolders of: '
f'{self.data_prefix}. '
'Supported extensions are: '
f'{",".join(self.IMG_EXTENSIONS)}'))
self.folder_to_idx = folder_to_idx
elif isinstance(self.ann_file, str):
with open(self.ann_file) as f:
samples = [x.strip().split(' ') for x in f.readlines()]
else:
raise TypeError('ann_file must be a str or None')
self.samples = samples
data_infos = []
for filename, gt_label in self.samples:
info = {'img_prefix': self.data_prefix}
info['img_info'] = {'filename': filename}
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_infos.append(info)
return data_infos
将自定义的数据加载类添加到mmcls/datasets目录下的__init__.py文件中
修改配置文件
configs/base/datasets目录下新建mydataset.py文件,写入相应的数据地址。
注意:前缀目录和.txt文件中的目录连起来就是图片的完整目录。
dataset_type = 'MyDataset' #数据加载器的类名
……
data = dict(
samples_per_gpu=32,
workers_per_gpu=2,
train=dict(
type=dataset_type,
data_prefix='/home/sychen/mmclassification/lp_data/train',
ann_file='/home/sychen/mmclassification/lp_data/train.txt',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_prefix='/home/sychen/mmclassification/lp_data/val',
ann_file='/home/sychen/mmclassification/lp_data/val.txt',
pipeline=test_pipeline),
test=dict(
# replace `data/val` with `data/test` for standard test
type=dataset_type,
data_prefix='/home/sychen/mmclassification/lp_data/test',
ann_file='/home/sychen/mmclassification/lp_data/test.txt',
pipeline=test_pipeline))
configs/base/models目录下新建mymodel.py文件,可以直接复制你想用的模型结构的配置文件,比如:resnet18.py。
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=18,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=4,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
# topk=(1, 5),#注释这个
))
configs/base/schedules目录下新建myschedule.py文件,可以直接复制你想用的训练计划的配置文件,比如:imagenet_bs256.py。
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[30, 60, 90])
runner = dict(type='EpochBasedRunner', max_epochs=100)
在configs/resnet目录下新建myconfig.py文件。加载上面的配置文件
_base_ = [
'../_base_/models/mymodel.py', '../_base_/datasets/mydataset.py',
'../_base_/schedules/myschedule.py', '../_base_/default_runtime.py'
]
训练
python tools/train.py configs/resnet/myconfig.py --work-dir work_dirs/task
出现的问题1:
解决:模型初始化问题,直接删除model.init_weights()
出现的问题2:
解决去掉对应的形参:custom_hooks_config
测试
python tools/test.py configs/resnet/myconfig.py work_dirs/dirname/epoch_n.pth --out test_result.json
测试完成后,会在根目录下生成一个test_result.json文件。
生成onnx文件
python tools/pytorch2onnx.py \
configs/resnet/ myconfig.py \
--checkpoint work_dirs/dirname/epoch_n.pth \
--output-file work_dirs/dirname / epoch_n.onnx \
--dynamic-shape \
--show \
--simplify \
--verify \