tf.keras是什么?tf.keras怎样实现深度学习?

tf.keras是TensorFlow 2.0的高阶API接口,为TensorFlow的代码提供了新的风格和设计模式,大大提升了TF代码的简洁性和复用性,官方也推荐使用tf.keras来进行模型设计和开发。

常用模块

tf.keras中常用模块如下表所示:

模块概述
activations激活函数
applications预训练网络模块
Callbacks在模型训练期间被调用
datasetstf.keras数据集模块,包括boston_housing,cifar10,fashion_mnist,imdb ,mnist
layersKeras层API
losses各种损失函数
metircs各种评价指标
models模型创建模块,以及与模型相关的API
optimizers优化方法
preprocessingKeras数据的预处理模块
regularizers正则化,L1,L2等
utils辅助功能实现

常用方法

深度学习实现的主要流程:1.数据获取,2,数据处理,3.模型创建与训练,4 模型测试与评估,5.模型预测。

1.导入tf.keras

使用 tf.keras,首先需要在代码开始时导入tf.keras。

<span style="background-color:#2d2d2d"><span style="color:#cccccc"><code class="language-jfx">import tensorflow as tf
from tensorflow import keras</code></span></span>

2.数据输入

对于小的数据集,可以直接使用numpy格式的数据进行训练、评估模型,对于大型数据集或者要进行跨设备训练时使用tf.data.datasets来进行数据输入。

3.模型构建

  • 简单模型使用Sequential进行构建
  • 复杂模型使用函数式编程来构建
  • 自定义layers

4.训练与评估

  • 配置训练过程:
<span style="background-color:#2d2d2d"><span style="color:#cccccc"><code class="language-js"># 配置优化方法,损失函数和评价指标
model<span style="color:#cccccc">.</span><span style="color:#f08d49">compile</span><span style="color:#cccccc">(</span>optimizer<span style="color:#67cdcc">=</span>tf<span style="color:#cccccc">.</span>train<span style="color:#cccccc">.</span><span style="color:#f08d49">AdamOptimizer</span><span style="color:#cccccc">(</span><span style="color:#f08d49">0.001</span><span style="color:#cccccc">)</span><span style="color:#cccccc">,</span>
              loss<span style="color:#67cdcc">=</span><span style="color:#7ec699">'categorical_crossentropy'</span><span style="color:#cccccc">,</span>
              metrics<span style="color:#67cdcc">=</span><span style="color:#cccccc">[</span><span style="color:#7ec699">'accuracy'</span><span style="color:#cccccc">]</span><span style="color:#cccccc">)</span></code></span></span>

模型训练:

<span style="background-color:#2d2d2d"><span style="color:#cccccc"><code class="language-js"># 指明训练数据集,训练epoch<span style="color:#cccccc">,</span>批次大小和验证集数据model<span style="color:#cccccc">.</span>fit<span style="color:#67cdcc">/</span><span style="color:#f08d49">fit_generator</span><span style="color:#cccccc">(</span>dataset<span style="color:#cccccc">,</span> epochs<span style="color:#67cdcc">=</span><span style="color:#f08d49">10</span><span style="color:#cccccc">,</span> 
                        batch_size<span style="color:#67cdcc">=</span><span style="color:#f08d49">3</span><span style="color:#cccccc">,</span>
          validation_data<span style="color:#67cdcc">=</span>val_dataset<span style="color:#cccccc">,</span>
          <span style="color:#cccccc">)</span></code></span></span>

模型评估:

<span style="background-color:#2d2d2d"><span style="color:#cccccc"><code class="language-js"># 指明评估数据集和批次大小
model<span style="color:#cccccc">.</span><span style="color:#f08d49">evaluate</span><span style="color:#cccccc">(</span>x<span style="color:#cccccc">,</span> y<span style="color:#cccccc">,</span> batch_size<span style="color:#67cdcc">=</span><span style="color:#f08d49">32</span><span style="color:#cccccc">)</span></code></span></span>

模型预测:

<span style="background-color:#2d2d2d"><span style="color:#cccccc"><code class="language-js"># 对新的样本进行预测
model<span style="color:#cccccc">.</span><span style="color:#f08d49">predict</span><span style="color:#cccccc">(</span>x<span style="color:#cccccc">,</span> batch_size<span style="color:#67cdcc">=</span><span style="color:#f08d49">32</span><span style="color:#cccccc">)</span></code></span></span>

5.回调函数(callbacks)

回调函数用在模型训练过程中,来控制模型训练行为,可以自定义回调函数,也可使用tf.keras.callbacks 内置的 callback :

ModelCheckpoint:定期保存 checkpoints。 LearningRateScheduler:动态改变学习速率。 EarlyStopping:当验证集上的性能不再提高时,终止训练。 TensorBoard:使用 TensorBoard 监测模型的状态。

6.模型的保存和恢复

只保存参数:

<span style="background-color:#2d2d2d"><span style="color:#cccccc"><code class="language-js"># 只保存模型的权重
model<span style="color:#cccccc">.</span><span style="color:#f08d49">save_weights</span><span style="color:#cccccc">(</span><span style="color:#7ec699">'./my_model'</span><span style="color:#cccccc">)</span>
# 加载模型的权重
model<span style="color:#cccccc">.</span><span style="color:#f08d49">load_weights</span><span style="color:#cccccc">(</span><span style="color:#7ec699">'my_model'</span><span style="color:#cccccc">)</span></code></span></span>

保存整个模型:

<span style="background-color:#2d2d2d"><span style="color:#cccccc"><code class="language-js"># 保存模型架构与权重在h5文件中
model<span style="color:#cccccc">.</span><span style="color:#f08d49">save</span><span style="color:#cccccc">(</span><span style="color:#7ec699">'my_model.h5'</span><span style="color:#cccccc">)</span>
# 加载模型:包括架构和对应的权重
model <span style="color:#67cdcc">=</span> keras<span style="color:#cccccc">.</span>models<span style="color:#cccccc">.</span><span style="color:#f08d49">load_model</span><span style="color:#cccccc">(</span><span style="color:#7ec699">'my_model.h5'</span><span style="color:#cccccc">)</span></code></span></span>
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值