Pytorch1.5入门笔记

此笔记仅供本人日后回顾使用
参考:

1.数据类型

Tensors 类似于 NumPy 的 ndarrays ,同时 Tensors 可以使用 GPU 进行计算。
在这里插入图片描述

1.1 类型检查

Pytorch数据类型的检查可以通过三个方式:

1)python内置函数type()

2)Tensor的成员函数Tensor.type()

3)Pytorch提供的工具函数isinstance()

import torch
 
a = torch.randn(2, 3)  # 2行3列,正态分布~N(0,1)
print(a)
print(type(a))
print(a.type())
print(isinstance(a, torch.FloatTensor))
tensor([[-0.7467,  1.3035,  0.0909],
        [-0.7687,  1.4444,  0.7397]])
<class 'torch.Tensor'>
torch.FloatTensor
True

从运行结果来看,python内置的类型检测函数type()只能检查该数据是Tensort类型,具体的基本数据类型无法检测出来,而Tensor的成员函数type()更加直观,可以检测出Tensor数据的基本类型。

isinstance()函数主要用于判断某数据是否属于某个数据类型,如果属于返回True,否则返回False

1.2 数据类型转换

Tensor类型的变量进行类型转换一般有两种方法:

1)Tensor类型的变量直接调用long(), int(), double(),float(),byte()等函数就能将Tensor进行类型转换;

2)在Tensor成员函数type()中直接传入要转换的数据类型。

当你不知道要转换为什么类型时,但需要求a1,a2两个张量的乘积,可以使用a1.type_as(a2)将a1转换为a2同类型

import torch
 
a = torch.randn(2, 3)
print(a.type())
 
# 转换为IntTensort类型
b = a.int()
 
# 转换为LongTensor类型
c = a.type(torch.LongTensor)
 
print(b.type())
print(c.type())
 
# 将a转换为与b相同的类型
a.type_as(b)
print(a.type())
torch.FloatTensor
torch.IntTensor
torch.LongTensor
torch.FloatTensor

1.3 Tensor与Numpy ndarray之间的转换

Tensor和numpy.ndarray之间还可以相互转换,其方式如下:

1)Numpy转化为Tensor:torch.from_numpy(numpy矩阵)

2)Tensor转化为numpy:Tensor矩阵.numpy()

import torch
import numpy as np
 
# 定义一个3行2列的全为0的矩阵
b = torch.randn((3, 2))
 
# tensor转化为numpy
numpy_b = b.numpy()
print(numpy_b)
 
# numpy转化为tensor
numpy_e = np.array([[1, 2], [3, 4], [5, 6]])
torch_e = torch.from_numpy(numpy_e)
 
print(numpy_e)
print(torch_e)
[[-3.4111385   0.28919587]
 [-0.527978    0.3108479 ]
 [-0.6355707  -0.47348464]]
[[1 2]
 [3 4]
 [5 6]]
tensor([[1, 2],
        [3, 4],
        [5, 6]], dtype=torch.int32)

1.4 CPU或GPU张量之间的转换

  1. CPU张量 ----> GPU张量, 使用Tensor.cuda()

  2. GPU张量 ----> CPU张量 使用Tensor.cpu()

我们可以通过torch.cuda.is_available()函数来判断当前的环境是否支持GPU,如果支持,则返回True。所以,为保险起见,在项目代码中一般采取“先判断,后使用”的策略来保证代码的正常运行,其基本结构如下:

import torch
 
# 定义一个3行2列的全为0的矩阵
tmp = torch.randn((3, 2))
 
# 如果支持GPU,则定义为GPU类型
if torch.cuda.is_available():
    inputs = tmp.cuda()
# 否则,定义为一般的Tensor类型
else:
    inputs = tmp

1.5 其他:向量相关

torch.tensor([ 数据 ])

torch.FloatTensor(维度)

1维的形状如何得到:

  • .size
  • .shape

几个概念:

dim:指的是size/shape的长度

size/shape指的是具体的形状

场景 : CNN

[b, c, h, w] b:几张照片 c: 通道 w:宽 h:高度

numel() 占用内存的大小

x = torch.empty(5,3) #构造一个5x3矩阵,不初始化
print(x)
tensor([[4.9971e-27, 7.4689e-43, 4.9969e-27],
        [7.4689e-43, 4.9971e-27, 7.4689e-43],
        [4.9971e-27, 7.4689e-43, 4.9970e-27],
        [7.4689e-43, 4.9970e-27, 7.4689e-43],
        [4.9981e-27, 7.4689e-43, 4.9981e-27]])
a = torch.randn(2,3) #随机生成一个正态分布的tensor
type(a) # 输出类型

torch.Tensor
len(a.shape)
2
torch.tensor(1.)# 0维度
tensor(1.)
torch.tensor([1.1,1.2])# 1维度
tensor([1.1000, 1.2000])
torch.FloatTensor(1) # 1维度, 给定长度为1,random初始化值
tensor([1.4013e-45])
torch.FloatTensor(2) # 1维度, 给定长度为2,random初始化值
tensor([4.2039e-45, 0.0000e+00])
c.shape
torch.Size([2, 3])
c.numel() # c占用内存的大小 
6
c.dim() # 维度
4
a.dim()
1

2.创建Tensor

2.1 从Numpy创建Tensor

import torch
import numpy as np
a = np.array([2,3.3])
a
array([2. , 3.3])
torch.from_numpy(a) # Import from numpy
tensor([2.0000, 3.3000], dtype=torch.float64)

2.2 从List创建Tensor

a = torch.FloatTensor([2, 3.3])  
b = torch.tensor([2, 3.3]) 
print(a)
print(b)
tensor([2.0000, 3.3000])
tensor([2.0000, 3.3000])

注意:小写的tensor只接受现有的数据;而大写的Tensor相当于就是FloatTensor,既可以接收现有的数据,也可以接受shape来创建指定形状的Tensor。为了避免混淆,建议接收现有数据的时候使用tensor,指定shape的时候使用Tensor。

2.3 指定维度创建Tensor

注意:通过指定维度创建的Tensor,初始化的值是随机的,数据不规则,容易出现问题。

torch.empty(5,3) # 构造一个5x3矩阵,不初始化
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
torch.FloatTensor(2, 3)
tensor([[-1.2763e+38,  4.8082e-34,  1.0862e-35],
        [ 2.3595e+07,  1.0804e-42,  5.5313e-05]])
torch.IntTensor(2, 3)
tensor([[-58718223,   2086912,         2],
        [        0,         1,         0]], dtype=torch.int32)

2.4 随机初始化创建Tensor

(1) rand(shape):生成shape维度的并且随机均匀采样0~1之间的数据

(2)rand_like(Tensor):形如*_like()函数,接收一个Tensor,并根据Tensor的shape生成对应维度的随机均匀采样的数据

(3)randint(min, max, shape):生成最小值为min,最大值为max(应该是不能等于max),维度为shape的随机均匀采样的数据

(4)randn(shape):生成均值为0,方差为1,shape维度的正态分布数据

x = torch.rand(5, 3) # 构造一个均匀分布随机初始化的矩阵
print(x)
tensor([[0.9699, 0.8675, 0.6635],
        [0.8416, 0.4337, 0.6939],
        [0.9386, 0.4049, 0.3297],
        [0.3012, 0.1845, 0.5385],
        [0.9237, 0.5872, 0.9059]])
x = torch.randn(5, 3) # 构造一个正态分布*(0,1)随机初始化的矩阵
print(x)
tensor([[ 0.4568, -0.6562,  0.2477],
        [-0.1330,  0.1071, -0.9438],
        [-0.8317, -1.4119, -0.8854],
        [-0.9180,  0.8430,  0.9199],
        [-0.9091, -0.2747,  0.5647]])
a=torch.rand(3,3)
print(a)

#与a形状类似的数组
b=torch.rand_like(a)
print(b)

#随机产生1~9的数,形状为3x3
c=torch.randint(1,10,(3,3))
print(c)

#产生一个全为7的2x3的张量
d=torch.full([2,3],7)
print(d)
tensor([[0.0270, 0.9725, 0.0142],
        [0.6040, 0.9199, 0.9383],
        [0.0540, 0.0522, 0.7065]])
tensor([[0.2406, 0.8567, 0.0787],
        [0.5401, 0.3085, 0.4922],
        [0.9112, 0.2104, 0.7182]])
tensor([[5, 6, 1],
        [3, 3, 7],
        [8, 5, 1]])
tensor([[7., 7., 7.],
        [7., 7., 7.]])

2.5 使用相同元素构建Tensor & 对角阵

ones/zeros/eye

全部是0/全部是1/单位tensor

full() 使用的相同的元素

注意:shape的指定方式是list方式。

a = torch.full([2,3],7)# 创建(2,3)的Tensor,元素都为7
a
tensor([[7., 7., 7.],
        [7., 7., 7.]])
x = torch.zeros(5, 3, dtype=torch.long)# 构造一个矩阵全为 0,而且数据类型是 long.
print(x)
print(x.type())
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])
torch.LongTensor
x = torch.ones(5, 3, dtype=torch.double)
# new_* methods take in sizes
print(x)
x = torch.randn_like(x, dtype=torch.float)# 创建一个 tensor 基于已经存在的 tensor
# override dtype!
print(x)
# result has the same size
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], dtype=torch.float64)
tensor([[ 0.9671,  0.3310,  0.2566],
        [ 1.2472,  0.4178,  0.6680],
        [ 0.2135,  0.0991, -0.8704],
        [ 0.5057, -0.5526, -1.1499],
        [ 1.1014, -2.0613,  0.2811]])
c = torch.eye(3, 4)  # 只能是二维的,传入dim=2的shape
c
tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.]])

也可以只给一个参数n,得到n阶的对角方阵:

c = torch.eye(4)  
print(c)
tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
print(x.size()) #维度信息
x.shape
torch.Size([5, 3])





torch.Size([5, 3])

2.6 指定参数的正态分布Tensor:

# 指定均值和标准差
a = torch.normal(mean=torch.full([10], 0), std=torch.arange(1, 0, -0.1))
a
tensor([ 0.8948,  1.0663, -0.0016,  0.2452, -0.3059,  0.3745,  0.4634,  0.0153,
         0.1452, -0.0325])

上面参数mean = [0, 0, 0, 0, 0, 0, 0, 0, 0,0],std = [1, 0.9, … , 0.1],得到的结果是1x10的Tensor,如果想得到其它shape的正态分布,需要在1x10的基础上reshape 为其它维度。

2.7 有序数列Tensor

a = torch.arange(0, 10) # 不包含10,默认步长为1
print(a)
torch.arange(0,10,2) # 生成一维向量,元素的值填充为0-10,步长为2
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])





tensor([0, 2, 4, 6, 8])

注意torch.range()是包含结尾的,但是已经被弃用了,一律用arange。

2.8 set_default_tensor_type设置默认tensor生成的数据类型

使用torch.tensor传入浮点数元素,或者使用torch.Tensor仅指定维度时,生成的默认是FloatTensor,也可以修改默认设置使其默认是其它类型的。

torch.tensor([1.2,3]).type() # 默认torch.FloatTensor
'torch.FloatTensor'
torch.set_default_tensor_type(torch.DoubleTensor) # 改变默认类型
torch.tensor([1.2,3]).type() # 默认类型已改变
'torch.DoubleTensor'
torch.set_default_tensor_type(torch.FloatTensor)

2.9 生成等分序列Tensor

torch.linspace()与torch.logspace()

torch.linspace(start, end, steps=100, out=None) → Tensor

返回一个1维张量,包含在区间start和end上均匀间隔的step个点。

输出张量的长度由steps决定。

参数:

start (float) - 区间的起始点

end (float) - 区间的终点

steps (int) - 在start和end间生成的样本数

out (Tensor, optional) - 结果张量

logspace(n, m, step=s)

从10的n次方取到10的m次方,指数是等差的,也就是元素值是等比的。

torch.linspace(0,10,steps=4) 
tensor([ 0.0000,  3.3333,  6.6667, 10.0000])
#生成10的0次方为起始值,10的-1次方为终止值的8个数构成的等比数列
c = torch.logspace(0,-1,steps=8)
c
tensor([1.0000, 0.7197, 0.5179, 0.3728, 0.2683, 0.1931, 0.1389, 0.1000])

2.10 randperm()

类似 random.shuffle

使用randperm可以生成一个从0开始的、已经打乱的连续索引Tensor,用它可以对其它Tensor做shuffle。特别是在有几个需要保持一致顺序的Tensor时,用相同的索引Tensor就能保证shuffle之后的Tensor在那个维度上的顺序一致了。

import torch
 
# 两个Tensor的shape[0]是相同的,都是3
a = torch.rand(3, 1)
b = torch.rand(3, 1)
print(a, b, sep='\n')
print("-" * 20)
 
# 制造一个[0,3)的索引序列
idx = torch.randperm(3)
print(idx)
print("-" * 20)
 
# 给a,b做shuffle,保证第一个维度在shuffle后的对应关系不变
a_sf = a[idx]
b_sf = b[idx]
print(a_sf, b_sf, sep='\n')
tensor([[0.7408],
        [0.1081],
        [0.4498]])
tensor([[0.9441],
        [0.8759],
        [0.3226]])
--------------------
tensor([1, 0, 2])
--------------------
tensor([[0.1081],
        [0.7408],
        [0.4498]])
tensor([[0.8759],
        [0.9441],
        [0.3226]])

3.索引

3.1 Pytorch风格的索引

import torch
 
a = torch.rand(4, 3, 28, 28)
print(a[0].shape) #取到第一个维度
print(a[0, 0].shape) # 取到二个维度
print(a[1, 2, 2, 4])  # 具体到某个元素
torch.Size([3, 28, 28])
torch.Size([28, 28])
tensor(0.5277)

上述代码创建了一个shape=[4, 3, 28, 28]的Tensor,我们可以理解为4张图片,每张图片有3个通道,每个通道是28x28的图像数据。

a代表这个Tensor,a后面跟着的列表[]表示对Tensor进行索引,a的维度dim = 4,决定了[]中的元素个数不能超过4个,[]中的值表示对应维度上的哪一个元素,比如 a[0]表示取第一个维度上的第一个元素,可以理解为第一张图片,a[1]表示取第一个维度上的第二个元素,可以理解为第二张图片。

a[0, 0]表示取第一个维度上第一个元素的与第二个维度上的第一个元素,也就是第一张图片第一个通道的元素。a[1, 2, 2, 4]表示取第第一个维度上的第二个元素与第二个维度上的第三个元素与第三个维度上的第三个元素与第四个维度上的第5个元素,也就是第二张图片第三个通道第三行第四列的像素值是一个标量值。


3.2 index_select()选择特定索引

具体的索引:index_select(dim, indices)

dim为维度,indices是索引序号

这里的indeces必须是tensor ,不能直接是一个list

a.shape
torch.Size([4, 3, 28, 28])
 a.index_select(0, torch.tensor([0,2])).shape	# 当前维度为0,取第0,2张图片
torch.Size([2, 3, 28, 28])
a.index_select(1, torch.tensor([1,2])).shape   	# 当前维度为1,取第1,2个通道
torch.Size([4, 2, 28, 28])
a.index_select(2,torch.arange(28)).shape		# 第二个参数,取28行【0,28】
torch.Size([4, 3, 28, 28])
a.index_select(2, torch.arange(8)).shape		# 取8行  [0,8)
torch.Size([4, 3, 8, 28])

选择特定下标有时候很有用,比如上面的a这个Tensor可以看作4张RGB(3通道)的MNIST图像,长宽都是28px。那么在第一维度上可以选择特定的图片,在第二维度上选择特定的通道,在第三维度上选择特定的行等:

# 选择第一张和第三张图
print(a.index_select(0, torch.tensor([0, 2])).shape)
 
# 选择R通道和B通道
print(a.index_select(1, torch.tensor([0, 2])).shape)
 
# 选择图像的0~8行
print(a.index_select(2, torch.arange(8)).shape)
torch.Size([2, 3, 28, 28])
torch.Size([4, 2, 28, 28])
torch.Size([4, 3, 8, 28])

3.3 使用 … 索引任意多的维度

表示任意多维度,根据实际的shape来推断。

当有 … 出现时,右边的索引理解为最右边

为什么会有它,没有它的话,存在这样一种情况 a[0,: ,: ,: ,: ,: ,: ,: ,: ,: ,2] 只对最后一个维度做了限度,这个向量的维度又很高,以前的方式就不太方便了。

 a[...].shape		# 所有维度
torch.Size([4, 3, 28, 28])
a[0,...].shape		# 后面都有,取第0个图片 = a[0]
torch.Size([3, 28, 28])
a[:,1,...].shape
torch.Size([4, 28, 28])
a[...,:2].shape		# 当有...出现时,右边的索引理解为最右边,只取两列
torch.Size([4, 3, 28, 2])

3.4 使用mask索引

masked_select()

求掩码位置原来的元素大小

缺点:会把数据,默认打平(flatten),

x = torch.randn(3,4)
x
tensor([[ 0.4645, -0.9875, -1.7468, -0.5176],
        [ 0.3795, -1.1566, -0.8264,  1.2487],
        [ 1.7771,  0.8601,  1.5568, -1.3688]])
mask = x.ge(0.5)          # >= 0.5 的元素的位置上为True,其余地方为False
mask
tensor([[False, False, False, False],
        [False, False, False,  True],
        [ True,  True,  True, False]])
torch.masked_select(x,mask) # 之所以打平是因为大于0.5的元素个数是根据内容才能确定的
tensor([1.2487, 1.7771, 0.8601, 1.5568])
torch.masked_select(x,mask).shape
torch.Size([4])

3.5 take索引

take索引是在原来Tensor的shape基础上打平,然后在打平后的Tensor上进行索引

torch.take(src, torch.tensor([index]))

打平后,按照index来取对应位置的元素

src = torch.tensor([[4,3,5],[6,7,8]])		# 先打平成1维的,共6列
src
tensor([[4, 3, 5],
        [6, 7, 8]])
torch.take(src, torch.tensor([0, 2, 5]))	# 取打平后编码,位置为0 2 5
tensor([4, 5, 8])

4.切片(python风格)

顾头不顾尾,起始位置计入切片,截止位置不计入切片,“:”为默认取该位置/维度所有内容

a.shape
torch.Size([4, 3, 28, 28])
a[:2].shape # 前面两张图片的所有数据
torch.Size([2, 3, 28, 28])
 a[:2,:1,:,:].shape      # 前面两张图片的第0通道的数据  
torch.Size([2, 1, 28, 28])
a[:2,1:,:,:].shape		# 前面两张图片,第1,2通道的数据
torch.Size([2, 2, 28, 28])
a[:2,-1:,:,:].shape		# 前面两张图片,-1表示最后一个通道的数据  从-1到最末尾,就是它本身
torch.Size([2, 1, 28, 28])

顾头不顾尾 + 步长

start : end : step

对于步长为1的,通常就省略了。

a[:,:,0:28,0:28:2].shape    # 隔点采样
torch.Size([4, 3, 28, 14])
a[:,:,::2,::2].shape
torch.Size([4, 3, 14, 14])

5. 维度变换

维度变化改变的是数据的理解方式!

  • view/reshape:大小不变的条件下,转变shape
  • squeeze/unsqueeze:减少/增加维度
  • transpose/t/permute:转置,单次/多次交换
  • expand/repeat:维度扩展

5.1 改变shape --[view/reshape]

  • 在pytorch0.3的时候,默认是view .为了与numpy一致0.4以后增加了reshape。
  • 损失维度信息,如果不额外存储/记忆的话,恢复时会出现问题。
  • 执行view/reshape是有一定的物理意义的,不然不会这样做。
  • 保证tensor的size不变即可/numel()一致/元素个数不变。
  • 数据的存储/维度顺序非常非常非常重要
 a = torch.rand(4,1,28,28)
 a.shape
torch.Size([4, 1, 28, 28])
a.view(4,28*28)# 4, 1*28*28 将后面的进行合并/合并通道,长宽,忽略了通道(channel为1)信息,上下左右的
tensor([[0.4669, 0.5704, 0.7877,  ..., 0.1791, 0.5174, 0.3081],
        [0.4520, 0.8231, 0.7201,  ..., 0.4881, 0.8215, 0.4224],
        [0.0197, 0.2802, 0.6098,  ..., 0.8035, 0.7624, 0.8517],
        [0.1558, 0.3465, 0.9982,  ..., 0.1680, 0.4229, 0.8472]])
a.view(4,28*28).shape
torch.Size([4, 784])
a.view(4*28, 28).shape    # 合并batch,channel,行合并 放在一起为N [N,28] 每个N,刚好有28个像素
torch.Size([112, 28])
a.view(4*1,28,28).shape	# 4张叠起来了
torch.Size([4, 28, 28])
# ❌错误示范
b = a.view(4,784)  # a原来的维度信息是[b,c,h,w],但a这样赋值后,它是恢复不到原来的
b.view(4,28,28,1)  # logic Bug  # 语法上没有问题,但逻辑上 [b h w c] 与以前是不对应的。
tensor([[[[0.4669],
          [0.5704],
          [0.7877],
          ...,
          [0.5856],
          [0.3245],
          [0.6232]],

         [[0.1831],
          [0.2813],
          [0.6289],
          ...,
          [0.4807],
          [0.7636],
          [0.3489]],

         [[0.2941],
          [0.5808],
          [0.9420],
          ...,
          [0.9006],
          [0.7880],
          [0.1321]],

         ...,

         [[0.5025],
          [0.7782],
          [0.3745],
          ...,
          [0.4442],
          [0.2080],
          [0.4459]],

         [[0.1220],
          [0.3747],
          [0.2139],
          ...,
          [0.5086],
          [0.5290],
          [0.4658]],

         [[0.7972],
          [0.4315],
          [0.9320],
          ...,
          [0.1791],
          [0.5174],
          [0.3081]]],


​ [[[0.4520],
​ [0.8231],
​ [0.7201],
​ …,
​ [0.4736],
​ [0.8436],
​ [0.2378]],

​ [[0.6897],
​ [0.2226],
​ [0.0204],
​ …,
​ [0.5949],
​ [0.0052],
​ [0.4773]],

[[0.4987],
[0.4569],
[0.5842],
…,
[0.0182],
[0.7573],
[0.6908]],

         ...,

         [[0.0894],
          [0.0050],
          [0.3347],
          ...,
          [0.8968],
          [0.7086],
          [0.6496]],

         [[0.6287],
          [0.7542],
          [0.0427],
          ...,
          [0.4744],
          [0.3224],
          [0.5295]],

         [[0.7941],
          [0.7476],
          [0.3190],
          ...,
          [0.4881],
          [0.8215],
          [0.4224]]],


​ [[[0.0197],
​ [0.2802],
​ [0.6098],
​ …,
​ [0.5006],
​ [0.5052],
​ [0.8960]],

​ [[0.9132],
​ [0.6296],
​ [0.1584],
​ …,
​ [0.2940],
​ [0.4965],
​ [0.7288]],

[[0.0660],
[0.0540],
[0.6224],
…,
[0.8429],
[0.6459],
[0.0344]],

         ...,

         [[0.1814],
          [0.6548],
          [0.7547],
          ...,
          [0.4313],
          [0.7968],
          [0.8281]],

         [[0.1409],
          [0.7694],
          [0.8473],
          ...,
          [0.7594],
          [0.4243],
          [0.1088]],

         [[0.3739],
          [0.2026],
          [0.9536],
          ...,
          [0.8035],
          [0.7624],
          [0.8517]]],


​ [[[0.1558],
​ [0.3465],
​ [0.9982],
​ …,
​ [0.5428],
​ [0.5633],
​ [0.5330]],

​ [[0.4698],
​ [0.6095],
​ [0.2370],
​ …,
​ [0.4649],
​ [0.5074],
​ [0.8569]],

[[0.8299],
[0.8584],
[0.8836],
…,
[0.1807],
[0.5713],
[0.8904]],

         ...,

         [[0.9641],
          [0.7628],
          [0.2850],
          ...,
          [0.2401],
          [0.4738],
          [0.1625]],

         [[0.5041],
          [0.1227],
          [0.5483],
          ...,
          [0.1078],
          [0.0815],
          [0.7116]],

         [[0.5174],
          [0.2978],
          [0.5122],
          ...,
          [0.1680],
          [0.4229],
          [0.8472]]]])

5.2 squeeze 与 unsqueeze 增删维度

5.2.1 unsqueeze

torch.unsqueeze(index)可以为Tensor增加一个维度,增加的这一个维度的位置由我们自己定义,新增加的这一个维度不会改变数据本身,只是为数据新增加了一个组别,这个组别是什么由我们自己定义。

  • unsqueeze(index) 拉伸(增加一个维度) (增加一个组别)

  • 参数的范围是 [-a.dim()-1, a.dim()+1) 如下面例子中范围是[-5,5)

  • -5 –> 0 … -1 –> 4 这样的话,0表示在前面插入,-1表示在后面插入,正负会有些混乱,所以推荐用正数。

  • 0与正数,就是在xxx前面插入。

a.shape
torch.Size([4, 1, 28, 28])

这个Tensor有4个维度,我们可以在现有维度的基础上插入一个新的维度,插入维度的index在[-a.dim()-1, a.dim()+1]范围内,并且当index>=0,则在index前面插入这个新增加的维度;当index < 0,则在index后面插入这个新增的维度。

a.unsqueeze(0).shape	# 在0的前面插入一个维度
torch.Size([1, 4, 1, 28, 28])
a.unsqueeze(-1).shape	# 在-1之后插入一个维度
torch.Size([4, 1, 28, 28, 1])
a.unsqueeze(4).shape
torch.Size([4, 1, 28, 28, 1])
a.unsqueeze(-4).shape
torch.Size([4, 1, 1, 28, 28])
a.unsqueeze(-5).shape
torch.Size([1, 4, 1, 28, 28])
a.unsqueeze(-6).shape # 超出Dimension范围
---------------------------------------------------------------------------

IndexError                                Traceback (most recent call last)

<ipython-input-54-3ae5b5ed17c8> in <module>
----> 1 a.unsqueeze(-6).shape # 超出Dimension范围


IndexError: Dimension out of range (expected to be in range of [-5, 4], but got -6)
a = torch.tensor([1.2,2.3])
print(a)
a.shape
tensor([1.2000, 2.3000])





torch.Size([2])
a.unsqueeze(-1)  # 维度变成 [2,1]  2行1列
tensor([[1.2000],
        [2.3000]])
 a.unsqueeze(0) # 维度变成 [1,2]  1行2列
tensor([[1.2000, 2.3000]])

给一个bias(偏置),bias相当于给每个channel上的所有像素增加一个偏置

为了做到 f+b 我们需要改变b的维度

f = torch.rand(4,32,14,14)
print('f:',f.shape)
b = torch.rand(32)
print('b:',b.shape)
b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print('b:',b.shape)
# 后面进一步扩张到[4,32,14,14]
f: torch.Size([4, 32, 14, 14])
b: torch.Size([32])
b: torch.Size([1, 32, 1, 1])

5.2.2 squeeze

删减维度实际上是一个维度挤压的过程,直观地看是把那些多余的[]给去掉,也就是只是去删除那些size=1的维度。

  • squeeze(index) 当index对应的dim为1,就产生作用。
  • 不写参数,会挤压所有维度为1的。
b.shape
torch.Size([1, 32, 1, 1])
b.squeeze().shape  # 默认将所有维度为1的进行挤压 这32个channel,每个channel有一个值
torch.Size([32])
b.squeeze(0).shape #挤压batch维度,batch维度为1,可挤压
torch.Size([32, 1, 1])
b.squeeze(-1).shape #挤压width维度,width维度为1,可挤压
torch.Size([1, 32, 1])
b.squeeze(1).shape #挤压channel维度,channel维度为32(非1),不可挤压,维持原样
torch.Size([1, 32, 1, 1])
b.squeeze(-4).shape #挤压batch维度,batch维度为1,可挤压
torch.Size([32, 1, 1])

5.3 维度扩展 /维度重复–>expand / repeat

  • Expand:broadcasting (推荐)

    • 只是改变了理解方式,并没有增加数据

    • 在需要的时候复制数据

  • Reapeat:memory copied

    • 会实实在在的增加数据

上面提到的b [1, 32, 1, 1] f[ 4, 32, 14, 14 ]

目标是将b的维度变成与f相同的维度。

expand 扩展(expand)张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view)
实际用例:

我们有一个shape=[4, 32, 14,14]的Tensor data,相当于4张图片,每张图片32个通道,每个通道行为14,列为14的图像数据,需要将每个通道上的所有像素增加一个偏置bias。图像数据的channel=32,因此bias = torch.rand(32),但是还不能完成data + bias的操作,因为两者dim与shape不一致。为了使得dim一致,需要增加bias的维度到4维,这就用到了unsqueeze()函数;为了使得shape一致,需要bias的4个维度的shape=[4, 32, 14, 14],这就用到了维度扩展expand。

import torch
 
bias = torch.rand(32)
data = torch.rand(4, 32, 14, 14)
 
# 想要把bias加到data上面去
# 先进行维度增加
bias = bias.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(bias.shape)
 
# 再进行维度扩展
bias = bias.expand(4, -1, 14, 14)  # -1表示这个维度保持不变,这里写32也可以
print(bias.shape)
 
data + bias
torch.Size([1, 32, 1, 1])
torch.Size([4, 32, 14, 14])





tensor([[[[0.8960, 1.5496, 0.9050,  ..., 1.1332, 0.7219, 1.0841],
          [1.2384, 0.6549, 1.0343,  ..., 0.7361, 1.5057, 1.4857],
          [1.2363, 0.9493, 1.3008,  ..., 0.9948, 0.8533, 1.6006],
          ...,
          [1.2948, 1.5342, 1.3291,  ..., 0.9560, 1.4789, 1.5602],
          [1.2962, 1.6183, 0.8428,  ..., 0.7917, 1.4077, 1.5777],
          [1.5127, 1.3154, 1.5063,  ..., 1.0158, 1.6256, 1.3794]],

         [[0.4441, 0.2225, 0.1535,  ..., 0.6123, 0.9616, 0.5543],
          [0.7507, 0.6727, 0.8909,  ..., 0.3457, 0.3017, 0.4114],
          [0.9226, 0.3908, 0.7864,  ..., 0.9857, 0.9885, 0.5947],
          ...,
          [0.8641, 1.1385, 0.1861,  ..., 0.3100, 0.6958, 0.5509],
          [0.9549, 1.0589, 0.7720,  ..., 0.5826, 1.1375, 0.9327],
          [0.5031, 1.0676, 1.0717,  ..., 0.2520, 0.6183, 0.8391]],

         [[0.7972, 1.6426, 1.0653,  ..., 1.7070, 1.1878, 1.4111],
          [0.9750, 1.4498, 1.0564,  ..., 0.9259, 1.0028, 1.6294],
          [1.0747, 1.6507, 1.4046,  ..., 1.1628, 1.0327, 1.4641],
          ...,
          [1.1980, 1.1770, 0.8256,  ..., 0.8555, 1.7524, 1.4209],
          [1.4331, 1.4346, 0.8684,  ..., 1.5482, 1.6331, 1.7779],
          [0.9896, 1.2697, 1.3178,  ..., 1.5928, 1.3262, 1.7125]],

         ...,

         [[1.0823, 0.9267, 0.7842,  ..., 0.6663, 1.1005, 0.8686],
          [1.2478, 1.1958, 0.3029,  ..., 1.0981, 0.6295, 0.5600],
          [1.1592, 0.9508, 1.1857,  ..., 0.8862, 0.9153, 0.8447],
          ...,
          [0.6752, 0.6254, 1.1716,  ..., 0.6047, 0.9779, 1.0749],
          [0.7814, 0.7104, 0.5610,  ..., 0.7313, 0.8480, 0.5892],
          [1.1465, 1.1682, 0.4911,  ..., 0.8453, 0.5524, 1.2125]],

         [[0.9857, 1.0350, 0.3590,  ..., 1.0411, 0.9470, 0.7949],
          [0.9931, 0.8492, 0.4514,  ..., 0.6077, 0.9882, 0.5753],
          [0.2135, 0.3392, 0.3776,  ..., 0.6816, 1.0357, 0.1273],
          ...,
          [0.6319, 0.9481, 1.0224,  ..., 0.7874, 0.7231, 0.1123],
          [0.6008, 0.9750, 0.8735,  ..., 0.6382, 0.7437, 0.1621],
          [0.5434, 0.8973, 0.7807,  ..., 0.3193, 0.4836, 0.5635]],

         [[0.7691, 1.1445, 0.8433,  ..., 0.3676, 1.1207, 1.1451],
          [0.7049, 0.4491, 0.4440,  ..., 0.3983, 0.7808, 0.6401],
          [0.6629, 0.4251, 1.0158,  ..., 1.1307, 0.3601, 1.2655],
          ...,
          [1.0580, 1.0222, 0.8507,  ..., 0.8253, 0.9767, 1.2937],
          [1.2397, 1.2240, 0.8112,  ..., 0.4000, 1.2583, 0.7200],
          [0.7175, 0.6339, 0.9625,  ..., 0.5887, 0.6523, 0.7274]]],


​ [[[1.5821, 1.0126, 1.6251, …, 1.4024, 1.1506, 0.9686],
​ [1.2884, 1.1859, 1.1280, …, 1.5143, 0.8557, 1.2237],
​ [1.1607, 0.8949, 0.7522, …, 0.9110, 0.9019, 0.9924],
​ …,
​ [0.7478, 0.8074, 0.8327, …, 1.4760, 1.2925, 0.8008],
​ [0.8813, 1.5677, 1.2962, …, 0.9556, 1.3413, 0.6787],
​ [1.3487, 1.3935, 1.1745, …, 1.2406, 1.3121, 1.5710]],

​ [[0.7101, 1.0555, 0.7033, …, 0.7977, 1.0045, 0.5837],
​ [0.6229, 0.6162, 0.6749, …, 1.0871, 0.5055, 0.5083],
​ [1.1294, 0.7006, 1.0414, …, 0.8605, 1.0098, 0.9168],
​ …,
​ [0.9028, 0.6679, 0.4674, …, 0.8716, 0.8924, 0.9703],
​ [0.4764, 0.7611, 0.7661, …, 0.8754, 1.0787, 1.0647],
​ [0.5523, 0.4778, 0.1919, …, 0.3149, 0.2454, 0.3703]],

[[1.6638, 1.6593, 1.6733, …, 1.7072, 1.3977, 1.1409],
[1.4301, 1.0561, 1.5163, …, 1.7737, 1.7532, 1.2684],
[1.2233, 1.3827, 1.1962, …, 0.8356, 1.3332, 1.4315],
…,
[1.1889, 1.7574, 1.7767, …, 0.9843, 1.0351, 1.5139],
[0.9718, 0.9305, 1.7841, …, 1.7643, 0.8967, 1.1640],
[0.8402, 1.0112, 1.7218, …, 1.3540, 1.7035, 1.0860]],

         ...,

         [[0.3265, 0.3134, 0.3151,  ..., 0.7248, 0.9859, 0.6728],
          [0.5627, 0.3210, 0.8208,  ..., 0.6669, 0.3255, 0.6620],
          [0.3609, 0.4460, 0.4548,  ..., 0.3164, 1.1621, 0.3869],
          ...,
          [0.3386, 1.1839, 1.0291,  ..., 1.0107, 1.2012, 0.2615],
          [0.5268, 1.2531, 1.1718,  ..., 0.9357, 0.8857, 1.2049],
          [0.4103, 1.0125, 1.1386,  ..., 0.7609, 0.3334, 0.4204]],

         [[0.8981, 0.8913, 0.4239,  ..., 0.3515, 0.4484, 0.8452],
          [1.0060, 0.1983, 0.1685,  ..., 0.9507, 0.5422, 1.0523],
          [0.3541, 0.1084, 0.9824,  ..., 0.6919, 0.2175, 0.3640],
          ...,
          [0.5873, 0.7695, 0.6969,  ..., 0.2764, 0.9349, 0.5279],
          [0.7354, 0.6831, 0.4657,  ..., 0.3226, 0.2009, 0.1248],
          [0.6900, 0.5530, 0.9687,  ..., 0.0911, 1.0034, 0.2468]],

         [[0.5653, 1.3259, 1.3363,  ..., 1.1609, 0.8436, 0.4663],
          [0.7371, 0.6001, 0.8747,  ..., 0.9240, 1.0972, 0.7577],
          [0.4713, 0.9068, 0.6100,  ..., 0.7166, 0.7765, 1.0245],
          ...,
          [1.3455, 1.2005, 0.7732,  ..., 0.6486, 0.7815, 0.6680],
          [0.9508, 0.4244, 0.9604,  ..., 1.1831, 0.7802, 0.7836],
          [0.7778, 0.7078, 1.1940,  ..., 0.3695, 0.8826, 1.0622]]],


​ [[[1.4143, 1.2995, 0.8139, …, 1.2896, 1.0774, 1.4234],
​ [1.3680, 1.4998, 1.3175, …, 1.3321, 0.8465, 0.7184],
​ [1.1172, 1.3721, 1.0740, …, 1.3989, 1.2516, 1.5627],
​ …,
​ [1.4785, 0.7343, 1.5378, …, 0.8592, 1.1209, 0.6605],
​ [0.9701, 1.4871, 1.0067, …, 0.7337, 0.8663, 1.1495],
​ [1.3008, 0.6869, 0.9265, …, 1.0017, 1.0166, 0.8520]],

​ [[0.7263, 0.3851, 0.2985, …, 0.8997, 0.8832, 0.4244],
​ [0.3260, 0.2953, 0.2261, …, 1.0073, 0.2056, 0.4568],
​ [0.7663, 0.6739, 1.0634, …, 0.9713, 0.5267, 0.8518],
​ …,
​ [0.8437, 0.9616, 0.3930, …, 0.6469, 0.7048, 0.3315],
​ [0.8692, 0.3884, 0.8911, …, 0.6669, 0.8625, 0.2133],
​ [1.1026, 0.7332, 0.7229, …, 0.2621, 0.4842, 0.4307]],

[[1.0641, 1.2921, 1.4621, …, 1.5553, 0.9183, 1.0367],
[1.0713, 1.2038, 1.3822, …, 1.4696, 1.4837, 1.6925],
[1.0082, 1.3797, 1.6363, …, 0.8198, 1.0169, 1.3877],
…,
[1.6118, 1.1801, 1.6943, …, 0.9778, 1.6738, 1.3158],
[1.1205, 1.3153, 1.2241, …, 1.0480, 1.7347, 1.5804],
[1.3998, 1.3264, 1.7306, …, 1.1373, 1.7730, 1.6030]],

         ...,

         [[0.3273, 0.7954, 1.1417,  ..., 0.8539, 0.3452, 0.6718],
          [0.8258, 0.6311, 0.8442,  ..., 1.0350, 0.9399, 0.6821],
          [0.5262, 0.3852, 0.5783,  ..., 1.0395, 0.8961, 0.3328],
          ...,
          [0.2712, 0.3258, 0.4477,  ..., 1.0990, 0.5651, 1.2033],
          [0.4650, 0.4297, 0.2735,  ..., 1.2499, 0.6360, 0.8601],
          [0.4908, 0.6670, 0.2758,  ..., 1.1991, 0.6807, 0.9915]],

         [[0.7808, 0.3267, 0.6485,  ..., 0.3851, 0.9846, 0.3835],
          [0.3318, 1.0238, 0.1418,  ..., 0.3042, 0.9472, 0.5742],
          [0.3162, 0.6359, 0.5158,  ..., 0.9014, 0.2424, 0.2757],
          ...,
          [0.9427, 0.7432, 0.1039,  ..., 0.4693, 0.5416, 0.3310],
          [0.4778, 0.9404, 0.4470,  ..., 0.5457, 0.3839, 0.9550],
          [0.6661, 0.7379, 0.1209,  ..., 0.7481, 0.1007, 0.7574]],

         [[0.4333, 0.4733, 1.1029,  ..., 1.1464, 0.8244, 0.6725],
          [0.9565, 0.5544, 0.6506,  ..., 1.0873, 1.1043, 0.8266],
          [1.1002, 0.5953, 1.3077,  ..., 0.9131, 0.3614, 0.8806],
          ...,
          [0.3689, 1.2011, 1.2095,  ..., 0.5513, 0.5514, 1.1391],
          [0.8860, 0.8691, 0.3859,  ..., 1.0626, 0.3986, 1.2656],
          [0.9867, 0.5254, 0.6492,  ..., 0.7394, 1.1726, 0.3677]]],


​ [[[0.7826, 1.1791, 0.8322, …, 0.8235, 1.0302, 1.0302],
​ [0.7702, 0.7221, 0.7272, …, 1.1678, 1.1863, 1.2819],
​ [1.2859, 1.5474, 1.4420, …, 0.7208, 1.5377, 1.0424],
​ …,
​ [1.5282, 0.9626, 0.6973, …, 1.5183, 1.5253, 1.2377],
​ [0.9410, 0.9411, 1.2357, …, 1.3346, 1.0733, 1.1061],
​ [1.5082, 1.2841, 0.8869, …, 0.9085, 1.4093, 0.9720]],

​ [[0.9618, 0.3823, 0.8777, …, 0.2757, 0.6889, 0.7530],
​ [0.6587, 1.1392, 0.9078, …, 0.4723, 1.1280, 1.0106],
​ [0.9830, 0.6176, 0.9818, …, 0.3107, 0.2464, 0.4503],
​ …,
​ [0.2503, 0.8345, 0.9273, …, 0.7442, 0.3182, 0.3155],
​ [0.4500, 0.1692, 1.0537, …, 1.1179, 0.5593, 1.0297],
​ [0.5964, 0.7033, 0.6544, …, 0.7268, 0.8003, 0.2025]],

[[1.6139, 1.1706, 1.1551, …, 1.1415, 0.8433, 1.0383],
[1.3710, 1.1335, 1.4540, …, 0.8371, 1.7301, 0.9495],
[1.6506, 1.7513, 1.4456, …, 1.2237, 1.5025, 1.6623],
…,
[1.1021, 1.4716, 1.4360, …, 1.2542, 1.4605, 1.2551],
[0.9401, 1.7444, 1.5189, …, 1.1773, 1.0009, 1.6454],
[0.9772, 0.9761, 1.6185, …, 1.1677, 1.4474, 1.5011]],

         ...,

         [[0.6740, 1.0540, 1.1193,  ..., 1.0580, 0.4050, 0.7746],
          [1.1201, 0.5247, 0.4629,  ..., 0.2560, 0.6034, 0.7429],
          [1.0713, 0.2652, 0.2633,  ..., 0.6131, 0.9528, 1.2320],
          ...,
          [0.7086, 1.0358, 0.7211,  ..., 1.2341, 0.9666, 0.4551],
          [0.8569, 0.3819, 1.2362,  ..., 0.3135, 0.7698, 0.3801],
          [0.6370, 0.6134, 0.4177,  ..., 0.6679, 0.4445, 0.5533]],

         [[0.9399, 0.3603, 0.3353,  ..., 0.3444, 0.5325, 0.7679],
          [0.4846, 0.3284, 0.6561,  ..., 0.2825, 0.8862, 0.1674],
          [0.4866, 0.5143, 0.5387,  ..., 0.7211, 0.6506, 0.4244],
          ...,
          [1.0203, 0.2508, 0.3907,  ..., 0.9631, 0.9657, 0.4968],
          [1.0628, 0.3519, 0.6227,  ..., 0.4153, 0.6817, 0.3877],
          [0.9314, 0.9506, 0.3368,  ..., 0.4821, 0.4303, 1.0270]],

         [[0.9297, 0.8843, 0.9464,  ..., 0.3750, 0.4969, 1.1617],
          [0.7578, 1.2751, 0.6529,  ..., 0.6165, 0.4856, 0.7537],
          [1.3444, 1.2611, 0.4272,  ..., 1.0863, 1.2938, 1.0524],
          ...,
          [0.8385, 0.7071, 0.5467,  ..., 0.7337, 1.2096, 0.5874],
          [0.5056, 1.3452, 0.7003,  ..., 0.5440, 0.9026, 0.6525],
          [1.1298, 0.9594, 0.5471,  ..., 0.7591, 0.5507, 1.2535]]]])
b.expand(-1,32,-1,-4).shape	# -4这里是一个bug,没有意义,最新版已经修复了
torch.Size([1, 32, 1, -4])

repeat

repeat就是将每个位置的维度都重复至指定的次数,以形成新的Tensor,功能与维度扩展一样,但是repeat会重新申请内存空间,repeat()参数表示各个维度指定的重复次数。

  • 主动复制原来的。
  • 参数表示的是要拷贝的次数/是原来维度的倍数
  • 沿着特定的维度重复这个张量,和expand()不同的是,这个函数拷贝张量的数据.
print("b:",b.shape)
print("a:",a.shape)
# 将b复制为和a 的shape对齐
b: torch.Size([1, 32, 1, 1])
a: torch.Size([4, 32, 14, 14])
b.repeat(4,32,1,1).shape # 复制:[1*4,32*32,1*1,1*1]
torch.Size([4, 1024, 1, 1])
b.repeat(4,1,1,1).shape
torch.Size([4, 32, 1, 1])
b.repeat(4,1,32,32)
b.repeat(4,1,32,32).shape

torch.Size([4, 32, 32, 32])
b.repeat(4,1,14,14).shape	# 这样就达到目标了
torch.Size([4, 32, 14, 14])

5.4 转置操作

[.t] Pytorch的转置操作只适用于dim=2的Tensor,也就是矩阵

a = torch.randn(3,4)
a
tensor([[-0.7498, -0.0034, -1.3071,  0.6672],
        [ 1.9724,  1.2990, -0.4861,  0.0552],
        [ 0.6652,  0.0232,  0.6340, -0.4155]])
a.shape
torch.Size([3, 4])
a.t()
tensor([[-0.7498,  1.9724,  0.6652],
        [-0.0034,  1.2990,  0.0232],
        [-1.3071, -0.4861,  0.6340],
        [ 0.6672,  0.0552, -0.4155]])
a.t().shape
torch.Size([4, 3])
b.shape
torch.Size([1, 32, 1, 1])
b.t()
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-103-b510e0f64f40> in <module>
----> 1 b.t()


RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D

5.5 维度变换

(1) transpose(dim1, dim2)交换dim1与dim2

注意这种交换使得存储不再连续,再执行一些reshape的操作会报错,所以要调用一下contiguous()使其变成连续的维度。

  • 在结合view使用的时候,view会导致维度顺序关系变模糊,所以需要人为跟踪。
  • 错误的顺序,会导致数据污染
  • 一次只能两两交换
  • contiguous

有一些对Tensor的操作不会真正改变Tensor的内容(真实维度),不会开辟新内存空间来存放处理之后的数据,新数据与原始数据共享同一块内存,改变的仅仅是Tensor中字节位置的索引。
这些操作如下:narrow(), view(), expand(), transpose()
因此在进行这些操作之前,需要使用contiguous()确保将数据划分到整块内存
参考:https://www.zhihu.com/question/60321866

a = torch.randn(4,3,32,32)
a.shape
torch.Size([4, 3, 32, 32])
a1 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,3,32,32)
#[b c h w] 交换1,3维度的数据 [b w h c],再把后面的三个连在一起,展开后变为 [b c w h] 导致和原来的顺序不同
a1.shape
torch.Size([4, 3, 32, 32])
a2 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,32,32,3).transpose(1,3)
# [b c h w] -> [b w h c] -> [b w h c] -> [b c h w] 和原来顺序相同。
a2.shape
torch.Size([4, 3, 32, 32])
# 验证向量一致性
torch.all(torch.eq(a,a1))
tensor(False)
# 验证向量一致性
torch.all(torch.eq(a,a2))
tensor(True)

(2) permute

如果四个维度表示上节的[batch,channel,h,w] ,如果想把channel放到最后去,形成[batch,h,w,channel],那么如果使用前面的维度交换,至少要交换两次(先13交换再12交换)。而使用permute可以直接指定维度新的所处位置,更加方便。

  • 会打乱内存顺序

  • 由于transpose一次只能两两交换,所以变换后在变回去至少需要两次操作,而permute一次就好。例如对于[b,h,w,c]

  • [b,h,w,c]是numpy存储图片的格式,需要这一步才能导出numpy

a = torch.rand(4,3,28,28)
a.transpose(1,3).shape	# [b c h w] -> [b w h c]  h与w的顺序发生了变换,导致图像发生了变化
torch.Size([4, 28, 28, 3])
b = torch.rand(4,3,28,32)
b.transpose(1,3).shape
torch.Size([4, 32, 28, 3])
b.transpose(1,3).transpose(1,2).shape# [b,h,w,c]是numpy存储图片的格式,需要这一步才能导出numpy
torch.Size([4, 28, 32, 3])
b.permute(0,2,3,1).shape # 调整维度顺序一步到位
torch.Size([4, 28, 32, 3])

6.Broadcast

Broadcasting能够实现Tensor自动维度增加(unsqueeze)与维度扩展(expand),主要按照如下步骤进行:

  • 从最后面的维度开始匹配;
  • 在前面插入若干维度,进行unsqueeze操作;
  • 将维度的size从1通过expand变到和某个Tensor相同的维度。

总之,Broadcasting也就是自动实现了若干unsqueeze和expand操作,以使两个Tensor的shape一致,从而完成某些操作,往往是加法操作。

Broadcasting 是指,在运算中,不同大小的两个 array 应该怎样处理的操作。通常情况下,小一点的数组会被 broadcast 到大一点的,这样才能保持大小一致。Broadcasting 过程中的循环操作都在 C 底层进行,所以速度比较快。但也有一些情况下 Broadcasting 会带来性能上的下降。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6tFOCP3u-1594740384688)(attachment:image.png)]

自动扩展:

维度扩展,自动调用expand
without copying data ,不需要拷贝数据。
参考 https://www.w3cschool.cn/tensorflow_python/tensorflow_python-nt3v2hcn.html

核心思想
我们希望进行几种计算,但需要满足数学上的约束(size相同),为了节省人们为满足数学上的约束而手动复制的过程,而产生的Broadcast,它节省了大量的内容消耗。

例如:对于 feature maps : [4, 32, 14, 14],想给它添加一个偏置Bias

需要进行操作Bias:[32] –> [32, 1 , 1] (这里是手动的) => [1, 32, 1, 1] => [4, 32, 14, 14]

目标:当Bias和feature maps的size一样时,才能执行叠加操作!!


两个 Tensors 只有在下列情况下才能进行 broadcasting 操作:

  • 每个 tensor 至少有一维

  • 遍历所有的维度,从尾部维度开始,每个对应的维度大小要么相同,要么其中一个是 1,要么其中一个不存在。

x=torch.empty(5,7,3)
y=torch.empty(5,7,3)
# 相同维度,一定可以 broadcasting
x=torch.empty((0,))
y=torch.empty(2,2)
# x 没有符合“至少有一个维度”,所以不可以 broadcasting
# 按照尾部维度对齐
x=torch.empty(5,3,4,1)
y=torch.empty(  3,1,1)
# x 和 y 是 broadcastable
# 1st 尾部维度: 都为 1
# 2nd 尾部维度: y 为 1
# 3rd 尾部维度: x 和 y 相同
# 4th 尾部维度: y 维度不存在
# 但是:
x=torch.empty(5,2,4,1)
y=torch.empty(  3,1,1)
# x 和 y 不能 broadcasting, 因为尾3维度 2 != 3

如果两个 tensors 可以 broadcasting,那么计算过程是这样的:

  • 如果 x 和 y 的维度不同,那么对于维度较小的 tensor 的维度补 1,使它们维度相同。
  • 然后,对于每个维度,计算结果的维度值就是 x 和 y 中较大的那个值。
# 按照尾部维度对齐
x=torch.empty(5,1,4,1)
y=torch.empty(  3,1,1)
(x+y).size()
# 结果维度如下
torch.Size([5, 3, 4, 1])

一个不对的例子:

x=torch.empty(5,2,4,1)
y=torch.empty(3,1,1)
(x+y).size()
# 报错提示说:在 non-singleton 维度上,tensor a 和 b 的 维度应该相同。
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-140-7836c771d982> in <module>
      1 x=torch.empty(5,2,4,1)
      2 y=torch.empty(3,1,1)
----> 3 (x+y).size()


RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

7.拼接与拆分

cat
stack
split
chunk

7.1 维度拼接

cat

numpy中使用concat,在pytorch中使用更加简写的 cat

完成一个拼接

两个向量维度相同,想要拼接的维度上的值可以不同,但是其它维度上的值必须相同。

举个例子:想将这两组班级的成绩合并起来

a[class 1-4, students, scores]

b[class 5-9, students, scores]

理解cat:

  • 行拼接:[4, 4] 与 [5, 4] 以 dim=0(行)进行拼接 —> [9, 4] 9个班的成绩合起来
  • 列拼接:[4, 5] 与 [4, 3] 以 dim=1(列)进行拼接 —> [4, 8] 每个班合成8项成绩
a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
torch.cat([a,b],dim=0).shape # 结果就是9个班级的成绩
torch.Size([9, 32, 8])
a1 = torch.rand(4,3,32,32)
a2 = torch.rand(5,3,32,32)
torch.cat([a1,a2],dim=0).shape		# 合并第1维 理解上相当于合并batch
torch.Size([9, 3, 32, 32])
a2 = torch.rand(4,1,32,32)
torch.cat([a1,a2],dim=1).shape		# 合并第2维 理解上相当于合并为 rgba
torch.Size([4, 4, 32, 32])
a1 = torch.rand(4,3,16,32)
a2 = torch.rand(4,3,16,32)
torch.cat([a1,a2],dim=3).shape		# 合并第3维 理解上相当于合并照片的上下两半
torch.Size([4, 3, 16, 64])
# 错误示范,其他非合并dim不同
a1 = torch.rand(4,3,32,32)
torch.cat([a1,a2],dim=0).shape
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-145-bc8f55519fa7> in <module>
      1 # 错误示范,其他非合并dim不同
      2 a1 = torch.rand(4,3,32,32)
----> 3 torch.cat([a1,a2],dim=0).shape


RuntimeError: Sizes of tensors must match except in dimension 0. Got 32 and 16 in dimension 2

stack

在指定维度的位置前插入新的维度。

stack需要保证两个Tensor的shape是一致的,这就像是有两类东西,它们的其它属性都是一样的(比如男的一张表,女的一张表)。使用stack时候要指定一个维度位置,在那个位置前会插入一个新的维度,因为是两类东西合并过来所以这个新的维度size是2,通过指定这个维度是0或者1来选择性别是男还是女。

a1 = torch.rand(4,3,16,32)
a2 = torch.rand(4,3,16,32) 
torch.cat([a1,a2],dim=2).shape		# 合并照片的上下部分
torch.Size([4, 3, 32, 32])
torch.stack([a1,a2],dim=2).shape	# 添加了一个维度 一个值代表上半部分,一个值代表下半部分
torch.Size([4, 3, 2, 16, 32])
a = torch.rand(32,8)
b = torch.rand(32,8)
torch.stack([a,b],dim=0).shape		# 将两个班级的学生成绩合并,添加一个新的维度
torch.Size([2, 32, 8])

7.2 维度拆分

split

split是torch.chunk()函数的升级版本,它不仅可以按份数均匀分割,还可以按特定方案进行分割。

对一个Tensor而言,要拆分的那个维度的size就是"这个维度的总长度"了,可以指定拆分后的几个Tensor各取多少长度,或者指定每个Tensor取多少长度

源码定义:torch.split(tensor,split_size_or_sections,dim=0)

  • 第一个参数是待分割张量
  • 第二个参数有两种形式。
    • 一种是分割份数,这就和torch.chunk()一样了。
    • 第二种这是分割方案,这是一个list,待分割张量将会分割为len(list)份,每一份的大小取决于list中的元素
  • 第三个参数为分割维度
a = torch.rand(32,8)
b = torch.rand(32,8)
c = torch.rand(32,8)
d = torch.rand(32,8)
e = torch.rand(32,8)
f = torch.rand(32,8)
s = torch.stack([a,b,c,d,e,f],dim=0)
s.shape
torch.Size([6, 32, 8])
aa,bb = s.split(3,dim=0)	# 按数量切分,可以使用一个常数
aa.shape, bb.shape
(torch.Size([3, 32, 8]), torch.Size([3, 32, 8]))
cc,dd,ee = s.split([3,2,1],dim=0)	# 按单位长度切分,可以使用一个列表
cc.shape, dd.shape, ee.shape
(torch.Size([3, 32, 8]), torch.Size([2, 32, 8]), torch.Size([1, 32, 8]))
 ff,gg = s.split(6,dim=0)	# 只切了一半,有一半不存在,所以报错
---------------------------------------------------------------------------

ValueError                                Traceback (most recent call last)

<ipython-input-153-58be9098868e> in <module>
----> 1 ff,gg = s.split(6,dim=0)        # 只切了一半,有一半不存在,所以报错


ValueError: not enough values to unpack (expected 2, got 1)

chunk

按数量进行拆分

对于按数量切分:chunk中的参数为每份有几个。

torch.cat()函数是把各个tensor连接起来,这里的torch.chunk()的作用是把一个tensor均匀分割成若干个小tensor

源码定义:torch.chunk(intput,chunks,dim=0)

  • 第一个参数input是你想要分割的tensor
  • 第二个参数chunks是你想均匀分割的份数,如果该tensor在你要进行分割的维度上的size不能被chunks整除,则最后一份会略小(也可能为空)
  • 第三个参数表示分割维度,dim=0按行分割,dim=1表示按列分割
  • 该函数返回由小tensor组成的list
s.shape
torch.Size([6, 32, 8])
aa,bb = s.chunk(2,dim=0)
aa.shape, bb.shape
(torch.Size([3, 32, 8]), torch.Size([3, 32, 8]))
cc,dd,ee = s.split(2,dim=0) # 每份2个
cc.shape,dd.shape,ee.shape
(torch.Size([2, 32, 8]), torch.Size([2, 32, 8]), torch.Size([2, 32, 8]))
cc,dd = s.split(3,dim=0)# 每份3个,可分2份
cc.shape,dd.shape
(torch.Size([3, 32, 8]), torch.Size([3, 32, 8]))

8.数学运算

  • add/minus/multiply/divide
  • matmul
  • pow
  • sqrt/rsqrt
  • round

8.1 基础运算:

可以使用 + - * / 推荐

也可以使用 torch.add, mul, sub, div

# 这两个Tensor加减乘除会对b自动进行Broadcasting
a = torch.rand(3, 4)
b = torch.rand(4)
 
c1 = a + b
c2 = torch.add(a, b)
print(c1.shape, c2.shape)
print(torch.all(torch.eq(c1, c2)))
torch.Size([3, 4]) torch.Size([3, 4])
tensor(True)
# 减法运算
a = torch.rand(3, 4)
b = torch.rand(4)
 
c1 = a - b
c2 = torch.sub(a, b)
print(c1.shape, c2.shape)
print(torch.all(torch.eq(c1, c2)))
torch.Size([3, 4]) torch.Size([3, 4])
tensor(True)
# 哈达玛积(element wise,对应元素相乘)
c1 = a * b
c2 = torch.mul(a, b)
print(c1.shape, c2.shape)
print(torch.all(torch.eq(c1, c2)))
torch.Size([3, 4]) torch.Size([3, 4])
tensor(True)
# 除法运算
c1 = a / b
c2 = torch.div(a, b)
print(c1.shape, c2.shape)
print(torch.all(torch.eq(c1, c2)))
torch.Size([3, 4]) torch.Size([3, 4])
tensor(True)

8.2 矩阵乘法

(1)二维矩阵相乘

二维矩阵乘法运算操作包括torch.mm()、torch.matmul()、@,

a = torch.ones(2, 1)
b = torch.ones(1, 2)
print(torch.mm(a, b).shape)
print(torch.matmul(a, b).shape)
print((a @ b).shape)
torch.Size([2, 2])
torch.Size([2, 2])
torch.Size([2, 2])

(2)多维矩阵相乘

对于高维的Tensor(dim>2),定义其矩阵乘法仅在最后的两个维度上,要求前面的维度必须保持一致,就像矩阵的索引一样并且运算操只有torch.matmul()。

c = torch.rand(4, 3, 28, 64)
d = torch.rand(4, 3, 64, 32)
print(torch.matmul(c, d).shape)
torch.Size([4, 3, 28, 32])
#在这种情形下的矩阵相乘,前面的"矩阵索引维度"如果符合Broadcasting机制,也会自动做Broadcasting,然后相乘。
c = torch.rand(4, 3, 28, 64)
d = torch.rand(4, 1, 64, 32)
print(torch.matmul(c, d).shape)
torch.Size([4, 3, 28, 32])

8.3 幂运算-pow

a = torch.full([2, 2], 3)
 
b = a.pow(2)  # 也可以a**2
print(b)
tensor([[9., 9.],
        [9., 9.]])

8.4 开方运算-sqrt()

rsqrt() 表示平方根的倒数

对矩阵中每一个元素生效

print(b)
c = b.sqrt()  # 也可以a**(0.5)
print(c)
 
d = b.rsqrt()  # 平方根的倒数
print(d)
tensor([[9., 9.],
        [9., 9.]])
tensor([[3., 3.],
        [3., 3.]])
tensor([[0.3333, 0.3333],
        [0.3333, 0.3333]])

8.5 指数与对数运算

exp(n) 表示:e的n次方

log(a) 表示:ln(a)

log2() 、 log10()

注意:log是以自然对数为底数的,以2为底的用log2,以10为底的用log10

a = torch.exp(torch.ones(2, 2))  # 得到2*2的全是e的Tensor次幂
print(a)
print(torch.log(a))  # 取自然对数
b = torch.FloatTensor([8,8])
print(torch.log2(b))  # 取以2为底的对数
tensor([[2.7183, 2.7183],
        [2.7183, 2.7183]])
tensor([[1., 1.],
        [1., 1.]])
tensor([3., 3.])

8.6 近似值运算

floor、ceil 向下取整、向上取整

round 4舍5入

trunc、frac 裁剪

a = torch.tensor(3.14)
print(a.floor(), a.ceil(), a.trunc(), a.frac())  # 取下,取上,取整数,取小数
b = torch.tensor(3.49)
c = torch.tensor(3.5)
print(b.round(), c.round())  # 四舍五入
tensor(3.) tensor(4.) tensor(3.) tensor(0.1400)
tensor(3.) tensor(4.)

8.7 裁剪运算

即对Tensor中的元素进行范围过滤,不符合条件的可以把它变换到范围内部(边界)上,常用于梯度裁剪(gradient clipping)

gradient clipping 梯度裁剪:

(min) 小于min的都变为某某值

(min, max) 不在这个区间的都变为某某值

梯度爆炸:一般来说,当梯度达到100左右的时候,就已经很大了,正常在10左右,通过打印梯度的模来查看 w.grad.norm(2)

对于w的限制叫做weight clipping,对于weight gradient clipping称为 gradient clipping。

grad = torch.rand(2, 3) * 15  # 0~15随机生成
print(grad.max(), grad.min(), grad.median())  # 最大值最小值平均值
 
print(grad)
print(grad.clamp(10))  # 最小是10,小于10的都变成10
print(grad.clamp(3, 10))  # 最小是3,小于3的都变成3;最大是10,大于10的都变成10
tensor(14.9548) tensor(4.7532) tensor(7.9810)
tensor([[ 9.5116, 14.9548,  6.5294],
        [ 7.9810, 14.2554,  4.7532]])
tensor([[10.0000, 14.9548, 10.0000],
        [10.0000, 14.2554, 10.0000]])
tensor([[ 9.5116, 10.0000,  6.5294],
        [ 7.9810, 10.0000,  4.7532]])

9.统计属性

求值或位置

  • norm
  • mean sum
  • prod
  • max, min, argmin, argmax
  • kthvalue, top

9.1 norm-p范数

1-Norm就是所有元素的绝对值之和

2-Norm就是所有元素的平方和并开根号

不加dim参数,默认所有维度

从shape出发,加入dim后,这个dim就会消失(做Norm)

  • 向量范数与矩阵范数:https://blog.csdn.net/bitcarmanlee/article/details/51945271
  • 机器学习下的各种norm到底是个什么东西?:https://www.zhihu.com/question/29458275
  • 机器学习中的范数规则化之(一)L0、L1与L2范数:https://blog.csdn.net/zouxy09/article/details/24971995
a = torch.full([8], 1)
b = a.reshape([2, 4])
c = a.reshape([2, 2, 2])
 
# 求L1范数(所有元素绝对值求和)
print(a.norm(1), b.norm(1), c.norm(1))
# 求L2范数(所有元素的平方和再开根号)
print(a.norm(2), b.norm(2), c.norm(2))

tensor(8.) tensor(8.) tensor(8.)
tensor(2.8284) tensor(2.8284) tensor(2.8284)
# 在b的1号维度上求L1范数
print(b.norm(1, dim=1))
# 在b的1号维度上求L2范数
print(b.norm(2, dim=1))
tensor([4., 4.])
tensor([2., 2.])
# 在c的0号维度上求L1范数
print(c.norm(1, dim=0))
# 在c的0号维度上求L2范数
print(c.norm(2, dim=0))
tensor([[2., 2.],
        [2., 2.]])
tensor([[1.4142, 1.4142],
        [1.4142, 1.4142]])

9.2 均值、累加、最小、最大、累积

max() 求最大的值

min() 求最小的值

mean() 求平均值 mean = sum / size

prod() 累乘

sum() 求和

argmax() 返回最大值元素的索引

argmin() 返回最大值元素的索引

argmax(dim=l) 求 l 维中,最大元素的位置,这样的话这一维将消失。

note:以上这些,如果不加参数,会先打平,在计算,所以对于 argmax 和 argmin来说得到的是打平后的索引。

b = torch.arange(8).reshape(2, 4).float()
print(b)
# 均值,累加,最小,最大,累积
print(b.mean(), b.sum(), b.min(), b.max(), b.prod())
# 打平后的最小、最大值索引
print(b.argmax(), b.argmin())
tensor([[0., 1., 2., 3.],
        [4., 5., 6., 7.]])
tensor(3.5000) tensor(28.) tensor(0.) tensor(7.) tensor(0.)
tensor(7) tensor(0)

注意:上面的argmax、argmin操作默认会将Tensor打平后取最大值索引和最小值索引,如果不希望Tenosr打平,而是求给定维度上的索引,需要指定在哪一个维度上求最大值索引或最小值索引。

比如,有shape=[4, 10]的Tensor,表示4张图片在10分类的概率结果,我们需要知道每张图片的最可能的分类结果:

a = torch.rand(4, 10)
print(a)
# 在第二维度上求最大值索引
print(a.argmax(dim=1))
tensor([[0.5089, 0.8470, 0.7515, 0.5390, 0.7651, 0.9499, 0.4106, 0.0224, 0.3070,
         0.0815],
        [0.6564, 0.6338, 0.5161, 0.1036, 0.7365, 0.4423, 0.6938, 0.4452, 0.0061,
         0.1241],
        [0.4390, 0.3248, 0.1907, 0.7630, 0.7486, 0.3885, 0.9798, 0.4495, 0.4535,
         0.4067],
        [0.1909, 0.9897, 0.5197, 0.3004, 0.1333, 0.5913, 0.8592, 0.4751, 0.5998,
         0.2342]])
tensor([5, 4, 6, 1])

直接使用max和min配合dim参数也可以获得最值索引,同时得到最值的具体值:

print(a.max(dim=1))
torch.return_types.max(
values=tensor([3., 7.]),
indices=tensor([3, 3]))

使用max(dim=) 函数配上dim参数,可以很好的返回最大值与该值的位置

argmax 其实是 max 的一部分(位置)

keepdim=True 设置这个参数后,维度得以保留,与原来的维度是一样的。

a = torch.randn(4,10)	# 假设生成4张手写体数字照片的概率(发生过偏移)
a
tensor([[-1.1685,  0.4939,  0.1523, -0.7941,  1.4138,  0.7953,  0.4168, -0.7407,
          0.2822, -0.5448],
        [-1.1060, -0.2001, -0.9245,  0.6636,  0.1517,  0.9220,  0.3030,  0.0887,
          1.3177,  1.4373],
        [ 0.6177,  2.2362,  0.3999,  0.6556, -0.2605, -0.5839, -1.2971, -1.7552,
          0.6277, -0.2984],
        [-0.5320, -0.0683, -1.5299, -1.8177, -0.7979,  0.8961,  0.1037,  0.2006,
         -1.5410,  0.6457]])
a.max(dim=1)
torch.return_types.max(
values=tensor([1.4138, 1.4373, 2.2362, 0.8961]),
indices=tensor([4, 9, 1, 5]))
a.max(dim=1,keepdim=True)
torch.return_types.max(
values=tensor([[1.4138],
        [1.4373],
        [2.2362],
        [0.8961]]),
indices=tensor([[4],
        [9],
        [1],
        [5]]))
 a.argmax(dim=1, keepdim=True)	# 返回一个 [4,1] , dim=1这一维并没有消失
tensor([[4],
        [9],
        [1],
        [5]])

9.3 取前k大/前k小/第k小的概率值及其索引

由于max只能找出一个最大,如果想找最大的几个就做不到了。

使用topk代替max可以完成更灵活的需求,有时候不是仅仅要概率最大的那一个,而是概率最大的k个。

如果不是求最大的k个,而是求最小的k个,只要使用参数largest=False,kthvalue还可以取第k小的概率值及其索引。

# 2个样本,分为10个类别的置信度
d = torch.randn(2, 10)  
d
tensor([[ 1.8442, -2.5249,  0.3780, -3.0356, -0.6604, -1.2991, -1.6812,  0.4360,
         -1.4151, -1.3771],
        [ 0.4271,  1.2599, -0.9852, -0.4424, -0.5836,  0.5713, -1.1026,  0.2437,
         -1.6229,  1.6462]])
# 最大概率的3个类别
print(d.topk(3, dim=1))  
# 最小概率的3个类别
print(d.topk(3, dim=1, largest=False))  
torch.return_types.topk(
values=tensor([[1.8442, 0.4360, 0.3780],
        [1.6462, 1.2599, 0.5713]]),
indices=tensor([[0, 7, 2],
        [9, 1, 5]]))
torch.return_types.topk(
values=tensor([[-3.0356, -2.5249, -1.6812],
        [-1.6229, -1.1026, -0.9852]]),
indices=tensor([[3, 1, 6],
        [8, 6, 2]]))
#kthvalue(i, dim=j) 求 j 维上,第 i 小的元素以及位置
# 求第8小概率的类别(一共10个那就是第3大)
print(d.kthvalue(8, dim=1))  
torch.return_types.kthvalue(
values=tensor([0.3780, 0.5713]),
indices=tensor([2, 5]))

9.4 比较操作

“>, >=, <, <=, !=, ==”

进行比较后,返回的是一个 bytetensor,不再是floattensor,由于pytorch中所有的类型都是数值,没有True or False ,为了表达使用整型的0,1

torch.eq(a,b) 判断每一个元素是否相等,返回 bytetensor

torch.equal(a,b) 返回True or False

a = torch.randn(2, 3)
b = torch.randn(2, 3)
print(a)
print(b)
# 比较是否大于0,是对应位置返回1,否对应位置返回0,注意得到的是ByteTensor
print(a > 0)  
# 作用同 > 号
print(torch.gt(a, 0))
# 是否不等于0,是对应位置返回1,否对应位置返回0
print(a != 0)
# 比较每个位置是否相等,是对应位置返回1,否对应位置返回0
print(torch.eq(a, b))  
# 比较每个位置是否相等,全部相等时才返回True
print(torch.equal(a, b), torch.equal(a, a))  
tensor([[ 1.1085, -1.1223,  0.6210],
        [-2.8198, -2.0091,  0.8598]])
tensor([[ 0.2230,  1.5942, -0.7011],
        [ 0.1289, -1.4188, -0.1122]])
tensor([[ True, False,  True],
        [False, False,  True]])
tensor([[ True, False,  True],
        [False, False,  True]])
tensor([[True, True, True],
        [True, True, True]])
tensor([[False, False, False],
        [False, False, False]])
False True

10 Pytorch 一些高阶操作

10.1 where

使用C=torch.where(condition,A,B)其中A,B,C,condition是shape相同的Tensor,C中的某些元素来自A,某些元素来自B,这由condition中对应位置的元素是1还是0来决定。如果condition对应位置元素是1,则C中的该位置的元素来自A中的该位置的元素,如果condition对应位置元素是0,则C中的该位置的元素来自B中的该位置的元素。

import torch
 
cond = torch.tensor([[0.6, 0.1], [0.2, 0.7]])
print(cond)
print(cond > 0.5)
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[4, 5], [6, 7]])
c = torch.where(cond > 0.5, a, b)
print(c)
tensor([[0.6000, 0.1000],
        [0.2000, 0.7000]])
tensor([[ True, False],
        [False,  True]])
tensor([[1, 5],
        [6, 4]])

10.2 gather

使用torch.gather(input,dim,index, out=None)对元素实现一个查表映射的操作:

# 4张图像的10种分类的概率值
prob = torch.randn(4, 10)
print(prob)
# 取概率最大的前3个的概率值及其索引
_, idx = prob.topk(3, dim=1)
print(idx)
label = torch.arange(10) + 100
# 用于将idx的0~9映射到100~109
print(label)
out = torch.gather(label.expand(4, 10), dim=1, index=idx.long())
print(label.expand(4, 10))
print(out)
tensor([[-0.4805,  0.6682, -1.1893,  0.7377,  0.6221,  0.1750, -1.3162,  1.5251,
          0.8921, -0.0975],
        [ 1.5657,  0.5704, -0.7770,  0.5287,  1.2055, -0.0726,  0.4084, -1.8432,
          0.3574,  1.7794],
        [ 0.8718,  1.4115,  0.2857, -2.4984, -0.4921,  0.0663, -0.7654,  1.4403,
         -0.7046,  0.4614],
        [-1.1375, -0.1773, -2.1657,  0.2821,  0.4089, -0.3924, -0.5846,  0.3160,
          1.0131,  0.1173]])
tensor([[7, 8, 3],
        [9, 0, 4],
        [7, 1, 0],
        [8, 4, 7]])
tensor([100, 101, 102, 103, 104, 105, 106, 107, 108, 109])
tensor([[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
        [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
        [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
        [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]])
tensor([[107, 108, 103],
        [109, 100, 104],
        [107, 101, 100],
        [108, 104, 107]])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Shine.Zhang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值