张量的拼接和分割
- torch.stack
通过传入张量列表和创建的维度,会把列表中的张量沿着指定的维度进行堆砌,列表中的张量大小 必须相同
a = torch.randn(3,4)
b = torch.randn(3,4)
torch.stack([a,b],-1)
- torch.cat
列表中的张量沿着指定的维度进行堆砌,除了指定堆砌的维度参数可以不同外,其他维度参数必须相同
a = torch.randn(3,4)
b = torch.randn(3,5)
torch.cat([a,b],1).shape
- torch.split
用于分割张量,第一个参数为张量,第二个参数为分割的大小,可以是整数,也可以是列表,第三个是分割的维度,从哪个维度进行分割
a = torch.randn(2,3,6)
for i in torch.split(a,[1,2,3],-1):
print(i.shape)
a = torch.randn(2,3,6)
for i in torch.split(a,3,-1):
print(i.shape)
- torch.chunk
和torch.split的功能类似,参数也差不多,唯一区别是第二个参数为整数,表示分割的数量
a = torch.randn(2,3,6)
for i in torch.chunk(a,2,-1):
print(i.shape)
torch.unsqueeze
该函数用于扩增张量维度,在指定的维度增加一个维度,维度参数为1
a = torch.randn(3,4)
a.unsqueeze(-1).shape
torch.squeeze
该函数用于压缩维度,凡是维度为1的都会被压缩掉
a = torch.randn(3,4,1,1)
print(a.shape)
a.squeeze().shape