将张量拉成一维的向量
x=torch.randn(2,3,2)
x2=torch.flatten(x,0)
x3=torch.flatten(x,1)
x4=torch.flatten(x,2)
import torch
x=torch.randn(2,3,2)
print(x)
#生成:
[ [ [-0.5829, 0.8214],
[ 0.6218, 0.3298],
[ 0.0222, -0.8473] ],
[ [ 0.1044, -1.8784],
[ 1.2323, 2.6551],
[ 0.0382, 0.6649] ] ]
x2=torch.flatten(x,0)#等价于x2=torch.flatten(x)
print(x2)
[-0.5829, 0.8214, 0.6218, 0.3298, 0.0222, -0.8473, 0.1044, -1.8784, 1.2323, 2.6551,0.0382, 0.6649]
import torch
x=torch.randn(2,3,2)
print(x)
#生成:
[ [ [-0.5829, 0.8214],
[ 0.6218, 0.3298],
[ 0.0222, -0.8473] ],
[ [ 0.1044, -1.8784],
[ 1.2323, 2.6551],
[ 0.0382, 0.6649] ] ]
x3=torch.flatten(x,1)
print(x3)
[ [-0.5829, 0.8214, 0.6218, 0.3298, 0.0222, -0.8473],
[ 0.1044, -1.8784, 1.2323, 2.6551, 0.0382, 0.6649] ]
import torch
x=torch.randn(2,3,2)
print(x)
#生成:
[ [ [-0.5829, 0.8214],
[ 0.6218, 0.3298],
[ 0.0222, -0.8473] ],
[ [ 0.1044, -1.8784],
[ 1.2323, 2.6551],
[ 0.0382, 0.6649] ] ]
x4=torch.flatten(x,2)
print(x4)
[ [ [-0.5829, 0.8214],
[ 0.6218, 0.3298],
[ 0.0222, -0.8473] ],
[ [ 0.1044, -1.8784],
[ 1.2323, 2.6551],
[ 0.0382, 0.6649] ] ]