最近在跑PSPNet语义分割的网络时,用作者给的源代码运行出现错误。错误为:
/opt/conda/conda-bld/pytorch_1565272271120/work/aten/src/THCUNN/SpatialClassNLLCriterion.cu:103: void cunn_SpatialClassNLLCriterion_updateOutput_kernel(T *, T *, T *, long *, T *, int, int, int, int, int, long) [with T = float, AccumT = float]: block: [1,0,0], thread: [671,0,0] Assertion
t >= 0 && t < n_classes
failed.
主要原因是读取的label标签值超过了要分割的总类别num_classes。
源程序读取的label图是用颜色标记好的,如图:
PSPNet中读取label图像的代码为:
label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) # GRAY 1 channel ndarray with shape H * W
调试后发现读出来的label为一个(468,625)的numpy数组,其中部分值如下所示,可见根本不是[0,num_classes-1]范围里的数值,当然会报错。