tensorflow batch normalization use example

import sys, os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("./MNIST_data/",one_hot=True)
import numpy as np
from PIL import Image

how_much=100
class MLPNet:
    def __init__(self,is_train,mean=None,variance=None):

        self.is_train = is_train
        self.mean=mean
        self.variance=variance

        self.x = tf.placeholder(dtype=tf.float32,shape=[None,28,28,1])

        self.y = tf.placeholder(dtype=tf.float32,shape=[None,10])#y为ONE-HOT

        self.in_w = tf.Variable(tf.truncated_normal(shape=[3,3,1,10],stddev=0.1))
        self.in_w_sacle = tf.Variable(tf.truncated_normal(shape=[10], stddev=0.1))
        self.in_w_offset = tf.Variable(tf.truncated_normal(shape=[10], stddev=0.1))

        self.out_w = tf.Variable(tf.truncated_normal(shape=[3,3,10,100],stddev=0.1))

        self.out_w1 = tf.Variable(tf.truncated_normal(shape=[3, 3,100,10], stddev=0.1))

        self.out_w2 = tf.Variable(tf.truncated_normal(shape=[7,7,10, 10], stddev=0.1))

    def forward(self):

        x = tf.nn.conv2d(input=self.x,filter=self.in_w,strides=[1,2,2,1],padding="SAME")
        if self.is_train == True:
            x, self.mean_train, self.varaince_train= tf.nn.fused_batch_norm(x,self.in_w_sacle,self.in_w_offset,is_training=True)
        else:
            x ,_,_= tf.nn.fused_batch_norm(x, self.in_w_sacle, self.in_w_offset,self.mean,self.variance,is_training=False)
        x = tf.nn.tanh(x)

        x = tf.nn.conv2d(input=x, filter=self.out_w, strides=[1, 2, 2, 1], padding="SAME")
        x = tf.nn.relu(x)

        x = tf.nn.conv2d(input=x, filter=self.out_w1, strides=[1, 1, 1, 1], padding="SAME")
        x = tf.nn.relu(x)

        x = tf.nn.conv2d(input=x, filter=self.out_w2, strides=[1, 7, 7, 1], padding="SAME")

        self.output_f=tf.reshape(x,(-1,10))
        self.output = tf.nn.softmax(self.output_f)

    def backward(self):

        self.loss = tf.reduce_mean((self.output-self.y)**2)
        self.opt = tf.train.GradientDescentOptimizer(0.1).minimize(self.loss)
if __name__ == '__main__':
    mean = None
    variance = None
    net = MLPNet(True)
    net.forward()
    net.backward()
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)

        for epoch in range(10000):
            xs,ys = mnist.train.next_batch(how_much)
            xs = np.reshape(xs, (how_much,28, 28,1))
            # my_iamge=Image.fromarray(np.uint8((xs+0.5)*255))
            # my_iamge.resize((100,100))
            # my_iamge.show()

            # test_output = sess.run(net.output,feed_dict={net.x:xs,net.y:ys})
            _loss, _,mean,variance = sess.run([net.loss, net.opt,net.mean_train,net.varaince_train],feed_dict={net.x:xs,net.y:ys})

            if epoch % 100 ==0:

                print("loss: ",_loss)


                test_xs,test_ys = mnist.test.next_batch(how_much)
                test_xs = np.reshape(test_xs, (how_much, 28, 28, 1))
                test_output = sess.run(net.output,feed_dict={net.x:test_xs})

                test_y = np.argmax(test_ys,axis=1)
                test_out = np.argmax(test_output,axis=1)
                print("acuracy: ",np.mean(np.array(test_y == test_out,dtype=np.float32)))
        with tf.gfile.FastGFile("./train.pb", mode='wb') as fw:
            fw.write(tf.get_default_graph().as_graph_def().SerializeToString())
        saver = tf.train.Saver()
        saver.save(sess,"./train.ckpt")

    print("*****************************************************************")
    with tf.Graph().as_default() as g:
        net = MLPNet(False,mean,variance)
        net.forward()
        net.backward()
        init = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init)
            saver = tf.train.Saver()
            saver.restore(sess=sess, save_path="./train.ckpt")
            for epoch in range(100):
                xs,ys = mnist.train.next_batch(how_much)
                xs = np.reshape(xs, (how_much,28, 28,1))
                # my_iamge=Image.fromarray(np.uint8((xs+0.5)*255))
                # my_iamge.resize((100,100))
                # my_iamge.show()

                # test_output = sess.run(net.output,feed_dict={net.x:xs,net.y:ys})
                if epoch % 100 ==0:
                    test_xs,test_ys = mnist.test.next_batch(how_much)
                    test_xs = np.reshape(test_xs, (how_much, 28, 28, 1))
                    test_output = sess.run(net.output,feed_dict={net.x:test_xs})

                    test_y = np.argmax(test_ys,axis=1)
                    test_out = np.argmax(test_output,axis=1)
                    print("acuracy: ",np.mean(np.array(test_y == test_out,dtype=np.float32)))

            constant_graph = tf.graph_util.convert_variables_to_constants(sess,
                                                                          tf.get_default_graph().as_graph_def(),
                                                                          output_node_names=[net.output.name[:-2]])

            with tf.gfile.FastGFile("./inf.pb", mode='wb') as fw:
                fw.write(constant_graph.SerializeToString())

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值