Tensor目录
一、张量Tensor
二、模块导入
import numpy as np
import torch
三、创建tensor的方式
(一)使用python中的列表创建tensor
代码示例
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(a)
运行结果
(二)使用numpy中的数组创建tensor
代码示例
b = np.random.rand(3, 4)
b = torch.tensor(b)
print(b)
运行结果
(三)使用torch的API创建tensor
代码示例(部分)
torch.empty(3, 4) # 创建空的tensor
torch.ones([3, 4]) # 创建全为1的tensor
torch.zeros([3, 4]) # 创建全为0的tensor
torch.rand([3, 4]) # 创建随机值的tensor,随机值的区间是[0,1)
torch.randint(low=0, high=10, size=[3, 4]) # 创建随机整数的tensor,随机区间是[low, high)
torch.randn([3, 4]) # 均匀分布,均值为0,方差为1
与numpy的用法大致相同。
二、pytorch中tensor的常用方法
-
获取tensor中的数据:当tensor中只有一个元素,使用tensor.item()
代码示例c = torch.tensor(np.array([1])) print(c.item())
运行结果
1
-
转化为numpy数组:tensor.numpy()
代码示例d = torch.tensor(np.arange(20).reshape(4, 5)) print(d.numpy())
运行结果
-
获取tensor形状:tensor.size()
e = torch.tensor([[1, 2, 3], [4, 5, 6]]) print(e.size()) print(e.size(0)) # 表示第0维的数据数量,含有[1,2,3]和[4,5,6]两个数据 print(e.size(1)) # 表示第1维的数据数量,含有1,2,3(或4,5,6)三个数据
-
形状改变:tensor.view((3,4))
d = torch.tensor(np.arange(20).reshape(4, 5)) print(d.view((5, 4))) d = d.view(-1) # 多维变成一维 print(d.size())
-
获取维数:tensor.dim()
f = torch.tensor(np.arange(10).reshape(2, 5)) print(f.dim()) # 维数为2 print(f.view(-1).dim()) # 维数为1
2
1 -
获取最大(小)值:tensor.max() / tensor.min()
f = torch.tensor(np.arange(10).reshape(2, 5)) print(f.max()) print(f.min())
tensor(9, dtype=torch.int32)
tensor(0, dtype=torch.int32) -
矩阵的转置:tensor.t()、tensor.transpose(参数1,参数2,…)、tensor.permute() 【注意:一维数组没有转置】
(1)tensor.t() 【注意:当tensor的维度<=2时,可以直接使用,不传参数;但当tensor的维度>2时,则要使用tensor.transpose()】g = torch.tensor(np.arange(12).reshape([3, 4])) h = g.t() print(g) print(h)
(2)tensor.transpose() # 只需传入tensor指定维度的下标
i = torch.tensor(np.arange(24).reshape([2, 3, 4])) j = i.transpose(1, 2) # 1,2指的是[2,3,4]的下标,意思是对指定的两个维度进行转置 print(j)
(3)tensor.permute() # 需要传入tensor中的所有维度对应的下标
k = torch.tensor(np.arange(24).reshape([2, 3, 4])) L = k.permute(1, 0, 2) print(L)
三、tensor的数据类型
-
获取tensor的数据类型:tensor.dtype
-
创建数据时指定类型
m = torch.tensor([3, 4], dtype=torch.float32) print(m)
注意一种报错情况:如果用torch.Tensor()则会报错
例如:print(torch.Tensor(np.array([1, 2]), dtype=torch.float))
运行结果
-
数据类型的修改:tensor.数据类型() 【注意:double == float64】
n = torch.tensor([1, 2], dtype=torch.float) n = n.double() print(n)