tensorflow入门实战MNIST

  • 这是github上面的一个开源项目,刚开始运行出错,我自己改了几句代码之后,亲测可运行。
    使用方法:先运行input_data,然后再运行code1,code2。

code1

import tensorflow as tf
import numpy as np
import input_data
import matplotlib.pyplot as plt

# 启用动态图机制
tf.enable_eager_execution()
#调用input_data.py中的read_data_sets函数获取数据
mnist_data = input_data.read_data_sets('mnist_data/', one_hot=False)

train_images = mnist_data.train.images
train_labels = mnist_data.train.labels

test_images = mnist_data.test.images
test_labels = mnist_data.test.labels

#tf.keras.Input函数用于向模型中输入数据,并指定数据的形状、数据类型等信息。
input_ = tf.keras.Input(shape=(784, ))

fc1 = tf.keras.layers.Dense(128, activation='tanh')(input_)
fc2 = tf.keras.layers.Dense(32, activation='tanh')(fc1)
out = tf.keras.layers.Dense(1)(fc2)

# 使用inputs与outputs建立函数链式模型;
model = tf.keras.Model(inputs=input_, outputs=out)

#使用keras构建深度学习模型,我们会通过model.summary()输出模型各层的参数状况
model.summary()

#构建模型后,通过调用compile方法配置其训练过程:
model.compile(loss='mse',optimizer='adam')#mean_squared_error=mse
                        # 顾名思义,意为均方误差,也称标准差,缩写为MSE,可以反映一个数据集的离散程度。
                        #标准误差定义为各测量值误差的平方和的平均值的平方根,故又称为均方误差。
                        # model.compile (optimizer=Adam(lr=1e-4), loss=’binary_crossentropy’, metrics=[‘accuracy’])

#模型拟合
model.fit(x=train_images, y=train_labels, epochs=5)

for i in range(10):
    #tf.expand_dims用来增加维度,
    pred = model(tf.expand_dims(test_images[i], axis=0))
    img = np.reshape(test_images[i], (28, 28))
    lab = test_labels[i]
    print('真实标签: ', lab, ', 网络预测: ', pred.numpy())
    plt.imshow(img)
    plt.show()


#
# import tensorflow.examples.tutorials.mnist.input_data as input_data
# mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
#
# x = tf.placeholder(tf.float32, [None, 784])
# W = tf.Variable(tf.zeros([784,10]))
# b = tf.Variable(tf.zeros([10])+0.1)
# y = tf.nn.softmax(tf.matmul(x,W) + b)
#
# y_ = tf.placeholder("float", [None,10])
#
# cross_entropy = -tf.reduce_sum(y_*tf.log(y))
#
# train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
#
# init = tf.global_variables_initializer()
#
# sess = tf.Session()
# sess.run(init)
#
# for i in range(1000):
#     batch_xs, batch_ys = mnist.train.next_batch(100)
#     sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
#
# correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
#
# accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
#
# print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

code2

import tensorflow as tf
import numpy as np
import input_data
import matplotlib.pyplot as plt
# 启用动态图机制,不可删除
tf.enable_eager_execution()

mnist_data = input_data.read_data_sets('mnist_data', one_hot=False)

train_images = mnist_data.train.images
train_labels = mnist_data.train.labels

test_images = mnist_data.test.images
test_labels = mnist_data.test.labels

input_ = tf.keras.Input(shape=(784, ))
dense = tf.keras.layers.Dense(128, activation='tanh')(input_)
out = tf.keras.layers.Dense(10, activation='softmax')(dense)

model = tf.keras.Model(inputs=input_, outputs=out)
model.summary()
model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy,optimizer='adam', metrics=['accuracy'])

model.fit(x=train_images, y=train_labels, epochs=5)

for i in range(10):
    pred = model(tf.expand_dims(test_images[i], axis=0))
    img = np.reshape(test_images[i], (28, 28))
    lab = test_labels[i]
    print('真实标签: ', lab, ', 网络预测: ', np.argmax(pred.numpy()))
    '''
    import numpy as np
    a = np.array([3, 1, 2, 4, 6, 1])
    b=np.argmax(a)#取出a中元素最大值所对应的索引,此时最大值位6,其对应的位置索引值为4,(索引值默认从0开始)
    print(b)#4
    '''
    plt.imshow(img)
    plt.show()

input_data

  # #!/usr/bin/python
  # # coding:utf-8input_data

  # 用于下载和读取MNIST数据的函数
  from __future__ import absolute_import
  from __future__ import division
  from __future__ import print_function
  import gzip
  import os
  import tensorflow.python.platform
  import numpy
  from six.moves import urllib
  from six.moves import xrange  # pylint: disable=redefined-builtin
  import tensorflow as tf
  SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'


  # 若数据不存在,则从Yann的网站下载数据
  def maybe_download(filename, work_directory):
    if not os.path.exists(work_directory):
      os.mkdir(work_directory)
    filepath = os.path.join(work_directory, filename)
    # 若指定路径不存在,则开始从原网站上下载
    if not os.path.exists(filepath):
      filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
      statinfo = os.stat(filepath)
      print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
    return filepath

  def _read32(bytestream):
    dt = numpy.dtype(numpy.uint32).newbyteorder('>')
    return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]


  # 将图像提取到一个4维uint8类型的numpy数组[index, y, x, depth]
  def extract_images(filename):
    print('Extracting', filename)
    with gzip.open(filename) as bytestream:
      magic = _read32(bytestream)
      if magic != 2051:
        raise ValueError('Invalid magic number %d in MNIST image file: %s' % (magic, filename))
      num_images = _read32(bytestream)
      rows = _read32(bytestream)
      cols = _read32(bytestream)
      buf = bytestream.read(rows * cols * num_images)
      data = numpy.frombuffer(buf, dtype=numpy.uint8)
      data = data.reshape(num_images, rows, cols, 1)
      return data

  # 将类标签从标量转换为一个one-hot向量
  def dense_to_one_hot(labels_dense, num_classes=10):
    num_labels = labels_dense.shape[0]
    index_offset = numpy.arange(num_labels) * num_classes
    labels_one_hot = numpy.zeros((num_labels, num_classes))
    labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
    return labels_one_hot

  # 将标签提取到一维uint8类型的numpy数组[index]中
  def extract_labels(filename, one_hot=False):
    print('Extracting', filename)
    with gzip.open(filename) as bytestream:
      magic = _read32(bytestream)
      if magic != 2049:
        raise ValueError('Invalid magic number %d in MNIST label file: %s' % (magic, filename))
      num_items = _read32(bytestream)
      buf = bytestream.read(num_items)
      labels = numpy.frombuffer(buf, dtype=numpy.uint8)
      if one_hot:
        return dense_to_one_hot(labels)
      return labels

  # 构造DataSet类
  # one_hot arg仅在fake_data为true时使用
  # `dtype`可以是`uint8`,将输入保留为`[0,255]`,或`float32`以重新调整为[0,1]class DataSet(object):
    def __init__(self, images, labels, fake_data=False, one_hot=False, dtype=tf.float32):
      dtype = tf.as_dtype(dtype).base_dtype
      if dtype not in (tf.uint8, tf.float32):
        raise TypeError('Invalid image dtype %r, expected uint8 or float32' % dtype)
      if fake_data:
        self._num_examples = 10000
        self.one_hot = one_hot
      else:
        assert images.shape[0] == labels.shape[0], ('images.shape: %s labels.shape: %s' % (images.shape, labels.shape))
        self._num_examples = images.shape[0]
        # 将[num examples, rows, columns, depth]转换形状成[num examples, rows*columns] (assuming depth == 1)
        assert images.shape[3] == 1
        images = images.reshape(images.shape[0], images.shape[1] * images.shape[2])
        if dtype == tf.float32:
          # 将[0, 255]转换为[0.0, 1.0].
          images = images.astype(numpy.float32)
          images = numpy.multiply(images, 1.0 / 255.0)
      self._images = images
      self._labels = labels
      self._epochs_completed = 0
      self._index_in_epoch = 0
    @property
    def images(self):
      return self._images
    @property
    def labels(self):
      return self._labels
    @property
    def num_examples(self):
      return self._num_examples
    @property
    def epochs_completed(self):
      return self._epochs_completed

    # 从数据集返回下一个`batch_size`示例
    def next_batch(self, batch_size, fake_data=False):
      if fake_data:
        fake_image = [1] * 784
        if self.one_hot:
          fake_label = [1] + [0] * 9
        else:
          fake_label = 0
        return [fake_image for _ in xrange(batch_size)], [fake_label for _ in xrange(batch_size)]
      start = self._index_in_epoch
      self._index_in_epoch += batch_size
      # 完成一个epoch
      if self._index_in_epoch > self._num_examples:
        # 随机抽取数据
        self._epochs_completed += 1
        perm = numpy.arange(self._num_examples)
        numpy.random.shuffle(perm)
        self._images = self._images[perm]
        self._labels = self._labels[perm]
        # 开始下一个epoch
        start = 0
        self._index_in_epoch = batch_size
        assert batch_size <= self._num_examples
      end = self._index_in_epoch
      return self._images[start:end], self._labels[start:end]

  # 读取训练数据
  def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32):
    class DataSets(object):
      pass
    data_sets = DataSets()
    # 若fake_data为true则返回空数据
    if fake_data:
      def fake():
        return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)
      data_sets.train = fake()
      data_sets.validation = fake()
      data_sets.test = fake()
      return data_sets
    # 训练和测试数据文件名
    TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
    TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
    TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
    TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
    VALIDATION_SIZE = 5000
    # 读取训练和测试数据
    local_file = maybe_download(TRAIN_IMAGES, train_dir)
    train_images = extract_images(local_file)
    local_file = maybe_download(TRAIN_LABELS, train_dir)
    train_labels = extract_labels(local_file, one_hot=one_hot)
    local_file = maybe_download(TEST_IMAGES, train_dir)
    test_images = extract_images(local_file)
    local_file = maybe_download(TEST_LABELS, train_dir)
    test_labels = extract_labels(local_file, one_hot=one_hot)
    # 取前5000个作为验证数据
    validation_images = train_images[:VALIDATION_SIZE]
    validation_labels = train_labels[:VALIDATION_SIZE]
    # 取前5000个以后的作为训练数据
    train_images = train_images[VALIDATION_SIZE:]
    train_labels = train_labels[VALIDATION_SIZE:]
    # 定义训练,验证和测试
    data_sets.train = DataSet(train_images, train_labels, dtype=dtype)
    data_sets.validation = DataSet(validation_images, validation_labels, dtype=dtype)
    data_sets.test = DataSet(test_images, test_labels, dtype=dtype)
    return data_sets

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值