出现IndexError: Only slices (‘:’), list, tuples, torch.tensor and np.ndarray of dtype long or bool are valid indices (got ‘Tensor’)报错
是因为类型错误
报错地方代码如下:
test_dset = GraphDataset(dset[test_fold_idx], n_tags, degree=True)
把数据类型改为long类型的即可
改正后的代码如下:
test_dset = GraphDataset(dset[test_fold_idx.long()], n_tags, degree=True)
即可运行成功