基于Tensorflow2的WGAN神经网络

GAN网络对于没有交集的数据集在训练上表现会不好,所以WGAN解决这个问题,加入了梯度惩罚

from GANFolder.GAN import Generator,Discriminator
import tensorflow as  tf
import numpy as np
import os
from PIL import Image

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def ParseData(batchData):

    feature = {

        'width':tf.io.FixedLenFeature([],tf.int64),
        'height':tf.io.FixedLenFeature([],tf.int64),
        'wmode':tf.io.FixedLenFeature([],tf.int64),
        'raw_image':tf.io.FixedLenFeature([],tf.string),

    }

    example = tf.io.parse_single_example(batchData,feature)
    raw_image_uint8 = tf.image.decode_jpeg(example['raw_image'],channels=3)
    image_tensor = tf.image.resize(raw_image_uint8,(64,64))
    return image_tensor

batchSize = 128
EPOCH = 300000
lr = 0.0002
isTraining = True


dataSets = tf.data.TFRecordDataset(r"E:\PycharmProjects\untitled\GANFolder\Duke.tfrecord")
dataSets = dataSets.map(ParseData).shuffle(10000).batch(batchSize,drop_remainder=True).repeat()
dataSets_iter = iter(dataSets)
z_dim = 100

generater = Generator()
generater.build(input_shape=[None,z_dim])
generater.summary()

print()

discriminater = Discriminator()
discriminater.build(input_shape=[None,64,64,3])
discriminater.summary()


G_optimizer = tf.keras.optimizers.Adam(learning_rate=lr,beta_1=0.5)
D_optimizer = tf.keras.optimizers.Adam(learning_rate=lr,beta_1=0.5)


G_LOSS = 1000000
D_LOSS = 1000000


def real_loss_ones(losigts):
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=losigts,labels=tf.ones_like(losigts)))


def fake_loss_zeros(losigts):
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=losigts,labels=tf.zeros_like(losigts)))



def gradient_penalty(dis,real,fakea):
    # t = tf.random.uniform(real.shape)
    t = tf.random.uniform([real.shape[0],1,1,1])
    t = tf.broadcast_to(t,real.shape)
    inter = t * real + (1-t) * fakea

    with tf.GradientTape() as tape:
        tape.watch(inter)
        loss_logits = discriminater(inter)
        grad = tape.gradient(loss_logits,inter)
        grad = tf.reshape(grad,[grad.shape[0],-1])
        l2_norm = tf.norm(grad,axis=1)
        gp = tf.reduce_mean((l2_norm - 1) ** 2)
        return gp


def d_loss_func(gen,dis,real,fakea,isTraining):
    fake_image = gen(fakea,training = isTraining)
    fake_logits = dis(fake_image,training =isTraining)
    real_logits = dis(real,training =isTraining)
    r_loss = real_loss_ones(real_logits)
    f_loss = fake_loss_zeros(fake_logits)
    
    #梯度惩罚,并乘以一个超参数作为系数,控制惩罚力度
    GP = gradient_penalty(discriminater,real,fake_image)
    loss = r_loss + f_loss + 1.0 * GP
    return loss


def g_loss_func(gen,dis,fakea,isTraining):
    fake_image = gen(fakea,training =isTraining)
    logits = dis(fake_image,training =isTraining)
    r_loss = real_loss_ones(logits)
    loss = r_loss

    return loss


for epoch in range(EPOCH):
    real_data = next(dataSets_iter)
    fake_data = tf.random.uniform([batchSize,z_dim],minval=-1.,maxval=1.)

    if epoch % 4 == 0: #训练5次生成器之后,训练一次判别器
        with tf.GradientTape() as  d_tape:
            d_loss = d_loss_func(generater,discriminater,real_data,fake_data,isTraining)
            grad = d_tape.gradient(d_loss,discriminater.trainable_variables)
            D_optimizer.apply_gradients(zip(grad,discriminater.trainable_variables))

            if d_loss < D_LOSS:
                D_LOSS = d_loss
                discriminater.save_weights("./discirminaterWeight/duwei.ckpt")
                print("Discriminator Weights has been Refreshed!")


    with tf.GradientTape() as  g_tape:
        g_loss = g_loss_func(generater,discriminater,fake_data,isTraining)
        grad = g_tape.gradient(g_loss,generater.trainable_variables)
        G_optimizer.apply_gradients(zip(grad,generater.trainable_variables))

        if g_loss < G_LOSS:
            G_LOSS = g_loss
            generater.save_weights("./GeneratorWeight/duwei.ckpt")
            print("Generator Weights has been Refreshed!")



    if epoch % 100 == 0:
        fake_image_data = tf.random.uniform([100,z_dim],minval=-1.,maxval=1.)
        fake_Image = generater(fake_image_data,False)
        print("g_loss:",float(g_loss),"d_loss:",float(d_loss))

        for index , single_image in enumerate(fake_Image):
            single_image = np.uint8(tf.nn.sigmoid(single_image) * 255.)
            # print(single_image.shape,single_image.min(),single_image.max())
            if index == 99:
                new_img = Image.new('RGB',(64,64))
                img = Image.fromarray(single_image,'RGB')
                new_img.paste(img,(0,0))
                new_img.save("./save_imgs/{a}-{b}.png".format(a=epoch,b = index))
        print("Epoch {a} 's Images has been saved!".format(a=epoch+1))


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值