1 自媒体账号
目前运营的自媒体账号如下:
- 哔哩哔哩 【雪天鱼】: 雪天鱼个人主页-bilibili.com
- 微信公众号 【雪天鱼】
- CSDN 【雪天鱼】: 雪天鱼-CSDN博客
QQ 学习交流群
- CNN | RVfpga学习交流群(推荐,人数上限 2000) 541434600
- FPGA&IC&DL学习交流群 866169462
菜鸡一枚,记录下自己的学习过程,可能后续有更新,也可能没有更新,谨慎参考。
- V1.0 24-02-28 使用 DeiT 官方代码进行不同版本的 DeiT 在 imagenet1k 验证集上的评估
此博客配套 B站视频【UP主 雪天鱼】:https://www.bilibili.com/video/BV1JZ421278B
2 资料分享
-
DeiT paper link: [2012.12877] Training data-efficient image transformers & distillation through attention (arxiv.org)
-
DeiT offical code link: https://github.com/facebookresearch/deit/tree/main
-
我已处理好的 ImageNet1K 数据集(包含
完整数据集
和仅验证集
):
链接:https://pan.baidu.com/s/1F97KAEoAzF-5pPtOvgtzoA
提取码:2024 -
我已下载的 DeiT 预训练模型(Tiny, Small, Base):
链接:https://pan.baidu.com/s/1TOIoauo3-2yEADclJ4S9iw
提取码:2024 -
ViT 自注意力学习到的特征信息可视化工具:GitHub - sayakpaul/probing-vits: Probing the representations of Vision Transformers.
3 基本使用-模型评估
-
开发环境1
cuda 11.1 | cuDnn8 | pytorch 1.8.0,torchvision 0.9.0 | python 3.8.3 | -
开发环境2
RXT 4080
cuda 11.1 | pytorch 1.9.0, torchvision 0.10.0 | python 3.7.16 | timm 0.3.2
从官方 github 链接:https://github.com/facebookresearch/deit/tree/main 下载代码压缩包到本地并解压:
重点关注以下文件:
models.py 定义了 DistilledVisionTransformer, 以及不同配置(tiny, small, base)的 DeiT model
README_deit.md Model Zoo 小节提供了DeiT预训练模型的下载链接(非pth格式), 以及脚本的使用方法
- 安装必要的依赖包
# 我安装的 numpy==1.19.5
pip install numpy
conda install -c pytorch pytorch torchvision
ViT模型的定义是直接从 timm 导入的
pip install timm==0.3.2
- ImageNet1K 数据集准备
下载我提取好的 ImageNet1K 数据集,并按下方的目录结构进行放置,以便可以正常的被 torchvision 的datasets.ImageFolder
https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html 读取到。训练集和验证集分别放到train/
和val/
目录下。
/path/to/imagenet/ 顶层目录
train/ 训练集
class1/
img1.jpeg
class2/
img2.jpeg
val/ 验证集
class1/
img3.jpeg
class2/
img4.jpeg
- Evaluation(评估)
进行评估,只需要下载网盘中的 ImageNet1K 验证集val.rar
, 解压后按下图进行放置:
其中 train 文件夹是直接复制 val 改名得到的,因为当前官方脚本进行评估时,必须要同时读取训练集和验证集,而 imagenet 的训练集有160GB左右,实在太大,我这里就复制 val 并改名当做假的训练集,以便顺利执行模型评估,实际训练时,train里面还是要放真实的训练集的。
在单个GPU 机器上对预训练的 DeiT-base 模型进行评估,数据集使用 ImageNet val:
python main.py --eval --resume https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth --data-path /path/to/imagenet
/path/to/imagenet 是 imagenet 数据集的绝对路径
我所执行的指令为:
> 智星云 window10 cmd 需要先执行:activate base 激活环境
python main.py --eval --resume https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth --data-path C:/Users/vipuser/Downloads/imagenet
路径用 `/`, `\\` 而不是 `\`
我们调用 main.py 进行模型的评估,主要需要关注的输入参数如下:
模型参数:
--model , 设置用于训练的模型,默认是 deit_base_patch16_224
数据集参数:
--data-path, 设置 imagenet dataset 所在路径
其他参数:
--eval , `store_true`类型,当出现此参数时,模型进入评估模式
--resulme , 设置 checkpoint,即预训练模型
如果 GPU 的显存不大,记得修改下 --batch-size 参数,默认为 64, 性能弱的 GPU 可能跑不了
运行结果:
Not using distributed mode
即打印所有的配置参数
Namespace(ThreeAugment=False, aa='rand-m9-mstd0.5-inc1', attn_only=False, batch_size=64, bce_loss=False, clip_grad=None, color_jitter=0.3, cooldown_epochs=10, cosub=False, cutmix=1.0, cutmix_minmax=None, data_path='C:\\\\jccao\\\\AppFiles\\\\PythonPj\\\\PTQ4ViT-main\\\\datasets\\\\imagenet', data_set='IMNET', decay_epochs=30, decay_rate=0.1, device='cuda', dist_eval=False, dist_url='env://', distillation_alpha=0.5, distillation_tau=1.0, distillation_type='none', distributed=False, drop=0.0, drop_path=0.1, epochs=300, eval=True, eval_crop_ratio=0.875, finetune='', inat_category='name', input_size=224, lr=0.0005, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, min_lr=1e-05, mixup=0.8, mixup_mode='batch', mixup_prob=1.0, mixup_switch_prob=0.5, model='deit_base_patch16_224', model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, momentum=0.9, num_workers=10, opt='adamw', opt_betas=None, opt_eps=1e-08, output_dir='', patience_epochs=10, pin_mem=True, recount=1, remode='pixel', repeated_aug=True, reprob=0.25, resplit=False, resume='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', sched='cosine', seed=0, smoothing=0.1, src=False, start_epoch=0, teacher_model='regnety_160', teacher_path='', train_interpolation='bicubic', train_mode=True, unscale_lr=False, warmup_epochs=5, warmup_lr=1e-06, weight_decay=0.05, world_size=1)
Creating model: deit_base_patch16_224
number of params: 86567656
RTX 4080 在Batch-Size=64时, 完成 imagenet 验证集评估所需时间大约为 2min
Test: [ 0/521] eta: 1:59:35 loss: 0.2683 (0.2683) acc1: 96.8750 (96.8750) acc5: 98.9583 (98.9583) time: 13.7734 data: 11.5260 max mem: 1607
Test: [ 10/521] eta: 0:11:33 loss: 0.4705 (0.3933) acc1: 92.7083 (93.4659) acc5: 98.9583 (98.2955) time: 1.3579 data: 1.0481 max mem: 1607
Test: [ 20/521] eta: 0:06:24 loss: 0.5191 (0.5495) acc1: 88.5417 (88.4425) acc5: 97.9167 (97.9663) time: 0.1175 data: 0.0004 max mem: 1607
Test: [ 30/521] eta: 0:04:34 loss: 0.6762 (0.6165) acc1: 84.3750 (86.6263) acc5: 97.9167 (97.4462) time: 0.1188 data: 0.0003 max mem: 1607
Test: [ 40/521] eta: 0:03:37 loss: 0.7444 (0.6966) acc1: 79.1667 (84.0193) acc5: 95.8333 (96.9512) time: 0.1189 data: 0.0002 max mem: 1607
Test: [ 50/521] eta: 0:03:01 loss: 0.5451 (0.6377) acc1: 90.6250 (85.9886) acc5: 97.9167 (97.2835) time: 0.1173 data: 0.0003 max mem: 1607
Test: [ 60/521] eta: 0:02:37 loss: 0.4698 (0.6218) acc1: 91.6667 (86.6633) acc5: 97.9167 (97.3190) time: 0.1173 data: 0.0001 max mem: 1607
Test: [ 70/521] eta: 0:02:19 loss: 0.5640 (0.6133) acc1: 89.5833 (86.9572) acc5: 97.9167 (97.4619) time: 0.1179 data: 0.0001 max mem: 1607
Test: [ 80/521] eta: 0:02:06 loss: 0.3679 (0.5868) acc1: 90.6250 (87.7443) acc5: 98.9583 (97.5952) time: 0.1165 data: 0.0002 max mem: 1607
Test: [ 90/521] eta: 0:01:55 loss: 0.5198 (0.6136) acc1: 89.5833 (87.0421) acc5: 98.9583 (97.4931) time: 0.1165 data: 0.0002 max mem: 1607
Test: [100/521] eta: 0:01:46 loss: 0.7690 (0.6237) acc1: 84.3750 (86.6955) acc5: 96.8750 (97.4629) time: 0.1178 data: 0.0004 max mem: 1607
Test: [110/521] eta: 0:01:38 loss: 0.6432 (0.6253) acc1: 86.4583 (86.6554) acc5: 97.9167 (97.4756) time: 0.1181 data: 0.0004 max mem: 1607
Test: [120/521] eta: 0:01:32 loss: 0.6027 (0.6279) acc1: 87.5000 (86.6649) acc5: 97.9167 (97.4346) time: 0.1177 data: 0.0003 max mem: 1607
Test: [130/521] eta: 0:01:26 loss: 0.6634 (0.6352) acc1: 86.4583 (86.2914) acc5: 97.9167 (97.4793) time: 0.1166 data: 0.0002 max mem: 1607
Test: [140/521] eta: 0:01:21 loss: 0.6191 (0.6254) acc1: 87.5000 (86.5396) acc5: 97.9167 (97.5251) time: 0.1172 data: 0.0002 max mem: 1607
Test: [150/521] eta: 0:01:17 loss: 0.5174 (0.6440) acc1: 85.4167 (86.0582) acc5: 97.9167 (97.4131) time: 0.1183 data: 0.0002 max mem: 1607
Test: [160/521] eta: 0:01:13 loss: 0.6763 (0.6388) acc1: 85.4167 (86.2901) acc5: 97.9167 (97.4961) time: 0.1180 data: 0.0002 max mem: 1607
Test: [170/521] eta: 0:01:09 loss: 0.4719 (0.6331) acc1: 90.6250 (86.4340) acc5: 97.9167 (97.5085) time: 0.1179 data: 0.0001 max mem: 1607
Test: [180/521] eta: 0:01:05 loss: 0.4948 (0.6267) acc1: 89.5833 (86.5677) acc5: 97.9167 (97.5541) time: 0.1176 data: 0.0000 max mem: 1607
Test: [190/521] eta: 0:01:02 loss: 0.4948 (0.6274) acc1: 87.5000 (86.4692) acc5: 98.9583 (97.6004) time: 0.1181 data: 0.0000 max mem: 1607
Test: [200/521] eta: 0:00:59 loss: 0.6553 (0.6362) acc1: 86.4583 (86.2666) acc5: 97.9167 (97.5228) time: 0.1172 data: 0.0001 max mem: 1607
Test: [210/521] eta: 0:00:56 loss: 0.6726 (0.6371) acc1: 87.5000 (86.2707) acc5: 95.8333 (97.4921) time: 0.1160 data: 0.0002 max mem: 1607
Test: [220/521] eta: 0:00:53 loss: 0.7357 (0.6544) acc1: 83.3333 (85.8267) acc5: 95.8333 (97.3181) time: 0.1175 data: 0.0002 max mem: 1607
Test: [230/521] eta: 0:00:51 loss: 0.7568 (0.6646) acc1: 81.2500 (85.5429) acc5: 94.7917 (97.2222) time: 0.1185 data: 0.0003 max mem: 1607
Test: [240/521] eta: 0:00:48 loss: 0.8931 (0.6753) acc1: 80.2083 (85.2654) acc5: 93.7500 (97.0911) time: 0.1180 data: 0.0003 max mem: 1607
Test: [250/521] eta: 0:00:46 loss: 0.9160 (0.6916) acc1: 80.2083 (84.9228) acc5: 92.7083 (96.8999) time: 0.1175 data: 0.0003 max mem: 1607
Test: [260/521] eta: 0:00:44 loss: 1.1382 (0.7074) acc1: 71.8750 (84.4628) acc5: 93.7500 (96.7872) time: 0.1175 data: 0.0002 max mem: 1607
Test: [270/521] eta: 0:00:42 loss: 1.0718 (0.7213) acc1: 73.9583 (84.1136) acc5: 93.7500 (96.6598) time: 0.1168 data: 0.0002 max mem: 1607
Test: [280/521] eta: 0:00:40 loss: 0.9118 (0.7298) acc1: 77.0833 (83.8894) acc5: 93.7500 (96.5859) time: 0.1175 data: 0.0002 max mem: 1607
Test: [290/521] eta: 0:00:38 loss: 0.7887 (0.7338) acc1: 80.2083 (83.7915) acc5: 94.7917 (96.5278) time: 0.1190 data: 0.0001 max mem: 1607
Test: [300/521] eta: 0:00:36 loss: 0.6294 (0.7296) acc1: 85.4167 (83.9113) acc5: 96.8750 (96.5497) time: 0.1187 data: 0.0001 max mem: 1607
Test: [310/521] eta: 0:00:34 loss: 0.7164 (0.7380) acc1: 83.3333 (83.7621) acc5: 95.8333 (96.4362) time: 0.1188 data: 0.0002 max mem: 1607
Test: [320/521] eta: 0:00:32 loss: 0.8881 (0.7392) acc1: 81.2500 (83.8169) acc5: 92.7083 (96.3785) time: 0.1197 data: 0.0002 max mem: 1607
Test: [330/521] eta: 0:00:30 loss: 0.8837 (0.7542) acc1: 81.2500 (83.4687) acc5: 94.7917 (96.2267) time: 0.1191 data: 0.0002 max mem: 1607
Test: [340/521] eta: 0:00:28 loss: 1.0756 (0.7600) acc1: 78.1250 (83.2570) acc5: 93.7500 (96.1571) time: 0.1187 data: 0.0002 max mem: 1607
Test: [350/521] eta: 0:00:26 loss: 0.8753 (0.7657) acc1: 78.1250 (83.0247) acc5: 95.8333 (96.1360) time: 0.1196 data: 0.0002 max mem: 1607
Test: [360/521] eta: 0:00:25 loss: 1.0177 (0.7747) acc1: 76.0417 (82.8428) acc5: 93.7500 (96.0353) time: 0.1202 data: 0.0001 max mem: 1607
Test: [370/521] eta: 0:00:23 loss: 0.9400 (0.7769) acc1: 81.2500 (82.8055) acc5: 93.7500 (96.0102) time: 0.1201 data: 0.0000 max mem: 1607
Test: [380/521] eta: 0:00:21 loss: 0.7648 (0.7784) acc1: 82.2917 (82.8111) acc5: 95.8333 (95.9700) time: 0.1192 data: 0.0001 max mem: 1607
Test: [390/521] eta: 0:00:20 loss: 1.0038 (0.7868) acc1: 78.1250 (82.6034) acc5: 93.7500 (95.8626) time: 0.1174 data: 0.0003 max mem: 1607
Test: [400/521] eta: 0:00:18 loss: 1.0038 (0.7901) acc1: 78.1250 (82.5566) acc5: 93.7500 (95.8359) time: 0.1175 data: 0.0003 max mem: 1607
Test: [410/521] eta: 0:00:16 loss: 0.8875 (0.7934) acc1: 81.2500 (82.4944) acc5: 93.7500 (95.7776) time: 0.1178 data: 0.0002 max mem: 1607
Test: [420/521] eta: 0:00:15 loss: 0.9099 (0.7954) acc1: 81.2500 (82.5020) acc5: 93.7500 (95.7368) time: 0.1178 data: 0.0002 max mem: 1607
Test: [430/521] eta: 0:00:13 loss: 0.9099 (0.8014) acc1: 81.2500 (82.3473) acc5: 93.7500 (95.6907) time: 0.1199 data: 0.0005 max mem: 1607
Test: [440/521] eta: 0:00:12 loss: 1.0472 (0.8096) acc1: 76.0417 (82.1122) acc5: 93.7500 (95.6231) time: 0.1205 data: 0.0004 max mem: 1607
Test: [450/521] eta: 0:00:10 loss: 0.9924 (0.8123) acc1: 77.0833 (82.0261) acc5: 94.7917 (95.5977) time: 0.1194 data: 0.0002 max mem: 1607
Test: [460/521] eta: 0:00:09 loss: 0.8889 (0.8129) acc1: 79.1667 (81.9844) acc5: 95.8333 (95.5983) time: 0.1187 data: 0.0002 max mem: 1607
Test: [470/521] eta: 0:00:07 loss: 0.9033 (0.8174) acc1: 81.2500 (81.8781) acc5: 94.7917 (95.5591) time: 0.1187 data: 0.0002 max mem: 1607
Test: [480/521] eta: 0:00:06 loss: 0.9033 (0.8215) acc1: 80.2083 (81.7654) acc5: 94.7917 (95.5431) time: 0.1197 data: 0.0002 max mem: 1607
Test: [490/521] eta: 0:00:04 loss: 0.7608 (0.8197) acc1: 82.2917 (81.8186) acc5: 95.8333 (95.5639) time: 0.1197 data: 0.0002 max mem: 1607
Test: [500/521] eta: 0:00:03 loss: 0.5812 (0.8179) acc1: 83.3333 (81.8738) acc5: 96.8750 (95.5880) time: 0.1181 data: 0.0003 max mem: 1607
Test: [510/521] eta: 0:00:01 loss: 0.7869 (0.8245) acc1: 80.2083 (81.7005) acc5: 96.8750 (95.5541) time: 0.1165 data: 0.0001 max mem: 1607
Test: [520/521] eta: 0:00:00 loss: 0.7329 (0.8201) acc1: 79.1667 (81.8080) acc5: 96.8750 (95.5940) time: 0.1151 data: 0.0001 max mem: 1607
Test: Total time: 0:01:15 (0.1446 s / it)
* Acc@1 81.808 Acc@5 95.594 loss 0.820
Accuracy of the network on the 50000 test images: 81.8%
这里的 Acc@1 表示 top-1 accuracy, Acc@5 表示 top-5 accuracy,这里分别为 81.808% 和 95.594% 和 github 所给数据基本一致。
3.1 不同版本 DeiT 的评估命令
# 模型越小,运行速度越快
# DeiT Base 输入图片size 为 384 而非默认的 224 未成功运行,有报错
python main.py --eval --model deit_base_patch16_384 --input-size 384 --resume https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth --data-path /path/to/imagenet
# DeiT Small 已成功运行
python main.py --eval --resume https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth --model deit_small_patch16_224 --data-path /path/to/imagenet
结果为:
* Acc@1 79.828 Acc@5 94.942 loss 0.881
Accuracy of the network on the 50000 test images: 79.8%
# DeiT Tiny 已成功运行
python main.py --eval --resume https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth --model deit_tiny_patch16_224 --data-path /path/to/imagenet
* Acc@1 72.122 Acc@5 91.136 loss 1.219
Accuracy of the network on the 50000 test images: 72.1%
3.2 可能遇到的问题
在执行脚本时可能遇到以下报错:
(1) ImportError: cannot import name ‘container_abcs‘ from ‘torch._six‘
报错原因:torch 1.8 版本之前和之后的版本,对于 container_abcs
的导入方式不同。
1.8 以上版本使用import collections.abc as container_abcs
1.8 以下版本使用 from torch._six import container_abcs
进入报错的地方(根据 log 提示来, 我这里是C:\Users\admin\.conda\envs\pytorch_jccao\lib\site-packages\timm\models\layers\helpers.py"
),将 from torch._six import container_abcs
改为 import collections.abc as container_abcs
, 重新运行代码,无报错。
(2)ImportError: cannot import name '_pil_interp' from 'timm.data.transforms'
完整报错:
Traceback (most recent call last):
File "main.py", line 24, in <module>
from augment import new_data_aug_generator
File "C:\user\AppFiles\PythonPj\deit-main\augment.py", line 12, in <module>
from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
ImportError: cannot import name '_pil_interp' from 'timm.data.transforms'
原因:这是由于安装的 timm 包版本比较新导致,出现这个报错时,所安装的 timm 包版本为 0.9.12
解决方法:我们根据提示进入 augment.py
文件,修改 line 12 为
from timm.data.transforms import str_to_pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
即将 import _pil_interp
改为导入 str_to_pil_interp
, 并且在同步对调用该 api 的地方进行修改,但在此代码中,实际上并没有用到 _pil_interp
这个api,所以直接删除该 api 的导入也是可以的。