代码
import prettytensor as pt
import time
from datetime import timedelta
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from sklearn.metrics import confusion_matrix
from tensorflow.examples.tutorials.mnist import input_data
data = input_data.read_data_sets('data/MNIST/', one_hot=True)
data.test.cls = np.argmax(data.test.labels, axis=1)
img_size = 28
img_size_flat = img_size * img_size
img_shape = (img_size, img_size)
num_channels = 1
num_classes = 10
x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='x')
x_image = tf.reshape(x, [-1, img_size, img_size, num_channels])
x_pretty = pt.wrap(x_image)
y_true = tf.placeholder(tf.float32, shape=[None, 10], name='y_true')
y_true_cls = tf.argmax(y_true, dimension=1)
with pt.defaults_scope(activation_fn=tf.nn.relu):
y_pred, loss = x_pretty.conv2d(kernel=5, depth=16, name='layer_conv1') \
.max_pool(kernel=2, stride=2).conv2d(kernel=5, depth=36, name='layer_conv2') \
.max_pool(kernel=2, stride=2).flatten().fully_connected(size=128, name='layer_fc1') \
.softmax_classifier(num_classes=num_classes, labels=y_true)
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)
y_pred_cls = tf.argmax(y_pred, dimension=1)
correct_prediction = tf.equal(y_pred_cls, y_true_cls)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
session = tf.Session()
session.run(tf.global_variables_initializer())
train_batch_size = 64
total_iterations = 0
def optimize(num_iterations):
global total_iterations
start_time = time.time()
for i in range(total_iterations,
total_iterations + num_iterations):
x_batch, y_true_batch = data.train.next_batch(train_batch_size)
feed_dict_train = {x: x_batch,
y_true: y_true_batch}
session.run(optimizer, feed_dict=feed_dict_train)
if i % 100 == 0:
acc = session.run(accuracy, feed_dict=feed_dict_train)
msg = "Optimization Iteration: {0:>6}, Training Accuracy: {1:>6.1%}"
print(msg.format(i + 1, acc))
total_iterations += num_iterations
end_time = time.time()
time_dif = end_time - start_time
print("Time usage: " + str(timedelta(seconds=int(round(time_dif)))))
test_batch_size = 256
def print_test_accuracy():
num_test = len(data.test.images)
cls_pred = np.zeros(shape=num_test, dtype=np.int)
i = 0
while i < num_test:
j = min(i + test_batch_size, num_test)
images = data.test.images[i:j, :]
labels = data.test.labels[i:j, :]
feed_dict = {x: images,
y_true: labels}
cls_pred[i:j] = session.run(y_pred_cls, feed_dict=feed_dict)
i = j
cls_true = data.test.cls
correct = (cls_true == cls_pred)
correct_sum = correct.sum()
acc = float(correct_sum) / num_test
msg = "Accuracy on Test-Set: {0:.1%} ({1} / {2})"
print(msg.format(acc, correct_sum, num_test))
plot_confusion_matrix(cls_pred)
plot_example_errors(cls_pred,correct)
def plot_confusion_matrix(cls_pred):
cls_true = data.test.cls
cm = confusion_matrix(y_true=cls_true,y_pred=cls_pred)
print(cm)
plt.matshow(cm)
plt.colorbar()
tick_marks = np.arange(num_classes)
plt.xticks(tick_marks, range(num_classes))
plt.yticks(tick_marks, range(num_classes))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()
def plot_example_errors(cls_pred, correct):
incorrect = (correct == False)
images = data.test.images[incorrect]
cls_pred = cls_pred[incorrect]
cls_true = data.test.cls[incorrect]
fig, axes = plt.subplots(3, 3)
fig.subplots_adjust(hspace=0.3, wspace=0.3)
for i, ax in enumerate(axes.flat):
ax.imshow(images[i].reshape(img_shape), cmap='binary')
xlabel = "True: {0}, Pred: {1}".format(cls_true[i], cls_pred[i])
ax.set_xlabel(xlabel)
ax.set_xticks([])
ax.set_yticks([])
plt.show()
optimize(num_iterations=5000)
print_test_accuracy()
效果