问题背景:该项目是基于图神经网络深度学习模型,前几个月代码跑通过没有任何问题,最近想重新训练得到新的数据,结果频频报错,以下是修改历程:
问题一:TypeError: forward() missing 1 required positional argument: 'batch'
问题原因分析:
-
在
gnnexplainer.py
的gnn_explainer_alg
函数中,调用self.model(x=h, edge_index=edge_index, batch=batch, **kwargs)
时,batch
参数未被定义。 -
gnn_explainer_alg
函数的参数列表中未包含batch
,导致无法从调用方接收该参数。
解决方案:
1.修改gnn_explainer_alg
函数定义,添加batch
参数:
def gnn_explainer_alg(self,
x: Tensor,
edge_index: Tensor,
ex_label: Tensor,
mask_features: bool = False,
batch: Tensor = None, # 添加batch参数
**kwargs
) -> Tensor:
2.在调用self.model
时使用传入的batch
参数:
raw_preds = self.model(x=h, edge_index=edge_index, batch=batch, **kwargs)
3.确保forward
函数调用gnn_explainer_alg
时传递batch
参数:
edge_mask = self.gnn_explainer_alg(x, edge_index, ex_label, mask_features=mask_features, batch=batch).sigmoid()
4. 在gnnexplainer_run
函数中,确保batch
被正确提取并传递给解释器
# 在main.py的gnnexplainer_run函数中:
x, edge_index, batch = graph.x, graph.edge_index.long(), graph.batch
# 调用explainer时确保batch传递到forward方法
explainer(x, edge_index, batch=batch, ...)
问题二: TypeError: forward() got an unexpected keyword argument 'num_classes'
问题原因分析
-
参数传递冲突:
在base_explainer.py
的eval_related_pred
方法中,调用self.model(x=x, edge_index=edge_index, **kwargs)
时,kwargs
可能包含num_classes
参数,但模型的forward
方法未定义该参数。 -
模型定义限制:
模型的forward
方法(如Detector
类)仅接受x
,edge_index
,batch
等参数,而num_classes
是模型初始化参数,不应在前向传播时传递。
解决方案
步骤 1:修改gnnexplainer.py
中的forward
方法
在调用eval_related_pred
时,过滤掉num_classes
参数:
# 修改前(报错代码)
related_preds = self.eval_related_pred(x, edge_index, hard_edge_masks, **kwargs)
# 修改后(过滤num_classes)
filtered_kwargs = {k: v for k, v in kwargs.items() if k != 'num_classes'}
related_preds = self.eval_related_pred(x, edge_index, hard_edge_masks, batch=batch, **filtered_kwargs)
总结
通过显式传递batch
参数至模型调用链的每一层,可修复因参数缺失导致的错误。此问题通常由参数在多层调用中遗漏引起,需逐层检查参数传递的完整性。