pytorch中张量的基本操作
1张量的概念
pytorch提供了两种类型的数据,成为张量和变量。张量类似于Numpy中的数组。常见的张量如下:标量(0维张量)、向量(1维张量)、矩阵(2维张量)、3维张量……
2 张量的创建
(1)利用torch.tensor()直接创建
torch.tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False)
其中:
data: 数据,可以是list,numpy的ndarray
dtype: 数据类型,默认与data的类型一致
device: 所在设备,gpu/cpu
requires_grad: 是否需要梯度,因为神经网络结构经常会要求梯度
pin_memory: 是否存于锁页内存
import torch
import numpy as np
#通过torch.tensor创建张量
arr = np.ones((3, 3))
print("ndarray的数据类型:", arr.dtype)
t = torch.tensor(arr)
print(t)
#结果如下
ndarray的数据类型: float64
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]], dtype=torch.float64)
(2)通过数值创建
torch.zeros() torch.zeros_like()
torch.ones() torch.ones_like()
torch.full() torch.full_like()
#通过torch.zeros创建张量
#其余torch.ones、torch.ones_like、torch.zeros_like
import torch
out_t = torch.tensor([1])
t=torch.zeros((3,3),out=out_t)
print(t, '\n', out_t)
print(id(t), id(out_t), id(t) == id(out_t))
#结果如下
tensor([[0, 0, 0],
[0, 0, 0],
[0, 0, 0]])
tensor([[0, 0, 0],
[0, 0, 0],
[0, 0, 0]])
2409918589672 2409918589672 True
#通过torch.full创建张量
t = torch.full((3, 3), 1)
print(t)
#结果如下
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])
# 通过torch.arange创建等差数列张量
t = torch.arange(2, 10, 2)
print(t)
#结果如下
tensor([2, 4, 6, 8])
(3)通过torch.normal创建正态分布张量
# 通过torch.normal创建正态分布张量
import torch
# mean:张量 std: 张量
mean = torch.arange(1, 5, dtype=torch.float)
std = 1
t_normal = torch.normal(mean, std)
print("mean:{}\nstd:{}".format(mean, std))
print(t_normal)
#结果如下
mean:tensor([1., 2., 3., 4.])
std:1
tensor([1.6614, 2.2669, 3.0617, 4.6213])
3 张量的操作
#==================example 1===============
# torch.cat拼接
# a=True
a=False
if a:
t = torch.ones((2, 3))
t_0 = torch.cat([t, t], dim=0)
t_1 = torch.cat([t, t, t], dim=1)
print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))
# ======================================= example 2 =======================================
# torch.stack拓展维度拼接
# a=True
a=False
if a:
t = torch.ones((2, 3))
t_stack = torch.stack([t, t, t], dim=0)
print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
# ======================================= example 3 =======================================
# torch.chunk切分
# a=True
a=False
if a:
a = torch.ones((2, 7)) # 7
list_of_tensors = torch.chunk(a, dim=1, chunks=3) # 3
for idx, t in enumerate(list_of_tensors):
print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))
# ======================================= example 4 =======================================
# torch.split
# a=True
a=False
if a:
t = torch.ones((2, 5))
list_of_tensors = torch.split(t, [2, 1, 1], dim=1) # [2 , 1, 2]
for idx, t in enumerate(list_of_tensors):
print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))
# list_of_tensors = torch.split(t, [2, 1, 2], dim=1) # 报错
# for idx, t in enumerate(list_of_tensors):
# print("第{}个张量:{}, shape is {}".format(idx, t, t.shape))
# ======================================= example 5 =======================================
# torch.index_select索引
# a=True
a=False
if a:
t = torch.randint(0, 9, size=(3, 3))
idx = torch.tensor([0, 2], dtype=torch.long) # float报错
t_select = torch.index_select(t, dim=0, index=idx)
print("t:\n{}\nt_select:\n{}".format(t, t_select))
# ======================================= example 6 =======================================
# torch.masked_select索引
# a=True
a=False
if a:
t = torch.randint(0, 9, size=(3, 3))
mask = t.le(5) # ge is mean greater than or equal/ gt: greater than le lt
t_select = torch.masked_select(t, mask)
print("t:\n{}\nmask:\n{}\nt_select:\n{} ".format(t, mask, t_select))
# ======================================= example 7 =======================================
# torch.reshape变换
# a=True
a=False
if a:
t = torch.randperm(8)
t_reshape = torch.reshape(t, (-1, 2, 2)) # -1
print("t:{}\nt_reshape:\n{}".format(t, t_reshape))
t[0] = 1024
print("t:{}\nt_reshape:\n{}".format(t, t_reshape))
print("t.data 内存地址:{}".format(id(t.data)))
print("t_reshape.data 内存地址:{}".format(id(t_reshape.data)))
# ======================================= example 8 =======================================
# torch.transpose变换
# a=True
a=False
if a:
# torch.transpose
t = torch.rand((2, 3, 4))
t_transpose = torch.transpose(t, dim0=1, dim1=2) # c*h*w h*w*c
print("t shape:{}\nt_transpose shape: {}".format(t.shape, t_transpose.shape))
# ======================================= example 9 =======================================
# torch.squeeze
# 压缩维度为1的维度
# a=True
a=False
if a:
t = torch.rand((1, 2, 3, 1))
t_sq = torch.squeeze(t)
t_0 = torch.squeeze(t, dim=0)
t_1 = torch.squeeze(t, dim=1)
print(t.shape)
print(t_sq.shape)
print(t_0.shape)
print(t_1.shape)
# ======================================= example 8 =======================================
# torch.add
# a=True
a=False
if a:
t_0 = torch.randn((3, 3))
t_1 = torch.ones_like(t_0)
t_add = torch.add(t_0, 10, t_1)
print("t_0:\n{}\nt_1:\n{}\nt_add_10:\n{}".format(t_0, t_1, t_add))