vit-pytorch实现 MobileViT注意力可视化

项目链接 https://github.com/lucidrains/vit-pytorch

注意一下参数设置:

Parameters

  1. image_size: int.
    Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
  2. patch_size: int.
    Number of patches. image_size must be divisible by patch_size.
    The number of patches is: n = (image_size // patch_size) ** 2 and n must be greater than 16.
  3. num_classes: int.
    Number of classes to classify.
  4. dim: int.
    Last dimension of output tensor after linear transformation nn.Linear(…, dim).
  5. depth: int.
    Number of Transformer blocks.
  6. heads: int.
    Number of heads in Multi-head Attention layer.
  7. mlp_dim: int.
    Dimension of the MLP (FeedForward) layer.
  8. channels: int, default 3.
    Number of image’s channels.
  9. dropout: float between [0, 1], default 0…
    Dropout rate.
  10. emb_dropout: float between [0, 1], default 0.
    Embedding dropout rate.
  11. pool: string, either cls token pooling or mean pooling

image_size:表示图像大小的整数。图片应该是正方形的,并且image_size必须是宽度和高度中的最大值。
patch_size:表示补丁大小的整数。image_size必须能被 整除patch_size。补丁的数量计算为n =
(image_size // patch_size) ** 2并且n必须大于 16。 num_classes:一个整数,表示要分类的类数。
dim:一个整数,表示线性变换后输出张量的最后一维nn.Linear(…, dim)。 depth:一个整数,表示
Transformer 块的数量。 heads:一个整数,表示多头注意力层中的头数。 mlp_dim:一个整数,表示
MLP(前馈)层的维度。 channels:一个整数,表示图像中的通道数,默认值为3。 dropout:一个介于 0 和 1
之间的浮点数,代表辍学率。 emb_dropout:一个介于 0 和 1 之间的浮点数,表示嵌入丢失率。
pool:表示池化方法的字符串,可以是“cls token pooling”或“mean pooling”。

快速使用实例

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

SimpleViT

来自原论文的一些作者的更新建议对ViT进行简化,使其能够更快更好地训练。

这些简化包括2d正弦波位置嵌入、全局平均池(无CLS标记)、无辍学、批次大小为1024而不是4096,以及使用RandAugment和MixUp增强。他们还表明,最后的简单线性并不明显比原始MLP头差。

你可以通过导入SimpleViT来使用它,如下图所示

import torch
from vit_pytorch import SimpleViT

v = SimpleViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

可视化
Accessing Attention
If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below

import torch
from vit_pytorch.vit import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

# import Recorder and wrap the ViT

from vit_pytorch.recorder import Recorder
v = Recorder(v)

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 256, 256)
preds, attns = v(img)

# there is one extra patch due to the CLS token

attns # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)

在这里插入图片描述
本文介绍了 MobileViT,一种用于移动设备的轻量级通用视觉转换器。MobileViT 为全球信息处理与转换器提供了不同的视角。

您可以将其与以下代码一起使用(例如 mobilevit_xs)

import torch
from vit_pytorch.mobile_vit import MobileViT

mbvit_xs = MobileViT(
    image_size = (256, 256),
    dims = [96, 120, 144],
    channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
    num_classes = 1000
)

img = torch.randn(1, 3, 256, 256)

pred = mbvit_xs(img) # (1, 1000)
  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
vit-pytorch是一个Python库,用于实现Vision Transformer(ViT)模型。ViT是一种基于Transformer架构的图像分类模型,它将图像分割成小的图像块,并使用Transformer编码器来学习图像的表示。ViT在计算机视觉任务中取得了很好的效果,特别是在图像分类任务中。 要使用vit-pytorch进行图像分类,首先需要安装该库。你可以按照官方提供的安装方法进行安装,链接为:https://lanzao.blog.csdn.net/article/details/101784059。 在使用vit-pytorch进行图像分类时,你需要创建一个VisionTransformer的实例,并在其初始化函数中设置一些参数。其中包括class token(用于表示整个图像的特殊标记)、dist token(用于蒸馏模型的特殊标记)和位置编码。位置编码是为了将图像块的位置信息引入模型中。 下面是一个示例代码,展示了如何使用vit-pytorch进行图像分类: ```python import torch import torch.nn as nn from vit_pytorch import VisionTransformer # 设置一些参数 num_patches = 16 # 图像分割成的图像块数量 embed_dim = 256 # 嵌入维度 drop_ratio = 0.1 # Dropout比率 distilled = False # 是否使用蒸馏模型 # 创建VisionTransformer实例 model = VisionTransformer( num_patches=num_patches, embed_dim=embed_dim, drop_ratio=drop_ratio, distilled=distilled ) # 输入图像数据 input_data = torch.randn(1, 3, 224, 224) # 假设输入图像大小为224x224,通道数为3 # 前向传播 output = model(input_data) # 输出分类结果 print(output) ``` 这是一个基本的使用vit-pytorch进行图像分类的示例。你可以根据自己的需求进行参数设置和模型调整。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值