深度学习_mini-batch实现&评价_详解

参考书:深度学习入门:基于Python的理论与实现
mini-batch是在训练数据中随机选择小批量的数据,进行深度学习找到合适权重值。
主要步骤
1.每次从6万多张MNIST数据集图片中挑选100张图片
2.计算梯度
3.根据梯度更新参数
4.如果数据经过一个epoch,则计算识别精度
5.重复以上步骤,通过梯度对参数更新10000次

import numpy as np
import os
import sys
sys.path.append(os.pardir)
from dataset.mnist import load_mnist
from P110_2层神经网络 import TwoLayerNet

(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
# load_mnist函数返回(训练图像,训练标签),(测试图像,测试标签)
# normalize为True将图像正规化为0-1的值,为False保持原来的0-255
# one_hot_label设置标签用one-hot表示
# MNIST数据集训练图像6万张,测试图像1万张,每个图像是28*28的像素
# 这里x_train.shape=(60000,784);t_train.shape=(60000,);x_test.shape=(10000,784);t_test.shape=(10000,)


# 超参数
iters_num = 10000       # 梯度更新次数
train_size = x_train.shape[0]       # 训练数据大小
batch_size = 100        # mini-batch大小为100
learning_rate = 0.1     # 学习率(梯度变化多少)
network = TwoLayerNet(input_size=784, hidden_size=50, output_size=10)
# 输入层神经元784个,隐藏层神经元50个,输出层神经元10个
# 输入层神经元数:每张图像的像元数确定;隐藏层神经元数:要尝试出合适值;输出层神经元数:每张图像想要的类别数


train_loss_list = []        # 训练数据的损失函数值
train_acc_list = []         # 训练数据的识别精度
test_acc_list = []          # 测试数据的识别精度
# 平均每个epoch的重复次数
iter_per_epoch = max(train_size / batch_size, 1)        # 训练数据除以小批量数据,这里等于600


for i in range(iters_num):      # 进行梯度更新,循环iters_num次
    # 获取mini-batch
    batch_mask = np.random.choice(train_size, batch_size)       # 从train_size中抽选batch_size个数
    x_batch = x_train[batch_mask]       # 小批量训练数据
    t_batch = t_train[batch_mask]       # 小批量测试数据

    # 计算梯度
    grad = network.numerical_gradient(x_batch, t_batch)
    # grad = network.gradient(x_batch, t_batch)  #高速版

    # 更新参数
    for key in ('W1', 'b1', 'W2', 'b2'):
        network.params[key] -= learning_rate * grad[key]
        # network.params[key] = network.params[key] - learning_rate * grad[key]

    # 记录学习过程;更新参数后对训练数据计算损失函数的值
    loss = network.loss(x_batch, t_batch)
    train_loss_list.append(loss)

    # 计算每个epoch的识别精度
    if i % iter_per_epoch == 0:
        train_acc = network.accuracy(x_train, t_train)      # 训练数据精度
        test_acc = network.accuracy(x_test, t_test)         # 测试数据精度
        train_acc_list.append(train_acc)
        test_acc_list.append(test_acc)
        print("train acc, test acc |" + str(train_acc) + "," + str(test_acc))


  • 4
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值