本文遇到的情况:
在 Pytorch 模型到 onnx 转化的过程中,由于模型的 forward() 函数中带有条件控制流,不能采用 torch.onnx.export
的默认方式直接将模型的分支遍历(原因可参考文章 ➡️ torch.jit.script 的使用),便
采用 torch.jit.script
时遇到的问题:
参考模型部署入门教程(三) 中的代码如下
import torch
class Model(torch.nn.Module):
def __init__(self, n):
super().__init__()
self.n = n
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
for i in range(self.n):
x = self.conv(x)
return x
models = [Model(2), Model(3)]
model_names = ['model_2', 'model_3']
for model, model_name in zip(models, model_names):
dummy_input = torch.rand(1, 3, 10, 10)
dummy_output = model(dummy_input)
model_trace = torch.jit.trace(model, dummy_input)
model_script = torch.jit.script(model)
# 跟踪法与直接 torch.onnx.export(model, ...)等价
torch.onnx.export(model_trace, dummy_input, f'{model_name}_trace.onnx', example_outputs=dummy_output)
# 记录法必须先调用 torch.jit.sciprt
torch.onnx.export(model_script, dummy_input, f'{model_name}_script.onnx', example_outputs=dummy_output)
本文所用代码如下
def forward(self, pix, xy, text, sign, return_loss=False):
'''
pix.shape [batch_size, max_len]
xy.shape [batch_size, max_len, 2]
text.shape [batch_size, text_len]
'''
pixel_v = pix
xy_v = xy
c_bs, c_seqlen, device = text.shape[0], text.shape[1], text.device
if sign != 0:
c_seqlen += pixel_v.shape[1]
# Context embedding values
context_embedding = torch.zeros((1, c_bs, self.embed_dim)).to(device) # [1, bs, dim]
tokens = self.text_emb(text)
# Data input embedding
if sign != 0:
coord_embed = self.coord_embed_x(xy_v[...,0]) + self.coord_embed_y(xy_v[...,1]) # [bs, vlen, dim]
pixel_embed = self.pixel_embed(pixel_v)
embed_inputs = pixel_embed + coord_embed
# tokens.shape [batch_size, text_len+max_len-1, emb_dim]
tokens = torch.cat((tokens, embed_inputs), dim=1)
embeddings = torch.cat([context_embedding, tokens.transpose(0,1)], axis=0)
decoder_inputs = self.pos_embed(embeddings)
memory_encode = torch.zeros((1, c_bs, self.embed_dim)).to(device)
# nopeak_mask.shape [c_seqlen+1, c_seqlen+1]
nopeak_mask = torch.nn.Transformer.generate_square_subsequent_mask(c_seqlen+1).to(device) # masked with -inf
decoder_out = self.decoder(tgt=decoder_inputs, memory=memory_encode, memory_key_padding_mask=None,
tgt_mask=nopeak_mask, tgt_key_padding_mask=None)
# Logits fc
logits = self.logit_fc(decoder_out) # [seqlen, bs, dim]
logits = logits.transpose(1,0) # [bs, textlen+seqlen, total_token]
logits_mask = self.logits_mask[:, :c_seqlen+1]
max_neg_value = -torch.finfo(logits.dtype).max
logits.masked_fill_(logits_mask, max_neg_value)
return logits
# 加载 PyTorch 模型
sketch_decoder = SketchDecoder(
config={
'hidden_dim':...
},
pix_len=cfg['pix_len'],
text_len=cfg['text_len'],
num_text_token=tokenizer.vocab_size,
word_emb_path=cfg['word_emb_path'],
pos_emb_path=cfg['pos_emb_path'],
)
scripted_model = torch.jit.script(sketch_decoder(pixel_seq, xy_seq, text, sign))
报错如下:
File "/path/bin2onnx.py", line 299, in <module>
scripted_model = torch.jit.script(sketch_decoder(pixel_seq, xy_seq, text, sign))
File "/path/lib/python3.9/site-packages/torch/jit/_script.py", line 1351, in script
return torch.jit._recursive.create_script_class(obj)
File "/path/lib/python3.9/site-packages/torch/jit/_recursive.py", line 424, in create_script_class
rcb = _jit_internal.createResolutionCallbackForClassMethods(type(obj))
File "/path/lib/python3.9/site-packages/torch/_jit_internal.py", line 395, in createResolutionCallbackForClassMethods
captures.update(get_closure(fn))
File "/path/lib/python3.9/site-packages/torch/_jit_internal.py", line 169, in get_closure
captures.update(fn.__globals__)
AttributeError: 'method_descriptor' object has no attribute '__globals__'
找到一篇 理解 AttributeError: ‘method_descriptor’ object has no attribute ‘globals’ 的文章
,详细地解释了这个问题,便翻译全文并精炼,提供一些解决思路:
1. 什么是 AttributeError?
AttributeError 是一种异常,当你试图访问或操作一个对象不存在的属性时会发生。在 Python 中,一切皆对象,对象具有定义其行为和属性的属性。当你试图访问一个不存在的属性时,Python 会引发 AttributeError 来告知你出了问题。
2. ‘method_descriptor’ 对象是什么?
在 Python 中,method_descriptor 是一种特殊类型的对象,表示类中定义的方法。当你在类中定义一个方法时,它会被创建,并且可以通过该类的实例进行访问。方法描述符用于定义方法的行为,并允许你在类的实例上调用它们。
3. ‘globals’ 属性是什么?
‘globals’ 属性是 Python 函数和方法中存在的一个特殊属性。它表示定义该函数或方法时的全局命名空间。它包含一个字典,将全局名称映射到其对应的值。
当你在 Python 中定义一个函数或方法时,它可以访问其定义时的全局命名空间。这意味着它可以访问和修改全局变量和函数。‘globals’ 属性允许你从函数或方法内部访问这个全局命名空间。
4. 错误信息意味着什么?
错误信息“AttributeError: ‘method_descriptor’ object has no attribute ‘globals’”发生在你试图在 method_descriptor 对象上访问 ‘globals’ 属性时。这通常发生在你错误地试图在方法而不是函数上访问 ‘globals’ 属性时。
5. 为什么会发生错误?
该错误发生是因为 method_descriptor 对象没有 ‘globals’ 属性。只有使用 ‘def’ 关键字定义的函数和方法才有此属性。当你试图在 method_descriptor 对象上访问 ‘globals’ 属性时,Python 会引发 AttributeError,因为该属性不存在。
6. 如何修复错误?
要修复“AttributeError: ‘method_descriptor’ object has no attribute ‘globals’”错误,你需要确保你是在函数或方法上访问 ‘globals’ 属性,而不是在 method_descriptor 对象上。
- 如果你试图在方法上访问 ‘globals’ 属性,可以通过使用 ‘staticmethod’ 或 ‘classmethod’ 装饰器将其转换为函数。这些装饰器创建一个新的函数对象,封装原始方法,并允许你访问 ‘globals’ 属性。
class MyClass:
@staticmethod
def my_static_method():
# 访问 __globals__ 属性
print(my_static_method.__globals__)
- 如果你试图在 method_descriptor 对象上访问 ‘globals’ 属性,你需要重新考虑你的方法。方法描述符不是为了直接访问而设计的,而是通过定义它们的类的实例进行访问。如果你需要访问 ‘globals’ 属性,应该定义一个可以在类实例上调用的单独函数或方法。
结论
“AttributeError: ‘method_descriptor’ object has no attribute ‘globals’” 错误发生在你试图在 method_descriptor 对象上访问 ‘globals’ 属性时。通过确保你是在函数或方法上访问该属性,而不是在 method_descriptor 对象上,可以修复此错误。如果你需要访问 ‘globals’ 属性,可以使用 ‘staticmethod’ 或 ‘classmethod’ 装饰器将方法转换为函数。
回到解决本文 torch.jit.script
的问题
从本文代码来看,SketchDecoder
类定义了一个 forward
方法,该方法用于前向传播。问题发生在尝试将 sketch_decoder
的实例调用传递给 torch.jit.script
时。
解决方案分析
要解决这个问题,需要确保传递给 torch.jit.script
的是一个可以被脚本化的函数对象。torch.jit.script
通常应用于函数或 nn.Module
的实例。如果想脚本化 SketchDecoder
的 forward
方法,可以直接传递 SketchDecoder
的实例,而不是调用结果。
这里有两个常见的解决方法:
方法1:直接脚本化模型实例 (❌不断出现对于模型的报错)
将整个模型实例传递给 torch.jit.script
。这样,JIT 将脚本化整个模型,包括其 forward
方法。
方法2:单独脚本化 forward
方法(❌还是需要模型的输入才行)
如果需要单独脚本化 forward
方法,可以定义一个包装函数或确保传递的是方法引用,而不是调用结果。
# 将 forward 方法的引用脚本化
scripted_model = torch.jit.script(sketch_decoder.forward)
最终结论
使用 torch.jit.script 的问题待解决,还是暂时通过 if 外置 + 用 torch.jit.trace 分别导出不同的 onnx
来解决问题。