代码来自于CVPR2018的一篇文章Context Encoding for Semantic Segmentation。
github工程地址为:https://github.com/zhanghang1989/PyTorch-Encoding
很棒的工作。
我是新入门,对代码解读如下,有不对的地方请高手们批评指正。
我的标签图像序号是从0开始,依次编号,0代表背景类
def batch_intersection_union(predict, target, nclass):
"""Batch Intersection of Union
Args:
predict: input 4D tensor 具体地为B(batch大小)*C(通道)*H*W
target: label 3D tensor 具体地为B(batch大小)*H*W
nclass: number of categories (int)
"""
#在通道维上取最大值,注意predict为第二个返回值,因此是索引,从0开始
#此时predict和target一样,维度均为B*H*W,且值均为0,1,2.........
_, predict = torch.max(predict, 1)
mini = 1
maxi = nclass
nbins = nclass
#将predict和target放入cpu并转换为numpy数组,同时+1
#此时predict和target的值为1,2,......
predict = predict.cpu().nu