input打加载json文件_TensorFlow2学习四、Keras 保存和加载模型

0b7cadd9388dce34ccb91ec1fa6dbd43.png

1. 权重保存和加载

# 保存为TensorFlow checkpoint格式model.save_weights('./my_model')# 保存为TensorFlow HDF5格式model.save_weights('model.h5', save_format='h5')# 加载model.load_weights('my_model')

2. 保存和加载网络结构

保存一个模型的配置,序列化过程中不包含权重。保存的配置可以用来重新创建、初始化出相同的模型,即使没有模型原始的定义代码。 Keras支持JSON、YAML序列化格式。

保存

import tensorflow as tfimport numpy as npfrom tensorflow import kerasimport jsonmodel = tf.keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])model.compile(optimizer='sgd', loss='mean_squared_error')xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)ys = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)model.fit(xs, ys, epochs=100)print(model.predict([10.0]))json_string = model.to_json()print(json_string)'''{"class_name": "Sequential", "config": {"name": "sequential_8", "layers": [{"class_name": "Dense", "config": {"name": "dense_8", "trainable": true, "batch_input_shape": [null, 1], "dtype": "float32", "units": 1, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null, "dtype": "float32"}}, "bias_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}, "keras_version": "2.2.4-tf", "backend": "tensorflow"}'''with open("model.json","w") as json_file: json_file.write(json_string)

加载

# load json and create modeljson_file = open('model.json', 'r')loaded_model_json = json_file.read()json_file.close()loaded_model = model_from_json(loaded_model_json)

如果使用yaml格式,将model.to_json()换成model.to_yaml(),model_from_json()换成model_from_yaml()

3. 保存整个模型

整个模型可以保存到一个文件里,包含:权重、模型配置、优化器配置等。
可以保存状态后从完全相同的状态恢复训练。

# Create a trivial modelmodel = keras.Sequential([ keras.layers.Dense(10, activation='softmax', input_shape=(32,)), keras.layers.Dense(10, activation='softmax')])model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])model.fit(data, targets, batch_size=32, epochs=5)# Save entire model to a HDF5 filemodel.save('my_model.h5')# Recreate the exact same model, including weights and optimizer.model = keras.models.load_model('my_model.h5')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值