import sys, os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("./MNIST_data/",one_hot=True)
import numpy as np
from PIL import Image
how_much=100
class MLPNet:
def __init__(self,is_train,mean=None,variance=None):
self.is_train = is_train
self.mean=mean
self.variance=variance
self.x = tf.placeholder(dtype=tf.float32,shape=[None,28,28,1])
self.y = tf.placeholder(dtype=tf.float32,shape=[None,10])#y为ONE-HOT
self.in_w = tf.Variable(tf.truncated_normal(shape=[3,3,1,10],stddev=0.1))
self.in_w_sacle = tf.Variable(tf.truncated_normal(shape=[10], stddev=0.1))
self.in_w_offset = tf.Variable(tf.truncated_normal(shape=[10], stddev=0.1))
self.out_w = tf.Variable(tf.truncated_normal(shape=[3,3,10,100],stddev=0.1))
self.out_w1 = tf.Variable(tf.truncated_normal(shape=[3, 3,100,10], stddev=0.1))
self.out_w2 = tf.Variable(tf.truncated_normal(shape=[7,7,10, 10], stddev=0.1))
def forward(self):
x = tf.nn.conv2d(input=self.x,filter=self.in_w,strides=[1,2,2,1],padding="SAME")
if self.is_train == True:
x, self.mean_train, self.varaince_train= tf.nn.fused_batch_norm(x,self.in_w_sacle,self.in_w_offset,is_training=True)
else:
x ,_,_= tf.nn.fused_batch_norm(x, self.in_w_sacle, self.in_w_offset,self.mean,self.variance,is_training=False)
x = tf.nn.tanh(x)
x = tf.nn.conv2d(input=x, filter=self.out_w, strides=[1, 2, 2, 1], padding="SAME")
x = tf.nn.relu(x)
x = tf.nn.conv2d(input=x, filter=self.out_w1, strides=[1, 1, 1, 1], padding="SAME")
x = tf.nn.relu(x)
x = tf.nn.conv2d(input=x, filter=self.out_w2, strides=[1, 7, 7, 1], padding="SAME")
self.output_f=tf.reshape(x,(-1,10))
self.output = tf.nn.softmax(self.output_f)
def backward(self):
self.loss = tf.reduce_mean((self.output-self.y)**2)
self.opt = tf.train.GradientDescentOptimizer(0.1).minimize(self.loss)
if __name__ == '__main__':
mean = None
variance = None
net = MLPNet(True)
net.forward()
net.backward()
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for epoch in range(10000):
xs,ys = mnist.train.next_batch(how_much)
xs = np.reshape(xs, (how_much,28, 28,1))
# my_iamge=Image.fromarray(np.uint8((xs+0.5)*255))
# my_iamge.resize((100,100))
# my_iamge.show()
# test_output = sess.run(net.output,feed_dict={net.x:xs,net.y:ys})
_loss, _,mean,variance = sess.run([net.loss, net.opt,net.mean_train,net.varaince_train],feed_dict={net.x:xs,net.y:ys})
if epoch % 100 ==0:
print("loss: ",_loss)
test_xs,test_ys = mnist.test.next_batch(how_much)
test_xs = np.reshape(test_xs, (how_much, 28, 28, 1))
test_output = sess.run(net.output,feed_dict={net.x:test_xs})
test_y = np.argmax(test_ys,axis=1)
test_out = np.argmax(test_output,axis=1)
print("acuracy: ",np.mean(np.array(test_y == test_out,dtype=np.float32)))
with tf.gfile.FastGFile("./train.pb", mode='wb') as fw:
fw.write(tf.get_default_graph().as_graph_def().SerializeToString())
saver = tf.train.Saver()
saver.save(sess,"./train.ckpt")
print("*****************************************************************")
with tf.Graph().as_default() as g:
net = MLPNet(False,mean,variance)
net.forward()
net.backward()
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
saver = tf.train.Saver()
saver.restore(sess=sess, save_path="./train.ckpt")
for epoch in range(100):
xs,ys = mnist.train.next_batch(how_much)
xs = np.reshape(xs, (how_much,28, 28,1))
# my_iamge=Image.fromarray(np.uint8((xs+0.5)*255))
# my_iamge.resize((100,100))
# my_iamge.show()
# test_output = sess.run(net.output,feed_dict={net.x:xs,net.y:ys})
if epoch % 100 ==0:
test_xs,test_ys = mnist.test.next_batch(how_much)
test_xs = np.reshape(test_xs, (how_much, 28, 28, 1))
test_output = sess.run(net.output,feed_dict={net.x:test_xs})
test_y = np.argmax(test_ys,axis=1)
test_out = np.argmax(test_output,axis=1)
print("acuracy: ",np.mean(np.array(test_y == test_out,dtype=np.float32)))
constant_graph = tf.graph_util.convert_variables_to_constants(sess,
tf.get_default_graph().as_graph_def(),
output_node_names=[net.output.name[:-2]])
with tf.gfile.FastGFile("./inf.pb", mode='wb') as fw:
fw.write(constant_graph.SerializeToString())
12-15
2万+
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)
“相关推荐”对你有帮助么?
-
非常没帮助
-
没帮助
-
一般
-
有帮助
-
非常有帮助
提交