学习tensorflow2(3)

学习tensorflow第三天

学习进度有点慢,今天争取多学点

Broadcasting

import tensorflow as tf
import numpy as np

Broadcasting 称为广播机制(或自动扩展机制),它是一种轻量级的张量复制手段,在逻
辑上扩展张量数据的形状,但是只会在需要时才会执行实际存储复制操作。对于大部分场
景,Broadcasting 机制都能通过优化手段避免实际复制数据而完成逻辑运算,从而相对于
tf.tile 函数,减少了大量计算代价。

对于所有长度为 1 的维度,Broadcasting 的效果和 tf.tile 一样,都能在此维度上逻辑复
制数据若干份,区别在于 tf.tile 会创建一个新的张量,执行复制 IO 操作,并保存复制后的
张量数据,而 Broadcasting 并不会立即复制数据,它会在逻辑上改变张量的形状,使得视
图上变成了复制后的形状。Broadcasting 会通过深度学习框架的优化手段避免实际复制数据
而完成逻辑运算,至于怎么实现的用户不必关心,对于用户来说,Broadcasting 和 tf.tile 复
制的最终效果是一样的,操作对用户透明,但是 Broadcasting 机制节省了大量计算资源,
建议在运算过程中尽可能地利用 Broadcasting 机制提高计算效率。

x = tf.random.normal([2,4])
w = tf.random.normal([4,3])
b = tf.random.normal([3])
x@w+b
<tf.Tensor: id=282, shape=(2, 3), dtype=float32, numpy=
array([[ 0.7105683 , -1.0936362 , -1.2084496 ],
       [ 0.31603348,  0.09385422, -0.59150743]], dtype=float32)>

上述加法并没有发生逻辑错误,那么它是怎么实现的呢?这是因为它自动调用 Broadcasting
函数 tf.broadcast_to(x, new_shape),将两者 shape 扩张为相同的[2,3],即上式可以等效为:
y = x@w + tf.broadcast_to(b,[2,3]) # 手动扩展,并相加

那么有了 Broadcasting 机制后,所有 shape 不一致的张量是不是都可以直接完成运算?显然,所有的运算都需要在正确逻辑下进行,Broadcasting 机制并不会扰乱正常的计算逻辑,它只会针对于最常见的场景自动完成增加维度并复制数据的功能,提高开发效率和运行效率。这种最常见的场景是什么呢?这就要说到 Broadcasting 设计的核心思想。

通过 tf.broadcast_to(x, new_shape)函数可以显式地执行自动扩展功能,将现有 shape 扩
张为 new_shape,实现如下:

A = tf.random.normal([32,1])
print(A)
tf.Tensor(
[[-0.90817887]
 [-0.10580511]
 [-1.3813894 ]
 [ 0.2078821 ]
 [-0.31694472]
 [ 0.8619327 ]
 [ 1.605737  ]
 [ 0.304653  ]
 [ 0.22030532]
 [-0.57627136]
 [-1.4633512 ]
 [-0.8507311 ]
 [ 0.5887055 ]
 [-1.1897424 ]
 [-0.9097976 ]
 [-1.2628037 ]
 [ 2.8203614 ]
 [ 0.7602069 ]
 [ 1.3017323 ]
 [-1.5896397 ]
 [-0.18809475]
 [-0.1437744 ]
 [ 1.3868498 ]
 [ 0.12247411]
 [ 1.3054321 ]
 [ 0.15166871]
 [ 0.30159867]
 [-0.17985806]
 [ 0.08241519]
 [ 0.63582766]
 [-0.01303059]
 [ 0.47545263]], shape=(32, 1), dtype=float32)
tf.broadcast_to(A,[2,32,32,3])
<tf.Tensor: id=290, shape=(2, 32, 32, 3), dtype=float32, numpy=
array([[[[-0.90817887, -0.90817887, -0.90817887],
         [-0.10580511, -0.10580511, -0.10580511],
         [-1.3813894 , -1.3813894 , -1.3813894 ],
         ...,
         [ 0.63582766,  0.63582766,  0.63582766],
         [-0.01303059, -0.01303059, -0.01303059],
         [ 0.47545263,  0.47545263,  0.47545263]],

        [[-0.90817887, -0.90817887, -0.90817887],
         [-0.10580511, -0.10580511, -0.10580511],
         [-1.3813894 , -1.3813894 , -1.3813894 ],
         ...,
         [ 0.63582766,  0.63582766,  0.63582766],
         [-0.01303059, -0.01303059, -0.01303059],
         [ 0.47545263,  0.47545263,  0.47545263]],

        [[-0.90817887, -0.90817887, -0.90817887],
         [-0.10580511, -0.10580511, -0.10580511],
         [-1.3813894 , -1.3813894 , -1.3813894 ],
         ...,
         [ 0.63582766,  0.63582766,  0.63582766],
         [-0.01303059, -0.01303059, -0.01303059],
         [ 0.47545263,  0.47545263,  0.47545263]],

        ...,

        [[-0.90817887, -0.90817887, -0.90817887],
         [-0.10580511, -0.10580511, -0.10580511],
         [-1.3813894 , -1.3813894 , -1.3813894 ],
         ...,
         [ 0.63582766,  0.63582766,  0.63582766],
         [-0.01303059, -0.01303059, -0.01303059],
         [ 0.47545263,  0.47545263,  0.47545263]],

        [[-0.90817887, -0.90817887, -0.90817887],
         [-0.10580511, -0.10580511, -0.10580511],
         [-1.3813894 , -1.3813894 , -1.3813894 ],
         ...,
         [ 0.63582766,  0.63582766,  0.63582766],
         [-0.01303059, -0.01303059, -0.01303059],
         [ 0.47545263,  0.47545263,  0.47545263]],

        [[-0.90817887, -0.90817887, -0.90817887],
         [-0.10580511, -0.10580511, -0.10580511],
         [-1.3813894 , -1.3813894 , -1.3813894 ],
         ...,
         [ 0.63582766,  0.63582766,  0.63582766],
         [-0.01303059, -0.01303059, -0.01303059],
         [ 0.47545263,  0.47545263,  0.47545263]]],


       [[[-0.90817887, -0.90817887, -0.90817887],
         [-0.10580511, -0.10580511, -0.10580511],
         [-1.3813894 , -1.3813894 , -1.3813894 ],
         ...,
         [ 0.63582766,  0.63582766,  0.63582766],
         [-0.01303059, -0.01303059, -0.01303059],
         [ 0.47545263,  0.47545263,  0.47545263]],

        [[-0.90817887, -0.90817887, -0.90817887],
         [-0.10580511, -0.10580511, -0.10580511],
         [-1.3813894 , -1.3813894 , -1.3813894 ],
         ...,
         [ 0.63582766,  0.63582766,  0.63582766],
         [-0.01303059, -0.01303059, -0.01303059],
         [ 0.47545263,  0.47545263,  0.47545263]],

        [[-0.90817887, -0.90817887, -0.90817887],
         [-0.10580511, -0.10580511, -0.10580511],
         [-1.3813894 , -1.3813894 , -1.3813894 ],
         ...,
         [ 0.63582766,  0.63582766,  0.63582766],
         [-0.01303059, -0.01303059, -0.01303059],
         [ 0.47545263,  0.47545263,  0.47545263]],

        ...,

        [[-0.90817887, -0.90817887, -0.90817887],
         [-0.10580511, -0.10580511, -0.10580511],
         [-1.3813894 , -1.3813894 , -1.3813894 ],
         ...,
         [ 0.63582766,  0.63582766,  0.63582766],
         [-0.01303059, -0.01303059, -0.01303059],
         [ 0.47545263,  0.47545263,  0.47545263]],

        [[-0.90817887, -0.90817887, -0.90817887],
         [-0.10580511, -0.10580511, -0.10580511],
         [-1.3813894 , -1.3813894 , -1.3813894 ],
         ...,
         [ 0.63582766,  0.63582766,  0.63582766],
         [-0.01303059, -0.01303059, -0.01303059],
         [ 0.47545263,  0.47545263,  0.47545263]],

        [[-0.90817887, -0.90817887, -0.90817887],
         [-0.10580511, -0.10580511, -0.10580511],
         [-1.3813894 , -1.3813894 , -1.3813894 ],
         ...,
         [ 0.63582766,  0.63582766,  0.63582766],
         [-0.01303059, -0.01303059, -0.01303059],
         [ 0.47545263,  0.47545263,  0.47545263]]]], dtype=float32)>

并不是所有的位置都合适,当前维度不满足那么 在最后一个维度就无法使用朴实的原则

    A = tf.random.normal([32,2])

数学运算

加、减、乘、除是最基本的数学运算,分别通过 tf.add, tf.subtract, tf.multiply, tf.divide
函数实现,TensorFlow 已经重载了+、 −、 ∗ 、/运算符,一般推荐直接使用运算符来完成

整除和余除也是常见的运算之一,分别通过//和%运算符实现。我们来演示整除运
算,例如:

a = tf.range(5)
a
<tf.Tensor: id=301, shape=(5,), dtype=int32, numpy=array([0, 1, 2, 3, 4])>
b = tf.constant(2)
b
<tf.Tensor: id=302, shape=(), dtype=int32, numpy=2>
a//b  #整除运算
<tf.Tensor: id=303, shape=(5,), dtype=int32, numpy=array([0, 0, 1, 1, 2])>
a%b   #求余除法
<tf.Tensor: id=304, shape=(5,), dtype=int32, numpy=array([0, 1, 0, 1, 0])>

乘方

通过 tf.pow(x, a)可以方便地完成𝑦 = 𝑎 的乘方运算,也可以通过运算符**实现 ∗∗𝑏运
算,实现如下:

x = tf.range(4)
tf.pow(x,2)
<tf.Tensor: id=310, shape=(4,), dtype=int32, numpy=array([0, 1, 4, 9])>
tf.pow(x,3)
<tf.Tensor: id=312, shape=(4,), dtype=int32, numpy=array([ 0,  1,  8, 27])>
x**2
<tf.Tensor: id=314, shape=(4,), dtype=int32, numpy=array([0, 1, 4, 9])>

平方根

设置指数为 1
𝑎 形式,即可实现√
𝑎
根号运算,例如

x = tf.constant([1.,4.,9.])
x
<tf.Tensor: id=315, shape=(3,), dtype=float32, numpy=array([1., 4., 9.], dtype=float32)>
x**0.5  #设置分数变成求平方根
<tf.Tensor: id=317, shape=(3,), dtype=float32, numpy=array([1., 2., 3.], dtype=float32)>

特别地,对于常见的平方和平方根运算,可以使用 tf.square(x)和 tf.sqrt(x)实现。平方运算
实现如下

x = tf.range(5)
x = tf.cast(x,tf.float32)
x = tf.square(x)  #求平方和的运算
x
<tf.Tensor: id=323, shape=(5,), dtype=float32, numpy=array([ 0.,  1.,  4.,  9., 16.], dtype=float32)>

求平方根的运算

tf.sqrt(x)
<tf.Tensor: id=324, shape=(5,), dtype=float32, numpy=array([0., 1., 2., 3., 4.], dtype=float32)>

求指数和对数运算

通过 tf.pow(a, x)或者**运算符也可以方便地实现指数运算𝑏 𝑥 ,例如

x = tf.constant([1.,2.,3.])
2**x  #指数运算
<tf.Tensor: id=327, shape=(3,), dtype=float32, numpy=array([2., 4., 8.], dtype=float32)>

特别地,对于自然指数e 𝑥 ,可以通过 tf.exp(x)实现,例如

tf.exp(1.)
<tf.Tensor: id=329, shape=(), dtype=float32, numpy=2.7182817>

在 TensorFlow 中,自然对数log e 可以通过 tf.math.log(x)实现,例如:

x = tf.exp(3.)
x
<tf.Tensor: id=331, shape=(), dtype=float32, numpy=20.085537>
tf.math.log(x)
<tf.Tensor: id=332, shape=(), dtype=float32, numpy=3.0>

如果希望计算其它底数的对数,可以根据对数的换底公式

间接地通过 tf.math.log(x)实现。如计算log 1 可以通过
log e 𝑥
log e 1 实现如下:

x = tf.constant([1.,2.])
x = 10**x
x
<tf.Tensor: id=335, shape=(2,), dtype=float32, numpy=array([ 10., 100.], dtype=float32)>
tf.math.log(x)/tf.math.log(10.)
<tf.Tensor: id=339, shape=(2,), dtype=float32, numpy=array([1., 2.], dtype=float32)>

实现起来相对繁琐,也许 TensorFlow 以后会推出任意底数的 log 函数。

矩阵相乘 矩阵相乘

神经网络中间包含了大量的矩阵相乘运算,前面我们已经介绍了通过@运算符可以方
便的实现矩阵相乘,还可以通过 tf.matmul(a, b)函数实现。需要注意的是,TensorFlow 中的
矩阵相乘可以使用批量方式,也就是张量𝑩和𝑪的维度数可以大于 2。当张量𝑩和𝑪维度数大
于 2 时,TensorFlow 会选择𝑩和𝑪的最后两个维度进行矩阵相乘,前面所有的维度都视作

根据矩阵相乘的定义,𝑩和𝑪能够矩阵相乘的条件是,𝑩的倒数第一个维度长度(列)和𝑪
的倒数第二个维度长度(行)必须相等。比如

a = tf.random.normal([4,3,28,32])
b = tf.random.normal([4,3,32,2])
a@b  #可以看出来shape的值为4,3,28,2
<tf.Tensor: id=352, shape=(4, 3, 28, 2), dtype=float32, numpy=
array([[[[-7.35481071e+00, -2.03977299e+00],
         [-3.53628492e+00,  5.42532158e+00],
         [ 6.02580070e+00,  1.65104294e+01],
         [ 3.53147149e+00,  7.73439074e+00],
         [ 4.83887196e+00, -1.71686435e+00],
         [-5.16285419e+00,  3.85994983e+00],
         [-5.72045374e+00, -1.33586657e+00],
         [ 5.72763348e+00, -3.38414907e-02],
         [-3.40847206e+00,  1.23990459e+01],
         [ 7.43876982e+00, -2.31412315e+00],
         [ 1.54934788e+00, -3.35335588e+00],
         [ 9.40248299e+00, -6.22480106e+00],
         [ 1.18808985e+01, -2.53611326e+00],
         [ 4.17519283e+00, -1.51334023e+00],
         [-6.75919676e+00, -5.21682072e+00],
         [-3.36644888e+00, -2.82862759e+00],
         [-1.16124134e+01, -1.50398493e+00],
         [-8.56001091e+00, -2.28039122e+00],
         [-2.40692997e+00, -8.32359695e+00],
         [ 7.32393169e+00, -2.05355263e+00],
         [ 3.57546616e+00, -6.92506886e+00],
         [-4.34608936e+00,  1.60183096e+00],
         [-3.37379670e+00, -7.84966469e-01],
         [ 1.41078424e+00, -1.48273158e+00],
         [ 2.05612922e+00,  2.27115583e+00],
         [-4.16214514e+00,  2.89665055e+00],
         [ 3.56977487e+00,  2.57852793e-01],
         [-5.27147245e+00,  1.78152790e+01]],

        [[-1.53355420e+00,  4.69918299e+00],
         [ 8.77958107e+00, -2.93876648e+00],
         [-5.04937077e+00,  7.56716633e+00],
         [-6.23950481e-01,  1.22060800e+00],
         [-5.54909992e+00, -4.24554253e+00],
         [-6.21147060e+00,  7.91981697e+00],
         [ 5.43690538e+00,  6.50116742e-01],
         [-7.63458300e+00,  1.25878906e+01],
         [ 2.38654304e+00,  5.72002411e-01],
         [-5.80140972e+00,  3.75658989e+00],
         [-4.31965399e+00, -9.42380333e+00],
         [-3.59904003e+00,  2.10305142e+00],
         [-2.13570929e+00, -2.44541287e+00],
         [-1.50560582e+00, -5.90160751e+00],
         [-2.52337933e+00, -3.87742233e+00],
         [ 9.41783142e+00,  4.36572647e+00],
         [-8.39972973e+00,  6.15722179e-01],
         [ 1.08730545e+01,  1.33887696e+00],
         [ 5.06330204e+00,  3.21782875e+00],
         [-6.71337223e+00,  3.77275920e+00],
         [-3.35591054e+00, -7.04978418e+00],
         [-2.90958977e+00,  1.96890903e+00],
         [ 7.45633173e+00,  3.10271263e-01],
         [ 8.74769402e+00,  2.22580385e+00],
         [ 3.66506219e+00, -4.80859756e+00],
         [ 2.04117036e+00, -3.44947767e+00],
         [-1.36152256e+00, -9.40324879e+00],
         [-4.52193642e+00, -4.28129435e+00]],

        [[-2.50831032e+00,  5.23478985e+00],
         [-1.79059589e+00,  1.42886639e-02],
         [ 4.81672573e+00, -2.55373359e+00],
         [-1.85727954e+00, -1.26235902e+00],
         [ 1.90887547e+00,  4.01020527e+00],
         [-5.71813703e-01, -1.45860672e-01],
         [-1.49302220e+00,  6.27095413e+00],
         [-2.04000354e-01,  7.02989578e+00],
         [-8.06577563e-01,  1.29722481e+01],
         [-2.50476503e+00,  3.45623112e+00],
         [-5.66325283e+00, -8.29533386e+00],
         [ 9.11501646e-02,  1.42319536e+00],
         [-7.96511531e-01,  3.30337143e+00],
         [ 2.18581867e+00,  5.92607784e+00],
         [-7.77641821e+00, -1.49946880e+00],
         [ 3.81387782e+00,  5.70896912e+00],
         [-4.37150300e-01, -2.80704999e+00],
         [-1.47905374e+00,  3.31149483e+00],
         [ 4.66254711e+00,  2.15404606e+00],
         [-9.92804432e+00, -4.85740709e+00],
         [ 1.45119691e+00, -5.66273355e+00],
         [-4.71570492e+00, -5.53786135e+00],
         [ 8.12362099e+00, -3.63454318e+00],
         [-3.21230531e-01, -5.24273586e+00],
         [-4.24352741e+00, -9.85592365e-01],
         [-3.18727446e+00, -6.47068024e-01],
         [ 1.51570010e+00,  2.30929017e-01],
         [ 9.96023178e-01, -5.94730377e+00]]],


       [[[ 3.03751183e+00, -6.42512178e+00],
         [-1.84599161e-01,  5.21206856e-02],
         [ 1.48670034e+01, -7.13164806e-02],
         [-1.74293840e+00, -8.03088856e+00],
         [ 1.68042183e-01,  7.30098915e+00],
         [-9.57644939e-01,  3.60807610e+00],
         [-7.70722866e+00, -2.52036333e+00],
         [-3.00681496e+00, -2.62908816e-01],
         [ 4.30262423e+00, -6.15630817e+00],
         [-9.08833313e+00, -6.40916920e+00],
         [-3.54467583e+00, -7.44161606e-02],
         [-5.45863819e+00, -1.47636938e+00],
         [-6.10821724e-01,  3.75593472e+00],
         [-8.75724673e-01, -1.60028505e+00],
         [ 6.99069202e-01,  1.43948615e+00],
         [ 2.76076674e-01,  5.40226173e+00],
         [-9.09997106e-01,  2.17277002e+00],
         [ 2.51408052e+00, -2.11053777e+00],
         [-2.56637502e+00,  4.46516418e+00],
         [ 1.41587424e+00, -1.89524055e-01],
         [-1.27456951e+00, -8.37762260e+00],
         [-3.02583170e+00,  2.95815468e+00],
         [-1.98925233e+00, -2.96769798e-01],
         [-6.71029449e-01, -5.79849434e+00],
         [-2.44489455e+00, -4.04025030e+00],
         [-4.49462461e+00,  4.15361166e-01],
         [-1.10173664e+01,  2.22191334e+00],
         [ 5.70892715e+00,  2.50673294e-03]],

        [[-5.78848362e+00, -7.71899462e-01],
         [ 3.27780986e+00, -6.83204937e+00],
         [ 4.17199969e-01, -1.57258916e+00],
         [-3.04634929e+00, -2.26687956e+00],
         [ 7.01046038e+00,  4.12341928e+00],
         [ 9.55332851e+00,  8.26794147e-01],
         [-5.03738594e+00,  1.28693414e+00],
         [-9.53421116e-01, -2.69434953e+00],
         [ 6.97012901e-01, -8.22899342e-01],
         [-1.16427004e+00,  1.10717640e+01],
         [ 5.14945221e+00, -1.08178675e+00],
         [-1.92237055e+00, -5.68886566e+00],
         [ 2.69684005e+00,  6.30944347e+00],
         [ 2.73702693e+00, -3.42885971e-01],
         [-1.00363684e+00, -1.45203578e+00],
         [ 3.49643588e+00,  7.42194653e+00],
         [ 2.40598941e+00, -1.51968002e-01],
         [-7.96474123e+00,  1.19880247e+00],
         [-8.60718060e+00,  3.77753305e+00],
         [ 4.76363897e+00, -5.11723852e+00],
         [ 1.22043562e+00,  5.38218260e+00],
         [-9.00807667e+00,  6.45827866e+00],
         [ 1.53577304e+00,  6.81701064e-01],
         [ 4.25251627e+00,  1.01356602e+00],
         [-4.99286652e+00, -5.23191643e+00],
         [-4.74570608e+00,  7.27024603e+00],
         [ 8.33272934e+00, -2.77329850e+00],
         [-4.92028475e+00,  5.54845524e+00]],

        [[-1.53533554e+00, -2.71458387e-01],
         [ 8.45319557e+00,  1.20218337e+00],
         [-4.67706203e-01,  8.14874291e-01],
         [ 8.24496460e+00,  3.64949846e+00],
         [-7.87544727e+00, -6.18611622e+00],
         [ 3.63813782e+00, -4.85417938e+00],
         [ 5.76154590e-01,  7.08875239e-01],
         [-1.89167261e+00,  6.65473604e+00],
         [ 2.10892391e+00, -2.26659918e+00],
         [ 1.25980031e+00, -9.27711964e-01],
         [-6.54261684e+00,  7.92043161e+00],
         [-2.96644354e+00, -1.74062371e-01],
         [-1.87673473e+00, -1.14581308e+01],
         [-2.69293833e+00,  1.84136200e+00],
         [ 6.82891607e-01,  1.82788730e+00],
         [-1.21999300e+00,  4.33571815e+00],
         [ 3.00662470e+00, -6.74679852e+00],
         [ 1.72100914e+00,  7.30976582e+00],
         [ 1.08241348e+01,  7.89593577e-01],
         [-3.65544009e+00, -1.14054298e+00],
         [-4.04491854e+00, -1.67660010e+00],
         [-2.67730665e+00, -4.30352926e+00],
         [ 3.18090773e+00, -1.10674324e+01],
         [-7.69230652e+00,  1.26239038e+00],
         [ 6.02041149e+00,  7.73219109e-01],
         [-2.30420113e+00, -5.21923161e+00],
         [-1.60385942e+00,  4.00859833e+00],
         [ 5.13326979e+00, -2.06486726e+00]]],


       [[[-6.49438095e+00, -2.74308777e+00],
         [ 6.84007168e+00,  4.25863171e+00],
         [ 3.19867730e+00,  6.85937643e-01],
         [-4.42984343e-01, -3.40796375e+00],
         [-6.44190693e+00, -3.57720184e+00],
         [ 2.67074251e+00,  2.91125774e+00],
         [ 9.41680551e-01,  2.80589318e+00],
         [-1.56111926e-01,  6.88308430e+00],
         [-4.55091047e+00,  2.38495564e+00],
         [ 2.00643826e+00, -3.26853275e-01],
         [ 2.05168486e+00, -1.03192663e+00],
         [-4.32541251e-01, -2.43872952e+00],
         [ 2.44771957e+00,  5.08524847e+00],
         [-4.72986126e+00,  1.49521704e+01],
         [-2.73483467e+00,  2.64439583e-01],
         [-5.31258774e+00, -3.91468072e+00],
         [ 3.68331981e+00,  2.89684153e+00],
         [ 9.42306995e+00,  4.12347603e+00],
         [-1.04403858e+01,  8.24390793e+00],
         [-1.71168458e+00, -8.86144352e+00],
         [-2.73308802e+00, -3.02217174e+00],
         [ 8.58013034e-01, -2.82139230e+00],
         [-6.84384704e-01, -2.07008958e-01],
         [-3.53357553e-01, -2.34063053e+00],
         [ 2.44999075e+00, -4.10168266e+00],
         [ 4.14325356e-01, -3.47654080e+00],
         [ 8.98855305e+00,  6.58408356e+00],
         [-3.99292350e-01, -1.12091374e+00]],

        [[ 3.14571571e+00,  2.99686790e+00],
         [-2.21133709e+00, -9.70385790e-01],
         [-5.96391106e+00, -9.51159954e-01],
         [-4.77849197e+00,  3.71352911e+00],
         [ 1.17691517e-01, -5.92812896e-01],
         [ 6.44118309e+00, -2.14396906e+00],
         [-5.36724186e+00,  3.96309376e-01],
         [ 3.85859966e-01, -2.32653213e+00],
         [-3.22249293e+00, -6.53439951e+00],
         [ 3.02827764e+00,  8.61958981e+00],
         [-1.42335320e+00, -3.99454832e+00],
         [-6.36955023e+00, -7.28528357e+00],
         [ 1.85124636e+00, -5.02364635e-01],
         [ 1.02130013e+01, -3.38119030e-01],
         [-8.71157646e+00, -1.61155205e+01],
         [-5.58173752e+00, -1.24358463e+00],
         [-8.18886948e+00, -1.90517139e+00],
         [ 6.99361229e+00,  9.25618458e+00],
         [-2.89466071e+00, -9.60486507e+00],
         [-2.14623737e+00, -5.37127495e+00],
         [ 1.89860020e+01,  3.93344164e+00],
         [ 7.12473392e-01,  2.13070750e+00],
         [-6.60672903e-01, -7.88174772e+00],
         [ 2.97913718e+00, -3.12134409e+00],
         [-1.07682629e+01,  8.61428452e+00],
         [-3.58722234e+00, -3.43406534e+00],
         [ 6.41956568e+00,  7.41960239e+00],
         [ 7.00065136e+00,  7.96450710e+00]],

        [[-6.10993052e+00, -2.18314385e+00],
         [ 4.31992340e+00,  1.16279054e+00],
         [ 4.96049404e-01,  9.78079891e+00],
         [-6.47086668e+00,  7.21457720e+00],
         [-3.75694990e+00,  1.06022692e+01],
         [ 7.92064857e+00,  3.27664900e+00],
         [-1.24585319e+00, -4.76537943e-01],
         [ 1.49830711e+00, -5.17059517e+00],
         [ 3.37454414e+00,  2.92519665e+00],
         [-2.53148031e+00, -7.02141047e+00],
         [-7.01026297e+00, -4.41620684e+00],
         [-8.86301517e+00,  6.36488676e-01],
         [ 4.18034434e-01,  3.57442784e+00],
         [ 2.04600620e+00,  1.46792746e+00],
         [-1.98461819e+00, -3.14873171e+00],
         [ 4.86282635e+00, -5.89713860e+00],
         [ 7.50659752e+00, -2.13629532e+00],
         [ 1.15753241e+01, -1.28638077e+00],
         [ 1.56143200e+00,  3.63091469e+00],
         [-6.95650291e+00,  7.40752554e+00],
         [ 6.95579767e-01,  9.03932571e-01],
         [-1.39583039e+00,  5.24073792e+00],
         [ 3.21206069e+00,  3.08843207e+00],
         [-1.57232094e+00, -4.91197395e+00],
         [ 6.76794648e-01, -6.70850420e+00],
         [ 2.22255826e-01, -5.70581102e+00],
         [-4.06473112e+00,  5.38560104e+00],
         [-5.58861446e+00, -4.31950569e-01]]],


       [[[ 8.55044174e+00, -5.81358528e+00],
         [-6.04423475e+00,  9.65268326e+00],
         [ 6.29490376e+00, -7.54046249e+00],
         [-3.34304357e+00, -6.56664467e+00],
         [-5.33210802e+00,  3.63116646e+00],
         [ 2.37405181e+00,  9.96071434e+00],
         [ 1.11147547e+00, -8.46913147e+00],
         [-9.01581764e-01,  2.72447109e-01],
         [-3.75446606e+00,  1.10781860e+01],
         [-4.64594364e+00,  7.05423069e+00],
         [-1.21783180e+01, -3.61951852e+00],
         [ 5.92875099e+00, -6.47851086e+00],
         [-3.90436602e+00, -8.32291412e+00],
         [ 3.02298975e+00,  3.54001522e+00],
         [-1.80975759e+00, -2.40108633e+00],
         [ 1.48073983e+00,  1.17728739e+01],
         [-3.31598830e+00,  8.33437157e+00],
         [ 3.61406684e+00,  1.03236341e+00],
         [-1.77352738e+00, -2.55925179e-01],
         [-5.86857939e+00, -8.09783077e+00],
         [ 1.68138695e+00, -1.58611975e+01],
         [ 2.02576685e+00,  3.71068311e+00],
         [ 2.72179627e+00,  1.01250381e+01],
         [ 1.18310423e+01,  1.22477646e+01],
         [ 7.11085176e+00, -2.96379900e+00],
         [ 2.74896646e+00, -6.54157352e+00],
         [-1.70857251e+00, -1.34025192e+00],
         [-2.88917279e+00, -5.81262493e+00]],

        [[ 6.40382385e+00, -3.45796633e+00],
         [ 1.31398773e+00,  2.63075757e+00],
         [-3.03473091e+00, -4.99808121e+00],
         [-5.78912067e+00, -7.22292006e-01],
         [ 1.15542126e+01,  7.40390444e+00],
         [ 3.43226624e+00, -1.15072193e+01],
         [-3.64890933e+00, -1.32880306e+00],
         [ 9.70282555e-02,  4.58459139e+00],
         [ 3.11939001e+00,  3.80542994e-01],
         [-1.15523872e+01,  1.56072068e+00],
         [-1.17496367e+01, -1.15079403e+01],
         [ 4.40511513e+00, -1.29582715e+00],
         [-5.96714973e+00,  9.57839966e+00],
         [-8.42221618e-01, -1.81568909e+00],
         [-1.52175188e+00, -4.46038198e+00],
         [ 6.42946529e+00, -4.65644455e+00],
         [ 2.72929668e+00, -5.48132038e+00],
         [ 1.06883764e+01,  1.58343887e+00],
         [-1.61480904e+00,  5.11785507e-01],
         [ 4.32915831e+00, -4.94306517e+00],
         [ 5.21513796e+00, -1.24856300e+01],
         [-4.30539846e+00, -2.64125538e+00],
         [-1.07499623e+00, -3.38630986e+00],
         [-7.20928574e+00, -3.56886435e+00],
         [-5.18494129e+00,  1.04023104e+01],
         [ 7.08348513e+00,  1.00553198e+01],
         [-2.69047499e+00,  5.75759792e+00],
         [ 8.95052052e+00, -1.09065266e+01]],

        [[ 1.86562717e+00,  2.10286880e+00],
         [ 7.47056961e+00, -1.72648740e+00],
         [-6.61417389e+00, -5.18063927e+00],
         [-5.63483238e-02, -4.29674244e+00],
         [-1.16115618e+00,  7.33229733e+00],
         [ 8.75630188e+00, -1.59543161e+01],
         [ 5.34942913e+00,  1.71065235e+00],
         [-8.14774132e+00, -5.06303215e+00],
         [-8.85612869e+00,  4.85389900e+00],
         [-7.58177102e-01,  5.48503923e+00],
         [ 2.40571117e+00,  1.14249821e+01],
         [ 3.75277305e+00, -1.13059711e+00],
         [ 9.77970362e-02,  5.23177290e+00],
         [-8.80667591e+00,  1.25938165e+00],
         [-4.94224548e+00,  4.98459291e+00],
         [ 3.17610693e+00,  1.10556488e+01],
         [-8.95340443e-01,  1.32760680e+00],
         [-1.48470902e+00,  8.13624001e+00],
         [-4.23281789e-01,  6.26393461e+00],
         [ 2.31567717e+00, -3.70716834e+00],
         [ 3.46101618e+00, -1.32023335e+01],
         [-3.69308448e+00, -2.05429411e+00],
         [ 9.10969925e+00,  3.59173727e+00],
         [ 9.05836487e+00, -5.84303141e+00],
         [ 1.19177341e+01,  9.51084805e+00],
         [ 7.98901939e+00, -7.79954433e+00],
         [ 3.83971834e+00, -3.18983769e+00],
         [-3.08059907e+00,  2.53280663e+00]]]], dtype=float32)>

矩阵相乘函数同样支持自动 Broadcasting 机制,例如

a = tf.random.normal([4,28,32])
b = tf.random.normal([32,16])
a@b   #此处调用了自动扩展并且相乘
<tf.Tensor: id=365, shape=(4, 28, 16), dtype=float32, numpy=
array([[[ -2.1727865 ,  -1.1229348 ,   0.5529645 , ...,   6.850294  ,
          -5.080549  ,   1.3870585 ],
        [ -0.7776412 ,   3.3607047 ,  -1.3229356 , ...,  -0.33060586,
          -1.718039  , -12.170118  ],
        [ -1.6537416 ,   3.1806996 ,  -4.479284  , ...,  -3.0433881 ,
           0.23675452,  -5.6958113 ],
        ...,
        [  3.659226  ,  -1.4129915 ,  -9.812353  , ...,  -0.89915407,
          10.7898855 ,   2.9662404 ],
        [  4.895536  ,  -2.4533572 ,   2.7401347 , ...,   0.86894536,
         -10.848548  ,   1.3184551 ],
        [  4.985075  ,  -0.63191354,  -4.784621  , ...,  -1.2277026 ,
           3.8447945 ,  -4.181891  ]],

       [[  0.51406497,   5.8773036 ,   0.7100219 , ...,  -3.0020056 ,
          -0.21059486,   2.430263  ],
        [ -0.27355823,   6.9391675 ,  -2.6067553 , ...,   3.5078154 ,
           4.6054435 ,  -1.3159671 ],
        [  0.23995495,  -1.2228749 ,  -1.578268  , ...,  11.343817  ,
           2.5552132 ,   6.416268  ],
        ...,
        [  0.7496167 ,   2.6149487 ,  -3.3291793 , ...,  -1.6913137 ,
          -4.6849575 ,  -0.07390194],
        [  3.2651026 , -10.189845  ,  -3.1388736 , ...,   3.340275  ,
          -3.5930347 ,   5.0638633 ],
        [ 14.000745  ,  -7.6722584 ,  -4.071188  , ...,  -2.4604309 ,
          -2.6787014 ,   4.1632576 ]],

       [[ -2.8267236 ,  -1.8276253 ,   6.803973  , ...,   6.998722  ,
          -0.13112731,   1.3310764 ],
        [  2.563221  ,  -1.0456358 ,   4.9088545 , ...,  10.572786  ,
          -1.4405905 ,   2.2257853 ],
        [  3.8990169 ,  -1.4742144 ,   2.1636624 , ...,  -3.6045642 ,
          -1.984739  ,  -6.500014  ],
        ...,
        [ -5.0625362 ,  -5.6125216 ,   5.606078  , ...,  -8.171687  ,
          -1.0461836 ,  -0.2186041 ],
        [  0.7595123 ,  -4.489399  ,   7.729659  , ...,  -0.26062453,
           2.0305562 ,  12.910347  ],
        [  0.5906749 ,  -9.632759  ,  -1.6593897 , ...,   2.9834595 ,
          -0.27359906,   3.4350774 ]],

       [[ -1.3969722 ,  -4.006015  ,  -0.78274345, ...,  -1.9209927 ,
           2.7273333 ,   3.325152  ],
        [  2.0444467 ,  -2.7864232 ,   2.97246   , ...,   3.4784262 ,
          -3.8851411 ,   7.043869  ],
        [  4.168339  ,  -3.8495026 ,   2.021028  , ...,  -0.03035581,
          -9.135693  ,   0.9983522 ],
        ...,
        [  5.2611065 ,  -1.1056949 ,   5.6409683 , ...,   6.9540353 ,
           2.3831298 ,   6.377899  ],
        [ 11.992814  ,   2.197509  ,   4.5876446 , ...,  -2.6564171 ,
          -6.70335   ,  12.283043  ],
        [ 13.1055355 ,  -8.137001  ,   5.915893  , ...,  -0.98460674,
          -2.5176544 ,  -1.723175  ]]], dtype=float32)>

上述运算自动将变量 b 扩展为公共 shape:[4,32,16],再与变量 a 进行批量形式地矩阵相
乘,得到结果的 shape 为[4,28,16]

向前传播实战

到现在为止,我们已经介绍了如何创建张量、对张量进行索引切片、维度变换和常见的数学运算等操作。最后我们将利用已经学到的知识去完成三层神经网络的实现

我们采用的数据集是 MNIST 手写数字图片集,输入节点数为 784,第一层的输出节点数是
测试版1205
第 4 章 TensorFlow 基础 36
256,第二层的输出节点数是 128,第三层的输出节点是 10,也就是当前样本属于 10 类别
的概率。
首先创建每个非线性层的𝑿和𝒃张量参数,代码如下:

# 每层的张量都需要被优化,故使用 Variable 类型,并使用截断的正太分布初始化权值张
# 偏置向量初始化为 0 即可
#第一层的参数
w1 = tf.Variable(tf.random.truncated_normal([784,256],stddev=0.1))   #normal函数中 mean代表均值,stddev表示标准差
 w1
<tf.Variable 'Variable:0' shape=(784, 256) dtype=float32, numpy=
array([[ 0.05604904,  0.1754755 , -0.15889499, ..., -0.13000603,
         0.12375573,  0.11799996],
       [ 0.01727579, -0.02990366,  0.18539117, ...,  0.10088935,
         0.0123363 ,  0.07507049],
       [-0.03101482,  0.06787052, -0.17878658, ...,  0.05300893,
         0.03232927,  0.14271668],
       ...,
       [ 0.00183942, -0.12539563, -0.00482116, ...,  0.00135334,
        -0.08737514, -0.01430448],
       [-0.08277363, -0.04015374, -0.06549906, ...,  0.08081874,
        -0.00835462,  0.15705827],
       [-0.10950299, -0.15212834,  0.15279026, ..., -0.16317515,
        -0.17880735,  0.12734374]], dtype=float32)>
b1 = tf.Variable(tf.zeros([256]))
#第二层参数
w2 = tf.Variable(tf.random.truncated_normal([256,128],stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))
#第三层参数
w3 = tf.Variable(tf.random.truncated_normal([128,10],stddev=0.1))
b3 = tf.Variable(tf.zeros([10]))

在前向计算时,首先将 shape 为[𝑐,28,28]的输入张量的视图调整为[𝑐,784],即将每个
图片的矩阵数据调整为向量特征,这样才适合于网络的输入格式

首先需要将输入的向量变换个格式并且在相应的格式上面实现输入

#改变视图[b,28,28]------------>[b,28*28]
from tensorflow import keras
from tensorflow.keras import layers,optimizers,datasets
(x,y),_ = datasets.mnist.load_data()
x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
y = tf.convert_to_tensor(y, dtype=tf.int32)
print(x.shape, y.shape, x.dtype, y.dtype)
print(tf.reduce_min(x), tf.reduce_max(x))
print(tf.reduce_min(y), tf.reduce_max(y))
(60000, 28, 28) (60000,) <dtype: 'float32'> <dtype: 'int32'>
tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(1.0, shape=(), dtype=float32)
tf.Tensor(0, shape=(), dtype=int32) tf.Tensor(9, shape=(), dtype=int32)
x = tf.reshape(x,[-1,28*28])
x
<tf.Tensor: id=616, shape=(60000, 784), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>

完成第一步的操作,显示的进行自动扩展

# 第一层计算,[b, 784]@[784, 256] + [256] => [b, 256] + [256] => [b,256] + [b, 256]
h1 = x@w1+tf.broadcast_to(b1,[x.shape[0],256])
h1 = tf.nn.relu(h1)
h2 = h1@w2+b2
h2 = tf.nn.relu(h2)
out = h2@w3+b3
y
<tf.Tensor: id=606, shape=(60000,), dtype=int32, numpy=array([5, 0, 4, ..., 5, 6, 8])>
out
<tf.Tensor: id=648, shape=(60000, 10), dtype=float32, numpy=
array([[ 0.33768478,  0.7924377 , -0.26547465, ...,  0.7518244 ,
         1.0822353 ,  0.27058533],
       [ 0.50584865,  0.07583661, -0.95454806, ...,  0.8536314 ,
         0.5347482 ,  0.01661541],
       [-0.39949498,  0.5179445 , -0.06265643, ...,  0.56090957,
         1.1628726 ,  1.0270436 ],
       ...,
       [-0.03918545, -0.03424276, -0.3795507 , ...,  0.64348745,
         1.0213664 , -0.242055  ],
       [ 0.1548642 , -0.415146  , -0.50202966, ...,  0.6150216 ,
         1.1578825 ,  0.27539507],
       [ 0.16750555, -0.47161245, -0.13367139, ...,  0.7650149 ,
         0.5568352 ,  0.10468676]], dtype=float32)>
y = tf.one_hot(y,depth=10)
y
<tf.Tensor: id=652, shape=(60000, 10), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.]], dtype=float32)>
#计算均值方差
#[b,10]
loss = tf.square(y-out)
#误差标量
loss = tf.reduce_mean(loss)
loss
<tf.Tensor: id=662, shape=(), dtype=float32, numpy=0.44232228>

上述的前向计算过程都需要包裹在 with tf.GradientTape() as tape 上下文中,使得前向计算时
能够保存计算图信息,方便自动求导操作。

通过 tape.gradient()函数求得网络参数到梯度信息,结果保存在 grads 列表变量中,实
现如下:

# 自动梯度,需要求梯度的张量有[w1, b1, w2, b2, w3, b3]
grads = tape.gradient(loss,[w1,b1,w2,b2,w3,b3])
---------------------------------------------------------------------------

NameError                                 Traceback (most recent call last)

<ipython-input-292-ef23df5cadc0> in <module>
----> 1 grads = tape.gradient(loss,[w1,b1,w2,b2,w3,b3])


NameError: name 'tape' is not defined

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值