以字典方式生成mnist_model_graph.pb并保存mnist图片

代码

TensorFlowMnist-master.zip

#!/usr/bin/env python
#coding: utf-8

import tensorflow as tf
import input_data
from tensorflow.python.framework import graph_util


def build_network(height,width):
    """
    Function:构建网络模型。

    Parameters
    ----------
        height: Mnist图像的宽。
        width: Mnist图像的宽。

    """

    x = tf.placeholder(tf.float32, [None, height, width], name='input')

    y_placeholder = tf.placeholder(tf.float32, shape=[None, 10],name='labels_placeholder')

    keep_prob_placeholder = tf.placeholder(tf.float32, name='keep_prob_placeholder')

    def weight_variable(shape):
        initial = tf.truncated_normal(shape, stddev=0.1)
        return tf.Variable(initial)

    def bias_variable(shape):
        initial = tf.constant(0.1, shape=shape)
        return tf.Variable(initial)

    def conv2d(x, W):
        return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

    def max_pool_2x2(x):
        return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

    x_image = tf.reshape(x, [-1,height, width,1])

    # First Convolutional Layer
    W_conv1 = weight_variable([5, 5, 1, 32])
    b_conv1 = bias_variable([32])

    h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
    h_pool1 = max_pool_2x2(h_conv1)

    # Second Convolutional Layer
    W_conv2 = weight_variable([5, 5, 32, 64])
    b_conv2 = bias_variable([64])

    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
    h_pool2 = max_pool_2x2(h_conv2)

    # Third Convolutional Layer
    W_conv3 = weight_variable([5, 5, 64, 128])
    b_conv3 = bias_variable([128])

    h_conv3 = tf.nn.relu(conv2d(h_pool2, W_conv3) + b_conv3)
    h_pool3 = max_pool_2x2(h_conv3)

    # Densely Connected Layer
    shape = h_pool3.get_shape().as_list()
    print("shape[1]*shape[2]*shape[3]=", (shape[1], shape[2], shape[3]))
    W_fc1 = weight_variable([shape[1] * shape[2] * shape[3], 1024])
    b_fc1 = bias_variable([1024])

    print("W_fc1.get_shape().as_list()",W_fc1.get_shape().as_list())
    h_pool3_flat = tf.reshape(h_pool3, [-1,W_fc1.get_shape().as_list()[0]])
#     h_pool2_flat = tf.reshape(h_pool3, [-1, 4*4*128])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool3_flat, W_fc1) + b_fc1)

    # Dropout
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob_placeholder)

    # Readout Layer
    W_fc2 = weight_variable([1024, 10])
    b_fc2 = bias_variable([10])

    logits = tf.matmul(h_fc1_drop, W_fc2) + b_fc2

    sofmax_out = tf.nn.softmax(logits,name="out_softmax")

    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits,labels=y_placeholder))

    optimize = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cost)

    prediction_labels = tf.argmax(sofmax_out, axis=1,name="output")
    real_labels= tf.argmax(y_placeholder, axis=1)

    correct_prediction = tf.equal(prediction_labels, real_labels)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    #一个Batch中预测正确的次数
    correct_times_in_batch = tf.reduce_sum(tf.cast(correct_prediction, tf.int32))

    return dict(
                keep_prob_placeholder = keep_prob_placeholder,
                x_placeholder= x,
                y_placeholder = y_placeholder,
                optimize = optimize,
                logits = logits,
                prediction_labels = prediction_labels,
                real_labels = real_labels,
                correct_prediction = correct_prediction,
                correct_times_in_batch = correct_times_in_batch,
                cost = cost,
                accuracy = accuracy,
    )

def train_network(graph,
                 dataset,
                 batch_size,
                 num_epochs,
                 pb_file_path,):
    """
    Function:训练网络。

    Parameters
    ----------
        graph: 一个dict,build_network函数的返回值。
        dataset: 数据集
        batch_size: 
        num_epochs: 训练轮数。
        pb_file_path:要生成的pb文件的存放路径。
    """

    with tf.Session() as sess:

        sess.run(tf.global_variables_initializer())

        print ("batch size:",batch_size)

        #用于控制每epoch_delta轮在train set和test set上计算一下accuracy和cost
        epoch_delta = 10
        for epoch_index in range(num_epochs):

            #################################
            #    获取TRAIN set,开始训练网络
            #################################
            for (batch_xs,batch_ys) in dataset.train.mini_batches(batch_size):
                sess.run([graph['optimize']], feed_dict={
                    graph['x_placeholder']: batch_xs,
                    graph['y_placeholder']: batch_ys,
                    graph['keep_prob_placeholder']:0.5,
                })


            #每epoch_delta轮在train set和test set上计算一下accuracy和cost
            if epoch_index % epoch_delta  == 0:
                #################################
                #    开始在 train set上计算一下accuracy和cost
                #################################
                #记录训练集中有多少个batch
                total_batches_in_train_set = 0
                #记录在训练集中预测正确的次数
                total_correct_times_in_train_set = 0
                #记录在训练集中的总cost
                total_cost_in_train_set = 0.
                for (train_batch_xs,train_batch_ys) in dataset.train.mini_batches(batch_size):
                    return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={
                        graph['x_placeholder']: train_batch_xs,
                        graph['y_placeholder']: train_batch_ys,
                        graph['keep_prob_placeholder']:1.0,
                    })
                    mean_cost_in_batch = sess.run(graph['cost'], feed_dict={
                        graph['x_placeholder']: train_batch_xs,
                        graph['y_placeholder']: train_batch_ys,
                        graph['keep_prob_placeholder']:1.0,
                    })

                    total_batches_in_train_set += 1
                    total_correct_times_in_train_set += return_correct_times_in_batch
                    total_cost_in_train_set  += (mean_cost_in_batch*batch_size)


                #################################
                # 开始在 test set上计算一下accuracy和cost
                #################################
                #记录测试集中有多少个batch
                total_batches_in_test_set = 0
                #记录在测试集中预测正确的次数
                total_correct_times_in_test_set = 0
                #记录在测试集中的总cost
                total_cost_in_test_set = 0.
                for (test_batch_xs,test_batch_ys) in dataset.test.mini_batches(batch_size):
                    return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={
                        graph['x_placeholder']: test_batch_xs,
                        graph['y_placeholder']: test_batch_ys,
                        graph['keep_prob_placeholder']:1.0,
                    })
                    mean_cost_in_batch = sess.run(graph['cost'], feed_dict={
                        graph['x_placeholder']: test_batch_xs,
                        graph['y_placeholder']: test_batch_ys,
                        graph['keep_prob_placeholder']:1.0,
                    })

                    total_batches_in_test_set += 1
                    total_correct_times_in_test_set += return_correct_times_in_batch
                    total_cost_in_test_set  += (mean_cost_in_batch*batch_size)

                ### summary and print
                acy_on_test = total_correct_times_in_test_set / float(total_batches_in_test_set * batch_size)
                acy_on_train = total_correct_times_in_train_set / float(total_batches_in_train_set * batch_size)
                print('Epoch - {:2d} , acy_on_test:{:6.2f}%({}/{}),loss_on_test:{:6.2f}, acy_on_train:{:6.2f}%({}/{}),loss_on_train:{:6.2f}'.
                      format(epoch_index, acy_on_test*100.0,total_correct_times_in_test_set,
                             total_batches_in_test_set * batch_size,total_cost_in_test_set, acy_on_train*100.0,
                             total_correct_times_in_train_set,total_batches_in_train_set * batch_size,total_cost_in_train_set))    

            # 每轮训练完后就保存为pb文件
            if epoch_index % 50  == 0:
                constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"]) #out_softmax
                with tf.gfile.FastGFile(pb_file_path,mode='wb') as f:
                    f.write(constant_graph.SerializeToString())



def main():

    batch_size = 30
    num_epochs = 101

    #pb文件保存路径
    pb_file_path = "output/mnist-tf1.0.1.pb"

    g = build_network(height=28, width=28)
    dataset = input_data.read_data_sets()
    train_network(g, dataset, batch_size, num_epochs, pb_file_path)

main()

结果

这里写图片描述
这里写图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

何以问天涯

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值