TensorFlow2.0 入门笔记(2)

创建自己的网络层

本笔记使用keras,因为——tensorflow2.0建议使用 tf.keras 作为构建神经网络的高级API。

网络层layer的常见操作

tf.keras.layers 官网 link.

tf.keras.layers.Dense(
    units, activation=None, use_bias=True, kernel_initializer='glorot_uniform',
    bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None,
    activity_regularizer=None, kernel_constraint=None, bias_constraint=None,
    **kwargs
)

这里的dense函数实行的操作: output = activation(dot(input, kernel) + bias) where activation is the element-wise activation function passed as the activation argument, kernel is a weights matrix created by the layer, and bias is a bias vector created by the layer (only applicable if use_bias is True).

  • 输入:N-D tensor with shape: (batch_size, …, input_dim). The most common situation would be a 2D input with shape (batch_size, input_dim).
  • 输出:N-D tensor with shape: (batch_size, …, units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).

我们要构造一个简单的全连接网络,只需要指定网络的神经元个数。
下面展示 构造一个简单的全连接网络

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf

# layer = tf.keras.layers.Dense(100) 
# none表示未知数量的样本,20表示每个样本最多有20个输入单元
layer = tf.keras.layers.Dense(100, input_shape=(None, 20))

layer(tf.ones([6, 6]))#定义输入为6*6的矩阵,6个样本。每一层都可以当作一个函数,然后以输入的数据作为函数的输入.
#print(layer.variables)# 打印layer中的所有参数,包含了权重和偏置
print(layer.kernel, layer.bias)  # 也可以分别取出权重和偏置

结果如下:
[<tf.Variable ‘dense_4/kernel:0’ shape=(6, 100) dtype=float32, numpy=
array([[-1.17631599e-01, 1.48546174e-01, -6.12933934e-02,
1.95379272e-01, -7.78010637e-02, -7.37012029e-02,
-1.51678920e-03, -1.56111956e-01, 7.11926073e-02,
-1.37753904e-01, -1.92362726e-01, -5.62874377e-02,
-1.66875511e-01, 8.67448300e-02, 1.13398716e-01,
-1.14426605e-01, -1.22155465e-01, -9.70358700e-02,
-2.28686333e-02, 1.98045239e-01, 1.30869493e-01,
1.21884897e-01, -6.00987375e-02, 2.22415999e-01,
2.11583361e-01, 1.22698352e-01, -1.50723904e-01,
1.73112080e-01, 2.26724848e-01, 1.40978858e-01,
7.25988895e-02, 1.46507517e-01, -1.78162634e-01,
2.23117918e-02, 2.72955149e-02, 1.24480233e-01,
2.37727836e-01, -4.67860848e-02, 1.87744349e-02,
-3.76515687e-02, 5.05203158e-02, -9.94962007e-02,
2.07418337e-01, -6.09074980e-02, 8.31539482e-02,
-2.31687725e-03, -1.74317241e-01, 1.88781783e-01,
-1.37958512e-01, -6.88468069e-02, 2.01815292e-01,
2.16439292e-01, -1.49591118e-02, -1.38965636e-01,
1.86278060e-01, -5.31143248e-02, 9.71562415e-02,
1.87845662e-01, -2.28760928e-01, 1.92107365e-01,
5.91409057e-02, 2.18419090e-01, -2.35403761e-01,
-2.24688232e-01, -1.04708180e-01, -2.00229019e-01,
1.43962130e-01, -1.80692941e-01, 7.19535500e-02,
-3.52407545e-02, -1.24557026e-01, 1.64803788e-01,
-2.35195473e-01, 1.82727382e-01, 3.48411351e-02,
1.18214831e-01, -1.94926798e-01, -2.27142259e-01,
4.94295806e-02, 1.55736655e-02, -2.07010657e-01,
9.34260041e-02, 2.25032255e-01, 2.03277275e-01,
1.02041051e-01, -7.37348348e-02, -1.99376807e-01,
-1.26140118e-01, -7.84487873e-02, -5.88199496e-03,
-3.98484021e-02, 2.16260597e-01, 1.44064292e-01,
1.22691885e-01, -2.28717595e-01, 3.54599506e-02,
1.39095470e-01, -1.82361931e-01, 8.97107273e-02,
-1.13304555e-02],
[-1.09308854e-01, -1.68572620e-01, -2.31104180e-01,
2.16986045e-01, -1.83715373e-02, 1.62551776e-01,
-2.06275284e-01, -1.59556091e-01, -1.40683502e-01,
1.41831085e-01, -8.34729522e-02, 2.91769207e-03,
1.75670877e-01, -5.32847196e-02, 6.74460083e-02,
-2.28857294e-01, -8.12378228e-02, 1.99883416e-01,
-1.66821688e-01, -1.78879052e-01, -1.88803554e-01,
9.88317877e-02, 2.50957757e-02, 1.48942322e-02,
2.85393000e-03, 1.02169231e-01, -1.58371240e-01,
2.02655211e-01, 1.69222221e-01, -1.77333444e-01,
6.16299957e-02, -1.81179792e-02, -6.17364049e-02,
2.55296975e-02, 2.13537142e-01, 1.96126059e-01,
1.05863795e-01, 2.91235298e-02, 1.00673780e-01,
-2.32420102e-01, 9.96763557e-02, 2.02711537e-01,
1.69935003e-01, 1.84762791e-01, -1.76676422e-01,
1.80472061e-01, 6.77751154e-02, -1.36237860e-01,
2.14780137e-01, 2.21189186e-01, -6.79217577e-02,
-1.57148570e-01, -9.54127312e-02, -1.23638615e-01,
9.31884199e-02, 1.69739321e-01, 2.18936309e-01,
-1.09364316e-01, -1.28934890e-01, 2.12897584e-01,
-1.67741954e-01, 8.08878988e-02, 1.35600522e-01,
2.07836881e-01, 4.69127446e-02, 1.44402400e-01,
6.90435916e-02, -1.50300562e-03, -3.69532406e-02,
2.00135365e-01, 5.28295189e-02, -7.80242682e-03,
-1.60380393e-01, 1.61641225e-01, -2.08564848e-02,
6.25326186e-02, -1.88183680e-01, -1.40430033e-03,
-9.06014442e-02, -2.05036968e-01, -1.32530525e-01,
1.57640651e-01, -1.14886634e-01, 7.88477212e-02,
8.14661235e-02, 2.24346325e-01, -1.84483975e-02,
-1.97269812e-01, -1.26243979e-01, 2.17142984e-01,
-2.29154646e-01, -1.43356085e-01, -8.43265206e-02,
-1.96521401e-01, 7.27632195e-02, 2.04385504e-01,
1.71724126e-01, 1.94900587e-01, 2.17406407e-01,
-1.32997781e-02],
[ 1.15572825e-01, 1.94951013e-01, 9.94793624e-02,
-7.17988014e-02, -1.19437844e-01, -1.88682675e-01,
-1.92734897e-01, -1.58719569e-02, -1.97600394e-01,
-6.43629879e-02, -2.79225856e-02, 1.07840970e-01,
1.93278477e-01, -7.73376822e-02, 1.99129775e-01,
-1.89572379e-01, -1.91574037e-01, 1.90130934e-01,
2.24429056e-01, 1.10003248e-01, 2.35215560e-01,
2.20825598e-01, 2.35394850e-01, 2.03409687e-01,
-3.31312716e-02, -1.86333641e-01, 2.16726944e-01,
-4.50666845e-02, -1.94139361e-01, 1.74561307e-01,
-2.19403818e-01, -9.53933895e-02, -1.47832930e-03,
1.79178044e-01, 2.06605092e-01, -7.54690468e-02,
1.39100805e-01, 9.60601717e-02, -1.73031092e-01,
1.99856237e-01, -2.07695648e-01, -2.28156596e-02,
-1.97354898e-01, -1.70395434e-01, -1.02708101e-01,
-1.04656383e-01, 1.82640329e-01, 1.77674457e-01,
-7.42921382e-02, 6.73636049e-02, -1.40101179e-01,
-2.81335860e-02, 1.80193782e-03, 1.47265062e-01,
2.31643334e-01, 6.16405457e-02, -1.36191398e-02,
1.83102801e-01, -1.17873520e-01, 4.82790917e-02,
1.02845594e-01, 2.15772107e-01, -1.45767123e-01,
-1.90595731e-01, 7.43748993e-02, -7.43917525e-02,
1.29948407e-02, -1.48897469e-02, -1.19628318e-01,
1.19773075e-01, -1.23455450e-01, 2.09888071e-02,
5.35250455e-02, -1.14944324e-01, -1.72280744e-01,
-2.12794513e-01, 7.70083517e-02, -4.16543186e-02,
-1.00999311e-01, -9.15705115e-02, 1.57599077e-01,
9.98326689e-02, -1.04623318e-01, -2.35660315e-01,
1.08594760e-01, -1.92205086e-01, 2.01606169e-01,
1.04561433e-01, 1.58496052e-02, -2.16812238e-01,
-1.02460504e-01, -1.08074486e-01, 1.17864609e-02,
2.11085960e-01, 9.39476639e-02, 7.27449507e-02,
3.75273377e-02, 1.72646508e-01, -6.92816973e-02,
1.68074653e-01],
[-1.75571159e-01, 6.66721314e-02, 1.23318687e-01,
1.78280279e-01, 1.78144976e-01, 1.37159839e-01,
2.37730846e-01, -2.03294367e-01, 1.27393082e-01,
2.34248176e-01, 1.71211407e-01, 1.53387636e-02,
1.95342705e-01, 2.36010358e-01, -2.05670163e-01,
-7.56098330e-02, -4.85904515e-03, -9.29171294e-02,
-9.33329612e-02, -2.35219866e-01, -1.79157555e-01,
-1.41761363e-01, 1.84406117e-01, -9.71422344e-02,
-1.52498782e-02, -1.67628080e-02, 4.72012311e-02,
1.39710709e-01, 5.93284816e-02, -1.31768286e-01,
-8.89754146e-02, -2.33653039e-01, -6.26356900e-03,
-1.46797836e-01, 1.47012547e-01, -1.45171970e-01,
-1.26264960e-01, -1.50665924e-01, 1.35475054e-01,
-3.33644599e-02, -1.85384780e-01, -3.93804908e-02,
1.46056309e-01, 2.32378229e-01, 1.01052120e-01,
6.49079680e-04, -2.27865830e-01, -1.98620632e-01,
2.07273796e-01, 2.02781573e-01, 4.13840860e-02,
9.72868055e-02, -4.66622561e-02, -1.99104249e-01,
-1.94209814e-03, 7.90973157e-02, -1.24301612e-02,
1.64668635e-01, -4.00311202e-02, 6.06928617e-02,
2.21985146e-01, 2.34202817e-01, -2.09789366e-01,
-3.54906917e-03, -1.55264437e-01, 1.52498886e-01,
-2.00151145e-01, -1.61675960e-01, -1.49074152e-01,
-8.24124515e-02, -1.66475505e-01, 3.90331894e-02,
-2.08077848e-01, 1.91506997e-01, 4.42447513e-02,
1.45575628e-01, 5.55175990e-02, -9.78842229e-02,
1.58216402e-01, -1.50803030e-01, 8.68480057e-02,
-1.89485595e-01, -1.16623677e-01, -5.50612807e-04,
1.25686675e-02, 2.35135391e-01, 3.52956802e-02,
2.11477354e-01, 1.26629248e-01, 1.73846528e-01,
-5.78700155e-02, -5.97620308e-02, 3.99827808e-02,
-2.23471820e-02, 1.04051232e-02, 2.23558888e-01,
-5.69698811e-02, -1.58035010e-02, 4.88346070e-02,
1.51560456e-02],
[-1.97800457e-01, 9.48752910e-02, -1.04982153e-01,
-1.91326842e-01, -9.19827819e-03, 1.39469907e-01,
8.20310265e-02, 7.41260499e-02, 1.73021093e-01,
1.51284501e-01, -1.22519344e-01, 1.53674498e-01,
9.94798690e-02, 2.04582319e-01, 1.04332492e-01,
2.03575417e-01, -6.68893754e-03, 6.42802864e-02,
1.38232157e-01, -8.60834271e-02, -3.10130417e-02,
1.40305206e-01, 1.20015785e-01, 2.39931196e-02,
1.32594213e-01, 3.12179178e-02, 5.40799946e-02,
7.80204087e-02, 3.65999788e-02, -2.21258402e-01,
-2.19913319e-01, 1.47448346e-01, -1.82560295e-01,
-1.35670513e-01, 1.13405481e-01, 2.34997347e-01,
1.19831517e-01, 3.55754942e-02, 7.28644282e-02,
-4.33939695e-02, -6.19581342e-02, -1.30101055e-01,
-1.31573036e-01, 2.06390962e-01, -1.32028356e-01,
4.82912511e-02, 1.86986819e-01, -2.31462672e-01,
-4.31398451e-02, -2.85604894e-02, -9.71522629e-02,
1.37412712e-01, 4.21222895e-02, -1.10199973e-01,
-8.97502601e-02, 8.65774900e-02, 6.18141145e-02,
1.26255766e-01, 1.69398442e-01, 1.82763025e-01,
1.85254022e-01, -6.47889823e-02, 1.77497551e-01,
2.90944427e-02, -7.22154379e-02, -1.56010523e-01,
1.57919660e-01, -1.13144882e-01, -2.36783326e-01,
1.77768245e-01, 9.11373049e-02, -2.90834904e-02,
6.39334172e-02, -1.94158882e-01, -3.03351432e-02,
-1.11179650e-01, -2.21896023e-02, -1.64513648e-01,
-2.42010653e-02, 3.53455544e-05, -1.01951763e-01,
1.55372605e-01, 2.01222882e-01, 2.01070473e-01,
1.39705643e-01, 7.08507448e-02, -1.89356267e-01,
9.92354751e-03, -4.84340191e-02, 1.96556345e-01,
-1.19340055e-01, -4.40516770e-02, 2.06891581e-01,
-1.30336851e-01, 6.14485294e-02, 2.94484347e-02,
-1.94941089e-01, -1.21488571e-02, -6.04876876e-03,
-2.02569202e-01],
[-1.70355320e-01, -1.52200967e-01, 1.75795957e-01,
1.42669827e-02, -1.51372015e-01, 8.99345428e-02,
-1.71248779e-01, -7.61070848e-03, -1.72162414e-01,
-1.60347044e-01, -6.70864433e-02, 1.51581600e-01,
-1.40151829e-01, 1.99005380e-01, -3.14210057e-02,
1.82582691e-01, -4.57919985e-02, 1.36525497e-01,
-1.67543545e-01, -1.27201408e-01, -1.60716817e-01,
1.74604818e-01, -2.32298777e-01, 2.24723473e-01,
1.99804947e-01, -1.54197589e-01, -4.90170121e-02,
7.25335330e-02, -2.56796777e-02, 2.19955161e-01,
-2.08994269e-01, -2.01605529e-01, -8.05292279e-02,
1.01273105e-01, 2.40100175e-02, -6.87536150e-02,
1.18179098e-01, -1.67855740e-01, -1.83153331e-01,
-2.72006094e-02, 2.03925520e-02, -8.65773112e-02,
2.23498926e-01, 7.72719532e-02, -2.00271517e-01,
2.12461799e-02, 2.75200754e-02, 3.09939831e-02,
1.55598149e-01, -2.12378561e-01, -1.22833595e-01,
1.49044320e-01, 5.89697808e-02, -1.88215330e-01,
2.16184989e-01, 1.81404993e-01, -2.16233894e-01,
2.29941756e-02, 1.10163614e-01, 1.74590513e-01,
-1.54663295e-01, -1.71188205e-01, -1.66733652e-01,
-1.54885352e-01, -2.32766688e-01, 8.06075186e-02,
-9.49109048e-02, -1.91170171e-01, -3.90950590e-02,
-2.08050564e-01, 2.14296713e-01, 2.26463392e-01,
-1.05332702e-01, -2.34366849e-01, -1.03073299e-01,
-1.81952447e-01, -4.09221798e-02, -2.12964118e-01,
1.08327076e-01, -3.87338996e-02, 1.24713287e-01,
-1.83723286e-01, 1.55386403e-01, -8.83301795e-02,
-1.69616103e-01, 5.01530617e-02, -6.18195683e-02,
-2.03908950e-01, 1.83558151e-01, -1.70404911e-01,
1.73030064e-01, 2.12534264e-01, 2.25581393e-01,
1.05500773e-01, -1.03059620e-01, -7.65965879e-03,
7.22167045e-02, -2.29688361e-01, 1.81469172e-02,
-5.53844571e-02]], dtype=float32)>, <tf.Variable ‘dense_4/bias:0’ shape=(100,) dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
dtype=float32)>]

以及
[<tf.Variable ‘dense_4/kernel:0’ shape=(6, 100) dtype=float32, numpy=
array([[-1.17631599e-01, 1.48546174e-01, -6.12933934e-02,
1.95379272e-01, -7.78010637e-02, -7.37012029e-02,
-1.51678920e-03, -1.56111956e-01, 7.11926073e-02,
-1.37753904e-01, -1.92362726e-01, -5.62874377e-02,
-1.66875511e-01, 8.67448300e-02, 1.13398716e-01,
-1.14426605e-01, -1.22155465e-01, -9.70358700e-02,
-2.28686333e-02, 1.98045239e-01, 1.30869493e-01,
1.21884897e-01, -6.00987375e-02, 2.22415999e-01,
2.11583361e-01, 1.22698352e-01, -1.50723904e-01,
1.73112080e-01, 2.26724848e-01, 1.40978858e-01,
7.25988895e-02, 1.46507517e-01, -1.78162634e-01,
2.23117918e-02, 2.72955149e-02, 1.24480233e-01,
2.37727836e-01, -4.67860848e-02, 1.87744349e-02,
-3.76515687e-02, 5.05203158e-02, -9.94962007e-02,
2.07418337e-01, -6.09074980e-02, 8.31539482e-02,
-2.31687725e-03, -1.74317241e-01, 1.88781783e-01,
-1.37958512e-01, -6.88468069e-02, 2.01815292e-01,
2.16439292e-01, -1.49591118e-02, -1.38965636e-01,
1.86278060e-01, -5.31143248e-02, 9.71562415e-02,
1.87845662e-01, -2.28760928e-01, 1.92107365e-01,
5.91409057e-02, 2.18419090e-01, -2.35403761e-01,
-2.24688232e-01, -1.04708180e-01, -2.00229019e-01,
1.43962130e-01, -1.80692941e-01, 7.19535500e-02,
-3.52407545e-02, -1.24557026e-01, 1.64803788e-01,
-2.35195473e-01, 1.82727382e-01, 3.48411351e-02,
1.18214831e-01, -1.94926798e-01, -2.27142259e-01,
4.94295806e-02, 1.55736655e-02, -2.07010657e-01,
9.34260041e-02, 2.25032255e-01, 2.03277275e-01,
1.02041051e-01, -7.37348348e-02, -1.99376807e-01,
-1.26140118e-01, -7.84487873e-02, -5.88199496e-03,
-3.98484021e-02, 2.16260597e-01, 1.44064292e-01,
1.22691885e-01, -2.28717595e-01, 3.54599506e-02,
1.39095470e-01, -1.82361931e-01, 8.97107273e-02,
-1.13304555e-02],
[-1.09308854e-01, -1.68572620e-01, -2.31104180e-01,
2.16986045e-01, -1.83715373e-02, 1.62551776e-01,
-2.06275284e-01, -1.59556091e-01, -1.40683502e-01,
1.41831085e-01, -8.34729522e-02, 2.91769207e-03,
1.75670877e-01, -5.32847196e-02, 6.74460083e-02,
-2.28857294e-01, -8.12378228e-02, 1.99883416e-01,
-1.66821688e-01, -1.78879052e-01, -1.88803554e-01,
9.88317877e-02, 2.50957757e-02, 1.48942322e-02,
2.85393000e-03, 1.02169231e-01, -1.58371240e-01,
2.02655211e-01, 1.69222221e-01, -1.77333444e-01,
6.16299957e-02, -1.81179792e-02, -6.17364049e-02,
2.55296975e-02, 2.13537142e-01, 1.96126059e-01,
1.05863795e-01, 2.91235298e-02, 1.00673780e-01,
-2.32420102e-01, 9.96763557e-02, 2.02711537e-01,
1.69935003e-01, 1.84762791e-01, -1.76676422e-01,
1.80472061e-01, 6.77751154e-02, -1.36237860e-01,
2.14780137e-01, 2.21189186e-01, -6.79217577e-02,
-1.57148570e-01, -9.54127312e-02, -1.23638615e-01,
9.31884199e-02, 1.69739321e-01, 2.18936309e-01,
-1.09364316e-01, -1.28934890e-01, 2.12897584e-01,
-1.67741954e-01, 8.08878988e-02, 1.35600522e-01,
2.07836881e-01, 4.69127446e-02, 1.44402400e-01,
6.90435916e-02, -1.50300562e-03, -3.69532406e-02,
2.00135365e-01, 5.28295189e-02, -7.80242682e-03,
-1.60380393e-01, 1.61641225e-01, -2.08564848e-02,
6.25326186e-02, -1.88183680e-01, -1.40430033e-03,
-9.06014442e-02, -2.05036968e-01, -1.32530525e-01,
1.57640651e-01, -1.14886634e-01, 7.88477212e-02,
8.14661235e-02, 2.24346325e-01, -1.84483975e-02,
-1.97269812e-01, -1.26243979e-01, 2.17142984e-01,
-2.29154646e-01, -1.43356085e-01, -8.43265206e-02,
-1.96521401e-01, 7.27632195e-02, 2.04385504e-01,
1.71724126e-01, 1.94900587e-01, 2.17406407e-01,
-1.32997781e-02],
[ 1.15572825e-01, 1.94951013e-01, 9.94793624e-02,
-7.17988014e-02, -1.19437844e-01, -1.88682675e-01,
-1.92734897e-01, -1.58719569e-02, -1.97600394e-01,
-6.43629879e-02, -2.79225856e-02, 1.07840970e-01,
1.93278477e-01, -7.73376822e-02, 1.99129775e-01,
-1.89572379e-01, -1.91574037e-01, 1.90130934e-01,
2.24429056e-01, 1.10003248e-01, 2.35215560e-01,
2.20825598e-01, 2.35394850e-01, 2.03409687e-01,
-3.31312716e-02, -1.86333641e-01, 2.16726944e-01,
-4.50666845e-02, -1.94139361e-01, 1.74561307e-01,
-2.19403818e-01, -9.53933895e-02, -1.47832930e-03,
1.79178044e-01, 2.06605092e-01, -7.54690468e-02,
1.39100805e-01, 9.60601717e-02, -1.73031092e-01,
1.99856237e-01, -2.07695648e-01, -2.28156596e-02,
-1.97354898e-01, -1.70395434e-01, -1.02708101e-01,
-1.04656383e-01, 1.82640329e-01, 1.77674457e-01,
-7.42921382e-02, 6.73636049e-02, -1.40101179e-01,
-2.81335860e-02, 1.80193782e-03, 1.47265062e-01,
2.31643334e-01, 6.16405457e-02, -1.36191398e-02,
1.83102801e-01, -1.17873520e-01, 4.82790917e-02,
1.02845594e-01, 2.15772107e-01, -1.45767123e-01,
-1.90595731e-01, 7.43748993e-02, -7.43917525e-02,
1.29948407e-02, -1.48897469e-02, -1.19628318e-01,
1.19773075e-01, -1.23455450e-01, 2.09888071e-02,
5.35250455e-02, -1.14944324e-01, -1.72280744e-01,
-2.12794513e-01, 7.70083517e-02, -4.16543186e-02,
-1.00999311e-01, -9.15705115e-02, 1.57599077e-01,
9.98326689e-02, -1.04623318e-01, -2.35660315e-01,
1.08594760e-01, -1.92205086e-01, 2.01606169e-01,
1.04561433e-01, 1.58496052e-02, -2.16812238e-01,
-1.02460504e-01, -1.08074486e-01, 1.17864609e-02,
2.11085960e-01, 9.39476639e-02, 7.27449507e-02,
3.75273377e-02, 1.72646508e-01, -6.92816973e-02,
1.68074653e-01],
[-1.75571159e-01, 6.66721314e-02, 1.23318687e-01,
1.78280279e-01, 1.78144976e-01, 1.37159839e-01,
2.37730846e-01, -2.03294367e-01, 1.27393082e-01,
2.34248176e-01, 1.71211407e-01, 1.53387636e-02,
1.95342705e-01, 2.36010358e-01, -2.05670163e-01,
-7.56098330e-02, -4.85904515e-03, -9.29171294e-02,
-9.33329612e-02, -2.35219866e-01, -1.79157555e-01,
-1.41761363e-01, 1.84406117e-01, -9.71422344e-02,
-1.52498782e-02, -1.67628080e-02, 4.72012311e-02,
1.39710709e-01, 5.93284816e-02, -1.31768286e-01,
-8.89754146e-02, -2.33653039e-01, -6.26356900e-03,
-1.46797836e-01, 1.47012547e-01, -1.45171970e-01,
-1.26264960e-01, -1.50665924e-01, 1.35475054e-01,
-3.33644599e-02, -1.85384780e-01, -3.93804908e-02,
1.46056309e-01, 2.32378229e-01, 1.01052120e-01,
6.49079680e-04, -2.27865830e-01, -1.98620632e-01,
2.07273796e-01, 2.02781573e-01, 4.13840860e-02,
9.72868055e-02, -4.66622561e-02, -1.99104249e-01,
-1.94209814e-03, 7.90973157e-02, -1.24301612e-02,
1.64668635e-01, -4.00311202e-02, 6.06928617e-02,
2.21985146e-01, 2.34202817e-01, -2.09789366e-01,
-3.54906917e-03, -1.55264437e-01, 1.52498886e-01,
-2.00151145e-01, -1.61675960e-01, -1.49074152e-01,
-8.24124515e-02, -1.66475505e-01, 3.90331894e-02,
-2.08077848e-01, 1.91506997e-01, 4.42447513e-02,
1.45575628e-01, 5.55175990e-02, -9.78842229e-02,
1.58216402e-01, -1.50803030e-01, 8.68480057e-02,
-1.89485595e-01, -1.16623677e-01, -5.50612807e-04,
1.25686675e-02, 2.35135391e-01, 3.52956802e-02,
2.11477354e-01, 1.26629248e-01, 1.73846528e-01,
-5.78700155e-02, -5.97620308e-02, 3.99827808e-02,
-2.23471820e-02, 1.04051232e-02, 2.23558888e-01,
-5.69698811e-02, -1.58035010e-02, 4.88346070e-02,
1.51560456e-02],
[-1.97800457e-01, 9.48752910e-02, -1.04982153e-01,
-1.91326842e-01, -9.19827819e-03, 1.39469907e-01,
8.20310265e-02, 7.41260499e-02, 1.73021093e-01,
1.51284501e-01, -1.22519344e-01, 1.53674498e-01,
9.94798690e-02, 2.04582319e-01, 1.04332492e-01,
2.03575417e-01, -6.68893754e-03, 6.42802864e-02,
1.38232157e-01, -8.60834271e-02, -3.10130417e-02,
1.40305206e-01, 1.20015785e-01, 2.39931196e-02,
1.32594213e-01, 3.12179178e-02, 5.40799946e-02,
7.80204087e-02, 3.65999788e-02, -2.21258402e-01,
-2.19913319e-01, 1.47448346e-01, -1.82560295e-01,
-1.35670513e-01, 1.13405481e-01, 2.34997347e-01,
1.19831517e-01, 3.55754942e-02, 7.28644282e-02,
-4.33939695e-02, -6.19581342e-02, -1.30101055e-01,
-1.31573036e-01, 2.06390962e-01, -1.32028356e-01,
4.82912511e-02, 1.86986819e-01, -2.31462672e-01,
-4.31398451e-02, -2.85604894e-02, -9.71522629e-02,
1.37412712e-01, 4.21222895e-02, -1.10199973e-01,
-8.97502601e-02, 8.65774900e-02, 6.18141145e-02,
1.26255766e-01, 1.69398442e-01, 1.82763025e-01,
1.85254022e-01, -6.47889823e-02, 1.77497551e-01,
2.90944427e-02, -7.22154379e-02, -1.56010523e-01,
1.57919660e-01, -1.13144882e-01, -2.36783326e-01,
1.77768245e-01, 9.11373049e-02, -2.90834904e-02,
6.39334172e-02, -1.94158882e-01, -3.03351432e-02,
-1.11179650e-01, -2.21896023e-02, -1.64513648e-01,
-2.42010653e-02, 3.53455544e-05, -1.01951763e-01,
1.55372605e-01, 2.01222882e-01, 2.01070473e-01,
1.39705643e-01, 7.08507448e-02, -1.89356267e-01,
9.92354751e-03, -4.84340191e-02, 1.96556345e-01,
-1.19340055e-01, -4.40516770e-02, 2.06891581e-01,
-1.30336851e-01, 6.14485294e-02, 2.94484347e-02,
-1.94941089e-01, -1.21488571e-02, -6.04876876e-03,
-2.02569202e-01],
[-1.70355320e-01, -1.52200967e-01, 1.75795957e-01,
1.42669827e-02, -1.51372015e-01, 8.99345428e-02,
-1.71248779e-01, -7.61070848e-03, -1.72162414e-01,
-1.60347044e-01, -6.70864433e-02, 1.51581600e-01,
-1.40151829e-01, 1.99005380e-01, -3.14210057e-02,
1.82582691e-01, -4.57919985e-02, 1.36525497e-01,
-1.67543545e-01, -1.27201408e-01, -1.60716817e-01,
1.74604818e-01, -2.32298777e-01, 2.24723473e-01,
1.99804947e-01, -1.54197589e-01, -4.90170121e-02,
7.25335330e-02, -2.56796777e-02, 2.19955161e-01,
-2.08994269e-01, -2.01605529e-01, -8.05292279e-02,
1.01273105e-01, 2.40100175e-02, -6.87536150e-02,
1.18179098e-01, -1.67855740e-01, -1.83153331e-01,
-2.72006094e-02, 2.03925520e-02, -8.65773112e-02,
2.23498926e-01, 7.72719532e-02, -2.00271517e-01,
2.12461799e-02, 2.75200754e-02, 3.09939831e-02,
1.55598149e-01, -2.12378561e-01, -1.22833595e-01,
1.49044320e-01, 5.89697808e-02, -1.88215330e-01,
2.16184989e-01, 1.81404993e-01, -2.16233894e-01,
2.29941756e-02, 1.10163614e-01, 1.74590513e-01,
-1.54663295e-01, -1.71188205e-01, -1.66733652e-01,
-1.54885352e-01, -2.32766688e-01, 8.06075186e-02,
-9.49109048e-02, -1.91170171e-01, -3.90950590e-02,
-2.08050564e-01, 2.14296713e-01, 2.26463392e-01,
-1.05332702e-01, -2.34366849e-01, -1.03073299e-01,
-1.81952447e-01, -4.09221798e-02, -2.12964118e-01,
1.08327076e-01, -3.87338996e-02, 1.24713287e-01,
-1.83723286e-01, 1.55386403e-01, -8.83301795e-02,
-1.69616103e-01, 5.01530617e-02, -6.18195683e-02,
-2.03908950e-01, 1.83558151e-01, -1.70404911e-01,
1.73030064e-01, 2.12534264e-01, 2.25581393e-01,
1.05500773e-01, -1.03059620e-01, -7.65965879e-03,
7.22167045e-02, -2.29688361e-01, 1.81469172e-02,
-5.53844571e-02]], dtype=float32)>, <tf.Variable ‘dense_4/bias:0’ shape=(100,) dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
dtype=float32)>]

自定义网络层

实现自己的层的最佳方法是扩展tf.keras.Layer类:link

tf.keras.layers.Layer(
    trainable=True, name=None, dtype=None, dynamic=False, **kwargs
)
  • _ init _() : Save configuration in member variables,可以在其中执行所有与输入无关的初始化,也可以在这里创建网络结构,但意味着要明确制定变量的shape。

  • build() : Called once from _ call _, when we know the shapes of inputs and dtype. Should have the calls to add_weight(), and then call the super’s build() (which sets self.built = True, which is nice in case the user wants to call build() manually before the first _ call _). 可以获得输入张量的形状,并可以进行其余的初始化在build()中创建网络结构的优点是它可以根据图层将要操作的输入的形状启用后期的网络构建。

  • call() : Called in _ call _ after making sure build() has been called once. Should actually perform the logic of applying the layer to the input tensors (which should be passed in as the first argument). 构建网络结构,进行前向传播

  • Attributes:

    • activity_regularizer: Optional regularizer function for the output of this layer.

    • dtype

    • dynamic

    • input: Retrieves the input tensor(s) of a layer.Only applicable if the layer has exactly one input, i.e. if it is connected to one incoming layer.

    • input_mask: Retrieves the input mask tensor(s) of a layer. Only applicable if the layer has exactly one inbound node, i.e. if it is connected to one incoming layer.

    • input_shape: Retrieves the input shape(s) of a layer. Only applicable if the layer has exactly one input, i.e. if it is connected to one incoming layer, or if all inputs have the same shape.

    • input_spec

    • losses: Losses which are associated with this Layer.
      Variable regularization tensors are created when this property is accessed, so it is eager safe: accessing losses under a tf.GradientTape will propagate gradients back to the corresponding variables.

    • metrics

    • name: Returns the name of this module as passed or determined in the ctor. NOTE: This is not the same as the self.name_scope.name which includes parent module names.

    • non_trainable_variables

    • non_trainable_weights

    • output: Retrieves the output tensor(s) of a layer. Only applicable if the layer has exactly one output, i.e. if it is connected to one incoming layer.

    • output_mask: Retrieves the output mask tensor(s) of a layer.Only applicable if the layer has exactly one inbound node, i.e. if it is connected to one incoming layer.

    • output_shape: Retrieves the output shape(s) of a layer. Only applicable if the layer has one output, or if all outputs have the same shape.

    • trainable

    • trainable_variables: Sequence of variables owned by this module and it’s submodules.

    • trainable_weights

    • updates

    • variables: Returns the list of all layer variables/weights.

    Alias of self.weights.

    • weights: Returns the list of all layer variables/weights.

下面展示 自定义网络层

class MyDense(tf.keras.layers.Layer):
    def __init__(self, n_outputs):
        super(MyDense, self).__init__()
        self.n_outputs = n_outputs
        
    def build(self, input_shape):
        self.kernel = self.add_weight('kernel',
                                       shape=[int(input_shape[-1]),
                                             self.n_outputs]) 
        
    def call(self, input):
        return tf.matmul(input, self.kernel)
    
layer = MyDense(10)# 有10个神经元
print(layer(tf.ones([6, 5]))) #上一层有6个样本,每个样本有5个输入,输出张量是6*10
print(layer.trainable_variables) # 打印可训练的参数,这层有10个神经元,和上层的5个输入,总参数为5*10

结果:
tf.Tensor(
[[ 0.46496624 -1.8100357 -0.29889947 -1.5847251 -0.23105323 -0.39976877
-0.30623198 1.4322932 -1.2614003 -1.5413039 ]
[ 0.46496624 -1.8100357 -0.29889947 -1.5847251 -0.23105323 -0.39976877
-0.30623198 1.4322932 -1.2614003 -1.5413039 ]
[ 0.46496624 -1.8100357 -0.29889947 -1.5847251 -0.23105323 -0.39976877
-0.30623198 1.4322932 -1.2614003 -1.5413039 ]
[ 0.46496624 -1.8100357 -0.29889947 -1.5847251 -0.23105323 -0.39976877
-0.30623198 1.4322932 -1.2614003 -1.5413039 ]
[ 0.46496624 -1.8100357 -0.29889947 -1.5847251 -0.23105323 -0.39976877
-0.30623198 1.4322932 -1.2614003 -1.5413039 ]
[ 0.46496624 -1.8100357 -0.29889947 -1.5847251 -0.23105323 -0.39976877
-0.30623198 1.4322932 -1.2614003 -1.5413039 ]], shape=(6, 10), dtype=float32)
[<tf.Variable ‘my_dense_4/kernel:0’ shape=(5, 10) dtype=float32, numpy=
array([[ 0.5684877 , 0.05711246, 0.12932414, -0.5744685 , -0.25184456,
-0.46761584, -0.2980388 , 0.5195171 , -0.41697842, -0.5735689 ],
[-0.51820725, -0.5548857 , 0.45287377, -0.45026648, -0.3958758 ,
-0.24954188, 0.26821226, 0.12655854, -0.54505545, -0.5809149 ],
[ 0.4248591 , -0.2728017 , -0.40157926, 0.01894468, -0.03666633,
0.5274741 , -0.2039176 , 0.5560588 , -0.08723891, 0.150118 ],
[ 0.230798 , -0.4505319 , -0.08260292, -0.49669462, 0.08722687,
-0.6181883 , 0.0394327 , 0.22648138, 0.23067093, -0.2805322 ],
[-0.24097133, -0.5889289 , -0.3969152 , -0.0822404 , 0.36610657,
0.40810317, -0.11192054, 0.00367731, -0.4427985 , -0.2564058 ]],
dtype=float32)>]

一个实例

以残差网络Resnet为例。若将输入设为X,将某一有参网络层设为H,那么以X为输入的此层的输出将为H(X)。一般的CNN网络如Alexnet/VGG等会直接通过训练学习出参数函数H的表达,从而直接学习X -> H(X)。残差学习则是致力于使用多个有参网络层来学习输入、输出之间的参差即H(X) - X即学习X -> (H(X) - X) + X。其中X这一部分为直接的identity mapping,而H(X) - X则为有参网络层要学习的输入输出间残差。
residual learning
本实例用keras.model实现。每个残差块就是“卷积+批正则化+残差连接”的组合。

  • tf.keras.layers.Conv2D link 将二维向量进行卷积
tf.keras.layers.Conv2D(
    filters, kernel_size, strides=(1, 1), padding='valid', data_format=None,
    dilation_rate=(1, 1), activation=None, use_bias=True,
    kernel_initializer='glorot_uniform', bias_initializer='zeros',
    kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None,
    kernel_constraint=None, bias_constraint=None, **kwargs
)

cov2d这个层创建了一个卷积核,将输入进行卷积来输出一个 tensor。

  • tf.keras.layers.BatchNormalization link
tf.keras.layers.BatchNormalization(
   axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True,
   beta_initializer='zeros', gamma_initializer='ones',
   moving_mean_initializer='zeros', moving_variance_initializer='ones',
   beta_regularizer=None, gamma_regularizer=None, beta_constraint=None,
   gamma_constraint=None, renorm=False, renorm_clipping=None, renorm_momentum=0.99,
   fused=None, trainable=True, virtual_batch_size=None, adjustment=None, name=None,
   **kwargs
)

下面是一个残差学习块的实现:

class ResnetBlock(tf.keras.Model):
    def __init__(self, kernel_size, filters):
        super(ResnetBlock, self).__init__(name='resnet_block')
        
        # 每个子层卷积核数
        filter1, filter2, filter3 = filters
        
        # 三个子层,每层1个卷积加一个批正则化
        # 第一个子层, 1*1的卷积
        self.conv1 = tf.keras.layers.Conv2D(filter1, (1,1))
        self.bn1 = tf.keras.layers.BatchNormalization()
        # 第二个子层, 使用特定的kernel_size(22)
        self.conv2 = tf.keras.layers.Conv2D(filter2, kernel_size, padding='same')
        self.bn2 = tf.keras.layers.BatchNormalization()
        # 第三个子层,1*1卷积
        self.conv3 = tf.keras.layers.Conv2D(filter3, (1,1))
        self.bn3 = tf.keras.layers.BatchNormalization()
        
    def call(self, inputs, training=False):
        
        # 堆叠每个子层
        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        
        x = self.conv2(x)
        x = self.bn2(x, training=training)
        
        x = self.conv3(x)
        x = self.bn3(x, training=training)
        
        # 残差连接
        x += inputs
        outputs = tf.nn.relu(x)
        
        return outputs

resnetBlock = ResnetBlock(2, [6,4,9])
# 数据测试
print(resnetBlock(tf.ones([1,3,9,9])))
# 查看网络中的变量名
print([x.name for x in resnetBlock.trainable_variables])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值