def test_view():
x = torch.randn(4,4)
print(x.size())
y = x.view(16)
print(f"y: {y}")
print(f"y.shape: {y.shape}")
z = x.view(-1, 8)
print(f"z: {z}")
print(f"z.shape: {z.shape}")
打印结果:
torch.Size([4, 4])
y: tensor([-1.0162, -0.2829, -1.5826, 0.0502, -0.2395, 2.3658, 0.3365, 2.3561,
0.6161, 1.6659, 0.2101, -0.7243, -0.9361, 0.1402, -0.5288, 0.3216])
y.shape: torch.Size([16])
z: tensor([[-1.0162, -0.2829, -1.5826, 0.0502, -0.2395, 2.3658, 0.3365, 2.3561],
[ 0.6161, 1.6659, 0.2101, -0.7243, -0.9361, 0.1402, -0.5288, 0.3216]])
z.shape: torch.Size([2, 8])