1、转pth
因为有可能pytorch版本的问题,可能会报一个zip类型的错误,所以需要考虑兼容问题
def load_model(model,path):
state=torch.load(str(path))
model.load_state_dict(state['model'])
return state
state_dict=load_model(model,'./snapshot3/model_79.pt')
torch.save(state_dict,'./snapshot3/model_79_new.pth',_use_new_zipfile_serialization=False)
2、转换onnx
这里是为了记录当我的输入不止一个的时候,如这里的话,我输入的除了是图片之外,还要输入图像尺寸作为辅助参数的时候,这里就需要做出以下转化
model.eval()
inputs=torch.ones((1,3,224,224))
sizes=[[0,0]]
sizes=np.array(sizes)
sizes=torch.from_numpy(sizes)
sizes=torch.tensor(sizes,dtype=torch.float32)
torch.onnx.export(model,
(inputs,sizes), #这里可以元组的形式进行封装
"SizeAttention.onnx",
input_names=["input"],
output_names=["output"]
)
3、反向验证
import onnx
onnx_model=onnx.load("SizeAttention.onnx")
onnx.checker.check_model(onnx_model)
如果没有报错,说明成功
进一步验证
import onnxruntime
ort_session=onnxruntime.InferenceSession("SizeAttention.onnx")
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
x=torch.randn((1,3,224,224))
sizes=[[0,0]]
sizes=np.array(sizes)
sizes=torch.from_numpy(sizes)
sizes=torch.tensor(sizes,dtype=torch.float32)
#compute ONNX Runtime
ort_inputs={ort_session.get_inputs()[0].name:to_numpy(x),ort_session.get_inputs()[0].name:to_numpy(sizes)}
ort_outs=ort_session.run(None,ort_inputs)
或者还有一种验证方式
import onnx
import caffe2.python.onnx.backend as onnx_caffe2_backend
cmodel=onnx.load("SizeAttention.onnx")
prepared_backend=onnx_caffe2_backend.prepare(cmodel)
W={cmodel.graph.input[0].name:image.data.cpu().numpy(),
cmodel.graph.input[1].name:sizes.data.cpu().numpy()}
c2_out=prepared_backend.run(W)[0]
np.testing.assert_almost_equal(torch_out,data.cpu().numpy(),c2_out,decimal=3)