torch.bmm
torch.bmm(batch1, batch2, out=None) → Tensor
对存储在两个批batch1和batch2内的矩阵进行批矩阵乘操作。batch1和 batch2都为包含相同数量矩阵的3维张量。 如果batch1是形为b×n×m的张量,batch1是形为b×m×p的张量,则out和mat的形状都是n×p(n×m m×p ->n×p),即 res=(beta∗M)+(alpha∗sum(batch1i@batch2i,i=0,b))
对类型为 FloatTensor 或 DoubleTensor 的输入,alphaand beta必须为实数,否则两个参数须为整数。
参数:
batch1 (Tensor) – 第一批相乘矩阵
batch2 (Tensor) – 第二批相乘矩阵
out (Tensor, optional) – 输出张量
batch1 = torch.randn(10, 3, 4)
batch2 = torch.randn(10, 4, 5)
res = torch.bmm(batch1, batch2)
>>> res.size()
torch.Size([10, 3, 5])
torch.matmul(mat1, mat2, out=None) → Tensor
torch.mm(mat1, mat2, out=None) → Tensor
对矩阵mat1和mat2进行相乘。 如果mat1 是一个n×m张量,mat2 是一个 m×p 张量,将会输出一个 n×p 张量out。
参数 :
mat1 (Tensor) – 第一个相乘矩阵
mat2 (Tensor) – 第二个相乘矩阵
out (Tensor, optional) – 输出张量
torch.Tensor.expand(*sizes)
返回tensor的一个新视图,单个维度扩大为更大的尺寸。
tensor也可以扩大为更高维,新增加的维度将附在前面。
扩大tensor不需要分配新内存,只是仅仅新建一个tensor的视图,其中通过将stride设为0,一维将会扩展位更高维。任何一个一维的在不分配新内存情况下可扩展为任意的数值。
参数: - sizes(torch.Size or int…)-需要扩展的大小
例:
x = torch.Tensor([[1], [2], [3]])
x.size()
**torch.Size([3, 1])x.expand(3, 4)**
1 1 1 1
2 2 2 2
3 3 3 3
[torch.FloatTensor of size 3x4]
torch.cat
torch.cat(inputs, dimension=0) → Tensor
在给定维度上对输入的张量序列seq 进行连接操作。
torch.cat()可以看做 torch.split() 和 torch.chunk()的反操作。 cat() 函数可以通过下面例子更好的理解。
参数:
inputs (sequence of Tensors) – 可以是任意相同Tensor 类型的python 序列
dimension (int, optional) – 沿着此维连接张量序列。
例子:
x = torch.randn(2, 3)
x
0.5983 -0.0341 2.4918
1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 2x3]
torch.cat((x, x, x), 0)
0.5983 -0.0341 2.4918
1.5981 -0.5265 -0.8735
0.5983 -0.0341 2.4918
1.5981 -0.5265 -0.8735
0.5983 -0.0341 2.4918
1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 6x3]
torch.cat((x, x, x), 1)
0.5983 -0.0341 2.4918 0.5983 -0.0341 2.4918 0.5983 -0.0341 2.4918
1.5981 -0.5265 -0.8735 1.5981 -0.5265 -0.8735 1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 2x9]
torch.split
torch.split(tensor, split_size, dim=0)
将输入张量分割成相等形状的chunks(如果可分)。 如果沿指定维的张量形状大小不能被split_size 整分, 则最后一个分块会小于其它分块。
参数:
tensor (Tensor) – 待分割张量
split_size (int) – 单个分块的形状大小
dim (int) – 沿着此维进行分割