mmclassification自定义数据集并训练

mmclassification自定义数据集并训练

本文手把手实现mmclassification框架的自定义数据集的导入和训练,不对mmclassification安装做解释,阅读者自行安装mmclassification。

1.准备数据集

首先准备好数据集,并搞成如下的文件结构:

​```
imagenet
├── meta
|	├── classmap.txt
├── train
│   ├── class1
│   │   ├── 026.JPEG
│   │   ├── ...
│   ├── class2
│   │   ├── 999.JPEG
│   │   ├── ...
│   ├── ...
├── val
│   ├── class1
│   │   ├── 0027.JPEG
│   │   ├── ...
│   ├── class2
│   │   ├── 993.JPEG
│   │   ├── ...
│   ├── ...
​```
  • 其中classmap.txt文件需要写入如下内容:(空格隔开 class1和class2需要与train和val文件夹中的class1和class2对应)
class1 dog 0
class2 cat 1

2.生成txt文件

生成txt文件用于导入mmclassification

import os
import glob
import re

# 生成train.txt和val.txt

#需要改为您自己的路径
root_dir = "/media/dmmm/CE31-3598/DataSets/classification_mine"
#在该路径下有train,val,meta三个文件夹
train_dir = os.path.join(root_dir, "train")
val_dir = os.path.join(root_dir, "val")
meta_dir = os.path.join(root_dir, "meta")

def generate_txt(images_dir,map_dict):
    # 读取所有文件名
    imgs_dirs = glob.glob(images_dir+"/*/*")
    # 打开写入文件
    typename = images_dir.split("/")[-1]
    target_txt_path = os.path.join(meta_dir,typename+".txt")
    f = open(target_txt_path,"w")
    # 遍历所有图片名
    for img_dir in imgs_dirs:
        # 获取第一级目录名称
        filename = img_dir.split("/")[-2]
        num = map_dict[filename]
        # 写入文件
        relate_name = re.findall(typename+"/([\w / - .]*)",img_dir)
        f.write(relate_name[0]+" "+num+"\n")

def get_map_dict():
    # 读取所有类别映射关系
    class_map_dict = {}
    with open(os.path.join(meta_dir,"classmap.txt"),"r") as F:
        lines = F.readlines()
        for line in lines:
            line = line.split("\n")[0]
            filename,cls,num = line.split(" ")
            class_map_dict[filename] = num
    return class_map_dict

if __name__ == '__main__':

    class_map_dict = get_map_dict()

    generate_txt(images_dir=train_dir,map_dict=class_map_dict)

    generate_txt(images_dir=val_dir,map_dict=class_map_dict)

运行结束后会在meta文件夹中生成train.txtval.txt,用于导入到mmclassification中,内容如下所示(以train为例,val也是一样的)

class1/026.JPEG 0
class2/999.JPEG 1

3.修改mmclassification代码

mmcls/datasets目录下新建py文件(名字自取,以mydataset.py为例),写入内容如下:(#****对应自己的类别)

import numpy as np

from .builder import DATASETS
from .base_dataset import BaseDataset


@DATASETS.register_module()
class MyDataset(BaseDataset):
	CLASSES = ["dog","cat"]#***********************************
    def load_annotations(self):
        assert isinstance(self.ann_file, str)

        data_infos = []
        with open(self.ann_file) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for filename, gt_label in samples:
                info = {'img_prefix': self.data_prefix}
                info['img_info'] = {'filename': filename}
                info['gt_label'] = np.array(gt_label, dtype=np.int64)
                data_infos.append(info)
            return data_infos

mmcls/datasets目录下修改__init__.py文件,添加内容如下:

from .mydataset import MyDataset

__all__ = [
    #增加MyDataset这一项
    'MyDataset'
]

4.修改configs文件

configs/_base_/datasets目录下新建mydataset.py文件,写入内容如下:(#***的内容是需要您自行修改为自己的路径,聪明的你肯定知道怎么改)

# dataset settings
dataset_type = 'MyDataset'#**************************************
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', size=224),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', size=(256, -1)),
    dict(type='CenterCrop', crop_size=224),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]
data = dict(
    samples_per_gpu=32,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        data_prefix='/media/dmmm/CE31-3598/DataSets/classification_mine/train',#***************
        ann_file='/media/dmmm/CE31-3598/DataSets/classification_mine/meta/train.txt',#****************
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_prefix='/media/dmmm/CE31-3598/DataSets/classification_mine/val',#******************
        ann_file='/media/dmmm/CE31-3598/DataSets/classification_mine/meta/val.txt',#***************
        pipeline=test_pipeline),
    test=dict(
        # replace `data/val` with `data/test` for standard test
        type=dataset_type,
        data_prefix='/media/dmmm/CE31-3598/DataSets/classification_mine/val',#********************
        ann_file='/media/dmmm/CE31-3598/DataSets/classification_mine/meta/val.txt',#*******************
        pipeline=test_pipeline))
evaluation = dict(interval=1, metric='accuracy')
  • 如果您使用过mmlab的代码,这边结束您应该已经ok了。

5.开始训练

configs/resnet/resnet18_b32x8_imagenet.py,修改为如下内容:

_base_ = [
    '../_base_/models/resnet18.py', '../_base_/datasets/mydataset.py',
    '../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]

下面就可以在tools/train中修改config文件进行训练:

def parse_args():
    parser = argparse.ArgumentParser(description='Train a model')
    parser.add_argument('--config',default="../configs/resnet/resnet18_b32x8_imagenet.py", help='train config file path')
    parser.add_argument('--work-dir', help='the dir to save logs and models')
    parser.add_argument(
        '--resume-from', help='the checkpoint file to resume from')
  • 15
    点赞
  • 62
    收藏
    觉得还不错? 一键收藏
  • 13
    评论
### 回答1: mmclassification是一个深度学习框架,主要用于图像分类任务。在此框架中,有一组名为猫狗数据集的图像数据集。该数据集包含25,000张猫和狗的图片,分别来自不同的种类,其中12,500张是猫的图片,12,500张是狗的图片。 这个数据集是一个很好的基础数据集,用于训练图像分类模型。它可以用来评估不同的深度学习算法在包含多种类别的图像分类任务中的效果。同时,该数据集的图像质量较高,与实际场景更为接近,因此训练出来的模型具有较高的实用价值。 在使用mmclassification的猫狗数据集进行训练时,可以采用各种深度学习模型进行训练,并通过交叉验证等方式评估不同模型的效果。此外,可以对图像进行预处理以提高训练效果,比如对图像进行剪切、旋转、缩放等操作。在训练过程中,还可以使用分布式训练等技术,加快模型训练的速度。 总之,mmclassification猫狗数据集是一个常用的图像分类数据集,可以用于训练和评估各种深度学习模型,在实际应用中具有广泛的应用和推广价值。 ### 回答2: mmclassification猫狗数据集是一个用于图像分类任务的数据集,其中包含有大量的猫和狗的图像。这个数据集可以被广泛应用于机器学习算法的训练和测试中。 使用mmclassification猫狗数据集,我们可以训练一个分类器来识别一张图片中是猫还是狗。这个任务涉及到图像预处理、特征提取和模型训练等很多方面,需要综合运用图像处理、机器学习和深度学习等多个领域的知识和技术。 对于这个数据集,我们需要预处理数据,包括图像的大小和颜色等方面。然后使用现有的深度学习算法或自行设计模型来提取图像特征和训练模型。最后使用测试数据集来评估模型的准确性。 通过使用这个数据集进行训练和测试,我们可以得到一个高准确率的分类器,它可以成功地识别一张图片中是猫还是狗,并且能够适应不同场景和环境的变化。同时,这个数据集也能够促进机器学习和深度学习技术的发展和应用。
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值