Vision Transformer Pytorch 实现代码学习记录

目前运营的社交平台账号:


可能后续有更新,也可能没有更新,谨慎参考

  • V1.0 24-02-13 ViT 代码的基本训练, 预测推理脚本运行

1 学习目标

  1. 能用官方的 ViT 预训练模型在 imagenet1k 上进行预测推理 完成
  2. 在 ImageNet-1K 的完整验证集上验证下载的官方 ViT 预训练模型的准确率

未处理的问题:

  • 官方的 ViT 预处理模型训练时的图片数据预处理方法是什么?

Github pytorch实现的 ViT 代码下载:deep-learning-for-image-processing/pytorch_classification/vision_transformer at master · WZMIAOMIAO/deep-learning-for-image-processing · GitHub
Note: 非官方仓库代码,但 vit_model.py 即ViT 模型定义代码是用的被 TIMM 采用的代码。

已经处理好的 ImageNet1K数据集网盘链接:
链接:https://pan.baidu.com/s/1sYMIwqkNldmqpaJqDK8lSQ?pwd=2024
提取码:2024

2 运行 flops.py (不重要,可跳过)

先安装fvcore包: pip install fvcore
然后点击运行会出错,报错为:
ValueError: Invalid type <class 'numpy.int32'> for the flop count! Please use a wider type to avoid overflow.

|850

点击红框中的位置进入到 jit_handles.py 文件中,修改 14~19行代码如下:

try:  
    from math import prod  
except ImportError:  
    from numpy import prod as prodnp  
    def prod(x):   
        return int(prodnp(x))

然后再重新运行 flops.py 无报错。结果为:

Self-Attention FLOPs: 60129542144
# 中间有一些红色字体的 warnings
Multi-Head Attention FLOPs: 68719476736

3 训练—train.py

从 vit_model 中导入想要训练的 ViT版本, 把默认导入的 vit_base_patch16_224_in21k 给注释掉,确保加载的预训练权重和实例化的模型class一致。

from vit_model import vit_base_patch16_224 as create_model

运行脚本,默认训练10 epochs, 每轮都会将训练好的权重文件保存至 weights 目录下

模型有 327 MB
用tensorboard 打开 runs 目录下的训练log,如下图所示:

4 预测推理—predict.py

现在我们用训练好的模型进行预测推理,自己从数据集或者网上选择一张图作为输入,预测结果如下图所示:

5 在 ImageNet1K 数据集上进行预测推理

我们可以直接加载官方预训练模型在 ImageNet1K 数据集上进行预测推理,需要准备 imagenet 1k的类别索引 json文件,这里我们从github下载即可:
https://github.com/raghakot/keras-vis/blob/master/resources/imagenet_class_index.json

然后准备好部分的 imagenet1K 数据集作为输入的预测图片,最终效果如下图所示:
|775

在进行 data_transform预处理之后,输入图片数据的最大值为 1,最小值为 -0.97

6 其他未整理的学习资料

  • 23
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
Transformer发轫于NLP(自然语言处理),并跨界应用到CV(计算机视觉)领域。目前已成为深度学习的新范式,影响力和应用前景巨大。  本课程对Transformer的原理和PyTorch代码进行精讲,来帮助大家掌握其详细原理和具体实现。  原理精讲部分包括:注意力机制和自注意力机制、Transformer的架构概述、Encoder的多头注意力(Multi-Head Attention)、Encoder的位置编码(Positional Encoding)、残差链接、层规范化(Layer Normalization)、FFN(Feed Forward Network)、Transformer的训练及性能、Transformer的机器翻译工作流程。   代码精讲部分使用Jupyter Notebook对TransformerPyTorch代码进行逐行解读,包括:安装PyTorchTransformer的Encoder代码解读、Transformer的Decoder代码解读、Transformer的超参设置代码解读、Transformer的训练示例(人为随机数据)代码解读、Transformer的训练示例(德语-英语机器翻译)代码解读。相关课程: 《Transformer原理与代码精讲(PyTorch)》https://edu.csdn.net/course/detail/36697《Transformer原理与代码精讲(TensorFlow)》https://edu.csdn.net/course/detail/36699《ViT(Vision Transformer)原理与代码精讲》https://edu.csdn.net/course/detail/36719《DETR原理与代码精讲》https://edu.csdn.net/course/detail/36768《Swin Transformer实战目标检测:训练自己的数据集》https://edu.csdn.net/course/detail/36585《Swin Transformer实战实例分割:训练自己的数据集》https://edu.csdn.net/course/detail/36586《Swin Transformer原理与代码精讲》 https://download.csdn.net/course/detail/37045
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

雪天鱼

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值