这里记录一下作业的流程以及心得。
1. 安装依赖库
主要是根据mmclassification的文档,安装了pytorch,mmcv以及mmclassification,具体可以参考这篇文档。
2. 下载数据集
根据作业中给出的数据集官网,下载好后,在mmclassification的文件夹中新建一个data文件夹,把数据放入后解压。
3. 对数据集进行分割,并生成.txt文件
话不多说,直接上代码
import os
from shutil import copy, rmtree
import random
def make_file(file_path: str):
if os.path.exists(file_path):
# 如果文件夹存在,则先删除原文件夹在重新创建
rmtree(file_path)
os.makedirs(file_path)
# 保证随机可复现
random.seed(0)#保证每次随机抽取的都可以复现
# 将数据集中20%的数据划分到验证集中
split_rate = 0.2 #这里填多少 就是验证集的比例是多少,比如填0.1就是验证集的数量占总数据集的10%
data_path = './data/flower_dataset'#数据集存放的地方,建议在程序所在的文件夹下新建一个data文件夹,将需要划分的数据集存放进去
data_root = './data/flower_dataset' #这里是生成的训练集和验证集所处的位置,这里设置的是在当前文件夹下。
data_class = [cla for cla in os.listdir(data_path)]
print("数据的种类分别为:")
print(data_class)# 输出数据种类,数据种类默认为读取的文件夹的名称
# 建立保存训练集的文件夹
train_data_root = os.path.join(data_root, "train") #训练集的文件夹名称为 train
make_file(train_data_root)
for num_class in data_class:
# 建立每个类别对应的文件夹
make_file(os.path.join(train_data_root, num_class))
# 建立保存验证集的文件夹
val_data_root = os.path.join(data_root, "val")#验证集的文件夹名称为 val
make_file(val_data_root)
for num_class in data_class:
# 建立每个类别对应的文件夹
make_file(os.path.join(val_data_root, num_class))
for num_class in data_class:
num_class_path = os.path.join(data_path, num_class)
images = os.listdir(num_class_path)
num = len(images)
val_index = random.sample(images, k=int(num*split_rate)) #随机抽取图片
for index, image in enumerate(images):
if image in val_index:
# 将划分到验证集中的文件复制到相应目录
data_image_path = os.path.join(num_class_path, image)
val_new_path = os.path.join(val_data_root, num_class)
copy(data_image_path, val_new_path)
try:
eval_f = open("./data/flower_dataset/val.txt", "a+")
eval_f.write(str(num_class)+"/"+str(image))
eval_f.write(" ")
if num_class == "daisy":
eval_f.write("0")
elif num_class == "dandelion":
eval_f.write("1")
elif num_class == "rose":
eval_f.write("2")
elif num_class == "sunflower":
eval_f.write("3")
else:
eval_f.write("4")
eval_f.write("\n")
except FileExistsError as e:
print(e)
exit(1)
else:
# 将划分到训练集中的文件复制到相应目录
data_image_path = os.path.join(num_class_path, image)
train_new_path = os.path.join(train_data_root, num_class)
copy(data_image_path, train_new_path)
try:
train_f = open("./data/flower_dataset/train.txt", "a+")
train_f.write(str(num_class)+"/"+str(image))
train_f.write(" ")
if num_class == "daisy":
train_f.write("0")
elif num_class == "dandelion":
train_f.write("1")
elif num_class == "rose":
train_f.write("2")
elif num_class == "sunflower":
train_f.write("3")
else:
train_f.write("4")
train_f.write("\n")
except FileExistsError as e:
print(e)
exit(1)
print("\r[{}] split_rating [{}/{}]".format(num_class, index+1, num), end="") # processing bar
print()
print(" ")
print(" ")
print("划分完成")
4. 生成classes.txt文件
因为就只有5个类别,我直接手写了
daisy 0
dandelion 1
rose 2
sunflower 3
tulip 4
5. 修改配置文件
将训练和验证集的数据路径进行修改,同样还有数据集标注列表,以及类别名文件路径。此外,还有评估方法修改为仅使用 top-1 分类错误率。
model = dict(
type='ImageClassifier',
backbone=dict(type='MobileNetV2', widen_factor=1.0),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=5,
in_channels=1280,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5)))
load_from = '/home/xilm/python_file/mmclassification-master/resources/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth'
dataset_type = 'CustomDataset'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=224, backend='pillow'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=(256, -1), backend='pillow'),
dict(type='CenterCrop', crop_size=224),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=32,
workers_per_gpu=2,
train=dict(
type='CustomDataset',
data_prefix=
'/home/xilm/python_file/mmclassification-master/data/flower_dataset/train',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=224, backend='pillow'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]),
val=dict(
type='CustomDataset',
data_prefix=
'/home/xilm/python_file/mmclassification-master/data/flower_dataset/val',
ann_file=
'/home/xilm/python_file/mmclassification-master/data/flower_dataset/val.txt',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='Resize', size=(256, -1), backend='pillow'),
dict(type='CenterCrop', crop_size=224),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]),
test=dict(
type='CustomDataset',
data_prefix=
'/home/xilm/python_file/mmclassification-master/data/flower_dataset/val',
ann_file=
'/home/xilm/python_file/mmclassification-master/data/flower_dataset/val.txt',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='Resize', size=(256, -1), backend='pillow'),
dict(type='CenterCrop', crop_size=224),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]))
evaluation = dict(
interval=1, metric='accuracy', metric_options=dict(topk=(1, )))#此处将指标修改为top1
optimizer = dict(type='SGD', lr=0.005, momentum=0.9, weight_decay=4e-05)
optimizer_config = dict(grad_clip=None)
lr_config = dict(policy='step', gamma=0.5, step=1)
runner = dict(type='EpochBasedRunner', max_epochs=10)
checkpoint_config = dict(interval=5)
log_config = dict(interval=10, hooks=[dict(type='TextLoggerHook')])
dist_params = dict(backend='nccl')
log_level = 'INFO'
resume_from = None
workflow = [('train', 1)]
work_dir = './work_dirs/mobilenet-v2_flower'
gpu_ids = range(0, 1)
6. 训练
在终端执行mim train mmcls [配置文件名]
以上就是完成的整个流程了!