ResNet(残差网络)是一种深度卷积神经网络结构,通过引入残差块(Residual Block)解决了深度网络中的梯度消失和梯度爆炸问题,使得可以训练非常深的网络。
使用PyTorch实现的ResNet-50模型
FixedBatchNorm类
定义FixedBatchNorm类,继承自nn.BatchNorm2d
,并在forward
方法中使用F.batch_norm
函数,这是为了固定批量归一化的行为。
class FixedBatchNorm(nn.BatchNorm2d):
def forward(self, input):
output = F.batch_norm(input, # 输入数据,即卷积层的输出
self.running_mean, # 训练过程中累积的样本均值,训练过程中被动更新,推理阶段用于标准化数据。
self.running_var, # 训练过程中累积的样本方差,训练过程中被动更新,推理阶段用于标准化数据。
self.weight,
self.bias,
training=False, # 推理阶段进行批量归一化,因此不需要计算新的均值和方差,而是使用之前训练时计算得到的self.running_mean和self.running_var。
eps=self.eps # eps是为了数值稳定性而添加到方差的小常量。这可以防止除以接近于零的方差,避免数值不稳定性的问题。
)
return output
计算标准化的值(normalized value):
F.batch_norm
操作在模型的卷积层之后,常常与激活函数一起使用,以促使模型更快地学习和更好