【Pytorch框架学习】之张量的创建和操作(1)

【Pytorch框架】之张量的创建和操作(1)

一、张量的创建

import torch
import numpy as np
torch.manual_seed(1)

# ===============================  exmaple 1 ===============================
# 通过torch.tensor创建张量
# flag = True
flag = False
if flag:
    arr = np.ones((3, 3))
    print("ndarray的数据类型:", arr.dtype)

    t = torch.tensor(arr, device='cuda')
    # t = torch.tensor(arr)

    print(t)

# ===============================  exmaple 2 ===============================
# 通过torch.from_numpy创建张量
# 共享内存,修改其中一个都改变
# flag = True
flag = False
if flag:
    arr = np.array([[1, 2, 3], [4, 5, 6]])
    t = torch.from_numpy(arr)
    # print("numpy array: ", arr)
    # print("tensor : ", t)

    # print("\n修改arr")
    # arr[0, 0] = 0
    # print("numpy array: ", arr)
    # print("tensor : ", t)

    print("\n修改tensor")
    t[0, 0] = -1
    print("numpy array: ", arr)
    print("tensor : ", t)

# ===============================  exmaple 3 ===============================
# 通过torch.zeros创建张量
# 其余torch.ones、torch.ones_like、torch.zeros_like
# flag = True
flag = False
if flag:
    out_t = torch.tensor([1])

    t = torch.zeros((3, 3), out=out_t)

    print(t, '\n', out_t)
    print(id(t), id(out_t), id(t) == id(out_t))

# ===============================  exmaple 4 ===============================
# 通过torch.full创建全1张量
# flag = True
flag = False
if flag:
    t = torch.full((3, 3), 1)
    print(t)

# ===============================  exmaple 5 ===============================
# 通过torch.arange创建等差数列张量
# flag = True
flag = False
if flag:
    t = torch.arange(2, 10, 2)
    print(t)

# ===============================  exmaple 6 ===============================
# 通过torch.linspace创建均分数列张量
# flag = True
flag = False
if flag:
    # t = torch.linspace(2, 10, 5)
    t = torch.linspace(2, 10, 6)
    print(t)

# ===============================  exmaple 7 ===============================
# 通过torch.normal创建正态分布张量
flag = True
# flag = False
if flag:

    # mean:张量 std: 张量
    # mean = torch.arange(1, 5, dtype=torch.float)
    # std = torch.arange(1, 5, dtype=torch.float)
    # t_normal = torch.normal(mean, std)
    # print("mean:{}\nstd:{}".format(mean, std))
    # print(t_normal)

    # mean:标量 std: 标量
    # t_normal = torch.normal(0., 1., size=(4,))
    # print(t_normal)

    # mean:张量 std: 标量
    mean = torch.arange(1, 5, dtype=torch.float)
    std = 1
    t_normal = torch.normal(mean, std)
    print("mean:{}\nstd:{}".format(mean, std))
    print(t_normal)

二、张量的操作

import torch
torch.manual_seed(1)

# ======================================= example 1 =======================================
# torch.cat拼接
# flag = True
flag = False

if flag:
    t = torch.ones((2, 3))

    t_0 = torch.cat([t, t], dim=0)
    t_1 = torch.cat([t, t, t], dim=1)

    print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))

# ======================================= example 2 =======================================
# torch.stack拓展维度拼接
# flag = True
flag = False

if flag:
    t = torch.ones((2, 3))

    t_stack = torch.stack([t, t, t], dim=0)

    print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))

# ======================================= example 3 =======================================
# torch.chunk切分
# flag = True
flag = False

if flag:
    a = torch.ones((2, 7))  # 7
    list_of_tensors = torch.chunk(a, dim=1, chunks=3)   # 3

    for idx, t in enumerate(list_of_tensors):
        print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

# ======================================= example 4 =======================================
# torch.split
# flag = True
flag = False

if flag:
    t = torch.ones((2, 5))

    list_of_tensors = torch.split(t, [2, 1, 1], dim=1)  # [2 , 1, 2]
    for idx, t in enumerate(list_of_tensors):
        print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

    # list_of_tensors = torch.split(t, [2, 1, 2], dim=1)  # 报错
    # for idx, t in enumerate(list_of_tensors):
    #     print("第{}个张量:{}, shape is {}".format(idx, t, t.shape))

# ======================================= example 5 =======================================
# torch.index_select索引
# flag = True
flag = False

if flag:
    t = torch.randint(0, 9, size=(3, 3))
    idx = torch.tensor([0, 2], dtype=torch.long)    # float报错
    t_select = torch.index_select(t, dim=0, index=idx)
    print("t:\n{}\nt_select:\n{}".format(t, t_select))

# ======================================= example 6 =======================================
# torch.masked_select索引
# flag = True
flag = False

if flag:

    t = torch.randint(0, 9, size=(3, 3))
    mask = t.le(5)  # ge is mean greater than or equal/   gt: greater than  le  lt
    t_select = torch.masked_select(t, mask)
    print("t:\n{}\nmask:\n{}\nt_select:\n{} ".format(t, mask, t_select))

# ======================================= example 7 =======================================
# torch.reshape变换
# flag = True
flag = False

if flag:
    t = torch.randperm(8)
    t_reshape = torch.reshape(t, (-1, 2, 2))    # -1
    print("t:{}\nt_reshape:\n{}".format(t, t_reshape))

    t[0] = 1024
    print("t:{}\nt_reshape:\n{}".format(t, t_reshape))
    print("t.data 内存地址:{}".format(id(t.data)))
    print("t_reshape.data 内存地址:{}".format(id(t_reshape.data)))

# ======================================= example 8 =======================================
# torch.transpose变换
# flag = True
flag = False

if flag:
    # torch.transpose
    t = torch.rand((2, 3, 4))
    t_transpose = torch.transpose(t, dim0=1, dim1=2)    # c*h*w     h*w*c
    print("t shape:{}\nt_transpose shape: {}".format(t.shape, t_transpose.shape))

# ======================================= example 9 =======================================
# torch.squeeze
# 压缩维度为1的维度
# flag = True
flag = False

if flag:
    t = torch.rand((1, 2, 3, 1))
    t_sq = torch.squeeze(t)
    t_0 = torch.squeeze(t, dim=0)
    t_1 = torch.squeeze(t, dim=1)
    print(t.shape)
    print(t_sq.shape)
    print(t_0.shape)
    print(t_1.shape)

# ======================================= example 8 =======================================
# torch.add
# flag = True
flag = False

if flag:
    t_0 = torch.randn((3, 3))
    t_1 = torch.ones_like(t_0)
    t_add = torch.add(t_0, 10, t_1)

    print("t_0:\n{}\nt_1:\n{}\nt_add_10:\n{}".format(t_0, t_1, t_add))
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值