[PaddleSeg源码阅读] 关于PaddleSeg模型返回的都是list这件小事

1. 倒着往回推

paddleseg/core/infer.py 中,inferenceslide_inference函数` 中:

(同时,aug_inference 也会调用 inference函数, inference函数内部也会调用 slide_inference函数)

def slide_inference(model, im, crop_size, stride):
    """
    Infer by sliding window.

    Args:
        model (paddle.nn.Layer): model to get logits of image.
        im (Tensor): the input image.
        crop_size (tuple|list). The size of sliding window, (w, h).
        stride (tuple|list). The size of stride, (w, h).

    Return:
        Tensor: The logit of input image.
    """
    h_im, w_im = im.shape[-2:]
    w_crop, h_crop = crop_size
    w_stride, h_stride = stride
    # calculate the crop nums
    rows = np.int(np.ceil(1.0 * (h_im - h_crop) / h_stride)) + 1
    cols = np.int(np.ceil(1.0 * (w_im - w_crop) / w_stride)) + 1
    # prevent negative sliding rounds when imgs after scaling << crop_size
    rows = 1 if h_im <= h_crop else rows
    cols = 1 if w_im <= w_crop else cols
    # TODO 'Tensor' object does not support item assignment. If support, use tensor to calculation.
    final_logit = None
    count = np.zeros([1, 1, h_im, w_im])
    for r in range(rows):
        for c in range(cols):
            h1 = r * h_stride
            w1 = c * w_stride
            h2 = min(h1 + h_crop, h_im)
            w2 = min(w1 + w_crop, w_im)
            h1 = max(h2 - h_crop, 0)
            w1 = max(w2 - w_crop, 0)
            im_crop = im[:, :, h1:h2, w1:w2]
            logits = model(im_crop)       # <------------------------- 看这里,以下几行
            if not isinstance(logits, collections.abc.Sequence):
                raise TypeError(
                    "The type of logits must be one of collections.abc.Sequence, e.g. list, tuple. But received {}"
                    .format(type(logits)))
            logit = logits[0].numpy()
            if final_logit is None:
                final_logit = np.zeros([1, logit.shape[1], h_im, w_im])
            final_logit[:, :, h1:h2, w1:w2] += logit[:, :, :h2 - h1, :w2 - w1]
            count[:, :, h1:h2, w1:w2] += 1
    if np.sum(count == 0) != 0:
        raise RuntimeError(
            'There are pixel not predicted. It is possible that stride is greater than crop_size'
        )
    final_logit = final_logit / count
    final_logit = paddle.to_tensor(final_logit)
    return final_logit
def inference(model,
              im,
              ori_shape=None,
              transforms=None,
              is_slide=False,
              stride=None,
              crop_size=None):
    """
    Inference for image.

    Args:
        model (paddle.nn.Layer): model to get logits of image.
        im (Tensor): the input image.
        ori_shape (list): Origin shape of image.
        transforms (list): Transforms for image.
        is_slide (bool): Whether to infer by sliding window. Default: False.
        crop_size (tuple|list). The size of sliding window, (w, h). It should be probided if is_slide is True.
        stride (tuple|list). The size of stride, (w, h). It should be probided if is_slide is True.

    Returns:
        Tensor: If ori_shape is not None, a prediction with shape (1, 1, h, w) is returned.
            If ori_shape is None, a logit with shape (1, num_classes, h, w) is returned.
    """
    if hasattr(model, 'data_format') and model.data_format == 'NHWC':
        im = im.transpose((0, 2, 3, 1))
    if not is_slide:
        logits = model(im)                 # <------------------------- 看这里,以下几行
        if not isinstance(logits, collections.abc.Sequence):
            raise TypeError(
                "The type of logits must be one of collections.abc.Sequence, e.g. list, tuple. But received {}"
                .format(type(logits)))
        logit = logits[0]
    else:
        logit = slide_inference(model, im, crop_size=crop_size, stride=stride)
    if hasattr(model, 'data_format') and model.data_format == 'NHWC':
        logit = logit.transpose((0, 3, 1, 2))
    if ori_shape is not None:
        logit = reverse_transform(logit, ori_shape, transforms, mode='bilinear')
        pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
        return pred, logit
    else:
        return logit

如果 is_slideFalse,则走这个分支(只走 inference 函数):

logits = model(im)                 # <------------------------- 看这里,以下几行
if not isinstance(logits, collections.abc.Sequence):
    raise TypeError(
        "The type of logits must be one of collections.abc.Sequence, e.g. list, tuple. But received {}"
        .format(type(logits)))
logit = logits[0]

如果 is_slideTrue,则走这个分支(inference 中调用 slide_inference 函数):

logits = model(im_crop)       # <------------------------- 看这里,以下几行
if not isinstance(logits, collections.abc.Sequence):
    raise TypeError(
        "The type of logits must be one of collections.abc.Sequence, e.g. list, tuple. But received {}"
        .format(type(logits)))
logit = logits[0].numpy()

1是 logit logits 的第0个元素
2是 logits 应当是 collections.abc.Sequence 类,诸如 list, tuple 之类的类型

总之,model(input) 返回的应当是 列表,大概率只有一个元素


2. 例子

光通过这样的推断,不是很精确,咱找几个 model 看看他 forward 的返回值是 list 还是 Tensor 就好了

来看看最近比较火的 pp_liteseg 模型

位置在 paddleseg/models/pp_liteseg.py

我们只需要找到,被装饰器@manager.MODELS.add_component 装饰的类即可
也就是
在这里插入图片描述
直接查看其 forward 函数的返回值:
在这里插入图片描述
看到两个分支返回的都是 list,无需知道list里到底是什么,只需要知道返回的是list即可

另外,在 非 training 的时候,list中只有一个元素


再来看看之前比较火的 pphumanseg_lite 模型
paddleseg/models/pphumanseg_lite.py

找最主要的类,一是看被装饰器@manager.MODELS.add_component 装饰的类
二是看__all__中的元素是啥:

__all__ = ['PPHumanSegLite']

来看其 forward 函数:

    def forward(self, x):
        # Encoder
        input_shape = paddle.shape(x)[2:]

        x = self.conv_bn0(x)  # 1/2
        shortcut = self.conv_bn1(x)  # shortcut
        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)  # 1/4
        x = self.block1(x)  # 1/8
        x = self.block2(x)  # 1/16

        # Decoder
        x = self.depthwise_separable0(x)
        shortcut_shape = paddle.shape(shortcut)[2:]
        x = F.interpolate(
            x,
            shortcut_shape,
            mode='bilinear',
            align_corners=self.align_corners)
        x = paddle.concat(x=[shortcut, x], axis=1)
        x = self.depthwise_separable1(x)

        logit = self.depthwise_separable2(x)
        logit = F.interpolate(
            logit,
            input_shape,
            mode='bilinear',
            align_corners=self.align_corners)

        return [logit]   # <---------- logit 显然是个 Tensor, 这里给他套上框框,变成列表

3. 总结
  • 从 infer.inference 中看到 logit 总是取 logits 的第0个元素,推测模型的返回值应该是个序列(列表)
  • 通过 PPHumanSegLite 和 pp_liteseg 两个模型,可以看出返回值确实是列表
  • 猜测可能是在为了兼容性,故而将所有的返回值设置为list
  • 可能有些较真,哈哈哈哈哈哈哈,确实(有些没有必要hh)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值