batch normalization手写代码

批归一化(BN)层的代码实现,pytorch里面处理会更复杂
这里只是写了一个较为简单的模块

# -*- coding: utf-8 -*-
# @Time    : 2021/3/9 11:59
# @Author  : Li Gang
# @File    : test4.py
import numpy as np

def batchnorm(X,params,mode):
    mode = mode
    D,N = params.shape
    running_mean = params.get('running_mean', np.zeros(D,dtype=X.dtype))
    running_var = params.get('running_var', np.zeros(D,dtype=X.dtype))
    gamma = params.get('gamma')
    beta = params.get('beta')
    eps = params.get('eps', 1e-5)
    if mode=='train':
        samples_mean = np.mean(X, axis=0)
        samples_var = np.var(X, axis=0)
        out_ = (X - samples_mean) / (np.sqrt(samples_var) + eps)
        momentum = params.get('momentum')
        out = gamma * out_ + beta
        running_mean = momentum * running_mean + (1 - momentum) * samples_mean
        running_var = momentum * running_var + (1 - momentum) * samples_var
        params['running_mean'] = running_mean
        params['running_var'] = running_var
    elif mode=='test':
        out_ = (X-running_mean)/(running_var+eps)
        out = gamma*out_+beta
    return out



  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值