4 神经网络的学习
这一章主要讲解神经网络的学习,包括第三章初步介绍的前向传播,已经这一章要将的反向传播等。
4.1 从数据中学习
神经网络的特征就是可以从数据中学习。所谓从数据中学习,是指可以由数据自动决定权重参数的取值。
4.1.1 数据驱动
利用数据相出一个可以识别数字的算法。一种方案是,先从图像中提取特征量,再用机器学习技术学习这些特征量的模式,最后对转换后的向量使用机器学习中的SVM、KNN等分类器进行学习。
神经网络可以将数据直接作为原始数据,进行“端到端”的学习。所谓端到端是指从一端到另一端,也就是从原始数据(输入)中获得目标结果(输出)的意思。
4.1.2 训练数据和测试数据
泛化能力或过拟合问题
4.2 损失函数
神经网络以某个指标为线索寻找最优权重参数。神经网络中学习所用的指标称为损失函数,这个损失函数可以使用任意函数。
4.2.1 均方误差
常用损失函数之一。
其中表示神经网络的输出,表示实际数据,表示数据的维数。
python实现:
def mean_squared_error(y, t):
return 0.5 * np.sum((y-t)**2)
4.2.2 交叉熵误差
另一个常用的误差函数。
表示神经网络的输出(是个概率,如sigmoid或者softmax的输出),是正确解的标签(采用one-hot表示)
代码实现:
def cross_entropy_error(y, t):
delta = 1e-7
return -np.sum(t * np.log(y + delta))
这里加上了一个微小值delta,因为当出现np.log(0)时会变为负无穷大,这样就导致后面无法计算。所以加入了保护性对策。
4.2.3 mini-batch学习
前面介绍的损失函数都是针对单个数据的,当采用批处理时,需要算出所有数据的损失函数的总和。
对于交叉熵:
这里假设有N个数据,表示第n个数据的第k个元素的值。是神经网络的输出,是对应的实际数据。实质上是将求单个数据的损失函数扩大到了N份数据,不过最后还要除以N进行正规化。通过除以N,可以求单个数据的“平均损失函数”。通过这样的平均化,可以获得和训练数据的数量无关的统一指标。
另外,对于有些数据集训练数据非常大,如果以全部数据为对象求损失函数的和,则计算过程需要花费较长的时间。因此从全部数据中选出一部分,作为全部数据的“近似”。神经网络的学习也是从训练数据中选出一批数据(称为mini-batch,小批量),然后对每个mini-batch进行学习。这种方式成为mini-batch学习。
从训练数据中随机选择指定个数的数据,以进行mini-batch学习
train_size = x_train.shape[0]
batch_size = 10
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
使用 np.random.choice(),可以从指定数量的数字中随机选择想要的数量的数字.
>>> np.random.choice(100, 10)
array([28, 64, 60, 53, 35, 87, 51, 67, 77, 56])
4.2.4 mini-batch版交叉熵误差的实现
这里实现一个可同时处理单个数据和批量数据的版本
def cross_entropy_error(y, t):
if y.ndim == 1:
t = t.reshape(1, t.size)
y = y.reshape(1, y.size)
batch_size = y.shape[0]
return -np.sum(t * np.log(y + 1e-7)) / batch_size
前面提到了这里的t应该是one-hot编码的方式的,即t中是一组01向量,正确数据的索引为1,其余为0.
如果不是这样的形式,而是只有一个数字,就是正确数据。
则可通过如下代码实现:
def cross_entropy_error(y, t):
if y.ndim == 1:
t = t.reshape(1, t.size)
y = y.reshape(1, y.size)
batch_size = y.shape[0]
return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_size
4.2.5 为何要设定损失函数
在神经网络的学习中,寻找最优参数时,要寻找使损失函数的值尽可能小的参数。为了找到使损失参数的值尽可能小的地方,需要计算参数的导数(确切地讲是梯度),然后以这个导数为指引,逐步更新参数的值。
不以识别精度作为指标是因为,识别精度和参数之间没有规则公式化的隐含关系,无法通过导数这样的方式进行调参,调参方式没有合适的准则。