unet学习笔记(milesial/Pytorch-UNet)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

**

小白一个,只是学习记录,有问题欢迎大佬指出

**

所用到的代码Pytorch-UNet/predict.py at master · milesial/Pytorch-UNet ·GitHub

原理讲解

在这里插入图片描述

蓝/白色框表示 feature map;
蓝色箭头表示 3x3 卷积,用于特征提取;
灰色箭头表示skip-connection,用于特征融合;
红色箭头表示池化 pooling,用于降低维度; 绿色箭头表示上采样
upsample,用于恢复维度; 青色箭头表示 1x1 卷积,用于输出结果。

在这里插入图片描述

多分类问题

对于多分类问题,以下做法只是代码可以运行,不能从本质上解决多分类问题
代码能运行的原因也是因为强制将多分类问题改为了二分类方法

        if not is_mask:
            if img_ndarray.ndim == 2:
                img_ndarray = img_ndarray[np.newaxis, ...]
            else:
                img_ndarray = img_ndarray.transpose((2, 0, 1))

            img_ndarray = img_ndarray / 255

        return img_ndarray

在dataloading.py中除以255是为了让图片转换到(0,1)中
如果是多分类问题,该处就没有任何用处
所以在代码中需要修改

 img = self.preprocess(img, self.scale, is_mask=False)
 mask = self.preprocess(mask, self.scale, is_mask=False)

在train.py中需要修改

              with torch.cuda.amp.autocast(enabled=amp):
                    masks_pred = net(images)
                    loss = criterion(masks_pred, true_masks.squeeze(1)) \
                           + dice_loss(F.softmax(masks_pred, dim=1).float(),
                                       F.one_hot(true_masks.squeeze(1), net.n_classes).permute(0, 3, 1, 2).float(),
                                       multiclass=True)

在 true_masks.squeeze(1) 中加入.squeeze(1)
因为 criterion = nn.CrossEntropyLoss()中的定义如下

这里是引用
输入的predict的维度为(N,C,H,W),对应label输入的维度应该为(N,H,W),且label的值在[0,C-1]之间。

squeeze的作用看下面的代码

                with torch.cuda.amp.autocast(enabled=amp):
                    masks_pred = net(images)
                    print(masks_pred)
                    print(masks_pred.shape)
                    print(true_masks)
                    print(masks_pred.shape)
                    print(true_masks.squeeze(1))
                    print(true_masks.squeeze(1).shape)
                    loss = criterion(masks_pred, true_masks.squeeze(1)) \
                           + dice_loss(F.softmax(masks_pred, dim=1).float(),
                                       F.one_hot(true_masks.squeeze(1), net.n_classes).permute(
  • 3
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值