行人重识别Relation-Aware Global Attention
的部分代码。
N = dist_mat.size(0) # 8
is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
#print(is_pos.shape)
is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
#print(is_neg.shape)
一、测试eq()函数:
import torch
lab = torch.tensor([389, 389, 389, 389, 628, 628, 628, 628])
print(lab.expand(8, 8))
print("****")
print(lab.expand(8, 8).t())
print("****")
print(lab.expand(8, 8).eq(lab.expand(8, 8).t()))
输出结果:
tensor([[389, 389, 389, 389, 628, 628, 628, 628],
[389, 389, 389, 389, 628, 628, 628, 628],
[389, 389, 389, 389, 628, 628, 628, 628],
[389, 389, 389, 389, 628, 628, 628, 628],
[389, 389, 389, 389, 628, 628, 628, 628],
[389, 389, 389, 389, 628, 628, 628, 628],
[389, 389, 389, 389, 628, 628, 628, 628],
[389, 389, 389, 389, 628, 628, 628, 628]])
****
tensor([[389, 389, 389, 389, 389, 389, 389, 389],
[389, 389, 389, 389, 389, 389, 389, 389],
[389, 389, 389, 389, 389, 389, 389, 389],
[389, 389, 389, 389, 389, 389, 389, 389],
[628, 628, 628, 628, 628, 628, 628, 628],
[628, 628, 628, 628, 628, 628, 628, 628],
[628, 628, 628, 628, 628, 628, 628, 628],
[628, 628, 628, 628, 628, 628, 628, 628]])
****
tensor([[ True, True, True, True, False, False, False, False],
[ True, True, True, True, False, False, False, False],
[ True, True, True, True, False, False, False, False],
[ True, True, True, True, False, False, False, False],
[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True]])
二、扩展:
扩展内容来自一篇博客:https://blog.csdn.net/weixin_44604887/article/details/109385651.
2.1、equal – 张量比较
- 原型:equal(other)
- 比较两个张量是否相等–相等返回:True; 否则返回:False
'''
tensor调用equal方法与torch.equal是一致的
都是比较两个张量是否相等
'''
x = torch.tensor([1, 2, 3])
y = torch.tensor([1, 1, 3])
print(x.equal(y))
print(x.equal(y) == torch.equal(x, y))
2.2、eq – 逐元素判断
- 原型:eq(other)
- 比较两个张量tensor中,每一个对应位置上元素是否相等–对应位置相等,就返回一个True;否则返回一个False.
'''
逐元素进行判断是否相等
'''
x = torch.tensor([1, 2, 3])
y = torch.tensor([1, 1, 3])
print(x.eq(y))
2.3、eq_ – 将判断结果返回并替换原tensor
- 原型:eq_(other)
- 等价于tensor = tensor.eq(other); 即:将比较后的结果替换原张量的值
x = torch.tensor([1, 2, 3])
y = torch.tensor([1, 1, 3])
print(x.eq_(y))
print(x)