test
import torch
nums = 3
# 首先,我们随机生成一个3*5的矩阵
a = torch.randn(nums, 5)
# 在列的维度(dim=0),取每列的最大值
overlap_for_each_prior, object_for_each_prior = a.max(dim=0)
_, prior_for_each_object = a.max(dim=1) # (N_o)
print(a)
print(overlap_for_each_prior)
print('++++++++++++++++++')
print(object_for_each_prior)
print(prior_for_each_object)
##################################################
##################################################
'''这里是关键'''
print(object_for_each_prior[prior_for_each_object])
'''这里是关键'''
##################################################
##################################################
print(torch.LongTensor(range(nums)))
print('++++++++++++++++++')
# 这一步骤的操作是以prior_for_each_object为索引,从
# object_for_each_prior取对应的值
object_for_each_prior[prior_for_each_object] = torch.LongTensor(range(nums, 2*nums))
print('++++++++++++++++++')
print(object_for_each_prior)
print('++++++++++++++++++')
代码解析:
# 首先,我们随机生成一个3*5的矩阵
a = torch.randn(nums, 5)
# 在列的维度(dim=0),取每列的最大值
overlap_for_each_prior, object_for_each_prior = a.max(dim=0)
_, prior_for_each_object = a.max(dim=1) # (N_o)
这个时候,假设
a的值为:
'''
tensor([[-0.1705, 1.2972, 1.8852, -1.0077, -0.6337],
[ 1.5984, -0.6461, 0.3798, -0.4751, 0.9754],
[ 0.7052, 0.4189, 0.1964, 1.0021, 1.6406]])
'''
那么
overlap_for_each_prior:
'''
tensor([1.5984, 1.2972, 1.8852, 1.0021, 1.6406])
'''
object_for_each_prior:
'''
tensor([1, 0, 0, 2, 2])
'''
prior_for_each_object
'''
tensor([2, 0, 4])
'''
然后,神一样的操作来了:
object_for_each_prior[prior_for_each_object] = torch.LongTensor(range(nums, 2*nums))
- 首先:
'''
torch.LongTensor(range(nums, 2*nums))
'''
生成了一个Tensor数组:
'''
tensor([3, 4, 5])
'''
- 然后:
object_for_each_prior:
'''
tensor([1, 0, 0, 2, 2])
'''
在tensor([2, 0, 4])的作用下变成了(取2, 0, 4对应的索引值)
'''
tensor([0, 1, 2])
'''
然后把object_for_each_prior中2,0,4对应的值换成:3,4,5.
所以,最后的结果就是
tensor([4, 0, 3, 2, 5])