基于pytorch的图像分类框架-更新日志
源码地址 github
使用示例 使用pytorch实现花朵分类
pytorch-classifier v1.1 更新日志
-
2022.11.8
- 修改processing.py的分配数据集逻辑,之前是先分出test_size的数据作为测试集,然后再从剩下的数据里面分val_size的数据作为验证集,这种分数据的方式,当我们的val_size=0.2和test_size=0.2,最后出来的数据集比例不是严格等于6:2:2,现在修改为等比例的划分,也就是现在的逻辑分割数据集后严格等于6:2:2.
- 参考yolov5,训练中的模型保存改为FP16保存.(在精度基本保持不变的情况下,模型相比FP32小一半)
- metrice.py和predict.py新增支持FP16推理.(在精度基本保持不变的情况下,速度更加快)
-
2022.11.9
- 支持albumentations库的数据增强.
- 训练过程新增R-Drop,具体在main.py中添加–rdrop参数即可.
-
2022.11.10
- 利用Pycm库进行修改metrice.py中的可视化内容.增加指标种类.
-
2022.11.11
- 支持EMA(Exponential Moving Average),具体在main.py中添加–ema参数即可.
- 修改早停法中的–patience机制,当–patience参数为0时,停止使用早停法.
- 知识蒸馏中增加了一些实验数据.
- 修复一些bug.
FP16推理实验:
实验环境:
System | CPU | GPU | RAM |
---|---|---|---|
Ubuntu | i9-12900KF | RTX-3090 | 32G |
训练mobilenetv2:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
训练resnext50:
python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50 --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
训练RepVGG-A0:
python main.py --model_name RepVGG-A0 --config config/config.py --save_path runs/RepVGG-A0 --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
训练densenet121:
python main.py --model_name densenet121 --config config/config.py --save_path runs/densenet121 --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
计算各个模型的指标:
python metrice.py --task val --save_path runs/mobilenetv2
python metrice.py --task val --save_path runs/resnext50
python metrice.py --task val --save_path runs/RepVGG-A0
python metrice.py --task val --save_path runs/densenet121
python metrice.py --task val --save_path runs/mobilenetv2 --half
python metrice.py --task val --save_path runs/resnext50 --half
python metrice.py --task val --save_path runs/RepVGG-A0 --half
python metrice.py --task val --save_path runs/densenet121 --half
计算各个模型的fps:
python metrice.py --task fps --save_path runs/mobilenetv2
python metrice.py --task fps --save_path runs/resnext50
python metrice.py --task fps --save_path runs/RepVGG-A0
python metrice.py --task fps --save_path runs/densenet121
python metrice.py --task fps --save_path runs/mobilenetv2 --half
python metrice.py --task fps --save_path runs/resnext50 --half
python metrice.py --task fps --save_path runs/RepVGG-A0 --half
python metrice.py --task fps --save_path runs/densenet121 --half
model | val accuracy(train stage) | val accuracy(test stage) | val accuracy half(test stage) | FP32 FPS(batch_size=64) | FP16 FPS(batch_size=64) |
---|---|---|---|---|---|
mobilenetv2 | 0.74284 | 0.74340 | 0.74396 | 52.43 | 92.80 |
resnext50 | 0.80966 | 0.80966 | 0.80966 | 19.48 | 30.28 |
RepVGG-A0 | 0.73666 | 0.73666 | 0.73666 | 54.74 | 98.87 |
densenet121 | 0.77035 | 0.77148 | 0.77035 | 18.87 | 32.75 |
R-Drop实验:
训练mobilenetv2:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --rdrop
训练resnext50:
python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50 --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --rdrop
训练ghostnet:
python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --rdrop
训练efficientnet_v2_s:
python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --rdrop
计算各个模型的指标:
python metrice.py --task val --save_path runs/mobilenetv2
python metrice.py --task val --save_path runs/mobilenetv2_rdrop
python metrice.py --task val --save_path runs/resnext50
python metrice.py --task val --save_path runs/resnext50_rdrop
python metrice.py --task val --save_path runs/ghostnet
python metrice.py --task val --save_path runs/ghostnet_rdrop
python metrice.py --task val --save_path runs/efficientnet_v2_s
python metrice.py --task val --save_path runs/efficientnet_v2_s_rdrop
python metrice.py --task test --save_path runs/mobilenetv2
python metrice.py --task test --save_path runs/mobilenetv2_rdrop
python metrice.py --task test --save_path runs/resnext50
python metrice.py --task test --save_path runs/resnext50_rdrop
python metrice.py --task test --save_path runs/ghostnet
python metrice.py --task test --save_path runs/ghostnet_rdrop
python metrice.py --task test --save_path runs/efficientnet_v2_s
python metrice.py --task test --save_path runs/efficientnet_v2_s_rdrop
model | val accuracy | val accuracy(r-drop) | test accuracy | test accuracy(r-drop) |
---|---|---|---|---|
mobilenetv2 | 0.74340 | 0.75126 | 0.73784 | 0.73741 |
resnext50 | 0.80966 | 0.81134 | 0.82437 | 0.82092 |
ghostnet | 0.77597 | 0.76698 | 0.76625 | 0.77012 |
efficientnet_v2_s | 0.84166 | 0.85289 | 0.84460 | 0.85837 |
EMA实验:
训练mobilenetv2:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --ema
训练resnext50:
python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50 --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --ema
训练ghostnet:
python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --ema
训练efficientnet_v2_s:
python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --ema
计算各个模型的指标:
python metrice.py --task val --save_path runs/mobilenetv2
python metrice.py --task val --save_path runs/mobilenetv2_ema
python metrice.py --task val --save_path runs/resnext50
python metrice.py --task val --save_path runs/resnext50_ema
python metrice.py --task val --save_path runs/ghostnet
python metrice.py --task val --save_path runs/ghostnet_ema
python metrice.py --task val --save_path runs/efficientnet_v2_s
python metrice.py --task val --save_path runs/efficientnet_v2_s_ema
python metrice.py --task test --save_path runs/mobilenetv2
python metrice.py --task test --save_path runs/mobilenetv2_ema
python metrice.py --task test --save_path runs/resnext50
python metrice.py --task test --save_path runs/resnext50_ema
python metrice.py --task test --save_path runs/ghostnet
python metrice.py --task test --save_path runs/ghostnet_ema
python metrice.py --task test --save_path runs/efficientnet_v2_s
python metrice.py --task test --save_path runs/efficientnet_v2_s_ema
model | val accuracy | val accuracy(ema) | test accuracy | test accuracy(ema) |
---|---|---|---|---|
mobilenetv2 | 0.74340 | 0.74958 | 0.73784 | 0.73870 |
resnext50 | 0.80966 | 0.81246 | 0.82437 | 0.82307 |
ghostnet | 0.77597 | 0.77765 | 0.76625 | 0.77142 |
efficientnet_v2_s | 0.84166 | 0.83998 | 0.84460 | 0.83986 |
pytorch-classifier v1.2 更新日志
-
新增export.py,支持导出(onnx, torchscript, tensorrt)模型.
-
metrice.py支持onnx,torchscript,tensorrt的推理.
此处在predict.py中暂不支持onnx,torchscript,tensorrt的推理的推理,原因是因为predict.py中的热力图可视化没办法在onnx、torchscript、tensorrt中实现,后续单独推理部分会额外写一部分代码. 在metrice.py中,onnx和torchscript和tensorrt的推理也不支持tsne的可视化,那么我在metrice.py中添加onnx,torchscript,tensorrt的推理的目的是为了测试fps和精度. 所以简单来说,使用metrice.py最好还是直接用torch模型,torchscript和onnx和tensorrt的推理的推理模型后续会写一个单独的推理代码.
-
main.py,metrice.py,predict.py,export.py中增加–device参数,可以指定设备.
-
优化程序和修复一些bug.
训练命令:
python main.py --model_name efficientnet_v2_s --config config/config.py --batch_size 128 --Augment AutoAugment --save_path runs/efficientnet_v2_s --device 0 \
--pretrained --amp --warmup --ema --imagenet_meanstd
GPU 推理速度测试 sh脚本:
batch_size=1 # 1 2 4 8 16 32 64
python metrice.py --task fps --save_path runs/efficientnet_v2_s --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --half --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --half --model_type torchscript --batch_size $batch_size
python export.py --save_path runs/efficientnet_v2_s --export onnx --simplify --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type onnx --batch_size $batch_size
python export.py --save_path runs/efficientnet_v2_s --export onnx --simplify --half --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type onnx --batch_size $batch_size
python export.py --save_path runs/efficientnet_v2_s --export tensorrt --simplify --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type tensorrt --batch_size $batch_size
python export.py --save_path runs/efficientnet_v2_s --export tensorrt --simplify --half --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type tensorrt --half --batch_size $batch_size
CPU 推理速度测试 sh脚本:
python export.py --save_path runs/efficientnet_v2_s --export onnx --simplify --dynamic --device cpu
batch_size=1
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
batch_size=2
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
batch_size=4
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
batch_size=8
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
batch_size=16
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
各导出模型在cpu和gpu上的fps实验:
实验环境:
System | CPU | GPU | RAM | Model |
---|---|---|---|---|
Ubuntu20.04 | i7-12700KF | RTX-3090 | 32G DDR5 6400 | efficientnet_v2_s |
GPU
model | Torch FP32 FPS | Torch FP16 FPS | TorchScript FP32 FPS | TorchScript FP16 FPS | ONNX FP32 FPS | ONNX FP16 FPS | TensorRT FP32 FPS | TensorRT FP16 FPS |
---|---|---|---|---|---|---|---|---|
batch-size 1 | 93.77 | 105.65 | 233.21 | 260.07 | 177.41 | 308.52 | 311.60 | 789.19 |
batch-size 2 | 94.32 | 108.35 | 208.53 | 253.83 | 166.23 | 258.98 | 275.93 | 713.71 |
batch-size 4 | 95.98 | 108.31 | 171.99 | 255.05 | 130.43 | 190.03 | 212.75 | 573.88 |
batch-size 8 | 94.03 | 85.76 | 118.79 | 210.58 | 87.65 | 122.31 | 147.36 | 416.71 |
batch-size 16 | 61.93 | 76.25 | 75.45 | 125.05 | 50.33 | 69.01 | 87.25 | 260.94 |
batch-size 32 | 34.56 | 58.11 | 41.93 | 72.29 | 26.91 | 34.46 | 48.54 | 151.35 |
batch-size 64 | 18.64 | 31.57 | 23.15 | 38.90 | 12.67 | 15.90 | 26.19 | 85.47 |
CPU
model | Torch FP32 FPS | Torch FP16 FPS | TorchScript FP32 FPS | TorchScript FP16 FPS | ONNX FP32 FPS | ONNX FP16 FPS | TensorRT FP32 FPS | TensorRT FP16 FPS |
---|---|---|---|---|---|---|---|---|
batch-size 1 | 27.91 | Not Support | 46.10 | Not Support | 79.27 | Not Support | Not Support | Not Support |
batch-size 2 | 25.26 | Not Support | 24.98 | Not Support | 45.62 | Not Support | Not Support | Not Support |
batch-size 4 | 14.02 | Not Support | 13.84 | Not Support | 23.90 | Not Support | Not Support | Not Support |
batch-size 8 | 7.53 | Not Support | 7.35 | Not Support | 12.01 | Not Support | Not Support | Not Support |
batch-size 16 | 3.07 | Not Support | 3.64 | Not Support | 5.72 | Not Support | Not Support | Not Support |
pytorch-classifier v1.3 更新日志
- 增加repghost模型.
- 推理阶段把模型中的conv和bn进行fuse.
- 发现mnasnet0_5有点问题,暂停使用.
- torch.no_grad()更换成torch.inference_mode().
pytorch-classifier v1.4 更新日志
- predict.py支持检测灰度图,其读取后会检测是否为RGB通道,不是的话会进行转换.
- 更新readme.md.
- 修复一些bug.
Knowledge Distillation Experiment
为了测试知识蒸馏的可用性,基于CUB-200-2011百度网盘链接数据集进行实验.
stduent为mobilenetv2,teacher为resnet50.
普通训练mobilenetv2:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
计算mobilenetv2指标:
python metrice.py --task val --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw --test_tta
普通训练resnet50:
python main.py --model_name resnet50 --config config/config.py --save_path runs/resnet50_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
计算resnet50指标:
python metrice.py --task val --save_path runs/resnet50_admaw
python metrice.py --task test --save_path runs/resnet50_admaw
python metrice.py --task test --save_path runs/resnet50_admaw --test_tta
知识蒸馏, resnet50作为teacher, mobilenetv2作为student, 使用SoftTarget进行蒸馏:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_ST --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method SoftTarget --kd_ratio 0.7 --teacher_path runs/resnet50_admaw
知识蒸馏, resnet50作为teacher, mobilenetv2作为student, 使用MGD进行蒸馏:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_MGD1 --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/resnet50_admaw
知识蒸馏, resnet50作为teacher, mobilenetv2作为student, 使用AT进行蒸馏:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_AT --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 100 --teacher_path runs/resnet50_admaw
计算通过resnet50蒸馏mobilenetv2指标:
python metrice.py --task val --save_path runs/mobilenetv2_admaw_ST
python metrice.py --task test --save_path runs/mobilenetv2_admaw_ST
python metrice.py --task test --save_path runs/mobilenetv2_admaw_ST --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_admaw_MGD
python metrice.py --task test --save_path runs/mobilenetv2_admaw_MGD
python metrice.py --task test --save_path runs/mobilenetv2_admaw_MGD --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_admaw_AT
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT --test_tta
model | val accuracy | val mpa | test accuracy | test mpa | test accuracy(TTA) | test mpa(TTA) |
---|---|---|---|---|---|---|
mobilenetv2 | 0.74116 | 0.74200 | 0.73483 | 0.73452 | 0.77012 | 0.76979 |
resnet50 | 0.78720 | 0.78744 | 0.77744 | 0.77670 | 0.81231 | 0.81162 |
teacher->resnet50 student->mobilenetv2 SoftTarget | 0.77092 | 0.77179 | 0.75248 | 0.75191 | 0.77787 | 0.77752 |
teacher->resnet50 student->mobilenetv2 MGD | 0.78888 | 0.78994 | 0.78390 | 0.78296 | 0.79940 | 0.79890 |
teacher->resnet50 student->mobilenetv2 AT | 0.74789 | 0.74878 | 0.73870 | 0.73795 | 0.76324 | 0.76244 |
stduent为mobilenetv2,teacher为ghostnet.
普通训练mobilenetv2:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
计算mobilenetv2指标:
python metrice.py --task val --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw --test_tta
普通训练ghostnet:
python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
计算ghostnet指标:
python metrice.py --task val --save_path runs/ghostnet_admaw
python metrice.py --task test --save_path runs/ghostnet_admaw
python metrice.py --task test --save_path runs/ghostnetadmaw --test_tta
知识蒸馏, ghostnet作为teacher, mobilenetv2作为student, 使用SoftTarget进行蒸馏:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_ST --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method SoftTarget --kd_ratio 0.7 --teacher_path runs/ghostnet_admaw
知识蒸馏, ghostnet作为teacher, mobilenetv2作为student, 使用MGD进行蒸馏:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_MGD --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method MGD --kd_ratio 0.2 --teacher_path runs/ghostnet_admaw
知识蒸馏, ghostnet作为teacher, mobilenetv2作为student, 使用AT进行蒸馏:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_AT --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 1000.0 --teacher_path runs/ghostnet_admaw
计算通过ghostnet蒸馏mobilenetv2指标:
python metrice.py --task val --save_path runs/mobilenetv2_admaw_ST
python metrice.py --task test --save_path runs/mobilenetv2_admaw_ST
python metrice.py --task test --save_path runs/mobilenetv2_admaw_ST --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_admaw_MGD
python metrice.py --task test --save_path runs/mobilenetv2_admaw_MGD
python metrice.py --task test --save_path runs/mobilenetv2_admaw_MGD --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_admaw_AT
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT --test_tta
model | val accuracy | val mpa | test accuracy | test mpa | test accuracy(TTA) | test mpa(TTA) |
---|---|---|---|---|---|---|
mobilenetv2 | 0.74116 | 0.74200 | 0.73483 | 0.73452 | 0.77012 | 0.76979 |
ghostnet | 0.77709 | 0.77756 | 0.76367 | 0.76277 | 0.78046 | 0.77958 |
teacher->ghostnet student->mobilenetv2 SoftTarget | 0.77878 | 0.77968 | 0.76108 | 0.76022 | 0.77916 | 0.77807 |
teacher->ghostnet student->mobilenetv2 MGD | 0.75632 | 0.75723 | 0.74688 | 0.74638 | 0.77357 | 0.77302 |
teacher->ghostnet student->mobilenetv2 AT | 0.74846 | 0.74945 | 0.73827 | 0.73782 | 0.76625 | 0.76534 |
由于SP蒸馏开启AMP时,kd_loss大概率会出现nan,所在SP蒸馏实验中,我们把所有模型都不开启AMP.
stduent为mobilenetv2,teacher为ghostnet.
普通训练mobilenetv2:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd
计算mobilenetv2指标:
python metrice.py --task val --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw --test_tta
普通训练ghostnet:
python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd
计算ghostnet指标:
python metrice.py --task val --save_path runs/ghostnet_admaw
python metrice.py --task test --save_path runs/ghostnet_admaw
python metrice.py --task test --save_path runs/ghostnetadmaw --test_tta
知识蒸馏, ghostnet作为teacher, mobilenetv2作为student, 使用SP进行蒸馏:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_SP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd \
--kd --kd_method SP --kd_ratio 10.0 --teacher_path runs/ghostnet_admaw
计算通过ghostnet蒸馏mobilenetv2指标:
python metrice.py --task val --save_path runs/mobilenetv2_admaw_SP
python metrice.py --task test --save_path runs/mobilenetv2_admaw_SP
python metrice.py --task test --save_path runs/mobilenetv2_admaw_SP --test_tta
model | val accuracy | val mpa | test accuracy | test mpa | test accuracy(TTA) | test mpa(TTA) |
---|---|---|---|---|---|---|
mobilenetv2 | 0.74509 | 0.74568 | 0.73827 | 0.73761 | 0.76969 | 0.76903 |
ghostnet | 0.77821 | 0.77881 | 0.75807 | 0.75708 | 0.77873 | 0.77805 |
teacher->ghostnet student->mobilenetv2 SP | 0.74733 | 0.74836 | 0.73267 | 0.73198 | 0.75893 | 0.75850 |
stduent为mobilenetv2,teacher为resnet50.
普通训练mobilenetv2:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd
计算mobilenetv2指标:
python metrice.py --task val --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw --test_tta
普通训练resnet50:
python main.py --model_name resnet50 --config config/config.py --save_path runs/resnet50_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd
计算resnet50指标:
python metrice.py --task val --save_path runs/resnet50_admaw
python metrice.py --task test --save_path runs/resnet50_admaw
python metrice.py --task test --save_path runs/resnet50_admaw --test_tta
知识蒸馏, resnet50作为teacher, mobilenetv2作为student, 使用SP进行蒸馏:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_SP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd \
--kd --kd_method SP --kd_ratio 10.0 --teacher_path runs/resnet50_admaw
model | val accuracy | val mpa | test accuracy | test mpa | test accuracy(TTA) | test mpa(TTA) |
---|---|---|---|---|---|---|
mobilenetv2 | 0.74509 | 0.74568 | 0.73827 | 0.73761 | 0.76969 | 0.76903 |
resnet50 | 0.78720 | 0.78707 | 0.77400 | 0.77321 | 0.81231 | 0.81138 |
teacher->resnet50 student->mobilenetv2 SP | 0.74116 | 0.74200 | 0.74042 | 0.73969 | 0.76840 | 0.76753 |
以下实验是通过训练好的自身模型再作为教师模型进行训练.
知识蒸馏, resnet50作为teacher, resnet50作为student, 使用AT进行蒸馏:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/resnet50_admaw_AT_self --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 100 --teacher_path runs/resnet50_admaw
计算通过resnet50蒸馏resnet50指标:
python metrice.py --task val --save_path runs/resnet50_admaw_AT_self
python metrice.py --task test --save_path runs/resnet50_admaw_AT_self
python metrice.py --task test --save_path runs/resnet50_admaw_AT_self --test_tta
知识蒸馏, mobilenetv2作为teacher, mobilenetv2作为student, 使用AT进行蒸馏:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_AT_self --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 100 --teacher_path runs/mobilenetv2_admaw
计算通过mobilenetv2蒸馏mobilenetv2指标:
python metrice.py --task val --save_path runs/mobilenetv2_admaw_AT_self
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT_self
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT_self --test_tta
知识蒸馏, ghostnet作为teacher, ghostnet作为student, 使用AT进行蒸馏:
python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_admaw_AT_self --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 1000 --teacher_path runs/ghostnet_admaw
计算通过ghostnet蒸馏ghostnet指标:
python metrice.py --task val --save_path runs/ghostnet_admaw_AT_self
python metrice.py --task test --save_path runs/ghostnet_admaw_AT_self
python metrice.py --task test --save_path runs/ghostnet_admaw_AT_self --test_tta
model | val accuracy | val mpa | test accuracy | test mpa | test accuracy(TTA) | test mpa(TTA) |
---|---|---|---|---|---|---|
mobilenetv2 | 0.74116 | 0.74200 | 0.73483 | 0.73452 | 0.77012 | 0.76979 |
teacher->mobilenetv2 student->mobilenetv2 AT | 0.74677 | 0.74758 | 0.74430 | 0.74342 | 0.77012 | 0.76926 |
resnet50 | 0.78720 | 0.78744 | 0.77744 | 0.77670 | 0.81231 | 0.81162 |
teacher->resnet50 student->resnet50 AT | 0.79057 | 0.79091 | 0.79165 | 0.79026 | 0.81102 | 0.81030 |
ghostnet | 0.77709 | 0.77756 | 0.76367 | 0.76277 | 0.78046 | 0.77958 |
teacher->ghostnet student->ghostnet AT | 0.78046 | 0.78080 | 0.77142 | 0.77069 | 0.78820 | 0.78742 |
在V1.1版本的测试中发现efficientnet_v2网络作为teacher网络效果还不错.
普通训练mobilenetv2:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
计算mobilenetv2指标:
python metrice.py --task val --save_path runs/mobilenetv2
python metrice.py --task test --save_path runs/mobilenetv2
python metrice.py --task test --save_path runs/mobilenetv2 --test_tta
普通训练efficientnet_v2_s:
python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd
计算efficientnet_v2_s指标:
python metrice.py --task val --save_path runs/efficientnet_v2_s
python metrice.py --task test --save_path runs/efficientnet_v2_s
python metrice.py --task test --save_path runs/efficientnet_v2_s --test_tta
知识蒸馏, efficientnet_v2_s作为teacher, mobilenetv2作为student, 使用SoftTarget进行蒸馏:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_ST --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method SoftTarget --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s
知识蒸馏, efficientnet_v2_s作为teacher, mobilenetv2作为student, 使用MGD进行蒸馏:
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_EMA --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --ema \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_RDROP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --rdrop \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s
python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_EMA_RDROP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --rdrop --ema \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s
计算通过efficientnet_v2_s蒸馏mobilenetv2指标:
python metrice.py --task val --save_path runs/mobilenetv2_ST
python metrice.py --task test --save_path runs/mobilenetv2_ST
python metrice.py --task test --save_path runs/mobilenetv2_ST --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_MGD
python metrice.py --task test --save_path runs/mobilenetv2_MGD
python metrice.py --task test --save_path runs/mobilenetv2_MGD --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_MGD_EMA
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_MGD_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_RDROP --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_MGD_EMA_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA_RDROP --test_tta
model | val accuracy | val mpa | test accuracy | test mpa | test accuracy(TTA) | test mpa(TTA) |
---|---|---|---|---|---|---|
mobilenetv2 | 0.74116 | 0.74200 | 0.73483 | 0.73452 | 0.77012 | 0.76979 |
efficientnet_v2_s | 0.84166 | 0.84191 | 0.84460 | 0.84441 | 0.86483 | 0.86484 |
teacher->efficientnet_v2_s student->mobilenetv2 ST | 0.76137 | 0.76209 | 0.75161 | 0.75088 | 0.77830 | 0.77715 |
teacher->efficientnet_v2_s student->mobilenetv2 MGD | 0.77204 | 0.77288 | 0.77529 | 0.77464 | 0.79337 | 0.79261 |
teacher->efficientnet_v2_s student->mobilenetv2 MGD(EMA) | 0.77204 | 0.77267 | 0.77744 | 0.77671 | 0.80284 | 0.80201 |
teacher->efficientnet_v2_s student->mobilenetv2 MGD(RDrop) | 0.77204 | 0.77288 | 0.77529 | 0.77464 | 0.79337 | 0.79261 |
teacher->efficientnet_v2_s student->mobilenetv2 MGD(EMA,RDrop) | 0.77204 | 0.77267 | 0.77744 | 0.77671 | 0.80284 | 0.80201 |
关于Knowledge Distillation的一些解释
实验解释:
- 对于AT和SP蒸馏方法,上述实验都是使用block3和block4的特征层进行蒸馏.
- MPA是平均类别精度,在类别不平衡的情况下非常有用,当类别基本平衡的情况下,跟accuracy差不多.
- 当蒸馏loss出现nan的时候请不要开启AMP,AMP可能会导致浮点溢出导致的nan.
目前支持的类型有:
Name | Method | paper |
---|---|---|
SoftTarget | logits | https://arxiv.org/pdf/1503.02531.pdf |
MGD | features | https://arxiv.org/abs/2205.01529.pdf |
SP | features | https://arxiv.org/pdf/1907.09682.pdf |
AT | features | https://arxiv.org/pdf/1612.03928.pdf |
蒸馏学习跟模型,参数,蒸馏的方法,蒸馏的层都有关系,效果不好需要自行调整,其中SP和AT都可以对模型中的四个block进行组合计算蒸馏损失具体代码在utils/utils_fit.py的fitting_distill函数中可以进行修改.