手写数字识别

1、导入相关库及数据文件

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot = True)
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

2、构建输入层

定义标签数据占位符

x = tf.placeholder(tf.float32,[None,784],name = "X")
y = tf.placeholder(tf.float32,[None,10],name = "Y")

3、构建隐藏层

H1_NH = 256
W1 = tf.Variable(tf.random_normal([784,H1_NH]))
b1 = tf.Variable(tf.zeros([H1_NH]))
Y1 = tf.nn.relu(tf.matmul(x,W1)+b1)

4、构建输出层

W2 = tf.Variable(tf.random_normal([H1_NH,10]))
b2 = tf.Variable(tf.zeros([10]))

forward = tf.matmul(Y1,W2)+b2
pred = tf.nn.softmax(forward)

5、定义损失函数

# loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices = 1))

# Tensorflow提供了结合softmax函数。用于避免log(0)为NaA造成的数据的不稳定
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=forward,labels=y))

6、训练模型及其训练模型的保存


# 设置训练参数
train_epochs = 40
batch_size = 50
total_batch = int(mnist.train.num_examples/batch_size)
display_step = 1
learning_rate = 0.01


#存储模型粒度
save_step = 5

#创建保持模型文件目录
import os
ckpt_dir = "./ckpt_dir/"
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

# 声明所有变量后,调用
saver = tf.train.Saver()
    
# 选择优化器
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_function)


#定义准确率
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(pred,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))




from time import time
startTime = time()
sess = tf.Session()
sess.run(tf.global_variables_initializer())

for epoch in range(train_epochs):
    for batch in range(total_batch):
        xs,ys = mnist.train.next_batch(batch_size)
        sess.run(optimizer,feed_dict = {x:xs,y:ys})
        
    loss,acc = sess.run([loss_function,accuracy],feed_dict = {x:mnist.validation.images,y:mnist.validation.labels})
    if(epoch+1)%display_step == 0:
        print("Train Epoch:","%02d"%(epoch+1),"Loss = ","{:.9f}".format(loss),"Accuracy = ","{:.4f}".format(acc))
    
    if(epoch+1)%save_step == 0:
        saver.save(sess,os.path.join(ckpt_dir,"mnist_h256_model_{:06d}.ckpt".format(epoch+1)))#存储模型
        print("mnist_h256_model_{:06d}.ckpt saved".format(epoch+1))
saver.save(sess,os.path.join(ckpt_dir,"mnist_h256_model.ckpt"))
print("Model saved!")



duration = time()-startTime
print("Train Finished takes:","{:.2f}".format(duration))







Train Epoch: 01 Loss =  1.582666516 Accuracy =  0.9302
Train Epoch: 02 Loss =  0.805085957 Accuracy =  0.9494
Train Epoch: 03 Loss =  0.570328474 Accuracy =  0.9570
Train Epoch: 04 Loss =  0.505634725 Accuracy =  0.9552
Train Epoch: 05 Loss =  0.395196378 Accuracy =  0.9596
mnist_h256_model_000005.ckpt saved
Train Epoch: 06 Loss =  0.426108211 Accuracy =  0.9606
Train Epoch: 07 Loss =  0.454120547 Accuracy =  0.9598
Train Epoch: 08 Loss =  0.418259770 Accuracy =  0.9616
Train Epoch: 09 Loss =  0.456357628 Accuracy =  0.9608
Train Epoch: 10 Loss =  0.426814556 Accuracy =  0.9636
mnist_h256_model_000010.ckpt saved
Train Epoch: 11 Loss =  0.400999218 Accuracy =  0.9606
Train Epoch: 12 Loss =  0.318266839 Accuracy =  0.9694
Train Epoch: 13 Loss =  0.385480493 Accuracy =  0.9676
Train Epoch: 14 Loss =  0.386757165 Accuracy =  0.9644
Train Epoch: 15 Loss =  0.444276184 Accuracy =  0.9652
mnist_h256_model_000015.ckpt saved
Train Epoch: 16 Loss =  0.438010752 Accuracy =  0.9692
Train Epoch: 17 Loss =  0.432323277 Accuracy =  0.9680
Train Epoch: 18 Loss =  0.419930577 Accuracy =  0.9712
Train Epoch: 19 Loss =  0.378582388 Accuracy =  0.9708
Train Epoch: 20 Loss =  0.457510829 Accuracy =  0.9710
mnist_h256_model_000020.ckpt saved
Train Epoch: 21 Loss =  0.456655473 Accuracy =  0.9726
Train Epoch: 22 Loss =  0.487205356 Accuracy =  0.9676
Train Epoch: 23 Loss =  0.547173560 Accuracy =  0.9708
Train Epoch: 24 Loss =  0.543109298 Accuracy =  0.9702
Train Epoch: 25 Loss =  0.673086405 Accuracy =  0.9710
mnist_h256_model_000025.ckpt saved
Train Epoch: 26 Loss =  0.768935084 Accuracy =  0.9704
Train Epoch: 27 Loss =  0.695913255 Accuracy =  0.9698
Train Epoch: 28 Loss =  0.694508731 Accuracy =  0.9708
Train Epoch: 29 Loss =  0.640166819 Accuracy =  0.9718
Train Epoch: 30 Loss =  0.749314189 Accuracy =  0.9714
mnist_h256_model_000030.ckpt saved
Train Epoch: 31 Loss =  0.683746219 Accuracy =  0.9742
Train Epoch: 32 Loss =  0.681889772 Accuracy =  0.9722
Train Epoch: 33 Loss =  0.857060552 Accuracy =  0.9722
Train Epoch: 34 Loss =  0.748108029 Accuracy =  0.9750
Train Epoch: 35 Loss =  0.782744944 Accuracy =  0.9718
mnist_h256_model_000035.ckpt saved
Train Epoch: 36 Loss =  0.842093468 Accuracy =  0.9748
Train Epoch: 37 Loss =  0.949436128 Accuracy =  0.9688
Train Epoch: 38 Loss =  0.856811941 Accuracy =  0.9744
Train Epoch: 39 Loss =  0.851330876 Accuracy =  0.9724
Train Epoch: 40 Loss =  0.948819518 Accuracy =  0.9726
mnist_h256_model_000040.ckpt saved
Model saved!
Train Finished takes: 129.34

评估模型

accu_test = sess.run(accuracy,feed_dict = {x:mnist.test.images,y:mnist.test.labels})
print("Test Accuracy:",accu_test)
Test Accuracy: 0.9686

进行预测

prediction_result = sess.run(tf.argmax(pred,1),feed_dict = {x:mnist.test.images})


print(prediction_result[0:10])
[7 2 1 0 4 1 4 9 5 9]

找出预测错误

import numpy as np
compare_lists = prediction_result==np.argmax(mnist.test.labels,1)
print(compare_lists)


err_lists = [i for i in range(len(compare_lists)) if compare_lists[i]==False]
print(err_lists,len(err_lists))
[ True  True  True ...,  True  True  True]
[115, 211, 247, 259, 321, 381, 445, 447, 448, 449, 450, 495, 507, 543, 582, 659, 707, 720, 774, 781, 846, 874, 883, 938, 947, 951, 956, 1014, 1032, 1039, 1050, 1068, 1112, 1156, 1219, 1226, 1232, 1247, 1260, 1299, 1433, 1500, 1520, 1530, 1531, 1549, 1568, 1571, 1600, 1717, 1730, 1748, 1754, 1813, 1828, 1850, 1855, 1899, 1901, 1917, 1941, 2004, 2018, 2033, 2035, 2037, 2040, 2043, 2063, 2073, 2098, 2109, 2118, 2129, 2130, 2135, 2148, 2266, 2272, 2280, 2292, 2293, 2308, 2314, 2325, 2334, 2369, 2387, 2406, 2414, 2422, 2433, 2488, 2512, 2552, 2573, 2574, 2582, 2597, 2607, 2648, 2654, 2671, 2684, 2721, 2730, 2735, 2743, 2770, 2771, 2809, 2863, 2877, 2896, 2921, 2944, 2952, 2953, 2979, 3030, 3056, 3060, 3073, 3117, 3225, 3243, 3289, 3316, 3329, 3333, 3369, 3422, 3503, 3520, 3533, 3558, 3559, 3597, 3662, 3702, 3762, 3767, 3776, 3778, 3780, 3785, 3806, 3808, 3831, 3853, 3902, 3906, 3941, 3943, 3976, 4007, 4065, 4078, 4124, 4140, 4159, 4163, 4176, 4201, 4256, 4259, 4271, 4294, 4300, 4369, 4382, 4433, 4487, 4497, 4500, 4534, 4552, 4571, 4584, 4601, 4615, 4639, 4690, 4699, 4731, 4740, 4761, 4789, 4807, 4814, 4823, 4860, 4861, 4874, 4876, 4879, 4880, 4890, 4915, 4956, 4966, 4990, 5068, 5201, 5331, 5457, 5586, 5642, 5654, 5710, 5719, 5734, 5749, 5835, 5888, 5935, 5936, 5950, 5955, 5973, 6011, 6023, 6030, 6046, 6059, 6065, 6071, 6075, 6081, 6091, 6166, 6172, 6174, 6347, 6421, 6532, 6555, 6557, 6571, 6574, 6576, 6578, 6597, 6608, 6625, 6632, 6651, 6783, 7031, 7216, 7268, 7338, 7434, 7529, 7545, 7574, 7797, 7849, 7858, 7886, 7899, 7902, 7905, 7915, 7921, 7928, 7990, 8020, 8059, 8062, 8091, 8094, 8115, 8198, 8246, 8253, 8255, 8290, 8339, 8353, 8362, 8397, 8408, 8426, 8453, 8456, 8522, 8527, 9009, 9015, 9019, 9024, 9280, 9587, 9634, 9642, 9664, 9669, 9679, 9692, 9701, 9716, 9729, 9745, 9764, 9768, 9770, 9779, 9792, 9839, 9883, 9944, 9959, 9975] 314

定义一个输出错误分类的函数

def print_predict_errs(labels,prediction):
    count = 0
    compare_lists = (prediction == np.argmax(labels,1))
    err_lists = [i for i in range(len(compare_lists)) if compare_lists[i] == False]
    for x in err_lists:
        print("index"+str(x)+" 标签值=",np.argmax(labels[x]),"预测值=",prediction[x])
        count = count+1
    print("总计:"+str(count))
print_predict_errs(labels = mnist.test.labels,prediction=prediction_result)
index115 标签值= 4 预测值= 9
index211 标签值= 5 预测值= 8
index247 标签值= 4 预测值= 2
index259 标签值= 6 预测值= 0
index321 标签值= 2 预测值= 7
index381 标签值= 3 预测值= 7
index445 标签值= 6 预测值= 0
index447 标签值= 4 预测值= 9
index448 标签值= 9 预测值= 5
index449 标签值= 3 预测值= 5
index450 标签值= 3 预测值= 5
index495 标签值= 8 预测值= 2
index507 标签值= 3 预测值= 5
index543 标签值= 8 预测值= 5
index582 标签值= 8 预测值= 2
index659 标签值= 2 预测值= 1
index707 标签值= 4 预测值= 9
index720 标签值= 5 预测值= 8
index774 标签值= 4 预测值= 9
index781 标签值= 8 预测值= 9
index846 标签值= 7 预测值= 4
index874 标签值= 9 预测值= 4
index883 标签值= 3 预测值= 5
index938 标签值= 3 预测值= 5
index947 标签值= 8 预测值= 9
index951 标签值= 5 预测值= 4
index956 标签值= 1 预测值= 2
index1014 标签值= 6 预测值= 0
index1032 标签值= 5 预测值= 6
index1039 标签值= 7 预测值= 2
index1050 标签值= 2 预测值= 3
index1068 标签值= 8 预测值= 1
index1112 标签值= 4 预测值= 6
index1156 标签值= 7 预测值= 8
index1219 标签值= 8 预测值= 1
index1226 标签值= 7 预测值= 2
index1232 标签值= 9 预测值= 6
index1247 标签值= 9 预测值= 5
index1260 标签值= 7 预测值= 1
index1299 标签值= 5 预测值= 7
index1433 标签值= 8 预测值= 1
index1500 标签值= 7 预测值= 1
index1520 标签值= 7 预测值= 2
index1530 标签值= 8 预测值= 7
index1531 标签值= 3 预测值= 5
index1549 标签值= 4 预测值= 6
index1568 标签值= 8 预测值= 3
index1571 标签值= 4 预测值= 9
index1600 标签值= 3 预测值= 5
index1717 标签值= 8 预测值= 0
index1730 标签值= 3 预测值= 5
index1748 标签值= 0 预测值= 4
index1754 标签值= 7 预测值= 2
index1813 标签值= 8 预测值= 3
index1828 标签值= 3 预测值= 5
index1850 标签值= 8 预测值= 7
index1855 标签值= 8 预测值= 3
index1899 标签值= 8 预测值= 3
index1901 标签值= 9 预测值= 4
index1917 标签值= 5 预测值= 8
index1941 标签值= 7 预测值= 8
index2004 标签值= 8 预测值= 3
index2018 标签值= 1 预测值= 8
index2033 标签值= 0 预测值= 4
index2035 标签值= 5 预测值= 3
index2037 标签值= 5 预测值= 8
index2040 标签值= 5 预测值= 8
index2043 标签值= 4 预测值= 8
index2063 标签值= 7 预测值= 2
index2073 标签值= 5 预测值= 6
index2098 标签值= 2 预测值= 0
index2109 标签值= 3 预测值= 7
index2118 标签值= 6 预测值= 0
index2129 标签值= 9 预测值= 4
index2130 标签值= 4 预测值= 9
index2135 标签值= 6 预测值= 1
index2148 标签值= 4 预测值= 9
index2266 标签值= 1 预测值= 6
index2272 标签值= 8 预测值= 0
index2280 标签值= 3 预测值= 5
index2292 标签值= 9 预测值= 5
index2293 标签值= 9 预测值= 0
index2308 标签值= 3 预测值= 5
index2314 标签值= 7 预测值= 9
index2325 标签值= 7 预测值= 2
index2334 标签值= 7 预测值= 3
index2369 标签值= 5 预测值= 4
index2387 标签值= 9 预测值= 1
index2406 标签值= 9 预测值= 4
index2414 标签值= 9 预测值= 4
index2422 标签值= 6 预测值= 4
index2433 标签值= 2 预测值= 1
index2488 标签值= 2 预测值= 4
index2512 标签值= 8 预测值= 5
index2552 标签值= 8 预测值= 3
index2573 标签值= 5 预测值= 1
index2574 标签值= 5 预测值= 3
index2582 标签值= 9 预测值= 5
index2597 标签值= 5 预测值= 3
index2607 标签值= 7 预测值= 1
index2648 标签值= 9 预测值= 5
index2654 标签值= 6 预测值= 1
index2671 标签值= 7 预测值= 2
index2684 标签值= 3 预测值= 7
index2721 标签值= 6 预测值= 5
index2730 标签值= 7 预测值= 4
index2735 标签值= 9 预测值= 7
index2743 标签值= 5 预测值= 8
index2770 标签值= 3 预测值= 8
index2771 标签值= 4 预测值= 9
index2809 标签值= 8 预测值= 1
index2863 标签值= 9 预测值= 4
index2877 标签值= 4 预测值= 7
index2896 标签值= 8 预测值= 0
index2921 标签值= 3 预测值= 8
index2944 标签值= 0 预测值= 2
index2952 标签值= 3 预测值= 5
index2953 标签值= 3 预测值= 5
index2979 标签值= 9 预测值= 7
index3030 标签值= 6 预测值= 0
index3056 标签值= 9 预测值= 7
index3060 标签值= 9 预测值= 4
index3073 标签值= 1 预测值= 2
index3117 标签值= 5 预测值= 3
index3225 标签值= 7 预测值= 9
index3243 标签值= 0 预测值= 6
index3289 标签值= 8 预测值= 5
index3316 标签值= 7 预测值= 4
index3329 标签值= 7 预测值= 2
index3333 标签值= 7 预测值= 9
index3369 标签值= 9 预测值= 7
index3422 标签值= 6 预测值= 0
index3503 标签值= 9 预测值= 1
index3520 标签值= 6 预测值= 4
index3533 标签值= 4 预测值= 7
index3558 标签值= 5 预测值= 0
index3559 标签值= 8 预测值= 5
index3597 标签值= 9 预测值= 5
index3662 标签值= 8 预测值= 2
index3702 标签值= 5 预测值= 3
index3762 标签值= 6 预测值= 8
index3767 标签值= 7 预测值= 3
index3776 标签值= 5 预测值= 8
index3778 标签值= 5 预测值= 2
index3780 标签值= 4 预测值= 6
index3785 标签值= 2 预测值= 1
index3806 标签值= 5 预测值= 8
index3808 标签值= 7 预测值= 1
index3831 标签值= 9 预测值= 4
index3853 标签值= 6 预测值= 0
index3902 标签值= 5 预测值= 3
index3906 标签值= 1 预测值= 3
index3941 标签值= 4 预测值= 6
index3943 标签值= 3 预测值= 5
index3976 标签值= 7 预测值= 1
index4007 标签值= 7 预测值= 8
index4065 标签值= 0 预测值= 9
index4078 标签值= 9 预测值= 8
index4124 标签值= 8 预测值= 9
index4140 标签值= 8 预测值= 2
index4159 标签值= 8 预测值= 2
index4163 标签值= 9 预测值= 7
index4176 标签值= 2 预测值= 7
index4201 标签值= 1 预测值= 7
index4256 标签值= 3 预测值= 2
index4259 标签值= 9 预测值= 4
index4271 标签值= 5 预测值= 3
index4294 标签值= 9 预测值= 7
index4300 标签值= 5 预测值= 8
index4369 标签值= 9 预测值= 4
index4382 标签值= 4 预测值= 9
index4433 标签值= 7 预测值= 4
index4487 标签值= 7 预测值= 8
index4497 标签值= 8 预测值= 7
index4500 标签值= 9 预测值= 1
index4534 标签值= 9 预测值= 8
index4552 标签值= 3 预测值= 5
index4571 标签值= 6 预测值= 3
index4584 标签值= 9 预测值= 4
index4601 标签值= 8 预测值= 4
index4615 标签值= 2 预测值= 4
index4639 标签值= 8 预测值= 9
index4690 标签值= 7 预测值= 2
index4699 标签值= 6 预测值= 2
index4731 标签值= 8 预测值= 2
index4740 标签值= 3 预测值= 4
index4761 标签值= 9 预测值= 4
index4789 标签值= 8 预测值= 5
index4807 标签值= 8 预测值= 2
index4814 标签值= 6 预测值= 0
index4823 标签值= 9 预测值= 6
index4860 标签值= 4 预测值= 9
index4861 标签值= 7 预测值= 3
index4874 标签值= 9 预测值= 6
index4876 标签值= 2 预测值= 4
index4879 标签值= 8 预测值= 6
index4880 标签值= 0 预测值= 8
index4890 标签值= 8 预测值= 2
index4915 标签值= 5 预测值= 8
index4956 标签值= 8 预测值= 4
index4966 标签值= 7 预测值= 4
index4990 标签值= 3 预测值= 2
index5068 标签值= 4 预测值= 7
index5201 标签值= 4 预测值= 9
index5331 标签值= 1 预测值= 6
index5457 标签值= 1 预测值= 0
index5586 标签值= 8 预测值= 2
index5642 标签值= 1 预测值= 8
index5654 标签值= 7 预测值= 2
index5710 标签值= 8 预测值= 1
index5719 标签值= 9 预测值= 7
index5734 标签值= 3 预测值= 7
index5749 标签值= 8 预测值= 5
index5835 标签值= 7 预测值= 2
index5888 标签值= 4 预测值= 6
index5935 标签值= 3 预测值= 5
index5936 标签值= 4 预测值= 9
index5950 标签值= 8 预测值= 1
index5955 标签值= 3 预测值= 8
index5973 标签值= 3 预测值= 8
index6011 标签值= 3 预测值= 5
index6023 标签值= 3 预测值= 5
index6030 标签值= 3 预测值= 1
index6046 标签值= 3 预测值= 8
index6059 标签值= 3 预测值= 9
index6065 标签值= 3 预测值= 8
index6071 标签值= 9 预测值= 3
index6075 标签值= 3 预测值= 5
index6081 标签值= 9 预测值= 5
index6091 标签值= 9 预测值= 5
index6166 标签值= 9 预测值= 5
index6172 标签值= 9 预测值= 5
index6174 标签值= 3 预测值= 5
index6347 标签值= 8 预测值= 6
index6421 标签值= 3 预测值= 2
index6532 标签值= 0 预测值= 7
index6555 标签值= 8 预测值= 9
index6557 标签值= 0 预测值= 2
index6571 标签值= 9 预测值= 7
index6574 标签值= 2 预测值= 6
index6576 标签值= 7 预测值= 1
index6578 标签值= 8 预测值= 3
index6597 标签值= 0 预测值= 7
index6608 标签值= 9 预测值= 5
index6625 标签值= 8 预测值= 4
index6632 标签值= 9 预测值= 8
index6651 标签值= 0 预测值= 5
index6783 标签值= 1 预测值= 6
index7031 标签值= 0 预测值= 2
index7216 标签值= 0 预测值= 5
index7268 标签值= 7 预测值= 4
index7338 标签值= 4 预测值= 6
index7434 标签值= 4 预测值= 8
index7529 标签值= 3 预测值= 5
index7545 标签值= 8 预测值= 3
index7574 标签值= 4 预测值= 1
index7797 标签值= 5 预测值= 6
index7849 标签值= 3 预测值= 2
index7858 标签值= 3 预测值= 9
index7886 标签值= 2 预测值= 4
index7899 标签值= 1 预测值= 8
index7902 标签值= 7 预测值= 8
index7905 标签值= 3 预测值= 2
index7915 标签值= 7 预测值= 8
index7921 标签值= 8 预测值= 2
index7928 标签值= 1 预测值= 8
index7990 标签值= 1 预测值= 8
index8020 标签值= 1 预测值= 8
index8059 标签值= 2 预测值= 1
index8062 标签值= 5 预测值= 8
index8091 标签值= 2 预测值= 1
index8094 标签值= 2 预测值= 1
index8115 标签值= 3 预测值= 5
index8198 标签值= 2 预测值= 4
index8246 标签值= 3 预测值= 9
index8253 标签值= 2 预测值= 4
index8255 标签值= 4 预测值= 8
index8290 标签值= 3 预测值= 1
index8339 标签值= 8 预测值= 0
index8353 标签值= 2 预测值= 4
index8362 标签值= 3 预测值= 5
index8397 标签值= 3 预测值= 5
index8408 标签值= 8 预测值= 2
index8426 标签值= 9 预测值= 4
index8453 标签值= 5 预测值= 3
index8456 标签值= 8 预测值= 6
index8522 标签值= 8 预测值= 6
index8527 标签值= 4 预测值= 9
index9009 标签值= 7 预测值= 2
index9015 标签值= 7 预测值= 2
index9019 标签值= 7 预测值= 2
index9024 标签值= 7 预测值= 2
index9280 标签值= 8 预测值= 5
index9587 标签值= 9 预测值= 4
index9634 标签值= 0 预测值= 1
index9642 标签值= 9 预测值= 7
index9664 标签值= 2 预测值= 7
index9669 标签值= 4 预测值= 7
index9679 标签值= 6 预测值= 2
index9692 标签值= 9 预测值= 7
index9701 标签值= 9 预测值= 7
index9716 标签值= 2 预测值= 5
index9729 标签值= 5 预测值= 6
index9745 标签值= 4 预测值= 0
index9764 标签值= 4 预测值= 0
index9768 标签值= 2 预测值= 0
index9770 标签值= 5 预测值= 0
index9779 标签值= 2 预测值= 0
index9792 标签值= 4 预测值= 9
index9839 标签值= 2 预测值= 7
index9883 标签值= 5 预测值= 1
index9944 标签值= 3 预测值= 8
index9959 标签值= 8 预测值= 1
index9975 标签值= 3 预测值= 2
总计:314

定义可视化函数

import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images,labels,prediction,index,num=10):
    fig = plt.gcf()
    fig.set_size_inches(10,12)
    if num>25:
        num = 25
    for i in range(0,num):
        ax = plt.subplot(5,5,i+1)
        ax.imshow(np.reshape(images[index],(28,28)),cmap = "binary")
        title = "label="+ str(np.argmax(labels[index]))
        if len(prediction)>0:
            title+=",predict="+str(prediction[index])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
        index +=1
    plt.show()
plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,610,20)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Zhr3px8r-1621262538097)(output_22_0.png)]

TensorBoard

构建输入层

# x = tf.placeholder(tf.float32, [None,784], name = "X")
image_shaped_input = tf.reshape(x,[-1,28,28,1])
tf.summary.image("input",image_shaped_input,10)
<tf.Tensor 'input:0' shape=() dtype=string>
tf.summary.histogram('forward',forward)
<tf.Tensor 'forward:0' shape=() dtype=string>
tf.summary.scalar("loss",loss_function)
<tf.Tensor 'loss:0' shape=() dtype=string>
tf.summary.scalar("accuracy",accuracy)
<tf.Tensor 'accuracy:0' shape=() dtype=string>
sess = tf.Session()
sess.run(tf.global_variables_initializer())
merged_summary_op = tf.summary.merge_all()
writer = tf.summary.FileWriter('log/',sess.graph)
for epoch in range(train_epochs):
    for batch in range(total_batch):
        xs,ys = mnist.train.next_batch(batch_size)
        sess.run(optimizer,feed_dict = {x:xs,y:ys})
        
        summary_str = sess.run(merged_summary_op,feed_dict={x:xs,y:ys})
        writer.add_summary(summary_str,epoch)
    loss,acc = sess.run([loss_function,accuracy],feed_dict = {x:mnist.validation.images,y:mnist.validation.labels})
print("保存成功")
保存成功

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值