import torch
a = torch.rand(4,1,28,28)print("a.shape:\t", a.shape)# prod(a.size) = prod(a'.size)
b = a.view(4,28*28)print("b:\t", b)
c = a.view(4,28*28).shape
print("c:\t", c)
d = a.view(4,784).shape
print("d:\t", d)
e = a.view(4*28,28).shape
print("e:\t", e)
f = a.view(4*1,28,28).shape
print("f:\t", f)
g = a.view(4,784)print("g:\t", g)
import torch
a = torch.rand(4,1,28,28)
b = torch.Size([4,1,28,28])# 在最前面增加了一个维度
c = a.unsqueeze(0).shape
print("c:\t", c)
d = a.unsqueeze(-1).shape
print("d:\t", d)
e = a.unsqueeze(4).shape
print("e:\t", e)
f = a.unsqueeze(-4).shape
print("f:\t", f)
g = a.unsqueeze(-5).shape
print("g:\t", g)
h = torch.tensor([1.2,2.3])
I = h.unsqueeze(-1)print("I:\t", I)
J = h.unsqueeze(0)print("J:\t", J)
import torch
a = torch.rand(32)
b = a.unsqueeze(1).unsqueeze(2).unsqueeze(0)print("b.shape:\t", b.shape)# squeeze
c = b.squeeze().shape
print("c:\t", c)
d = b.squeeze(0).shape
print("d:\t", d)
e = b.squeeze(-1).shape
print("e:\t", e)# shape不为1, 因此32不会变
f = b.squeeze(1).shape
print("f:\t", f)
g = b.squeeze(-4).shape
print("g:\t", g)
import torch
a = torch.rand(32)
b = a.unsqueeze(1).unsqueeze(2).unsqueeze(0)print("b.shape:\t", b.shape)
c = b.expand(4,32,14,14).shape
print("c:\t", c)# -1表示不变
d = b.expand(-1,32,-1,-1).shape
print("d:\t", d)
e = b.expand(-1,32,-1,-4).shape
print("e:\t", e)
import torch
a = torch.rand(32)
b = a.unsqueeze(1).unsqueeze(2).unsqueeze(0)print("b.shape:\t", b.shape)# 每一个维度要重复的次数
c = b.repeat(4,32,1,1).shape
print("c:\t", c)
d = b.repeat(4,1,1,1).shape
print("d:\t", d)
e = b.repeat(4,1,32,32).shape
print("e:\t", e)
import torch
a = torch.rand(4,3,28,28)
b = a.transpose(1,3).shape
print("b:\t", b)
c = torch.rand(4,3,28,32)
d = c.transpose(1,3).shape
print("d:\t", d)
e = c.transpose(1,3).transpose(1,2).shape
print("e:\t", e)
f = c.permute(0,2,3,1).shape
print("f:\t", f)