torch.norm()
是 PyTorch 中的一个函数,用于计算输入张量沿指定维度的范数。具体而言,当给定一个输入张量 x
和一个整数 p
时,torch.norm(x, p)
将返回输入张量 x
沿着最后一个维度(默认为所有维度)上所有元素的 p
范数。
除了使用标量 p
之外,torch.norm()
还接受以下参数:
dim
:指定沿哪个轴计算范数,默认对所有维度计算。keepdim
:如果设置为 True,则输出张量维度与输入张量相同,其中指定轴尺寸为 1;否则,将从输出张量中删除指定轴。out
:可选输出张量结果。
以下是一个示例:
import torch
# 创建一个形状为 (3, 4) 的二维张量
x = torch.tensor([[2., 3., 5., -1.], [-1., -2., 1., 4.], [0.5, -2., 7., 2.]])
# 计算所有元素的 L2 范数
l2_norm_all = torch.norm(x)
print("L2 norm of all elements:", l2_norm_all.item())
# 计算第一个维度上每个子数组的 L2 范数(即按行计算)
l2_norm_rows = torch.norm(x, dim=1)
print("L2 norm of rows:", l2_norm_rows.numpy())
# 计算最后一个维度上每个子数组的 L1 范数(即按列计算)
l1_norm_cols = torch.norm(x, p=1, dim=-1)
print("L1 norm of columns:", l1_norm_cols.numpy())
在这个示例中,我们首先创建了一个形状为 (3, 4)
的二维张量 x
,然后使用 torch.norm()
函数计算了不同维度上的范数。注意,我们将 dim
参数设置为 1 和 -1 以分别按行和按列计算范数,并将 p
参数设置为 1 来计算 L1 范数。在输出结果中,我们使用 .item()
将标量张量转换回 Python 中的浮点数,用 .numpy()
将张量转换回 NumPy 数组。