使用numpy或pytorch校验两个张量是否相等

文章目录


做算法过程中,如果涉及到模型落地,那必然会将原始的深度学习的框架训练好的模型转换成目标硬件模型的格式,如onnx,tensorrt,openvino,tflite;那么就有对比不同格式模型输出的一致性,从而判断模型转换是否成功。

1、numpy

用到的核心代码就一行,就是:

import numpy as np
np.testing.assert_allclose(actual,expected,rtol,atol)

上代示例:

import numpy as np

# 定义两个数组
actual= np.array([1.0, 2.0, 3.0])
expected = np.array([1.01, 1.99, 3.0])

# 使用 np.testing.allclose 检查它们是否近似相等
np.testing.assert_allclose(actual,expected,rtol=0,atol=0.01)

输出:
在这里插入图片描述
最大的绝对误差是0.01,最大的相对误差是0.00990099.
再一个示例:

import numpy as np

# 定义两个数组
actual= np.array([1.0, 2.0, 3.0])
expected = np.array([1.01, 1.99, 3.0])

# 使用 np.testing.allclose 检查它们是否近似相等
np.testing.assert_allclose(actual,expected,rtol=0,atol=0.0100001)

只是改了atol 从0.01改成0.0100001。

所以关于rtol和atol做如下理解:
rtol 就是relative tolarance ,atol 就absolute tolarance.
先计算绝对误差:

diff = abs(actual-expecd) #绝对误差
tolarance = atol+ rtol*abs(expected) #误差容忍上限
if diff<tolarance:
	pass
else:
	print("报错信息,如图,有最大绝对误差 最大相对误差 不相等的百分比等")

最大绝对误差= max(diff)
最大相对误差= max(diff)/abs(expected)

函数默认的 atol=1e-7,rtol=0
但考虑到float32精度,有效数字也就7位,可以设置atol=1e-5,小数点后五位有效数字即可。

2、pytorch

pytorch有相似的api:

import numpy as np
import torch
actual= np.array([1.0, 2.0, 3.0])
expected = np.array([1.01, 1.99, 3.0])
torch.testing.assert_close(torch.tensor(actual),torch.tensor(expected),rtol=0,atol=0.011)

以上不会有任何输出

import numpy as np
import torch
actual= np.array([1.0, 2.0, 3.0])
expected = np.array([1.01, 1.99, 3.0])
torch.testing.assert_close(torch.tensor(actual),torch.tensor(expected),rtol=0,atol=0.01)

在这里插入图片描述
相比numpy,多给出了相关最大误差的位置及允许的上限。

  • 4
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值