论文原文:https://arxiv.org/pdf/2010.11929.pdf
中文版知乎解读: [论文笔记] ViT - 知乎
网络基本思想:
通过transformer完成nlp和cv类网络结构的统一。 NLP网络输入序列,通过Encoder和Decoder完成对于序列信息的编码和解码。CV网络中基础为图片(假设图片大小为224×224×3),通过卷积和MLP完成图片信息的抽取和分类(假设为分类网络)。 将Transformer应用与图片上,主要就是替换掉原有的卷积层(更确切地说是通过attention的结构替换原有的Conv),完成图片信息的编码。
Attention的网络结构输入一般为NTC,其中N为batch size, T为序列长度, C为特征信息。 由此ViT网络切片将大图切分为16×16的图片块。即将N×224×224×3 转换为 N × (14 × 14) × (16×16×3)的序列输入,其中14×14可以理解为Token, 16×16×3可以理解为特征信息。
网络基本结构(摘自论文原文)
主体步骤:
1)通过Rearrange+Linear或者(Conv空洞卷积+Linear)完成图片切分和线性变换,切片是将图片切分为16×16的一个一个patch,而线性变换主要是patch转换为一维向量(一般dim和patch_dim大小一致;输出维度: N × 196 × 768。
self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), nn.Linear(patch_dim, dim), ) |
2)添加cls token和位置信息,这两个变量都是可学习的。(其中cls token借鉴了Bert的思想,认为其拥有全局的所有信息,最终通过cls token进行分类。 论文中也说明不使用cls token,通过avg pooling也可以达到相同的效果)
# cls_token shape: (N, 1, dim) cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) x = torch.cat((cls_tokens, x), dim=1) # pos_embedding Shape: (N, 16*16+1, dim) x += self.pos_embedding[:, :(n + 1)] x = self.dropout(x) |
3)transformer结构进行图片信息编码抽取,即Attention结构 + FFN;
attention基本公式:
多头原因:【深度学习笔记】为什么transformer(Bert)的多头注意力要对每一个head进行降维? – Sniper
FFN基本公式:
延伸阅读:
详解Transformer (Attention Is All You Need) - 知乎
https://arxiv.org/pdf/1706.03762.pdf
The Illustrated Transformer – Jay Alammar – Visualizing machine learning one concept at a time.
详解Transformer中Self-Attention以及Multi-Head Attention_霹雳吧啦Wz-CSDN博客_multi-head self-attention
4)MLP 结构给出分类结果。
代码:
1)网络代码
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): super().__init__() image_height, image_width = pair(image_size) patch_height, patch_width = pair(patch_size) assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' num_patches = (image_height // patch_height) * (image_width // patch_width) patch_dim = channels * patch_height * patch_width assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' # Rearrange类似进行多次slice抽取,将图片分割成对应大小的小块图片。通过linear,将特征维度映射到更高的dim维度。 self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), nn.Linear(patch_dim, dim), ) # 生成可训练的pos_embedding和cls_token self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(emb_dropout) self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) self.pool = pool self.to_latent = nn.Identity() self.mlp_head = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, num_classes) ) |
2)Transformer基本结构代码
class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) ])) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return x class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class Attention(nn.Module): def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads # 生成scale,论文里使用根号下dim值,减少QK矩阵乘后数据太大, # 防止softmax反向grad过小,缩放功能。 self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim = -1) # 输入分别和Wq,Wk, Wv权值矩阵运算,生成Q,K,V self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity() def forward(self, x): qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) |
引用: