下面是配置文件,建议单独弄一个文件,可以随时改。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2018/4/6 23:05
# @Author : Zehan Song
# @Site :
# @File : configs.py
# @Software: PyCharm
import warnings
class Config(object):
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500
BATCH_SIZE = 64
SAVE_PATH = './model_path/'
MAX_EPOCH = 100
LEARNING_BASE = 0.05
DECAY_RATE = 0.99
MOVING_AVERAGE_DECAY = 0.99
MODEL_NAME = "two_linear_layers_mnist.ckpt"
REGULARIZATION_RATE = 0.0001
def Parse(self,kwargs):#parse the dict
print "user config:"
for k,v in kwargs.items():
if not hasattr(self,k):
warnings.warn(KeyError)
else:
setattr(self,k,v)
for k,v in self.__class__.__dict__.items():
if not (k.startswith("__") or k.startswith("Parse")):
print (k,getattr(self,k))
Config.Parse = Parse#class attribute instead of object attribute
configs= Config()
之后是训练模型的代码,有不懂的问题可以在下面评论。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2018/4/6 23:05
# @Author : Zehan Song
# @Site :
# @File : best_tensorflow_practice_example.py
# @Software: PyCharm
from configs import configs
import tensorflow as tf
import time
from tqdm import tqdm
import fire
from tensorflow.examples.tutorials.mnist import input_data
# inference
def get_weight_variable(shape,regularizer):
weights = tf.get_variable(
name="weights",shape=shape,initializer=tf.truncated_normal_initializer(stddev=0.1)
)
if regularizer != None:
tf.add_to_collection("losses",regularizer(weights))
return weights
def inference(input_tensor,regularizer):
# first layer
# if IsTrain:
# reuse = False
# regularizer = regularizer
# else:
# reuse = True
# regularizer = None
with tf.variable_scope('layer1'):
weights = get_weight_variable([configs.INPUT_NODE,configs.LAYER1_NODE],regularizer)
biases = tf.get_variable("biases",[configs.LAYER1_NODE],initializer=tf.constant_initializer(0.0))
layer1 = tf.nn.relu(tf.matmul(input_tensor,weights)+biases)
# second layer
with tf.variable_scope('layer2'):
weights = get_weight_variable([configs.LAYER1_NODE,configs.OUTPUT_NODE],regularizer)
biases = tf.get_variable("biases",[configs.OUTPUT_NODE],initializer=tf.constant_initializer(0.0))
layer2 = tf.matmul(layer1,weights)+biases
return layer2
def train(**kwargs):
configs.Parse(kwargs)
# set placeholder,regularizer
x = tf.placeholder(dtype=tf.float32, shape=[None, configs.INPUT_NODE], name='x-input')
y = tf.placeholder(dtype=tf.float32, shape=[None, configs.OUTPUT_NODE], name='y-input')
# flag = tf.placeholder(dtype=tf.bool,shape=[1],name='flag')
regularizer = tf.contrib.layers.l2_regularizer(configs.REGULARIZATION_RATE)
#prepare data
mnist = input_data.read_data_sets("./data",one_hot=True)
validation_feed = {x:mnist.validation.images,y:mnist.validation.labels}
test_feed = {x:mnist.test.images,y:mnist.test.labels}
#inference
y_ = inference(x,regularizer)
# set useful tricks
global_step = tf.Variable(0,trainable=False)
lr = tf.train.exponential_decay(configs.LEARNING_BASE,global_step,mnist.train.num_examples/configs.BATCH_SIZE,configs.DECAY_RATE)
saver = tf.train.Saver()
# compute loss,accuracy and set optimizer,remember to add the regularization loss
cross_entropy_mean = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits\
(labels=tf.argmax(y,1),logits=y_))
loss = cross_entropy_mean + tf.add_n(tf.get_collection("losses"))
optimizer = tf.train.GradientDescentOptimizer(lr).minimize(loss,global_step=global_step)
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y,1),tf.argmax(y_,1)),tf.float32))
# start session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True#limit occupation of GPU
config.allow_soft_placement = True#allow computing among different gpus and cpus
with tf.Session(config=config).as_default() as sess:
tf.global_variables_initializer().run()#very important!!!
# start training
train_op = tf.group([optimizer],name='no_op')
# TRAINING_STEPS = mnist.train.num_examples/configs.BATCH_SIZE
TRAINING_STEPS = 5000
best_accuracy = 0
# for epoch in range(configs.MAX_EPOCH):
for i in tqdm(range(TRAINING_STEPS)):
xs,ys = mnist.train.next_batch(configs.BATCH_SIZE)
_,loss_value,step,accuracy_value = sess.run([train_op,loss,global_step,accuracy],\
feed_dict={x:xs,y:ys})
if i%50 == 0:
# saver.save(sess,"./saveforema.ckpt")
print("iters[%d/%d],loss:%.6f,accuracy:%.6f"\
%(i,TRAINING_STEPS,loss_value,accuracy_value))
# if epoch%10 == 0:#validation and don't add regularization loss
# saver.restore(sess,"./saveforema.ckpt")#use ema
if i%100 == 0:
loss_value, accuracy_value = sess.run([cross_entropy_mean, accuracy], feed_dict=validation_feed)
print("val_loss:%.6f,val_accuracy:%.6f"%(loss_value,accuracy_value))
if best_accuracy<accuracy_value:
best_accuracy = accuracy_value
saver.save(sess, configs.SAVE_PATH + "best.cpkt")
#save model
# print("validation accuracy:%.6f validation loss:%.6f"%(accuracy_value,loss_value))
#start test
ckpt = tf.train.get_checkpoint_state(configs.SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)
loss_value,accuracy_value = sess.run([cross_entropy_mean,accuracy],feed_dict=test_feed)
print("test accuracy:%.6f test loss:%.6f" % (accuracy_value, loss_value))
if __name__ == "__main__":
fire.Fire()
#input the following line on your server and run the above code
#python best_tensorflow_practice_example.py train --BATCH_SIZE=100
下面是运行过程截图