关于前沿AI方法transformer、VIT和Swin T的总结
[1].Attention Is All You Need
[2].https://zhuanlan.zhihu.com/p/366592542
[3].代码实现:https://zhuanlan.zhihu.com/p/653170203
[4].An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
[5] Swin Transformer 时间复杂度的分析 https://blog.csdn.net/weixin_45943887/article/details/127881179
[6]. ConvMixer, Trockman, Asher, and J. Zico Kolter. “Patches are all you need?.” arXiv preprint arXiv:2201.09792 (2022).
CSDN: https://blog.csdn.net/baidu_36913330/article/details/120655407
[7]. Yadav, Saurabh, and Koteswar Rao Jerripothula. “FCCNs: Fully Complex-valued Convolutional Networks using Complex-valued Color Model and Loss Function.” Proceedings of the IEEE/CVF International Conference on Computer Vision. 2023.
1.transformer
1.1.注意力机制
An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.[1]
输入是query和 key-value,注意力机制首先计算query与每个key的关联性(compatibility),每个关联性作为每个value的权重(weight),各个权重与value的乘积相加得到输出。
Attention Is All You Need 中用到的attention叫做“Scaled Dot-Product Attention”,具体过程如下图所示:
代码实现:
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query, mask):
N = query.shape[0] # the number of training examples
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# Split embedding into self.heads pieces
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# queries shape: (N, query_len, heads, heads_dim)
# keys shape: (N, key_len, heads, heads_dim)
# energy shape: (N, heads, query_len, key_len)
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
# Fills elements of self tensor with value where mask is True
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.heads * self.head_dim
)
# attention shape: (N, heads, query_len, key_len)
# values shape: (N, value_len, heads, head_dim)
# after einsum (N, query_len, heads, head_dim) then flatten last two dimensions
out = self.fc_out(out)
return out
1.为什么有mask?
NLP处理不定长文本需要padding,但是padding的内容无意义,所以处理时需要mask.
2.关于qkv
qkv是相同的,需要查询的q,与每一个key相乘得到权重信息,权重与v相乘,这样结果受权重大的v影响
3.为什么除以根号dk
We suspect that for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients 4. To counteract this effect, we scale the dot products by 1 √dk
点积过大,经过softmax,进入饱和区,梯度很小
4.为什么需要多头
不同头部的output就是从不同层面(representation subspace)考虑关联性而得到的输出。
4.复杂度分析
1.2.TransformerBlock
解码端的后面两部分和编码段一样,所以打包成一个类
class TransformerBlock(nn.Module):
def __init__(self, embed_size, heads, dropout, forward_expansion):
super(TransformerBlock, self).__init__()
self.attention = SelfAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size)
)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, mask):
attention = self.attention(value, key, query, mask)
x = self.dropout(self.norm1(attention + query))
forward = self.feed_forward(x)
out = self.dropout(self.norm2(forward + x))
return out
1.3.Encoder
关键的就是位置编码
class Encoder(nn.Module):
def __init__(self,
src_vocab_size,
embed_size,
num_layers,
heads,
device,
forward_expansion,
dropout,
max_length
):
super(Encoder, self).__init__()
self.embed_size = embed_size
self.device = device
self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
self.position_embedding = nn.Embedding(max_length, embed_size)
self.layers = nn.ModuleList(
[
TransformerBlock(
embed_size,
heads,
dropout=dropout,
forward_expansion=forward_expansion
)
for _ in range(num_layers)]
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
N, seq_lengh = x.shape
positions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)
out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
for layer in self.layers:
out = layer(out, out, out, mask)
return out
2.VIT
序列化
原始图像尺寸:3×224×224
patch size:16×16(×3)
token数:(224/16)^2 = 14 ^ 2 =196
所以最终序列: (196,768)
在Swin T中,
patch size:4×4(×3)
token数:(224/4)^2 = 56 ^ 2 =3136
所以最终序列: (3136,48)
但是Swin T引入了window的概念,是在window中作self-attention,window size = 7,所以一共有56/7 = 8个窗口,每个窗口有7×7个patch, 每个patch是4×4个像素。
之后Swin T在窗口里做self-attention,序列长度从3136降为49
3.Swin T
Patch partition和Linear Embedding
Patch partition和Linear Embedding是以此卷积就完成的,Patch partition将原始图片分块,Linear Embedding将分块后的48投射到指定维度
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
self.proj = nn.Conv2d(in_chans=3, embed_dim=256, kernel_size=patch_size(2,2), stride=patch_size)(2,2)
x = self.proj(x) #(32,3,32,32) --> (32,256,16,16) if patch_size = 4, (32,256,8,8)
x = x.flatten(2)#(32,256,256)
x = x.transpose(1,2)#(32,256,256) (32,64,256) (B,N,C)
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# (3,128,8,64,32)
q, k, v = qkv[0], qkv[1], qkv[2]
# (128, 8, 64,32)
# (B_, num_heads, N, C // self.num_heads)
patch merge的含义和作用?
上采样,获得多尺度的信息,本质是用空间的维度的减少,换取更多的channel维度。
首先合并patch(并不是相邻的patch合并,而是隔一个),这样HW都减半,C变成4C。为了和ResNet中的池化对应(HW都减半,C变成2C),patch合并后又加了一个1×1的卷积将4C变为2C
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
移动窗口
提出窗口做self-attention,但是窗口间没有通信无法全局建模(没有上下文信息),所以移动窗口。
所以每个基本的Swin T block都包括两次多头self-attention, 一次是W-MSA, 一次是SW-MSA