关于torch.index_select()和torch.gather()函数的使用和区别

前言:

因为和人大合作一个项目,人大小哥哥给我原来的代码做了个简化,因此想记录一下,关于torch.gather()这个函数,感觉突然通了

应用场景

主要是在input:(batch_size,seq_len,embedding_dim)作为输入,进入gru以后返回也是(batch_size,seq_len,embedding_dim),但是由于有padding_id,只想拿到第item_list_len返回的那个隐藏层。

代码

def forward(self, interaction):
        #TODO behavior_list_emb = concat(item,catgory)
        item_list_emb = self.item_list_embedding(interaction[self.ITEM_ID_LIST])
        position_list_emb = self.position_list_embedding(interaction[self.POSITION_ID])
        behavior_list_emb = item_list_emb + position_list_emb
        short_term_intent_temp, _ = self.gru_layers(behavior_list_emb)
        short_term_intent_temp = self.gather_indexes(short_term_intent_temp, interaction[self.ITEM_LIST_LEN] - 1)
        predict_behavior_emb = self.layer_norm(short_term_intent_temp)
        return predict_behavior_emb

    def gather_indexes(self, gru_output, gather_index):
        "Gathers the vectors at the spexific positions over a minibatch"
        gather_index = gather_index.view(-1, 1, 1).expand(-1, -1, self.embedding_size)
        output_tensor = gru_output.gather(dim=1, index=gather_index)
        return output_tensor.squeeze(1)

简单的一个demo的例子解释上面的code:
在这里插入图片描述
再贴两个写的很好的博客链接(以后供自己review,hhh):
先看第一个大概理解
第二个看
最后看官网的例子,基本就可以理解了

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值