人工智能小白日记 TF学习篇之1 Estimator

前言

为啥又突然跳到这里,主要是最近在探索语音的时候发现一个问题,就是我训练好的模型,如何进行使用?之前的代码写完后都是,训练预测在一起,实际操作中在训练验证完之后,拿来使用肯定不能再次运行去训练,而是应该直接使用训练好的模型,对吧。这个问题tensorflow肯定会给答案。

第二个问题是,在搜集很多资料的时候,发现都是基于tensorflow,好像代码不太一样啊。这是因为tensorflow有不同层次的api,而之前的代码,大部分是直接使用了Estimator api。这是一个高级api,有现成可用的模型,比如前面用过的线性分类器,DNN分类器等。

因为,这些杂七杂八的问题多了,就得系统的去了解下了,不然一脸懵逼。当然,首先是我们比较熟悉的estimator了。
参考 https://www.tensorflow.org/guide/premade_estimators

正文内容

1 Premade Estimators

Estimator直译过来就是评估器,前面用过的诸多评估器,线性分类器,DNN分类器,这些已经封装好的Estimators就属于Premade Estimators。

在这里插入图片描述
另一类是Custom Estimators,也就是自定义的评估器,这类其实也已经用过了,还记得CNN卷积神经网络吗,那个是自主构建起来的,包含卷积层、池化层、密集层和对数层。

后面附上了对鸢尾花进行分类的案例:也就是通过4个特征来匹配3种标签
在这里插入图片描述
使用了预创建的DNN分类器,2层10个神经元,最后实现3分类

# Build a DNN with 2 hidden layers and 10 nodes in each hidden layer.
classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    # Two hidden layers of 10 nodes each.
    hidden_units=[10, 10],
    # The model must choose between 3 classes.
    n_classes=3)

这个已经很熟悉了,不说了。

2 CheckPoints保存和恢复模型

介绍了如何保存和恢复通过 Estimator 构建的 TensorFlow 模型。TensorFlow 提供了两种模型格式:

  • 检查点:这种格式依赖于创建模型的代码。
  • SavedModel:这种格式与创建模型的代码无关。

2-1 保存经过部分训练的模型

Estimator 自动将以下内容写入磁盘:

  • 检查点:训练期间所创建的模型版本。
  • 事件文件:其中包含 TensorBoard 用于创建可视化图表的信息。

ps:这里注意自动,不做任何措施,它也会保存,还记得之前跑了很多模型之后磁盘满了吗?就是因为保存了大量模型的内容。当没有设置保存的位置时,保存在默认位置上,比如mac系统会存储在这样的位置上:
/var/folders/0s/5q9kfzfj3gx2knj0vj8p68yc00dhcr/T/tmpYm1Rwa

1)当然也可以设置保存模型的位置

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris') #保存到当前文件夹下models/iris中

2)保存模型的时机
在这里插入图片描述
默认情况下,Estimator 按照以下时间安排将检查点保存到 model_dir 中:

  • 每 10 分钟(600 秒)写入一个检查点。
  • 在 train 方法开始(第一次迭代)和完成(最后一次迭代)时写入一个检查点。
  • 默认只在目录中保留 5 个最近写入的检查点。

当然,除了开始是在train中这点,其他配置都是是可以修改的

my_checkpointing_config = tf.estimator.RunConfig(
    save_checkpoints_secs = 20*60,  # Save checkpoints every 20 minutes.
    keep_checkpoint_max = 10,       # Retain the 10 most recent checkpoints.
)

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris',
    config=my_checkpointing_config)

3)保存的内容
**保存的文件**

2-2 恢复模型

在这里插入图片描述

注意:如果保存之后,修改了模型配置,比如2层10个神经元[10,10],改成了[50,30,20]。就会导致恢复的时候出错。这个我碰过好几次了?。解决办法很容易,如果之前的过时了,可以直接删掉,如果还需要就换一个位置保存。

由此,可以知道如果想训练完保存,然后拿来用于服务,其实不需要做什么,指定好保存的文件夹,在另一段代码中恢复Estimator即可,然后无论是执行train,evaluate,predict都可以接着使用。

以下,尝试恢复了前面cnn模型,并获取了其中的所有变量

classifier2 = tf.estimator.Estimator(
    model_fn=cnn_model_fn, model_dir="opt/audio_model")

variable_list = classifier2.get_variable_names()
print(variable_list)
for v in variable_list:
    print(v,classifier2.get_variable_value(v))

在这里插入图片描述
可以看出,这些变量包含了每一层上次训练结束后的状态。

2-2 提取某层的输出

那么问题来了,这只是保存了模型和模型变量,我需要它能提取每一层的输出怎么办?
其实cnn_model_fn函数中是我们定义的模型部分,里面有一段
在这里插入图片描述
如果是predict模式下,返回为我定义的predictions,来预测看下结果:

length = 10

predict_input_fn = tf.estimator.inputs.numpy_input_fn(
      x={"x": test_examples[:length]},
      num_epochs=1,
      shuffle=False)

predictions = list(classifier2.predict(input_fn=predict_input_fn))
print('predictions',predictions)
predicted_classes = [p["classes"] for p in predictions]

print(
      "New Samples, Class Predictions:    {}\n"
      .format(predicted_classes))

在这里插入图片描述
可以看到predict调用后会返回完整的predictions给我,那么我当然可以在predictions中添加任何我需要的内容,比如某一层的输出。

所以,修改代码,添加一个输出结果data
在这里插入图片描述
由于最后一层是分类了,所以我的data选取在它前面一点点。
在这里插入图片描述
当啷,大功告成。

3 创建自定义 Estimator

这个暂时用不上,大概了解下就行了:

相对于Premade Estimator,自定义的Estimator肯定更加灵活,不是现成的模型,前面见过吗?有,CNN模型识别手写字和后来做了语音情感分析哪个。

classifier2 = tf.estimator.Estimator(
    model_fn=cnn_model_fn, model_dir="opt/audio_model")

上面就是典型的自定义Estimator的标志tf.estimator.Estimator

模型函数(即 model_fn)会实现机器学习算法。采用预创建的 Estimator 和自定义 Estimator 的唯一区别是:

  • 如果采用预创建的 Estimator,则有人已为您编写了模型函数。
  • 如果采用自定义 Estimator,则您必须自行编写模型函数。
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值