PyTorch的Batch Normalization
torch.nn.BatchNorm2d
E[x]是batch的均值,Var[x]是batch的方差,ϵ 为了防止除0,γ 对应batch学习得到的权重,β就是偏置
torch.load(weight)
weights = torch.load(your_model_dict_state_path)
in_gamma = weights['in.weight'].numpy() # in gamma
in_beta = weights['in.bias'].numpy() # in beta
in_mean = weights['in.running_mean'].numpy() # in mean
in_var = weights['in.running_var'].numpy() # in var sqrt
TensorRT实现BN:
IScaleLayer的文档见链接,https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Layers.html#iscalelayer
power取1
python API
import tensorrt as trt
weights = torch.load(your_model_dict_state_path)
in_gamma = weights['in.weight'].numpy() # in gamma
in_beta = weights['in.bias'].numpy() # in beta
in_mean = weights['in.running_mean'].numpy() # in mean
in_var = weights['in.running_var'].numpy() # in var sqrt
eps = 1e-05
in_var = np.sqrt(in_var + eps)
in_scale = in_gamma / in_var
in_shift = - in_mean / in_var * in_gamma + in_beta
in = network.add_scale(input=last_layer.get_output(0), mode=trt.ScaleMode.CHANNEL, shift=in_shift, scale=in_scale)
C++ API
IScaleLayer* addBatchNorm2d(INetworkDefinition* network, std::map<std::string, Weights>& weightMap, ITensor& input, std::string layerName, float eps) {
float *gamma = (float *)weightMap[layerName + ".weight"].values;
float *beta = (float *)weightMap[layerName + ".bias"].values;
float *mean = (float *)weightMap[layerName + ".running_mean"].values;
float *var = (float *)weightMap[layerName + ".running_var"].values;
int len = weightMap[layerName + ".running_var"].count;
float *scval = reinterpret_cast<float *>(malloc(sizeof(float ) * len));
for (int i = 0; i < len; ++i) {
scval[i] = gamma[i] / sqrt(var[i] + eps);
}
Weights scale{DataType::kFLOAT, scval, len};
float *shval = reinterpret_cast<float *>(malloc(sizeof(float ) * len));
for (int i = 0; i < len; ++i) {
shval[i] = beta[i] - mean[i] * gamma[i] / sqrt(var[i] + eps);
}
Weights shift{DataType::kFLOAT, shval, len};
float *pval = reinterpret_cast<float*>(malloc(sizeof(float) * len));
for (int i = 0; i < len; i++) {
pval[i] = 1.0;
}
Weights power{DataType::kFLOAT, pval, len};
weightMap[layerName + ".scale"] = scale;
weightMap[layerName + ".shift"] = shift;
weightMap[layerName + ".power"] = power;
IScaleLayer *scale1 = network->addScale(input, ScaleMode::kCHANNEL, shift, scale, power);
assert(scale1);
return scale1;
}