自己动手实现BatchNorm(pytorch实现)

BatchNorm可以加速模型的收敛并且缓解梯度消失问题,是深度学习领域常用的一个技术

最近仔细学习了BatchNorm的原理,因此想自己动手实现一下它,加深理解

代码如下:

import torch
import torch.nn as nn


class MyBatchNorm(nn.Module):
    # def __init__(self, dim):
    def __init__(self, dim):
        super().__init__()
        # 可训练参数 gamma和beta
        self.gamma = nn.Parameter(data=torch.randn((dim)))
        self.beta = nn.Parameter(data=torch.randn((dim)))
        # 全局的均值和方差
        self.mean_whole = torch.zeros((dim))
        self.var_whole = torch.zeros((dim))
        self.lba = 0.99
        # 防止除零错误
        self.eps = 1e-7
   
    def forward(self, x):
        # 检查形状
        if x.dim() == 4:
            x = x.reshape(x.size(0), x.size(1), -1)
        assert x.dim() == 3

        # 处于训练状态
        if self.training:
            # 首先计算每个通道的均值和方差
            # (b, c, d) -> (1, c, 1)
            mean_batch = torch.mean(x, dim=[0, 2], keepdim=True)
            var_batch = torch.var(x, dim=[0, 2], keepdim=True, unbiased=False)
            # 使用滑动平均办法计算全局均值和方差
            n = x.numel() / x.size(1)
            self.mean_whole = self.lba * self.mean_whole + (1 - self.lba) * mean_batch
            self.var_whole = self.lba * self.var_whole + (1 - self.lba) * var_batch * n / (n-1)
            # 然后归一化数据
            x = (x - mean_batch) / torch.sqrt((var_batch + self.eps))
        else:
            # 归一化数据
            x = (x - self.mean_whole[None, ..., None]) / torch.sqrt((self.var_whole[None, ..., None] + self.eps))

        # 放缩平移
        x = x * self.gamma[None, ..., None] + self.beta[None, ..., None]
        return x


x = torch.randn((2, 3, 4))

batch_norm = MyBatchNorm(dim=3)
batch_norm = batch_norm.train()

b = batch_norm(x)

print(b.shape)

参考资料:

1. 原理

https://zhuanlan.zhihu.com/p/34879333

2. 代码

https://zhuanlan.zhihu.com/p/337732517

  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值