GAN学习历程(2)

DCGAN (Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks)

对卷积神经网络的结构做了一些改变,以提高样本的质量和收敛的速度。

  1. 取消所有pooling G中使用反卷积上采样 D中用加入stride卷积 代替 pooling
  2. 去掉FC层 全卷积网络
  3. G用ReLU,最后一层用tanh
  4. D使用LReLU
  5. 在G和D上都是用batchnorm:解决初始化差的问题;帮助梯度传播到每一层;防止G把所有样本都收敛到一个点;在G的输出层和D的输入层不采用BN为了防止样本震荡和模型不稳定。
  6. 使用Adam(最优的优化器)beat1 = 0.5
    由简单的GAN可以改出DCGAN,只需要更改model内容即可
#导入数据
(train_image, train_labels),(_,_) = tf.keras.dataset.minst.loaddata()#不需要测试数据,用占位符
# train_image.shape  (60000,28,28)
#train_images.dtype  dtype(‘uint8’) 数据类型要利用归一化落到0的周围,因为激活函数
#能在0的周围发挥最大的作用
train_image = train_image.dtype(‘float32’)#将图片类型uint8转化为float型
(train_image – 127.5)/127.5 #使训练图像取值范围落到[-1,1]
BATCH_SIZE = 256
BUFFER_SIZE = 60000 #一次取多少张/*
#输入管道
datasets = tf.data.Dataset.from_tensor_slices(train_image) # 从rain_image 中创建数据集
datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE) 
定义生成器model
def generator_model():
model = tf.keras.Sequential()#创建顺序模型
#第一层
model.add(layers.Dense(7*7*256,input_shape=(100,), use_bias=False))#长度为100的向量
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())#使用内置激活函数
model.add(layers.Reshape((7,7,256)))  #到这里图片大小为7*7*256
#上采样和反卷积不一样
model.add(layer,conv2DTranspose(128,(5,5),strides = (1,1) , padding = ‘same’,use_bias=False))
model.add(layers.batchnormalization())
model.add(layers.LeakyReLU())  #到这里图片大小为7*7*128 因为stride为1*1所以图像没有变大
model.add(layer,conv2DTranspose(64,(5,5),stride = (2,2) , padding = ‘same’,use_bias=False))
model.add(layers.batchnormalization())
model.add(layers.LeakyReLU())  #到这里图片大小为14*14*128 stride 为2*2,图像扩大*2
#卷积是channel变厚变小,反卷积是变薄变大
model.add(layer,conv2DTranspose(
1,(5,5),stride = (2,2) , padding = ‘same’,use_bias=False,activation = ‘tanh’))#28*28*1
Return model

定义辨别器model
def discriminator_model():
model = keras.Sequential()

model.add(layers.conv2D(64,(5,5),stride=(2,2),padding = ‘same’,input_shape(28,28,1))) 
model.add(layers.LeakyReLU())#14*14*64
Model.add(layer.Dropout(0.3)) #通用:创建GAN小技巧:辨别器不需要这么好

model.add(layers.conv2D(128,(5,5),stride=(2,2),padding = ‘same’)) 
model.add(layers.LeakyReLU())#7*7*128
Model.add(layer.Dropout(0.3)) 

model.add(layers.conv2D(256,(5,5),stride=(2,2),padding = ‘same’)) 
model.add(layers.LeakyReLU())#3*3*256

model.add(layers.Flatten())

model.add(layers.Dense(1)) 判断是真是假
return model

Loss函数
cross_entropy= tf.keras.losses.BinaryCrossentropy(from_logits=True)#使用内置损失函数
判别器Loss
def discriminator_loss(real_output,fake_ output):
real_loss = Cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = Cross_entropy(tf.zeros_like(fake_output), fake_output)
return Real_loss + fake_loss
生成器Loss
def generator_loss(fake_output):#生成器只接收fake output 输入的不是一张图片
return Cross_entropy(tf.ones_like(fake_output), fake_output) #希望我们的图片都能为真
#优化器定义
generator_opt = tf.keras.optimizers.Adam(1e-4)#学习率为0.0001
discriminator_opt = tf.keras.optimizers.Adam(1e-4)
#参数设置
EPOCHS = 100
noise_dim = 100
num_exp_to_generate = 16      #观察生成过程 用种子
Seed = tf.random.normal ([num_exp_to_generate, noise_dim])  

discriminator = discriminator_model()
generator = generator_model()#执行这两个函数,返回两个model

def train_step(images):
noise = tf.random.normal([BTACH_SIZE, noise_dim])  
with tf.gradientTape() as gen_tape, With tf.gradientTape() as disc_tape,      #记录梯度信息
real_out = discriminator(images, training = True) 
gen_image = generator(noise,training = true)#输入noise 产生一张图片
fake_out = discriminator(images, training = True) 

gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output,fake_output)
#计算梯度
gradient_gen = gen_tape_gradient(gen_loss,generator.trainable_variables)
gradient_gen = disc_tape_gradient(disc_loss,discriminator.trainable_variables)
#优化函数
generator_opt.apply_gradient(zip(gradient_gen, generator.trainable_variables))
discriminator_opt.apply_gradient(zip(gradient_disc, discriminator.trainable_variables))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值