1、问题描述:
最近在魔改一个GNN结构,出现了如下错误:
Traceback (most recent call last):
File "D:\Desktop\Machine\PyTorchTemplate\train_multi_step.py", line 399, in <module>
File "D:\Desktop\Machine\PyTorchTemplate\train_multi_step.py", line 251, in main
metrics = engine.train(tx, ty[:, 0, :, :], id) #Tag: ty only needs the Info, Not need Time EncodedInfo, which 0 is select info
File "D:\Desktop\Machine\PyTorchTemplate\trainer.py", line 53, in train
loss.backward(retain_graph=True)
File "E:\SoftWares\anaconda3\envs\pytorch38\lib\site-packages\torch\_tensor.py", line 307, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "E:\SoftWares\anaconda3\envs\pytorch38\lib\site-packages\torch\autograd\__init__.py", line 154, in backward
Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [307, 307]], which is output 0 of struct torch::autograd::CopySlices, is at version 6; expected version 5 instead. Hint: enable anomaly
detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
2、问题解析
刚开始我上来就是百度、bing、c**n大法查找这个问题,有人说是pytorch版本的问题,于是装了虚拟环境pytorch1.7.1的,结果不行。有人说是什么in_place操作,我看半天代码没有这个东西啊,反正就一直看。
然后,我仔细看报错,它说version 6和version 5,我记得里面好像有个_version
用来记录当前变量版本的,然后一行行debug,就找到了错误所在。
class MGCN(nn.Module):
....
@staticmethod
def process_graph(graph_data): # 这个就是在原始的邻接矩阵之上,再次变换,也就是\hat A = D_{-1/2}*A*D_{-1/2}
N = graph_data.size(0) # 获得节点的个数
matrix_i = torch.eye(N, dtype=torch.float, device=graph_data.device) # 定义[N, N]的单位矩阵
graph_data += matrix_i # [N, N] ,就是 A+I
degree_matrix = torch.sum(graph_data, dim=1, keepdim=False) # [N],计算度矩阵,塌陷成向量,其实就是将上面的A+I每行相加
degree_matrix = degree_matrix.pow(-1) # 计算度矩阵的逆,若为0,-1次方可能计算结果为无穷大的数
degree_matrix[degree_matrix == float("inf")] = 0. # 让无穷大的数为0
degree_matrix = torch.diag(degree_matrix) # 转换成对角矩阵
return torch.mm(degree_matrix, graph_data) # 返回 \hat A=D^(-1) * A ,这个等价于\hat A = D_{-1/2}*A*D_{-1/2}
def forward(self,data,A): # 传入data和矩阵A(A是[307,307],本次出错的地方)
# 因为要对A进行变换
# 1、直接使用以下进行操作,报错
adj = MGCN.process_graph(A) # 变换邻接矩阵 \hat A = D_{-1/2}*A*D_{-1/2}
# 2、修改为以下操作,不报错
adj = MGCN.process_graph(A.data) # 变换邻接矩阵 \hat A = D_{-1/2}*A*D_{-1/2}
...
报错的根本原因在于你传入的是整个对象,中间处理时,会导致其_version
变化,从而导致loss.back()
出错。
3、解决办法
一般对传对象时,如果后期涉及修改值,要注意传入对象的data,而不是整个对象。
或者一步步debug了,注意[307,307]
这个变量可能是谁,然后看它的_version
在哪变化,就是哪有问题。
4、参考资料
1.Pytorch踩坑记:赋值、浅拷贝、深拷贝三者的区别以及model.state_dict()和model.load_state_dict()的坑点