Tensorflow2.x 训练网络时的指标输出,以及模型结构图导出

训练指标输出

1. 使用TensorBoard
2. 使用History类

TensorBorad

TensorBoard的Scalars可以可视化这些指标

使用步骤:

记录训练中的指标,需要执行以下操作:

  1. 创建KerasTensorBoard回调
  2. 指定日志目录
  3. 将TensorBoard回调传递Keras的Model.fit()

回调函数:

tf.keras.callbacks.TensorBoard(
           log_dir='logs', histogram_freq=0, write_graph=True, write_images=False,
          update_freq='epoch', profile_batch=2, embeddings_freq=0,
    embeddings_metadata=None, **kwargs
)

参数

  1. log_dir:将要由TensorBoard解析的日志文件保存到的目录路径。
  2. histogram_freq:计算模型各层的激活度和权重直方图的频率。如果设置为0,将不计算直方图。必须为直方图可视化指定验证数据。
  3. write_graph:是否在TensorBoard中可视化图形。当write_graph设置为True时,日志文件可能会变得很大。
  4. write_images:是否编写模型权重以在TensorBoard中可视化为图像。
  5. update_freq:‘batch’或’epoch’或整数。使用时’batch’,每批之后将损失和指标写入TensorBoard。同样适用于’epoch’。如果使用整数,假设1000,回调将每1000批将指标和损失写入TensorBoard。请注意,过于频繁地向TensorBoard写入可能会减慢训练速度。
  6. profile_batch:分析批次以采样计算特征。默认情况下,它将配置第二批。将profile_batch = 0设置为禁用分析。必须在TensorFlow急切模式下运行。
  7. embeddings_freq:嵌入层可视化的频率(以历元计)。如果设置为0,则嵌入将不可见。
  8. embeddings_metadata:将层名称映射到文件名的字典,该嵌入层的元数据保存在该文件名中。查看 有关元数据文件格式的 详细信息。如果相同的元数据文件用于所有嵌入层,则可以传递字符串。

定义好回调函数后,在fit()函数中加入参数
如下:

 logdir = "logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")
 tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
.....(省略掉的代码)
history = model.fit(train_data,epochs=5,validation_data=test_data,validation_freq=1,
             callbacks=[tensorboard_callback]
  )

然后在终端,使用 tensorboard --logdir log/,就会出现下面的信息:

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.1.0 at http://localhost:6007/ (Press CTRL+C to quit)

进入连接即可出现网页,就会显示图表。
模型结构图
在这里插入图片描述

使用History类

这种方式比较简单,

history = model.fit(train_data,epochs=5,validation_data=test_data,validation_freq=1,
             callbacks=[tensorboard_callback]

fit() 会返回一个History的类,它的History.history属性记录了训练时期(每个epoch),训练损失和准确率以及验证损失和验证准确率。
如下:

model.compile(optimizer=keras.optimizers.Adam(lr=0.01),
        loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
)
history = model.fit(train_data,epochs=5,validation_data=test_data,validation_freq=1,
                    # callbacks=[tensorboard_callback]
                    )
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label = 'val_loss')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
# plt.ylim([0.5, 1])
plt.legend(loc='lower right')
plt.show()

在这里插入图片描述

  • 1
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
要将训练好的 TensorFlow 模型保存为 .pb 文件,您可以按照以下步骤进行操作: 1. 定义模型结构:在保存模型之前,您需要定义模型的结构,包括输入和输出节点的名称、形状和数据类型。您可以使用 TensorFlow 的高级 API(如 Keras)或自定义模型来定义模型结构。 2. 加载模型权重:将训练好的模型权重加载到定义的模型结构。这可以通过加载已保存的模型权重文件(如 .h5、.ckpt 等)或通过重新训练模型来实现。 3. 创建 SavedModel:使用 TensorFlow 的 `tf.saved_model.save` 函数将模型保存为 SavedModel 格式。SavedModel 是 TensorFlow 的一种标准模型保存格式,可以包含模型的计算图和变量值。 ```python import tensorflow as tf # 定义和加载模型权重 model = ... # 定义模型结构 model.load_weights('model_weights.h5') # 加载模型权重 # 保存为 SavedModel 格式 tf.saved_model.save(model, 'saved_model') ``` 这将会在指定路径下创建一个名为 `saved_model` 的文件夹,其包含了模型的计算图和变量值。 4. 导出为 .pb 文件:从 SavedModel 导出所需的 .pb 文件。可以使用 TensorFlow 的 `tf.compat.v1.graph_util.convert_variables_to_constants` 函数将 SavedModel 的计算图和变量值转换为常量,并保存为 .pb 文件。 ```python from tensorflow.python.framework import graph_util # 加载 SavedModel saved_model_dir = 'saved_model' saved_model = tf.saved_model.load(saved_model_dir) # 将 SavedModel 转换为 .pb 文件 output_pb_file = 'model.pb' graph_def = graph_util.convert_variables_to_constants( saved_model.sess, saved_model.sess.graph_def, ['output_node_name']) with tf.io.gfile.GFile(output_pb_file, 'wb') as f: f.write(graph_def.SerializeToString()) ``` 将上述代码的 `'output_node_name'` 替换为模型输出节点的名称。 现在,您应该已经成功将训练好的 TensorFlow 模型保存为 .pb 文件。请注意,这只是一个基本示例,具体的实现细节可能因您的模型结构和需求而有所不同。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

FlyDremever

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值