Batch and Momentum

Batch and Momentum

shuffle

将资料分成多个Batch的过程,常见的做法是在每一个Epoch之前分一次Batch,每一次的Batch都不一样。

为什么要采用Batch

时间?事实上,因为有平行运算的原因,large Batch(例如只分出1个Batch)在时间上往往要比small Batch更占优势(超大数据集除外)

答案是small Batch的Optimization做的更好,small Batch的update过程是更noisy的,但恰恰是更noisy更有利于optimization的。为什么呢?因为如果是large batch,在update过程中遇到一个local minima或者saddle point就停下来了,而对于small batch,由于每次训练采用的数据集不同,因此function也不同,当一个function遇到local minima或者saddle point停下来了,却不会影响另一个function。

batch

除此之外,small Batch也更有利于预测(或测试),其原因尚有争论,有一个观点是small Batch更有利于遇到flat minima,而large Batch更倾向于进入sharp minima。(flat minima指周围梯度绝对值较小的点,而sharp minima则与之相反,由于测试的数据集与训练的数据集可能存在差别,flat minima显然更有包容性,而small Batch在遇到sharp minima时,更有可能跳出去,因此测试结果更好)。

good-bad-local

small Batch 与 large Batch对比

s-b对比

momentum

每次 θ \theta θ的移动并不是只考虑gradient ,而是考虑过去所有gradient的总和。

即每次移动的方向为梯度反方向+原来移动方向

mumentum

mg-hctsYKRY-1683293147530)]

这种方法的好处是可以有效避免local minima和saddle point(可以类比一下惯性)。mu

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
import mindspore.nn as nn import mindspore.ops.operations as P from mindspore import Model from mindspore import Tensor from mindspore import context from mindspore import dataset as ds from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.nn.metrics import Accuracy # Define the ResNet50 model class ResNet50(nn.Cell): def __init__(self, num_classes=10): super(ResNet50, self).__init__() self.resnet50 = nn.ResNet50(num_classes=num_classes) def construct(self, x): x = self.resnet50(x) return x # Load the CIFAR-10 dataset data_home = "/path/to/cifar-10/" train_data = ds.Cifar10Dataset(data_home, num_parallel_workers=8, shuffle=True) test_data = ds.Cifar10Dataset(data_home, num_parallel_workers=8, shuffle=False) # Define the hyperparameters learning_rate = 0.1 momentum = 0.9 epoch_size = 200 batch_size = 32 # Define the optimizer optimizer = nn.Momentum(filter(lambda x: x.requires_grad, resnet50.get_parameters()), learning_rate, momentum) # Define the loss function loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # Define the model net = ResNet50() # Define the model checkpoint config_ck = CheckpointConfig(save_checkpoint_steps=1000, keep_checkpoint_max=10) ckpt_cb = ModelCheckpoint(prefix="resnet50", directory="./checkpoints/", config=config_ck) # Define the training dataset train_data = train_data.batch(batch_size, drop_remainder=True) # Define the testing dataset test_data = test_data.batch(batch_size, drop_remainder=True) # Define the model and train it model = Model(net, loss_fn=loss_fn, optimizer=optimizer, metrics={"Accuracy": Accuracy()}) model.train(epoch_size, train_data, callbacks=[ckpt_cb, LossMonitor()], dataset_sink_mode=True) # Load the trained model and test it param_dict = load_checkpoint("./checkpoints/resnet50-200_1000.ckpt") load_param_into_net(net, param_dict) model = Model(net, loss_fn=loss_fn, metrics={"Accuracy": Accuracy()}) result = model.eval(test_data) print("Accuracy: ", result["Accuracy"])这段代码有错误
05-29
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值