Vision Transformer(1):ViT源码逐行阅读解析

 上图是Vision Transformer原文的模型结构展示,可以看到模型包含了几个核心模块:

 Vision Transformer:

        1. Embedding模块

        2.Transformer Encoder模块

                2.1 NormLayer ( × depth )

                        2.1.1 Multi-Head Attention层

                                 关于Attention机制的详细解析

                        2.1.2 MLP多层感知器

        3.MLP-Head 模块映射为类别

自底向上摸索是在未知中探索的不可缺少的方式,但通过摸索后,发现自顶向下能更好的阐述清楚整个逻辑。

一、ViT & Embedding

假设训练数据维度为(64, 3, 224, 224),意味着有64张三通道的224*224的图像。

设定参数dim=128意味着编码向量长度为128。

ViT中出现的PreNorm、Attention、FeedForward、Transformer后续解释

class ViT(nn.Module):
    '''
    :param
        *: input data
        image_size: 等边图像尺寸
        patch_size: patch的尺寸
        num_classes: 分类类别
        dim: 为每一个patch编码的长度
        depth: Encoder的深度,也就是连接encoder的数目
        heads: 多头注意力中头的数目
        mlp_dim: 多层感知器中隐含层的维度
        pool: 使用cls token还是使用均值池化
        channel: 图像的通道数
        dim_head: 注意力机制中一个头的输入维度
        dropout: NormLayer中dropout的参数比例
        emb_dropout: Embedding中的dropout比例
    :return 分类结果(64, 2)
    '''
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        # image_size就是每一张图像的长和宽,通过pair函数便捷明了的表现
        # patch_size就是图像的每一个patch的长和宽
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        # 保证图像可以整除为若干个patch
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        # 计算出每一张图片会被切割为多少个patch
        # 假设输入维度(64, 3, 224, 224), num_patches = 49
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 每一个patch数组大小, patch_dim = 3*32*32=3072
        patch_dim = channels * patch_height * patch_width
        # cls就是分类的Token, mean就是均值池化
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        # embeding操作:假设输入维度(64, 3, 224, 224),那么经过Rearange层后变成了(64, 7*7=49, 32*32*3=3072)
        self.to_patch_embedding = nn.Sequential(
            # 将图片分割为b*h*w个三通道patch,b表示输入图像数量
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            # 经过线性全连接后,维度变成(64, 49, 128)
            nn.Linear(patch_dim, dim),
        )
        # dim张图像,每张图像需要num_patches个向量进行编码
        # 位置编码(1, 50, 128) 本应该为49,但因为cls表示类别需要增加一个
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # CLS类别token,(1, 1, 128)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        # 设置dropout
        self.dropout = nn.Dropout(emb_dropout)
        # 初始化Transformer
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        # pool默认是cls进行分类
        self.pool = pool
        self.to_latent = nn.Identity()
        # 多层感知用于将最终特征映射为2个类别
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        # 第一步,原始图像ebedding,进行了图像切割以及线性变换,变成x->(64, 49, 128)
        x = self.to_patch_embedding(img)
        # 得到原始图像数目和单图像的patches数量, b=64, n=49
        b, n, _ = x.shape
        # (1, 1, 128) -> (64, 1, 128) 为每一张图像设置一个cls的token
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        # 将cls token加入到数据中 -> (64, 50, 128)
        x = torch.cat((cls_tokens, x), dim=1)
        # x(64, 50, 128)添加位置编码(1, 50, 128)
        x += self.pos_embedding[:, :(n + 1)]
        # 经过dropout层防止过拟合
        x = self.dropout(x)

        x = self.transformer(x)
        # 进行均值池化
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        # 最终进行分类映射
        return self.mlp_head(x)

二、Transformer

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        # 设定depth个encoder相连,并添加残差结构
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        # 每次取出包含Norm-attention和Norm-mlp这两个的ModuleList,实现残差结构
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

1.Norm层

class PreNorm(nn.Module):
    '''
    :param  dim 输入维度
            fn 前馈网络层,选择Multi-Head Attn和MLP二者之一
    '''
    def __init__(self, dim, fn):
        super().__init__()
        # LayerNorm: ( a - mean(last 2 dim) ) / sqrt( var(last 2 dim) )
        # 数据归一化的输入维度设定,以及保存前馈层
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    # 前向传播就是将数据归一化后传递给前馈层
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

2.MLP层

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

3.Attention层

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = heads * dim_head
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        # 表示1/(sqrt(dim_head))用于消除误差,保证方差为1,避免向量内积过大导致的softmax将许多输出置0的情况
        # 可以看原文《attention is all you need》中关于Scale Dot-Product Attention如何抑制内积过大
        self.scale = dim_head ** -0.5
        # dim =  > 0 时,表示mask第d维度,对相同的第d维度,进行softmax
        # dim =  < 0 时,表示mask倒数第d维度,对相同的倒数第d维度,进行softmax
        self.attend = nn.Softmax(dim = -1)
        # 生成qkv矩阵,三个矩阵被放在一起,后续会被分开
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        # 如果是多头注意力机制则需要进行全连接和防止过拟合,否则输出不做更改
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        # 分割成q、k、v三个矩阵
        # qkv为 inner_dim * 3,其中inner_dim = heads * dim_head
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        # qkv的维度是(3, inner_dim = heads * dim_head)
        # 'b n (h d) -> b h n d' 重新按思路分离出8个头,一共8组q,k,v矩阵
        # rearrange后维度变成 (3, heads, dim, dim_head)
        # 经过map后,q、k、v维度变成(1, heads, dim, dim_head)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        # query * key 得到对value的注意力预测,并通过向量内积缩放防止softmax无效化部分参数
        # heads * dim * dim
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        # 对最后一个维度进行softmax后得到预测的概率值
        attn = self.attend(dots)
        # 乘积得到预测结果
        # out -> heads * dim * dim_head
        out = torch.matmul(attn, v)
        # 重组张量,将heads维度重新还原
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

4.其他部分

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

三、MLP-Head模块

self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

  • 16
    点赞
  • 64
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
这段代码是用于实现Vision Transformer框架的一部分功能,具体逐行解析如下: 1. `conv_output = F.conv2d(image, kernel, stride=stride)`: 这一行代码使用PyTorch中的卷积函数`F.conv2d`来对输入图像进行卷积操作。 2. `bs, oc, oh, ow = conv_output.shape`: 这一行代码通过`conv_output.shape`获取卷积输出张量的形状信息,其中`bs`表示批次大小,`oc`表示输出通道数,`oh`和`ow`分别表示输出张量的高度和宽度。 3. `patch_embedding = conv_output.reshape((bs, oc, oh*ow))`: 这一行代码通过`reshape`函数将卷积输出张量进行形状变换,将其转换为形状为`(bs, oc, oh*ow)`的张量。 4. `patch_embedding = patch_embedding.transpose(-1, -2)`: 这一行代码使用`transpose`函数交换张量的最后两个维度,将形状为`(bs, oh*ow, oc)`的张量转换为`(bs, oc, oh*ow)`的张量。 5. `weight = weight.transpose(0, 1)`: 这一行代码将权重张量进行转置操作,交换第0维和第1维的位置。 6. `kernel = weight.reshape((-1, ic, patch_size, patch_size))`: 这一行代码通过`reshape`函数将权重张量进行形状变换,将其转换为形状为`(outchannel*inchannel, ic, patch_size, patch_size)`的张量。 7. `patch_embedding_conv = image2emb_conv(image, kernel, patch_size)`: 这一行代码调用了`image2emb_conv`函数,并传入了图像、权重张量和补丁大小作为参数。 8. `print(patch_embedding_conv.shape)`: 这一行代码打印了`patch_embedding_conv`的形状信息。 以上是对Vision Transformer代码的逐行解析

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

尼卡尼卡尼

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值