一、.view()函数简介
- PyTorch中的
.view()
函数是一个用于改变张量形状的方法。它类似于NumPy中的.reshape()
函数,可以通过重新排列张量的维度来改变其形状,而不改变张量的数据。 - 在深度学习中,
.view()
函数常用于调整输入数据的形状以适应模型的输入要求,或者在网络层之间传递数据时进行形状的转换。 - .view()函数的语法如下,
shape
是一个整数元组,用于指定新的张量形状,新形状的元素个数必须与原形状的元素个数相同。函数返回一个具有指定形状的新张量,但与原始张量共享数据存储,因此它们指向相同的内存区域。:
New_Tensor = Tensor.view(*shape)
二、.view()函数的使用方法
1. 改变形状
直接使用.view(new_shape)
将张量修改为指定形状,保持新张量和原张量在数量上一致即可。
import torch
# 创建一个形状为(2, 3, 4)的张量
x = torch.randn(2, 3, 4)
print("原始张量形状:", x.shape)
# 使用.view()改变张量形状为(6, 4)
y = x.view(6, 4)
print("改变形状后的张量形状:", y.shape)
运行结果:
2. 使用-1展平张量
在PyTorch中,.view()函数可以接受一个特殊的参数 -1,用于自动计算张量在该维度上的大小。将某个维度的大小设置为 -1,可以使得该维度的大小根据其他维度的大小自动确定,以保持张量的元素总数不变。
import torch
# 创建一个形状为(2, 3, 4)的张量
x = torch.randn(2, 3, 4)
print("原始张量形状:", x.shape)
# 使用.view()展平张量为形状为(12,2)
y1 = x.view(-1,2)
print("展平后的张量y1形状:", y1.shape)
# 使用.view()展平张量为形状为(6,2,2)
y2 = x.view(-1,2,2)
print("展平后的张量y2形状:", y2.shape)
运行结果:
3. 调整维度
import torch
# 创建一个形状为(2, 3, 4)的张量
x = torch.randn(2, 3, 4)
print("原始张量形状:", x.shape)
# 使用.view()调整张量的维度为(2, 4, 3)
y = x.view(2, 4, 3)
print("调整维度后的张量形状:", y.shape)
运行结果: