通过MNIST理解batch_size的作用

batch_size 的选择在训练神经网络时是一个重要的超参数,它会影响训练的速度、模型的泛化能力以及内存的消耗。

以下是不同 batch_size 对训练过程的影响:

  1. batch_size 太大

    • 优点
      • 训练过程更快,因为可以充分利用硬件的并行计算能力。
    • 缺点
      • 内存需求更高,特别是对于大型模型和大规模数据集。
      • 每次更新所需的计算量增大,可能导致训练过程不稳定,容易陷入局部最优解。
  2. batch_size 太小

    • 优点
      • 内存消耗较小,适用于资源受限的情况。
      • 可以更频繁地更新模型参数,有助于平滑梯度,使训练过程更加稳定。
    • 缺点
      • 训练速度相对较慢,因为不能充分利用硬件的并行计算能力。
      • 容易受到噪声样本的影响,使得模型参数更新的方向不稳定。
  3. 适当选择

    • 一般来说,合适的 batch_size 取决于许多因素,包括模型的大小、硬件资源、数据集大小等。
    • 常用的一些取值范围是 32、64、128 等。在实践中,你可能需要通过实验来找到最适合你的具体情况的 batch_size

总的来说,合适的 batch_size 可以提高训练效率、稳定性,同时避免内存消耗过大。在实践中,通常需要通过实验和调参来选择最合适的 batch_size

学习测试代码

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmax


def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test


def init_network():
    with open("sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network


def predict(network, x):
    w1, w2, w3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, w1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, w2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, w3) + b3
    y = softmax(a3)

    return y


x, t = get_data()
network = init_network()

batch_size = 10 # 批数量
accuracy_cnt = 0

for i in range(0, len(x), batch_size):
    x_batch = x[i:i+batch_size]
    y_batch = predict(network, x_batch)

    p = np.argmax(y_batch, axis=1) # 将预测结果组合成为数组
    accuracy_cnt += np.sum(p == t[i:i+batch_size])

print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

王摇摆

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值