目录
3.torch.squeeze函数和torch.unsqueeze函数(tensor的维度去除和增加)
6.torch.cat函数和torch.stack函数(tensor的拼接)
7.torch.chunk函数和torch.split函数(tensor的拆分)
9.reshape函数和resize函数(改变tensor的形状)
1.linspace函数(tensor生成)
torch.linspace(start, end, steps=分割点数,dtype=返回类型)
>>> t = torch.linspace(3, 10, 5)
>>> t
tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000])
>>> c = torch.linspace(3, 10, 5, dtype=torch.int)
>>> c
tensor([ 3, 4, 6, 8, 10], dtype=torch.int32)
2.rand函数和randn函数(随机tensor生成)
rand函数和randn函数两个都是生成随机tensor的函数,两者的区别是rand生成的tensor是基于均匀分布的,而randn生成的向量是基于标准正态分布的
>>> print(torch.rand(2, 3))
tensor([[0.4525, 0.1064, 0.9951],
[0.2394, 0.7348, 0.5460]])
>>> print(torch.randn(2, 3))
tensor([[ 1.7014, 1.1992, -0.5759],
[-0.2074, -1.2572, 0.8707]])
3.torch.squeeze函数和torch.unsqueeze函数(tensor的维度去除和增加)
squeeze函数是将tenser维度为1的删除,保留其他维度
>>> import torch
>>> x = torch.randn(size=(2, 1, 2, 1, 2))
>>> x.shape
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
# 把x中维度为1删除掉,保留其他维度
>>> y.shape
torch.Size([2, 2, 2])
同时可以指定dim去删除,但是不可以删除维度大于1的
>>> x = torch.randn(size=(2, 1, 2, 1, 2))
>>> y = torch.squeeze(x, dim=1)
>>> y.shape
torch.Size([2, 2, 1, 2])
>>> z = torch.squeeze(x, dim=-1)
>>> z.shape
torch.Size([2, 1, 2, 1, 2])
unsqueeze函数主要是对tensor升维,
>>> x = torch.tensor([1, 2, 3, 4])
>>> y = torch.unsqueeze(x, dim=0)
>>> y
tensor([[1, 2, 3, 4]])
>>> y.shape
torch.Size([1, 4])
4.repeat函数(tensor的复制)
repeat函数是对tensor进行复制,可以指定维度复制,可以对非单数维复制
>>> import torch
>>> a = torch.tensor(([[1, 2, 3], [1, 2, 3]]))
>>> b = a.repeat(2, 2)
>>> b
tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]])
>>> b.shape
torch.Size([4, 6])
当参数个数大于原有tensor的维度时,在第0维扩展维度为1的维度,然后复制
>>> a = torch.tensor(([[1, 2, 3], [1, 2, 3]]))
>>> a.shape
torch.Size([2, 3])
>>> b = a.repeat(1, 2, 1)
>>> b
tensor([[[1, 2, 3],
[1, 2, 3],
[1, 2, 3],
[1, 2, 3]]])
>>> b.shape
torch.Size([1, 4, 3])
5.expand函数(tensor扩张)
expand函数是对tensor在指定dim进行扩张,与repeat函数区别是仅在单位维上扩张(复制)
>>> a = torch.randn(1, 1, 3)
>>> b = a.expand(-1, 3, -1)
>>> c = a.expand(3, 3 ,3)
>>> a
tensor([[[-0.8384, 0.5715, -1.5192]]])
>>> b.shape
torch.Size([1, 3, 3])
>>> b
tensor([[[-0.8384, 0.5715, -1.5192],
[-0.8384, 0.5715, -1.5192],
[-0.8384, 0.5715, -1.5192]]])
>>> c.shape
torch.Size([3, 3, 3])
>>> c
tensor([[[-0.8384, 0.5715, -1.5192],
[-0.8384, 0.5715, -1.5192],
[-0.8384, 0.5715, -1.5192]],
[[-0.8384, 0.5715, -1.5192],
[-0.8384, 0.5715, -1.5192],
[-0.8384, 0.5715, -1.5192]],
[[-0.8384, 0.5715, -1.5192],
[-0.8384, 0.5715, -1.5192],
[-0.8384, 0.5715, -1.5192]]])
6.torch.cat函数和torch.stack函数(tensor的拼接)
cat函数是将两个tensor按照指定维度拼接在一起,除了拼接维数dim数值可以不同,其他数值需要相同,不然无法对齐。(注,cat函数不会新增维度)
>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 3)
>>> c = torch.cat((a, b), dim=0)
>>> c.shape
torch.Size([5, 3])
>>> a, b, c
(tensor([[-0.4516, -0.3405, 0.7484],
[ 0.5414, -1.2156, 0.2577]]),
tensor([[-1.9560, 1.1200, -0.1139],
[-0.5212, 0.3194, 0.8153],
[-0.8011, 1.1799, -0.0382]]),
tensor([[-0.4516, -0.3405, 0.7484],
[ 0.5414, -1.2156, 0.2577],
[-1.9560, 1.1200, -0.1139],
[-0.5212, 0.3194, 0.8153],
[-0.8011, 1.1799, -0.0382]]))
stack函数 是在指定参数dim新增维度,然后拼接
>>> a = torch.arange(6.0).reshape(2, 3)
>>> a
tensor([[0., 1., 2.],
[3., 4., 5.]])
>>> b = torch.linspace(0, 10, 6).reshape(2, 3)
>>> b
tensor([[ 0., 2., 4.],
[ 6., 8., 10.]])
>>> f = torch.stack((a, b), dim=1)
>>> f
tensor([[[ 0., 1., 2.],
[ 0., 2., 4.]],
[[ 3., 4., 5.],
[ 6., 8., 10.]]])
>>> f.shape
torch.Size([2, 2, 3])
7.torch.chunk函数和torch.split函数(tensor的拆分)
chunk函数是按照某维度,对tensor进行均匀切分
>>> a = torch.arange(12).reshape(4, 3)
>>> a
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
>>> c = torch.chunk(a, 4, dim=0)
>>> c
(tensor([[0, 1, 2]]),
tensor([[3, 4, 5]]),
tensor([[6, 7, 8]]),
tensor([[ 9, 10, 11]]))
split函数是对tensor在指定dim维度下,按照split_size进行拆分,若split_size不能整除,则剩下数据为一块
>>> a = torch.arange(12).reshape(4, 3)
>>> a
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
>>> s = torch.split(a, 2, dim=0)
>>> s
(tensor([[0, 1, 2],
[3, 4, 5]]),
tensor([[ 6, 7, 8],
[ 9, 10, 11]]))
>>> s = torch.split(a, 2, dim=1)
>>> s
(tensor([[ 0, 1],
[ 3, 4],
[ 6, 7],
[ 9, 10]]),
tensor([[ 2],
[ 5],
[ 8],
[11]]))
8.permute函数和transpose函数(交换维度)
permute函数可以将维度位置进行交换,和transpose函数区别是可以对任意高维矩阵进行转置
>>> a = torch.randn(2, 3, 5)
>>> a.shape
torch.Size([2, 3, 5])
>>> a.permute(2, 0, 1).shape
torch.Size([5, 2, 3])
transpose函数tensor转置,只可以操作2D转置
>>> a = torch.randn(2, 3)
>>> a
tensor([[ 2.0059, -0.5201, 1.0149],
[-0.3918, -0.8710, 0.5579]])
>>> torch.transpose(a, 0, 1)
tensor([[ 2.0059, -0.3918],
[-0.5201, -0.8710],
[ 1.0149, 0.5579]])
9.reshape函数和resize函数(改变tensor的形状)
reshape函数是对tensor形状进行改变,但是元素个数不变
>>> a = torch.randn(4)
>>> a
tensor([-1.2296, 0.2770, -0.7377, -0.9197])
>>> torch.reshape(a, (2, 2))
tensor([[-1.2296, 0.2770],
[-0.7377, -0.9197]])
resize函数是对tensor尺寸进行调整,但与view不同,它可以修改tensor的尺寸。如果新的尺寸超过原来的尺寸,会自动分配新的内存空间,而如果新的尺寸小于原尺寸,则之前的数据依旧会被保存。
>>> b = torch.arange(0, 6)
tensor([0, 1, 2, 3, 4, 5])
>>> b.resize(2, 3)
tensor([[0, 1, 2],
[3, 4, 5]])
10.view函数(改变tensor形状)
view函数可以调整tensor的形状,但必须保证调整前后元素总数一致。view不会修改自身的数据,返回的新tensor与源tensor共享内存,即更改其中一个,另外一个也会跟着改变。
>>> a = torch.arange(0, 6)
>>> a
tensor([0, 1, 2, 3, 4, 5])
>>> b = a.view(-1, 3)
>>> b
tensor([[0, 1, 2],
[3, 4, 5]])
11.torch.flatten函数(展平tensor)
torch.flatten(input, start_dim=0, end_dim=-1)将Tensor从start_dim展平到end_dim
>>> t = torch.randn([3, 2, 3])
>>> t
tensor([[[-0.9367, 0.5908, 0.8416],
[ 0.2513, 1.4519, -0.6566]],
[[-0.0527, -0.2044, 0.0610],
[ 1.7106, -1.7432, 0.2789]],
[[ 1.4201, 0.4685, -0.9257],
[ 1.5866, 0.9033, -0.8300]]])
>>> f = torch.flatten(t)
>>> f
tensor([-0.9367, 0.5908, 0.8416, 0.2513, 1.4519, -0.6566, -0.0527, -0.2044,
0.0610, 1.7106, -1.7432, 0.2789, 1.4201, 0.4685, -0.9257, 1.5866,
0.9033, -0.8300])
>>> t = torch.randn([3, 2, 3])
>>> t
tensor([[[-0.4668, 1.1902, -0.1419],
[ 0.8214, -1.3676, 1.5473]],
[[-0.1603, -0.8412, 0.0700],
[-0.3150, -1.2187, -0.8126]],
[[ 0.2662, -0.5577, -1.8281],
[ 0.4309, 0.6306, -1.8118]]])
>>> f = torch.flatten(t, start_dim=1)
>>> f
tensor([[-0.4668, 1.1902, -0.1419, 0.8214, -1.3676, 1.5473],
[-0.1603, -0.8412, 0.0700, -0.3150, -1.2187, -0.8126],
[ 0.2662, -0.5577, -1.8281, 0.4309, 0.6306, -1.8118]])
12.meshgrid函数(生成网格)
输入两个数据类型相同的一维tensor,输出两个tensor(tensor行数为第一个输入张量的元素个数,列数为第二个输入张量的元素个数)。(注,输入两个tensor数据类型和维度不是一维会出错)
>>> x = torch.tensor([1, 2, 3])
>>> y = torch.tensor([5, 6, 7])
>>> resultx, resulty = torch.meshgrid(x,y)
>>> print(x,y)
tensor([1, 2, 3]) tensor([5, 6, 7])
>>> print(resultx)
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
>>> print(resulty)
tensor([[5, 6, 7],
[5, 6, 7],
[5, 6, 7]])
13.torch.argmax(返回指定维度最大值的序号)
>>> a = torch.tensor([[1, 5, 5, 2], [9, 6, -2, 8], [-3, 7, -9, 1]])
>>> a
>>> tensor([[ 1, 5, 5, 2],
[ 9, 6, -2, 8],
[-3, 7, -9, 1]])
>>> b = torch.argmax(a, dim=0)
>>> b
tensor([1, 2, 0, 1])