前言:
因为和人大合作一个项目,人大小哥哥给我原来的代码做了个简化,因此想记录一下,关于torch.gather()这个函数,感觉突然通了
应用场景
主要是在input:(batch_size,seq_len,embedding_dim)作为输入,进入gru以后返回也是(batch_size,seq_len,embedding_dim),但是由于有padding_id,只想拿到第item_list_len返回的那个隐藏层。
代码
def forward(self, interaction):
#TODO behavior_list_emb = concat(item,catgory)
item_list_emb = self.item_list_embedding(interaction[self.ITEM_ID_LIST])
position_list_emb = self.position_list_embedding(interaction[self.POSITION_ID])
behavior_list_emb = item_list_emb + position_list_emb
short_term_intent_temp, _ = self.gru_layers(behavior_list_emb)
short_term_intent_temp = self.gather_indexes(short_term_intent_temp, interaction[self.ITEM_LIST_LEN] - 1)
predict_behavior_emb = self.layer_norm(short_term_intent_temp)
return predict_behavior_emb
def gather_indexes(self, gru_output, gather_index):
"Gathers the vectors at the spexific positions over a minibatch"
gather_index = gather_index.view(-1, 1, 1).expand(-1, -1, self.embedding_size)
output_tensor = gru_output.gather(dim=1, index=gather_index)
return output_tensor.squeeze(1)
简单的一个demo的例子解释上面的code:
再贴两个写的很好的博客链接(以后供自己review,hhh):
先看第一个大概理解
第二个看
最后看官网的例子,基本就可以理解了