MNIST_UPDATE_inference.py代码如下:
# -*- coding: utf-8 -*-
import tensorflow as tf
tf.reset_default_graph()
INPUT_NODE=784
LAYER1_NODE=500
OUTPUT_NODE=10
def get_weight_variable(shape,regularizer=None):
weights=tf.get_variable("weights",shape,initializer=tf.truncated_normal_initializer(stddev=0.1))
if regularizer != None:
tf.add_to_collection('losses',regularizer(weights))
return weights
def inference(input_tensor,regularizer):
with tf.variable_scope('layer1'):
weights=get_weight_variable([INPUT_NODE,LAYER1_NODE],regularizer)
biases=tf.get_variable("biases",[LAYER1_NODE],initializer=tf.constant_initializer(0.0))
layer1=tf.nn.relu(tf.matmul(input_tensor,weights)+biases)
with tf.variable_scope('layer2'):
weights=get_weight_variable([LAYER1_NODE,OUTPUT_NODE],regularizer)
biases=tf.get_variable("biases",[OUTPUT_NODE],initializer=tf.constant_initializer(0.0))
layer2=tf.matmul(layer1,weights)+biases
return layer2
MNIST_UPDATE_train.py代码如下:
# -*- coding: utf-8 -*-
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import MNIST_UPDATE_inference
#tf.reset_default_graph()
BATCH_SIZE=100
LEARNING_RATE_BAS