博文配套视频课程:24小时实现从零到AI人工智能
zip函数
zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。如果各个迭代器的元素个数不一致,则返回列表长度与最短的对象相同
a = [1,2,3]
b = [4,5,6,7]
for i in zip(a,b):
print(i)
# 输出结果如下
# (1, 4)
# (2, 5)
# (3, 6)
enumerate函数
enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中
a = [1,2,3]
b = [4,5,6,7]
for i,j in enumerate(zip(a,b)):
print(i,j)
# 输出结果如下
# 0 (1, 4)
# 1 (2, 5)
# 2 (3, 6)
可视化完整代码
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
# y_predict 本质是一个Tensor
guess = sess.run(y_predict, feed_dict=d)
images_labels_guess = zip(mnist_x, mnist_y, guess)
# figsize=(20,18) 代表图片宽与高,英寸
plt.figure(figsize=(20, 18), dpi=200)
# 循环获取每一次 图片 + 目标值 + 预测值
for index, (image, label, guess) in enumerate(images_labels_guess):
# 55个显示区域
plt.subplot(5, 11, index + 1)
image = image.reshape(28, 28)
plt.imshow(image, cmap=plt.cm.gray_r)
if index == 0:
print(label, guess)
val_t, val_p = sess.run(tf.argmax(label)), sess.run(tf.argmax(guess))
if val_t != val_p:
color = '#ff0000'
else:
color = '#000000'
plt.title(f'真:{val_t},预:{val_p}',fontsize=20,color=color)
plt.show()
手写数字可视化结果
红色部分代表识别失败,采用DNN深度神经网络进行优化后,如果在训练次数达到2000时,训练的正确率可以达到95%左右。