torch.histc
在跑小样本分割ASGNet的时候,计算mIOU的时候,在intersectionAndUnionGPU(output, target, K, ignore_index=255)
中,使用到了torch.histc计算直方图:
def intersectionAndUnionGPU(output, target, K, ignore_index=255):
# 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
assert (output.dim() in [1, 2, 3])
assert output.shape == target.shape
output = output.view(-1)
target = target.view(-1)
output[target == ignore_index] = ignore_index
intersection = output[output == target]
area_intersection = torch.histc(intersection, bins=K, min=0, max=K-1)
area_output = torch.histc(output, bins=K, min=0, max=K-1)
area_target = torch.histc(target, bins=K, min=0, max=K-1)
area_union = area_output + area_target - area_intersection
return area_intersection, area_union, area_target
torch.histc的API:
torch.histc(input, bins=100, min=0, max=0, *, out=None) → Tensor
计算张量的直方图。
元素被分类为 min 和 max 之间相等宽度的单元格。如果 min 和 max 均为零,则使用数据的最小值和最大值。
小于最小值和高于最大值的元素将被忽略。
Parameters:
input(Tensor)
–输入 张 量 \color{red}{张量} 张量。bins(int)
–直方图箱数min(int)
–范围的下限(包括)max(int)
–范围的上限(包括)
Returns
output
:直方图用 张 量 \color{red}{张量} 张量Tensor输出表示
示例:
>>> torch.histc(torch.tensor([1., 2, 1]), bins=4, min=0, max=3)
输出:tensor([ 0., 2., 1., 0.])
解释:直方图箱为0、1、2、3;
其中
- 0的数量为0
- 1的数量为2
- 2的数量为1
- 3的数量为0
所以输出tensor([0, 2, 1, 0])