Pytorch的_version问题

1、问题描述:

最近在魔改一个GNN结构,出现了如下错误:

Traceback (most recent call last):
  File "D:\Desktop\Machine\PyTorchTemplate\train_multi_step.py", line 399, in <module>
  File "D:\Desktop\Machine\PyTorchTemplate\train_multi_step.py", line 251, in main
    metrics = engine.train(tx, ty[:, 0, :, :], id)  #Tag: ty only needs the Info, Not need Time EncodedInfo,  which 0 is select info   
  File "D:\Desktop\Machine\PyTorchTemplate\trainer.py", line 53, in train
    loss.backward(retain_graph=True)
  File "E:\SoftWares\anaconda3\envs\pytorch38\lib\site-packages\torch\_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "E:\SoftWares\anaconda3\envs\pytorch38\lib\site-packages\torch\autograd\__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [307, 307]], which is output 0 of struct torch::autograd::CopySlices, is at version 6; expected version 5 instead. Hint: enable anomaly 
detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

2、问题解析

刚开始我上来就是百度、bing、c**n大法查找这个问题,有人说是pytorch版本的问题,于是装了虚拟环境pytorch1.7.1的,结果不行。有人说是什么in_place操作,我看半天代码没有这个东西啊,反正就一直看。

然后,我仔细看报错,它说version 6和version 5,我记得里面好像有个_version用来记录当前变量版本的,然后一行行debug,就找到了错误所在。

class MGCN(nn.Module):
	....
    @staticmethod
    def process_graph(graph_data):  # 这个就是在原始的邻接矩阵之上,再次变换,也就是\hat A = D_{-1/2}*A*D_{-1/2}
        N = graph_data.size(0) # 获得节点的个数
        matrix_i = torch.eye(N, dtype=torch.float, device=graph_data.device)  # 定义[N, N]的单位矩阵
        graph_data += matrix_i  # [N, N]  ,就是 A+I

        degree_matrix = torch.sum(graph_data, dim=1, keepdim=False)  # [N],计算度矩阵,塌陷成向量,其实就是将上面的A+I每行相加
        degree_matrix = degree_matrix.pow(-1)  # 计算度矩阵的逆,若为0,-1次方可能计算结果为无穷大的数
        degree_matrix[degree_matrix == float("inf")] = 0.  # 让无穷大的数为0

        degree_matrix = torch.diag(degree_matrix)  # 转换成对角矩阵
        return torch.mm(degree_matrix, graph_data)  # 返回 \hat A=D^(-1) * A ,这个等价于\hat A = D_{-1/2}*A*D_{-1/2}
	def forward(self,data,A):	# 传入data和矩阵A(A是[307,307],本次出错的地方)
		# 因为要对A进行变换
		# 1、直接使用以下进行操作,报错
		adj = MGCN.process_graph(A)  # 变换邻接矩阵 \hat A = D_{-1/2}*A*D_{-1/2}
		# 2、修改为以下操作,不报错
		adj = MGCN.process_graph(A.data)  # 变换邻接矩阵 \hat A = D_{-1/2}*A*D_{-1/2}
	...

报错的根本原因在于你传入的是整个对象,中间处理时,会导致其_version变化,从而导致loss.back()出错。

3、解决办法

一般对传对象时,如果后期涉及修改值,要注意传入对象的data,而不是整个对象。
或者一步步debug了,注意[307,307]这个变量可能是谁,然后看它的_version在哪变化,就是哪有问题。

4、参考资料

1.Pytorch踩坑记:赋值、浅拷贝、深拷贝三者的区别以及model.state_dict()和model.load_state_dict()的坑点

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值