目录
python 张量信息提取、交互有关函数
记录码代码中常用的一些pytorch函数,提升写代码的效率。
torch.gather()
对张量a按索引张量b取值。则最后得到的张量c维度应和张量b维度相同。
a = torch.tensor([[3, 4, 5],[6, 7, 8],[9, 10, 11]])
b = torch.tensor([[2, 1, 0]])
# dim=0 按列取值(分别取第一列索引2、第二列索引1、第三列索引0的值)
c = torch.gather(a, dim=0, index=b)
print(c) #c=tensor([[9, 7, 5]])
# dim=-1 按行取值(分别取第一行索引2、第一行索引1、第一行索引0的值)
c = torch.gather(a, dim=-1, index=b)
print(c) #c=tensor([[5, 4, 3]])
torch.eq()
对两个张量Tensor进行逐元素的比较,若相同位置的两个元素相同,则返回True;若不同,返回False。输入的第二个张量可以是数字或张量,可以和第一个维度不同(广播扩展维度)
与torch.equal不同,torch.eq()是逐元素对比,torch.equal()是整个张量和张量是否相等。
print(torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])))
#tensor([[ True, False],
# [False, True]])
torch.ge()/torch.gt() & torch.le()/torch.lt()
ge(input, other): 逐元素对比input>=other
gt(input, other): 逐元素对比input>other
le(input, other): 逐元素对比input<=other
lt(input, other): 逐元素对比input<other
print(torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])))
#tensor([[False, False],
#[ True, False]])
torch.eye(raw, col)
返回一个 对角线上为1,其他地方为0 的二维张量。
torch.eye(raw, col), 其中raw是必须给出的,col可不给,默认为raw
torch.eye(3)
# tensor([[ 1., 0., 0.],
# [ 0., 1., 0.],
# [ 0., 0., 1.]])
torch.masked_fill()
a.masked_fill(mask, value), 其中mask必须是一个二值张量(ByteTensor),且大小维度必须和a一样。
该函数将a中对应mask为1的值替换为value。
a = torch.tensor([1,2,3,5,2,1])
a = a[:,None]
mask = torch.eq(a, a.t()).bool()
print(mask)
#tensor([[ True, False, False, False, False, True],
# [False, True, False, False, True, False],
# [False, False, True, False, False, False],
# [False, False, False, True, False, False],
# [False, True, False, False, True, False],
# [ True, False, False, False, False, True]])
eye = torch.eye(mask.shape[0], mask.shape[1]).bool()
print(eye)
#tensor([[ True, False, False, False, False, False],
# [False, True, False, False, False, False],
# [False, False, True, False, False, False],
# [False, False, False, True, False, False],
# [False, False, False, False, True, False],
# [False, False, False, False, False, True]])
mask_pos = mask.masked_fill(eye, 0)
print(mask_pos)
#tensor([[False, False, False, False, False, True],
# [False, False, False, False, True, False],
# [False, False, False, False, False, False],
# [False, False, False, False, False, False],
# [False, True, False, False, False, False],
# [ True, False, False, False, False, False]])
torch.view()
torch.view(新维度),返回一个新的tensor,这个tensor中包含原tensor中的所有数据,只是维度不一样,维度为括号里规定的新维度。
x = torch.randn(4, 4)
# torch.Size([4, 4])
y = x.view(16)
# torch.Size([16]),但元素数量和x相同
z = x.view(-1, 8)
# -1的维度大小由别的维数推断出来,比如这里z的第1个维度给出了8,但是x的总数量为16,所以第0个维度大小为2
# torch.Size([2, 8])
torch.index_select()
torch.index_select(a, dim, index): 取出a张量中index对用索引的值。
x = torch.randn(3, 4)
#tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
# [-0.4664, 0.2647, -0.1228, -1.1068],
# [-1.1734, -0.6571, 0.7230, -0.6004]])
indices = torch.tensor([0, 2])
torch.index_select(x, 0, indices)
#tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],
# [-1.1734, -0.6571, 0.7230, -0.6004]])
torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
# [-0.4664, -0.1228],
# [-1.1734, 0.7230]])
torch.mask_select()
torch.mask_select(a, mask): 取出a中对应mask为True的值,注意最后返回的张量是一维的。
x = torch.randn(3, 4)
#tensor([[ 0.3552, -2.3825, -0.8297, 0.3477],
# [-1.2035, 1.2252, 0.5002, 0.6248],
# [ 0.1307, -2.0608, 0.1244, 2.0139]])
mask = x.ge(0.5)
#tensor([[False, False, False, False],
# [False, True, True, True],
# [False, False, False, True]])
torch.masked_select(x, mask)
# tensor([ 1.2252, 0.5002, 0.6248, 2.0139])