DeiT 官方代码简明使用教程

1 自媒体账号

目前运营的自媒体账号如下:

QQ 学习交流群

  • CNN | RVfpga学习交流群(推荐,人数上限 2000) 541434600
  • FPGA&IC&DL学习交流群 866169462

菜鸡一枚,记录下自己的学习过程,可能后续有更新,也可能没有更新,谨慎参考。

  • V1.0 24-02-28 使用 DeiT 官方代码进行不同版本的 DeiT 在 imagenet1k 验证集上的评估

此博客配套 B站视频【UP主 雪天鱼】:https://www.bilibili.com/video/BV1JZ421278B

2 资料分享

3 基本使用-模型评估

  • 开发环境1
    cuda 11.1 | cuDnn8 | pytorch 1.8.0,torchvision 0.9.0 | python 3.8.3 |

  • 开发环境2
    RXT 4080
    cuda 11.1 | pytorch 1.9.0, torchvision 0.10.0 | python 3.7.16 | timm 0.3.2

从官方 github 链接:https://github.com/facebookresearch/deit/tree/main 下载代码压缩包到本地并解压:

image.png|450

重点关注以下文件:


models.py 定义了 DistilledVisionTransformer, 以及不同配置(tiny, small, base)的 DeiT model


README_deit.md  Model Zoo 小节提供了DeiT预训练模型的下载链接(非pth格式), 以及脚本的使用方法
  • 安装必要的依赖包
# 我安装的 numpy==1.19.5
pip install numpy

conda install -c pytorch pytorch torchvision

ViT模型的定义是直接从 timm 导入的 
pip install timm==0.3.2
/path/to/imagenet/   顶层目录
  train/  训练集
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/    验证集
    class1/
      img3.jpeg
    class2/
      img4.jpeg
  • Evaluation(评估)
    进行评估,只需要下载网盘中的 ImageNet1K 验证集 val.rar, 解压后按下图进行放置:

image.png|475

其中 train 文件夹是直接复制 val 改名得到的,因为当前官方脚本进行评估时,必须要同时读取训练集和验证集,而 imagenet 的训练集有160GB左右,实在太大,我这里就复制 val 并改名当做假的训练集,以便顺利执行模型评估,实际训练时,train里面还是要放真实的训练集的。

在单个GPU 机器上对预训练的 DeiT-base 模型进行评估,数据集使用 ImageNet val:

python main.py --eval --resume https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth --data-path /path/to/imagenet

/path/to/imagenet 是 imagenet 数据集的绝对路径

我所执行的指令为:
> 智星云 window10 cmd 需要先执行:activate base 激活环境

python main.py --eval --resume https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth --data-path C:/Users/vipuser/Downloads/imagenet

路径用 `/`, `\\` 而不是 `\`

我们调用 main.py 进行模型的评估,主要需要关注的输入参数如下:

模型参数:
--model , 设置用于训练的模型,默认是 deit_base_patch16_224

数据集参数:
--data-path, 设置 imagenet dataset 所在路径

其他参数:
--eval , `store_true`类型,当出现此参数时,模型进入评估模式
--resulme , 设置 checkpoint,即预训练模型

如果 GPU 的显存不大,记得修改下 --batch-size 参数,默认为 64, 性能弱的 GPU 可能跑不了

运行结果:

Not using distributed mode

即打印所有的配置参数

Namespace(ThreeAugment=False, aa='rand-m9-mstd0.5-inc1', attn_only=False, batch_size=64, bce_loss=False, clip_grad=None, color_jitter=0.3, cooldown_epochs=10, cosub=False, cutmix=1.0, cutmix_minmax=None, data_path='C:\\\\jccao\\\\AppFiles\\\\PythonPj\\\\PTQ4ViT-main\\\\datasets\\\\imagenet', data_set='IMNET', decay_epochs=30, decay_rate=0.1, device='cuda', dist_eval=False, dist_url='env://', distillation_alpha=0.5, distillation_tau=1.0, distillation_type='none', distributed=False, drop=0.0, drop_path=0.1, epochs=300, eval=True, eval_crop_ratio=0.875, finetune='', inat_category='name', input_size=224, lr=0.0005, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, min_lr=1e-05, mixup=0.8, mixup_mode='batch', mixup_prob=1.0, mixup_switch_prob=0.5, model='deit_base_patch16_224', model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, momentum=0.9, num_workers=10, opt='adamw', opt_betas=None, opt_eps=1e-08, output_dir='', patience_epochs=10, pin_mem=True, recount=1, remode='pixel', repeated_aug=True, reprob=0.25, resplit=False, resume='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', sched='cosine', seed=0, smoothing=0.1, src=False, start_epoch=0, teacher_model='regnety_160', teacher_path='', train_interpolation='bicubic', train_mode=True, unscale_lr=False, warmup_epochs=5, warmup_lr=1e-06, weight_decay=0.05, world_size=1)

Creating model: deit_base_patch16_224

number of params: 86567656

RTX 4080 在Batch-Size=64时, 完成 imagenet 验证集评估所需时间大约为 2min

Test:  [  0/521]  eta: 1:59:35  loss: 0.2683 (0.2683)  acc1: 96.8750 (96.8750)  acc5: 98.9583 (98.9583)  time: 13.7734  data: 11.5260  max mem: 1607
Test:  [ 10/521]  eta: 0:11:33  loss: 0.4705 (0.3933)  acc1: 92.7083 (93.4659)  acc5: 98.9583 (98.2955)  time: 1.3579  data: 1.0481  max mem: 1607
Test:  [ 20/521]  eta: 0:06:24  loss: 0.5191 (0.5495)  acc1: 88.5417 (88.4425)  acc5: 97.9167 (97.9663)  time: 0.1175  data: 0.0004  max mem: 1607
Test:  [ 30/521]  eta: 0:04:34  loss: 0.6762 (0.6165)  acc1: 84.3750 (86.6263)  acc5: 97.9167 (97.4462)  time: 0.1188  data: 0.0003  max mem: 1607
Test:  [ 40/521]  eta: 0:03:37  loss: 0.7444 (0.6966)  acc1: 79.1667 (84.0193)  acc5: 95.8333 (96.9512)  time: 0.1189  data: 0.0002  max mem: 1607
Test:  [ 50/521]  eta: 0:03:01  loss: 0.5451 (0.6377)  acc1: 90.6250 (85.9886)  acc5: 97.9167 (97.2835)  time: 0.1173  data: 0.0003  max mem: 1607
Test:  [ 60/521]  eta: 0:02:37  loss: 0.4698 (0.6218)  acc1: 91.6667 (86.6633)  acc5: 97.9167 (97.3190)  time: 0.1173  data: 0.0001  max mem: 1607
Test:  [ 70/521]  eta: 0:02:19  loss: 0.5640 (0.6133)  acc1: 89.5833 (86.9572)  acc5: 97.9167 (97.4619)  time: 0.1179  data: 0.0001  max mem: 1607
Test:  [ 80/521]  eta: 0:02:06  loss: 0.3679 (0.5868)  acc1: 90.6250 (87.7443)  acc5: 98.9583 (97.5952)  time: 0.1165  data: 0.0002  max mem: 1607
Test:  [ 90/521]  eta: 0:01:55  loss: 0.5198 (0.6136)  acc1: 89.5833 (87.0421)  acc5: 98.9583 (97.4931)  time: 0.1165  data: 0.0002  max mem: 1607
Test:  [100/521]  eta: 0:01:46  loss: 0.7690 (0.6237)  acc1: 84.3750 (86.6955)  acc5: 96.8750 (97.4629)  time: 0.1178  data: 0.0004  max mem: 1607
Test:  [110/521]  eta: 0:01:38  loss: 0.6432 (0.6253)  acc1: 86.4583 (86.6554)  acc5: 97.9167 (97.4756)  time: 0.1181  data: 0.0004  max mem: 1607
Test:  [120/521]  eta: 0:01:32  loss: 0.6027 (0.6279)  acc1: 87.5000 (86.6649)  acc5: 97.9167 (97.4346)  time: 0.1177  data: 0.0003  max mem: 1607
Test:  [130/521]  eta: 0:01:26  loss: 0.6634 (0.6352)  acc1: 86.4583 (86.2914)  acc5: 97.9167 (97.4793)  time: 0.1166  data: 0.0002  max mem: 1607
Test:  [140/521]  eta: 0:01:21  loss: 0.6191 (0.6254)  acc1: 87.5000 (86.5396)  acc5: 97.9167 (97.5251)  time: 0.1172  data: 0.0002  max mem: 1607
Test:  [150/521]  eta: 0:01:17  loss: 0.5174 (0.6440)  acc1: 85.4167 (86.0582)  acc5: 97.9167 (97.4131)  time: 0.1183  data: 0.0002  max mem: 1607
Test:  [160/521]  eta: 0:01:13  loss: 0.6763 (0.6388)  acc1: 85.4167 (86.2901)  acc5: 97.9167 (97.4961)  time: 0.1180  data: 0.0002  max mem: 1607
Test:  [170/521]  eta: 0:01:09  loss: 0.4719 (0.6331)  acc1: 90.6250 (86.4340)  acc5: 97.9167 (97.5085)  time: 0.1179  data: 0.0001  max mem: 1607
Test:  [180/521]  eta: 0:01:05  loss: 0.4948 (0.6267)  acc1: 89.5833 (86.5677)  acc5: 97.9167 (97.5541)  time: 0.1176  data: 0.0000  max mem: 1607
Test:  [190/521]  eta: 0:01:02  loss: 0.4948 (0.6274)  acc1: 87.5000 (86.4692)  acc5: 98.9583 (97.6004)  time: 0.1181  data: 0.0000  max mem: 1607
Test:  [200/521]  eta: 0:00:59  loss: 0.6553 (0.6362)  acc1: 86.4583 (86.2666)  acc5: 97.9167 (97.5228)  time: 0.1172  data: 0.0001  max mem: 1607
Test:  [210/521]  eta: 0:00:56  loss: 0.6726 (0.6371)  acc1: 87.5000 (86.2707)  acc5: 95.8333 (97.4921)  time: 0.1160  data: 0.0002  max mem: 1607
Test:  [220/521]  eta: 0:00:53  loss: 0.7357 (0.6544)  acc1: 83.3333 (85.8267)  acc5: 95.8333 (97.3181)  time: 0.1175  data: 0.0002  max mem: 1607
Test:  [230/521]  eta: 0:00:51  loss: 0.7568 (0.6646)  acc1: 81.2500 (85.5429)  acc5: 94.7917 (97.2222)  time: 0.1185  data: 0.0003  max mem: 1607
Test:  [240/521]  eta: 0:00:48  loss: 0.8931 (0.6753)  acc1: 80.2083 (85.2654)  acc5: 93.7500 (97.0911)  time: 0.1180  data: 0.0003  max mem: 1607
Test:  [250/521]  eta: 0:00:46  loss: 0.9160 (0.6916)  acc1: 80.2083 (84.9228)  acc5: 92.7083 (96.8999)  time: 0.1175  data: 0.0003  max mem: 1607
Test:  [260/521]  eta: 0:00:44  loss: 1.1382 (0.7074)  acc1: 71.8750 (84.4628)  acc5: 93.7500 (96.7872)  time: 0.1175  data: 0.0002  max mem: 1607
Test:  [270/521]  eta: 0:00:42  loss: 1.0718 (0.7213)  acc1: 73.9583 (84.1136)  acc5: 93.7500 (96.6598)  time: 0.1168  data: 0.0002  max mem: 1607
Test:  [280/521]  eta: 0:00:40  loss: 0.9118 (0.7298)  acc1: 77.0833 (83.8894)  acc5: 93.7500 (96.5859)  time: 0.1175  data: 0.0002  max mem: 1607
Test:  [290/521]  eta: 0:00:38  loss: 0.7887 (0.7338)  acc1: 80.2083 (83.7915)  acc5: 94.7917 (96.5278)  time: 0.1190  data: 0.0001  max mem: 1607
Test:  [300/521]  eta: 0:00:36  loss: 0.6294 (0.7296)  acc1: 85.4167 (83.9113)  acc5: 96.8750 (96.5497)  time: 0.1187  data: 0.0001  max mem: 1607
Test:  [310/521]  eta: 0:00:34  loss: 0.7164 (0.7380)  acc1: 83.3333 (83.7621)  acc5: 95.8333 (96.4362)  time: 0.1188  data: 0.0002  max mem: 1607
Test:  [320/521]  eta: 0:00:32  loss: 0.8881 (0.7392)  acc1: 81.2500 (83.8169)  acc5: 92.7083 (96.3785)  time: 0.1197  data: 0.0002  max mem: 1607
Test:  [330/521]  eta: 0:00:30  loss: 0.8837 (0.7542)  acc1: 81.2500 (83.4687)  acc5: 94.7917 (96.2267)  time: 0.1191  data: 0.0002  max mem: 1607
Test:  [340/521]  eta: 0:00:28  loss: 1.0756 (0.7600)  acc1: 78.1250 (83.2570)  acc5: 93.7500 (96.1571)  time: 0.1187  data: 0.0002  max mem: 1607
Test:  [350/521]  eta: 0:00:26  loss: 0.8753 (0.7657)  acc1: 78.1250 (83.0247)  acc5: 95.8333 (96.1360)  time: 0.1196  data: 0.0002  max mem: 1607
Test:  [360/521]  eta: 0:00:25  loss: 1.0177 (0.7747)  acc1: 76.0417 (82.8428)  acc5: 93.7500 (96.0353)  time: 0.1202  data: 0.0001  max mem: 1607
Test:  [370/521]  eta: 0:00:23  loss: 0.9400 (0.7769)  acc1: 81.2500 (82.8055)  acc5: 93.7500 (96.0102)  time: 0.1201  data: 0.0000  max mem: 1607
Test:  [380/521]  eta: 0:00:21  loss: 0.7648 (0.7784)  acc1: 82.2917 (82.8111)  acc5: 95.8333 (95.9700)  time: 0.1192  data: 0.0001  max mem: 1607
Test:  [390/521]  eta: 0:00:20  loss: 1.0038 (0.7868)  acc1: 78.1250 (82.6034)  acc5: 93.7500 (95.8626)  time: 0.1174  data: 0.0003  max mem: 1607
Test:  [400/521]  eta: 0:00:18  loss: 1.0038 (0.7901)  acc1: 78.1250 (82.5566)  acc5: 93.7500 (95.8359)  time: 0.1175  data: 0.0003  max mem: 1607
Test:  [410/521]  eta: 0:00:16  loss: 0.8875 (0.7934)  acc1: 81.2500 (82.4944)  acc5: 93.7500 (95.7776)  time: 0.1178  data: 0.0002  max mem: 1607
Test:  [420/521]  eta: 0:00:15  loss: 0.9099 (0.7954)  acc1: 81.2500 (82.5020)  acc5: 93.7500 (95.7368)  time: 0.1178  data: 0.0002  max mem: 1607
Test:  [430/521]  eta: 0:00:13  loss: 0.9099 (0.8014)  acc1: 81.2500 (82.3473)  acc5: 93.7500 (95.6907)  time: 0.1199  data: 0.0005  max mem: 1607
Test:  [440/521]  eta: 0:00:12  loss: 1.0472 (0.8096)  acc1: 76.0417 (82.1122)  acc5: 93.7500 (95.6231)  time: 0.1205  data: 0.0004  max mem: 1607
Test:  [450/521]  eta: 0:00:10  loss: 0.9924 (0.8123)  acc1: 77.0833 (82.0261)  acc5: 94.7917 (95.5977)  time: 0.1194  data: 0.0002  max mem: 1607
Test:  [460/521]  eta: 0:00:09  loss: 0.8889 (0.8129)  acc1: 79.1667 (81.9844)  acc5: 95.8333 (95.5983)  time: 0.1187  data: 0.0002  max mem: 1607
Test:  [470/521]  eta: 0:00:07  loss: 0.9033 (0.8174)  acc1: 81.2500 (81.8781)  acc5: 94.7917 (95.5591)  time: 0.1187  data: 0.0002  max mem: 1607
Test:  [480/521]  eta: 0:00:06  loss: 0.9033 (0.8215)  acc1: 80.2083 (81.7654)  acc5: 94.7917 (95.5431)  time: 0.1197  data: 0.0002  max mem: 1607
Test:  [490/521]  eta: 0:00:04  loss: 0.7608 (0.8197)  acc1: 82.2917 (81.8186)  acc5: 95.8333 (95.5639)  time: 0.1197  data: 0.0002  max mem: 1607
Test:  [500/521]  eta: 0:00:03  loss: 0.5812 (0.8179)  acc1: 83.3333 (81.8738)  acc5: 96.8750 (95.5880)  time: 0.1181  data: 0.0003  max mem: 1607
Test:  [510/521]  eta: 0:00:01  loss: 0.7869 (0.8245)  acc1: 80.2083 (81.7005)  acc5: 96.8750 (95.5541)  time: 0.1165  data: 0.0001  max mem: 1607
Test:  [520/521]  eta: 0:00:00  loss: 0.7329 (0.8201)  acc1: 79.1667 (81.8080)  acc5: 96.8750 (95.5940)  time: 0.1151  data: 0.0001  max mem: 1607
Test: Total time: 0:01:15 (0.1446 s / it)
* Acc@1 81.808 Acc@5 95.594 loss 0.820
Accuracy of the network on the 50000 test images: 81.8%

这里的 Acc@1 表示 top-1 accuracy,  Acc@5 表示 top-5 accuracy,这里分别为 81.808% 和 95.594% 和 github 所给数据基本一致。

3.1 不同版本 DeiT 的评估命令

# 模型越小,运行速度越快

# DeiT Base  输入图片size 为 384 而非默认的 224  未成功运行,有报错
python main.py --eval --model deit_base_patch16_384 --input-size 384 --resume https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth --data-path /path/to/imagenet

# DeiT Small 已成功运行
python main.py --eval --resume https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth --model deit_small_patch16_224 --data-path /path/to/imagenet

结果为:
* Acc@1 79.828 Acc@5 94.942 loss 0.881
Accuracy of the network on the 50000 test images: 79.8%

# DeiT Tiny 已成功运行
python main.py --eval --resume https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth --model deit_tiny_patch16_224 --data-path /path/to/imagenet

* Acc@1 72.122 Acc@5 91.136 loss 1.219
Accuracy of the network on the 50000 test images: 72.1%

3.2 可能遇到的问题

在执行脚本时可能遇到以下报错:
(1) ImportError: cannot import name ‘container_abcs‘ from ‘torch._six‘
报错原因:torch 1.8 版本之前和之后的版本,对于 container_abcs的导入方式不同。

1.8 以上版本使用import collections.abc as container_abcs
1.8 以下版本使用 from torch._six import container_abcs

进入报错的地方(根据 log 提示来, 我这里是C:\Users\admin\.conda\envs\pytorch_jccao\lib\site-packages\timm\models\layers\helpers.py"),将 from torch._six import container_abcs 改为 import collections.abc as container_abcs, 重新运行代码,无报错。

(2)ImportError: cannot import name '_pil_interp' from 'timm.data.transforms'
完整报错:

Traceback (most recent call last):
  File "main.py", line 24, in <module>
    from augment import new_data_aug_generator
  File "C:\user\AppFiles\PythonPj\deit-main\augment.py", line 12, in <module>
    from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
ImportError: cannot import name '_pil_interp' from 'timm.data.transforms' 

原因:这是由于安装的 timm 包版本比较新导致,出现这个报错时,所安装的 timm 包版本为 0.9.12
解决方法:我们根据提示进入 augment.py 文件,修改 line 12 为

from timm.data.transforms import str_to_pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor

即将 import _pil_interp 改为导入 str_to_pil_interp, 并且在同步对调用该 api 的地方进行修改,但在此代码中,实际上并没有用到 _pil_interp 这个api,所以直接删除该 api 的导入也是可以的。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

雪天鱼

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值