- 注:本文翻译自Tensorflow官网教程,原文链接:
翻译有误之处请多指正,以下正文开始。
与往常一样,本例中的代码将使用tf.keras
API,可以在TensorFlow Keras guide 中了解更多。
在之前的分类文本(classifying text)和预测燃料效率(predicting fuel efficiency)的例子中,我们看到我们的模型在验证数据上的准确性会在经过几个 Epoch 的训练后达到顶峰,然后会停滞不前或开始下降。
换句话说,我们的模型会过度拟合训练数据。学习如何处理过拟合很重要。尽管在训练集上获得高精度是可能的,但我们真正想要的是,开发一个能够很好地推广到测试集(或模型没有见过的数据)的模型。
过拟合的反义词是欠拟合。当测试数据仍有改进空间时,就会出现拟合不足。出现这种情况的原因有很多:模型不够强大、过于标准化,或者只是训练的时间不够长。这意味着神经网络没有(完全)学习到训练数据中的模式。
但是,如果训练时间过长,模型就会开始过度适应,从训练数据中学习到那些不能适应到测试数据中的特征。我们需要找到一个平衡点。在下面的探索中,我们将理解如何训练一个适当数量的 Epoch 。
为了防止过拟合,最好的解决方案是使用更完整的训练数据。数据集应该涵盖我们期望让模型进行处理的所有输入情况。其他额外的数据可能只有在涉及新的情况时才有用。
在更完整的数据上训练的模型自然会拟合得更好。在我们不能再完善数据集的情况下,下一个最好的解决方案是使用正则化之类的技术。这些限制了模型可以存储的信息的数量和类型。如果一个网络只能记住少量的识别模式,那么优化过程将迫使它专注于那些最突出的模式,这些模式具有更好的推广机会。
在本文中,我们将探讨几种常见的正则化技术,并使用它们来改进分类模型。
设置
在开始之前,导入必要的包:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import regularizers
print(tf.__version__)
!pip install -q git+https://github.com/tensorflow/docs
import tensorflow_docs as tfdocs
import tensorflow_docs.modeling
import tensorflow_docs.plots
from IPython import display
from matplotlib import pyplot as plt
import numpy as np
import pathlib
import shutil
import tempfile
logdir = pathlib.Path(tempfile.mkdtemp())/"tensorboard_logs"
shutil.rmtree(logdir, ignore_errors=True)
希格斯粒子数据集
本教程的目标不是做粒子物理,所以不需要详述数据集的细节。它包含了1100万个例子,每个例子含有28个特性和一个二进制类标签。
gz = tf.keras.utils.get_file('HIGGS.csv.gz', 'http://mlphysics.ics.uci.edu/data/higgs/HIGGS.csv.gz')
FEATURES = 28
tf.data.experimental.CsvDataset
类可以直接从gzip文件中读取csv记录,并且不需要中间的解压缩步骤。
ds = tf.data.experimental.CsvDataset(gz,[float(),]*(FEATURES+1), compression_type="GZIP")
这个csv阅读器类为每条记录返回一个标量列表。下面的函数将该标量列表重新打包为(feature_vector, label) 对。
def pack_row(*row):
label = row[0]
features = tf.stack(row[1:],1)
return features, label
TensorFlow在处理大量数据时的效率最高。
因此,与其单独重新打包每一行,不如创建一个新的数据集(Dataset
),接收10,000个例子,对每一个batch应用pack_row
函数,然后将这些batches分割回单个记录:
packed_ds = ds.batch(10000).map(pack_row).unbatch()
看看这个新的packed_ds
中的一些记录。
这些特性并没有完全规范化,但对于本文来说这已经足够了。
for features,label in packed_ds.batch(1000).take(1):
print(features[0])
plt.hist(features.numpy().flatten(), bins = 101)
为了使本文相对简短,我们只使用前1000个示例进行验证,接下来10000个示例用于训练:
N_VALIDATION = int(1e3)
N_TRAIN