模型定义
如HAN模型:
class HAN(nn.Module):
def __init__(
self, meta_paths, in_size, hidden_size, out_size, num_heads, dropout
):
super(HAN, self).__init__()
self.layers = nn.ModuleList()
self.layers.append(
HANLayer(meta_paths, in_size, hidden_size, num_heads[0], dropout)
)
for l in range(1, len(num_heads)):
self.layers.append(
HANLayer(
meta_paths,
hidden_size * num_heads[l - 1],
hidden_size,
num_heads[l],
dropout,
)
)
self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)
def forward(self, g, h):
for gnn in self.layers:
h = gnn(g, h)
return self.predict(h)
模型使用
直接调用,如:
model = HAN(#之前构建的边pa,ap。组合成meta-path:pap
meta_paths=[["pa", "ap"], ["pf", "fp"]],
in_size=features.shape[1],
hidden_size=args["hidden_units"],
out_size=num_classes,
num_heads=args["num_heads"],
dropout=args["dropout"],
).to(args["device"])
而不用
model.forward()
forward函数的使用
python calss 中的__call__和__init__方法会调用forward函数,因此在实例化模型中已经调用forward函数。
class A():
def __call__(self, param):#或者__init__()
#此处省略代码
res = self.forward(param)
return res
def forward(self, input_):
print('forward 函数被调用了')
#forward函数功能实现代码
return input_
a = A()
#此时在实例化的过程中已经执行了forward()函数
注: 在声明网络架构是,常常使用class HAN(nn.Module),其中nn.Module中包含了__call__函数,在函数中调用了forward,由于继承关系,对于HAN同样具备__call__函数的功能。
相关HAN算法代码地址为:https://download.csdn.net/download/weixin_43333607/87513112