pytorch转换onnx,测试onnx,pytorch 模型结果是否一致

def to_onnx():
    dummy_input = torch.randn(1, 3, 112, 112, dtype=torch.float)
    # model = model_res()
    model = model_osnet()

    input_names = ["data"]
    output_names = ["fc"]
    torch.onnx.export(
        model,
        dummy_input,
        "./osnet.onnx",
        verbose=True,
        input_names=input_names,
        output_names=output_names,
    )
    print("转换模型成功^^")


def pytorch_out(input):
    model = model_res() #model.eval
    # input = input.cuda()
    # model.cuda()
    torch.no_grad()
    output = model(input)
    # print output[0].flatten()[70:80]
    return output

def pytorch_onnx_test():
    import onnxruntime
    from onnxruntime.datasets import get_example

    def to_numpy(tensor):
        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

    # 测试数据
    torch.manual_seed(66)
    dummy_input = torch.randn(1, 3, 112, 112, device='cpu')

    example_model = get_example("/home/shiyy/nas/all_workspace/pytorch_code/reid/InsightFace-v2/res50.onnx")
    # netron.start(example_model) 使用 netron python 包可视化网络
    sess = onnxruntime.InferenceSession(example_model)

    # onnx 网络输出
    onnx_out = np.array(sess.run(None, { "data": to_numpy(dummy_input)}))  #fc 输出是三维列表
    print("==============>")
    print(onnx_out)
    print(onnx_out.shape)
    print("==============>")
    torch_out_res = pytorch_out(dummy_input).detach().numpy()   #fc输出是二维 列表
    print(torch_out_res)
    print(torch_out_res.shape)

    print("===================================>")
    print("输出结果验证小数点后五位是否正确,都变成一维np")

    torch_out_res = torch_out_res.flatten()
    onnx_out = onnx_out.flatten()

    pytor = np.array(torch_out_res,dtype="float32") #need to float32
    onn=np.array(onnx_out,dtype="float32")  ##need to float32
    np.testing.assert_almost_equal(pytor,onn, decimal=5)  #精确到小数点后5位,验证是否正确,不正确会自动打印信息
    print("恭喜你 ^^ ,onnx 和 pytorch 结果一致 , Exported model has been executed decimal=5 and the result looks good!")
	 [[...
	  -5.49954772e-02 -7.31383190e-02  2.37192452e-01  1.48571879e-02
	  -2.97113061e-02 -5.73663861e-02 -6.34638742e-02 -5.24816178e-02]]
	(1, 512)
	===================================>
	输出结果验证小数点后五位是否正确,都变成一维np
	恭喜你 ^^ ,onnx 和 pytorch 结果一致 , Exported model has been executed decimal=5 and the result looks good!
  • 3
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值