TypeError: inc() takes 3 positional arguments but 4 were given
在复现别人代码时程序报错:TypeError: __inc__() takes 3 positional arguments but 4 were given
。
Debug发现是site-pacakages/torch_geometric/data/collate.py/get_incs
方法出了问题。
原因分析:
猜测是因为PYG版本的问题。
源代码推荐PYG版本为1.6,而我的PYG版本为2.0。
根据github确定为版本问题。当PYG>=2.0以后data.__inc__的参数发生变化。
解决方案:
两种解决方案
一、如果你能找到你代码中的__cat_dim__()(注意不是源文件的__cat_dim__()),你可以参考这篇文章PyG中自定义Data的注意事项(cat_dim)。
二、如果你代码中没有__cat_dim__(),你可以将你site-pacakages/torch_geometric/data/collate.py/get_incs
(源文件)中def get_incs()方法按照如下方式修改。
def get_incs(key, values: List[Any], data_list: List[BaseData],
stores: List[BaseStorage]) -> Tensor:
repeats = [
data.__inc__(key, value)
for value, data, store in zip(values, data_list, stores)
]
if isinstance(repeats[0], Tensor):
repeats = torch.stack(repeats, dim=0)
else:
repeats = torch.tensor(repeats)
return cumsum(repeats[:-1])