MNIST数据集的简单操作(2)

MNIST数据集的简单操作(2)

看了一遍这位博主的博文,自己重新学习并整理了以下内容:
1.引入mnist集数据

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("./MNIST_data",one_hot=True)

a = mnist.train.num_examples
b = mnist.validation.num_examples
c = mnist.test.num_examples
d = mnist.train.images.shape
e = mnist.train.labels.shape

print(a,b,c,d,e)

a——训练集数量
b——验证集数量
c——测试集数量
d——训练集图像形状
e——训练集标签形状

结果如下:
在这里插入图片描述
2.创建模型

#为输入数据和标签提供占位符
x = tf.placeholder(tf.float32,[None,784],name="X")
y = tf.placeholder(tf.float32,[None,10],name="Y")

# w = tf.Variable(tf.random_normal([784,10]),name="W")#正态分布
# b = tf.Variable(tf.zeros([10]),name="B")#常数为0
#初始化权重和偏置,进行第一层卷积操作
conv1_w = tf.Variable(tf.random_normal([3, 
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
根据提供的引用内容,我可以为您介绍一下mnist数据集分类和数据扩增的相关内容。 mnist数据集是一个常用的手写数字识别数据集,包含了60000个训练样本和10000个测试样本。每个样本都是一个28x28像素的灰度图像,表示了0到9之间的一个手写数字。 对于mnist数据集的分类任务,可以使用各种机器学习算法或深度学习模型来进行训练和预测。常见的深度学习模型包括卷积神经网络(CNN)和循环神经网络(RNN)。这些模型可以通过对训练样本进行学习,来预测测试样本的数字标签。 数据扩增是指通过对原始数据进行一系列变换和操作,生成更多的训练样本,以增加模型的泛化能力和鲁棒性。在mnist数据集中,常见的数据扩增操作包括旋转、平移、缩放、翻转等。这些操作可以通过图像处理库(如OpenCV)或深度学习框架(如TensorFlow、PyTorch)中的函数来实现。 通过数据扩增,可以生成更多样本,使得模型在训练过程中能够学习到更多的变化和特征,从而提高模型的分类性能。 下面是一个使用Python和TensorFlow进行mnist数据集分类和数据扩增的示例代码: ```python import tensorflow as tf from tensorflow.keras.datasets import mnist from tensorflow.keras.preprocessing.image import ImageDataGenerator # 加载mnist数据集 (x_train, y_train), (x_test, y_test) = mnist.load_data() # 数据扩增 datagen = ImageDataGenerator( rotation_range=10, # 随机旋转角度范围 width_shift_range=0.1, # 随机水平平移范围 height_shift_range=0.1, # 随机垂直平移范围 zoom_range=0.1, # 随机缩放范围 horizontal_flip=False # 不进行水平翻转 ) # 对训练集进行数据扩增 x_train_augmented = datagen.flow(x_train, y_train, batch_size=32) # 构建模型 model = tf.keras.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) # 编译模型 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # 训练模型 model.fit(x_train_augmented, epochs=10) # 在测试集上评估模型 model.evaluate(x_test, y_test) ``` 这段代码首先加载mnist数据集,然后使用ImageDataGenerator进行数据扩增操作。接下来,构建一个简单神经网络模型,并使用编译后的模型对扩增后的训练集进行训练。最后,使用测试集评估模型的性能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值