mmclassification自定义数据集并训练
本文手把手实现mmclassification框架的自定义数据集的导入和训练,不对mmclassification安装做解释,阅读者自行安装mmclassification。
1.准备数据集
首先准备好数据集,并搞成如下的文件结构:
```
imagenet
├── meta
| ├── classmap.txt
├── train
│ ├── class1
│ │ ├── 026.JPEG
│ │ ├── ...
│ ├── class2
│ │ ├── 999.JPEG
│ │ ├── ...
│ ├── ...
├── val
│ ├── class1
│ │ ├── 0027.JPEG
│ │ ├── ...
│ ├── class2
│ │ ├── 993.JPEG
│ │ ├── ...
│ ├── ...
```
- 其中
classmap.txt
文件需要写入如下内容:(空格隔开 class1和class2需要与train和val文件夹中的class1和class2对应)
class1 dog 0
class2 cat 1
2.生成txt文件
生成txt文件用于导入mmclassification
import os
import glob
import re
# 生成train.txt和val.txt
#需要改为您自己的路径
root_dir = "/media/dmmm/CE31-3598/DataSets/classification_mine"
#在该路径下有train,val,meta三个文件夹
train_dir = os.path.join(root_dir, "train")
val_dir = os.path.join(root_dir, "val")
meta_dir = os.path.join(root_dir, "meta")
def generate_txt(images_dir,map_dict):
# 读取所有文件名
imgs_dirs = glob.glob(images_dir+"/*/*")
# 打开写入文件
typename = images_dir.split("/")[-1]
target_txt_path = os.path.join(meta_dir,typename+".txt")
f = open(target_txt_path,"w")
# 遍历所有图片名
for img_dir in imgs_dirs:
# 获取第一级目录名称
filename = img_dir.split("/")[-2]
num = map_dict[filename]
# 写入文件
relate_name = re.findall(typename+"/([\w / - .]*)",img_dir)
f.write(relate_name[0]+" "+num+"\n")
def get_map_dict():
# 读取所有类别映射关系
class_map_dict = {}
with open(os.path.join(meta_dir,"classmap.txt"),"r") as F:
lines = F.readlines()
for line in lines:
line = line.split("\n")[0]
filename,cls,num = line.split(" ")
class_map_dict[filename] = num
return class_map_dict
if __name__ == '__main__':
class_map_dict = get_map_dict()
generate_txt(images_dir=train_dir,map_dict=class_map_dict)
generate_txt(images_dir=val_dir,map_dict=class_map_dict)
运行结束后会在meta文件夹中生成
train.txt
和val.txt
,用于导入到mmclassification中,内容如下所示(以train为例,val也是一样的)
class1/026.JPEG 0
class2/999.JPEG 1
3.修改mmclassification代码
mmcls/datasets
目录下新建py文件(名字自取,以mydataset.py
为例),写入内容如下:(#****对应自己的类别)
import numpy as np
from .builder import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module()
class MyDataset(BaseDataset):
CLASSES = ["dog","cat"]#***********************************
def load_annotations(self):
assert isinstance(self.ann_file, str)
data_infos = []
with open(self.ann_file) as f:
samples = [x.strip().split(' ') for x in f.readlines()]
for filename, gt_label in samples:
info = {'img_prefix': self.data_prefix}
info['img_info'] = {'filename': filename}
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_infos.append(info)
return data_infos
mmcls/datasets
目录下修改__init__.py
文件,添加内容如下:
from .mydataset import MyDataset
__all__ = [
#增加MyDataset这一项
'MyDataset'
]
4.修改configs文件
configs/_base_/datasets
目录下新建mydataset.py
文件,写入内容如下:(#***的内容是需要您自行修改为自己的路径,聪明的你肯定知道怎么改)
# dataset settings
dataset_type = 'MyDataset'#**************************************
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=224),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=(256, -1)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=32,
workers_per_gpu=2,
train=dict(
type=dataset_type,
data_prefix='/media/dmmm/CE31-3598/DataSets/classification_mine/train',#***************
ann_file='/media/dmmm/CE31-3598/DataSets/classification_mine/meta/train.txt',#****************
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_prefix='/media/dmmm/CE31-3598/DataSets/classification_mine/val',#******************
ann_file='/media/dmmm/CE31-3598/DataSets/classification_mine/meta/val.txt',#***************
pipeline=test_pipeline),
test=dict(
# replace `data/val` with `data/test` for standard test
type=dataset_type,
data_prefix='/media/dmmm/CE31-3598/DataSets/classification_mine/val',#********************
ann_file='/media/dmmm/CE31-3598/DataSets/classification_mine/meta/val.txt',#*******************
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='accuracy')
- 如果您使用过mmlab的代码,这边结束您应该已经ok了。
5.开始训练
configs/resnet/resnet18_b32x8_imagenet.py
,修改为如下内容:
_base_ = [
'../_base_/models/resnet18.py', '../_base_/datasets/mydataset.py',
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]
下面就可以在
tools/train
中修改config
文件进行训练:
def parse_args():
parser = argparse.ArgumentParser(description='Train a model')
parser.add_argument('--config',default="../configs/resnet/resnet18_b32x8_imagenet.py", help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')