前言:工具用不好,万事都烦恼,原本真的就是很简单的一个思路实现,偏偏绕了一圈又一圈,今天就来认识认识Timm库吧!
目录
4.2 Timm库的Vision Transformer活用
4.2.1 vision_transformer.py参数解读
4.2.2 Timm库的调用搭建Vision Transformer
目前,pytorch的知名度就不用说了,vision transformer的知名度,懂的都懂。其实,自己搭建vision transformer原本就是我的初衷,奈何资料根本就是无从下手,主要就是活用问题。今天,突然发现,也没想有想象中那么糟糕,从前人的实验中,还是由迹可寻的。
先放初步学习时候看的一些手把手搭建transformer框架理念和实践的资源:
1.百度飞桨提供的-从零开始学视觉Transformer
理论还是比较详细的,代码实践则是构建在百度子的paddle库上,与pytorch有所区别,但本意还是不变的。
Paddle-Pytorch API对应表:paddle pytorch 对照
2.资源:视觉Transformer优秀的开源工作
视觉 Transformer 优秀开源工作:timm 库 vision transformer 代码解读
python timm库-CSDN博客_python timm
3.如何发现的Timm-Debug
一直在尝试搭建自己的网络模型,却迟迟没有得到自己想要的效果,其实,挺痛苦的。究其原因还是复现前人的工作,复现了很多,但往往也没有深入解读,因为浮在表面,怎么可能去知晓他的真实意图呢?怎么知道他是怎么引用前人的经验呢?就开始了一遍又一遍的Debug,一个点一个点的突破学习,介于空降python的我,很多知识也开始积累起来了,果然,只要思想不滑坡,办法总比困难多!
本次就是科普timm库啦!这是我复现Debug所观察到作者进行Vision Transformer模块时,所进行的操作,根本就不是完全自己搭建哇!有现成的工具,加以整改就好!一切好像又变的美好了起来哈哈哈哈!
4 Timm库
4.1 概念
Timm:pyTorImageModels,简单来说,就是PyTorch的库之一,也算是torchvision.models的扩展模块,面向CV的模型,主要以分类为主。同时,所有的模型都有默认的API。
模型介绍:Model Summaries - Pytorch Image Models (rwightman.github.io)
模型结果:Results - Pytorch Image Models (rwightman.github.io)
4.2 Timm库的Vision Transformer活用
4.2.1 vision_transformer.py参数解读
- img_size: 图像大小,默认224,tuple类型,内部int类型。
- patch_size: Patch size,默认16,tuple类型,内部int类型。
- in_chans: 输入图像的channel数,默认3,int类型。
- num_classes: classification head的分类数,默认1000,int类型。
- embed_dim:Transformer的embedding dimension,默认768, int类型。
- depth: Transformer的Block的数量,默认12,int类型。
- num_head: attention heads的数量,默认12,int类型。
- mlp_ratio: mlp hidden dim/embedding dim的值,默认4, int类型。
- qkv_bias: attention模块计算qkv时,需要bias吗?默认True,bool类型。
- qk_scale: 通常为None。
- drop_rate: dropout rate,默认0,float类型。
- attn_drop_rate: attention模块的dropout rate,默认0,float类型。
- drop_path_rate: 默认0,float类型。
- hybrid_backbone: 在把图像转换成Patch之前,需要先通过一个Backbone吗?默认None。如果是None,就直接把图像转化成Patch。如果不是None,就先通过这个Backbone,再转化成Patch,nn.Module类型。
- norm_layer: 归一化层类型,默认None,nn.Module类型。
- 元组(Tulpe)是Python中另外的一种数据类型,和列表(List)一样也是一组有序对象的集合。
4.2.2 Timm库的调用搭建Vision Transformer
(1)导入必要的库和模型:
import timm
(2)调用timm库中的模型:
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
(3)按需调整,有需求的话,后续整理