import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# 定义生成器模型
def build_generator(latent_dim):
model = keras.Sequential(
[
layers.Dense(256, input_dim=latent_dim),
layers.LeakyReLU(alpha=0.2),
layers.BatchNormalization(),
layers.Dense(512),
layers.LeakyReLU(alpha=0.2),
layers.BatchNormalization(),
layers.Dense(1024),
layers.LeakyReLU(alpha=0.2),
layers.BatchNormalization(),
layers.Dense(28 * 28, activation="tanh"),
layers.Reshape((28, 28, 1)),
]
)
return model
# 定义鉴别器模型
def build_discriminator(img_shape):
model = keras.Sequential(
[
layers.Flatten(input_shape=img_shape),
layers.Dense(512),
layers.LeakyReLU(alpha=0.2),
layers.Dense(256),
layers.LeakyReLU(alpha=0.2),
layers.Dense(1, activation="sigmoid"),
]
)
return model
# 定义GAN模型
def build_gan(generator, discriminator):
discriminator.trainable = False
model = keras.Sequential([generator, discriminator])
return model
# 加载MNIST数据集
(x_train, _), (_, _) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 127.5 - 1.0
dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(buffer_size=1024).batch(32)
# 创建生成器和鉴别器
latent_dim = 128
generator = build_generator(latent_dim)
discriminator = build_discriminator(x_train[0].shape)
# 定义优化器和损失函数
loss_fn = keras.losses.BinaryCrossentropy()
generator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
# 训练GAN模型
epochs = 50
for epoch in range(epochs):
for real_images in dataset:
# 训练鉴别器
noise = tf.random.normal(shape=(32, latent_dim))
fake_images = generator(noise)
real_labels = tf.ones((32, 1))
fake_labels = tf.zeros((32, 1))
with tf.GradientTape() as tape:
real_loss = loss_fn(real_labels, discriminator(real_images))
fake_loss = loss_fn(fake_labels, discriminator(fake_images))
total_loss = real_loss + fake_loss
grads = tape.gradient(total_loss, discriminator.trainable_weights)
discriminator_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))
# 训练生成器
noise = tf.random.normal(shape=(32, latent_dim))
with tf.GradientTape() as tape:
fake_images = generator(noise)
fake_loss = loss_fn(real_labels, discriminator(fake_images))
grads = tape.gradient(fake_loss, generator.trainable_weights)
generator_optimizer.apply_gradients(zip(grads, generator.trainable_weights))
Python !AI绘画
最新推荐文章于 2023-12-11 08:00:00 发布