[DAY3]图像分类工具包 MMClassification

代码仓库:https://github.com/open-mmlab/mmclassification

文档教程:https://mmclassification.readthedocs.io/en/latest/

一、安装

1.创建环境

  • 加载 anaconda ,创建一个 python 3.8 的环境。

# 创建 python=3.8 的环境
conda create --name mmclassification python=3.8
​
# 激活环境
conda activate mmclassification
  • 安装torch

查看已安装的cuda版本

nvcc --version
# Cuda compilation tools, release 11.3, V11.3.58

根据已有的cuda版本安装torch

pytorch下载地址:

https://pytorch.org/get-started/previous-versions/

# CUDA 11.3
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
  • 安装 mmcv-full 模块,mmcv-full 模块安装时候需要注意 torch 和 cuda 版本。

参考安装 MMCV — mmcv 1.7.1 文档

在安装 mmcv-full 之前,请确保 PyTorch 已经成功安装在环境中,可以参考 PyTorch 官方安装文档。可使用以下命令验证:

python -c 'import torch;print(torch.__version__)'

使用 mim 安装

# 安装 mim
pip install -U openmim
# 安装基础库 mmcv 完整版
mim install mmcv-full  
  • 安装mmdetection

# 源码安装 mmdet
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
mim install -v -e .
  • 安装mmclassification

# 源码安装 mmcls
cd ..
git clone https://github.com/open-mmlab/mmclassification.git
cd mmclassification
mim install -v -e .

2.验证安装

为了验证 MMClassification 的安装是否正确,我们提供了一些示例代码来执行模型推理。

  • 第 1 步 我们需要下载配置文件和模型权重文件

mim download mmcls --config resnet50_8xb32_in1k --dest .
  • 第 2 步 验证示例的推理流程

如果你是从源码安装的 mmcls,那么直接运行以下命令进行验证:

python demo/image_demo.py demo/demo.JPEG resnet50_8xb32_in1k.py resnet50_8xb32_in1k_20210831-ea4938fc.pth --device cpu

你可以看到命令行中输出了结果字典,包括 pred_label,pred_score 和 pred_class 三个字段。

3.数据集

  • flower 数据集包含 5 种类别的花卉图像:雏菊 daisy 588张,蒲公英 dandelion 556张,玫瑰 rose 583张,向日葵 sunflower 536张,郁金香 tulip 585张。

3.1划分数据集

  • 将数据集按照 8:2 的比例划分成训练和验证子数据集,并将数据集整理成 ImageNet的格式

  • 将训练子集和验证子集放到 train 和 val 文件夹下。

  • 创建并编辑标注文件将所有类别的名称写到 classes.txt 中,每行代表一个类别。

tulip
dandelion
daisy
sunflower
rose
  • 生成训练(可选)和验证子集标注列表 train.txt 和 val.txt ,每行应包含一个文件名和其对应的标签。如下,可将处理好的数据集迁移到 mmclassification/data 文件夹下。

  • 数据集划分代码 split_data.py 如下,执行:

python split_data.py [源数据集路径] [目标数据集路径]
import os
import sys
import shutil
import numpy as np


def load_data(data_path):
    count = 0
    data = {}
    for dir_name in os.listdir(data_path):
        dir_path = os.path.join(data_path, dir_name)
        if not os.path.isdir(dir_path):
            continue

        data[dir_name] = []
        for file_name in os.listdir(dir_path):
            file_path = os.path.join(dir_path, file_name)
            if not os.path.isfile(file_path):
                continue
            data[dir_name].append(file_path)

        count += len(data[dir_name])
        print("{} :{}".format(dir_name, len(data[dir_name])))

    print("total of image : {}".format(count))
    return data


def copy_dataset(src_img_list, data_index, target_path):
    target_img_list = []
    for index in data_index:
        src_img = src_img_list[index]
        img_name = os.path.split(src_img)[-1]

        shutil.copy(src_img, target_path)
        target_img_list.append(os.path.join(target_path, img_name))
    return target_img_list
    
    
def write_file(data, file_name):
    if isinstance(data, dict):
        write_data = []
        for lab, img_list in data.items():
            for img in img_list:
                write_data.append("{} {}".format(img, lab))
    else:
        write_data = data

    with open(file_name, "w") as f:
        for line in write_data:
            f.write(line + "\n")

    print("{} write over!".format(file_name))
    
    
def split_data(src_data_path, target_data_path, train_rate=0.8):
    src_data_dict = load_data(src_data_path)

    classes = []
    train_dataset, val_dataset = {}, {}
    train_count, val_count = 0, 0
    for i, (cls_name, img_list) in enumerate(src_data_dict.items()):
        img_data_size = len(img_list)
        random_index = np.random.choice(img_data_size, img_data_size,replace=False)

        train_data_size = int(img_data_size * train_rate)
        train_data_index = random_index[:train_data_size]
        val_data_index = random_index[train_data_size:]

        train_data_path = os.path.join(target_data_path, "train", cls_name)
        val_data_path = os.path.join(target_data_path, "val", cls_name)
        os.makedirs(train_data_path, exist_ok=True)
        os.makedirs(val_data_path, exist_ok=True)

        classes.append(cls_name)
        train_dataset[i] = copy_dataset(img_list, train_data_index,train_data_path)
        val_dataset[i] = copy_dataset(img_list, val_data_index, val_data_path)

        print("target {} train:{}, val:{}".format(cls_name,len(train_dataset[i]), len(val_dataset[i])))

        train_count += len(train_dataset[i])
        val_count += len(val_dataset[i])

    print("train size:{}, val size:{}, total:{}".format(train_count, val_count, train_count + val_count))

    write_file(classes, os.path.join(target_data_path, "classes.txt"))
    write_file(train_dataset, os.path.join(target_data_path, "train.txt"))
    write_file(val_dataset, os.path.join(target_data_path, "val.txt"))
    
    
def main():
    src_data_path = sys.argv[1]
    target_data_path = sys.argv[2]
    split_data(src_data_path, target_data_path, train_rate=0.8)


if __name__ == '__main__':
    main()

二、代码教学

pytorch官方demo:

https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

OpenMMLab 项目中的重要概念——配置文件

作业

https://github.com/open-mmlab/OpenMMLabCamp/issues/1

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值