defoptimize(num_iterations):for i inrange(num_iterations):# Get a batch of training examples.# x_batch now holds a batch of images and# y_true_batch are the true labels for those images.
x_batch, y_true_batch, _ = data.random_batch(batch_size=batch_size)# Put the batch into a dict with the proper names# for placeholder variables in the TensorFlow graph.# Note that the placeholder for y_true_cls is not set# because it is not used during training.
feed_dict_train ={x: x_batch,y_true: y_true_batch}# Run the optimizer using this batch of training data.# TensorFlow assigns the variables in feed_dict_train# to the placeholder variables and then runs the optimizer.
session.run(optimizer, feed_dict=feed_dict_train)# 找损失函数绑定的元素# 回溯, 哪些节点相关的# 回溯, 哪些输入依赖的# 计算损失值# 计算梯队# 更新参数# 但是只更细一次参数
4.2 其他帮忙函数
defplot_images(images, cls_true, cls_pred=None):assertlen(images)==len(cls_true)==9# Create figure with 3x3 sub-plots.
fig, axes = plt.subplots(3,3)
fig.subplots_adjust(hspace=0.3, wspace=0.3)for i, ax inenumerate(axes.flat):# Plot image.
ax.imshow(images[i].reshape(img_shape), cmap='binary')# Show true and predicted classes.if cls_pred isNone:
xlabel ="True: {0}".format(cls_true[i])else:
xlabel ="True: {0}, Pred: {1}".format(cls_true[i], cls_pred[i])
ax.set_xlabel(xlabel)# Remove ticks from the plot.
ax.set_xticks([])
ax.set_yticks([])
plt.show()
defprint_confusion_matrix():# Get the true classifications for the test-set.
cls_true = data.y_test_cls
# Get the predicted classifications for the test-set.
cls_pred = session.run(y_pred_cls, feed_dict=feed_dict_test)# Get the confusion matrix using sklearn.
cm = confusion_matrix(y_true=cls_true,
y_pred=cls_pred)# Print the confusion matrix as text.print(cm)# Plot the confusion matrix as an image.
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)# Make various adjustments to the plot.
plt.tight_layout()
plt.colorbar()
tick_marks = np.arange(num_classes)
plt.xticks(tick_marks,range(num_classes))
plt.yticks(tick_marks,range(num_classes))
plt.xlabel('Predicted')
plt.ylabel('True')# Ensure the plot is shown correctly with multiple plots# in a single Notebook cell.
plt.show()
defplot_example_errors():# Use TensorFlow to get a list of boolean values# whether each test-image has been correctly classified,# and a list for the predicted class of each image.
correct, cls_pred = session.run([correct_prediction, y_pred_cls],
feed_dict=feed_dict_test)
incorrect =(correct ==False)# Get the images from the test-set that have been incorrectly classified.
images = data.x_test[incorrect]# Get the predicted classes for those images.
cls_pred = cls_pred[incorrect]# Get the true classes for those images.
cls_true = data.y_test_cls[incorrect]# Plot the first 9 images.
plot_images(images=images[0:9],
cls_true=cls_true[0:9],
cls_pred=cls_pred[0:9])
defprint_accuracy():# Use TensorFlow to compute the accuracy.
acc = session.run(accuracy, feed_dict=feed_dict_test)# Print the accuracy.print("Accuracy on test-set: {0:.1%}".format(acc))