squeeze
import pytorch
torch.squeeze(input, dim=None)
squeeze是一个用来压缩维度的函数,input为指定的输入的张量,dim这个参数为输入参量中需要压缩的维度。举个栗子
x=torch.ones(1,2,5)
print(x)
'''
输出:
tensor([[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]])
'''
列出维度为(1,2,5)的三维张量。这里解释一下这个三维如何看,首先三维对应着三个中括号。拆掉第一层(最外面一层)中括号,会发现里面只有一个数据(一个中括号)。
拆掉第二层中括号,会发现里面有两个数据(两个中括号)。
拆掉第三层其中一个中括号,会发现有五个数据(1.)
正好对应(1,2,5)维。
print(x.squeeze(0))
'''
对上述张量操作
输出:
tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
'''
print(x.squeeze(1))
'''
对上述张量操作
输出:
tensor([[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]])
'''
squeeze(n)本质上是将第n维的维数由1转化成0,若第n维不为1,则不进行变化。
第一个例子中。对(1,2,5)维张量,squeeze(0)将第一个维度进行压缩。由于第一个维度为1。则输出转化为(2,5)维的张量。
第二个例子中。对(1,2,5)维张量,squeeze(1)将第二个维度进行压缩。由于第二个维度是2,不为1,不进行变化。输出不变。
unsqueeze
print(x.unsqueeze(2))
'''
tensor([[[[1., 1., 1., 1., 1.]],
[[1., 1., 1., 1., 1.]]]])
'''
print(x.unsqueeze(3))
'''
tensor([[[[1.],
[1.],
[1.],
[1.],
[1.]],
[[1.],
[1.],
[1.],
[1.],
[1.]]]])
'''
unsqueeze(n)表示的是在第n维度增加一个大小为1的维度
第一个例子中转化成(1,2,1,5)维
第二个例子转化成(1,2,5,1)维
菜鸟刚学,仅记录学习心得,欢迎交流~