★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>
1.情况简介
该项目基于PaddleClas,主要完成多标签分类的训练、评估、预测的体验过程。PaddleClas的多标签分类模型只有MobileNetV1,因此本项目是基于MobileNetV1来进行改写SwinTransformer模型。
2.数据集
该项目数据集为NUS-WIDE-SCENE的子集,需要对图像进行分类,具有36个标签。
- 该子集下载地址: https://paddle-imagenet-models-name.bj.bcebos.com/data/NUS-SCENE-dataset.tar
- NUS-WIDE-SCENE数据集下载地址:https://lms.comp.nus.edu.sg/wp-content/uploads/2019/research/nuswide/NUS-WIDE.html
标签有:
airport
beach
bridge
buildings
castle
cityscape
clouds
frost
- garden
glacier
grass
harbor
house
lake
moon
mountain
nighttime
ocean
plants
railroad
rainbow
reflection
road
sky
snow
street
sunset
temple
town
valley
water
waterfall
window
二、PaddleClas安装
1.PaddleClas下载
#可以直接从套件中下载PaddleClas,这样速度比较快,这里下载v2.4.0版本。
2.PaddleClas安装
主要完成相关依赖库安装等
!pip install -r ~/PaddleClas-2.4.0/requirements.txt >log.log
!pip install -e ~/PaddleClas-2.4.0 >log.log
三、数据集准备
1.数据解压缩
主要完成 数据集下载、解压缩 等。
%cd ~/PaddleClas-2.4.0
!mkdir dataset/NUS-WIDE-SCENE
%cd dataset/NUS-WIDE-SCENE
!wget https://paddle-imagenet-models-name.bj.bcebos.com/data/NUS-SCENE-dataset.tar
!tar -xf NUS-SCENE-dataset.tar
2.数据查看
其中第一列为图像文件名,其后36列分别为garden glacier grass harbor house lake moon mountain nighttime ocean plants railroad rainbow reflection road sky snow street sunset temple town valley water waterfall window 标签,为1则是,0否。
!head NUS-SCENE-dataset/multilabel_train_list.txt
from PIL import Image
%cd ~
img=Image.open("PaddleClas-2.4.0/dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images/0006_2074187535.jpg")
img.show()
四、模型训练
1.训练配置
将PaddleClas-2.4.0/ppcls/configs/ImageNet/SwinTransformer/SwinTransformer_tiny_patch4_window7_224.yaml配置文件中的内容改成以下内容
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 5
eval_during_train: True
eval_interval: 5
epochs: 300
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
use_multilabel: True
# model architecture
Arch:
name: SwinTransformer_tiny_patch4_window7_224
class_num: 33
pretrained: True
# loss function config for traing/eval process
Loss:
Train:
- MultiLabelLoss:
weight: 1.0
Eval:
- MultiLabelLoss:
weight: 1.0
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
weight_decay: 0.05
no_weight_decay_name: absolute_pos_embed relative_position_bias_table .bias norm
one_dim_param_no_weight_decay: True
lr:
# for 8 cards
name: Cosine
learning_rate: 1e-3
eta_min: 2e-5
warmup_epoch: 20
warmup_start_lr: 2e-6
# data loader for train and eval
DataLoader:
Train:
dataset:
name: MultiLabelDataset
image_root: ./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images/
cls_label_path: ./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/multilabel_train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
interpolation: bicubic
backend: pil
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
img_size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.25
sl: 0.02
sh: 1.0/3.0
r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
# batch_transform_ops:
# - OpSampler:
# MixupOperator:
# alpha: 0.8
# prob: 0.5
# CutmixOperator:
# alpha: 1.0
# prob: 0.5
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: MultiLabelDataset
image_root: ./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images/
cls_label_path: ./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/multilabel_test_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
interpolation: bicubic
backend: pil
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: ./deploy/images/0517_2715693311.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
interpolation: bicubic
backend: pil
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: MultiLabelTopk
topk: 5
class_id_map_file: None
Metric:
Train:
- HammingDistance:
- AccuracyScore:
Eval:
- HammingDistance:
- AccuracyScore:
!cp /home/aistudio/SwinTransformer_tiny_patch4_window7_224.yaml ~/PaddleClas-2.4.0/ppcls/configs/ImageNet/SwinTransformer/SwinTransformer_tiny_patch4_window7_224.yaml
3.开始训练
%cd ~/PaddleClas-2.4.0/
!python3 tools/train.py \
-c ./ppcls/configs/ImageNet/SwinTransformer/SwinTransformer_tiny_patch4_window7_224.yaml
五、模型评估
!python3 tools/eval.py \
-c ./ppcls/configs/ImageNet/SwinTransformer/SwinTransformer_tiny_patch4_window7_224.yaml \
-o Arch.pretrained="./output/SwinTransformer_tiny_patch4_window7_224/latest"
六、模型预测
通过预测,图像
!python3 tools/infer.py \
-c ./ppcls/configs/ImageNet/SwinTransformer/SwinTransformer_tiny_patch4_window7_224.yaml \
-o Arch.pretrained="./output/SwinTransformer_tiny_patch4_window7_224/latest"
七、基于预测引擎预测
1.导出 inference model
!python3 tools/export_model.py \
-c ./ppcls/configs/ImageNet/SwinTransformer/SwinTransformer_tiny_patch4_window7_224.yaml \
-o Arch.pretrained="./output/SwinTransformer_tiny_patch4_window7_224/latest"
inference model 的路径默认在当前路径下 ./inference
%cd ~/PaddleClas-2.4.0
!ls ./inference -l
2 基于预测引擎预测
- 首先进入 deploy 目录
- 通过预测引擎推理预测
将PaddleClas-2.4.0/deploy/configs/inference_cls_multilabel.yaml配置文件修改成如下内容
Global:
infer_imgs: "./images/0517_2715693311.jpg"
inference_model_dir: "../inference/"
batch_size: 1
use_gpu: True
enable_mkldnn: False
cpu_num_threads: 10
enable_benchmark: True
use_fp16: False
ir_optim: True
use_tensorrt: False
gpu_mem: 8000
enable_profile: False
PreProcess:
transform_ops:
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
channel_num: 3
- ToCHWImage:
PostProcess:
main_indicator: MultiLabelTopk
MultiLabelTopk:
topk: 5
class_id_map_file: None
SavePreLabel:
save_dir: ./pre_label/
!cp /home/aistudio/inference_cls_multilabel.yaml ~/inference_cls_multilabel.yaml
%cd ~/PaddleClas-2.4.0/deploy
!python3 python/predict_cls.py \
-c ./configs/inference_cls_multilabel.yaml
八、总结
多标签图像分类在日常生活中很常见,例如百度网盘AI大赛——版式分析场景比赛https://aistudio.baidu.com/aistudio/competition/detail/850/0/leaderboard等。
上述比赛中,可以利用更新的模型来训练多标签分类,将取得更好的结果。
此文章为搬运
原项目链接