【 torch.jit.script 踩坑记录】‘method_descriptor‘ object has no attribute ‘__globals__‘ 暂未解决

22 篇文章 0 订阅
8 篇文章 0 订阅
本文遇到的情况:

在 Pytorch 模型到 onnx 转化的过程中,由于模型的 forward() 函数中带有条件控制流,不能采用 torch.onnx.export 的默认方式直接将模型的分支遍历(原因可参考文章 ➡️ torch.jit.script 的使用),便

采用 torch.jit.script 时遇到的问题:

参考模型部署入门教程(三) 中的代码如下

import torch 
 
class Model(torch.nn.Module): 
    def __init__(self, n): 
        super().__init__() 
        self.n = n 
        self.conv = torch.nn.Conv2d(3, 3, 3) 
 
    def forward(self, x): 
        for i in range(self.n): 
            x = self.conv(x) 
        return x 
 
 
models = [Model(2), Model(3)] 
model_names = ['model_2', 'model_3'] 
 
for model, model_name in zip(models, model_names): 
    dummy_input = torch.rand(1, 3, 10, 10) 
    dummy_output = model(dummy_input) 
    model_trace = torch.jit.trace(model, dummy_input) 
    model_script = torch.jit.script(model) 
 
    # 跟踪法与直接 torch.onnx.export(model, ...)等价 
    torch.onnx.export(model_trace, dummy_input, f'{model_name}_trace.onnx', example_outputs=dummy_output) 
    # 记录法必须先调用 torch.jit.sciprt 
    torch.onnx.export(model_script, dummy_input, f'{model_name}_script.onnx', example_outputs=dummy_output)

本文所用代码如下

def forward(self, pix, xy, text, sign, return_loss=False):
    '''
    pix.shape  [batch_size, max_len]
    xy.shape   [batch_size, max_len, 2]
    text.shape [batch_size, text_len]
    '''
    pixel_v = pix
    xy_v = xy

    c_bs, c_seqlen, device = text.shape[0], text.shape[1], text.device
    if sign != 0:
      c_seqlen += pixel_v.shape[1]  

    # Context embedding values
    context_embedding = torch.zeros((1, c_bs, self.embed_dim)).to(device) # [1, bs, dim]
    tokens = self.text_emb(text)

    # Data input embedding
    if sign != 0:
      coord_embed = self.coord_embed_x(xy_v[...,0]) + self.coord_embed_y(xy_v[...,1]) # [bs, vlen, dim]
      pixel_embed = self.pixel_embed(pixel_v)
      embed_inputs = pixel_embed + coord_embed
      # tokens.shape [batch_size, text_len+max_len-1, emb_dim]
      tokens = torch.cat((tokens, embed_inputs), dim=1)
    
    embeddings = torch.cat([context_embedding, tokens.transpose(0,1)], axis=0)
    decoder_inputs = self.pos_embed(embeddings) 
    memory_encode = torch.zeros((1, c_bs, self.embed_dim)).to(device)
    
    # nopeak_mask.shape [c_seqlen+1, c_seqlen+1]
    nopeak_mask = torch.nn.Transformer.generate_square_subsequent_mask(c_seqlen+1).to(device)  # masked with -inf
    decoder_out = self.decoder(tgt=decoder_inputs, memory=memory_encode, memory_key_padding_mask=None,
                               tgt_mask=nopeak_mask, tgt_key_padding_mask=None)

    # Logits fc
    logits = self.logit_fc(decoder_out)  # [seqlen, bs, dim] 
    logits = logits.transpose(1,0)  # [bs, textlen+seqlen, total_token] 

    logits_mask = self.logits_mask[:, :c_seqlen+1]
    max_neg_value = -torch.finfo(logits.dtype).max
    logits.masked_fill_(logits_mask, max_neg_value)
    return logits

# 加载 PyTorch 模型
sketch_decoder = SketchDecoder(
    config={
        'hidden_dim':...
    },
    pix_len=cfg['pix_len'],
    text_len=cfg['text_len'],
    num_text_token=tokenizer.vocab_size,
    word_emb_path=cfg['word_emb_path'],
    pos_emb_path=cfg['pos_emb_path'],
)

scripted_model = torch.jit.script(sketch_decoder(pixel_seq, xy_seq, text, sign))

报错如下:

  File "/path/bin2onnx.py", line 299, in <module>
    scripted_model = torch.jit.script(sketch_decoder(pixel_seq, xy_seq, text, sign))
  File "/path/lib/python3.9/site-packages/torch/jit/_script.py", line 1351, in script
    return torch.jit._recursive.create_script_class(obj)
  File "/path/lib/python3.9/site-packages/torch/jit/_recursive.py", line 424, in create_script_class
    rcb = _jit_internal.createResolutionCallbackForClassMethods(type(obj))
  File "/path/lib/python3.9/site-packages/torch/_jit_internal.py", line 395, in createResolutionCallbackForClassMethods
    captures.update(get_closure(fn))
  File "/path/lib/python3.9/site-packages/torch/_jit_internal.py", line 169, in get_closure
    captures.update(fn.__globals__)
AttributeError: 'method_descriptor' object has no attribute '__globals__'

找到一篇 理解 AttributeError: ‘method_descriptor’ object has no attribute ‘globals’ 的文章
,详细地解释了这个问题,便翻译全文并精炼,提供一些解决思路:

1. 什么是 AttributeError?

AttributeError 是一种异常,当你试图访问或操作一个对象不存在的属性时会发生。在 Python 中,一切皆对象,对象具有定义其行为和属性的属性。当你试图访问一个不存在的属性时,Python 会引发 AttributeError 来告知你出了问题。

2. ‘method_descriptor’ 对象是什么?

在 Python 中,method_descriptor 是一种特殊类型的对象,表示类中定义的方法。当你在类中定义一个方法时,它会被创建,并且可以通过该类的实例进行访问。方法描述符用于定义方法的行为,并允许你在类的实例上调用它们。

3. ‘globals’ 属性是什么?

globals’ 属性是 Python 函数和方法中存在的一个特殊属性。它表示定义该函数或方法时的全局命名空间。它包含一个字典,将全局名称映射到其对应的值。

当你在 Python 中定义一个函数或方法时,它可以访问其定义时的全局命名空间。这意味着它可以访问和修改全局变量和函数。‘globals’ 属性允许你从函数或方法内部访问这个全局命名空间。

4. 错误信息意味着什么?

错误信息“AttributeError: ‘method_descriptor’ object has no attribute ‘globals’”发生在你试图在 method_descriptor 对象上访问 ‘globals’ 属性时。这通常发生在你错误地试图在方法而不是函数上访问 ‘globals’ 属性时。

5. 为什么会发生错误?

该错误发生是因为 method_descriptor 对象没有 ‘globals’ 属性。只有使用 ‘def’ 关键字定义的函数和方法才有此属性。当你试图在 method_descriptor 对象上访问 ‘globals’ 属性时,Python 会引发 AttributeError,因为该属性不存在。

6. 如何修复错误?

要修复“AttributeError: ‘method_descriptor’ object has no attribute ‘globals’”错误,你需要确保你是在函数或方法上访问 ‘globals’ 属性,而不是在 method_descriptor 对象上。

  1. 如果你试图在方法上访问 ‘globals’ 属性,可以通过使用 ‘staticmethod’ 或 ‘classmethod’ 装饰器将其转换为函数。这些装饰器创建一个新的函数对象,封装原始方法,并允许你访问 ‘globals’ 属性。
class MyClass:
    @staticmethod
    def my_static_method():
        # 访问 __globals__ 属性
        print(my_static_method.__globals__)
  1. 如果你试图在 method_descriptor 对象上访问 ‘globals’ 属性,你需要重新考虑你的方法。方法描述符不是为了直接访问而设计的,而是通过定义它们的类的实例进行访问。如果你需要访问 ‘globals’ 属性,应该定义一个可以在类实例上调用的单独函数或方法。
结论

“AttributeError: ‘method_descriptor’ object has no attribute ‘globals’” 错误发生在你试图在 method_descriptor 对象上访问 ‘globals’ 属性时。通过确保你是在函数或方法上访问该属性,而不是在 method_descriptor 对象上,可以修复此错误。如果你需要访问 ‘globals’ 属性,可以使用 ‘staticmethod’ 或 ‘classmethod’ 装饰器将方法转换为函数。

回到解决本文 torch.jit.script 的问题

从本文代码来看,SketchDecoder 类定义了一个 forward 方法,该方法用于前向传播。问题发生在尝试将 sketch_decoder 的实例调用传递给 torch.jit.script 时。

解决方案分析

要解决这个问题,需要确保传递给 torch.jit.script 的是一个可以被脚本化的函数对象。torch.jit.script 通常应用于函数或 nn.Module 的实例。如果想脚本化 SketchDecoderforward 方法,可以直接传递 SketchDecoder 的实例,而不是调用结果。

这里有两个常见的解决方法:

方法1:直接脚本化模型实例 (❌不断出现对于模型的报错)

将整个模型实例传递给 torch.jit.script。这样,JIT 将脚本化整个模型,包括其 forward 方法。

方法2:单独脚本化 forward 方法(❌还是需要模型的输入才行)

如果需要单独脚本化 forward 方法,可以定义一个包装函数或确保传递的是方法引用,而不是调用结果。

# 将 forward 方法的引用脚本化
scripted_model = torch.jit.script(sketch_decoder.forward)
最终结论

使用 torch.jit.script 的问题待解决,还是暂时通过 if 外置 + 用 torch.jit.trace 分别导出不同的 onnx 来解决问题。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值