https://tensorflow.google.cn/guide/checkpoint
import tensorflow as tf
# 模型
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
net = Net()
# keras保存权重
net.save_weights('easy_checkpoint') # 从 tf.keras 训练 API 保存
# 加载数据
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
dataset = toy_dataset()
# 更新梯度步骤
def train_step(net, example, optimizer):
"""Trains `net` on `example` using `optimizer`."""
with tf.GradientTape()