topk(num,dim)

topk(num,dim=1)

>>> output=torch.randn(3,4)
>>> output
tensor([[-1.9291,  1.4127, -2.2464,  0.8932],
        [-0.4483, -0.3458,  0.8384,  1.9580],
        [-0.5633, -2.2806,  0.6278,  1.3552]])
在行上取一个最大值
>>> topkv,topki=output.topk(1,1)
>>> topkv
tensor([[1.4127],
        [1.9580],
        [1.3552]])
>>> topki
tensor([[1],
        [3],
        [3]])
在行上取前两个最大值
>>> topkv,topki=output.topk(2,1)
>>> topkv
tensor([[1.4127, 0.8932],
        [1.9580, 0.8384],
        [1.3552, 0.6278]])
>>> topki
tensor([[1, 3],
        [3, 2],
        [3, 2]])


topk(num,dim=0)

>>> output=torch.randn(3,4)
>>> output
tensor([[-1.9291,  1.4127, -2.2464,  0.8932],
        [-0.4483, -0.3458,  0.8384,  1.9580],
        [-0.5633, -2.2806,  0.6278,  1.3552]])
在列上取一个最大值        
>>> topkv,topki=output.topk(1,0)
>>> topkv
tensor([[-0.4483,  1.4127,  0.8384,  1.9580]])
>>> topki
tensor([[1, 0, 1, 1]])
在列上取两个最大值
>>> topkv,topki=output.topk(2,0)
>>> topkv
tensor([[-0.4483,  1.4127,  0.8384,  1.9580],
        [-0.5633, -0.3458,  0.6278,  1.3552]])
>>> topki
tensor([[1, 0, 1, 1],
        [2, 1, 2, 2]])
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
解释下面这段代码: for i, edge_index in enumerate(edge_index_sets): edge_num = edge_index.shape[1] cache_edge_index = self.cache_edge_index_sets[i] if cache_edge_index is None or cache_edge_index.shape[1] != edge_num*batch_num: self.cache_edge_index_sets[i] = get_batch_edge_index(edge_index, batch_num, node_num).to(device) batch_edge_index = self.cache_edge_index_sets[i] all_embeddings = self.embedding(torch.arange(node_num).to(device)) weights_arr = all_embeddings.detach().clone() all_embeddings = all_embeddings.repeat(batch_num, 1) weights = weights_arr.view(node_num, -1) cos_ji_mat = torch.matmul(weights, weights.T) normed_mat = torch.matmul(weights.norm(dim=-1).view(-1,1), weights.norm(dim=-1).view(1,-1)) cos_ji_mat = cos_ji_mat / normed_mat dim = weights.shape[-1] topk_num = self.topk topk_indices_ji = torch.topk(cos_ji_mat, topk_num, dim=-1)[1] self.learned_graph = topk_indices_ji gated_i = torch.arange(0, node_num).T.unsqueeze(1).repeat(1, topk_num).flatten().to(device).unsqueeze(0) gated_j = topk_indices_ji.flatten().unsqueeze(0) gated_edge_index = torch.cat((gated_j, gated_i), dim=0) batch_gated_edge_index = get_batch_edge_index(gated_edge_index, batch_num, node_num).to(device) gcn_out = self.gnn_layers[i](x, batch_gated_edge_index, node_num=node_num*batch_num, embedding=all_embeddings) gcn_outs.append(gcn_out) x = torch.cat(gcn_outs, dim=1) x = x.view(batch_num, node_num, -1) indexes = torch.arange(0,node_num).to(device) out = torch.mul(x, self.embedding(indexes)) out = out.permute(0,2,1) out = F.relu(self.bn_outlayer_in(out)) out = out.permute(0,2,1) out = self.dp(out) out = self.out_layer(out) out = out.view(-1, node_num) return out
最新发布
04-19

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值