TensorFlow 11——ch08-GAN和DCGAN入门

在这里插入图片描述
代码:https://github.com/MONI-JUAN/Tensorflow_Study/ch08-GAN和DCGAN入门

一、基本概念

GAN 的全称为 Generative Adversarial Networks,意为对抗生成网络。

DCGAN 将 GAN 的概念扩展到卷积神经网络中,可以生成质量较高的图片样本 。

1.GAN 的原理

有两个网络,一个是生成网络G(Generator),一个是判别网络D(Discriminator)

  • G:通过噪声z生成图片,记作 G(z) ;
  • D:判断图片是不是”真实的“,输入的x,输出 D(x) 代表是真实图片的概率

训练过程:G尽量生成真实图片去欺骗D,D尽量区分G生成的图片和真实图片。

2.交叉熵损失

V ( D , G ) = E x ∼ P data  ( x ) [ ln ⁡ D ( x ) ] + E z ∼ p z ( z ) [ ln ⁡ ( 1 − D ( G ( z ) ) ) ] V(D, G)=E_{x \sim P_{\text {data }}(x)}[\ln D(x)]+E_{z \sim p_{z}(z)}[\ln (1-D(G(z)))] V(D,G)=ExPdata (x)[lnD(x)]+Ezpz(z)[ln(1D(G(z)))]

  • 左边x部分代表真实图片,右边G(z)是生成的图片;
  • D(x) 和 D(G(z)) 都是判断的概率;
  • 生成网络 G 希望 D(G(z)) 变大,V(D, G)越大越好;
  • 判别网络 D 希望 D(x) 变大,V(D, G)越小越好;

3.DCGAN的原理

DCGAN 的全称是 Deep Convolutional Generative Adversarial Networks,即深度卷积对抗生成网络。从名字上来看,是在 GAN 的基础上增加深度卷积网络结构,专门生成图像样本。

事实上,GAN 并没再对D、 G 的具体结构做出任何限制 。DCGAN 中的 D、 G 的含义以及损失都和原始 GAN 中完全一,但是它在 D 和 G 中采用了较为特殊的结构,以便对图片进行高效建模。

DCGAN 中 G 的网络结构:

在这里插入图片描述

  • 不采用池化层,D中用补偿(stride)的卷积代替池化;
  • 在 G、 D 中均使用 Batch Normalization 帮助模型收敛。
  • 在 G 中,激活函数除了最后一层都使用 ReLU 函数,而最后一层使用 tanh 函数。
  • 在 D 中,激活函数都使用 Leaky ReLU 作为激活函数。

请添加图片描述

请添加图片描述
请添加图片描述

二、生成MNIST图像

1.下载数据集

用脚本下载(可能会下载失败,我也不知道为什么每次都失败)

python download.py mnist

或者百度云

链接:https://pan.baidu.com/s/1l-IHrXYvt4M8kj_C-Blklw
提取码:kgrw

这个数据集和chapter 01 的一样:https://blog.csdn.net/qq_34451909/article/details/108264641

在这里插入图片描述

2.训练

python main.py --dataset mnist --input_height=28 --output_height=28 --train

在这里插入图片描述
在这里插入图片描述

3.训练结果

每过100步会保存一张当前训练情况的图

在这里插入图片描述

对比一下 0_99 和 1_106,才训练了一千步左右,已经很有数字的样子了。

在这里插入图片描述

看一下书中25个epoch,也就是2.5w步之后的图像:

在这里插入图片描述

三、使用自己的数据集训练

1.下载数据集

faces.zip

链接:https://pan.baidu.com/s/1l-IHrXYvt4M8kj_C-Blklw
提取码:kgrw

解压faces.zip ,把 anime放进 data 目录

2.训练模型

python main.py --input_height 96 --input_width 96 \ # 截取中心96*96
  --output_height 48 --output_width 48 \ # 缩放到48*48
  --dataset anime --crop -–train \ # 需要执行训练
  --epoch 300 --input_fname_pattern "*.jpg" # 找出所有.jpg训练
python main.py --input_height 96 --input_width 96 --output_height 48 --output_width 48 --dataset anime --crop -–train --epoch 300 --input_fname_pattern "*.jpg"

这已经是训练3.7小时候的结果了,电脑太渣了

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

对比训练模型:

  • 如果是 mnist 数据集:

    if config.dataset == 'mnist':
    # Update D network
    _, summary_str = self.sess.run([d_optim, self.d_sum],
    feed_dict={ 
      self.inputs: batch_images,
      self.z: batch_z,
      self.y:batch_labels,
    })
    self.writer.add_summary(summary_str, counter)
    
    # Update G network
    _, summary_str = self.sess.run([g_optim, self.g_sum],
    feed_dict={
      self.z: batch_z, 
      self.y:batch_labels,
    })
    self.writer.add_summary(summary_str, counter)
    
    # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
    _, summary_str = self.sess.run([g_optim, self.g_sum],
    feed_dict={ self.z: batch_z, self.y:batch_labels })
    self.writer.add_summary(summary_str, counter)
    
    errD_fake = self.d_loss_fake.eval({
      self.z: batch_z, 
      self.y:batch_labels
    })
    errD_real = self.d_loss_real.eval({
      self.inputs: batch_images,
      self.y:batch_labels
    })
    errG = self.g_loss.eval({
      self.z: batch_z,
      self.y: batch_labels
    })
    
  • 如果是其他数据:

    else:
      # Update D network
      _, summary_str = self.sess.run([d_optim, self.d_sum],
    	feed_dict={ self.inputs: batch_images, self.z: batch_z })
      self.writer.add_summary(summary_str, counter)
    
      # Update G network
      _, summary_str = self.sess.run([g_optim, self.g_sum],
    	feed_dict={ self.z: batch_z })
      self.writer.add_summary(summary_str, counter)
    
      # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
      _, summary_str = self.sess.run([g_optim, self.g_sum],
    	feed_dict={ self.z: batch_z })
      self.writer.add_summary(summary_str, counter)
      
      errD_fake = self.d_loss_fake.eval({ self.z: batch_z })
      errD_real = self.d_loss_real.eval({ self.inputs: batch_images })
      errG = self.g_loss.eval({self.z: batch_z})
    

3.测试模型

python main.py --input_height 96 --input_width 96 \
 --output_height 48 --output_width 48 \
 --dataset anime --crop
python main.py --input_height 96 --input_width 96 --output_height 48 --output_width 48 --dataset anime --crop

main.py中的 OPTION 可以设置 0-4 ,在 utils.py 中的函数 visualize() 中可以看到不同的可视化选项,可以自己设置这个OPTION

# Below is codes for visualization
OPTION = 0
visualize(sess, dcgan, FLAGS, OPTION)

4.测试效果

因为默认都是生成到samples这个文件夹,比较乱,我改了一下路径,生成到五个文件夹。

又因为模型训练的程度不够,才一千多不就已经训练了两个半小时了,只能凑合看看。

OPTION = 0:用模型生成一张10*10的图片
OPTION = 1:生成100张10*10的图片,都差不多样子
OPTION = 2:生成100张10*10的图片,都差不多样子
OPTION = 3:生成100张10*10的图片组成的动画
OPTION = 4:生成100张10*10的图片组成的动画,最后汇合到一个gif
  • OPTION = 0

在这里插入图片描述

  • OPTION = 1
    在这里插入图片描述

  • OPTION = 2

    在这里插入图片描述

  • OPTION = 3

    在这里插入图片描述

  • OPTION = 4

在这里插入图片描述

随便放了个gif上来看

在这里插入图片描述

好吧,训练的太少看不出效果

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
好的,这个问题可以用汇编语言来解决。下面是一个可能的解决方案,仅供参考。 首先,我们需要定义一个数组,用于存储每个月份的缩写。可以用如下方式定义: ``` months db 'Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep','Oct','Nov','Dec' ``` 接下来,我们需要获取用户输入的数字,可以使用 DOS 中断 21h 中的功能来实现。具体步骤如下: 1. 调用中断 21h 的 01h 功能,显示提示信息 "Please enter a number (1-12): ",并等待用户输入。 2. 调用中断 21h 的 0Ah 功能,将用户输入的字符串读入到一个缓冲区中。 3. 将缓冲区中的字符串转换为数字,并保存到一个寄存器中。 代码如下: ``` mov ah, 09h ; 01h 功能:显示字符串 lea dx, prompt ; 提示信息 int 21h mov ah, 0Ah ; 0Ah 功能:输入字符串 lea dx, buffer ; 缓冲区 int 21h mov al, buffer[2] ; 第三个字符是用户输入的数字 sub al, '0' ; 将 ASCII 码转换为数字 ``` 接下来,我们需要检查用户输入的数字是否在 1-12 的范围内。如果不是,则显示错误信息并退出程序。否则,将对应月份的缩写输出到屏幕上。 代码如下: ``` cmp al, 1 jl error cmp al, 12 jg error mov dl, months[al-1] ; 获取对应月份的缩写 mov ah, 02h ; 02h 功能:显示字符 int 21h mov dl, months[al] int 21h mov dl, months[al+1] int 21h jmp exit error: mov ah, 09h lea dx, errmsg int 21h exit: mov ah, 4Ch ; 4Ch 功能:退出程序 int 21h prompt db 'Please enter a number (1-12): $' buffer db 3, 0 errmsg db 'Invalid input!$' ``` 完整代码如下: ``` .model tiny .code org 100h start: mov ah, 09h ; 01h 功能:显示字符串 lea dx, prompt ; 提示信息 int 21h mov ah, 0Ah ; 0Ah 功能:输入字符串 lea dx, buffer ; 缓冲区 int 21h mov al, buffer[2] ; 第三个字符是用户输入的数字 sub al, '0' ; 将 ASCII 码转换为数字 cmp al, 1 jl error cmp al, 12 jg error mov dl, months[al-1] ; 获取对应月份的缩写 mov ah, 02h ; 02h 功能:显示字符 int 21h mov dl, months[al] int 21h mov dl, months[al+1] int 21h jmp exit error: mov ah, 09h lea dx, errmsg int 21h exit: mov ah, 4Ch ; 4Ch 功能:退出程序 int 21h prompt db 'Please enter a number (1-12): $' buffer db 3, 0 errmsg db 'Invalid input!$' months db 'Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep','Oct','Nov','Dec' end start ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值