pytorch框架学习总结(一)

pytorch中张量的基本操作

1张量的概念

pytorch提供了两种类型的数据,成为张量和变量。张量类似于Numpy中的数组。常见的张量如下:标量(0维张量)、向量(1维张量)、矩阵(2维张量)、3维张量……

2 张量的创建

(1)利用torch.tensor()直接创建
torch.tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False)
其中:
data: 数据,可以是list,numpy的ndarray
dtype: 数据类型,默认与data的类型一致
device: 所在设备,gpu/cpu
requires_grad: 是否需要梯度,因为神经网络结构经常会要求梯度
pin_memory: 是否存于锁页内存

import torch
import numpy as np
#通过torch.tensor创建张量
arr = np.ones((3, 3))
print("ndarray的数据类型:", arr.dtype)
t = torch.tensor(arr)
print(t)
#结果如下
ndarray的数据类型: float64
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], dtype=torch.float64)

(2)通过数值创建
torch.zeros() torch.zeros_like()
torch.ones() torch.ones_like()
torch.full() torch.full_like()

#通过torch.zeros创建张量
#其余torch.ones、torch.ones_like、torch.zeros_like
import torch
out_t = torch.tensor([1])
t=torch.zeros((33),out=out_t)
print(t, '\n', out_t)
print(id(t), id(out_t), id(t) == id(out_t))
#结果如下
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]]) 
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])
2409918589672 2409918589672 True

#通过torch.full创建张量
t = torch.full((3, 3), 1)
print(t)
#结果如下
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
# 通过torch.arange创建等差数列张量
t = torch.arange(2, 10, 2)
print(t)
#结果如下
tensor([2, 4, 6, 8])

(3)通过torch.normal创建正态分布张量

# 通过torch.normal创建正态分布张量
import torch
# mean:张量 std: 张量
mean = torch.arange(1, 5, dtype=torch.float)
std = 1
t_normal = torch.normal(mean, std)
print("mean:{}\nstd:{}".format(mean, std))
print(t_normal)
#结果如下
mean:tensor([1., 2., 3., 4.])
std:1
tensor([1.6614, 2.2669, 3.0617, 4.6213])

3 张量的操作

#==================example 1===============
# torch.cat拼接
# a=True
a=False
if a:
    t = torch.ones((2, 3))

    t_0 = torch.cat([t, t], dim=0)
    t_1 = torch.cat([t, t, t], dim=1)

    print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))

# ======================================= example 2 =======================================
# torch.stack拓展维度拼接
# a=True
a=False
if a:
    t = torch.ones((2, 3))

    t_stack = torch.stack([t, t, t], dim=0)

    print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))

# ======================================= example 3 =======================================
# torch.chunk切分
# a=True
a=False
if a:
    a = torch.ones((2, 7))  # 7
    list_of_tensors = torch.chunk(a, dim=1, chunks=3)   # 3

    for idx, t in enumerate(list_of_tensors):
        print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

# ======================================= example 4 =======================================
# torch.split
# a=True
a=False
if a:
    t = torch.ones((2, 5))

    list_of_tensors = torch.split(t, [2, 1, 1], dim=1)  # [2 , 1, 2]
    for idx, t in enumerate(list_of_tensors):
        print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

    # list_of_tensors = torch.split(t, [2, 1, 2], dim=1)  # 报错
    # for idx, t in enumerate(list_of_tensors):
    #     print("第{}个张量:{}, shape is {}".format(idx, t, t.shape))

# ======================================= example 5 =======================================
# torch.index_select索引
# a=True
a=False
if a:
    t = torch.randint(0, 9, size=(3, 3))
    idx = torch.tensor([0, 2], dtype=torch.long)    # float报错
    t_select = torch.index_select(t, dim=0, index=idx)
    print("t:\n{}\nt_select:\n{}".format(t, t_select))

# ======================================= example 6 =======================================
# torch.masked_select索引
# a=True
a=False
if a:

    t = torch.randint(0, 9, size=(3, 3))
    mask = t.le(5)  # ge is mean greater than or equal/   gt: greater than  le  lt
    t_select = torch.masked_select(t, mask)
    print("t:\n{}\nmask:\n{}\nt_select:\n{} ".format(t, mask, t_select))

# ======================================= example 7 =======================================
# torch.reshape变换
# a=True
a=False
if a:
    t = torch.randperm(8)
    t_reshape = torch.reshape(t, (-1, 2, 2))    # -1
    print("t:{}\nt_reshape:\n{}".format(t, t_reshape))

    t[0] = 1024
    print("t:{}\nt_reshape:\n{}".format(t, t_reshape))
    print("t.data 内存地址:{}".format(id(t.data)))
    print("t_reshape.data 内存地址:{}".format(id(t_reshape.data)))

# ======================================= example 8 =======================================
# torch.transpose变换
# a=True
a=False
if a:
    # torch.transpose
    t = torch.rand((2, 3, 4))
    t_transpose = torch.transpose(t, dim0=1, dim1=2)    # c*h*w     h*w*c
    print("t shape:{}\nt_transpose shape: {}".format(t.shape, t_transpose.shape))

# ======================================= example 9 =======================================
# torch.squeeze
# 压缩维度为1的维度
# a=True
a=False
if a:
    t = torch.rand((1, 2, 3, 1))
    t_sq = torch.squeeze(t)
    t_0 = torch.squeeze(t, dim=0)
    t_1 = torch.squeeze(t, dim=1)
    print(t.shape)
    print(t_sq.shape)
    print(t_0.shape)
    print(t_1.shape)

# ======================================= example 8 =======================================
# torch.add
# a=True
a=False
if a:
    t_0 = torch.randn((3, 3))
    t_1 = torch.ones_like(t_0)
    t_add = torch.add(t_0, 10, t_1)

    print("t_0:\n{}\nt_1:\n{}\nt_add_10:\n{}".format(t_0, t_1, t_add))

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值