【图像分类】mmclassification 安装、准备数据、训练、可视化


在这里插入图片描述

MMclassification 是一个分类工具库,这篇文章是简单记录一下如何用该工具库来训练自己的分类模型,包括数据准备,模型修改,模型训练,模型测试等等。

MMclassification链接:https://github.com/open-mmlab/mmclassification

安装:https://mmclassification.readthedocs.io/en/latest/install.html

训练:https://mmclassification.readthedocs.io/en/latest/getting_started.html

一、环境配置

  • 配置 /etc/apt/sources.list 为阿里云的源
  • 配置 /etc/resolv.conf 为 nameserver 114.114.114.114
sudo apt install software-properties-common -y
sudo add-apt-repository ppa:deadsnakes/ppa
sudo apt update && sudo apt upgrade -y

1.1 安装 conda

apt-get install libgl1-mesa-glx libegl1-mesa libxrandr2 libxrandr2 libxss1 libxcursor1 libxcomposite1 libasound2 libxi6 libxtst6
wget https://repo.continuum.io/archive/Anaconda3-2020.11-Linux-x86_64.sh

source ~/.bashrc # source路径生效

(base) root@k8s-master-133:/home/y# conda -V
conda 4.9.2

# 配置python
conda create -n mmcls python=3.8 -y
conda activate mmcls

1.2 安装cuda

https://blog.csdn.net/zhouchen1998/article/details/107778087
https://developer.nvidia.com/cuda-10.2-download-archive

wget https://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda_10.2.89_440.33.01_linux.run
sudo sh cuda_10.2.89_440.33.01_linux.run

最后只要nvidia-smi能看到就ok了

在这里插入图片描述

1.3 安装 pytorch

conda install pytorch==1.10.1 torchvision==0.11.2 cudatoolkit=cuda版本 -c pytorch

1.4 工程准备

cd mmclassification(把给你的文件夹名字改成mmclassification,然后进入文件夹)
pip install -e .

二、数据准备

MMclassification 支持 ImageNet 和 cifar 两种数据格式,我们以 ImageNet 为例来看看数据结构:

|- imagenet
|    |- classmap.txt
|    |- train
|    |	 |- cls1
|    |	 |- cls2
|    |	 |- cls3
|    |	 |- ...
|    |- train.txt
|    |- val
|    |	 |- images
|    |- val.txt

假设我们要训练一个猫狗二分类模型,则需要组织的形式如下:

|- dog_cat_dataset
|    |- classmap.txt
|    |- train
|    |	 |- dog
|    |	 |- cat
|    |- train.txt
|    |- val
|    |	 |- images
|    |- val.txt

其中,classmap.txt 中的内容如下:

dog 0
cat 1

三、模型修改

假设使用 resnet18 来训练,则我们需要修改的内容主要集中在 config 文件里边,修改后的config文件 resnet18_b32x8_dog_cat_cls.py 如下:

  • 修改类别:将 1000 类改为 2 类
  • 修改数据路径:data
  • 如果数据前处理需要修改的话,也可以在config里边修改
  • 因为config是最高级的,所以在这里修改后会覆盖模型从mmcls库中读出来的参数
_base_ = [
    '../_base_/models/resnet18.py', '../_base_/datasets/imagenet_bs32.py',
    '../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]
model = dict(
    head=dict(
        type='LinearClsHead',
        num_classes=2,
        in_channels=512,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        topk=(1, ),
    ))

data = dict(
    samples_per_gpu=32,
    workers_per_gpu=1,
    train=dict(
        data_prefix='data/dog_cat_dataset/train',
        ann_file='data/dog_cat_dataset/train.txt',
        classes='data/dog_cat_dataset/classmap.txt'),
    val=dict(
        data_prefix='data/dog_cat_dataset/val',
        ann_file='data/dog_cat_dataset/val.txt',
        classes='data/dog_cat_dataset/classmap.txt'),
    test=dict(
        # replace `data/val` with `data/test` for standard test
        data_prefix='data/dog_cat_dataset/val',
        ann_file='data/dog_cat_dataset/val.txt',
        classes='data/dog_cat_dataset/classmap.txt'))
evaluation = dict(interval=1, metric='accuracy', metric_options={'topk': (1, )})

四、模型训练

python tools/train.py configs/resnet/resnet18_b32x8_dog_cat_cls.py

在这里插入图片描述

五、模型效果可视化

python tools/test.py configs/resnet/resnet18_b32x8_dog_cat_cls.py ./models/epoch_99.pth --out result.pkl --show-dir output_cls

使用 gradcam 可视化:

python tools/visualizations/vis_cam.py visual_img/4.jpg configs/resnet/resnet18_b32x8_door.py  ./models/epoch_99.pth --s
ave-path visual_path/4.jpg

六、如何分别计算每个类别的精确率和召回率

先进行测试,得到 result.pkl 文件,然后运行下面的程序即可:

python tools/cal_precision.py configs/resnet/resnet18_b32x8_imagenet.py
import mmcv
import argparse
from mmcls.datasets import build_dataset
from mmcls.core.evaluation import calculate_confusion_matrix
from sklearn.metrics import confusion_matrix

def parse_args():
    parser = argparse.ArgumentParser(description='calculate precision and recall for each class')
    parser.add_argument('config', help='test config file path')
    args = parser.parse_args()
    return args

def main():
    args = parse_args()
    cfg = mmcv.Config.fromfile(args.config)
    dataset = build_dataset(cfg.data.test)
    pred = mmcv.load("./result.pkl")['pred_label']
    matrix = confusion_matrix(pred, dataset.get_gt_labels())
    print('confusion_matrix:', matrix)
    cat_recall = matrix[0,0]/(matrix[0,0]+matrix[1,0])
    dog_recall = matrix[1,1]/(matrix[0,1]+matrix[1,1])
    cat_precision = matrix[0,0]/sum(matrix[0])
    dog_precision = matrix[1,1]/sum(matrix[1])
    print(' cat_precision:{} \n dog_precison:{} \n cat_recall:{} \n dog_recall:{}'.format(cat_precision, dog_precison, cat_recall, dog_recall))

if __name__ == '__main__':
    main()
  • 7
    点赞
  • 39
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
可视化mmclassification的结果,你可以按照以下步骤进行操作: 1. 首先,你需要下载mmclassification的代码和模型。你可以通过引用中的下载链接来获取mmclassification的代码。该链接指向了mmclassification在GitHub上的仓库。 2. 如果你已经训练好了自己的模型,并且想要使用自己的模型进行可视化,那么你可以将模型文件放在mmclassification的默认模型目录中。该目录的路径可以在引用中找到。 3. 如果你想要使用已经训练好的模型进行可视化,你可以关注阿旭算法与机器学习公众号,并回复【mmlab实战1】,即可获取已经下载好的mmclassification源码和demo训练用的数据数据文件已经放置在mmcls/data目录中,这些数据可以用于可视化结果的展示。你可以在引用中找到更多关于如何获取这些数据的信息。 4. 一旦你准备好了代码和数据,你可以使用mmclassification提供的可视化工具来展示模型的结果。具体的可视化方法可以在mmclassification的文档或代码中找到。你可以通过查看mmclassification的GitHub仓库或者阅读相关的文档来了解如何使用这些可视化工具。 5. 最后,你可以根据你的需要选择不同的可视化方式,比如绘制混淆矩阵、生成分类报告或绘制类别概率分布图等。这些方法都可以帮助你更好地理解和展示mmclassification的结果。 请注意,以上步骤仅仅是为了帮助你开始可视化mmclassification的结果,并非详尽的操作指南。具体的操作步骤可能会因为你的实际需求和环境而有所不同。为了获得更详细的信息和指导,请参考mmclassification的文档和代码。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* *3* [【超详细】MMLab分类任务mmclassification:环境配置说明、训练、预测及模型结果可视化展示](https://blog.csdn.net/qq_42589613/article/details/129630044)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 100%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

呆呆的猫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值