chainer-目标检测-centerNet


前言

本文简述基于chainer实现centerNet目标检测算法


一、centerNet代码实现

from collections import defaultdict
import chainer
import numpy as np
from chainer import Chain, Variable, reporter
from typing import Callable, Dict
import chainer.functions as F
from core.loss import center_detection_loss
from utils.bbox import resize_bbox
from utils.image import resize

def find_peak(map, x, y):
    dx = np.array([-1,  0,  1, 0, -1,  1, -1, 0, 1])
    dy = np.array([-1, -1, -1, 0,  0,  0,  1, 1, 1])
    while True:
        nx = np.minimum(np.maximum(x + dx, 0), map.shape[1] - 1)
        ny = np.minimum(np.maximum(y + dy, 0), map.shape[0] - 1)

        max_idx = np.argmax(map[ny, nx])

        if x == nx[max_idx] and y == ny[max_idx]:
            break
        x = nx[max_idx]
        y = ny[max_idx]
    return x, y

class CenterDetector(Chain):
    def __init__(self, base_network_factory: Callable[[Dict[str, int]], Chain], insize, num_classes, downratio=4, dtype=np.float32):
        super().__init__()
        self.num_classes = num_classes
        self.insize = insize
        self.downratio = downratio
        self.dtype = dtype
        with self.init_scope():
            self.base_network = base_network_factory({
                'hm': num_classes,
                'wh': 2,
                'offset': 2,
            })

    def forward(self, x):
        y = self.base_network(x)
        return y

    def predict(self, imgs, k=100, detail=False, output_index=-1):
        x = []
        sizes = []
        for img in imgs:
            _, H, W = img.shape
            img = self._prepare(img)
            x.append(self.xp.array(img))
            sizes.append((H, W))
        with chainer.using_config('train', False), chainer.function.no_backprop_mode():
            x = Variable(self.xp.stack(x))
            output = self.forward(x)[output_index]

        bboxes = []
        labels = []
        scores = []
        output['hm'] = F.sigmoid(output['hm'])
        output['hm'].to_cpu()
        for i in range(len(imgs)):
            bbox, label, score = self._decode_output(output, i, k)
            bbox = resize_bbox(bbox, (self.insize, self.insize), sizes[i])
            bboxes.append(bbox)
            labels.append(label)
            scores.append(score)

        if detail:
            return bboxes, labels, scores, output
        else:
            return bboxes, labels, scores

    def _decode_output(self, output, index, k):
        bboxes = []
        labels = []
        scores = []
        for j in range(self.num_classes):
            hm = output['hm'].array[index, j]

            pos_indices = np.argsort(hm.flatten())[::-1][:k]
            already_visit_peak = defaultdict(lambda: False)
            for pos_index in pos_indices:
                x = pos_index % hm.shape[1]
                y = pos_index // hm.shape[1]
                peak_x, peak_y = find_peak(hm, x, y)
                if not already_visit_peak[peak_y, peak_x]:
                    already_visit_peak[peak_y, peak_x] = True

                    adjusted_x, adjusted_y, w, h = self._decode_bbox(output, peak_x, peak_y, index)
                    bboxes.append([adjusted_y - h / 2, adjusted_x - w / 2, adjusted_y + h / 2, adjusted_x + w / 2])
                    labels.append(j)
                    scores.append(hm[y, x])
        scores = np.array(scores)
        sorted_idx = scores.argsort()[::-1]
        return np.array(bboxes)[sorted_idx], np.array(labels)[sorted_idx], scores[sorted_idx]

    def _decode_bbox(self, output, x, y, index):
        wh = output['wh']
        wh.to_cpu()
        offset = output['offset']
        offset.to_cpu()
        wh = wh.array
        offset = offset.array

        return (
            x * self.downratio + offset[index, 0, y, x], y * self.downratio + offset[index, 1, y, x],
            wh[index, 0, y, x] * self.downratio, wh[index, 1, y, x] * self.downratio
        )

    def _prepare(self, img):
        img = img.astype(self.dtype)
        img = resize(img, (self.insize, self.insize))
        return img

class CenterDetectorTrain(Chain):
    def __init__(self,
                 center_detector,
                 hm_weight,
                 wh_weight,
                 offest_weight,
                 focial_loss_alpha=2,
                 focial_loss_beta=4,
                 comm=None
                 ):
        super().__init__()

        self.hm_weight = hm_weight
        self.wh_weight = wh_weight
        self.offset_weight = offest_weight
        self.comm = comm

        with self.init_scope():
            self.center_detector = center_detector

    def forward(self, **indata):
        imgs = indata['image']
        y = self.center_detector(imgs)
        loss, hm_loss, wh_loss, offset_loss, detail_losses = center_detection_loss(
            y, indata,
            self.hm_weight, self.wh_weight, self.offset_weight, comm=self.comm
        )
        hm = y[-1]["hm"]
        hm_mae = F.mean_absolute_error(hm, indata["hm"])
        reporter.report({
            'loss': loss,
            'hm_loss': hm_loss,
            'hm_pos_loss': detail_losses['hm_pos_loss'],
            'hm_neg_loss': detail_losses['hm_neg_loss'],
            'hm_mae': hm_mae,
            'wh_loss': wh_loss,
            'offset_loss': offset_loss
        }, self)
        return loss


总结

源码:chainer-centerNet

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

爱学习的广东仔

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值