用 Python 从零开始创建神经网络(十九):真实数据集

引言

在实践中,深度学习通常涉及庞大的数据集(通常以TB甚至更多为单位),模型的训练可能需要数天、数周甚至数月。这就是为什么到目前为止,我们使用了程序生成的数据集来使学习过程更易管理并保持快速,同时学习深度学习的数学和其他相关方面的知识。本书的主要目标是教授神经网络的工作原理,而不是深度学习在各种问题中的应用。话虽如此,现在我们将探索一个更实际的数据集,因为这将带来一些我们尚未考虑的深度学习新挑战。

如果你在阅读本书之前已经探索过深度学习,你可能已经熟悉(也可能感到厌倦)MNIST数据集,这是一个包含手写数字(0到9)的图像数据集,每张图像的分辨率为28x28像素。它是一个相对较小的数据集,对模型来说也相对容易学习。这个数据集曾成为深度学习的“Hello World”,并且一度是机器学习算法的基准。然而,这个数据集的问题在于,获得99%以上的准确率变得极其容易,因此它无法提供足够的空间来学习各种参数如何影响模型的学习过程。然而,在2017年,一家名为Zalando的公司发布了一个名为Fashion MNIST的数据集(https://arxiv.org/abs/1708.07747),这是MNIST数据集的直接替代品(https://github.com/zalandoresearch/fashion-mnist)。

Fashion MNIST数据集包含60,000个训练样本和10,000个测试样本,这些样本是28x28像素的图像,涵盖了10种不同的服装类别,例如鞋子、靴子、衬衫、包等。我们稍后会看到一些示例,但首先我们需要获取实际的数据。由于原始数据集由包含特定格式编码图像数据的二进制文件组成,为了本书的使用,我们已经准备并托管了一个预处理数据集,其中包含以.png格式保存的图像。通常,对于图像来说,使用无损压缩是明智的,因为有损压缩(例如JPEG)会通过更改图像数据对图像造成影响。这些图像还根据标签分组,并被分成训练组和测试组。样本是服装物品的图像,而标签是分类信息。以下是数字标签及其对应的描述:


在这里插入图片描述


数据准备

首先,我们将从nnfs.io网站获取数据。让我们定义数据集的URL、本地保存的文件名以及解压图像的文件夹:

URL = 'https://nnfs.io/datasets/fashion_mnist_images.zip'
FILE = 'fashion_mnist_images.zip'
FOLDER = 'fashion_mnist_images'

接下来,使用Python的标准库urllib下载压缩数据(如果指定路径下的文件不存在):

import os
import urllib
import urllib.request

if not os.path.isfile(FILE):
    print(f'Downloading {
     
     URL} and saving as {
     
     FILE}...')
    urllib.request.urlretrieve(URL, FILE)

接下来,我们将使用另一个标准的Python库zipfile来解压文件。我们会使用上下文管理器(即with关键字,它会为我们打开和关闭文件)来获取压缩文件的句柄,并使用.extractall方法和指定的FOLDER提取所有包含的文件:

from zipfile import ZipFile

print('Unzipping images...')
with ZipFile(FILE) as zip_images:
    zip_images.extractall(FOLDER)

检索数据的完整代码:

from zipfile import ZipFile
import os
import urllib
import urllib.request

URL = 'https://nnfs.io/datasets/fashion_mnist_images.zip'
FILE = 'fashion_mnist_images.zip'
FOLDER = 'fashion_mnist_images'
if not os.path.isfile(FILE):
	print(f'Downloading {
     
     URL} and saving as {
     
     FILE}...')
	urllib.request.urlretrieve(URL, FILE)

print('Unzipping images...')
with ZipFile(FILE) as zip_images:
	zip_images.extractall(FOLDER)
	
print('Done!')

运行:

>>>
Downloading https://nnfs.io/datasets/fashion_mnist_images.zip and saving as fashion_mnist_images.zip...
Unzipping images...
Done!

现在你应该有一个名为fashion_mnist_images的目录,其中包含test和train目录以及数据许可文件。在test和train目录中各有10个子目录,编号从0到9。这些数字是与其中图像对应的分类。例如,如果我们打开目录0,可以看到这些是短袖或无袖衬衫的图像。例如:

在这里插入图片描述
在目录 7 中,我们有非靴子鞋,或本数据集创建者分类的运动鞋。例如:

在这里插入图片描述
将图像转换为灰度图(即将每像素的三通道RGB值转换为单一的黑白范围,像素值为0到255)是一种常见的做法,不过这些图像已经是灰度图。另外,将图像调整大小以规范其尺寸也是一种常见的做法,但同样地,Fashion MNIST数据集已经经过处理,所有图像的尺寸都相同(28x28)。


数据加载

接下来,我们需要将这些图像读入Python,并将图像(像素)数据与相应的标签关联起来。我们可以通过以下方式访问这些目录:

import os

labels = os.listdir('fashion_mnist_images/train')
print(labels)
>>>
['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

由于子目录名称本身就是标签,我们可以通过查看每个编号子目录中的文件来引用每个类别的单个样本:

files = os.listdir('fashion_mnist_images/train/0')
print(files[:10])
print(len(files))
>>>
['0000.png', '0001.png', '0002.png', '0003.png', '0004.png', '0005.png', '0006.png', '0007.png', '0008.png', '0009.png']
6000

如你所见,我们有6,000个类别0的样本。总共我们有60,000个样本,每个类别6,000个。这意味着我们的数据集已经是平衡的;每个类别出现的频率相同。如果数据集未平衡,神经网络可能会倾向于预测包含最多图像的类别。这是因为神经网络本质上会寻找最陡峭且最快的梯度下降以减少损失,这可能导致陷入局部最小值,使模型无法找到全局损失的最小值。我们这里总共有10个类别,因此在一个平衡的数据集中,随机预测的准确率大约为10%。

然而,假设数据集中类别的不平衡程度为类别0占64%,而类别1到9分别仅占4%。神经网络可能会很快学会始终预测类别0。尽管模型的损失最初会迅速降低,但它可能会一直停留在预测类别0上,准确率接近64%。在这种情况下,我们最好通过削减高频类别的样本数量,使每个类别的样本数量相同。

另一个选择是使用类别权重,在计算损失时为频率较高的类别赋予小于1的权重。然而,在实践中我们几乎没有见过这种方法效果很好。对于图像数据,另一个选择是通过裁剪、旋转、水平或垂直翻转等操作来扩充样本。在应用这些变换之前,需确保它们会生成符合目标的有效样本。幸运的是,我们无需担心这一点,因为Fashion MNIST数据集已经完全平衡。现在,我们将通过查看单个样本来探索数据。为处理图像数据,我们将使用包含OpenCV的Python包,即cv2库,你可以通过pip/pip3安装它:

pip3 install opencv-python

并加载图像数据:

import cv2
image_data = cv2.imread('fashion_mnist_images/train/7/0002.png', cv2.IMREAD_UNCHANGED)
print(image_data)

我们使用cv2.imread()读取图像,其中第一个参数是图像的路径。参数cv2.IMREAD_UNCHANGED通知cv2包,我们希望以图像保存时的格式读取它们(在这种情况下是灰度图)。默认情况下,即使是灰度图,OpenCV也会将这些图像转换为使用所有三个颜色通道。因此,我们得到的是一个二维数组——灰度像素值。如果我们在打印之前使用以下代码行格式化这个杂乱的数组,NumPy会知道打印更多的字符在一行中,因为加载的图像是一个NumPy数组对象:

import numpy as np
np.set_printoptions(linewidth=200)

我们仍然可能能够识别出主题:

>>>
[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   1   0   0   0  49 135 182 150  59   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  78 255 220 212 219 255 246 191 155  87   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   1   0   0  57 206 215 203 191 203 212 216 217 220 211  15   0]
 [  0   0   0   0   0   0   0   0   0   0   1   0   0   0  58 231 220 210 199 209 218 218 217 208 200 215  56   0]
 [  0   0   0   0   1   2   0   0   4   0   0   0   0 145 213 207 199 187 203 210 216 217 215 215 206 215 130   0]
 [  0   0   0   0   1   2   4   0   0   0   3 105 225 205 190 201 210 214 213 215 215 212 211 208 205 207 218   0]
 [  1   5   7   0   0   0   0   0  52 162 217 189 174 157 187 198 202 217 220 223 224 222 217 211 217 201 247  65]
 [  0   0   0   0   0   0  21  72 185 189 171 171 185 203 200 207 208 209 214 219 222 222 224 215 218 211 212 148]
 [  0  70 114 129 145 159 179 196 172 176 185 196 199 206 201 210 212 213 216 218 219 217 212 207 208 200 198 173]
 [  0 122 158 184 194 192 193 196 203 209 211 211 215 218 221 222 226 227 227 226 226 223 222 216 211 208 216 185]
 [ 21   0   0  12  48  82 123 152 170 184 195 211 225 232 233 237 242 242 240 240 238 236 222 209 200 193 185 106]
 [ 26  47  54  18   5   0   0   0   0   0   0   0   0   0   2   4   6   9   9   8   9   6   6   4   2   0   0   0]
 [  0  10  27  45  55  59  57  50  44  51  58  62  65  56  54  57  59  61  60  63  68  67  66  73  77  74  65  39]
 [  0   0   0   0   4   9  18  23  26  25  23  25  29  37  38  37  39  36  29  31  33  34  28  24  20  14   7   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]]

在这种情况下,这是一个运动鞋。与其通过格式化原始值来查看图像,我们可以使用Matplotlib来可视化它。例如:

import matplotlib.pyplot as plt
plt.imshow(image_data)
plt.show()

在这里插入图片描述
我们可以检查另一个样本:

import matplotlib.pyplot as plt
image_data = cv2.imread('fashion_mnist_images/train/4/0011.png', cv2.IMREAD_UNCHANGED)
plt.imshow(image_data)
plt.show()

在这里插入图片描述

看起来像是一件夹克。如果我们查看之前的表格,类别4是“外套”。你可能会对奇怪的颜色感到疑惑,但这只是因为Matplotlib默认不期望灰度图像。我们可以在调用plt.imshow()时通过指定cmap(颜色映射)来通知Matplotlib这是灰度图像:

import matplotlib.pyplot as plt
image_data = cv2.imread('fashion_mnist_images/train/4/0011.png', cv2.IMREAD_UNCHANGED)
plt.imshow(image_data, cmap='gray')
plt.show()

在这里插入图片描述

现在我们可以遍历所有样本,将它们加载到输入数据( X X X)和目标( y y y)列表中。首先,我们扫描训练文件夹,正如之前提到的,该文件夹包含从0到9命名的子文件夹,这些子文件夹同时也充当样本标签。我们遍历这些文件夹及其中的图像,将图像添加到一个列表变量(命名为 X X X)中,并将其对应的标签添加到另一个列表变量(命名为 y y y)中,从而形成我们的样本和真实标签(目标标签):

# Scan all the directories and create a list of labels
labels = os.listdir('fashion_mnist_images/train')
# Create lists for samples and labels
X = []
y = []
# For each label folder
for label in labels:
    # And for each image in given folder
    for file in os.listdir(os.path.join('fashion_mnist_images', 'train', label)):
        # Read the image
        image = cv2.imread(os.path.join('fashion_mnist_images/train', label, file), cv2.IMREAD_UNCHANGED)
        # And append it and a label to the lists
        X.append(image)
        y.append(label)

我们需要对测试数据和训练数据执行相同的操作。幸运的是,它们已经为我们很好地分开了。很多时候,你需要自己将数据分成训练组和测试组。我们将把上述代码转换为一个函数,以避免为训练和测试目录重复编写代码。这个函数将接收一个数据集类型(训练或测试)作为参数,以及这些数据集所在路径:

import numpy as np
import cv2
import os

# Loads a MNIST dataset
def load_mnist_dataset(dataset, path):
    # Scan all the directories and create a list of labels
    labels = os.listdir(os.path.join(path, dataset))
    # Create lists for samples and labels
    X = []
    y = []
    # For each label folder
    for label in labels:
        # And for each image in given folder
        for file in os.listdir(os.path.join(path, dataset, label)):
            # Read the image
            image = cv2.imread(os.path.join(path, dataset, label, file), cv2.IMREAD_UNCHANGED)
            # And append it and a label to the lists
            X.append(image)
            y.append(label)
    # Convert the data to proper numpy arrays and return
    return np.array(X), np.array(y).astype('uint8')

由于 X X X被定义为一个列表,并且我们将以NumPy数组形式表示的图像添加到这个列表中,因此我们会在最后调用np.array() X X X从列表转换为一个正式的NumPy数组。对于标签( y y y)也会执行相同的操作,因为它们是一个数字列表,我们还需要告知NumPy这些标签是整数(而非浮点数)值。

然后,我们可以编写一个函数,用于创建并返回我们的训练和测试数据:

# MNIST dataset (train + test)
def create_data_mnist(path):
    # Load both sets separately
    X, y = load_mnist_dataset('train', path)
    X_test, y_test = load_mnist_dataset('test', path)
    # And return all the data
    return X, y, X_test, y_test

到目前为止,针对我们新数据的代码:

import numpy as np
import cv2
import os

# Loads a MNIST dataset
def load_mnist_dataset(dataset, path):
    # Scan all the directories and create a list of labels
    labels = os.listdir(os.path.join(path, dataset))
    # Create lists for samples and labels
    X = []
    y = []
    # For each label folder
    for label in labels:
        # And for each image in given folder
        for file in os.listdir(os.path.join(path, dataset, label)):
            # Read the image
            image = cv2.imread(os.path.join(path, dataset, label, file), cv2.IMREAD_UNCHANGED)
            # And append it and a label to the lists
            X.append(image)
            y.append(label)
    # Convert the data to proper numpy arrays and return
    return np.array(X), np.array(y).astype('uint8')


# MNIST dataset (train + test)
def create_data_mnist(path):
    # Load both sets separately
    X, y = load_mnist_dataset('train', path)
    X_test, y_test = load_mnist_dataset('test', path)
    # And return all the data
    return X, y, X_test, y_test

有了这个函数,我们就可以通过以下操作加载数据:

# Create dataset
X, y, X_test, y_test = create_data_mnist('fashion_mnist_images')

数据预处理

接下来,我们将对数据进行缩放(不是对图像本身,而是表示它们的数字)。神经网络在数据范围为0到1或-1到1时通常表现最佳。在这里,图像数据的范围是0到255。我们需要决定如何对这些数据进行缩放。通常,这一过程需要一些实验和反复试验。例如,我们可以将图像缩放到-1到1的范围,通过对每个像素值减去所有像素值的最大值的一半(即 255 / 2 = 127.5 255/2 = 127.5 255/2=127.5),然后除以这一半,从而生成一个范围为-1到1的值。我们也可以通过简单地将数据除以255(最大值)将其缩放到0到1的范围。首先,我们选择将数据缩放到-1到1的范围。在执行这一操作之前,我们需要更改NumPy数组的数据类型,当前的数据类型是uint8(无符号整数,范围为0到255的整数值)。如果我们不更改,NumPy会将其转换为float64数据类型,而我们的目的是使用float32(32位浮点值)。可以通过在NumPy数组对象上调用.astype(np.float32)实现。标签将保持不变:

# Create dataset
X, y, X_test, y_test = create_data_mnist('fashion_mnist_images')
# Scale features
X = (X.astype(np.float32) - 127.5) / 127.5
X_test = (X_test.astype(np.float32) - 127.5) / 127.5

确保使用相同的方法对训练数据和测试数据进行缩放。稍后,在进行预测时,你还需要对用于推断的输入数据进行缩放。在不同的地方忘记对数据进行缩放是很常见的错误。同时,你需要确保任何预处理操作(例如缩放)仅基于训练数据集的信息。在这个例子中,我们知道最小值(min)和最大值(max)分别为0和255,并执行了线性缩放。然而,通常你需要首先查询数据集的最小值和最大值以用于缩放。如果你的数据集中存在极端异常值,最小/最大值方法可能效果不佳。在这种情况下,你可以使用平均值和标准差的某种组合来创建缩放方法。

缩放时一个常见的错误是允许测试数据集影响对训练数据集所做的变换。对此规则唯一的例外是当数据以线性方式缩放,例如通过提到的除以常数的方式。如果使用的是非线性缩放函数,可能会将测试或验证数据的信息泄露到训练数据中。任何预处理规则都应该在不了解测试数据集的情况下得出,但随后应用于测试数据集。例如,你的整个数据集可能最小值为0,最大值为125,而训练数据集的最小值为0,最大值为100。在这种情况下,你仍然会使用100作为缩放测试数据集的值。这意味着你的测试数据集在缩放后可能不会完全适合-1到1的范围,但这通常不是问题。如果差异较大,你可以通过将数据线性缩放再除以某个数值来进行额外的缩放。

回到我们的数据,让我们检查一下数据是否已经缩放:

print(X.min(), X.max())
>>>
-1.0 1.0

接下来,我们检查输入数据的形状:

print(X.shape)
>>>
(60000, 28, 28)

我们的Dense层处理的是一维向量的批量数据,无法直接操作形状为28x28的二维数组图像。我们需要将这些28x28的图像“展平”,这意味着将图像数组的每一行依次附加到数组的第一行,从而将图像的二维数组转换为一维数组(即向量)。换句话说,这可以看作是将二维数组中的数字展开为类似列表的形式。有一种叫做卷积神经网络的模型,可以直接处理二维图像数据,但像我们这里的全连接神经网络(Dense网络)需要一维的样本数据。即使在卷积神经网络中,你通常也需要在将数据传递到输出层或Dense层之前对数据进行展平。

在NumPy中展平数组可以使用reshape方法,并将第一个维度设置为-1,表示“根据实际元素数量决定”,从而将所有元素放在第一维度中,形成一个一维数组。以下是这种概念的一个示例:

example = np.array([[1,2],[3,4]])
flattened = example.reshape(-1)
print(example)
print(example.shape)
print(flattened)
print(flattened.shape)
>>>
[[1 2]
 [3 4]]
(2, 2)
[1 2 3 4]
(4,)

我们也可以使用np.flatten()方法,但当处理一批样本时,我们的意图有所不同。在样本的情况下,我们希望保留所有60,000个样本,因此我们需要将训练数据的形状调整为(60000, -1)。这将通知NumPy我们希望保留60,000个样本(第一维度),但将其余的部分展平(-1作为第二维度意味着我们希望将所有样本数据放入这个单一维度中,形成一维数组)。这将创建60,000个样本,每个样本包含784个特征。这784个特征是28·28的结果。为此,我们将分别使用训练数据(X.shape[0])和测试数据(X_test.shape[0])的样本数量,并对它们进行reshape操作:

# Reshape to vectors
X = X.reshape(X.shape[0], -1)
X_test = X_test.reshape(X_test.shape[0], -1)

你也可以通过显式定义形状来实现相同的结果,而不是依赖NumPy的推断:

.reshape(X.shape[0], X.shape[1]*X.shape[2])

这样会更明确,但我们认为这样不太清晰。


数据洗牌

我们当前的数据集由样本及其目标分类组成,按顺序从0到9排列。为了说明这一点,我们可以在不同位置查询 y y y数据。前6000个样本的标签都将是0。例如:

print(y[0:10])
>>>
[0 0 0 0 0 0 0 0 0 0]

如果我们稍后再进行查询:

print(y[6000:6010])
>>>
[1 1 1 1 1 1 1 1 1 1]

如果我们按这种顺序训练网络,会导致问题;原因与数据集不平衡的问题类似。

在训练前6,000个样本时,模型会学到最快减少损失的方法是始终预测为0,因为它会看到多个仅包含类别0的数据批次。然后,在6,000到12,000之间,损失会因为标签的变化而最初上升,而模型仍然会错误地预测标签为0。随后,模型可能会学到现在需要始终预测类别1(因为它在优化过程中看到的标签批次全是类别1)。模型会在当前批次中重复的标签附近循环于局部最小值,并且很可能永远找不到全局最小值。这一过程会一直持续,直到我们完成所有样本,并重复我们选择的训练轮数(epochs)。

理想情况下,每次拟合的样本中应该包含多个类别(最好每个类别都有一些),以防止模型因为最近看到某个类别较多而对该类别产生偏向。因此,我们通常会随机打乱数据。在之前的训练数据(例如螺旋数据)中,我们不需要打乱数据,因为我们是一次性对整个数据集进行训练,而不是分批次训练。但对于这个更大的数据集,我们是以批次进行训练的,因此需要打乱数据,因为目前数据是按每个标签的6,000个样本块顺序排列的。

在打乱数据时,我们需要确保样本数组和目标数组同步打乱,否则标签将不再与样本匹配,导致模型非常混乱(在大多数情况下结果也非常错误)。因此,我们不能简单地分别对它们调用shuffle()方法。有许多方法可以实现这一点,但我们的方法是获取所有的“键”(在这里是样本和目标数组的索引),然后对这些键进行打乱。

在这个例子中,这些键的值范围是从0到59999。

keys = np.array(range(X.shape[0]))
print(keys[:10])
>>>
array([0 1 2 3 4 5 6 7 8 9])

然后,我们就可以对这些密钥进行洗牌:

import nnfs
nnfs.init()
np.random.shuffle(keys)
print(keys[:10])
>>>
[ 3048 19563 58303  8870 40228 31488 21860 56864   845 25770]

现在,这基本上就是新的索引顺序,我们可以通过操作来应用它:

X = X[keys]
y = y[keys]

这告诉NumPy按照给定的索引返回对应的值,就像我们通常对NumPy数组进行索引一样,但这里我们使用的是一个包含随机顺序索引的变量。然后我们可以检查目标数据的一部分切片:

print(y[:15])
>>>
[0 3 9 1 6 5 3 9 0 4 8 9 0 6 6]

它们似乎被洗过。我们也可以检查个别样本:

import matplotlib.pyplot as plt
plt.imshow((X[4].reshape(28, 28))) # Reshape as image is a vector already
plt.show()
洗牌之后随机的shirt图片

在这里插入图片描述

然后,我们就可以在同一索引下检查该类:

print(y[4])
>>>
6

类别6确实是“衬衫”,因此这些数据看起来已经正确地打乱了。你可以手动再检查一些数据,以确保数据符合预期。如果模型无法训练或表现异常,你需要仔细检查数据的预处理过程。


批次(Batches)

到目前为止,我们通过将整个数据集作为一个单一的“批次”传递给模型来训练我们的模型。我们在第2章中讨论过为什么一次处理多个样本是更优的,但是否存在一个过大的批次大小呢?我们的数据集足够小,可以一次性传递整个数据集,但真实世界中的数据集通常可能有TB或更多的规模,这对于大多数计算机来说无法作为一个单一批次处理。

批次是数据的一个固定大小的切片。当我们使用批次进行训练时,我们一次以一个数据块或“批次”来迭代数据集,依次执行前向传播、损失计算、反向传播和优化。如果数据已经被打乱,并且每个批次足够大且能在一定程度上代表整个数据集,那么可以合理地假设每个批次的梯度方向是朝向全局最小值的良好近似。如果批次太小,梯度下降的方向可能会在不同批次之间波动过大,导致模型训练耗时较长。

常见的批次大小范围是32到128个样本。如果你的内存不足,可以使用更小的批次;如果想让训练更快,可以使用更大的批次,但通常这个范围是典型的批次大小范围。通过将批次大小从2增加到8,或者从8增加到32,通常可以看到准确率和损失的改善。然而,继续增大批次大小时,关于准确率和损失的提升会逐渐减少。此外,与较小的批次相比,使用大批次进行训练会变得更慢——就像我们之前使用螺旋数据的例子,需要1万轮训练!在神经网络中,很多时候需要针对具体的数据和模型进行大量的试验和调整。

例如,假设我们选择批次大小为128,并选择进行10轮训练。这意味着,在每一轮训练中,我们会遍历数据集,每次拟合128个样本来训练模型。每次训练的批次被称为一个“步骤”。我们可以通过将样本数量除以批次大小来计算步骤的数量:

steps = X.shape[0] // BATCH_SIZE

我们使用整数除法运算符//(而不是浮点除法运算符/)来返回整数,因为步骤的数量不能包含小数部分。这是我们在每一轮训练中循环执行的迭代次数。如果有一些剩余的样本无法被整除,我们可以通过简单地增加一步来将它们包含进去:

if steps * BATCH_SIZE < X.shape[0]:
    steps += 1

我们可以通过一个简单的例子来说明为什么要添加这个 1:

batch_size = 2
X = [1, 2, 3, 4]
print(len(X) // batch_size)
>>>
2
X = [1, 2, 3, 4, 5]
print(len(X) 
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值