维度变换
view reshape
a = torch.rand(4,1,28,28)
print(a.view(-1,28*28).shape)
print(a.reshape(-1,28*28).shape)
unsqueeze ,sequeeze
增加或减少维度
sequeeze只能压缩对应维度大小为1的,不是1时则不操作
a = torch.rand(3,4)
print(a.shape)
print(a.unsqueeze(1).shape)
print(a.unsqueeze(1).squeeze(1).shape)
expand/repeat
expand 对对于维度进行扩充,只有在使用时,才填充值,-1表示对应维度值保持不变
a = torch.rand(4,32,14,14)
b = torch.rand(32)
b = b.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
print(b.shape)
print(b.expand(4,32,14,14).shape)
print(b.expand(4,32,14,-1).shape)
repeat表示对应维度copy的次数
a = torch.rand(4,32,14,14)
b = torch.rand(32)
b = b.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
print(b.shape)
print(b.repeat(4,1,1,1).shape)
拼接
cat
对应维度进行叠加
a = torch.rand(1,4,5)
b = torch.rand(2,4,5)
torch.cat([a,b],dim=0).shape
stack
新的维度叠加
a = torch.rand(4,5)
b = torch.rand(4,5)
torch.stack([a,b],dim=1).shape
运算操作
clamp
用于裁剪tensor
a = torch.randperm(10)
print(a)
a = a.clamp(5)
print(a)
a = a.clamp(5,8)
print(a)
统计属性
norm ,mean,sum,max,min,prod,argmax,argmin\where \gather
prob = torch.randn(4,10)
idx = prob.topk(dim=1,k=3)
label = torch.arange(10)+100
torch.gather(label.expand(4,10),dim=1,index=idx[1])