1、torch.stack
torch.stack((tensor1,tensor2),dim=0)
1、把tensor1和tensor2叠在一起生成一个新张量(如tensor1维度为2,则新张量维度为3)。
2、tensor1和tensor2必须大小完全一致。
3、dim决定新生成维度的位置,如:dim=0,新生成的张量为[新维度,3,3]。dim=1,新生成的张量为[3,新维度,3]
2、torch.cat
torch.cat((tensor1,tensor2),dim)
1、把tensor1和tensor2按照dim给定的维度拼接在一起。
2、和torch.stack不同的是拼接后的张量维度不会增加。
torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True, padding_value=0)
(1) 函数的作用是对输入的sequences进行补长填充。
(2) padding_value是填充的值,batch_first=True默认在第0维度进行填充。
(3) 例子:
a=tensor([ 0.7160, 1.2006, -1.8447])
b=tensor([ 0.3941, 0.3839, 0.1166, -0.7221, 1.8661])
c=tensor([-0.6521, 0.0681, 0.6626, -0.3679, -0.6042, 1.6951, 0.4937])
>>>pad_sequence([a,b,c],batch_first=True)
tensor([[ 0.7160, 1.2006, -1.8447, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.3941, 0.3839, 0.1166, -0.7221, 1.8661, 0.0000, 0.0000],
[-0.6521, 0.0681, 0.6626, -0.3679, -0.6042, 1.6951, 0.4937]])
>>>pad_sequence([a,b,c],batch_first=False)
tensor([[ 0.7160, 0.3941, -0.6521],
[ 1.2006, 0.3839, 0.0681],
[-1.8447, 0.1166, 0.6626],
[ 0.0000, -0.7221, -0.3679],
[ 0.0000, 1.8661, -0.6042],
[ 0.0000, 0.0000, 1.6951],
[ 0.0000, 0.0000, 0.4937]])
(1) torch.squeeze(tensor,[dim])
去掉tensor中维数dim为1的维度,如果没有dim则默认去掉tensor中所有为1的维度。
(2) torch.unsqueeze(tensor,[dim])
在指定维度dim添加维数为一的维度。
其中pytorch的维度顺序定义方式如下所示:
a = torch.tensor([ [ [1],[2],[3] ],[ [4],[5],[6] ] ])
↑ ↑ ↑
0 1 2
>>> torch.squeeze(a, 2)
tensor([[1, 2, 3],
[4, 5, 6]])
pytorch中的张量乘法主要有以下几个:
(1) 二维矩阵乘法 torch.mm
(2) 三维batch矩阵乘法 torch.bmm
(3) 逐元素乘法 torch.mul
(4) 逐元素乘法 *
原文地址:https://zhuanlan.zhihu.com/p/100069938
torch.mm
a = torch.tensor([[1,2,3,4]])
b = torch.tensor([1],[2],[3],[4])
>>> torch.mm(b,a)
tensor([[ 1, 2, 3, 4],
[ 2, 4, 6, 8],
[ 3, 6, 9, 12],
[ 4, 8, 12, 16]])
torch.bmm
a = torch.tensor([[[1,2,3],[1,2,3]],[[3,4,5],[3,4,5]]]) #2×2×3
b = torch.tensor([[[1,2],[1,2],[1,2]],[[3,4],[3,4],[3,4]]]) #2×3×2
>>> torch.bmm(a,b)
tensor([[[ 6, 12],
[ 6, 12]],
[[36, 48],
[36, 48]]]) #2×2×2
6、torch.t() 或 张量.t()
torch.t() 或 t()用于对2维的张量进行转置
a = torch.tensor([[1,2,3,4]])
>>> a.t()
>>> torch.t(a)
tensor([[1],
[2],
[3],
[4]])