import torch
# 假设 x 的形状为 (batch_size, channels, height, width)
x = torch.arange(12, dtype=torch.float32).reshape(1,1,3,4)
print(x)
# 计算最后两个维度上的方差
x_v = torch.var(x, dim=(-2, -1), keepdim=True)
print(x_v) # 输出形状为 (batch_size, channels, 1, 1)
# tensor([[[[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.]]]])
# tensor([[[[13.]]]])
计算最后两个维度上的方差 《=====》计算整张图片的方差