针对F32, NCHW格式, workspace limit = 0, 首先执行
void ConvolutionBackwardData::scn_do_execute() {
if (input(0)->dev_tensor().empty() || input(1)->dev_tensor().empty()) {
mgb_assert(output(0)->dev_tensor().empty());
return;
}
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
output(0)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output(1)));
}
其中包含的算法 megdnn_opr为 fallback::ConvolutionBackwardDataImpl。
对于Format为 NCHW, 注意梯度的数据类型必须时 QuantizedS8, 才会走naive::ConvolutionBackwardDataImpl::exec
void ConvolutionBackwardDataImpl::exec(
_megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) {
if (param().format == param::Convolution::Format::NHWCD4 ||
param().format == param::Convolution::Format::NCHW4 ||
((param().format == param::Convolution::Format::NCHW ||
param().format == param::Convolution::Format::NHWC) &&
grad.layout.dtype.enumv() == DTypeEnum::QuantizedS8)) {
return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad, workspace);
}
auto fparam = make_ncb_kern_param(filter, diff, grad, workspace);
return exec_with_ncb_kern(fparam);
}
exec_with_ncb_kern:
void ConvolutionBackwardDataImpl::exec_with_ncb_kern(const NCBKernParam& param) {
auto p1g = param;
auto group = p1g.filter_meta.group;
p1g.filter_meta.group = 1;
auto&& algo = get_algorithm(p1g);
auto kptr = ncb_1g_dispatch_kern(algo, p1g);
if (group == 1 || static_cast<AlgoBase*>(algo)->is_naive()) {
auto run = [kptr, param]() { kptr(param); };
static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
} else {
...
}
}
ncb_1g_dispatch_kern 选择核函数
ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::
ncb_1g_dispatch_kern(Algorithm* algo, const NCBKernSizeParam& param) {
megdnn_assert(param.filter_meta.group == 1);
if (algo->handle_type() == Handle::HandleType::FALLBACK) {
return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param);
}
megdnn_throw("no suitable ConvolutionBackwardData algorithm");
}
static_cast<AlgoBase*>(algo) 为
如果算子为:class ConvolutionBackwardDataImpl::AlgoNaive
如果算法为:class ConvolutionBackwardDataImpl::AlgoDirect
class ConvolutionBackwardDataImpl::AlgoDirect final : public AlgoBase {
public:
const char* name() const override { return "DeconvDirect"; }
bool usable(ConvolutionBackwardDataImpl* opr, const NCBKernSizeParam& param)
const override;
size_t get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const override;
ncb_kern_t dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override;
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
MEGDNN_DECL_ALGO_TYPE(FB_DIRECT)
};
实际实现为 :
void kern_direct(const NCBKernParam& param) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
auto diff = param.diff<float>(), filter = param.filter<float>();
auto grad = param.grad<float>();
for (size_t n = 0; n < N; ++n) {
convolution::run_conv_backward_data(
diff + n * param.inp_bs, filter, grad + n * param.out_bs,
param.workspace_ptr, IH, IW, IC, FH, FW, OH, OW, OC, PH, PW, SH, SW,
!param.filter_meta.should_flip);
}
}
最后调用run_conv_backward_data进行计算,最纯粹的 naive计算。
void run_conv_backward_data(
const float* diff, const float* filter, float* grad, void* workspace, size_t IH,
size_t IW, size_t IC, size_t FH, size_t FW, size_t OH, size_t OW, size_t OC,
size_t PH, size_t PW, size_t SH, size_t SW, bool xcorr) {
std::memset(grad, 0, sizeof(float) * IC * OH * OW);
for (size_t oc = 0; oc < OC; ++oc)
for (size_t ic = 0; ic < IC; ++ic) {
// ut for untransposed
const float* fut = filter + oc * IC * FH * FW + ic * FH * FW;
const float* f;
if (!xcorr) {
// need transpose
f = (float*)workspace;
for (size_t fh = 0; fh < FH; ++fh)
for (size_t fw = 0; fw < FW; ++fw) {
((float*)f)[fh * FW + fw] =
fut[(FH - fh - 1) * FW + (FW - fw - 1)];
}
} else {
// do not need transpose
f = fut;
}
conv_backdata_single_channel(
diff + oc * IH * IW, f, grad + ic * OH * OW, IH, IW, FH, FW, OH, OW,
PH, PW, SH, SW);
}
}
void conv_backdata_single_channel(
const float* diff, const float* filter, float* grad, size_t IH, size_t IW,
size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
size_t SW) {
if (can_run_xcorr_single_channel_templated(
IH, IW, FH, FW, OH, OW, PH, PW, SH, SW)) {
conv_backdata_single_channel_templated(
diff, filter, grad, IH, IW, FH, FW, OH, OW, PH, PW, SH, SW);
} else {
MIDOUT_BEGIN(megdnn_fallback_conv, void) {
conv_backdata_single_channel_nontemplated(
diff, filter, grad, IH, IW, FH, FW, OH, OW, PH, PW, SH, SW);
}
MIDOUT_END();
}
}
void conv_backdata_single_channel_templated(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
size_t SW) {
megdnn_ignore(FW);
#define DISPATCH(ker_size) \
if (FH == ker_size) { \
MIDOUT_BEGIN(megdnn_fallback_conv, ker_size) { \
conv_backdata_single_channel_templated_impl<ker_size>( \
src, filter, dst, IH, IW, OH, OW, PH, PW, SH, SW); \
} \
MIDOUT_END(); \
return; \
}
DISPATCH(1)
DISPATCH(2)
DISPATCH(3)
DISPATCH(4)
DISPATCH(5)
DISPATCH(6)
DISPATCH(7)
#undef DISPATCH
megdnn_throw("internal error in conv_backdata template dispatching: impossible");
}