GAN_loss的构建

GAN_loss的构建.

tensorflow:参考链接:https://blog.csdn.net/qiu931110/article/details/80181212

pytorch参考链接:https://blog.csdn.net/u011961856/article/details/78697863

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
def train_gan(generator, discriminator, gan, dataset, latent_dim, epochs): notes = get_notes() # 得到所有不重复的音调数目 num_pitch = len(set(notes)) network_input, network_output = prepare_sequences(notes, num_pitch) model = build_gan(network_input, num_pitch) # 输入,音符的数量,训练后的参数文件(训练的时候不用写) filepath = "03weights-{epoch:02d}-{loss:.4f}.hdf5" checkpoint = tf.keras.callbacks.ModelCheckpoint( filepath, # 保存参数文件的路径 monitor='loss', # 衡量的标准 verbose=0, # 不用冗余模式 save_best_only=True, # 最近出现的用monitor衡量的最好的参数不会被覆盖 mode='min' # 关注的是loss的最小值 ) for epoch in range(epochs): for real_images in dataset: # 训练判别器 noise = tf.random.normal((real_images.shape[0], latent_dim)) fake_images = generator(noise) with tf.GradientTape() as tape: real_pred = discriminator(real_images) fake_pred = discriminator(fake_images) real_loss = loss_fn(tf.ones_like(real_pred), real_pred) fake_loss = loss_fn(tf.zeros_like(fake_pred), fake_pred) discriminator_loss = real_loss + fake_loss gradients = tape.gradient(discriminator_loss, discriminator.trainable_weights) discriminator_optimizer.apply_gradients(zip(gradients, discriminator.trainable_weights)) # 训练生成器 noise = tf.random.normal((real_images.shape[0], latent_dim)) with tf.GradientTape() as tape: fake_images = generator(noise) fake_pred = discriminator(fake_images) generator_loss = loss_fn(tf.ones_like(fake_pred), fake_pred) gradients = tape.gradient(generator_loss, generator.trainable_weights) generator_optimizer.apply_gradients(zip(gradients, generator.trainable_weights)) gan.fit(network_input, np.ones((network_input.shape[0], 1)), epochs=100, batch_size=64) # 每 10 个 epoch 打印一次损失函数值 if (epoch + 1) % 10 == 0: print("Epoch:", epoch + 1, "Generator Loss:", generator_loss.numpy(), "Discriminator Loss:", discriminator_loss.numpy())
最新发布
06-01

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值