Torchscript不兼容(截至1.2.0)
首先,您的示例torch.nn.Module有一些小错误(可能是意外造成的)。在
第二,您可以将任何传递给forward,register_forward_pre_hook将只获得将传递给您的torch.nn.Module(无论是层或模型或任何其他内容)的参数。如果不修改forward调用,确实无法做到这一点,但为什么要避免这种情况呢?您可以简单地将参数转发到基函数,如下所示:import torch
class NeoEmbeddings(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx=-1):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.register_forward_pre_hook(NeoEmbeddings.neo_genesis)
# First argument should be named something like module, as that's what
# you are registering this hook to
@staticmethod
def neo_genesis(module, inputs): # No need for self as first argument
net_input, higgs_bosson = inputs # Simply unpack tuple here
return net_input
def forward(self, inputs, higgs_bosson):