torch.reshape的使用

小白记录:

torch.reshape() 是 PyTorch 中用于改变张量(Tensor)形状的函数。它的参数含义如下:

函数签名

python

torch.reshape(input, shape) → Tensor

参数详解

  1. input (Tensor)

    • 含义:需要被改变形状的输入张量。

    • 要求:必须是一个有效的 PyTorch 张量。

  2. shape (tuple of ints)

    • 含义:定义输出张量的新形状,是一个由整数组成的元组。

    • 关键规则

      • 总元素数必须一致:新形状的元素总数(各维度大小的乘积)必须与原张量一致。
        例如:原张量形状 (4, 3)(12个元素)可改为 (6, 2) 或 (12,),但不能改为 (5, 2)(10个元素)。

      • -1 的特殊含义
        元组中最多只能有一个 -1,表示该维度大小由 PyTorch 自动计算(根据总元素数和其他维度推断)。
        例如:原张量有 12 个元素:

        • shape=(4, -1) → 自动计算为 (4, 3)

        • shape=(-1, 6) → 自动计算为 (2, 6)

        • shape=(-1,) → 输出一维张量 (12,)


示例说明

基础用法

python

import torch

x = torch.arange(12)  # [0, 1, 2, ..., 11]
print(x.shape)        # torch.Size([12])

# 改为 3x4 矩阵
y = torch.reshape(x, (3, 4))
print(y)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])

# 使用 -1 自动计算维度
z = torch.reshape(x, (2, 2, -1))  # 2x2x3
print(z.shape)  # torch.Size([2, 2, 3])
错误用法

python

# 错误:元素总数不匹配 (12 != 8)
torch.reshape(x, (2, 4))   # ❌

# 错误:多个 -1
torch.reshape(x, (-1, -1)) # ❌

与 view() 的区别

  • reshape()
    尽可能返回原数据的视图(view)(不复制数据),如果原始内存连续则共享内存;否则自动复制数据返回新张量。

  • view()
    严格要求原张量内存连续,否则报错。若内存不连续,需先 .contiguous()

推荐:优先用 reshape(),它更灵活且能处理非连续内存。


总结

参数含义规则
input输入张量必须是 PyTorch 张量
shape目标形状(整数元组)1. 元素总数必须与原张量一致
2. 最多一个 -1(自动计算该维度大小)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值