基于pytorch的图像分类框架-更新日志

该更新日志详细记录了一个基于PyTorch的图像分类框架的改进过程,包括数据集划分逻辑的调整、FP16推理的支持、R-Drop和EMA技术的集成,以及模型性能的评估。此外,还介绍了知识蒸馏的应用,展示了不同模型作为教师和学生模型时的效果,并提到了ONNX、TorchScript和TensorRT的模型导出与推理性能测试。
摘要由CSDN通过智能技术生成

基于pytorch的图像分类框架-更新日志

源码地址 github

使用示例 使用pytorch实现花朵分类


pytorch-classifier v1.1 更新日志

  • 2022.11.8

    1. 修改processing.py的分配数据集逻辑,之前是先分出test_size的数据作为测试集,然后再从剩下的数据里面分val_size的数据作为验证集,这种分数据的方式,当我们的val_size=0.2和test_size=0.2,最后出来的数据集比例不是严格等于6:2:2,现在修改为等比例的划分,也就是现在的逻辑分割数据集后严格等于6:2:2.
    2. 参考yolov5,训练中的模型保存改为FP16保存.(在精度基本保持不变的情况下,模型相比FP32小一半)
    3. metrice.py和predict.py新增支持FP16推理.(在精度基本保持不变的情况下,速度更加快)
  • 2022.11.9

    1. 支持albumentations库的数据增强.
    2. 训练过程新增R-Drop,具体在main.py中添加–rdrop参数即可.
  • 2022.11.10

    1. 利用Pycm库进行修改metrice.py中的可视化内容.增加指标种类.
  • 2022.11.11

    1. 支持EMA(Exponential Moving Average),具体在main.py中添加–ema参数即可.
    2. 修改早停法中的–patience机制,当–patience参数为0时,停止使用早停法.
    3. 知识蒸馏中增加了一些实验数据.
    4. 修复一些bug.

FP16推理实验:

实验环境:

SystemCPUGPURAM
Ubuntui9-12900KFRTX-309032G

训练mobilenetv2:

    python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

训练resnext50:

    python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

训练RepVGG-A0:

    python main.py --model_name RepVGG-A0 --config config/config.py --save_path runs/RepVGG-A0 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

训练densenet121:

    python main.py --model_name densenet121 --config config/config.py --save_path runs/densenet121 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

计算各个模型的指标:

    python metrice.py --task val --save_path runs/mobilenetv2
    python metrice.py --task val --save_path runs/resnext50
    python metrice.py --task val --save_path runs/RepVGG-A0
    python metrice.py --task val --save_path runs/densenet121

    python metrice.py --task val --save_path runs/mobilenetv2 --half
    python metrice.py --task val --save_path runs/resnext50 --half
    python metrice.py --task val --save_path runs/RepVGG-A0 --half
    python metrice.py --task val --save_path runs/densenet121 --half

计算各个模型的fps:

    python metrice.py --task fps --save_path runs/mobilenetv2
    python metrice.py --task fps --save_path runs/resnext50
    python metrice.py --task fps --save_path runs/RepVGG-A0
    python metrice.py --task fps --save_path runs/densenet121

    python metrice.py --task fps --save_path runs/mobilenetv2 --half
    python metrice.py --task fps --save_path runs/resnext50 --half
    python metrice.py --task fps --save_path runs/RepVGG-A0 --half
    python metrice.py --task fps --save_path runs/densenet121 --half
modelval accuracy(train stage)val accuracy(test stage)val accuracy half(test stage)FP32 FPS(batch_size=64)FP16 FPS(batch_size=64)
mobilenetv20.742840.743400.7439652.4392.80
resnext500.809660.809660.8096619.4830.28
RepVGG-A00.736660.736660.7366654.7498.87
densenet1210.770350.771480.7703518.8732.75

R-Drop实验:

训练mobilenetv2:

    python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd --rdrop

训练resnext50:

    python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --rdrop

训练ghostnet:

    python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --rdrop

训练efficientnet_v2_s:

    python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --rdrop

计算各个模型的指标:

    python metrice.py --task val --save_path runs/mobilenetv2
    python metrice.py --task val --save_path runs/mobilenetv2_rdrop
    python metrice.py --task val --save_path runs/resnext50
    python metrice.py --task val --save_path runs/resnext50_rdrop
    python metrice.py --task val --save_path runs/ghostnet
    python metrice.py --task val --save_path runs/ghostnet_rdrop
    python metrice.py --task val --save_path runs/efficientnet_v2_s
    python metrice.py --task val --save_path runs/efficientnet_v2_s_rdrop

    python metrice.py --task test --save_path runs/mobilenetv2
    python metrice.py --task test --save_path runs/mobilenetv2_rdrop
    python metrice.py --task test --save_path runs/resnext50
    python metrice.py --task test --save_path runs/resnext50_rdrop
    python metrice.py --task test --save_path runs/ghostnet
    python metrice.py --task test --save_path runs/ghostnet_rdrop
    python metrice.py --task test --save_path runs/efficientnet_v2_s
    python metrice.py --task test --save_path runs/efficientnet_v2_s_rdrop
modelval accuracyval accuracy(r-drop)test accuracytest accuracy(r-drop)
mobilenetv20.743400.751260.737840.73741
resnext500.809660.811340.824370.82092
ghostnet0.775970.766980.766250.77012
efficientnet_v2_s0.841660.852890.844600.85837

EMA实验:

训练mobilenetv2:

    python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd --ema

训练resnext50:

    python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --ema

训练ghostnet:

    python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --ema

训练efficientnet_v2_s:

    python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --ema

计算各个模型的指标:

    python metrice.py --task val --save_path runs/mobilenetv2
    python metrice.py --task val --save_path runs/mobilenetv2_ema
    python metrice.py --task val --save_path runs/resnext50
    python metrice.py --task val --save_path runs/resnext50_ema
    python metrice.py --task val --save_path runs/ghostnet
    python metrice.py --task val --save_path runs/ghostnet_ema
    python metrice.py --task val --save_path runs/efficientnet_v2_s
    python metrice.py --task val --save_path runs/efficientnet_v2_s_ema

    python metrice.py --task test --save_path runs/mobilenetv2
    python metrice.py --task test --save_path runs/mobilenetv2_ema
    python metrice.py --task test --save_path runs/resnext50
    python metrice.py --task test --save_path runs/resnext50_ema
    python metrice.py --task test --save_path runs/ghostnet
    python metrice.py --task test --save_path runs/ghostnet_ema
    python metrice.py --task test --save_path runs/efficientnet_v2_s
    python metrice.py --task test --save_path runs/efficientnet_v2_s_ema
modelval accuracyval accuracy(ema)test accuracytest accuracy(ema)
mobilenetv20.743400.749580.737840.73870
resnext500.809660.812460.824370.82307
ghostnet0.775970.777650.766250.77142
efficientnet_v2_s0.841660.839980.844600.83986

pytorch-classifier v1.2 更新日志

  1. 新增export.py,支持导出(onnx, torchscript, tensorrt)模型.

  2. metrice.py支持onnx,torchscript,tensorrt的推理.

     此处在predict.py中暂不支持onnx,torchscript,tensorrt的推理的推理,原因是因为predict.py中的热力图可视化没办法在onnx、torchscript、tensorrt中实现,后续单独推理部分会额外写一部分代码.
     在metrice.py中,onnx和torchscript和tensorrt的推理也不支持tsne的可视化,那么我在metrice.py中添加onnx,torchscript,tensorrt的推理的目的是为了测试fps和精度.
     所以简单来说,使用metrice.py最好还是直接用torch模型,torchscript和onnx和tensorrt的推理的推理模型后续会写一个单独的推理代码.
    
  3. main.py,metrice.py,predict.py,export.py中增加–device参数,可以指定设备.

  4. 优化程序和修复一些bug.

训练命令:
python main.py --model_name efficientnet_v2_s --config config/config.py --batch_size 128 --Augment AutoAugment --save_path runs/efficientnet_v2_s --device 0 \
--pretrained --amp --warmup --ema --imagenet_meanstd
GPU 推理速度测试 sh脚本:
batch_size=1 # 1 2 4 8 16 32 64
python metrice.py --task fps --save_path runs/efficientnet_v2_s --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --half --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --half --model_type torchscript --batch_size $batch_size
python export.py --save_path runs/efficientnet_v2_s --export onnx --simplify --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type onnx --batch_size $batch_size
python export.py --save_path runs/efficientnet_v2_s --export onnx --simplify --half --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type onnx --batch_size $batch_size
python export.py --save_path runs/efficientnet_v2_s --export tensorrt --simplify --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type tensorrt --batch_size $batch_size
python export.py --save_path runs/efficientnet_v2_s --export tensorrt --simplify --half --batch_size $batch_size 
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type tensorrt --half --batch_size $batch_size
CPU 推理速度测试 sh脚本:
python export.py --save_path runs/efficientnet_v2_s --export onnx --simplify --dynamic --device cpu
batch_size=1
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
batch_size=2
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
batch_size=4
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
batch_size=8
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
batch_size=16
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size

各导出模型在cpu和gpu上的fps实验:

实验环境:

SystemCPUGPURAMModel
Ubuntu20.04i7-12700KFRTX-309032G DDR5 6400efficientnet_v2_s
GPU
modelTorch FP32 FPSTorch FP16 FPSTorchScript FP32 FPSTorchScript FP16 FPSONNX FP32 FPSONNX FP16 FPSTensorRT FP32 FPSTensorRT FP16 FPS
batch-size 193.77105.65233.21260.07177.41308.52311.60789.19
batch-size 294.32108.35208.53253.83166.23258.98275.93713.71
batch-size 495.98108.31171.99255.05130.43190.03212.75573.88
batch-size 894.0385.76118.79210.5887.65122.31147.36416.71
batch-size 1661.9376.2575.45125.0550.3369.0187.25260.94
batch-size 3234.5658.1141.9372.2926.9134.4648.54151.35
batch-size 6418.6431.5723.1538.9012.6715.9026.1985.47
CPU
modelTorch FP32 FPSTorch FP16 FPSTorchScript FP32 FPSTorchScript FP16 FPSONNX FP32 FPSONNX FP16 FPSTensorRT FP32 FPSTensorRT FP16 FPS
batch-size 127.91Not Support46.10Not Support79.27Not SupportNot SupportNot Support
batch-size 225.26Not Support24.98Not Support45.62Not SupportNot SupportNot Support
batch-size 414.02Not Support13.84Not Support23.90Not SupportNot SupportNot Support
batch-size 87.53Not Support7.35Not Support12.01Not SupportNot SupportNot Support
batch-size 163.07Not Support3.64Not Support5.72Not SupportNot SupportNot Support

pytorch-classifier v1.3 更新日志

  1. 增加repghost模型.
  2. 推理阶段把模型中的conv和bn进行fuse.
  3. 发现mnasnet0_5有点问题,暂停使用.
  4. torch.no_grad()更换成torch.inference_mode().

pytorch-classifier v1.4 更新日志

  1. predict.py支持检测灰度图,其读取后会检测是否为RGB通道,不是的话会进行转换.
  2. 更新readme.md.
  3. 修复一些bug.

Knowledge Distillation Experiment

为了测试知识蒸馏的可用性,基于CUB-200-2011百度网盘链接数据集进行实验.

stduent为mobilenetv2,teacher为resnet50.

普通训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw --test_tta

普通训练resnet50:

python main.py --model_name resnet50 --config config/config.py --save_path runs/resnet50_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算resnet50指标:

python metrice.py --task val --save_path runs/resnet50_admaw
python metrice.py --task test --save_path runs/resnet50_admaw
python metrice.py --task test --save_path runs/resnet50_admaw --test_tta

知识蒸馏, resnet50作为teacher, mobilenetv2作为student, 使用SoftTarget进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_ST --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method SoftTarget --kd_ratio 0.7 --teacher_path runs/resnet50_admaw

知识蒸馏, resnet50作为teacher, mobilenetv2作为student, 使用MGD进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_MGD1 --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/resnet50_admaw

知识蒸馏, resnet50作为teacher, mobilenetv2作为student, 使用AT进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_AT --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 100 --teacher_path runs/resnet50_admaw 

计算通过resnet50蒸馏mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw_ST
python metrice.py --task test --save_path runs/mobilenetv2_admaw_ST
python metrice.py --task test --save_path runs/mobilenetv2_admaw_ST --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_admaw_MGD
python metrice.py --task test --save_path runs/mobilenetv2_admaw_MGD
python metrice.py --task test --save_path runs/mobilenetv2_admaw_MGD --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_admaw_AT
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT --test_tta
modelval accuracyval mpatest accuracytest mpatest accuracy(TTA)test mpa(TTA)
mobilenetv20.741160.742000.734830.734520.770120.76979
resnet500.787200.787440.777440.776700.812310.81162
teacher->resnet50
student->mobilenetv2
SoftTarget
0.770920.771790.752480.751910.777870.77752
teacher->resnet50
student->mobilenetv2
MGD
0.788880.789940.783900.782960.799400.79890
teacher->resnet50
student->mobilenetv2
AT
0.747890.748780.738700.737950.763240.76244

stduent为mobilenetv2,teacher为ghostnet.

普通训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw --test_tta

普通训练ghostnet:

python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算ghostnet指标:

python metrice.py --task val --save_path runs/ghostnet_admaw
python metrice.py --task test --save_path runs/ghostnet_admaw
python metrice.py --task test --save_path runs/ghostnetadmaw --test_tta

知识蒸馏, ghostnet作为teacher, mobilenetv2作为student, 使用SoftTarget进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_ST --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method SoftTarget --kd_ratio 0.7 --teacher_path runs/ghostnet_admaw

知识蒸馏, ghostnet作为teacher, mobilenetv2作为student, 使用MGD进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_MGD --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method MGD --kd_ratio 0.2 --teacher_path runs/ghostnet_admaw

知识蒸馏, ghostnet作为teacher, mobilenetv2作为student, 使用AT进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_AT --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 1000.0 --teacher_path runs/ghostnet_admaw

计算通过ghostnet蒸馏mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw_ST
python metrice.py --task test --save_path runs/mobilenetv2_admaw_ST
python metrice.py --task test --save_path runs/mobilenetv2_admaw_ST --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_admaw_MGD
python metrice.py --task test --save_path runs/mobilenetv2_admaw_MGD
python metrice.py --task test --save_path runs/mobilenetv2_admaw_MGD --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_admaw_AT
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT --test_tta
modelval accuracyval mpatest accuracytest mpatest accuracy(TTA)test mpa(TTA)
mobilenetv20.741160.742000.734830.734520.770120.76979
ghostnet0.777090.777560.763670.762770.780460.77958
teacher->ghostnet
student->mobilenetv2
SoftTarget
0.778780.779680.761080.760220.779160.77807
teacher->ghostnet
student->mobilenetv2
MGD
0.756320.757230.746880.746380.773570.77302
teacher->ghostnet
student->mobilenetv2
AT
0.748460.749450.738270.737820.766250.76534

由于SP蒸馏开启AMP时,kd_loss大概率会出现nan,所在SP蒸馏实验中,我们把所有模型都不开启AMP.

stduent为mobilenetv2,teacher为ghostnet.

普通训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd

计算mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw --test_tta

普通训练ghostnet:

python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd

计算ghostnet指标:

python metrice.py --task val --save_path runs/ghostnet_admaw
python metrice.py --task test --save_path runs/ghostnet_admaw
python metrice.py --task test --save_path runs/ghostnetadmaw --test_tta

知识蒸馏, ghostnet作为teacher, mobilenetv2作为student, 使用SP进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_SP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd \
--kd --kd_method SP --kd_ratio 10.0 --teacher_path runs/ghostnet_admaw

计算通过ghostnet蒸馏mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw_SP
python metrice.py --task test --save_path runs/mobilenetv2_admaw_SP
python metrice.py --task test --save_path runs/mobilenetv2_admaw_SP --test_tta
modelval accuracyval mpatest accuracytest mpatest accuracy(TTA)test mpa(TTA)
mobilenetv20.745090.745680.738270.737610.769690.76903
ghostnet0.778210.778810.758070.757080.778730.77805
teacher->ghostnet
student->mobilenetv2
SP
0.747330.748360.732670.731980.758930.75850
stduent为mobilenetv2,teacher为resnet50.

普通训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd

计算mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw --test_tta

普通训练resnet50:

python main.py --model_name resnet50 --config config/config.py --save_path runs/resnet50_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd

计算resnet50指标:

python metrice.py --task val --save_path runs/resnet50_admaw
python metrice.py --task test --save_path runs/resnet50_admaw
python metrice.py --task test --save_path runs/resnet50_admaw --test_tta

知识蒸馏, resnet50作为teacher, mobilenetv2作为student, 使用SP进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_SP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd \
--kd --kd_method SP --kd_ratio 10.0 --teacher_path runs/resnet50_admaw
modelval accuracyval mpatest accuracytest mpatest accuracy(TTA)test mpa(TTA)
mobilenetv20.745090.745680.738270.737610.769690.76903
resnet500.787200.787070.774000.773210.812310.81138
teacher->resnet50
student->mobilenetv2
SP
0.741160.742000.740420.739690.768400.76753

以下实验是通过训练好的自身模型再作为教师模型进行训练.

知识蒸馏, resnet50作为teacher, resnet50作为student, 使用AT进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/resnet50_admaw_AT_self --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 100 --teacher_path runs/resnet50_admaw 

计算通过resnet50蒸馏resnet50指标:

python metrice.py --task val --save_path runs/resnet50_admaw_AT_self
python metrice.py --task test --save_path runs/resnet50_admaw_AT_self
python metrice.py --task test --save_path runs/resnet50_admaw_AT_self --test_tta

知识蒸馏, mobilenetv2作为teacher, mobilenetv2作为student, 使用AT进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_AT_self --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 100 --teacher_path runs/mobilenetv2_admaw 

计算通过mobilenetv2蒸馏mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw_AT_self
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT_self
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT_self --test_tta

知识蒸馏, ghostnet作为teacher, ghostnet作为student, 使用AT进行蒸馏:

python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_admaw_AT_self --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 1000 --teacher_path runs/ghostnet_admaw 

计算通过ghostnet蒸馏ghostnet指标:

python metrice.py --task val --save_path runs/ghostnet_admaw_AT_self
python metrice.py --task test --save_path runs/ghostnet_admaw_AT_self
python metrice.py --task test --save_path runs/ghostnet_admaw_AT_self --test_tta
modelval accuracyval mpatest accuracytest mpatest accuracy(TTA)test mpa(TTA)
mobilenetv20.741160.742000.734830.734520.770120.76979
teacher->mobilenetv2
student->mobilenetv2
AT
0.746770.747580.744300.743420.770120.76926
resnet500.787200.787440.777440.776700.812310.81162
teacher->resnet50
student->resnet50
AT
0.790570.790910.791650.790260.811020.81030
ghostnet0.777090.777560.763670.762770.780460.77958
teacher->ghostnet
student->ghostnet
AT
0.780460.780800.771420.770690.788200.78742

在V1.1版本的测试中发现efficientnet_v2网络作为teacher网络效果还不错.

普通训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2
python metrice.py --task test --save_path runs/mobilenetv2
python metrice.py --task test --save_path runs/mobilenetv2 --test_tta

普通训练efficientnet_v2_s:

python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算efficientnet_v2_s指标:

python metrice.py --task val --save_path runs/efficientnet_v2_s
python metrice.py --task test --save_path runs/efficientnet_v2_s
python metrice.py --task test --save_path runs/efficientnet_v2_s --test_tta

知识蒸馏, efficientnet_v2_s作为teacher, mobilenetv2作为student, 使用SoftTarget进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_ST --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method SoftTarget --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

知识蒸馏, efficientnet_v2_s作为teacher, mobilenetv2作为student, 使用MGD进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_EMA --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --ema \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_RDROP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --rdrop \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_EMA_RDROP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --rdrop --ema \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

计算通过efficientnet_v2_s蒸馏mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_ST
python metrice.py --task test --save_path runs/mobilenetv2_ST
python metrice.py --task test --save_path runs/mobilenetv2_ST --test_tta

python metrice.py --task val --save_path runs/mobilenetv2_MGD
python metrice.py --task test --save_path runs/mobilenetv2_MGD
python metrice.py --task test --save_path runs/mobilenetv2_MGD --test_tta

python metrice.py --task val --save_path runs/mobilenetv2_MGD_EMA
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA --test_tta

python metrice.py --task val --save_path runs/mobilenetv2_MGD_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_RDROP --test_tta

python metrice.py --task val --save_path runs/mobilenetv2_MGD_EMA_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA_RDROP --test_tta
modelval accuracyval mpatest accuracytest mpatest accuracy(TTA)test mpa(TTA)
mobilenetv20.741160.742000.734830.734520.770120.76979
efficientnet_v2_s0.841660.841910.844600.844410.864830.86484
teacher->efficientnet_v2_s
student->mobilenetv2
ST
0.761370.762090.751610.750880.778300.77715
teacher->efficientnet_v2_s
student->mobilenetv2
MGD
0.772040.772880.775290.774640.793370.79261
teacher->efficientnet_v2_s
student->mobilenetv2
MGD(EMA)
0.772040.772670.777440.776710.802840.80201
teacher->efficientnet_v2_s
student->mobilenetv2
MGD(RDrop)
0.772040.772880.775290.774640.793370.79261
teacher->efficientnet_v2_s
student->mobilenetv2
MGD(EMA,RDrop)
0.772040.772670.777440.776710.802840.80201

关于Knowledge Distillation的一些解释

实验解释:

  1. 对于AT和SP蒸馏方法,上述实验都是使用block3和block4的特征层进行蒸馏.
  2. MPA是平均类别精度,在类别不平衡的情况下非常有用,当类别基本平衡的情况下,跟accuracy差不多.
  3. 当蒸馏loss出现nan的时候请不要开启AMP,AMP可能会导致浮点溢出导致的nan.

目前支持的类型有:

NameMethodpaper
SoftTargetlogitshttps://arxiv.org/pdf/1503.02531.pdf
MGDfeatureshttps://arxiv.org/abs/2205.01529.pdf
SPfeatureshttps://arxiv.org/pdf/1907.09682.pdf
ATfeatureshttps://arxiv.org/pdf/1612.03928.pdf

蒸馏学习跟模型,参数,蒸馏的方法,蒸馏的层都有关系,效果不好需要自行调整,其中SP和AT都可以对模型中的四个block进行组合计算蒸馏损失具体代码在utils/utils_fit.py的fitting_distill函数中可以进行修改.

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

魔鬼面具

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值