2021-06-24

介绍

前段时间正好在做一个项目的过程中需要使用mmclassification训练自己的数据集,顺便整理了一下流程供大家参考。本文章默认大家已经安装好mmcv的所有环境。

本文环境

  1. cuda ->11.1
  2. torch ->1.9.0
  3. python ->3.6
  4. mmcv-full ->1.3.0
  5. mmcls ->0.12.0

准备数据集

  1. 将图片划分为训练集,验证集,测试集。文件目录结构如下:
    目录树结构
  2. 生成TXT标签
import pathlib
import random

path='/home/sychen/mmclassification/lp_data/val'

data_path = pathlib.Path(path)
all_images_path = list(data_path.glob('*/*'))
all_images_path = [str(path) for path in all_images_path]  # 所有图片路径名存入列表
random.shuffle(all_images_path)  # 打散

print(len(all_images_path))
print(all_images_path[:5])  # 打印前五个

# 开始制作标签
label_names = sorted(item.name for item in data_path.glob('*/') if item.is_dir())
print(label_names)  # 打印类别名  注:下一步是制作与类别名对应的标签
label_to_index = dict((name, index) for index, name in enumerate(label_names))

all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_images_path]

for image, label in zip(all_images_path[:5], all_image_labels[:5]):
   print(image, '-----', label)

filename='/home/sychen/mmclassification/lp_data/val.txt'     # ***这里也要记得改***
with open(filename,'w') as f:
   for image,label in zip(all_images_path,all_image_labels):
       image=image.split("/")[-2]+"/"+image.split("/")[-1]
       f.write(image+" "+str(label)+"\n")
print("\nAll images and labels have been written in the txt!\n")

生成后的结果
label

修改mmclassification代码

mmcls/datasets目录下新建py文件(名字自取)

import os
import numpy as np
from .base_dataset import BaseDataset
from .builder import DATASETS

def has_file_allowed_extension(filename, extensions):
    filename_lower = filename.lower()
    return any(filename_lower.endswith(ext) for ext in extensions)
def find_folders(root):
    folders = [
        d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))
    ]
    folders.sort()
    folder_to_idx = {folders[i]: i for i in range(len(folders))}
    return folder_to_idx
def get_samples(root, folder_to_idx, extensions):
    samples = []
    root = os.path.expanduser(root)
    for folder_name in sorted(os.listdir(root)):
        _dir = os.path.join(root, folder_name)
        if not os.path.isdir(_dir):
            continue
        for _, _, fns in sorted(os.walk(_dir)):
            for fn in sorted(fns):
                if has_file_allowed_extension(fn, extensions):
                    path = os.path.join(folder_name, fn)
                    item = (path, folder_to_idx[folder_name])
                    samples.append(item)
    return samples
@DATASETS.register_module()
#类名自取
class MyDataset(BaseDataset):
    IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
    CLASSES = ['0', '8', 'B', 'D']#修改点
    def load_annotations(self):
        if self.ann_file is None:
            folder_to_idx = find_folders(self.data_prefix)
            samples = get_samples(
                self.data_prefix,
                folder_to_idx,
                extensions=self.IMG_EXTENSIONS)
            if len(samples) == 0:
                raise (RuntimeError('Found 0 files in subfolders of: '
                                    f'{self.data_prefix}. '
                                    'Supported extensions are: '
                                    f'{",".join(self.IMG_EXTENSIONS)}'))

            self.folder_to_idx = folder_to_idx
        elif isinstance(self.ann_file, str):
            with open(self.ann_file) as f:
                samples = [x.strip().split(' ') for x in f.readlines()]
        else:
            raise TypeError('ann_file must be a str or None')
        self.samples = samples

        data_infos = []
        for filename, gt_label in self.samples:
            info = {'img_prefix': self.data_prefix}
            info['img_info'] = {'filename': filename}
            info['gt_label'] = np.array(gt_label, dtype=np.int64)
            data_infos.append(info)
        return data_infos

将自定义的数据加载类添加到mmcls/datasets目录下的__init__.py文件中
在这里插入图片描述

修改配置文件

configs/base/datasets目录下新建mydataset.py文件,写入相应的数据地址。
注意:前缀目录和.txt文件中的目录连起来就是图片的完整目录。

dataset_type = 'MyDataset' #数据加载器的类名
……
data = dict(
    samples_per_gpu=32,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_prefix='/home/sychen/mmclassification/lp_data/train',
        ann_file='/home/sychen/mmclassification/lp_data/train.txt',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_prefix='/home/sychen/mmclassification/lp_data/val',
        ann_file='/home/sychen/mmclassification/lp_data/val.txt',
        pipeline=test_pipeline),
    test=dict(
        # replace `data/val` with `data/test` for standard test
        type=dataset_type,
        data_prefix='/home/sychen/mmclassification/lp_data/test',
        ann_file='/home/sychen/mmclassification/lp_data/test.txt',
        pipeline=test_pipeline))

configs/base/models目录下新建mymodel.py文件,可以直接复制你想用的模型结构的配置文件,比如:resnet18.py。

# model settings
model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNet',
        depth=18,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=4,
        in_channels=512,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        # topk=(1, 5),#注释这个
    ))

configs/base/schedules目录下新建myschedule.py文件,可以直接复制你想用的训练计划的配置文件,比如:imagenet_bs256.py。

optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[30, 60, 90])
runner = dict(type='EpochBasedRunner', max_epochs=100)

在configs/resnet目录下新建myconfig.py文件。加载上面的配置文件

_base_ = [
    '../_base_/models/mymodel.py', '../_base_/datasets/mydataset.py',
    '../_base_/schedules/myschedule.py', '../_base_/default_runtime.py'
]

训练

python tools/train.py configs/resnet/myconfig.py --work-dir work_dirs/task

在这里插入图片描述
出现的问题1:
在这里插入图片描述
解决:模型初始化问题,直接删除model.init_weights()
出现的问题2:
在这里插入图片描述
解决去掉对应的形参:custom_hooks_config

测试

python tools/test.py configs/resnet/myconfig.py work_dirs/dirname/epoch_n.pth --out test_result.json

测试完成后,会在根目录下生成一个test_result.json文件。

生成onnx文件

python tools/pytorch2onnx.py \
    configs/resnet/ myconfig.py \
    --checkpoint work_dirs/dirname/epoch_n.pth \
    --output-file work_dirs/dirname / epoch_n.onnx \
    --dynamic-shape \
    --show \
    --simplify \
    --verify \
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值