深度学习Pytorch框架Tensor张量

点击上方“视学算法”,选择加"星标"或“置顶

重磅干货,第一时间送达da1d87f1e57f6132333d52e7aac30e5a.png

作者 | 秦一@知乎(已授权)

来源 | https://zhuanlan.zhihu.com/p/399350505

编辑 | 极市平台

导读

 

本文主要介绍了Tensor的裁剪运算、索引与数据筛选、组合/拼接、切片、变形操作、填充操作和Tensor的频谱操作(傅里叶变换)。

1 Tensor的裁剪运算

  • 对Tensor中的元素进行范围过滤

  • 常用于梯度裁剪(gradient clipping),即在发生梯度离散或者梯度爆炸时对梯度的处理

  • torch.clamp(input, min, max, out=None) → Tensor:将输入input张量每个元素的夹紧到区间 [min,max],并返回结果到一个新张量。

64c9a5dfe6efcece73ea027c1e7aada6.png

2 Tensor的索引与数据筛选

  • torch.where(codition,x,y):按照条件从x和y中选出满足条件的元素组成新的tensor,输入参数condition:条件限制,如果满足条件,则选择a,否则选择b作为输出。

  • torch.gather(input,dim,index,out=None):在指定维度上按照索引赋值输出tensor

  • torch.inex_select(input,dim,index,out=None):按照指定索引赋值输出tensor

  • torch.masked_select(input,mask,out=None):按照mask输出tensor,输出为向量

  • torch.take(input,indices):将输入看成1D-tensor,按照索引得到输出tensor

  • torch.nonzero(input,out=None):输出非0元素的坐标

import torch
#torch.where

a = torch.rand(4, 4)
b = torch.rand(4, 4)

print(a)
print(b)

out = torch.where(a > 0.5, a, b)

print(out)
7ebac0f09bb0423b37613ea72b983caf.png
print("torch.index_select")
a = torch.rand(4, 4)
print(a)
out = torch.index_select(a, dim=0,
                   index=torch.tensor([0, 3, 2]))
#dim=0按列,index取的是行
print(out, out.shape)
e79b9e62601cfa457f1f3f66601160fd.png
print("torch.gather")
a = torch.linspace(1, 16, 16).view(4, 4)

print(a)

out = torch.gather(a, dim=0,
             index=torch.tensor([[0, 1, 1, 1],
                                 [0, 1, 2, 2],
                                 [0, 1, 3, 3]]))
print(out)
print(out.shape)
#注:从0开始,第0列的第0个,第一列的第1个,第二列的第1个,第三列的第1个,,,以此类推
#dim=0, out[i, j, k] = input[index[i, j, k], j, k]
#dim=1, out[i, j, k] = input[i, index[i, j, k], k]
#dim=2, out[i, j, k] = input[i, j, index[i, j, k]]
a5e80ed3f8e6926f742fe6e9b067cf78.png
print("torch.masked_index")
a = torch.linspace(1, 16, 16).view(4, 4)
mask = torch.gt(a, 8)
print(a)
print(mask)
out = torch.masked_select(a, mask)
print(out)
85036698b8b1b7de16a69ace09be29ea.png
print("torch.take")
a = torch.linspace(1, 16, 16).view(4, 4)

b = torch.take(a, index=torch.tensor([0, 15, 13, 10]))

print(b)
4af2352fd9a09d694c52aa542df7f1f4.png
#torch.nonzero
print("torch.take")
a = torch.tensor([[0, 1, 2, 0], [2, 3, 0, 1]])
out = torch.nonzero(a)
print(out)
#稀疏表示
b5daf6bdfcd870ece9ab3b8aae10d391.png

3 Tensor的组合/拼接

  • torch.cat(seq,dim=0,out=None):按照已经存在的维度进行拼接

  • torch.stack(seq,dim=0,out=None):沿着一个新维度对输入张量序列进行连接。序列中所有的张量都应该为相同形状。

print("torch.stack")
a = torch.linspace(1, 6, 6).view(2, 3)
b = torch.linspace(7, 12, 6).view(2, 3)
print(a, b)
out = torch.stack((a, b), dim=2)
print(out)
print(out.shape)

print(out[:, :, 0])
print(out[:, :, 1])
e70c1210576cbfe9038cba42d06fa9d4.png

4 Tensor的切片

  • torch.chunk(tensor,chunks,dim=0):按照某个维度平均分块(最后一个可能小于平均值)

  • torch.split(tensor,split_size_or_sections,dim=0):按照某个维度依照第二个参数给出的list或者int进行分割tensor

5 Tensor的变形操作

  • torch().reshape(input,shape)

  • torch().t(input):只针对2D tensor转置

  • torch().transpose(input,dim0,dim1):交换两个维度

  • torch().squeeze(input,dim=None,out=None):去除那些维度大小为1的维度

  • torch().unbind(tensor,dim=0):去除某个维度

  • torch().unsqueeze(input,dim,out=None):在指定位置添加维度,dim=-1在最后添加

  • torch().flip(input,dims):按照给定维度翻转张量

  • torch().rot90(input,k,dims):按照指定维度和旋转次数进行张量旋转

import torch
a = torch.rand(2, 3)
print(a)
out = torch.reshape(a, (3, 2))
print(out)
2a79c8410a4e2fb84048c14a81a088f1.png
print(a)
print(torch.flip(a, dims=[2, 1]))

print(a)
print(a.shape)
out = torch.rot90(a, -1, dims=[0, 2]) #顺时针旋转90°  
print(out)
print(out.shape)
72c4bfe8cdc1219ec5887dbfd8c73cf5.png

6 Tensor的填充操作

  • torch.full((2,3),3.14)

7 Tensor的频谱操作(傅里叶变换)

c93e20827928f9f78185d69b24134d76.png

如果觉得有用,就请分享到朋友圈吧!

8960c7a7d17b5fd5bec8cc3ff7f3aa1d.png

outside_default.png

点个在看 paper不断!

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值