VIT(vision transformer)onnx模型解析

背景:transformer在CV领域的应用

论文下载链接:https://arxiv.org/abs/2010.11929

Pytorch实现代码: pytorch_classification/vision_transformer(太阳花的小绿豆博主实现的代码)

有一些大神在研究关于CNN+transformer或者纯用transformer实现。

原文的摘要说"We show that this reliance on CNNs is not necessary and a pure transformer applied directly to sequences of image patches can perform very well on image classification tasks."(我们展示,这种对 CNN 的依赖是不必要的,直接应用于图像块序列的纯变换器可以很好地执行图像分类任务)

比较具体的内容请看太阳花的小绿豆博主的《Vision Transformer详解》,相关的图片是这个博主的,我这里直接用ONNX的模型结构进行说明,可能更加直观一点(不喜勿碰哈)

  1. VIT整体结构图

  1. VIT形状变化

pytorch的api:summary(model, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
(1) 前处理
            Conv2d-1          [-1, 768, 14, 14]         590,592
          Identity-2             [-1, 196, 768]               0
        PatchEmbed-3             [-1, 196, 768]               0
           Dropout-4             [-1, 197, 768]               0
(2) transformer encoder

block 1
         LayerNorm-5             [-1, 197, 768]           1,536
            Linear-6            [-1, 197, 2304]       1,771,776
           Dropout-7         [-1, 12, 197, 197]               0
            Linear-8             [-1, 197, 768]         590,592
           Dropout-9             [-1, 197, 768]               0
        Attention-10             [-1, 197, 768]               0
         Identity-11             [-1, 197, 768]               0
        LayerNorm-12             [-1, 197, 768]           1,536
           Linear-13            [-1, 197, 3072]       2,362,368
             GELU-14            [-1, 197, 3072]               0
          Dropout-15            [-1, 197, 3072]               0
           Linear-16             [-1, 197, 768]       2,360,064
          Dropout-17             [-1, 197, 768]               0
              Mlp-18             [-1, 197, 768]               0
         Identity-19             [-1, 197, 768]               0
            Block-20             [-1, 197, 768]               0
block 2
        LayerNorm-21             [-1, 197, 768]           1,536
           Linear-22            [-1, 197, 2304]       1,771,776
          Dropout-23         [-1, 12, 197, 197]               0
           Linear-24             [-1, 197, 768]         590,592
          Dropout-25             [-1, 197, 768]               0
        Attention-26             [-1, 197, 768]               0
         Identity-27             [-1, 197, 768]               0
        LayerNorm-28             [-1, 197, 768]           1,536
           Linear-29            [-1, 197, 3072]       2,362,368
             GELU-30            [-1, 197, 3072]               0
          Dropout-31            [-1, 197, 3072]               0
           Linear-32             [-1, 197, 768]       2,360,064
          Dropout-33             [-1, 197, 768]               0
              Mlp-34             [-1, 197, 768]               0
         Identity-35             [-1, 197, 768]               0
            Block-36             [-1, 197, 768]               0
block 3
        LayerNorm-37             [-1, 197, 768]           1,536
           Linear-38            [-1, 197, 2304]       1,771,776
          Dropout-39         [-1, 12, 197, 197]               0
           Linear-40             [-1, 197, 768]         590,592
          Dropout-41             [-1, 197, 768]               0
        Attention-42             [-1, 197, 768]               0
         Identity-43             [-1, 197, 768]               0
        LayerNorm-44             [-1, 197, 768]           1,536
           Linear-45            [-1, 197, 3072]       2,362,368
             GELU-46            [-1, 197, 3072]               0
          Dropout-47            [-1, 197, 3072]               0
           Linear-48             [-1, 197, 768]       2,360,064
          Dropout-49             [-1, 197, 768]               0
              Mlp-50             [-1, 197, 768]               0
         Identity-51             [-1, 197, 768]               0
            Block-52             [-1, 197, 768]               0
block 4
        LayerNorm-53             [-1, 197, 768]           1,536
           Linear-54            [-1, 197, 2304]       1,771,776
          Dropout-55         [-1, 12, 197, 197]               0
           Linear-56             [-1, 197, 768]         590,592
          Dropout-57             [-1, 197, 768]               0
        Attention-58             [-1, 197, 768]               0
         Identity-59             [-1, 197, 768]               0
        LayerNorm-60             [-1, 197, 768]           1,536
           Linear-61            [-1, 197, 3072]       2,362,368
             GELU-62            [-1, 197, 3072]               0
          Dropout-63            [-1, 197, 3072]               0
           Linear-64             [-1, 197, 768]       2,360,064
          Dropout-65             [-1, 197, 768]               0
              Mlp-66             [-1, 197, 768]               0
         Identity-67             [-1, 197, 768]               0
            Block-68             [-1, 197, 768]               0
block 5
        LayerNorm-69             [-1, 197, 768]           1,536
           Linear-70            [-1, 197, 2304]       1,771,776
          Dropout-71         [-1, 12, 197, 197]               0
           Linear-72             [-1, 197, 768]         590,592
          Dropout-73             [-1, 197, 768]               0
        Attention-74             [-1, 197, 768]               0
         Identity-75             [-1, 197, 768]               0
        LayerNorm-76             [-1, 197, 768]           1,536
           Linear-77            [-1, 197, 3072]       2,362,368
             GELU-78            [-1, 197, 3072]               0
          Dropout-79            [-1, 197, 3072]               0
           Linear-80             [-1, 197, 768]       2,360,064
          Dropout-81             [-1, 197, 768]               0
              Mlp-82             [-1, 197, 768]               0
         Identity-83             [-1, 197, 768]               0
            Block-84             [-1, 197, 768]               0
block 6
        LayerNorm-85             [-1, 197, 768]           1,536
           Linear-86            [-1, 197, 2304]       1,771,776
          Dropout-87         [-1, 12, 197, 197]               0
           Linear-88             [-1, 197, 768]         590,592
          Dropout-89             [-1, 197, 768]               0
        Attention-90             [-1, 197, 768]               0
         Identity-91             [-1, 197, 768]               0
        LayerNorm-92             [-1, 197, 768]           1,536
           Linear-93            [-1, 197, 3072]       2,362,368
             GELU-94            [-1, 197, 3072]               0
          Dropout-95            [-1, 197, 3072]               0
           Linear-96             [-1, 197, 768]       2,360,064
          Dropout-97             [-1, 197, 768]               0
              Mlp-98             [-1, 197, 768]               0
         Identity-99             [-1, 197, 768]               0
           Block-100             [-1, 197, 768]               0
block 7
       LayerNorm-101             [-1, 197, 768]           1,536
          Linear-102            [-1, 197, 2304]       1,771,776
         Dropout-103         [-1, 12, 197, 197]               0
          Linear-104             [-1, 197, 768]         590,592
         Dropout-105             [-1, 197, 768]               0
       Attention-106             [-1, 197, 768]               0
        Identity-107             [-1, 197, 768]               0
       LayerNorm-108             [-1, 197, 768]           1,536
          Linear-109            [-1, 197, 3072]       2,362,368
            GELU-110            [-1, 197, 3072]               0
         Dropout-111            [-1, 197, 3072]               0
          Linear-112             [-1, 197, 768]       2,360,064
         Dropout-113             [-1, 197, 768]               0
             Mlp-114             [-1, 197, 768]               0
        Identity-115             [-1, 197, 768]               0
           Block-116             [-1, 197, 768]               0
block 8
       LayerNorm-117             [-1, 197, 768]           1,536
          Linear-118            [-1, 197, 2304]       1,771,776
         Dropout-119         [-1, 12, 197, 197]               0
          Linear-120             [-1, 197, 768]         590,592
         Dropout-121             [-1, 197, 768]               0
       Attention-122             [-1, 197, 768]               0
        Identity-123             [-1, 197, 768]               0
       LayerNorm-124             [-1, 197, 768]           1,536
          Linear-125            [-1, 197, 3072]       2,362,368
            GELU-126            [-1, 197, 3072]               0
         Dropout-127            [-1, 197, 3072]               0
          Linear-128             [-1, 197, 768]       2,360,064
         Dropout-129             [-1, 197, 768]               0
             Mlp-130             [-1, 197, 768]               0
        Identity-131             [-1, 197, 768]               0
           Block-132             [-1, 197, 768]               0
block 9
       LayerNorm-133             [-1, 197, 768]           1,536
          Linear-134            [-1, 197, 2304]       1,771,776
         Dropout-135         [-1, 12, 197, 197]               0
          Linear-136             [-1, 197, 768]         590,592
         Dropout-137             [-1, 197, 768]               0
       Attention-138             [-1, 197, 768]               0
        Identity-139             [-1, 197, 768]               0
       LayerNorm-140             [-1, 197, 768]           1,536
          Linear-141            [-1, 197, 3072]       2,362,368
            GELU-142            [-1, 197, 3072]               0
         Dropout-143            [-1, 197, 3072]               0
          Linear-144             [-1, 197, 768]       2,360,064
         Dropout-145             [-1, 197, 768]               0
             Mlp-146             [-1, 197, 768]               0
        Identity-147             [-1, 197, 768]               0
           Block-148             [-1, 197, 768]               0
block 10
       LayerNorm-149             [-1, 197, 768]           1,536
          Linear-150            [-1, 197, 2304]       1,771,776
         Dropout-151         [-1, 12, 197, 197]               0
          Linear-152             [-1, 197, 768]         590,592
         Dropout-153             [-1, 197, 768]               0
       Attention-154             [-1, 197, 768]               0
        Identity-155             [-1, 197, 768]               0
       LayerNorm-156             [-1, 197, 768]           1,536
          Linear-157            [-1, 197, 3072]       2,362,368
            GELU-158            [-1, 197, 3072]               0
         Dropout-159            [-1, 197, 3072]               0
          Linear-160             [-1, 197, 768]       2,360,064
         Dropout-161             [-1, 197, 768]               0
             Mlp-162             [-1, 197, 768]               0
        Identity-163             [-1, 197, 768]               0
           Block-164             [-1, 197, 768]               0
block 11
       LayerNorm-165             [-1, 197, 768]           1,536
          Linear-166            [-1, 197, 2304]       1,771,776
         Dropout-167         [-1, 12, 197, 197]               0
          Linear-168             [-1, 197, 768]         590,592
         Dropout-169             [-1, 197, 768]               0
       Attention-170             [-1, 197, 768]               0
        Identity-171             [-1, 197, 768]               0
       LayerNorm-172             [-1, 197, 768]           1,536
          Linear-173            [-1, 197, 3072]       2,362,368
            GELU-174            [-1, 197, 3072]               0
         Dropout-175            [-1, 197, 3072]               0
          Linear-176             [-1, 197, 768]       2,360,064
         Dropout-177             [-1, 197, 768]               0
             Mlp-178             [-1, 197, 768]               0
        Identity-179             [-1, 197, 768]               0
           Block-180             [-1, 197, 768]               0
block 12
       LayerNorm-181             [-1, 197, 768]           1,536
          Linear-182            [-1, 197, 2304]       1,771,776
         Dropout-183         [-1, 12, 197, 197]               0
          Linear-184             [-1, 197, 768]         590,592
         Dropout-185             [-1, 197, 768]               0
       Attention-186             [-1, 197, 768]               0
        Identity-187             [-1, 197, 768]               0
       LayerNorm-188             [-1, 197, 768]           1,536
          Linear-189            [-1, 197, 3072]       2,362,368
            GELU-190            [-1, 197, 3072]               0
         Dropout-191            [-1, 197, 3072]               0
          Linear-192             [-1, 197, 768]       2,360,064
         Dropout-193             [-1, 197, 768]               0
             Mlp-194             [-1, 197, 768]               0
        Identity-195             [-1, 197, 768]               0
           Block-196             [-1, 197, 768]               0
(3)后处理
       LayerNorm-197             [-1, 197, 768]           1,536
        Identity-198                  [-1, 768]               0
          Linear-199                    [-1, 5]           3,845
================================================================
Total params: 85,650,437
Trainable params: 85,650,437
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 408.54
Params size (MB): 326.73
Estimated Total Size (MB): 735.84
----------------------------------------------------------------

3. 数据前处理

  1. 3*224*224经过768个16*16的卷积,输出768*14*14

  1. 将输出flatten,768*196(14*14)

  1. 调整通道196*768

  1. 添加class_num(分类信息)1*768,拼接196*768成197*768

  1. 添加位置信息pos,add(shape还是197*768)

4.数据输入到transformer encoder的onnx结构图

关于ONNX里面的op,说实话,有点hold不住,layernorm层搞得很复杂,融合暂时还没有看(后续会研究的,暂时没有时间),反正这个就是transformer encoder(我不管,这个就是)

LayerNorm-5 [-1, 197, 768]

Linear-6 [-1, 197, 2304]

Dropout-7 [-1, 12, 197, 197]

Linear-8 [-1, 197, 768]

Dropout-9 [-1, 197, 768]

Attention-10 [-1, 197, 768]

Identity-11 [-1, 197, 768]

LayerNorm-12 [-1, 197, 768]

Linear-13 [-1, 197, 3072]

GELU-14 [-1, 197, 3072]

Dropout-15 [-1, 197, 3072]

Linear-16 [-1, 197, 768]

Dropout-17 [-1, 197, 768]

Mlp-18 [-1, 197, 768]

Identity-19 [-1, 197, 768]

Block-20 [-1, 197, 768]

5.后处理

LayerNorm-197 [-1, 197, 768]

Identity-198 [-1, 768]

Linear-199 [-1, 5]

那 ,你看,这就是layernorm的op操作(不忍吐槽)

最后接上全连接层,输出结果

总结

其实从OP来看,VIT并没有添加新的算子,只是一些层的拼接,但是效果却是很好,真的,朴实无华的结构,做着深奥的内容,哎,继续学习吧,学无止境!!!相关的ONNX代码,感兴趣的读者多的话,后续可以上传,供大家试用,请关注或者评论(⊙o⊙)哦!!!

class: daisy prob: 0.995

class: dandelion prob: 0.00298

class: roses prob: 0.000599

class: sunflowers prob: 0.000633

class: tulips prob: 0.000771

  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

呆呆珝

您的打赏是我的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值