pytorch中的reshape函数
1.首先什么是reshape函数?
reshape函数是MATLAB中将指定的矩阵变换成特定维数矩阵一种函数,且矩阵中元素个数不变,函数可以重新调整矩阵的行数、列数、维数。函数语法为 B = A.reshape(size) 是指返回一个和A元素相同的n维数组,但是由向量size来决定重构数组维数的大小。下面看在jupyter中运行的代码示例。
int:
import torch
import numpy
a = torch.rand(12) #创建一个0维的张量
print(a)
print(a.shape)
out:
tensor([0.3902, 0.0670, 0.2945, 0.3515, 0.0290, 0.7229, 0.8366, 0.4437, 0.4730,
0.6963, 0.4553, 0.4144])
torch.Size([12])
int:
b = a.reshape(3,4) #通过reshape函数将数组改成了二维的
print(b)
print(b.shape)
out:
tensor([[0.3902, 0.0670, 0.2945, 0.3515],
[0.0290, 0.7229, 0.8366, 0.4437],
[0.4730, 0.6963, 0.4553, 0.4144]])
torch.Size([3, 4])
#如果你愿意也可以将其转化为三维的数组
int:
c = b.reshape(2,2,3) #变成了三维的数组
print(c) #形状的变化是基于数组元素所决定的
print(c.shape)
out:
tensor([[[0.3902, 0.0670, 0.2945],
[0.3515, 0.0290, 0.7229]],
[[0.8366, 0.4437, 0.4730],
[0.6963, 0.4553, 0.4144]]])
torch.Size([2, 2, 3])
"""另外因为reshape函数生成的新数组和原始数组公用一个内存,
所以不管改变的是新的或者是旧的数组,另一个数组也会随之改变"""
int:
a[0] = 100 #可以很明显的观察到b随着a的改变而改变了
print(a)
print(b)
out:
tensor([**1.0000e+02**, 6.6983e-02, 2.9452e-01, 3.5150e-01, 2.8957e-02, 7.2294e-01,
8.3661e-01, 4.4366e-01, 4.7304e-01, 6.9634e-01, 4.5533e-01, 4.1437e-01])
tensor([[**1.0000e+02**, 6.6983e-02, 2.9452e-01, 3.5150e-01],
[2.8957e-02, 7.2294e-01, 8.3661e-01, 4.4366e-01],
[4.7304e-01, 6.9634e-01, 4.5533e-01, 4.1437e-01]])