cnn+pooling, save and restore, 第一次结构比较完整的代码

cnn+pooling, save and restore, 第一次结构比较完整的代码

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os

#hyper parameters
save_path = 'model'
max_globel_step = 1000


mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

def add_layer(input, in_size, out_size, name='layer', activation=None):
    with tf.variable_scope(name):
        Weight = tf.get_variable(name='Weight', shape=(in_size, out_size), dtype=tf.float32, initializer=tf.random_normal_initializer())
        bias = tf.get_variable(name='bias', shape=(out_size), dtype=tf.float32, initializer=tf.constant_initializer(0.1))
        Wx_b = tf.matmul(input, Weight) + bias
        if activation != None:
            output = activation(Wx_b)
        else:
            output = Wx_b
    return output

def weight_variable(x):
    init = tf.truncated_normal(shape=x, stddev=0.1)
    w = tf.Variable(init)
    return w
def bias_variable(x):
    init = tf.constant(0.1, shape=x)
    return tf.Variable(init)
def conv2d(x, w):
    return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME')
def max_pooling_2X2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

#network structure
class nn(object):
    def __init__(self, x, y_label, trainning=True, reuse=False):
        with tf.variable_scope('nn', reuse=reuse):
            self.global_step = tf.Variable(0, trainable=False, dtype=tf.int32, name='global_step')
            with tf.variable_scope('predict'):

                #x (batch, 784) -> (batch, 28, 28, 1)
                x_pic = tf.reshape(x, shape=(-1, 28, 28, 1))

                #1st conv+pooling  [-1, 28, 28 , 1] -> [-1, 14, 14, 32] patch: 5*5
                '''用tf.nn的conv莫名其妙的无法训练
                weight_conv1 = weight_variable([5, 5, 1, 32])
                bias_conv1 = bias_variable([32])
                conv_conv1 = tf.nn.relu(conv2d(x_pic, weight_conv1) + bias_conv1)
                pooling_conv1 = max_pooling_2X2(conv_conv1)
                '''
                conv_conv1 = tf.nn.relu(tf.layers.conv2d(x_pic, filters=32, kernel_size=[5, 5],
                                                         strides=[1, 1], padding='SAME'))
                pooling_conv1 = max_pooling_2X2(conv_conv1)


                # 2st conv+pooling  [-1, 14, 14 , 32] -> [-1, 7, 7, 64] patch: 5*5
                conv_conv2 = tf.nn.relu(tf.layers.conv2d(pooling_conv1, filters=64, kernel_size=[5, 5],
                                              strides=[1, 1], padding='SAME'))
                pooling_conv2 = max_pooling_2X2(conv_conv2)


                #nn fc (-1, 7, 7, 64) -> (-1, 7*7*64) -> (-1, 1024)
                pl_conv2 = tf.reshape(pooling_conv1, shape=(-1, 14*14*32))
                # pl_conv2 = tf.reshape(pooling_conv2, shape=(-1, 7*7*64))
                y = add_layer(pl_conv2, 14*14*32, 1024, activation=tf.nn.relu)
                # y = add_layer(pl_conv2, 7*7*64, 1024, activation=tf.nn.relu)
                y = tf.layers.dropout(y, rate=0, training=trainning)

                #nn fc (-1, 1024) -> (-1, 10)
                y = add_layer(y, 1024, 10, name='layer2', activation=tf.nn.softmax)
            with tf.variable_scope('loss'):
                self.cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_label * tf.log(y + 0.000001),
                                                              reduction_indices=[1]))  # loss

            with tf.variable_scope('accuracy'):
                correct = tf.equal(tf.arg_max(y, 1), tf.arg_max(y_label, 1))
                self.acuracy = tf.reduce_mean(tf.cast(correct, 'float'))

    def get_summary(self):
        return tf.summary.scalar('loss', self.cross_entropy)

#placehoder
x = tf.placeholder(name='x', shape=(None, 784), dtype=tf.float32)
y_label = tf.placeholder(name='y_label', shape=(None, 10), dtype=tf.float32)

#model
train_model = nn(x, y_label, trainning=True, reuse=False)
dev_model = nn(x, y_label, trainning=False, reuse=True)

#optimize
with tf.variable_scope('optimizer'):
    train_opt = tf.train.AdamOptimizer(1e-4).minimize(train_model.cross_entropy, global_step=train_model.global_step)

#saver
saver = tf.train.Saver()

#run sess
with tf.Session() as sess:
    # summary
    train_model.get_summary()
    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter('logs/train', graph=sess.graph)
    test_writer = tf.summary.FileWriter('logs/test', graph=sess.graph)
    #restore or initailize
    ckpt = tf.train.get_checkpoint_state(save_path)
    if ckpt:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        saver.restore(sess, os.path.join(save_path, ckpt_name))
    else:
        sess.run(tf.global_variables_initializer())

    global_step_val = sess.run(train_model.global_step)
    while global_step_val < max_globel_step:
        batch_x, batch_y = mnist.train.next_batch(100)
        sess.run(train_opt, feed_dict={x:batch_x, y_label:batch_y})
        global_step_val += 1
        if global_step_val % 50 == 0:
            train_summary = sess.run(merged, feed_dict={x:mnist.train.images[:1000], y_label:mnist.train.labels[:1000]})
            test_summary = sess.run(merged, feed_dict={x:mnist.test.images[:1000], y_label:mnist.test.labels[:1000]})
            train_writer.add_summary(train_summary, global_step_val)
            test_writer.add_summary(test_summary, global_step_val)
            print(global_step_val)
            print('train')
            print(sess.run(dev_model.acuracy, feed_dict={x:mnist.train.images[:1000], y_label:mnist.train.labels[:1000]}))
            print('test')
            print(sess.run(dev_model.acuracy, feed_dict={x:mnist.test.images[:1000], y_label:mnist.test.labels[:1000]}))
        if global_step_val % 100 == 0:
            print('save')
            saver.save(sess, os.path.join(save_path, 'mnist.ckpt'), global_step_val)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
毕业设计,基于SpringBoot+Vue+MySQL开发的纺织品企业财务管理系统,源码+数据库+毕业论文+视频演示 在如今社会上,关于信息上面的处理,没有任何一个企业或者个人会忽视,如何让信息急速传递,并且归档储存查询,采用之前的纸张记录模式已经不符合当前使用要求了。所以,对纺织品企业财务信息管理的提升,也为了对纺织品企业财务信息进行更好的维护,纺织品企业财务管理系统的出现就变得水到渠成不可缺少。通过对纺织品企业财务管理系统的开发,不仅仅可以学以致用,让学到的知识变成成果出现,也强化了知识记忆,扩大了知识储备,是提升自我的一种很好的方法。通过具体的开发,对整个软件开发的过程熟练掌握,不论是前期的设计,还是后续的编码测试,都有了很深刻的认知。 纺织品企业财务管理系统通过MySQL数据库与Spring Boot框架进行开发,纺织品企业财务管理系统能够实现对财务人员,员工,收费信息,支出信息,薪资信息,留言信息,报销信息等信息的管理。 通过纺织品企业财务管理系统对相关信息的处理,让信息处理变的更加的系统,更加的规范,这是一个必然的结果。已经处理好的信息,不管是用来查找,还是分析,在效率上都会成倍的提高,让计算机变得更加符合生产需要,变成人们不可缺少的一种信息处理工具,实现了绿色办公,节省社会资源,为环境保护也做了力所能及的贡献。 关键字:纺织品企业财务管理系统,薪资信息,报销信息;SpringBoot
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值