torch.eq与torch.ne——判断数组中的数值是否相等
torch.eq()——判断元素是否相等
torch.eq(input, other, *, out=None) → Tensor
功能:判断两个数组的元素是否相等。
输出:返回与输入具有相同形状的张量数组,若对应位置上的元素相等,则该位置上的元素是True
,否则是False
。
输入:
input
:要比较的张量数组other
:判断标准(判断input
是否与other
相等),张量数组或者值
注意:
torch.eq
具有广播机制的效应,other
的形状必须能够通过广播机制扩充为input
的形状,最终比较的是input
和扩充后的other
torch.eq
也可以通过a.eq
实现,效果类似,只是后者的a相当于前者输入中的input
- 如果输入的是数组,则必须是
tensor
类型
具体的广播机制条件可以见这里
案例代码
import torch
a=torch.arange(10).view(2,5)
b=torch.arange(5).view(1,5)
c=torch.eq(a,b)
d=a.eq(5)
print(a)
print(b)
print(c)
print(d)
输出
# 数组a
tensor([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
# 数组b
tensor([[0, 1, 2, 3, 4]])
# a、b比较结果
tensor([[ True, True, True, True, True],
[False, False, False, False, False]])
# a与单个数字5比较,相当于做了一个扩充,再比较
tensor([[False, False, False, False, False],
[ True, False, False, False, False]])
当other
与input
形状不一样的时候,会通过广播机制将other
扩充,扩充为和input
具有一样大小的数组,再进行逐元素比较
torch.ne()——判断元素是否不相等
torch.ne(input, other, *, out=None) → Tensor
功能:判断两个数组的元素是否不相等。(相当于是eq
的反运算,下面的性质与eq
类似)
输出:返回与输入具有相同形状的张量数组,若对应位置上的元素不相等,则该位置上的元素是True
,否则是False
。
输入:
input
:要比较的张量数组other
:判断标准(判断input
是否与other
不相等),张量数组或者值
注意:
torch.ne
同样具有广播机制的效应,other
的形状必须能够通过广播机制扩充为input
的形状,最终比较的是input
和扩充后的other
torch.ne
也可以通过a.ne
实现,效果类似,只是后者的a相当于前者输入中的input
- 如果输入的是数组,则必须是
tensor
类型
案例代码
import torch
a=torch.arange(10).view(2,5)
b=torch.arange(5).view(1,5)
c=torch.ne(a,b)
d=a.ne(6)
print(a)
print(b)
print(c)
print(d)
输出
# 数组a
tensor([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
# 数组b
tensor([[0, 1, 2, 3, 4]])
# a与b的比较情况
tensor([[False, False, False, False, False],
[ True, True, True, True, True]])
# a与单个数字6比较的结果
tensor([[ True, True, True, True, True],
[ True, False, True, True, True]])
扩展
torch.eq
常用于分类任务中,判断预测值是否与真实值(标签值)相等,然后再配合.sum()
方法,将返回数组中True
的个数求和,从而计算正确预测的个数,最终得到正确率。
代码案例
import torch
import numpy as np
# 从0,1内随机选50个数,作为标签和预测值
labels=torch.tensor(np.random.choice(2,50))
pre=torch.tensor(np.random.choice(2,50))
print(labels)
print(pre)
right=torch.eq(pre,labels).sum().item()
# .sum()是求和操作,将数组中True数量求和
# .item()方法是为了取出数值
print('正确预测:',right,'个')
acc=right/len(labels)
print('准确率为:',acc,'%')
输出
# 标签
tensor([1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1,
0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0,
1, 1], dtype=torch.int32)
# 预测值
tensor([1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1,
1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0,
0, 1], dtype=torch.int32)
正确预测: 31 个
准确率为: 0.62 %
官方文档
torch.eq():https://pytorch.org/docs/stable/generated/torch.eq.html#torch.eq
torch.ne():https://pytorch.org/docs/stable/generated/torch.ne.html?highlight=ne#torch.ne
点个赞支持一下吧