基于tensorflow的mnist数据集手写字体分类level-1

本文属于学些tensorflow框架系列的文章,不是注重于算法~
基于之前博文中的工作,已经安装好tensorflow等等的配置工作,开始学习tensorflow框架的使用,本文参考了以下链接,致以敬意
http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html
https://github.com/GoogleCloudPlatform/tensorflow-without-a-phd

(建议先读完整个文章再进行实际操作)
1、直接上参考代码:
(1)执行函数代码

import tensorflow as tf
import mnist_data_process


data_path = r'./mnist_data'


def mnist_v1():
    input_data = tf.placeholder('float', [None, 784])
    input_labels = tf.placeholder('float', [None, 10])

    # TODO:1、构建图
    W = tf.Variable(tf.zeros([784, 10]))
    b = tf.Variable(tf.zeros([10]))
    out_put = tf.nn.softmax(tf.matmul(input_data, W) + b)

    # TODO:2、定义损失函数
    cross_entropy_loss = -tf.reduce_sum(input_labels * tf.log(out_put))
    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy_loss)

    # TODO:3、始化所有的参数
    sess = tf.Session()
    init = tf.initialize_all_variables()
    sess.run(init)

    # TODO:4、准备数据
    # 这里是读入tfrecord的数据并且转化为正常的数据矩阵
    mnist_data = mnist_data_process.read_data_sets(data_path, one_hot=True, reshape=True)

    # TODO:5、开始训练,定义训练的迭代轮数
    iter_num = 1000
    for index in range(iter_num):
        batch_xs, batch_ys = mnist_data.train.next_batch(100)
        sess.run(train_step, feed_dict={input_data: batch_xs, input_labels: batch_ys})

    # TODO:6、使用训练完的参数对测试数据进行测试
    test_data = mnist_data.test.images
    test_labels = tf.argmax(mnist_data.test.labels, 1)
    test_y = sess.run(out_put, feed_dict={input_data: test_data})
    predict_labels = tf.argmax(test_y, 1)
    correct_prediction = tf.equal(predict_labels, test_labels)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    print(sess.run(accuracy))


if __name__ == '__main__':
    mnist_v1()

(2)第二级函数代码:mnist_data_process.py

# encoding: UTF-8
# Copyright 2018 Google.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
import numpy as np

from mnist_input_data import load_mnist_data 
from mnist_input_data import load_dataset    
# This loads entire dataset to an in-memory numpy array.
# This uses tf.data.Dataset to avoid duplicating code.
# Normally, if you already have a tf.data.Dataset, loading
# it to memory is not useful. The goal here is educational:
# teach about neural network basics without having to
# explain tf.data.Dataset now. The concept will be introduced
# later.
# The proper way of using tf.data.Dataset is to call
# features, labels = tf_dataset.make_one_shot_iterator().get_next()
# and then to use "features" and "labels" in your Tensorflow
# model directly. These tensorflow nodes, when executed, will
# automatically trigger the loading of the next batch of data.
# The sample that uses tf.data.Dataset correctly is in mlengine/trainer.


class MnistData(object):

    def __init__(self, tf_dataset, one_hot, reshape):
        self.pos = 0
        self.images = None
        self.labels = None
        # load entire Dataset into memory by chunks of 10000
        tf_dataset = tf_dataset.batch(10000)
        tf_dataset = tf_dataset.repeat(1)
        features, labels = tf_dataset.make_one_shot_iterator().get_next()
        if not reshape:
            features = tf.reshape(features, [-1, 28, 28, 1])
        if one_hot:
            labels = tf.one_hot(labels, 10)
        with tf.Session() as sess:
            while True:
                try:
                    feats, labs = sess.run([features, labels])
                    self.images = feats if self.images is None else np.concatenate([self.images, feats])
                    self.labels = labs if self.labels is None else np.concatenate([self.labels, labs])
                except tf.errors.OutOfRangeError:
                    break

    def next_batch(self, batch_size):
        if self.pos+batch_size > len(self.images) or self.pos+batch_size > len(self.labels):
            self.pos = 0
        res = (self.images[self.pos:self.pos+batch_size], self.labels[self.pos:self.pos+batch_size])
        self.pos += batch_size
        return res


class Mnist(object):
    def __init__(self, train_dataset, test_dataset, one_hot, reshape):
        self.train = MnistData(train_dataset, one_hot, reshape)
        self.test = MnistData(test_dataset, one_hot, reshape)


def read_data_sets(data_dir, one_hot, reshape):
    train_images_file, train_labels_file, test_images_file, test_labels_file = load_mnist_data(data_dir)
    train_dataset = load_dataset(train_images_file, train_labels_file)
    train_dataset = train_dataset.shuffle(60000)
    test_dataset = load_dataset(test_images_file, test_labels_file)
    mnist = Mnist(train_dataset, test_dataset, one_hot, reshape)
    return mnist

(3)第三级函数代码:mnist_input_data.py

import os
import gzip
import shutil
from six.moves import urllib
from tensorflow.python.platform import gfile

import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging

logging.set_verbosity(logging.INFO)
logging.log(logging.INFO, "Tensorflow version " + tf.__version__)


def maybe_download_and_ungzip(filename, work_directory, source_url):
    if filename[-3:] == ".gz":
        unzipped_filename = filename[:-3]
    else:
        unzipped_filename = filename

    if not gfile.Exists(work_directory):
        gfile.MakeDirs(work_directory)

    filepath = os.path.join(work_directory, filename)
    unzipped_filepath = os.path.join(work_directory, unzipped_filename)

    if not gfile.Exists(unzipped_filepath):
        if not os._exists(filepath):
            urllib.request.urlretrieve(source_url, filepath)

        if not filename == unzipped_filename:
            with gzip.open(filepath, 'rb') as f_in:
                with open(unzipped_filepath, 'wb') as f_out: # remove .gz
                    shutil.copyfileobj(f_in, f_out)

        with gfile.GFile(filepath) as f:
            size = f.size()
        print('Successfully downloaded and unzipped', filename, size, 'bytes.')
    return unzipped_filepath


def read_label(tf_bytestring):
    label = tf.decode_raw(tf_bytestring, tf.uint8)
    return tf.reshape(label, [])


def read_image(tf_bytestring):
    image = tf.decode_raw(tf_bytestring, tf.uint8)
    return tf.cast(image, tf.float32)/256.0


def load_mnist_data(data_dir):
    SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
    train_images_file = 'train-images-idx3-ubyte.gz'
    local_train_images_file = maybe_download_and_ungzip(train_images_file, data_dir, SOURCE_URL + train_images_file)
    train_labels_file = 'train-labels-idx1-ubyte.gz'
    local_train_labels_file = maybe_download_and_ungzip(train_labels_file, data_dir, SOURCE_URL + train_labels_file)
    test_images_file = 't10k-images-idx3-ubyte.gz'
    local_test_images_file = maybe_download_and_ungzip(test_images_file, data_dir, SOURCE_URL + test_images_file)
    test_labels_file = 't10k-labels-idx1-ubyte.gz'
    local_test_labels_file = maybe_download_and_ungzip(test_labels_file, data_dir, SOURCE_URL + test_labels_file)
    return local_train_images_file, local_train_labels_file, local_test_images_file, local_test_labels_file


# Load a tf.data.Dataset made of interleaved images and labels
# from an image file and a labels file.
def load_dataset(image_file, label_file):
    imagedataset = tf.data.FixedLengthRecordDataset(image_file, 28*28,
                                                    header_bytes=16, buffer_size=1024*16).map(read_image)
    labelsdataset = tf.data.FixedLengthRecordDataset(label_file, 1,
                                                     header_bytes=8, buffer_size=1024*16).map(read_label)
    dataset = tf.data.Dataset.zip((imagedataset, labelsdataset))
    return dataset

2、总结
这只是一个简单使用softmax来对手写字体进行分类的过程,这里相当于将原始灰度图像的像素值作为了特征,softmax相当于一个激活的操作,实验测试结果的acc在91%左右,实际操作该代码主要为了理解tensorflow的运行机制,个人认为从头学习tensorflow的时候直接拿复杂的深度网络模型来调试分析会很容易被搞晕,先用一个简单的例子来学习是比较好的,然后再一步一步的加深。

ps.如果代码里面的路径下载数据集下载不下来就先在http://www.tensorfly.cn/tfdoc/tutorials/mnist_download.html这里下载好数据集放在对应的路径下,然后再运行上面的代码。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值