pytorch中的forward函数(以HAN算法为例详细说明)

模型定义

如HAN模型:

class HAN(nn.Module):
    def __init__(
        self, meta_paths, in_size, hidden_size, out_size, num_heads, dropout
    ):
        super(HAN, self).__init__()

        self.layers = nn.ModuleList()
        self.layers.append(
            HANLayer(meta_paths, in_size, hidden_size, num_heads[0], dropout)
        )
        for l in range(1, len(num_heads)):
            self.layers.append(
                HANLayer(
                    meta_paths,
                    hidden_size * num_heads[l - 1],
                    hidden_size,
                    num_heads[l],
                    dropout,
                )
            )
        self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)

    def forward(self, g, h):
        for gnn in self.layers:
            h = gnn(g, h)

        return self.predict(h)

模型使用

直接调用,如:

model = HAN(#之前构建的边pa,ap。组合成meta-path:pap
            meta_paths=[["pa", "ap"], ["pf", "fp"]],
            in_size=features.shape[1],
            hidden_size=args["hidden_units"],
            out_size=num_classes,
            num_heads=args["num_heads"],
            dropout=args["dropout"],
        ).to(args["device"])

而不用

model.forward()

forward函数的使用

python calss 中的__call__和__init__方法会调用forward函数,因此在实例化模型中已经调用forward函数。

class A():
    def __call__(self, param):#或者__init__()
        #此处省略代码
        res = self.forward(param)
        return res
 
    def forward(self, input_):
        print('forward 函数被调用了')
        #forward函数功能实现代码
        return input_
 
a = A()
 
#此时在实例化的过程中已经执行了forward()函数

注: 在声明网络架构是,常常使用class HAN(nn.Module),其中nn.Module中包含了__call__函数,在函数中调用了forward,由于继承关系,对于HAN同样具备__call__函数的功能。

相关HAN算法代码地址为:https://download.csdn.net/download/weixin_43333607/87513112

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

筱文rr

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值