用MNIST数据集实现简单交叉验证

重点在于将数据集元素的再次划分。将数据集中的训练集分成训练集和验证集两部分。主要使用tf.split()函数。主用途是把一个张量分成几个子张量。

tf.split(
    value,
    num_or_size_splits,
    axis=0
    }

value为准备切分的张量
num_or_size_splits用来确定切割方式
axis指切割的维度

分割方式分为两种:

  1. 如果num_or_size_splits 传入的是整数,将直接在axis=x这个维度上把张量平均切分成几个小张量
  2. 如果num_or_size_splits 传入的是向量(向量的各个元素之和要跟原本这个维度的数值相等)就根据这个向量依次在axis=x这个维度切分)

最终代码如下:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, Sequential, optimizers

# load data
(x, y), (x_test, y_test) = datasets.mnist.load_data()

# build datasets


def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32)/255.
    x = tf.reshape(x, [-1, 28*28])
    y = tf.cast(y, dtype=tf.int64)
    y = tf.one_hot(y, depth=10)
    return x, y

batchsizie = 128
x_train, x_val = tf.split(x, num_or_size_splits=[50000, 10000],axis=0) #cut the data
y_train, y_val = tf.split(y, num_or_size_splits=[50000, 10000])

db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db_train = db_train.shuffle(50000).batch(batchsizie).map(preprocess)

db_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
db_val = db_val.shuffle(10000).batch(batchsizie).map(preprocess)

db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.batch(batchsizie).map(preprocess)


#build network
network = Sequential([
    layers.Dense(256, activation=tf.nn.relu),  # [b, 784] to [b, 256]
    layers.Dense(128, activation=tf.nn.relu),  # [b, 256] to [b, 128]
    layers.Dense(64, activation=tf.nn.relu),  # [b, 128] to [b, 64]
    layers.Dense(32, activation=tf.nn.relu),  # [b, 64] to [b, 32]
    layers.Dense(10)  # [b, 32] to [b, 10]
])

network.build(input_shape=[None,28*28])
network.summary()


# train and text
network.compile(optimizer=optimizers.Adam(lr=0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['acc'])
network.fit(db_train, epochs=10, validation_data=db_val, validation_freq=2)
network.evaluate(db_test)

最终调用测试集得到测试准确度

1/79 [..............................] - ETA: 0s - loss: 0.0861 - acc: 0.9766
23/79 [=======>......................] - ETA: 0s - loss: 0.2205 - acc: 0.9647
45/79 [================>.............] - ETA: 0s - loss: 0.2026 - acc: 0.9655
67/79 [========================>.....] - ETA: 0s - loss: 0.1686 - acc: 0.9703
79/79 [==============================] - 0s 2ms/step - loss: 0.1625 - acc: 0.9714
  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
Mnist交叉验证是一种常用的验证集方法,用于评估分类模型在Mnist数据集上的性能。该数据集包含手写数字图片和对应的标签,其中训练集用于训练模型,而验证集则用于在训练过程中评估模型的效果。 Mnist数据集通常通过交叉验证来评估模型的性能。交叉验证是将训练数据分为K个子集,其中K-1个子集用于训练模型,而剩余的一个子集用于验证模型。这个过程会多次进行,每次都会将一个不同的子集作为验证集。通过对每次验证的结果求平均值,可以得到模型在整个训练数据上的性能评估。 具体实施Mnist交叉验证的步骤如下:首先将训练集分为K个子集,通常选择K=5或K=10。然后,对于每个子集i,将其作为验证集,其他K-1个子集作为训练集。接着,在训练集上训练分类模型,并在验证集上评估模型性能,通常使用准确率或其他指标来评估模型。 重复上述步骤K次,每次都选择一个不同的子集作为验证集,并在所有验证结果上求平均值,得到模型的最终性能评估。这样可以有效地减少模型性能评估的方差,并更好地评估模型在整个数据集上的性能。 总之,Mnist交叉验证是一种评估分类模型性能的方法,通过将训练数据划分为多个子集,并在每次验证中选择不同的子集作为验证集,可以更准确地评估模型在整个Mnist数据集上的性能。这种方法可以帮助我们选择合适的模型和参数,并提高模型的泛化能力。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值