最近在加速pytorch模型时尝试转换为onnx格式,onnx确实快一些,但因为需要在推理时获得loss以计算ppl,发现onnx的output里loss为None(即使给定了label)因此需要自己根据logit计算loss
pytorch模型的output:
onnx模型的Output:
参考原Pytorch模型的forward定义里loss如何计算的(以qwen2为例):
抄过来就好了:
最终两边输出的loss应该是一样的,并且onnx要快:
onnx:
pytorch:
最近在加速pytorch模型时尝试转换为onnx格式,onnx确实快一些,但因为需要在推理时获得loss以计算ppl,发现onnx的output里loss为None(即使给定了label)因此需要自己根据logit计算loss
pytorch模型的output:
onnx模型的Output:
参考原Pytorch模型的forward定义里loss如何计算的(以qwen2为例):
抄过来就好了:
最终两边输出的loss应该是一样的,并且onnx要快:
onnx:
pytorch: