Vision Transformer

 视频:11.1 Vision Transformer(vit)网络详解_哔哩哔哩_bilibili

Vision Transformer学习笔记_linear projection of flattened patches-CSDN博客

一、embedding 层 

 对于标准的Transformer模块,要求输入的是token (向量)序列,即二维矩阵[num_token,token_dim];

在代码实现中,直接通过一个卷积层来实现以ViT一 B/16为例,使用卷积核大小为16x16,stride为16, 卷积核个数为768;

  • [224, 224, 3] -> [14, 14, 768] -> [196, 768]

在输入Transformer Encoder之前需要加上[class]token 以及Position Embedding,都是可训练参数

  • 拼接[class]token: Cat([1,768],[196,768])->[197,768]
  • 叠加Position Embedding: [197,768]->[197,768]

在这里我画了一个图来解释一下整体过程: 

二、Encoder层

主要完成机制就是多头注意力机制

三、 MLP Head层

 把class token从最终结果[197,768]中切片拿出来,对其进行linear全连接(简单理解),如果需要类别概率的话,可以再接一个softmax

借用我导的图片来总结一下 

四、代码实现 

deep-learning-for-image-processing/pytorch_classification/vision_transformer at master · WZMIAOMIAO/deep-learning-for-image-processing · GitHub

4.1 数据集

使用的数据集为苹果树叶数据集,共有5788张,5个类,分别是苹果树叶的三种疾病类(各1000张)、一个healthy类(1645张)和一个无关类(1143张)

4.2 训练过程

使用patch=16、输入图像为224*224的vit在imagenet21k上预训练过的模型进行了实验,实验设置80%的训练集和20%的测试集

  • train.py文件需要修改的数据集路径和预训练权重

参数设置: epoch=10,batchsize=8,最终的训练结果准确率达到96.7%

训练过程:

生成的权重文件: 

 4.3 预测过程

  • predict.py文件需要修改的预测图片的路径以及训练的权重路径

  •  测试结果如下:

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值