TensorFlow下的API结构
前言:
在使用tensorflow
的函数时,对它整体API的结构比较模糊,搜索了一遍之后官方文档解答了我的疑惑,以下为小总结,如有错误欢迎指正。
一、tf 下面有三部分内容:模块、类、常用的函数
|--- tf
|---- 大模块
tf.nn,神经网络模块
tf.keras,高阶API
tf.math,数学工具模块
tf.losses,计算误差
tf.data,数据模块
tf.random
tf.summary,展示模块信息
tf.train,训练函数
tf.contrib,实验性质常变动的函数模块
....
|--- 类
Dtype
Variable
...
|--- 常用的函数
tf.argmax()
tf.add()
tf.constant()
tf.one_hot()
tf.cast()
tf.reduce_mean()
tf.square()
...
二、其中像比较常用的tf.keras
中
tf,keras.datasets(下载数据集)
tf.keras.metrics(计算精度,评估性能)
tf.keras.layers
自定义层
tf.keras.layers.Layer
tf.keras.Squentital
tf.keras.Model
1. Model
母类中有针对训练的函数
compile(),训练
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
fit(),训练期间测试
network.fit(db, epochs=5, validation_data=ds_val, validation_freq=2)
evaluate(),训练结束最终评估
network.evaluate(ds_val)
predict(),预测值
2. tf.metrics
中测量三步走
1、生成测量器
acc_meter = metrics.Accuracy()
loss_meter = metrics.Mean()
2、喂数据
loss_meter.update_state(loss)
acc_meter.update_state(y, pred)
3、取结果
loss_meter.result().numpy()
acc_meter.result().numpy()
4、一个迭代之后重置
loss_meter.reset_states()
acc_meter.reset_states()
三、数据处理的tf.data.Dataset
下的
tf.data.Dataset.from_tensor_slices(),切分数据
tf.data.Dataset.shuffle(),打散数据
tf.data.Dataset.map(),预处理
tf.data.Dataset.batch(),分批
tf.data.Dataset.repeat(),重复迭代
四、Tensorflow构建神经网络和全连接层常用的函数
1.数据集操作
- 数据加载,返回numpy类型的数据
(x_data,y_data),(x_test,y_test) = tf.keras.datasets.mnist.load_data()
- 数据预处理(类型转换。。)
- 打散、分批
2.搭建网络层
- 搭建线性神经网络层,获取输出
tf.keras.Squentital()
- 对输出数据进行
tf.nn.softmax()
,将数值映射到0-1
,并且和为1
(这是和tf.softmax
的区别)
3.计算误差
- MSE:
tf.losses.MSE()
有时用MSE会出现梯度消失的情况,所以交叉熵也很好 - 交叉熵:
tf.losses.categorical_cossentropy()
,交叉熵越小,说明信息量越大,不可知的东西多,既误差很大
推荐:Tensorflow的龙良曲老师GitHub
老师讲的很好,资料很全,能让自己学习更清晰