1.torch.max()
b=torch.randn(3,4)
print("b:\n",b)
print('b.max(0):',b.max(0))
print('b.max(1):',b.max(1))
print('b.max(1)[0]:',b.max(1)[0])
print('b.max(1)[1]:',b.max(1)[1])
print('b.max(0)[0]:',b.max(0)[0])
print('b.max(0)[1]:',b.max(0)[1])
print('torch.max(b,0)[0]:',torch.max(b,0)[0])
print('torch.max(b,0)[1]:',torch.max(b,0)[1])
#torch.max(tensor,a)[b]中
#tensor代表想要进行操作的张量
#(a)表示想要进行操作的维度,如本例子为3*4的张量,维度为2。但是维度是从0开始索引的。
#例如对一个3*4*5的三维张量,a=0表示输出一个4*5的二维张量,进行比较的都是位于第一维的数据,详见下面的示例
#[b]表示想要输出的东西,0表示输出所得的最大值组成的张量,1表示输出所得的最大值对应的索引
b=torch.randn(3,4,5)
print("b:\n",b)
print('b.max(0)[0]:',b.max(0)[0])
print('b.max(0)[1]:',b.max(0)[1])
print('b.max(1)[0]:',b.max(1)[0])
print('b.max(1)[1]:',b.max(1)[1])
print('b.max(2)[0]:',b.max(2)[0])
print('b.max(2)[1]:',b.max(2)[1])
结果:
b:
tensor([[ 0.8548, 0.8018, -1.8448, -1.9834],
[-0.1640, 0.1191, -0.1988, 1.4577],
[ 0.6882, 1.3879, -0.9425, 0.5111]])
b.max(0): torch.return_types.max(
values=tensor([ 0.8548, 1.3879, -0.1988, 1.4577]),
indices=tensor([0, 2, 1, 1]))
b.max(1): torch.return_types.max(
values=tensor([0.8548, 1.4577, 1.3879]),
indices=tensor([0, 3, 1]))
b.max(1)[0]: tensor([0.8548, 1.4577, 1.3879])
b.max(1)[1]: tensor([0, 3, 1])
b.max(0)[0]: tensor([ 0.8548, 1.3879, -0.1988, 1.4577])
b.max(0)[1]: tensor([0, 2, 1, 1])
torch.max(b,0)[0]: tensor([ 0.8548, 1.3879, -0.1988, 1.4577])
torch.max(b,0)[1]: tensor([0, 2, 1, 1])
b:
tensor([[[-1.0886, -0.2318, 0.5687, 0.6992, 0.7993],
[-1.6488, -0.4074, 0.7266, -1.5632, -0.2511],
[ 0.2802, -0.1441, -0.0184, -1.2758, 1.0951],
[ 0.1700, -1.6584, -1.8804, 1.1791, -0.3551]],
[[-0.6114, -0.8189, 0.3983, -0.3744, 1.0545],
[ 1.1177, 0.5962, 0.5073, 0.3863, 0.9423],
[ 0.0200, -1.0322, 0.1329, -0.3569, 0.3295],
[-1.0800, -0.5290, -1.7301, 0.0407, 0.4790]],
[[ 0.6548, -0.6043, 1.7130, -0.9530, 0.7656],
[-0.1233, -1.5174, 0.8852, -0.2953, -0.1956],
[-0.6458, -1.4962, -1.4248, 0.6494, 0.8585],
[-1.9865, -2.4318, 1.4741, -0.9116, -0.2791]]])
b.max(0)[0]: tensor([[ 0.6548, -0.2318, 1.7130, 0.6992, 1.0545],
[ 1.1177, 0.5962, 0.8852, 0.3863, 0.9423],
[ 0.2802, -0.1441, 0.1329, 0.6494, 1.0951],
[ 0.1700, -0.5290, 1.4741, 1.1791, 0.4790]])
b.max(0)[1]: tensor([[2, 0, 2, 0, 1],
[1, 1, 2, 1, 1],
[0, 0, 1, 2, 0],
[0, 1, 2, 0, 1]])
b.max(1)[0]: tensor([[ 0.2802, -0.1441, 0.7266, 1.1791, 1.0951],
[ 1.1177, 0.5962, 0.5073, 0.3863, 1.0545],
[ 0.6548, -0.6043, 1.7130, 0.6494, 0.8585]])
b.max(1)[1]: tensor([[2, 2, 1, 3, 2],
[1, 1, 1, 1, 0],
[0, 0, 0, 2, 2]])
b.max(2)[0]: tensor([[0.7993, 0.7266, 1.0951, 1.1791],
[1.0545, 1.1177, 0.3295, 0.4790],
[1.7130, 0.8852, 0.8585, 1.4741]])
b.max(2)[1]: tensor([[4, 2, 4, 3],
[4, 0, 4, 4],
[2, 2, 4, 2]])
2.torch.gather()
在1中提到,[b]表示想要输出的东西,0表示输出所得的最大值组成的张量,1表示输出所得的最大值对应的索引张量。
而torch.gather(input, dim, index,out=None)中,input是想要抓取数据的原始张量,dim是要索引的维度,index就对应想要抓取的元素的索引(与上文提到的b=1时的输出是一个东西)。
给出示例:
b = torch.Tensor([[1,2,3],[4,5,6]])
print(b)
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print(torch.gather(b, dim=1, index=index_1))
print(torch.gather(b, dim=0, index=index_2))
结果:
tensor([[1., 2., 3.], [4., 5., 6.]]) tensor([[1., 2.], [6., 4.]]) tensor([[1., 5., 6.], [1., 2., 3.]])