算法面经手撕系列(2)--手撕BatchNormlization

BatchNormlization

  BatchNormlization的编码流程:

  1. init阶段初始化 C i n C_in Cin大小的scale向量和shift向量,同时初始化相同大小的滑动均值向量和滑动标准差向量;
  2. forward时沿着非channel维度计算均值、有偏方差
  3. 依据得到均值和有偏方差进行归一化
  4. 对归一化的结果进行缩放和平移

代码

 代码如下:

class BN(nn.Module):
    def __init__(self,C_in):
        super(BN,self).__init__()

        self.scale=nn.Parameter(torch.ones(C_in).view(1,-1,1,1))
        self.shift=nn.Parameter(torch.zeros(C_in).view(1,-1,1,1))

        self.momentum=0.9

        self.register_buffer('running_mean',torch.zeros(C_in).view(1,-1,1,1))
        self.register_buffer('running_var',torch.zeros(C_in).view(1,-1,1,1))
        self.eps=1e-9
    def forward(self,x):
        if self.training:
            N,C,H,W=x.shape

            mean=x.mean(dim=[0,2,3],keepdim=True)
            var=x.var(dim=[0,2,3],keepdim=True,unbiased=False)

            x=(x-mean)/torch.sqrt(var+self.eps)

            self.running_mean=self.momentum*self.running_mean+(1-self.momentum)*mean
            self.running_var=self.momentum*self.running_var+(1-self.momentum)*var
        else:
            x=(x-self.running_mean)/torch.sqrt(self.running_var+self.eps)

        return x
 if __name__=="__main__":
    input=torch.rand(10,3,5,5)
    model=BN(3)
    res=model(input)
    print('cool')


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值