torch.matmul()
torch.matmul()
是 PyTorch 中的矩阵乘法函数,用于进行两个矩阵(或张量)的矩阵乘法。它的使用方式如下:
output = torch.matmul(input1, input2)
其中,input1
和 input2
分别表示两个矩阵或张量,output
表示它们的矩阵乘积结果。需要注意的是,两个矩阵的维度必须满足矩阵乘法的要求,即第一个矩阵的列数必须等于第二个矩阵的行数。举个例子,如果 input1
的形状是 (m,n)
,input2
的形状是 (n,p)
,则它们的矩阵乘积结果的形状是 (m,p)
。
除了处理矩阵乘法,torch.matmul()
还可以用于进行批量矩阵乘法等复杂计算。例如,可以通过设置维度参数来进行矩阵乘法的批处理:
batch_input1 = torch.randn(3, 4, 5)
batch_input2 = torch.randn(3, 5, 6)
batch_output = torch.matmul(batch_input1, batch_input2)
其中,batch_input1 和 batch_input2 表示两个形状分别为 (3,4,5) 和 (3,5,6) 的矩阵组成的批次数据,batch_output 表示它们的矩阵乘积结果,形状为 (3,4,6)。
需要注意的是,如果 torch.matmul()
的输入张量维度大于 2,则它会自动对最后两个维度进行矩阵乘法并保留其它维度。如果需要在其它维度上进行批处理,可以通过设置维度参数来指定。此外,torch.matmul() 还支持不同维度的广播和转置操作等高级功能,具体使用方法可以参考 PyTorch 官方文档中的说明。
torch.nn.functional.interpolate()
torch.nn.functional.interpolate()
是 PyTorch 中的一个函数,用于对输入张量进行线性插值操作。它的作用类似于 tf.image.resize_images() 函数。
函数的调用形式为:
torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None)
其中,各参数的含义如下:
函数的返回值为一个新的张量,其形状由输入张量和参数 size
、scale_factor
及 align_corners
决定。常见的用法有:
import torch
# 创建形状为 (1,3,4,4) 的张量 x
x = torch.tensor([
[
[[1,2,3,4],
[2,3,4,5],
[3,4,5,6],
[4,5,6,7]],
[[2,3,4,5],
[3,4,5,6],
[4,5,6,7],
[5,6,7,8]],
[[3,4,5,6],
[4,5,6,7],
[5,6,7,8],
[6,7,8,9]]
]
])
#将 x 变形为形状为 (1,3,5,5) 的张量 y
y = torch.nn.functional.interpolate(x, size=(5,5), mode='nearest')
print(y)
#将 x 缩放为形状为 (1,3,8,8) 的张量 z
z = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
print(z)
其中,torch.nn.functional.interpolate(x, size=(5,5), mode='nearest')
将输入张量 x 缩放到大小为 (5,5) 的输出张量,插值方式为最近邻插值;torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
将输入张量 x 缩放到大小为原来的两倍,插值方式为双线性插值,不校正角点的位置。