Python !AI绘画

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# 定义生成器模型
def build_generator(latent_dim):
    model = keras.Sequential(
        [
            layers.Dense(256, input_dim=latent_dim),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(),
            layers.Dense(512),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(),
            layers.Dense(1024),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(),
            layers.Dense(28 * 28, activation="tanh"),
            layers.Reshape((28, 28, 1)),
        ]
    )
    return model

# 定义鉴别器模型
def build_discriminator(img_shape):
    model = keras.Sequential(
        [
            layers.Flatten(input_shape=img_shape),
            layers.Dense(512),
            layers.LeakyReLU(alpha=0.2),
            layers.Dense(256),
            layers.LeakyReLU(alpha=0.2),
            layers.Dense(1, activation="sigmoid"),
        ]
    )
    return model

# 定义GAN模型
def build_gan(generator, discriminator):
    discriminator.trainable = False
    model = keras.Sequential([generator, discriminator])
    return model

# 加载MNIST数据集
(x_train, _), (_, _) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 127.5 - 1.0
dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(buffer_size=1024).batch(32)

# 创建生成器和鉴别器
latent_dim = 128
generator = build_generator(latent_dim)
discriminator = build_discriminator(x_train[0].shape)

# 定义优化器和损失函数
loss_fn = keras.losses.BinaryCrossentropy()
generator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

# 训练GAN模型
epochs = 50
for epoch in range(epochs):
    for real_images in dataset:
        # 训练鉴别器
        noise = tf.random.normal(shape=(32, latent_dim))
        fake_images = generator(noise)
        real_labels = tf.ones((32, 1))
        fake_labels = tf.zeros((32, 1))
        with tf.GradientTape() as tape:
            real_loss = loss_fn(real_labels, discriminator(real_images))
            fake_loss = loss_fn(fake_labels, discriminator(fake_images))
            total_loss = real_loss + fake_loss
        grads = tape.gradient(total_loss, discriminator.trainable_weights)
        discriminator_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))
        
        # 训练生成器
        noise = tf.random.normal(shape=(32, latent_dim))
        with tf.GradientTape() as tape:
            fake_images = generator(noise)
            fake_loss = loss_fn(real_labels, discriminator(fake_images))
        grads = tape.gradient(fake_loss, generator.trainable_weights)
        generator_optimizer.apply_gradients(zip(grads, generator.trainable_weights))
 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值