mmcls实战

本博客记录用mmclassification完成分类任务。具体步骤如下:

点击这里查看配置文件和训练保存的模型


1.运行环境

北京超算

2.划分数据集

划分数据集的代码如下:

import os
import sys
import shutil
import numpy as np


def load_data(data_path):
    count = 0
    data = {}
    for dir_name in os.listdir(data_path):
        dir_path = os.path.join(data_path, dir_name)
        if not os.path.isdir(dir_path):
            continue

        data[dir_name] = []
        for file_name in os.listdir(dir_path):
            file_path = os.path.join(dir_path, file_name)
            if not os.path.isfile(file_path):
                continue
            data[dir_name].append(file_path)

        count += len(data[dir_name])
        print("{} :{}".format(dir_name, len(data[dir_name])))

        print("total of image : {}".format(count))
    return data



def copy_datasets(src_img_list, data_index, target_path):
    target_img_list = []
    for index in data_index:
        src_img = src_img_list[index]
        img_name = os.path.split(src_img)[-1]   #tail:也就是最后的路径名元素


        shuil.copy(src_img, target_path)
        target_img_list.append(os.path.join(target_path, img_name))
    return target_img_list




def write_file(data, file_name):
    if isinstance(data, dict):  #train && val
        write_data = []
        for lab, img_list in data.items():
            for img in img_list:
                write_data.append("{} {}".format(img, lab))
    else:  # test
        write_data = data

    with open(file_name, "w") as f:
        for line in write_data:
            f.write(line + "\n")  # 将数据写入

    print("{} write over!".format(file_name))



def split_data(src_data_path, target_data_path, train_rate=0.8):
    src_data_dict = load_data(src_data_path)

    classes = []
    train_dataset, val_dataset = {}, {}
    train_count, val_count = 0, 0
    for i, (cls_name, img_list) in enumerate(src_data_dict.items()):
        img_data_size = len(img_list)
        # 随机抽取数据,False 表示不可以取相同数据
        random_index = np.random.choice(img_data_size, img_data_size, replace=False)

        train_data_size = int(img_data_size * train_rate)
        train_data_index = random_index[:train_data_size]
        val_data_index = random_index[train_data_size:]

        trian_data_path = os.path.join(target_data_path, "train", cla_name)
        val_data_path = os.path.join(target_data_path, "val", val_name)

        os.makedirs(train_data_path, exist_ok=True)  # 在目录存在时不触发FileExistError异常
        os.makedirs(val_data_path, exist_ok=True)


        classes.append(cls_name)
        train_dataset[i] = copy_dataset(img_list, train_data_index, train_data_path)
        val_dataset[i] = copy_dataset(img_list, val_data_index, val_data_path)

        print("target {} train:{}, val:{}".format(cls_name, len(train_dataset[i]), len(val_dataset[i])))
        train_count += len(train_dataset[i])
        val_count += len(val_dataset[i])

    print("train size:{}, val size:{}, total:{}".format(train_count, val_count, train_count + val_count))

    write_file(classes, os.path.join(target_data_path, "classes.txt"))
    write_file(train_dataset, os.path.join(target_data_path, "train.txt"))
    write_file(val_dataset, os.path.join(target_data_path, "val.txt"))



def main():
    src_data_path = sys.argv[1]
    target_data_path = sys.argv[2]
    split_data(src_data_path, target_data_path, target_rate=0.8)



if __name__ == '__main__':
    main()

划分后的数据文件夹的结构如下图所示:

3. 配置文件

使用openmmlab的好处就是可以从 configs/__base__ 中继承 ImageNet 预训练的任何模型,可以对照base文件夹中的这些配置进行自己需要的更改,

  • configs/_base_/models:模型相关配置
  • configs/_base_/datasets:数据集相关配置(数据集路径等)
  • configs/_base_/schedules:训练相关配置(学习率,优化器等)
  • configs/_base_/default_runtime.py:权重相关的保存设置

3.1 下载权重参数

 3.2 微调

 使⽤ tools/train.py 进⾏模型微调

python tools/train.py configs/resnet/resnet18_b16_flower.py --work-dir work_dirs/flower

4.开始训练

我使用的单节点多卡计算,作业脚本文件如下:

4.1 训练结果 

5.总结

以上就是本次分类任务的全部内容。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值