PyTorch库学习之torch.mean函数
一、简介
torch.mean
是 PyTorch 库中的一个函数,用于计算张量的均值。它可以沿着指定的维度或者整个张量计算均值,是数据分析和机器学习中常用的操作之一。
二、语法和参数
语法:
torch.mean(input, dim=None, keepdim=False, *, out=None)
参数:
input
(torch.Tensor): 输入张量。dim
(int, 可选): 沿着哪个维度计算均值。如果为None
,则计算整个张量的均值。keepdim
(bool, 可选): 如果为True
,则输出张量与输入张量具有相同的维度,但是指定维度的大小为 1。out
(Tensor, 可选): 输出张量,用于存储计算结果。
返回值:
- 返回一个新的张量,包含计算得到的均值。
三、实例
3.1 计算一维张量的全部元素的均值
import torch
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
result = torch.mean(x)
print(result)
输出:
tensor(3.)
3.2 计算二维张量沿特定维度的均值
import torch
y = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
result = torch.mean(y, dim=1)
print(result)
输出:
tensor([2., 5.])
3.3 计算二维张量均值并保持维度
import torch
y = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
result = torch.mean(y, dim=0, keepdim=False)
result_keepdim = torch.mean(y, dim=0, keepdim=True)
print(result.shape)
print(result_keepdim.shape)
输出:
torch.Size([3])
torch.Size([1, 3])
四、注意事项
- 当
dim
参数为None
时,torch.mean
会计算所有元素的均值。 keepdim
参数在处理多维数据时很有用,特别是需要与原始数据维度对齐的操作。- 如果指定了
out
参数,计算结果将直接写入该张量中,而不是创建新的张量。 - 确保输入张量
input
不是零维的,因为零维张量没有元素可以计算均值。