报错
Traceback (most recent call last):
File "main.py", line 306, in <module>
args_config=args_config,
File "main.py", line 224, in train
avg_reward,
File "main.py", line 56, in train_one_epoch
selected_neg_items_list, _ = sampler(batch_data, adj_matrix, edge_matrix)
File "/home/zzy/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/home/zzy/KG-Policy/kgpolicy/modules/sampler/kgpolicy.py", line 91, in forward
one_hop, one_hop_logits = self.kg_step(pos, users, adj_matrix, step=1)
File "/home/zzy/KG-Policy/kgpolicy/modules/sampler/kgpolicy.py", line 139, in kg_step
i_e = gcn_embedding[one_hop]
IndexError: tensors used as indices must be long, byte or bool tensors
定位到kgpolicy文件的第139行
one_hop = adj_matrix[pos]
i_e = gcn_embedding[one_hop]
感觉索引indices是one_hop ,检查下one_hop:
print(one_hop)
结果为:
tensor([[ 1295., 6801., 6590., ..., 43463., 63427., 23697.],
[ 2940., 23298., 4720., ..., 54026., 45077., 68521.],
[ 136., 137., 1033., ..., 47605., 59133., 51274.],
...,
[ 1629., 6126., 118., ..., 56004., 23626., 41774.],
[12689., 6415., 21290., ..., 28091., 24405., 37709.],
[ 739., 818., 1202., ..., 54062., 29978., 25789.]])
因为默认的是float类型,所以不对,就直接类型转换为long就可以了.
关于究竟转化为哪种类型,参考链接:
https://blog.csdn.net/junqing_wu/article/details/99692296
https://blog.csdn.net/jacke121/article/details/82703640
https://blog.csdn.net/weixin_38314865/article/details/105949825
one_hop = adj_matrix[pos].type(torch.long)#已修改
i_e = gcn_embedding[one_hop]
ok,这样就可以啦…这么个错误弄了一上午,崩溃…希望能帮到小可爱们~