Tensorflow2.0之Minist手写数字识别

Tensorflow2.0之Minist手写数字识别

注: 完整代码在最后

Minist数据集介绍

 Minist数据集是(Lecun, Bottou, Bengio, & Haffner, 1998)发布的,它包含了0~9 共10 种数字的手写图片,每种数字一共有7000 张图片,采集自不同书写风格的真实手写图片,一共70000 张图片。其中60000张图片作为训练集𝔻

      t
     
     
      r
     
     
      a
     
     
      i
     
     
      n
     
    
   
  
  
   ^{train}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.824664em; vertical-align: 0em;"></span><span class="mord"><span class=""></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.824664em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">t</span><span class="mord mathdefault mtight" style="margin-right: 0.02778em;">r</span><span class="mord mathdefault mtight">a</span><span class="mord mathdefault mtight">i</span><span class="mord mathdefault mtight">n</span></span></span></span></span></span></span></span></span></span></span></span></span>(Training Set),用来训练模型,剩下10000 张图片作为测试集𝔻<span class="katex--inline"><span class="katex"><span class="katex-mathml">

 
  
   
    
    
     
      t
     
     
      e
     
     
      s
     
     
      t
     
    
   
  
  
   ^{test}
  
 
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.793556em; vertical-align: 0em;"></span><span class="mord"><span class=""></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.793556em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">t</span><span class="mord mathdefault mtight">e</span><span class="mord mathdefault mtight">s</span><span class="mord mathdefault mtight">t</span></span></span></span></span></span></span></span></span></span></span></span></span>(Test Set),用来预测或者测试,训练集和测试集共同组成了整个MNIST 数据集。</p> 

 考虑到手写数字图片包含的信息比较简单,每张图片均被缩放到28 × 28的大小,同时
只保留了灰度信息

 现在我们来看下图片的表示方法。一张图片包含了ℎ行(Height/Row),𝑤(Width/Column),每个位置保存了像素(Pixel)值,像素值一般使用0~255 的整形数值来表达颜色强度信息,例如0 表示强度最低,255 表示强度最高。如果是彩色图片,则每个像素点包含了R、G、B 三个通道的强度信息,分别代表红色通道、绿色通道、蓝色通道的颜色强度,所以与灰度图片不同,它的每个像素点使用一个1 维、长度为3 的向量(Vector)来表示,向量的3 个元素依次代表了当前像素点上面的R、G、B 颜色强值,因此彩色图片需要保存为形状是[ℎ, 𝑤, 3]的张量(Tensor,可以通俗地理解为3 维数组)。如果是灰度图片,则使用一个数值来表示灰度强度,例如0 表示纯黑,255 表示纯白,因此它只需要一个形状为[ℎ, 𝑤]的二维矩阵(Matrix)来表示一张图片信息(也可以保存为[ℎ, 𝑤, 1]形状的张量)。图 3.3 演示了内容为8 的数字图片的矩阵内容,可以看到,图片中黑色的像素用0 表示,灰度信息用0~255 表示,图片中灰度越白的像素点,对应矩阵位置中数值也就越大。

在这里插入图片描述

网络结构介绍

 本文中使用的简单的三层神经网络:

     o
    
    
     u
    
    
     t
    
    
     =
    
    
     r
    
    
     e
    
    
     l
    
    
     u
    
    
     {
    
    
     &nbsp;
    
    
     r
    
    
     e
    
    
     l
    
    
     u
    
    
     {
    
    
     &nbsp;
    
    
     [
    
    
     X
    
    
     @
    
    
     
      W
     
     
      1
     
    
    
     +
    
    
     
      b
     
     
      1
     
    
    
     ]
    
    
     @
    
    
     
      W
     
     
      2
     
    
    
     +
    
    
     
      b
     
     
      2
     
    
    
     }
    
    
     &nbsp;
    
    
     @
    
    
     
      W
     
     
      3
     
    
    
     +
    
    
     
      b
     
     
      3
     
    
    
     }
    
    
     &nbsp;
    
   
   
     out=relu \{\ relu\{\ [X@W_1+b_1]@W_2+b_2\}\ @W_3+b_3 \}\ 
   
  
 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.61508em; vertical-align: 0em;"></span><span class="mord mathdefault">o</span><span class="mord mathdefault">u</span><span class="mord mathdefault">t</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathdefault" style="margin-right: 0.02778em;">r</span><span class="mord mathdefault">e</span><span class="mord mathdefault" style="margin-right: 0.01968em;">l</span><span class="mord mathdefault">u</span><span class="mopen">{<!-- --></span><span class="mspace">&nbsp;</span><span class="mord mathdefault" style="margin-right: 0.02778em;">r</span><span class="mord mathdefault">e</span><span class="mord mathdefault" style="margin-right: 0.01968em;">l</span><span class="mord mathdefault">u</span><span class="mopen">{<!-- --></span><span class="mspace">&nbsp;</span><span class="mopen">[</span><span class="mord mathdefault" style="margin-right: 0.07847em;">X</span><span class="mord">@</span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathdefault">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mclose">]</span><span class="mord">@</span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathdefault">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mclose">}</span><span class="mspace">&nbsp;</span><span class="mord">@</span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathdefault">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mclose">}</span><span class="mspace">&nbsp;</span></span></span></span></span></span><br>  out 可以套上激活函数也可以不用套<br>  我们采用的数据集是MNIST 手写数字图片集,输入节点数为784,第一层的输出节点数是256,第二层的输出节点数是128,第三层的输出节点是10,也就是当前样本属于10 类别的概率。</p> 

代码部分

导入相应的包

from matplotlib import pyplot as mp
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers

 
 
  • 1
  • 2
  • 3

预处理函数

 从 keras.datasets 中加载的数据集的格式大部分情况都不能满足模型的输入要求,因此需要根据用户的逻辑自己实现预处理函数。Dataset 对象通过提供map(func)工具函数可以非常方便地调用用户自定义的预处理逻辑,它实现在func 函数里:

# 预处理函数实现在preprocess 函数中,传入函数引用即可
train_db = train_db.map(preprocess)

 
 
  • 1
  • 2

 考虑 MNIST 手写数字图片,从keras.datasets 中经.batch()后加载的图片x shape 为[𝑏, 28,28],像素使用0~255 的整形表示;标注shape 为[𝑏],即采样的数字编码方式。实际的神经网络输入,一般需要将图片数据标准化到[0,1]或[−1,1]等0 附近区间,同时根据网络的设置,需要将shape [28,28] 的输入Reshape 为合法的格式;对于标注信息,可以选择在预处理时进行one-hot 编码,也可以在计算误差时进行one-hot 编码。

 同时,我们将MNIST 图片数据映射到𝑥 ∈ [0,1]区间,视图调整为
[𝑏, 28 ∗ 28];对于标注y,我们选择在预处理函数里面进行one-hot 编码:

def preprocess(x, y): # 自定义的预处理函数
	# 调用此函数时会自动传入x,y 对象,shape 为[b, 28, 28], [b]
	# 标准化到0~1
	x = tf.cast(x, dtype=tf.float32) / 255.
	x = tf.reshape(x, [-1, 28*28]) # 打平
	y = tf.cast(y, dtype=tf.int32) # 转成整形张量
	y = tf.one_hot(y, depth=10) # one-hot 编码
	# 返回的x,y 将替换传入的x,y 参数,从而实现数据的预处理功能
	return x,y

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

加载手写数据集并进行数据处理

batchsz = 512
train_db = tf.data.Dataset.from_tensor_slices((x, y))  # 转化为Dataset对象
train_db = train_db.shuffle(1000)  # 随机打散
train_db = train_db.batch(batchsz)  # 批训练
train_db = train_db.map(preprocess)  # 数据预处理
train_db = train_db.repeat(20)  # 复制20份数据
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.shuffle(1000).batch(batchsz).map(preprocess)
x, y = next(iter(train_db))
print('train sample:', x.shape, y.shape)

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

 关于随机打散,批训练之类的内容在我的另一篇博客中有讲解:https://blog.csdn.net/python_LC_nohtyp/article/details/104106498

main() 函数部分

 在本次的网络中我们定义学习率lr=1e-2,并使用accs和losses两个列表来存储准确度和误差,方便之后绘图使用

设置网络层结构

网络的输入结点有784个,输出结点有10个

# 784 => 512
    w1, b1 = tf.Variable(tf.random.normal([784, 256], stddev=0.1)), tf.Variable(tf.zeros([256]))
    # 512 => 256
    w2, b2 = tf.Variable(tf.random.normal([256, 128], stddev=0.1)), tf.Variable(tf.zeros([128]))
    # 256 => 10
    w3, b3 = tf.Variable(tf.random.normal([128, 10], stddev=0.1)), tf.Variable(tf.zeros([10]))

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
循环更新

 现在我们进行循环更新,使用for循环去变量上述得到的train_db,并对w1,w2,w3,b1,b2,b3进行更新。

for step, (x, y) in enumerate(train_db):
	...

 
 
  • 1
  • 2

 下面说的都是for循环内的内容:

先我们将图片信息张量打平

x = tf.reshape(x, (-1, 784))

 
 
  • 1

之后进行网络的搭建和误差的计算

with tf.GradientTape() as tape:
    # layer1.
    h1 = x @ w1 + b1
    h1 = tf.nn.relu(h1)
    # layer2
    h2 = h1 @ w2 + b2
    h2 = tf.nn.relu(h2)
    # output
    out = h2 @ w3 + b3
    # compute loss
    # [b, 10] - [b, 10]
    loss = tf.square(y - out)
    # [b, 10] => scalar
    loss = tf.reduce_mean(loss)

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

通过自动求导函数计算梯度(求偏导)并进行参数的更新

参数更新通过公式:

      θ
     
     
      ′
     
    
    
     =
    
    
     θ
    
    
     −
    
    
     η
    
    
     ∗
    
    
     
      
       σ
      
      
       L
      
     
     
      
       σ
      
      
       θ
      
     
    
   
   
     \theta '= \theta - \eta *\frac{\sigma L}{\sigma \theta} 
   
  
 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.801892em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.02778em;">θ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.801892em;"><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.77777em; vertical-align: -0.08333em;"></span><span class="mord mathdefault" style="margin-right: 0.02778em;">θ</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 0.65972em; vertical-align: -0.19444em;"></span><span class="mord mathdefault" style="margin-right: 0.03588em;">η</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 2.04633em; vertical-align: -0.686em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.36033em;"><span class="" style="top: -2.314em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.03588em;">σ</span><span class="mord mathdefault" style="margin-right: 0.02778em;">θ</span></span></span><span class="" style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span class="" style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right: 0.03588em;">σ</span><span class="mord mathdefault">L</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.686em;"><span class=""></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></span><br> 进行更新</p> 

# 计算梯度
grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
# 参数更新
for p, g in zip([w1, b1, w2, b2, w3, b3], grads):
    p.assign_sub(lr * g)

 
 
  • 1
  • 2
  • 3
  • 4
  • 5

 每当step可以被100整除的时候打印一下错误率,并将其添加到列表当中,同时还进行准确度的计算

# print
if step % 100 == 0:
    print(step, 'loss:', float(loss))
    losses.append(float(loss))
if step % 100 == 0:
    ...

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

接下来说一下第二个if里应该写什么

首先我们先定义两个变量用于计算准确度

total, total_correct = 0., 0

 
 
  • 1

之后我们去迭代测试集获得准确度

 我们将测试集中的图片数据带入到目前的网络中进行对比,我们知道网络输出的是一个[b,10]结构的张量,b代表在每个数据集下的准确度,那么我们就选取最大的作为预测值
我们根据tf.argmax函数选出概率最大值出现的索引号,也即样本最有可能的类别号:
 pred = tf.argmax(out, axis=1)
由于我们的标注y 已经在预处理中完成了one-hot 编码,这在测试时其实是不需要的,因此通过tf.argmax 可以得到数字编码的标注y:
 y = tf.argmax(y, axis=1)
通过tf.equal 可以比较这2 者的结果是否相等:
 correct = tf.equal(pred, y)
并求和比较结果中所有True(转换为1)的数量,即为预测正确的数量:
total_correct += tf.reduce_sum(tf.cast(correct,dtype=tf.int32)).numpy()
通过预测的数量除以总测试数量即可得到准确度:
print(step, ‘Evaluate Acc:’, total_correct/total)

if step % 100 == 0:
    # evaluate/test
    total, total_correct = 0., 0
    # 计算准确度
    for x, y in test_db:
        # layer1.
        h1 = x @ w1 + b1
        h1 = tf.nn.relu(h1)
        # layer2
        h2 = h1 @ w2 + b2
        h2 = tf.nn.relu(h2)
        # output
        out = h2 @ w3 + b3
        # [b, 10] => [b]
        pred = tf.argmax(out, axis=1)
        # convert one_hot y to number y
        y = tf.argmax(y, axis=1)
        # bool type
        correct = tf.equal(pred, y)
        # bool tensor => int tensor => numpy
        total_correct += tf.reduce_sum(tf.cast(correct, dtype=tf.int32)).numpy()
        total += x.shape[0]
    print(step, 'Evaluate Acc:', total_correct / total)
    accs.append(total_correct / total)

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

 到这里循环更新的内容就写完了,
 通过简单的3 层神经网络,训练20 个Epoch 后,我们在测试集上获得了87.25%的准确率,如果使用复杂的神经网络模型,增加数据增强,精调网络超参数等技巧,可以获得更高的模型性能

生成svg图片文件

mp.figure()
x = [i * 80 for i in range(len(losses))]
mp.plot(x, losses, color='C0', marker='s', label='train')
mp.ylabel('MSE')
mp.xlabel('Step')
mp.legend()
mp.savefig('train.svg')
mp.figure()
mp.plot(x, accs, color='C1', marker='s', label='test')
mp.ylabel('Acc')
mp.xlabel('Step')
mp.legend()
mp.savefig('test.svg')

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

完整代码

from matplotlib import pyplot as mp
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers

def preprocess(x, y):
“”"
预处理函数
“”"

# [b, 28, 28], [b]
print(x.shape, y.shape)
x = tf.cast(x, dtype=tf.float32) / 255.
x = tf.reshape(x, [-1, 28 * 28]) # 将图片打平
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x, y

(x, y), (x_test, y_test) = datasets.mnist.load_data() # 加载手写数据集数据
print(‘x:’, x.shape, ‘y:’, y.shape, ‘x test:’, x_test.shape, ‘y test:’, y_test)

batchsz = 512
train_db = tf.data.Dataset.from_tensor_slices((x, y)) # 转化为Dataset对象
train_db = train_db.shuffle(1000) # 随机打散
train_db = train_db.batch(batchsz) # 批训练
train_db = train_db.map(preprocess) # 数据预处理
train_db = train_db.repeat(20) # 复制20份数据
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.shuffle(1000).batch(batchsz).map(preprocess)
x, y = next(iter(train_db))
print(‘train sample:’, x.shape, y.shape)

def main():
# learning rate
lr = 1e-2
accs, losses = [], []
# 784 => 512
w1, b1 = tf.Variable(tf.random.normal([784, 256], stddev=0.1)), tf.Variable(tf.zeros([256]))
# 512 => 256
w2, b2 = tf.Variable(tf.random.normal([256, 128], stddev=0.1)), tf.Variable(tf.zeros([128]))
# 256 => 10
w3, b3 = tf.Variable(tf.random.normal([128, 10], stddev=0.1)), tf.Variable(tf.zeros([10]))
for step, (x, y) in enumerate(train_db):
# [b, 28, 28] => [b, 784]
x = tf.reshape(x, (-1, 784))
with tf.GradientTape() as tape:
# layer1.
h1 = x @ w1 + b1
h1 = tf.nn.relu(h1)
# layer2
h2 = h1 @ w2 + b2
h2 = tf.nn.relu(h2)
# output
out = h2 @ w3 + b3
# compute loss
# [b, 10] - [b, 10]
loss = tf.square(y - out)
# [b, 10] => scalar
loss = tf.reduce_mean(loss)
# 计算梯度
grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
# 参数更新
for p, g in zip([w1, b1, w2, b2, w3, b3], grads):
p.assign_sub(lr * g)

    <span class="token comment"># print</span>
    <span class="token keyword">if</span> step <span class="token operator">%</span> <span class="token number">100</span> <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
        <span class="token keyword">print</span><span class="token punctuation">(</span>step<span class="token punctuation">,</span> <span class="token string">'loss:'</span><span class="token punctuation">,</span> <span class="token builtin">float</span><span class="token punctuation">(</span>loss<span class="token punctuation">)</span><span class="token punctuation">)</span>
        losses<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token builtin">float</span><span class="token punctuation">(</span>loss<span class="token punctuation">)</span><span class="token punctuation">)</span>

    <span class="token keyword">if</span> step <span class="token operator">%</span> <span class="token number">100</span> <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
        <span class="token comment"># evaluate/test</span>
        total<span class="token punctuation">,</span> total_correct <span class="token operator">=</span> <span class="token number">0</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">0</span>
        <span class="token comment"># 计算准确度</span>
        <span class="token keyword">for</span> x<span class="token punctuation">,</span> y <span class="token keyword">in</span> test_db<span class="token punctuation">:</span>
            <span class="token comment"># layer1.</span>
            h1 <span class="token operator">=</span> x @ w1 <span class="token operator">+</span> b1
            h1 <span class="token operator">=</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>h1<span class="token punctuation">)</span>
            <span class="token comment"># layer2</span>
            h2 <span class="token operator">=</span> h1 @ w2 <span class="token operator">+</span> b2
            h2 <span class="token operator">=</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>h2<span class="token punctuation">)</span>
            <span class="token comment"># output</span>
            out <span class="token operator">=</span> h2 @ w3 <span class="token operator">+</span> b3
            <span class="token comment"># [b, 10] =&gt; [b]</span>
            pred <span class="token operator">=</span> tf<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>out<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
            <span class="token comment"># convert one_hot y to number y</span>
            y <span class="token operator">=</span> tf<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>y<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
            <span class="token comment"># bool type</span>
            correct <span class="token operator">=</span> tf<span class="token punctuation">.</span>equal<span class="token punctuation">(</span>pred<span class="token punctuation">,</span> y<span class="token punctuation">)</span>
            <span class="token comment"># bool tensor =&gt; int tensor =&gt; numpy</span>
            total_correct <span class="token operator">+=</span> tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>cast<span class="token punctuation">(</span>correct<span class="token punctuation">,</span> dtype<span class="token operator">=</span>tf<span class="token punctuation">.</span>int32<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span>
            total <span class="token operator">+=</span> x<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span>
        <span class="token keyword">print</span><span class="token punctuation">(</span>step<span class="token punctuation">,</span> <span class="token string">'Evaluate Acc:'</span><span class="token punctuation">,</span> total_correct <span class="token operator">/</span> total<span class="token punctuation">)</span>
        accs<span class="token punctuation">.</span>append<span class="token punctuation">(</span>total_correct <span class="token operator">/</span> total<span class="token punctuation">)</span>

mp<span class="token punctuation">.</span>figure<span class="token punctuation">(</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> <span class="token punctuation">[</span>i <span class="token operator">*</span> <span class="token number">80</span> <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>losses<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">]</span>
mp<span class="token punctuation">.</span>plot<span class="token punctuation">(</span>x<span class="token punctuation">,</span> losses<span class="token punctuation">,</span> color<span class="token operator">=</span><span class="token string">'C0'</span><span class="token punctuation">,</span> marker<span class="token operator">=</span><span class="token string">'s'</span><span class="token punctuation">,</span> label<span class="token operator">=</span><span class="token string">'train'</span><span class="token punctuation">)</span>
mp<span class="token punctuation">.</span>ylabel<span class="token punctuation">(</span><span class="token string">'MSE'</span><span class="token punctuation">)</span>
mp<span class="token punctuation">.</span>xlabel<span class="token punctuation">(</span><span class="token string">'Step'</span><span class="token punctuation">)</span>
mp<span class="token punctuation">.</span>legend<span class="token punctuation">(</span><span class="token punctuation">)</span>
mp<span class="token punctuation">.</span>savefig<span class="token punctuation">(</span><span class="token string">'train.svg'</span><span class="token punctuation">)</span>

mp<span class="token punctuation">.</span>figure<span class="token punctuation">(</span><span class="token punctuation">)</span>
mp<span class="token punctuation">.</span>plot<span class="token punctuation">(</span>x<span class="token punctuation">,</span> accs<span class="token punctuation">,</span> color<span class="token operator">=</span><span class="token string">'C1'</span><span class="token punctuation">,</span> marker<span class="token operator">=</span><span class="token string">'s'</span><span class="token punctuation">,</span> label<span class="token operator">=</span><span class="token string">'test'</span><span class="token punctuation">)</span>
mp<span class="token punctuation">.</span>ylabel<span class="token punctuation">(</span><span class="token string">'Acc'</span><span class="token punctuation">)</span>
mp<span class="token punctuation">.</span>xlabel<span class="token punctuation">(</span><span class="token string">'Step'</span><span class="token punctuation">)</span>
mp<span class="token punctuation">.</span>legend<span class="token punctuation">(</span><span class="token punctuation">)</span>
mp<span class="token punctuation">.</span>savefig<span class="token punctuation">(</span><span class="token string">'test.svg'</span><span class="token punctuation">)</span>

if name == main:
main()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值