小白记录:
torch.reshape()
是 PyTorch 中用于改变张量(Tensor)形状的函数。它的参数含义如下:
函数签名
python
torch.reshape(input, shape) → Tensor
参数详解
-
input
(Tensor)-
含义:需要被改变形状的输入张量。
-
要求:必须是一个有效的 PyTorch 张量。
-
-
shape
(tuple of ints)-
含义:定义输出张量的新形状,是一个由整数组成的元组。
-
关键规则:
-
总元素数必须一致:新形状的元素总数(各维度大小的乘积)必须与原张量一致。
例如:原张量形状(4, 3)
(12个元素)可改为(6, 2)
或(12,)
,但不能改为(5, 2)
(10个元素)。 -
-1
的特殊含义:
元组中最多只能有一个-1
,表示该维度大小由 PyTorch 自动计算(根据总元素数和其他维度推断)。
例如:原张量有 12 个元素:-
shape=(4, -1)
→ 自动计算为(4, 3)
-
shape=(-1, 6)
→ 自动计算为(2, 6)
-
shape=(-1,)
→ 输出一维张量(12,)
-
-
-
示例说明
基础用法
python
import torch x = torch.arange(12) # [0, 1, 2, ..., 11] print(x.shape) # torch.Size([12]) # 改为 3x4 矩阵 y = torch.reshape(x, (3, 4)) print(y) # tensor([[ 0, 1, 2, 3], # [ 4, 5, 6, 7], # [ 8, 9, 10, 11]]) # 使用 -1 自动计算维度 z = torch.reshape(x, (2, 2, -1)) # 2x2x3 print(z.shape) # torch.Size([2, 2, 3])
错误用法
python
# 错误:元素总数不匹配 (12 != 8) torch.reshape(x, (2, 4)) # ❌ # 错误:多个 -1 torch.reshape(x, (-1, -1)) # ❌
与 view()
的区别
-
reshape()
:
尽可能返回原数据的视图(view)(不复制数据),如果原始内存连续则共享内存;否则自动复制数据返回新张量。 -
view()
:
严格要求原张量内存连续,否则报错。若内存不连续,需先.contiguous()
。
推荐:优先用
reshape()
,它更灵活且能处理非连续内存。
总结
参数 | 含义 | 规则 |
---|---|---|
input | 输入张量 | 必须是 PyTorch 张量 |
shape | 目标形状(整数元组) | 1. 元素总数必须与原张量一致 2. 最多一个 -1 (自动计算该维度大小) |