训练semantic segmentation时的报错

RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [1,3,840,840]

        以上是在训练semantic segmentation时,出现的报错。翻了一圈看到的唯一靠谱的解释。

        原因大概就是label图片需要1维的数据格式(灰度图),但是图片在输入前仍是3维的RGB图片,没转换成1维的。以下是靠谱解释的链接

dcRuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [1, 3, 96, 128] - vision - PyTorch ForumsI am using UNet I have images with masks (black background and highlighted portion contain different RGB colors) and 12 classes. I want to do training using uNET. but i got error. “RuntimeError: 1only batches of spatial…https://discuss.pytorch.org/t/runtimeerror-1only-batches-of-spatial-targets-supported-3d-tensors-but-got-targets-of-size-1-3-96-128/95030Training Semantic Segmentation - #3 by WeiQin_Chuah - vision - PyTorch ForumsHi, I am trying to reproduce PSPNet using PyTorch and this is my first time creating a semantic segmentation model. I understand that for image classification model, we have RGB input = [h,w,3] and label or ground truth…https://discuss.pytorch.org/t/training-semantic-segmentation/49275/3

/tmp/pip-req-build-xlj_h8ax/aten/src/THCUNN/SpatialClassNLLCriterion.cu:106: cunn_SpatialClassNLLCriterion_updateOutput_kernel: block: [4,0,0], thread: [417,0,0] Assertion `t >= 0 && t < n_classes` failed.

RuntimeError: CUDA error: device-side assert triggered

        以上报错是因为训练时候输入的是4类,但是label数据因为用cv2.resize时,采用了默认的cv2.INTER_LINER插值法,插入了其他非类别的数值,改成cv2.INTER_NEAREST最近邻插值法即可。转换语句如下:

cv2.resize(image, (tw, th), interpolation=cv2.INTER_NEAREST)

# 不可简写成cv2.resize(image, (tw, th), cv2.INTER_NEAREST),否则仍是默认的cv2.INTER_LINER

参考:

Is there anybody happen this error? - #14 by XiaoAHeng - autograd - PyTorch Forums/opt/conda/conda-bld/pytorch_1512386481460/work/torch/lib/THCUNN/SpatialClassNLLCriterion.cu:99: void cunn_SpatialClassNLLCriterion_updateOutput_kernel(T *, T *, T *, long *, T *, int, int, int, int, int, long) [with T =…https://discuss.pytorch.org/t/is-there-anybody-happen-this-error/17416/14        ​​​​​​​由于PIL.Image 和cv2读取图片后数据格式不同,(Image读取后的size是宽,高;cv2读取后的shape是高,宽),我就手残把处理图片的代码全改了(其实完全不用),以下是使用差别较大的几个方法:

# 截取图片
label = label[h:h+h, w:w+w]  # cv2
label = label.crop((w, h, w + w, h + h))  # PIL

# 左右翻转
cv2.flip(image, 1)  # cv2
image.transpose(0)  # PIL

# 修改图片尺寸,特别要注意interpolation的对应关系
cv2.resize(image, (tw, th), interpolation=cv2.INTER_NEAREST)  # cv2
image.resize((tw, th), interpolation)  # PIL

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值