Tensor(张量)
什么是tensor?
- scalar(标量):一个数值
- vector(向量):一维数组
- matrix(矩阵):二维数组
- tensor(张量):大于二维的数组,即多维数组
Tensor的类型:一共9种类型
- numpy.float64
- numpy.float32
- numpy.float16
- numpy.int64
- numpy.int32
- numpy.int16
- numpy.int8
- numpy.uint9(无符号整型)
- numpy.bool
1.torch.ones:默认值为1
2.torch.zeros:默认值为0
import torch
y = torch.zeros(3,4)
print(y)
//控制台结果
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
3.torch.full:填充
4.torch.empty:为空
5.torch.rand:随机数,均匀分布在0和1之间
import torch
x = torch.rand(10)
print(x)
//控制台输出
tensor([0.6338, 0.8150, 0.3313, 0.2341, 0.9809, 0.2198, 0.9884, 0.2101, 0.4671, 0.6889])
6.torch.randn:正态分布,均值为0,方差为1
import torch
x = torch.randn(10)
print(x)
//控制台输出
tensor([ 0.1908, -0.0746, -0.3406, -1.4599, 0.7656, 1.2863, -0.3496, 1.5926, 0.7925, 3.1737])
7.torch.randint:创建随机整数
import torch
x = torch.randint(1,99,(3,3))
print(x)
//控制台输出
tensor([[35, 56, 53],
[22, 67, 85],
[31, 28, 58]])
8.torch.randperm:选择随机数,下面的例子为输出0到9的10个随机数
import torch
x = torch.randperm(10)
print(x)
//控制台输出
tensor([9, 8, 2, 1, 4, 3, 0, 6, 5, 7])
9.torch.argmin():特别的在dim=0表示二维中的列,dim=1在二维矩阵中表示行
import torch
x = torch.randint(1,99,(3,3))
print(x)
print(torch.argmin(x, dim=0))
//控制台结果
tensor([[32, 75, 31],
[ 1, 36, 74],
[92, 9, 15]])
tensor([1, 2, 2])
10.torch.argmax():
import torch
x = torch.randint(1,99,(3,3))
print(x)
print(torch.argmax(x, dim=1))
//控制台结果
tensor([[40, 7, 31],
[60, 12, 85],
[71, 51, 89]])
tensor([0, 2, 2])
11.torch.ones_like:属性相似
12.torch.randn_like:属性相似
13.torch.add:加法
14.torch.tensor:创建
15.torch.as_tensor:创建
16.torch.from_numpy:从numpy转为tensor
import numpy as np
import torch
x = np.array([3,4,5,6,7])
print(x)
y = torch.from_numpy(x)
print(y)
//控制台结果
array([3 4 5 6 7])
tensor([3, 4, 5, 6, 7], dtype=torch.int32)
17.torch.reshape:改变shape
18.torch.numel:统计张量里有几个元素
import torch
x = torch.rand(2,3)
print(x)
print(torch.numel(x))
//控制台结果
tensor([[0.2432, 0.2156, 0.5053],
[0.7344, 0.4466, 0.5718]])
6
19.torch.view:视图
20.torch.arange:生成一个区间的数,下面例题是生成从10开始到30结束之间步长为4的数,左闭右开
import torch
x = torch.arange(10,30,4)
print(x)
//控制台输出
tensor([10, 14, 18, 22, 26])
21.torch.linspace:切分,把2到10切分成5等分
import torch
x = torch.linspace(2,10,5)
print(x)
//控制台结果
tensor([ 2., 4., 6., 8., 10.])
22.torch.eye:对角线为1的tensor
import torch
x = torch.eye(3,3)
print(x)
//控制台结果
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
23.torch.cat:连接,把两个张量堆叠起来,dim=0为纵轴堆叠,dim=1为横轴堆叠
import torch
x = torch.randint(1,10,(2,3))
print(x)
y = torch.cat((x,x),dim=0)
print(y)
//控制台结果
tensor([[3, 9, 9],
[7, 9, 6]])
tensor([[3, 9, 9],
[7, 9, 6],
[3, 9, 9],
[7, 9, 6]])
import torch
x = torch.randint(1,10,(2,3))
print(x)
y = torch.cat((x,x),dim=1)
print(y)
//控制台结果
tensor([[9, 9, 2],
[5, 2, 8]])
tensor([[9, 9, 2, 9, 9, 2],
[5, 2, 8, 5, 2, 8]])
24.torch.index_select:根据索引选择,.index_select中的参数0为横向选,1为纵向选
import torch
x = torch.randint(1,10,(4,4))
print(x)
indices = torch.tensor([0,2])
y = torch.index_select(x,0,indices)
print(y)
//控制台结果
tensor([[8, 1, 7, 8],
[3, 1, 6, 1],
[9, 2, 7, 9],
[2, 1, 7, 8]])
tensor([[8, 1, 7, 8],
[9, 2, 7, 9]])
25.torch.narrow:缩小
26.torch.t / torch.transpose:转置, torch.transpose中0表示行,1表示列,下面例子为行列交换
import torch
x = torch.tensor([[1,2,3],[4,5,6]])
y1 = x.t()
y2 = x.transpose(1,0)
print(y1)
print(y2)
//控制台结果
tensor([[1, 2, 3],
[4, 5, 6]])
tensor([[1, 4],
[2, 5],
[3, 6]])
tensor([[1, 4],
[2, 5],
[3, 6]])
27.torch.take:根据索引获取元素
28.torch.split:分割
import torch
x = torch.tensor([1,2,3,4,5,6,7])
y = x.split(2)
print(y)
//控制台结果
(tensor([1, 2]), tensor([3, 4]), tensor([5, 6]), tensor([7]))
29.torch.is_tensor(x):检测x是否为张量
import torch
x = [2,8,3,3,2,1,2]
y = torch.rand(1,2)
print(torch.is_tensor(x))
print(torch.is_tensor(y))
//控制台结果
False
True
30.torch.chunk:把张量切块,dim=0为横向切,dim=1为纵向切
import torch
x = torch.randint(1,10,(3,3))
y = torch.chunk(x,2,dim=0)
print(x)
print(y)
//控制台结果
tensor([[4, 1, 9],
[6, 2, 4],
[7, 7, 3]])
(tensor([[4, 1, 9],
[6, 2, 4]]), tensor([[7, 7, 3]]))
import torch
x = torch.randint(1,10,(3,3))
y = torch.chunk(x,2,dim=1)
print(x)
print(y)
//控制台结果
tensor([[9, 1, 4],
[1, 2, 5],
[3, 9, 8]])
(tensor([[9, 1],
[1, 2],
[3, 9]]), tensor([[4],
[5],
[8]]))