使用GCN并且传入edge weights参数时,反向传播报错PowBackward0
问题描述
在使用torch_geometric库的GCN块时,传入edge weights参数。按照官方文档描述,该参数会将本为0-1二值矩阵的邻接矩阵
A
A
A 中的 1‘s 替换成edge weights中相应的值。
但是在使用edge weights参数的时候,反向传播会报错
RUntimeError: Function 'PowBackward0' returned nan values in its 0th output
后经过设置
with torch.autograd.detect_anomaly():
loss.backward()
发现是GCN层出了错
File "/home/user/anaconda3/envs/pytorch1.12.1-python3.9/lib/python3.9/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 210, in forward
edge_index, edge_weight = gcn_norm( # yapf: disable
File "/home/user/anaconda3/envs/pytorch1.12.1-python3.9/lib/python3.9/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 101, in gcn_norm
deg_inv_sqrt = deg.pow_(-0.5)
经过查看源码,确认是edge _weights的问题
deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes, reduce='sum')
deg_inv_sqrt = deg.pow_(-0.5)
默认情况下edge_wight是一个全1的矩阵,但是在传入edge_weights之后,计算的deg矩阵可能会有负值,对负值开平方根导致反向传播报错
解决思路
一、
使用GCN时不使用edge weights
二、
对edge weights做预处理,确保传入的edge weights均为正值(未实验,猜测应当可以)