视觉transfomer怎么对一张图去提取特征?
1.把图片切分成很多个子图,每个子图用卷积神经网络去提取特征
2.假设卷积网络里面有100个卷积核,那最终一张子图就可以得到100维的特征向量
3.此时计算出来的特征没有考虑到其他图片的信息,因此使用self_attn
4.对于分类任务,我们希望得到一个全局的信息,因此再加入一个cls向量,去和其他每张图去算self_attn,分析在分类过程中每张子图的重要程度。
transformer层网络结构
def forward(self, input_ids):
embedding_output = self.embeddings(input_ids)
encoded, attn_weights = self.encoder(embedding_output)
return encoded, attn_weights
嵌入层
输入
每个batch的输入:16组3通道的224*224大小的图片像素数据
输出是16组图片的分类概率分布
特征提取
要对其进行切割成子图并进行特征提取,使用如下卷积神经网络完成:
Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
224*224/(16*16)=14*14 就是说利用卷积核将原图切分为14*14张子图,每张子图大小就是卷积核大小,然后有768个卷积核,如果将每一张子图看作一个特征,每一张原图可以得到768个特征
卷积网络的输出为:
在卷积网络中,每个卷积核可以得到来自原图的一张特征图(每个特征来自原图的不同子图,14*14=196就表示特征图)
标准化
1,需要加入cls
2.需要转化为[batch,sequence,hidden]的形式,也即特征序列
3.特征序列包含位置编码信息
x = x.flatten(2) #x.shape=[16,768,196]拉平操作,将196张子图的特征按照顺序排列,同时每个子图有768个(特征)卷积核
x = x.transpose(-1, -2) #转置x.shape=[16,196,768]x = torch.cat((cls_tokens, x), dim=1) #加入cls(参数) x.shape=[16,197,768]embeddings = x + self.position_embeddings 加入位置编码(参数)
嵌入层的输出:
16个batch的cls+196张子图的特征向量,每张图用768个卷积核提取共有768维度
编码层
def forward(self, hidden_states):
print(hidden_states.shape)
attn_weights = []
for layer_block in self.layer: #self.layer是专门的网络列表
hidden_states, weights = layer_block(hidden_states)
if self.vis:
attn_weights.append(weights)
encoded = self.encoder_norm(hidden_states)
return encoded, attn_weights
编码层由多个子层block堆叠而成,嵌入层提取的特征首先经过若干个block层,然后再经过一个layernorm层
block层
#h是用于残差计算的备份输入
h = x
x = self.attention_norm(x) #LayerNorm模块
x, weights = self.attn(x) #多头自注意力机制
x = x + h
h = x
x = self.ffn_norm(x) #LayerNorm模块
x = self.ffn(x) #输出线形层
x = x + h
return x, weights
Norm
实际上是nn模块里面的LayerNorm
按照序列的维度,对同序列内的特征进行调整,使其分布符合归一化至均值为0,方差为1。
多头自注意力层
先进行Q、K、V的提取,均由网络参数的线性层来提取
然后把Q、K、V切分成多头768=12*64
把12也看成是batch,此时的QKV是197*64的矩阵
算出注意力值之后再经过线形层
输出用于分类的logits
transformer输出头
def forward(self, x, labels=None):
x, attn_weights = self.transformer(x)
print(x.shape)
logits = self.head(x[:, 0])
print(logits.shape)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
return loss
else:
return logits, attn_weights
也就是把拿到的logits通过线性层映射为概率(非0-1下的)