PyTorch库学习之nn.BatchNorm2d模块

PyTorch库学习之nn.BatchNorm2d模块

一、简介

nn.BatchNorm2d 是 PyTorch 深度学习框架中的一个模块,用于对二维卷积层的输出进行批量归一化(Batch Normalization)。批量归一化是一种常用的正则化技术,可以加速训练过程,提高模型的泛化能力。它通过规范化层的输入来减少内部协变量偏移,即确保网络的每一层输入数据的分布保持相对稳定。

二、语法和参数

nn.BatchNorm2d 模块的语法如下:

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

参数说明:

  • num_features:输入的通道数。
  • eps:数值稳定性的小常数,用于防止除以零。
  • momentum:移动平均的动量参数。
  • affine:布尔值,如果为True,则使用可学习的仿射变换参数。
  • track_running_stats:布尔值,如果为True,则跟踪整个训练过程中的均值和方差。

三、实例

3.1 基本使用
import torch
import torch.nn as nn

# 定义 BatchNorm2d 模块
bn = nn.BatchNorm2d(3)

# 创建一个假的输入数据,假设输入的维度是 [batch_size, num_channels, height, width]
input = torch.randn(4, 3, 10, 10)

# 应用 BatchNorm2d 模块
output = bn(input)

输出:

torch.Size([4, 3, 10, 10])
3.2 使用affine参数
# 定义 BatchNorm2d 模块,不使用仿射变换
bn_no_affine = nn.BatchNorm2d(3, affine=False)

# 应用 BatchNorm2d 模块
output_no_affine = bn_no_affine(input)

输出:

torch.Size([4, 3, 10, 10])
3.3 使用track_running_stats参数
# 定义 BatchNorm2d 模块,不跟踪运行时的统计数据
bn_no_running_stats = nn.BatchNorm2d(3, track_running_stats=False)

# 应用 BatchNorm2d 模块
output_no_running_stats = bn_no_running_stats(input)

输出:

torch.Size([4, 3, 10, 10])

四、注意事项

  • nn.BatchNorm2d 通常在卷积层之后使用。
  • 在训练过程中,nn.BatchNorm2d 会更新均值和方差的运行估计值;在评估模式下,使用这些估计值进行归一化。
  • 当设置 affine=True 时,BatchNorm2d 会学习两个可训练参数,即缩放因子和偏移因子,以允许对归一化输出进行仿射变换。
  • 使用 track_running_stats=True 可以使得 BatchNorm2d 在训练过程中动态更新统计数据,这对于模型的泛化能力是有益的。在推理时,可以使用训练过程中得到的均值和方差进行归一化。
  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值