个人体会: 这个工作有点像是做了一个伪字符级别(定长处理)的结果,在处理上,通过垂直方向的缩小,在水平方向上进行字符级的注意力机制和序列预测。全文的主要工作在于做了一个注意力机制的2D-attention,并且在decoder中进行了运用。实验结果上看可以在不规则文本的数据集上取得良好的性能。
encoder部分
这里不同于常见的encoder论文,这里只取了最后一个特征信息,称之为Holistic Feature (1*1*c)
,这个做法很独特。
@ENCODERS.register_module()
class SAREncoder(BaseEncoder):
def __init__(self):
super().__init__()
# LSTM Encoder
if enc_gru:
self.rnn_encoder = nn.GRU(**kwargs)
else:
self.rnn_encoder = nn.LSTM(**kwargs)
# global feature transformation
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
def init_weights(self):
# initialize weight and bias
def forward(self, feat, img_metas=None):
h_feat = feat.size(2)
#step 1 :压缩了水平方向的向量,维度从(b,c,w,h) -> (b,c,w,h) -> (b,c,w) -> (b,w,c)
feat_v = F.max_pool2d(
feat, kernel_size=(h_feat, 1), stride=1, padding=0)
feat_v = feat_v.squeeze(2) # bsz * C * W
feat_v = feat_v.permute(0, 2, 1).contiguous() # bsz * W * C
#step 2 : 将1D序列放入循环序列网络中,提取上下文关系依赖
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
#step 3 : 提取序列中的最后一个
valid_hf = holistic_feat[:, -1, :] # bsz * C
holistic_feat = self.linear(valid_hf) # bsz * C
return holistic_feat
decoder部分
作者把这个Holistic Feature的片段作为开始序列,(这里我理解为start字符的开始位置为特征信息所做的解码,这里用编码的部分作为start去模拟这一方式,个人理解),之后使用ground truth作为后序输入信息(图中y所示),代码中将两个部分进行了cat拼接操作。
def forward_train(self, feat, out_enc, targets_dict, img_metas):
#step 1 : 将target部分词嵌入,并输入到对应的编码器中
targets = targets_dict['padded_targets'].to(feat.device)
tgt_embedding = self.embedding(targets)
# bsz * seq_len * emb_dim
#step 2 : 将encoder输出的部分调整成相应维度并进行拼接
out_enc = out_enc.unsqueeze(1)
# bsz * 1 * emb_dim
in_dec = torch.cat((out_enc, tgt_embedding), dim=1)
# bsz * (seq_len + 1) * C
#step 3 : 送入2Dattention机制中
out_dec = self._2d_attention(
in_dec, feat, out_enc, valid_ratios=valid_ratios)
# bsz * (seq_len + 1) * num_classes
return out_dec[:, 1:, :] # bsz * seq_len * num_classes
def forward_test(self, feat, out_enc, img_metas):
seq_len = self.max_seq_len
bsz = feat.size(0)
#step 1 : 用开始字符去替代target,并进行同样处理
start_token = torch.full((bsz, ),
self.start_idx,
device=feat.device,
dtype=torch.long)
# bsz
start_token = self.embedding(start_token)
# bsz * emb_dim
start_token = start_token.unsqueeze(1).expand(-1, seq_len, -1)
#step 2 : 将encoder部分作为第一个序列,进行输入
# bsz * seq_len * emb_dim
out_enc = out_enc.unsqueeze(1)
# bsz * 1 * emb_dim
decoder_input = torch.cat((out_enc, start_token), dim=1)
# bsz * (seq_len + 1) * emb_dim
#step 3 : 通过循环的方式不断输入到2d attention中,进行预测
outputs = []
for i in range(1, seq_len + 1):
decoder_output = self._2d_attention(
decoder_input, feat, out_enc, valid_ratios=valid_ratios)
char_output = decoder_output[:, i, :] # bsz * num_classes
char_output = F.softmax(char_output, -1)
outputs.append(char_output)
_, max_idx = torch.max(char_output, dim=1, keepdim=False)
char_embedding = self.embedding(max_idx) # bsz * emb_dim
if i < seq_len:
decoder_input[:, i + 1, :] = char_embedding
outputs = torch.stack(outputs, 1) # bsz * seq_len * num_classes
return outputs
2D-attention部分:
这里的2D-attention其实就是一个注意力机制,Hidden_state作为一个Query,Feature Map作为一个Key和Value存在。(这里的hidden states 第一个字符为encoder的输出Holistic Feature,之后的为解码器的LSTM的输出)注意力公式可以列成
step 1: 加性模型
a = tanh(W Hidden_state + W Feature Map )
step 2 :注意力权重
a = softmax(a)
step 3 : 相乘(Feature Map可以理解成Value)
output = a* Feature Map
如果从K,Q,V的角度来理解这个注意力机制,即用decoder的结果去做CNN feature Map的attention机制,即(CNN encoder decoder )+CNN 做attention
def _2d_attention(self,
decoder_input,
feat,
holistic_feat,
valid_ratios=None):
y = self.rnn_decoder(decoder_input)[0]
# y: bsz * (seq_len + 1) * hidden_size
#step 1 : 如图所示,通过卷积,获得query(hidden states)
#(这里的hidden states 第一个字符为encoder的输出Holistic Feature,之后的为解码器的LSTM的输出)
attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
bsz, seq_len, attn_size = attn_query.size()
attn_query = attn_query.view(bsz, seq_len, attn_size, 1, 1)
#step 2: feat 为backbone的输出,看成key,这里作为attention的feature map出现
attn_key = self.conv3x3_1(feat)
# bsz * attn_size * h * w
attn_key = attn_key.unsqueeze(1)
# bsz * 1 * attn_size * h * w
#step 3 : 做加性模型的注意力机制
attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1))
# bsz * (seq_len + 1) * attn_size * h * w
attn_weight = attn_weight.permute(0, 1, 3, 4, 2).contiguous()
# bsz * (seq_len + 1) * h * w * attn_size
attn_weight = self.conv1x1_2(attn_weight)
# bsz * (seq_len + 1) * h * w * 1
bsz, T, h, w, c = attn_weight.size()
assert c == 1
#掩码部分,去填充图谱未缩放所造成的问题
if valid_ratios is not None:
# cal mask of attention weight
attn_mask = torch.zeros_like(attn_weight)
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(w, math.ceil(w * valid_ratio))
attn_mask[i, :, :, valid_width:, :] = 1
attn_weight = attn_weight.masked_fill(attn_mask.bool(),
float('-inf'))
#制作掩码部分
attn_weight = attn_weight.view(bsz, T, -1)
attn_weight = F.softmax(attn_weight, dim=-1)
attn_weight = attn_weight.view(bsz, T, h, w,
c).permute(0, 1, 4, 2, 3).contiguous()
#step 3 : 和掩码相加,并作为注意力权重进行相乘
attn_feat = torch.sum(
torch.mul(feat.unsqueeze(1), attn_weight), (3, 4), keepdim=False)
# bsz * (seq_len + 1) * C
#step 4 : 预测最终结果
# linear transformation
if self.pred_concat:
hf_c = holistic_feat.size(-1)
holistic_feat = holistic_feat.expand(bsz, seq_len, hf_c)
y = self.prediction(torch.cat((y, attn_feat, holistic_feat), 2))
else:
y = self.prediction(attn_feat)
# bsz * (seq_len + 1) * num_classes
if self.train_mode:
y = self.pred_dropout(y)
return y