【Pytorch】常用函数及其用法总结

python 张量信息提取、交互有关函数

记录码代码中常用的一些pytorch函数,提升写代码的效率。

torch.gather()

对张量a按索引张量b取值。则最后得到的张量c维度应和张量b维度相同。

a = torch.tensor([[3,  4,  5],[6,  7,  8],[9, 10, 11]])
b = torch.tensor([[2, 1, 0]])
# dim=0 按列取值(分别取第一列索引2、第二列索引1、第三列索引0的值)
c = torch.gather(a, dim=0, index=b)
print(c) #c=tensor([[9, 7, 5]])
# dim=-1 按行取值(分别取第一行索引2、第一行索引1、第一行索引0的值)
c = torch.gather(a, dim=-1, index=b)
print(c) #c=tensor([[5, 4, 3]])

torch.eq()

对两个张量Tensor进行逐元素的比较,若相同位置的两个元素相同,则返回True;若不同,返回False。输入的第二个张量可以是数字或张量,可以和第一个维度不同(广播扩展维度)
与torch.equal不同,torch.eq()是逐元素对比,torch.equal()是整个张量和张量是否相等。

print(torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])))
#tensor([[ True, False],
        # [False,  True]])

torch.ge()/torch.gt() & torch.le()/torch.lt()

ge(input, other): 逐元素对比input>=other
gt(input, other): 逐元素对比input>other
le(input, other): 逐元素对比input<=other
lt(input, other): 逐元素对比input<other

print(torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])))
#tensor([[False, False],
        #[ True, False]])

torch.eye(raw, col)

返回一个 对角线上为1,其他地方为0 的二维张量。
torch.eye(raw, col), 其中raw是必须给出的,col可不给,默认为raw

torch.eye(3)
# tensor([[ 1.,  0.,  0.],
       # [ 0.,  1.,  0.],
       # [ 0.,  0.,  1.]])

torch.masked_fill()

a.masked_fill(mask, value), 其中mask必须是一个二值张量(ByteTensor),且大小维度必须和a一样。
该函数将a中对应mask为1的值替换为value。

a = torch.tensor([1,2,3,5,2,1])
a = a[:,None]
mask = torch.eq(a, a.t()).bool()
print(mask)
#tensor([[ True, False, False, False, False,  True],
       # [False,  True, False, False,  True, False],
       # [False, False,  True, False, False, False],
       # [False, False, False,  True, False, False],
       # [False,  True, False, False,  True, False],
       # [ True, False, False, False, False,  True]])
eye = torch.eye(mask.shape[0], mask.shape[1]).bool()
print(eye)
#tensor([[ True, False, False, False, False, False],
       # [False,  True, False, False, False, False],
       # [False, False,  True, False, False, False],
       # [False, False, False,  True, False, False],
       # [False, False, False, False,  True, False],
       # [False, False, False, False, False,  True]])
mask_pos = mask.masked_fill(eye, 0)
print(mask_pos)
#tensor([[False, False, False, False, False,  True],
       # [False, False, False, False,  True, False],
       # [False, False, False, False, False, False],
       # [False, False, False, False, False, False],
       # [False,  True, False, False, False, False],
       # [ True, False, False, False, False, False]])

torch.view()

torch.view(新维度),返回一个新的tensor,这个tensor中包含原tensor中的所有数据,只是维度不一样,维度为括号里规定的新维度。

x = torch.randn(4, 4)
# torch.Size([4, 4])
y = x.view(16)
# torch.Size([16]),但元素数量和x相同
z = x.view(-1, 8) 
# -1的维度大小由别的维数推断出来,比如这里z的第1个维度给出了8,但是x的总数量为16,所以第0个维度大小为2
# torch.Size([2, 8])

torch.index_select()

torch.index_select(a, dim, index): 取出a张量中index对用索引的值。

x = torch.randn(3, 4)
#tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
       # [-0.4664,  0.2647, -0.1228, -1.1068],
       # [-1.1734, -0.6571,  0.7230, -0.6004]])
indices = torch.tensor([0, 2])
torch.index_select(x, 0, indices)
#tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
       # [-1.1734, -0.6571,  0.7230, -0.6004]])
torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
       # [-0.4664, -0.1228],
       # [-1.1734,  0.7230]])

torch.mask_select()

torch.mask_select(a, mask): 取出a中对应mask为True的值,注意最后返回的张量是一维的。

x = torch.randn(3, 4)
#tensor([[ 0.3552, -2.3825, -0.8297,  0.3477],
       # [-1.2035,  1.2252,  0.5002,  0.6248],
       # [ 0.1307, -2.0608,  0.1244,  2.0139]])
mask = x.ge(0.5)
#tensor([[False, False, False, False],
       # [False, True, True, True],
       # [False, False, False, True]])
torch.masked_select(x, mask)
# tensor([ 1.2252,  0.5002,  0.6248,  2.0139])
  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值