模型构建
class Encoder(layers.Layer):
def __init__(self, latent_dim=32, intermediate_dim=64, name="encoder", **kwargs):
super(Encoder, self).__init__(name=name, **kwargs)
'''
w_init = tf.random_normal_initializer()
self.w = tf.Variable(initial_value=w_init(shape=(input_dim, units), dtype="float32"),trainable=True)
b_init = tf.zeros_initializer()
self.b = tf.Variable(initial_value=b_init(shape=(units,), dtype="float32"), trainable=True)
'''
# 简洁写法
self.w = self.add_weight(shape=(input_dim, units), initializer="random_normal", trainable=True)
self.b = self.add_weight(shape=(units,), initializer="zeros", trainable=True)
# 可具有不可训练权重
self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), trainable=False)
# 可以延迟权重创建在得知输出形状后:https://www.tensorflow.org/guide/keras/custom_layers_and_models
def call(self, inputs):
# ...
class Decoder(layers.Layer):
def __init__(self, original_dim, intermediate_dim=64, name="decoder", **kwargs):
super(Decoder, self).__init__(name=name, **kwargs)
self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
self.dense_output = layers.Dense(original_dim, activation="sigmoid")
def call(self, inputs):
x = self.dense_proj(inputs)
return self.dense_output(x)
class VariationalAutoEncoder(keras.Model):
def __init__(self,original_dim,intermediate_dim=64,latent_dim=32,name="autoencoder",**kwargs):
super(VariationalAutoEncoder, self).__init__(name=name, **kwargs)
self.original_dim = original_dim
self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)
self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)
def call(self, inputs):
z_mean, z_log_var, z = self.encoder(inputs)
reconstructed = self.decoder(z)
# Add KL divergence regularization loss.
kl_loss = -0.5 * tf.reduce_mean(
z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1
)
self.add_loss(kl_loss)
return reconstructed
模型训练
# 数据集加载
(x_train, _), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
# 模型初始化
model = VariationalAutoEncoder(784, 64, 32)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = tf.keras.losses.MeanSquaredError()
# 模型训练
for epoch in range(3):
for x_batch_train in train_dataset:
with tf.GradientTape() as tape:
reconstructed = model(x_batch_train)
loss = loss_fn(x_batch_train, reconstructed) # Compute reconstruction loss
loss += sum(model.losses) # Add KLD regularization loss
grads = tape.gradient(loss, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
print("step %d: mean loss = %.4f" % (epoch, loss.numpy()))
# 由于模型是 Model 子类化的结果,它具有内置的训练循环。因此,您也可以用以下方式训练它:
model.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())
model.fit(x_train, x_train, epochs=2, batch_size=64)
模型保存和加载
# 模型保存
model.save('path/to/location')
# 模型加载
model = keras.models.load_model('path/to/location')
# 其他详细内容:https://www.tensorflow.org/guide/keras/save_and_serialize
案例二
# 自定义一个Layer
class Linear(keras.layers.Layer):
def __init__(self, units=32, input_dim=32):
super(Linear, self).__init__()
# ...
def call(self, inputs):
# ...
# 层递归组合
class MLPBlock(keras.layers.Model):
def __init__(self):
super(MLPBlock, self).__init__()
self.linear_1 = Linear(64, 32)
self.linear_2 = Linear(32, 16)
self.linear_3 = Linear(16, 1)
def call(self, inputs):
x = self.linear_1(inputs)
x = tf.nn.relu(x)
x = self.linear_2(x)
x = tf.nn.relu(x)
return self.linear_3(x)
# 自定义损失函数和评估方法 add_loss()/add_metric():https://www.tensorflow.org/guide/keras/custom_layers_and_models
d_optimizer = keras.optimizers.Adam(learning_rate=0.001)
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()
model = MLPBlock()
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
predictions = model(x, training=True)
loss_value = loss_fn(y, predictions)
grads = tape.gradient(loss_value, model.trainable_weights)
d_optimizer.apply_gradients(zip(grads, model.trainable_weights))
@tf.function
def test_step(x, y):
predictions = model(x, training=False)
val_acc_metric.update_state(y, predictions)