TensorRT搭建Batch Normalization BN层

PyTorch的Batch Normalization

torch.nn.BatchNorm2d

E[x]是batch的均值,Var[x]是batch的方差,ϵ 为了防止除0,γ 对应batch学习得到的权重,β就是偏置

torch.load(weight)

weights  = torch.load(your_model_dict_state_path)
in_gamma = weights['in.weight'].numpy()        # in gamma
in_beta  = weights['in.bias'].numpy()          # in beta
in_mean  = weights['in.running_mean'].numpy()  # in mean
in_var   = weights['in.running_var'].numpy()   # in var sqrt


TensorRT实现BN:

IScaleLayer的文档见链接,https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Layers.html#iscalelayer

power取1

python API

import tensorrt as trt

weights  = torch.load(your_model_dict_state_path)
in_gamma = weights['in.weight'].numpy()        # in gamma
in_beta  = weights['in.bias'].numpy()          # in beta
in_mean  = weights['in.running_mean'].numpy()  # in mean
in_var   = weights['in.running_var'].numpy()   # in var sqrt
eps      = 1e-05
in_var   = np.sqrt(in_var + eps)

in_scale = in_gamma / in_var
in_shift = - in_mean / in_var * in_gamma + in_beta
in       = network.add_scale(input=last_layer.get_output(0), mode=trt.ScaleMode.CHANNEL, shift=in_shift, scale=in_scale)

C++ API

IScaleLayer* addBatchNorm2d(INetworkDefinition* network, std::map<std::string, Weights>& weightMap, ITensor& input, std::string layerName, float eps) {
    float *gamma = (float *)weightMap[layerName + ".weight"].values;
    float *beta = (float *)weightMap[layerName + ".bias"].values;
    float *mean = (float *)weightMap[layerName + ".running_mean"].values;
    float *var = (float *)weightMap[layerName + ".running_var"].values;
    int len = weightMap[layerName + ".running_var"].count;

    float *scval = reinterpret_cast<float *>(malloc(sizeof(float ) * len));
    for (int i = 0; i < len; ++i) {
        scval[i] = gamma[i] / sqrt(var[i] + eps);
    }
    Weights scale{DataType::kFLOAT, scval, len};

    float *shval = reinterpret_cast<float *>(malloc(sizeof(float ) * len));
    for (int i = 0; i < len; ++i) {
        shval[i] = beta[i] - mean[i] * gamma[i] / sqrt(var[i] + eps);
    }
    Weights shift{DataType::kFLOAT, shval, len};

    float *pval = reinterpret_cast<float*>(malloc(sizeof(float) * len));
    for (int i = 0; i < len; i++) {
        pval[i] = 1.0;
    }
    Weights power{DataType::kFLOAT, pval, len};

    weightMap[layerName + ".scale"] = scale;
    weightMap[layerName + ".shift"] = shift;
    weightMap[layerName + ".power"] = power;
    IScaleLayer *scale1 = network->addScale(input, ScaleMode::kCHANNEL, shift, scale, power);
    assert(scale1);
    return scale1;
}

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值