MegEngine之ConvolutionBackwardData执行流程

针对F32, NCHW格式, workspace limit = 0, 首先执行

void ConvolutionBackwardData::scn_do_execute() {
    if (input(0)->dev_tensor().empty() || input(1)->dev_tensor().empty()) {
        mgb_assert(output(0)->dev_tensor().empty());
        return;
    }
    megdnn_opr()->exec(
        input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
        output(0)->dev_tensor().as_megdnn(),
        intl::get_megdnn_workspace_from_var(output(1)));
}

其中包含的算法 megdnn_opr为 fallback::ConvolutionBackwardDataImpl。

对于Format为 NCHW, 注意梯度的数据类型必须时 QuantizedS8, 才会走naive::ConvolutionBackwardDataImpl::exec

void ConvolutionBackwardDataImpl::exec(
_megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) {

    if (param().format == param::Convolution::Format::NHWCD4 ||
        param().format == param::Convolution::Format::NCHW4 ||
        ((param().format == param::Convolution::Format::NCHW ||
        param().format == param::Convolution::Format::NHWC) &&
        grad.layout.dtype.enumv() == DTypeEnum::QuantizedS8)) {
        return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad, workspace);
    }
    auto fparam = make_ncb_kern_param(filter, diff, grad, workspace);
    return exec_with_ncb_kern(fparam);
}

 exec_with_ncb_kern:

void ConvolutionBackwardDataImpl::exec_with_ncb_kern(const NCBKernParam& param) {
    auto p1g = param;
    auto group = p1g.filter_meta.group;
    p1g.filter_meta.group = 1;
    auto&& algo = get_algorithm(p1g);
    auto kptr = ncb_1g_dispatch_kern(algo, p1g);
    if (group == 1 || static_cast<AlgoBase*>(algo)->is_naive()) {
        auto run = [kptr, param]() { kptr(param); };
        static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
    } else {
        ...
    }
}

ncb_1g_dispatch_kern 选择核函数

ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::
        ncb_1g_dispatch_kern(Algorithm* algo, const NCBKernSizeParam& param) {
    megdnn_assert(param.filter_meta.group == 1);

    if (algo->handle_type() == Handle::HandleType::FALLBACK) {
        return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param);
    }

    megdnn_throw("no suitable ConvolutionBackwardData algorithm");
}

static_cast<AlgoBase*>(algo) 为

如果算子为:class ConvolutionBackwardDataImpl::AlgoNaive

如果算法为:class ConvolutionBackwardDataImpl::AlgoDirect

class ConvolutionBackwardDataImpl::AlgoDirect final : public AlgoBase {
public:
    const char* name() const override { return "DeconvDirect"; }
    bool usable(ConvolutionBackwardDataImpl* opr, const NCBKernSizeParam& param)
            const override;
    size_t get_workspace(
            ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const override;
    ncb_kern_t dispatch_kern(
            ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override;
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
    MEGDNN_DECL_ALGO_TYPE(FB_DIRECT)
};

实际实现为 :

void kern_direct(const NCBKernParam& param) {
    UNPACK_CONV_F32_NCB_KERN_SIZES(param);
    auto diff = param.diff<float>(), filter = param.filter<float>();
    auto grad = param.grad<float>();
    for (size_t n = 0; n < N; ++n) {
        convolution::run_conv_backward_data(
                diff + n * param.inp_bs, filter, grad + n * param.out_bs,
                param.workspace_ptr, IH, IW, IC, FH, FW, OH, OW, OC, PH, PW, SH, SW,
                !param.filter_meta.should_flip);
    }
}

最后调用run_conv_backward_data进行计算,最纯粹的 naive计算。

void run_conv_backward_data(
        const float* diff, const float* filter, float* grad, void* workspace, size_t IH,
        size_t IW, size_t IC, size_t FH, size_t FW, size_t OH, size_t OW, size_t OC,
        size_t PH, size_t PW, size_t SH, size_t SW, bool xcorr) {
    std::memset(grad, 0, sizeof(float) * IC * OH * OW);
    for (size_t oc = 0; oc < OC; ++oc)
        for (size_t ic = 0; ic < IC; ++ic) {
            // ut for untransposed
            const float* fut = filter + oc * IC * FH * FW + ic * FH * FW;
            const float* f;
            if (!xcorr) {
                // need transpose
                f = (float*)workspace;
                for (size_t fh = 0; fh < FH; ++fh)
                    for (size_t fw = 0; fw < FW; ++fw) {
                        ((float*)f)[fh * FW + fw] =
                                fut[(FH - fh - 1) * FW + (FW - fw - 1)];
                    }
            } else {
                // do not need transpose
                f = fut;
            }
            conv_backdata_single_channel(
                    diff + oc * IH * IW, f, grad + ic * OH * OW, IH, IW, FH, FW, OH, OW,
                    PH, PW, SH, SW);
        }
}

void conv_backdata_single_channel(
        const float* diff, const float* filter, float* grad, size_t IH, size_t IW,
        size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
        size_t SW) {
    if (can_run_xcorr_single_channel_templated(
                IH, IW, FH, FW, OH, OW, PH, PW, SH, SW)) {
        conv_backdata_single_channel_templated(
                diff, filter, grad, IH, IW, FH, FW, OH, OW, PH, PW, SH, SW);
    } else {
        MIDOUT_BEGIN(megdnn_fallback_conv, void) {
            conv_backdata_single_channel_nontemplated(
                    diff, filter, grad, IH, IW, FH, FW, OH, OW, PH, PW, SH, SW);
        }
        MIDOUT_END();
    }
}
void conv_backdata_single_channel_templated(
        const float* src, const float* filter, float* dst, size_t IH, size_t IW,
        size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
        size_t SW) {
    megdnn_ignore(FW);
#define DISPATCH(ker_size)                                             \
    if (FH == ker_size) {                                              \
        MIDOUT_BEGIN(megdnn_fallback_conv, ker_size) {                 \
            conv_backdata_single_channel_templated_impl<ker_size>(     \
                    src, filter, dst, IH, IW, OH, OW, PH, PW, SH, SW); \
        }                                                              \
        MIDOUT_END();                                                  \
        return;                                                        \
    }
    DISPATCH(1)
    DISPATCH(2)
    DISPATCH(3)
    DISPATCH(4)
    DISPATCH(5)
    DISPATCH(6)
    DISPATCH(7)
#undef DISPATCH
    megdnn_throw("internal error in conv_backdata template dispatching: impossible");
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值