python中forward的参数_如何将关键字参数传递给preforward钩子使用的forward?

本文介绍了在PyTorch中如何在`forward`方法中处理参数,并利用`register_forward_pre_hook`注册预钩子。示例展示了如何在不修改`forward`调用的情况下,将参数传递给钩子函数。同时,讨论了TorchScript的兼容性问题,提出使用组合而非继承来解决。
摘要由CSDN通过智能技术生成

Torchscript不兼容(截至1.2.0)

首先,您的示例torch.nn.Module有一些小错误(可能是意外造成的)。在

第二,您可以将任何传递给forward,register_forward_pre_hook将只获得将传递给您的torch.nn.Module(无论是层或模型或任何其他内容)的参数。如果不修改forward调用,确实无法做到这一点,但为什么要避免这种情况呢?您可以简单地将参数转发到基函数,如下所示:import torch

class NeoEmbeddings(torch.nn.Embedding):

def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx=-1):

super().__init__(num_embeddings, embedding_dim, padding_idx)

self.register_forward_pre_hook(NeoEmbeddings.neo_genesis)

# First argument should be named something like module, as that's what

# you are registering this hook to

@staticmethod

def neo_genesis(module, inputs): # No need for self as first argument

net_input, higgs_bosson = inputs # Simply unpack tuple here

return net_input

def forward(self, inputs, higgs_bosson):

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值