学习tensorflow第三天
学习进度有点慢,今天争取多学点
Broadcasting
import tensorflow as tf
import numpy as np
Broadcasting 称为广播机制(或自动扩展机制),它是一种轻量级的张量复制手段,在逻
辑上扩展张量数据的形状,但是只会在需要时才会执行实际存储复制操作。对于大部分场
景,Broadcasting 机制都能通过优化手段避免实际复制数据而完成逻辑运算,从而相对于
tf.tile 函数,减少了大量计算代价。
对于所有长度为 1 的维度,Broadcasting 的效果和 tf.tile 一样,都能在此维度上逻辑复
制数据若干份,区别在于 tf.tile 会创建一个新的张量,执行复制 IO 操作,并保存复制后的
张量数据,而 Broadcasting 并不会立即复制数据,它会在逻辑上改变张量的形状,使得视
图上变成了复制后的形状。Broadcasting 会通过深度学习框架的优化手段避免实际复制数据
而完成逻辑运算,至于怎么实现的用户不必关心,对于用户来说,Broadcasting 和 tf.tile 复
制的最终效果是一样的,操作对用户透明,但是 Broadcasting 机制节省了大量计算资源,
建议在运算过程中尽可能地利用 Broadcasting 机制提高计算效率。
x = tf.random.normal([2,4])
w = tf.random.normal([4,3])
b = tf.random.normal([3])
x@w+b
<tf.Tensor: id=282, shape=(2, 3), dtype=float32, numpy=
array([[ 0.7105683 , -1.0936362 , -1.2084496 ],
[ 0.31603348, 0.09385422, -0.59150743]], dtype=float32)>
上述加法并没有发生逻辑错误,那么它是怎么实现的呢?这是因为它自动调用 Broadcasting
函数 tf.broadcast_to(x, new_shape),将两者 shape 扩张为相同的[2,3],即上式可以等效为:
y = x@w + tf.broadcast_to(b,[2,3]) # 手动扩展,并相加
那么有了 Broadcasting 机制后,所有 shape 不一致的张量是不是都可以直接完成运算?显然,所有的运算都需要在正确逻辑下进行,Broadcasting 机制并不会扰乱正常的计算逻辑,它只会针对于最常见的场景自动完成增加维度并复制数据的功能,提高开发效率和运行效率。这种最常见的场景是什么呢?这就要说到 Broadcasting 设计的核心思想。
通过 tf.broadcast_to(x, new_shape)函数可以显式地执行自动扩展功能,将现有 shape 扩
张为 new_shape,实现如下:
A = tf.random.normal([32,1])
print(A)
tf.Tensor(
[[-0.90817887]
[-0.10580511]
[-1.3813894 ]
[ 0.2078821 ]
[-0.31694472]
[ 0.8619327 ]
[ 1.605737 ]
[ 0.304653 ]
[ 0.22030532]
[-0.57627136]
[-1.4633512 ]
[-0.8507311 ]
[ 0.5887055 ]
[-1.1897424 ]
[-0.9097976 ]
[-1.2628037 ]
[ 2.8203614 ]
[ 0.7602069 ]
[ 1.3017323 ]
[-1.5896397 ]
[-0.18809475]
[-0.1437744 ]
[ 1.3868498 ]
[ 0.12247411]
[ 1.3054321 ]
[ 0.15166871]
[ 0.30159867]
[-0.17985806]
[ 0.08241519]
[ 0.63582766]
[-0.01303059]
[ 0.47545263]], shape=(32, 1), dtype=float32)
tf.broadcast_to(A,[2,32,32,3])
<tf.Tensor: id=290, shape=(2, 32, 32, 3), dtype=float32, numpy=
array([[[[-0.90817887, -0.90817887, -0.90817887],
[-0.10580511, -0.10580511, -0.10580511],
[-1.3813894 , -1.3813894 , -1.3813894 ],
...,
[ 0.63582766, 0.63582766, 0.63582766],
[-0.01303059, -0.01303059, -0.01303059],
[ 0.47545263, 0.47545263, 0.47545263]],
[[-0.90817887, -0.90817887, -0.90817887],
[-0.10580511, -0.10580511, -0.10580511],
[-1.3813894 , -1.3813894 , -1.3813894 ],
...,
[ 0.63582766, 0.63582766, 0.63582766],
[-0.01303059, -0.01303059, -0.01303059],
[ 0.47545263, 0.47545263, 0.47545263]],
[[-0.90817887, -0.90817887, -0.90817887],
[-0.10580511, -0.10580511, -0.10580511],
[-1.3813894 , -1.3813894 , -1.3813894 ],
...,
[ 0.63582766, 0.63582766, 0.63582766],
[-0.01303059, -0.01303059, -0.01303059],
[ 0.47545263, 0.47545263, 0.47545263]],
...,
[[-0.90817887, -0.90817887, -0.90817887],
[-0.10580511, -0.10580511, -0.10580511],
[-1.3813894 , -1.3813894 , -1.3813894 ],
...,
[ 0.63582766, 0.63582766, 0.63582766],
[-0.01303059, -0.01303059, -0.01303059],
[ 0.47545263, 0.47545263, 0.47545263]],
[[-0.90817887, -0.90817887, -0.90817887],
[-0.10580511, -0.10580511, -0.10580511],
[-1.3813894 , -1.3813894 , -1.3813894 ],
...,
[ 0.63582766, 0.63582766, 0.63582766],
[-0.01303059, -0.01303059, -0.01303059],
[ 0.47545263, 0.47545263, 0.47545263]],
[[-0.90817887, -0.90817887, -0.90817887],
[-0.10580511, -0.10580511, -0.10580511],
[-1.3813894 , -1.3813894 , -1.3813894 ],
...,
[ 0.63582766, 0.63582766, 0.63582766],
[-0.01303059, -0.01303059, -0.01303059],
[ 0.47545263, 0.47545263, 0.47545263]]],
[[[-0.90817887, -0.90817887, -0.90817887],
[-0.10580511, -0.10580511, -0.10580511],
[-1.3813894 , -1.3813894 , -1.3813894 ],
...,
[ 0.63582766, 0.63582766, 0.63582766],
[-0.01303059, -0.01303059, -0.01303059],
[ 0.47545263, 0.47545263, 0.47545263]],
[[-0.90817887, -0.90817887, -0.90817887],
[-0.10580511, -0.10580511, -0.10580511],
[-1.3813894 , -1.3813894 , -1.3813894 ],
...,
[ 0.63582766, 0.63582766, 0.63582766],
[-0.01303059, -0.01303059, -0.01303059],
[ 0.47545263, 0.47545263, 0.47545263]],
[[-0.90817887, -0.90817887, -0.90817887],
[-0.10580511, -0.10580511, -0.10580511],
[-1.3813894 , -1.3813894 , -1.3813894 ],
...,
[ 0.63582766, 0.63582766, 0.63582766],
[-0.01303059, -0.01303059, -0.01303059],
[ 0.47545263, 0.47545263, 0.47545263]],
...,
[[-0.90817887, -0.90817887, -0.90817887],
[-0.10580511, -0.10580511, -0.10580511],
[-1.3813894 , -1.3813894 , -1.3813894 ],
...,
[ 0.63582766, 0.63582766, 0.63582766],
[-0.01303059, -0.01303059, -0.01303059],
[ 0.47545263, 0.47545263, 0.47545263]],
[[-0.90817887, -0.90817887, -0.90817887],
[-0.10580511, -0.10580511, -0.10580511],
[-1.3813894 , -1.3813894 , -1.3813894 ],
...,
[ 0.63582766, 0.63582766, 0.63582766],
[-0.01303059, -0.01303059, -0.01303059],
[ 0.47545263, 0.47545263, 0.47545263]],
[[-0.90817887, -0.90817887, -0.90817887],
[-0.10580511, -0.10580511, -0.10580511],
[-1.3813894 , -1.3813894 , -1.3813894 ],
...,
[ 0.63582766, 0.63582766, 0.63582766],
[-0.01303059, -0.01303059, -0.01303059],
[ 0.47545263, 0.47545263, 0.47545263]]]], dtype=float32)>
并不是所有的位置都合适,当前维度不满足那么 在最后一个维度就无法使用朴实的原则
A = tf.random.normal([32,2])
数学运算
加、减、乘、除是最基本的数学运算,分别通过 tf.add, tf.subtract, tf.multiply, tf.divide
函数实现,TensorFlow 已经重载了+、 −、 ∗ 、/运算符,一般推荐直接使用运算符来完成
整除和余除也是常见的运算之一,分别通过//和%运算符实现。我们来演示整除运
算,例如:
a = tf.range(5)
a
<tf.Tensor: id=301, shape=(5,), dtype=int32, numpy=array([0, 1, 2, 3, 4])>
b = tf.constant(2)
b
<tf.Tensor: id=302, shape=(), dtype=int32, numpy=2>
a//b #整除运算
<tf.Tensor: id=303, shape=(5,), dtype=int32, numpy=array([0, 0, 1, 1, 2])>
a%b #求余除法
<tf.Tensor: id=304, shape=(5,), dtype=int32, numpy=array([0, 1, 0, 1, 0])>
乘方
通过 tf.pow(x, a)可以方便地完成𝑦 = 𝑎 的乘方运算,也可以通过运算符**实现 ∗∗𝑏运
算,实现如下:
x = tf.range(4)
tf.pow(x,2)
<tf.Tensor: id=310, shape=(4,), dtype=int32, numpy=array([0, 1, 4, 9])>
tf.pow(x,3)
<tf.Tensor: id=312, shape=(4,), dtype=int32, numpy=array([ 0, 1, 8, 27])>
x**2
<tf.Tensor: id=314, shape=(4,), dtype=int32, numpy=array([0, 1, 4, 9])>
平方根
设置指数为 1
𝑎 形式,即可实现√
𝑎
根号运算,例如
x = tf.constant([1.,4.,9.])
x
<tf.Tensor: id=315, shape=(3,), dtype=float32, numpy=array([1., 4., 9.], dtype=float32)>
x**0.5 #设置分数变成求平方根
<tf.Tensor: id=317, shape=(3,), dtype=float32, numpy=array([1., 2., 3.], dtype=float32)>
特别地,对于常见的平方和平方根运算,可以使用 tf.square(x)和 tf.sqrt(x)实现。平方运算
实现如下
x = tf.range(5)
x = tf.cast(x,tf.float32)
x = tf.square(x) #求平方和的运算
x
<tf.Tensor: id=323, shape=(5,), dtype=float32, numpy=array([ 0., 1., 4., 9., 16.], dtype=float32)>
求平方根的运算
tf.sqrt(x)
<tf.Tensor: id=324, shape=(5,), dtype=float32, numpy=array([0., 1., 2., 3., 4.], dtype=float32)>
求指数和对数运算
通过 tf.pow(a, x)或者**运算符也可以方便地实现指数运算𝑏 𝑥 ,例如
x = tf.constant([1.,2.,3.])
2**x #指数运算
<tf.Tensor: id=327, shape=(3,), dtype=float32, numpy=array([2., 4., 8.], dtype=float32)>
特别地,对于自然指数e 𝑥 ,可以通过 tf.exp(x)实现,例如
tf.exp(1.)
<tf.Tensor: id=329, shape=(), dtype=float32, numpy=2.7182817>
在 TensorFlow 中,自然对数log e 可以通过 tf.math.log(x)实现,例如:
x = tf.exp(3.)
x
<tf.Tensor: id=331, shape=(), dtype=float32, numpy=20.085537>
tf.math.log(x)
<tf.Tensor: id=332, shape=(), dtype=float32, numpy=3.0>
如果希望计算其它底数的对数,可以根据对数的换底公式
间接地通过 tf.math.log(x)实现。如计算log 1 可以通过
log e 𝑥
log e 1 实现如下:
x = tf.constant([1.,2.])
x = 10**x
x
<tf.Tensor: id=335, shape=(2,), dtype=float32, numpy=array([ 10., 100.], dtype=float32)>
tf.math.log(x)/tf.math.log(10.)
<tf.Tensor: id=339, shape=(2,), dtype=float32, numpy=array([1., 2.], dtype=float32)>
实现起来相对繁琐,也许 TensorFlow 以后会推出任意底数的 log 函数。
矩阵相乘 矩阵相乘
神经网络中间包含了大量的矩阵相乘运算,前面我们已经介绍了通过@运算符可以方
便的实现矩阵相乘,还可以通过 tf.matmul(a, b)函数实现。需要注意的是,TensorFlow 中的
矩阵相乘可以使用批量方式,也就是张量𝑩和𝑪的维度数可以大于 2。当张量𝑩和𝑪维度数大
于 2 时,TensorFlow 会选择𝑩和𝑪的最后两个维度进行矩阵相乘,前面所有的维度都视作
根据矩阵相乘的定义,𝑩和𝑪能够矩阵相乘的条件是,𝑩的倒数第一个维度长度(列)和𝑪
的倒数第二个维度长度(行)必须相等。比如
a = tf.random.normal([4,3,28,32])
b = tf.random.normal([4,3,32,2])
a@b #可以看出来shape的值为4,3,28,2
<tf.Tensor: id=352, shape=(4, 3, 28, 2), dtype=float32, numpy=
array([[[[-7.35481071e+00, -2.03977299e+00],
[-3.53628492e+00, 5.42532158e+00],
[ 6.02580070e+00, 1.65104294e+01],
[ 3.53147149e+00, 7.73439074e+00],
[ 4.83887196e+00, -1.71686435e+00],
[-5.16285419e+00, 3.85994983e+00],
[-5.72045374e+00, -1.33586657e+00],
[ 5.72763348e+00, -3.38414907e-02],
[-3.40847206e+00, 1.23990459e+01],
[ 7.43876982e+00, -2.31412315e+00],
[ 1.54934788e+00, -3.35335588e+00],
[ 9.40248299e+00, -6.22480106e+00],
[ 1.18808985e+01, -2.53611326e+00],
[ 4.17519283e+00, -1.51334023e+00],
[-6.75919676e+00, -5.21682072e+00],
[-3.36644888e+00, -2.82862759e+00],
[-1.16124134e+01, -1.50398493e+00],
[-8.56001091e+00, -2.28039122e+00],
[-2.40692997e+00, -8.32359695e+00],
[ 7.32393169e+00, -2.05355263e+00],
[ 3.57546616e+00, -6.92506886e+00],
[-4.34608936e+00, 1.60183096e+00],
[-3.37379670e+00, -7.84966469e-01],
[ 1.41078424e+00, -1.48273158e+00],
[ 2.05612922e+00, 2.27115583e+00],
[-4.16214514e+00, 2.89665055e+00],
[ 3.56977487e+00, 2.57852793e-01],
[-5.27147245e+00, 1.78152790e+01]],
[[-1.53355420e+00, 4.69918299e+00],
[ 8.77958107e+00, -2.93876648e+00],
[-5.04937077e+00, 7.56716633e+00],
[-6.23950481e-01, 1.22060800e+00],
[-5.54909992e+00, -4.24554253e+00],
[-6.21147060e+00, 7.91981697e+00],
[ 5.43690538e+00, 6.50116742e-01],
[-7.63458300e+00, 1.25878906e+01],
[ 2.38654304e+00, 5.72002411e-01],
[-5.80140972e+00, 3.75658989e+00],
[-4.31965399e+00, -9.42380333e+00],
[-3.59904003e+00, 2.10305142e+00],
[-2.13570929e+00, -2.44541287e+00],
[-1.50560582e+00, -5.90160751e+00],
[-2.52337933e+00, -3.87742233e+00],
[ 9.41783142e+00, 4.36572647e+00],
[-8.39972973e+00, 6.15722179e-01],
[ 1.08730545e+01, 1.33887696e+00],
[ 5.06330204e+00, 3.21782875e+00],
[-6.71337223e+00, 3.77275920e+00],
[-3.35591054e+00, -7.04978418e+00],
[-2.90958977e+00, 1.96890903e+00],
[ 7.45633173e+00, 3.10271263e-01],
[ 8.74769402e+00, 2.22580385e+00],
[ 3.66506219e+00, -4.80859756e+00],
[ 2.04117036e+00, -3.44947767e+00],
[-1.36152256e+00, -9.40324879e+00],
[-4.52193642e+00, -4.28129435e+00]],
[[-2.50831032e+00, 5.23478985e+00],
[-1.79059589e+00, 1.42886639e-02],
[ 4.81672573e+00, -2.55373359e+00],
[-1.85727954e+00, -1.26235902e+00],
[ 1.90887547e+00, 4.01020527e+00],
[-5.71813703e-01, -1.45860672e-01],
[-1.49302220e+00, 6.27095413e+00],
[-2.04000354e-01, 7.02989578e+00],
[-8.06577563e-01, 1.29722481e+01],
[-2.50476503e+00, 3.45623112e+00],
[-5.66325283e+00, -8.29533386e+00],
[ 9.11501646e-02, 1.42319536e+00],
[-7.96511531e-01, 3.30337143e+00],
[ 2.18581867e+00, 5.92607784e+00],
[-7.77641821e+00, -1.49946880e+00],
[ 3.81387782e+00, 5.70896912e+00],
[-4.37150300e-01, -2.80704999e+00],
[-1.47905374e+00, 3.31149483e+00],
[ 4.66254711e+00, 2.15404606e+00],
[-9.92804432e+00, -4.85740709e+00],
[ 1.45119691e+00, -5.66273355e+00],
[-4.71570492e+00, -5.53786135e+00],
[ 8.12362099e+00, -3.63454318e+00],
[-3.21230531e-01, -5.24273586e+00],
[-4.24352741e+00, -9.85592365e-01],
[-3.18727446e+00, -6.47068024e-01],
[ 1.51570010e+00, 2.30929017e-01],
[ 9.96023178e-01, -5.94730377e+00]]],
[[[ 3.03751183e+00, -6.42512178e+00],
[-1.84599161e-01, 5.21206856e-02],
[ 1.48670034e+01, -7.13164806e-02],
[-1.74293840e+00, -8.03088856e+00],
[ 1.68042183e-01, 7.30098915e+00],
[-9.57644939e-01, 3.60807610e+00],
[-7.70722866e+00, -2.52036333e+00],
[-3.00681496e+00, -2.62908816e-01],
[ 4.30262423e+00, -6.15630817e+00],
[-9.08833313e+00, -6.40916920e+00],
[-3.54467583e+00, -7.44161606e-02],
[-5.45863819e+00, -1.47636938e+00],
[-6.10821724e-01, 3.75593472e+00],
[-8.75724673e-01, -1.60028505e+00],
[ 6.99069202e-01, 1.43948615e+00],
[ 2.76076674e-01, 5.40226173e+00],
[-9.09997106e-01, 2.17277002e+00],
[ 2.51408052e+00, -2.11053777e+00],
[-2.56637502e+00, 4.46516418e+00],
[ 1.41587424e+00, -1.89524055e-01],
[-1.27456951e+00, -8.37762260e+00],
[-3.02583170e+00, 2.95815468e+00],
[-1.98925233e+00, -2.96769798e-01],
[-6.71029449e-01, -5.79849434e+00],
[-2.44489455e+00, -4.04025030e+00],
[-4.49462461e+00, 4.15361166e-01],
[-1.10173664e+01, 2.22191334e+00],
[ 5.70892715e+00, 2.50673294e-03]],
[[-5.78848362e+00, -7.71899462e-01],
[ 3.27780986e+00, -6.83204937e+00],
[ 4.17199969e-01, -1.57258916e+00],
[-3.04634929e+00, -2.26687956e+00],
[ 7.01046038e+00, 4.12341928e+00],
[ 9.55332851e+00, 8.26794147e-01],
[-5.03738594e+00, 1.28693414e+00],
[-9.53421116e-01, -2.69434953e+00],
[ 6.97012901e-01, -8.22899342e-01],
[-1.16427004e+00, 1.10717640e+01],
[ 5.14945221e+00, -1.08178675e+00],
[-1.92237055e+00, -5.68886566e+00],
[ 2.69684005e+00, 6.30944347e+00],
[ 2.73702693e+00, -3.42885971e-01],
[-1.00363684e+00, -1.45203578e+00],
[ 3.49643588e+00, 7.42194653e+00],
[ 2.40598941e+00, -1.51968002e-01],
[-7.96474123e+00, 1.19880247e+00],
[-8.60718060e+00, 3.77753305e+00],
[ 4.76363897e+00, -5.11723852e+00],
[ 1.22043562e+00, 5.38218260e+00],
[-9.00807667e+00, 6.45827866e+00],
[ 1.53577304e+00, 6.81701064e-01],
[ 4.25251627e+00, 1.01356602e+00],
[-4.99286652e+00, -5.23191643e+00],
[-4.74570608e+00, 7.27024603e+00],
[ 8.33272934e+00, -2.77329850e+00],
[-4.92028475e+00, 5.54845524e+00]],
[[-1.53533554e+00, -2.71458387e-01],
[ 8.45319557e+00, 1.20218337e+00],
[-4.67706203e-01, 8.14874291e-01],
[ 8.24496460e+00, 3.64949846e+00],
[-7.87544727e+00, -6.18611622e+00],
[ 3.63813782e+00, -4.85417938e+00],
[ 5.76154590e-01, 7.08875239e-01],
[-1.89167261e+00, 6.65473604e+00],
[ 2.10892391e+00, -2.26659918e+00],
[ 1.25980031e+00, -9.27711964e-01],
[-6.54261684e+00, 7.92043161e+00],
[-2.96644354e+00, -1.74062371e-01],
[-1.87673473e+00, -1.14581308e+01],
[-2.69293833e+00, 1.84136200e+00],
[ 6.82891607e-01, 1.82788730e+00],
[-1.21999300e+00, 4.33571815e+00],
[ 3.00662470e+00, -6.74679852e+00],
[ 1.72100914e+00, 7.30976582e+00],
[ 1.08241348e+01, 7.89593577e-01],
[-3.65544009e+00, -1.14054298e+00],
[-4.04491854e+00, -1.67660010e+00],
[-2.67730665e+00, -4.30352926e+00],
[ 3.18090773e+00, -1.10674324e+01],
[-7.69230652e+00, 1.26239038e+00],
[ 6.02041149e+00, 7.73219109e-01],
[-2.30420113e+00, -5.21923161e+00],
[-1.60385942e+00, 4.00859833e+00],
[ 5.13326979e+00, -2.06486726e+00]]],
[[[-6.49438095e+00, -2.74308777e+00],
[ 6.84007168e+00, 4.25863171e+00],
[ 3.19867730e+00, 6.85937643e-01],
[-4.42984343e-01, -3.40796375e+00],
[-6.44190693e+00, -3.57720184e+00],
[ 2.67074251e+00, 2.91125774e+00],
[ 9.41680551e-01, 2.80589318e+00],
[-1.56111926e-01, 6.88308430e+00],
[-4.55091047e+00, 2.38495564e+00],
[ 2.00643826e+00, -3.26853275e-01],
[ 2.05168486e+00, -1.03192663e+00],
[-4.32541251e-01, -2.43872952e+00],
[ 2.44771957e+00, 5.08524847e+00],
[-4.72986126e+00, 1.49521704e+01],
[-2.73483467e+00, 2.64439583e-01],
[-5.31258774e+00, -3.91468072e+00],
[ 3.68331981e+00, 2.89684153e+00],
[ 9.42306995e+00, 4.12347603e+00],
[-1.04403858e+01, 8.24390793e+00],
[-1.71168458e+00, -8.86144352e+00],
[-2.73308802e+00, -3.02217174e+00],
[ 8.58013034e-01, -2.82139230e+00],
[-6.84384704e-01, -2.07008958e-01],
[-3.53357553e-01, -2.34063053e+00],
[ 2.44999075e+00, -4.10168266e+00],
[ 4.14325356e-01, -3.47654080e+00],
[ 8.98855305e+00, 6.58408356e+00],
[-3.99292350e-01, -1.12091374e+00]],
[[ 3.14571571e+00, 2.99686790e+00],
[-2.21133709e+00, -9.70385790e-01],
[-5.96391106e+00, -9.51159954e-01],
[-4.77849197e+00, 3.71352911e+00],
[ 1.17691517e-01, -5.92812896e-01],
[ 6.44118309e+00, -2.14396906e+00],
[-5.36724186e+00, 3.96309376e-01],
[ 3.85859966e-01, -2.32653213e+00],
[-3.22249293e+00, -6.53439951e+00],
[ 3.02827764e+00, 8.61958981e+00],
[-1.42335320e+00, -3.99454832e+00],
[-6.36955023e+00, -7.28528357e+00],
[ 1.85124636e+00, -5.02364635e-01],
[ 1.02130013e+01, -3.38119030e-01],
[-8.71157646e+00, -1.61155205e+01],
[-5.58173752e+00, -1.24358463e+00],
[-8.18886948e+00, -1.90517139e+00],
[ 6.99361229e+00, 9.25618458e+00],
[-2.89466071e+00, -9.60486507e+00],
[-2.14623737e+00, -5.37127495e+00],
[ 1.89860020e+01, 3.93344164e+00],
[ 7.12473392e-01, 2.13070750e+00],
[-6.60672903e-01, -7.88174772e+00],
[ 2.97913718e+00, -3.12134409e+00],
[-1.07682629e+01, 8.61428452e+00],
[-3.58722234e+00, -3.43406534e+00],
[ 6.41956568e+00, 7.41960239e+00],
[ 7.00065136e+00, 7.96450710e+00]],
[[-6.10993052e+00, -2.18314385e+00],
[ 4.31992340e+00, 1.16279054e+00],
[ 4.96049404e-01, 9.78079891e+00],
[-6.47086668e+00, 7.21457720e+00],
[-3.75694990e+00, 1.06022692e+01],
[ 7.92064857e+00, 3.27664900e+00],
[-1.24585319e+00, -4.76537943e-01],
[ 1.49830711e+00, -5.17059517e+00],
[ 3.37454414e+00, 2.92519665e+00],
[-2.53148031e+00, -7.02141047e+00],
[-7.01026297e+00, -4.41620684e+00],
[-8.86301517e+00, 6.36488676e-01],
[ 4.18034434e-01, 3.57442784e+00],
[ 2.04600620e+00, 1.46792746e+00],
[-1.98461819e+00, -3.14873171e+00],
[ 4.86282635e+00, -5.89713860e+00],
[ 7.50659752e+00, -2.13629532e+00],
[ 1.15753241e+01, -1.28638077e+00],
[ 1.56143200e+00, 3.63091469e+00],
[-6.95650291e+00, 7.40752554e+00],
[ 6.95579767e-01, 9.03932571e-01],
[-1.39583039e+00, 5.24073792e+00],
[ 3.21206069e+00, 3.08843207e+00],
[-1.57232094e+00, -4.91197395e+00],
[ 6.76794648e-01, -6.70850420e+00],
[ 2.22255826e-01, -5.70581102e+00],
[-4.06473112e+00, 5.38560104e+00],
[-5.58861446e+00, -4.31950569e-01]]],
[[[ 8.55044174e+00, -5.81358528e+00],
[-6.04423475e+00, 9.65268326e+00],
[ 6.29490376e+00, -7.54046249e+00],
[-3.34304357e+00, -6.56664467e+00],
[-5.33210802e+00, 3.63116646e+00],
[ 2.37405181e+00, 9.96071434e+00],
[ 1.11147547e+00, -8.46913147e+00],
[-9.01581764e-01, 2.72447109e-01],
[-3.75446606e+00, 1.10781860e+01],
[-4.64594364e+00, 7.05423069e+00],
[-1.21783180e+01, -3.61951852e+00],
[ 5.92875099e+00, -6.47851086e+00],
[-3.90436602e+00, -8.32291412e+00],
[ 3.02298975e+00, 3.54001522e+00],
[-1.80975759e+00, -2.40108633e+00],
[ 1.48073983e+00, 1.17728739e+01],
[-3.31598830e+00, 8.33437157e+00],
[ 3.61406684e+00, 1.03236341e+00],
[-1.77352738e+00, -2.55925179e-01],
[-5.86857939e+00, -8.09783077e+00],
[ 1.68138695e+00, -1.58611975e+01],
[ 2.02576685e+00, 3.71068311e+00],
[ 2.72179627e+00, 1.01250381e+01],
[ 1.18310423e+01, 1.22477646e+01],
[ 7.11085176e+00, -2.96379900e+00],
[ 2.74896646e+00, -6.54157352e+00],
[-1.70857251e+00, -1.34025192e+00],
[-2.88917279e+00, -5.81262493e+00]],
[[ 6.40382385e+00, -3.45796633e+00],
[ 1.31398773e+00, 2.63075757e+00],
[-3.03473091e+00, -4.99808121e+00],
[-5.78912067e+00, -7.22292006e-01],
[ 1.15542126e+01, 7.40390444e+00],
[ 3.43226624e+00, -1.15072193e+01],
[-3.64890933e+00, -1.32880306e+00],
[ 9.70282555e-02, 4.58459139e+00],
[ 3.11939001e+00, 3.80542994e-01],
[-1.15523872e+01, 1.56072068e+00],
[-1.17496367e+01, -1.15079403e+01],
[ 4.40511513e+00, -1.29582715e+00],
[-5.96714973e+00, 9.57839966e+00],
[-8.42221618e-01, -1.81568909e+00],
[-1.52175188e+00, -4.46038198e+00],
[ 6.42946529e+00, -4.65644455e+00],
[ 2.72929668e+00, -5.48132038e+00],
[ 1.06883764e+01, 1.58343887e+00],
[-1.61480904e+00, 5.11785507e-01],
[ 4.32915831e+00, -4.94306517e+00],
[ 5.21513796e+00, -1.24856300e+01],
[-4.30539846e+00, -2.64125538e+00],
[-1.07499623e+00, -3.38630986e+00],
[-7.20928574e+00, -3.56886435e+00],
[-5.18494129e+00, 1.04023104e+01],
[ 7.08348513e+00, 1.00553198e+01],
[-2.69047499e+00, 5.75759792e+00],
[ 8.95052052e+00, -1.09065266e+01]],
[[ 1.86562717e+00, 2.10286880e+00],
[ 7.47056961e+00, -1.72648740e+00],
[-6.61417389e+00, -5.18063927e+00],
[-5.63483238e-02, -4.29674244e+00],
[-1.16115618e+00, 7.33229733e+00],
[ 8.75630188e+00, -1.59543161e+01],
[ 5.34942913e+00, 1.71065235e+00],
[-8.14774132e+00, -5.06303215e+00],
[-8.85612869e+00, 4.85389900e+00],
[-7.58177102e-01, 5.48503923e+00],
[ 2.40571117e+00, 1.14249821e+01],
[ 3.75277305e+00, -1.13059711e+00],
[ 9.77970362e-02, 5.23177290e+00],
[-8.80667591e+00, 1.25938165e+00],
[-4.94224548e+00, 4.98459291e+00],
[ 3.17610693e+00, 1.10556488e+01],
[-8.95340443e-01, 1.32760680e+00],
[-1.48470902e+00, 8.13624001e+00],
[-4.23281789e-01, 6.26393461e+00],
[ 2.31567717e+00, -3.70716834e+00],
[ 3.46101618e+00, -1.32023335e+01],
[-3.69308448e+00, -2.05429411e+00],
[ 9.10969925e+00, 3.59173727e+00],
[ 9.05836487e+00, -5.84303141e+00],
[ 1.19177341e+01, 9.51084805e+00],
[ 7.98901939e+00, -7.79954433e+00],
[ 3.83971834e+00, -3.18983769e+00],
[-3.08059907e+00, 2.53280663e+00]]]], dtype=float32)>
矩阵相乘函数同样支持自动 Broadcasting 机制,例如
a = tf.random.normal([4,28,32])
b = tf.random.normal([32,16])
a@b #此处调用了自动扩展并且相乘
<tf.Tensor: id=365, shape=(4, 28, 16), dtype=float32, numpy=
array([[[ -2.1727865 , -1.1229348 , 0.5529645 , ..., 6.850294 ,
-5.080549 , 1.3870585 ],
[ -0.7776412 , 3.3607047 , -1.3229356 , ..., -0.33060586,
-1.718039 , -12.170118 ],
[ -1.6537416 , 3.1806996 , -4.479284 , ..., -3.0433881 ,
0.23675452, -5.6958113 ],
...,
[ 3.659226 , -1.4129915 , -9.812353 , ..., -0.89915407,
10.7898855 , 2.9662404 ],
[ 4.895536 , -2.4533572 , 2.7401347 , ..., 0.86894536,
-10.848548 , 1.3184551 ],
[ 4.985075 , -0.63191354, -4.784621 , ..., -1.2277026 ,
3.8447945 , -4.181891 ]],
[[ 0.51406497, 5.8773036 , 0.7100219 , ..., -3.0020056 ,
-0.21059486, 2.430263 ],
[ -0.27355823, 6.9391675 , -2.6067553 , ..., 3.5078154 ,
4.6054435 , -1.3159671 ],
[ 0.23995495, -1.2228749 , -1.578268 , ..., 11.343817 ,
2.5552132 , 6.416268 ],
...,
[ 0.7496167 , 2.6149487 , -3.3291793 , ..., -1.6913137 ,
-4.6849575 , -0.07390194],
[ 3.2651026 , -10.189845 , -3.1388736 , ..., 3.340275 ,
-3.5930347 , 5.0638633 ],
[ 14.000745 , -7.6722584 , -4.071188 , ..., -2.4604309 ,
-2.6787014 , 4.1632576 ]],
[[ -2.8267236 , -1.8276253 , 6.803973 , ..., 6.998722 ,
-0.13112731, 1.3310764 ],
[ 2.563221 , -1.0456358 , 4.9088545 , ..., 10.572786 ,
-1.4405905 , 2.2257853 ],
[ 3.8990169 , -1.4742144 , 2.1636624 , ..., -3.6045642 ,
-1.984739 , -6.500014 ],
...,
[ -5.0625362 , -5.6125216 , 5.606078 , ..., -8.171687 ,
-1.0461836 , -0.2186041 ],
[ 0.7595123 , -4.489399 , 7.729659 , ..., -0.26062453,
2.0305562 , 12.910347 ],
[ 0.5906749 , -9.632759 , -1.6593897 , ..., 2.9834595 ,
-0.27359906, 3.4350774 ]],
[[ -1.3969722 , -4.006015 , -0.78274345, ..., -1.9209927 ,
2.7273333 , 3.325152 ],
[ 2.0444467 , -2.7864232 , 2.97246 , ..., 3.4784262 ,
-3.8851411 , 7.043869 ],
[ 4.168339 , -3.8495026 , 2.021028 , ..., -0.03035581,
-9.135693 , 0.9983522 ],
...,
[ 5.2611065 , -1.1056949 , 5.6409683 , ..., 6.9540353 ,
2.3831298 , 6.377899 ],
[ 11.992814 , 2.197509 , 4.5876446 , ..., -2.6564171 ,
-6.70335 , 12.283043 ],
[ 13.1055355 , -8.137001 , 5.915893 , ..., -0.98460674,
-2.5176544 , -1.723175 ]]], dtype=float32)>
上述运算自动将变量 b 扩展为公共 shape:[4,32,16],再与变量 a 进行批量形式地矩阵相
乘,得到结果的 shape 为[4,28,16]
向前传播实战
到现在为止,我们已经介绍了如何创建张量、对张量进行索引切片、维度变换和常见的数学运算等操作。最后我们将利用已经学到的知识去完成三层神经网络的实现
我们采用的数据集是 MNIST 手写数字图片集,输入节点数为 784,第一层的输出节点数是
测试版1205
第 4 章 TensorFlow 基础 36
256,第二层的输出节点数是 128,第三层的输出节点是 10,也就是当前样本属于 10 类别
的概率。
首先创建每个非线性层的𝑿和𝒃张量参数,代码如下:
# 每层的张量都需要被优化,故使用 Variable 类型,并使用截断的正太分布初始化权值张
# 偏置向量初始化为 0 即可
#第一层的参数
w1 = tf.Variable(tf.random.truncated_normal([784,256],stddev=0.1)) #normal函数中 mean代表均值,stddev表示标准差
w1
<tf.Variable 'Variable:0' shape=(784, 256) dtype=float32, numpy=
array([[ 0.05604904, 0.1754755 , -0.15889499, ..., -0.13000603,
0.12375573, 0.11799996],
[ 0.01727579, -0.02990366, 0.18539117, ..., 0.10088935,
0.0123363 , 0.07507049],
[-0.03101482, 0.06787052, -0.17878658, ..., 0.05300893,
0.03232927, 0.14271668],
...,
[ 0.00183942, -0.12539563, -0.00482116, ..., 0.00135334,
-0.08737514, -0.01430448],
[-0.08277363, -0.04015374, -0.06549906, ..., 0.08081874,
-0.00835462, 0.15705827],
[-0.10950299, -0.15212834, 0.15279026, ..., -0.16317515,
-0.17880735, 0.12734374]], dtype=float32)>
b1 = tf.Variable(tf.zeros([256]))
#第二层参数
w2 = tf.Variable(tf.random.truncated_normal([256,128],stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))
#第三层参数
w3 = tf.Variable(tf.random.truncated_normal([128,10],stddev=0.1))
b3 = tf.Variable(tf.zeros([10]))
在前向计算时,首先将 shape 为[𝑐,28,28]的输入张量的视图调整为[𝑐,784],即将每个
图片的矩阵数据调整为向量特征,这样才适合于网络的输入格式
首先需要将输入的向量变换个格式并且在相应的格式上面实现输入
#改变视图[b,28,28]------------>[b,28*28]
from tensorflow import keras
from tensorflow.keras import layers,optimizers,datasets
(x,y),_ = datasets.mnist.load_data()
x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
y = tf.convert_to_tensor(y, dtype=tf.int32)
print(x.shape, y.shape, x.dtype, y.dtype)
print(tf.reduce_min(x), tf.reduce_max(x))
print(tf.reduce_min(y), tf.reduce_max(y))
(60000, 28, 28) (60000,) <dtype: 'float32'> <dtype: 'int32'>
tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(1.0, shape=(), dtype=float32)
tf.Tensor(0, shape=(), dtype=int32) tf.Tensor(9, shape=(), dtype=int32)
x = tf.reshape(x,[-1,28*28])
x
<tf.Tensor: id=616, shape=(60000, 784), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>
完成第一步的操作,显示的进行自动扩展
# 第一层计算,[b, 784]@[784, 256] + [256] => [b, 256] + [256] => [b,256] + [b, 256]
h1 = x@w1+tf.broadcast_to(b1,[x.shape[0],256])
h1 = tf.nn.relu(h1)
h2 = h1@w2+b2
h2 = tf.nn.relu(h2)
out = h2@w3+b3
y
<tf.Tensor: id=606, shape=(60000,), dtype=int32, numpy=array([5, 0, 4, ..., 5, 6, 8])>
out
<tf.Tensor: id=648, shape=(60000, 10), dtype=float32, numpy=
array([[ 0.33768478, 0.7924377 , -0.26547465, ..., 0.7518244 ,
1.0822353 , 0.27058533],
[ 0.50584865, 0.07583661, -0.95454806, ..., 0.8536314 ,
0.5347482 , 0.01661541],
[-0.39949498, 0.5179445 , -0.06265643, ..., 0.56090957,
1.1628726 , 1.0270436 ],
...,
[-0.03918545, -0.03424276, -0.3795507 , ..., 0.64348745,
1.0213664 , -0.242055 ],
[ 0.1548642 , -0.415146 , -0.50202966, ..., 0.6150216 ,
1.1578825 , 0.27539507],
[ 0.16750555, -0.47161245, -0.13367139, ..., 0.7650149 ,
0.5568352 , 0.10468676]], dtype=float32)>
y = tf.one_hot(y,depth=10)
y
<tf.Tensor: id=652, shape=(60000, 10), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
[1., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 1., 0.]], dtype=float32)>
#计算均值方差
#[b,10]
loss = tf.square(y-out)
#误差标量
loss = tf.reduce_mean(loss)
loss
<tf.Tensor: id=662, shape=(), dtype=float32, numpy=0.44232228>
上述的前向计算过程都需要包裹在 with tf.GradientTape() as tape 上下文中,使得前向计算时
能够保存计算图信息,方便自动求导操作。
通过 tape.gradient()函数求得网络参数到梯度信息,结果保存在 grads 列表变量中,实
现如下:
# 自动梯度,需要求梯度的张量有[w1, b1, w2, b2, w3, b3]
grads = tape.gradient(loss,[w1,b1,w2,b2,w3,b3])
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-292-ef23df5cadc0> in <module>
----> 1 grads = tape.gradient(loss,[w1,b1,w2,b2,w3,b3])
NameError: name 'tape' is not defined