【Pytorch框架】之张量的创建和操作(1)
一、张量的创建
import torch
import numpy as np
torch.manual_seed(1)
# =============================== exmaple 1 ===============================
# 通过torch.tensor创建张量
# flag = True
flag = False
if flag:
arr = np.ones((3, 3))
print("ndarray的数据类型:", arr.dtype)
t = torch.tensor(arr, device='cuda')
# t = torch.tensor(arr)
print(t)
# =============================== exmaple 2 ===============================
# 通过torch.from_numpy创建张量
# 共享内存,修改其中一个都改变
# flag = True
flag = False
if flag:
arr = np.array([[1, 2, 3], [4, 5, 6]])
t = torch.from_numpy(arr)
# print("numpy array: ", arr)
# print("tensor : ", t)
# print("\n修改arr")
# arr[0, 0] = 0
# print("numpy array: ", arr)
# print("tensor : ", t)
print("\n修改tensor")
t[0, 0] = -1
print("numpy array: ", arr)
print("tensor : ", t)
# =============================== exmaple 3 ===============================
# 通过torch.zeros创建张量
# 其余torch.ones、torch.ones_like、torch.zeros_like
# flag = True
flag = False
if flag:
out_t = torch.tensor([1])
t = torch.zeros((3, 3), out=out_t)
print(t, '\n', out_t)
print(id(t), id(out_t), id(t) == id(out_t))
# =============================== exmaple 4 ===============================
# 通过torch.full创建全1张量
# flag = True
flag = False
if flag:
t = torch.full((3, 3), 1)
print(t)
# =============================== exmaple 5 ===============================
# 通过torch.arange创建等差数列张量
# flag = True
flag = False
if flag:
t = torch.arange(2, 10, 2)
print(t)
# =============================== exmaple 6 ===============================
# 通过torch.linspace创建均分数列张量
# flag = True
flag = False
if flag:
# t = torch.linspace(2, 10, 5)
t = torch.linspace(2, 10, 6)
print(t)
# =============================== exmaple 7 ===============================
# 通过torch.normal创建正态分布张量
flag = True
# flag = False
if flag:
# mean:张量 std: 张量
# mean = torch.arange(1, 5, dtype=torch.float)
# std = torch.arange(1, 5, dtype=torch.float)
# t_normal = torch.normal(mean, std)
# print("mean:{}\nstd:{}".format(mean, std))
# print(t_normal)
# mean:标量 std: 标量
# t_normal = torch.normal(0., 1., size=(4,))
# print(t_normal)
# mean:张量 std: 标量
mean = torch.arange(1, 5, dtype=torch.float)
std = 1
t_normal = torch.normal(mean, std)
print("mean:{}\nstd:{}".format(mean, std))
print(t_normal)
二、张量的操作
import torch
torch.manual_seed(1)
# ======================================= example 1 =======================================
# torch.cat拼接
# flag = True
flag = False
if flag:
t = torch.ones((2, 3))
t_0 = torch.cat([t, t], dim=0)
t_1 = torch.cat([t, t, t], dim=1)
print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))
# ======================================= example 2 =======================================
# torch.stack拓展维度拼接
# flag = True
flag = False
if flag:
t = torch.ones((2, 3))
t_stack = torch.stack([t, t, t], dim=0)
print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
# ======================================= example 3 =======================================
# torch.chunk切分
# flag = True
flag = False
if flag:
a = torch.ones((2, 7)) # 7
list_of_tensors = torch.chunk(a, dim=1, chunks=3) # 3
for idx, t in enumerate(list_of_tensors):
print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))
# ======================================= example 4 =======================================
# torch.split
# flag = True
flag = False
if flag:
t = torch.ones((2, 5))
list_of_tensors = torch.split(t, [2, 1, 1], dim=1) # [2 , 1, 2]
for idx, t in enumerate(list_of_tensors):
print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))
# list_of_tensors = torch.split(t, [2, 1, 2], dim=1) # 报错
# for idx, t in enumerate(list_of_tensors):
# print("第{}个张量:{}, shape is {}".format(idx, t, t.shape))
# ======================================= example 5 =======================================
# torch.index_select索引
# flag = True
flag = False
if flag:
t = torch.randint(0, 9, size=(3, 3))
idx = torch.tensor([0, 2], dtype=torch.long) # float报错
t_select = torch.index_select(t, dim=0, index=idx)
print("t:\n{}\nt_select:\n{}".format(t, t_select))
# ======================================= example 6 =======================================
# torch.masked_select索引
# flag = True
flag = False
if flag:
t = torch.randint(0, 9, size=(3, 3))
mask = t.le(5) # ge is mean greater than or equal/ gt: greater than le lt
t_select = torch.masked_select(t, mask)
print("t:\n{}\nmask:\n{}\nt_select:\n{} ".format(t, mask, t_select))
# ======================================= example 7 =======================================
# torch.reshape变换
# flag = True
flag = False
if flag:
t = torch.randperm(8)
t_reshape = torch.reshape(t, (-1, 2, 2)) # -1
print("t:{}\nt_reshape:\n{}".format(t, t_reshape))
t[0] = 1024
print("t:{}\nt_reshape:\n{}".format(t, t_reshape))
print("t.data 内存地址:{}".format(id(t.data)))
print("t_reshape.data 内存地址:{}".format(id(t_reshape.data)))
# ======================================= example 8 =======================================
# torch.transpose变换
# flag = True
flag = False
if flag:
# torch.transpose
t = torch.rand((2, 3, 4))
t_transpose = torch.transpose(t, dim0=1, dim1=2) # c*h*w h*w*c
print("t shape:{}\nt_transpose shape: {}".format(t.shape, t_transpose.shape))
# ======================================= example 9 =======================================
# torch.squeeze
# 压缩维度为1的维度
# flag = True
flag = False
if flag:
t = torch.rand((1, 2, 3, 1))
t_sq = torch.squeeze(t)
t_0 = torch.squeeze(t, dim=0)
t_1 = torch.squeeze(t, dim=1)
print(t.shape)
print(t_sq.shape)
print(t_0.shape)
print(t_1.shape)
# ======================================= example 8 =======================================
# torch.add
# flag = True
flag = False
if flag:
t_0 = torch.randn((3, 3))
t_1 = torch.ones_like(t_0)
t_add = torch.add(t_0, 10, t_1)
print("t_0:\n{}\nt_1:\n{}\nt_add_10:\n{}".format(t_0, t_1, t_add))