先使用bn层,后使用conv层,如下图:
def _fuse_bn_conv_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
"""先bn,后conv
:param branch:
:return: Tuple of (kernel, bias) after fusing batchnorm.
"""
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
std = (running_var + eps).sqrt()
t = gamma / std
t = torch.stack([t] * (kernel.shape[1]//t.shape[0]),dim=0).reshape(-1, self.in_channels // self.groups, 1, 1)
t_beta = torch.stack([beta] * (kernel.shape[1]//beta.shape[0]),dim=0).reshape(-1, self.in_channels // self.groups, 1, 1)
t_running_mean = torch.stack([running_mean] * (kernel.shape[1]//running_mean.shape[0]),dim=0).reshape(-1, self.in_channels // self.groups, 1, 1)
return kernel * t, torch.sum(
kernel
* (
t_beta - t_running_mean * t
),
dim=(1, 2, 3),
)+branch.conv.bias
这里还有个问题如果pad不为0的情况,那融合后的bias就不是Wconv*Bbn+Bconv,而是后面要接一个add,值为
import torch
import torch.nn.functional as F
input_h, input_w=ir.get_shape(conv.input[0])[2],ir.get_shape(conv.input[0])[3]
a=np.tile((t_beta - t_running_mean * t), (1, 1, input_h, input_w))
a=torch.from_numpy(a)
W1=torch.from_numpy(W)
bc1 = F.conv2d(a, W1, stride=ir.get_attr(conv, 'strides'), padding=ir.get_attr(conv, 'pads')[0], dilation=1).numpy()#[0,...]
bc1=np.float32(bc1)
就是后接个(t_beta - t_running_mean * t)过W的conv后的值