view()函数有些像numpy中的reshape函数,是用来的tensor(张量)形式的数据进行围堵重构的,直接用程序来说明用法
-
生成测试数据
import torch torch.manual_seed(0) # 用来控制内部的随机机制使每次得到的随机数一样 tt = torch.rand(3,4) # tensor([[0.4963, 0.7682, 0.0885, 0.1320], # [0.3074, 0.6341, 0.4901, 0.8964], # [0.4556, 0.6323, 0.3489, 0.4017]])
-
实现方法
print(tt.view((2,-1))) # tensor([[0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341], # [0.4901, 0.8964, 0.4556, 0.6323, 0.3489, 0.4017]]) print(tt.view(2,-1)) # tensor([[0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341], # [0.4901, 0.8964, 0.4556, 0.6323, 0.3489, 0.4017]])
其中
-1
表示不对这一维度的数量做具体限定,算出来是多少就是多少,注意在所有维度中只能有一个维度指定为-1view()函数可以接收两种形式的输入,一种是给出一个‘形状’
(2,-1)
,一种是一次列举各维度的维度值2,-1
-
可以用reshape()函数实现
pytorch提供了很好的numpy兼容性,很多numpy下的方法在pytorch中也可以使用,用reshape()函数实现方式和实现结果与view相同
print(tt.reshape((2,-1))) # tensor([[0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341], # [0.4901, 0.8964, 0.4556, 0.6323, 0.3489, 0.4017]]) print(tt.reshape(2,-1)) # tensor([[0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341], # [0.4901, 0.8964, 0.4556, 0.6323, 0.3489, 0.4017]])
-
再多做一点儿,三维
(2,2,-1)
print(tt.reshape(2,2,-1)) # tensor([[[0.4963, 0.7682, 0.0885], # [0.1320, 0.3074, 0.6341]], # # [[0.4901, 0.8964, 0.4556], # [0.6323, 0.3489, 0.4017]]])
要在2维空间里print3维的数据,大概就是这样了吧