target.reshape(-1)
的作用是将多维张量变成一维张量,特别是在你有一个二维或更高维度的张量时,reshape(-1)
可以将它压平为一维。我们通过以下例子展示 reshape(-1)
的效果:
import torch
# 假设一个二维张量 target
target = torch.tensor([[1, 2, 3], [4, 5, 6]])
print("原始 target 张量:")
print(target)
print("target 的形状:", target.shape)
# 使用 reshape(-1) 将二维张量变为一维
reshaped_target = target.reshape(-1)
print("\n使用 reshape(-1) 后的 target 张量:")
print(reshaped_target)
print("reshaped_target 的形状:", reshaped_target.shape)
输出结果:
原始 target 张量:
tensor([[1, 2, 3],
[4, 5, 6]])
target 的形状: torch.Size([2, 3])
使用 reshape(-1) 后的 target 张量:
tensor([1, 2, 3, 4, 5, 6])
reshaped_target 的形状: torch.Size([6])
解释:
- 原始的
target
是一个 2x3 的二维张量,形状为[2, 3]
。 - 通过
target.reshape(-1)
,这个二维张量被压平为一维,变成了一个长度为 6 的一维张量,形状为[6]
。
reshape(-1)
通常用于将多维数据转换为一维数据,在某些操作中(如创建掩码、计算损失等)需要使用这种压平的形式。