莫烦pytorch(14)——GAN网络

1.画出大师作品(构造目标)

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(1)    # 设置种子
np.random.seed(1)

BATCH_SIZE = 64
LR_G = 0.0001           # (生成网络)
LR_D = 0.0001           # (判别网络)
N_IDEAS = 5             # 认为生成网络有五个灵感构成
ART_COMPONENTS = 15     # 15个部分

PAINT_POINTS=np.vstack([np.linspace(-1,1,ART_COMPONENTS) for _ in range(BATCH_SIZE)])  #水平拼接
print(PAINT_POINTS.shape)      #(64,15)
plt.plot(PAINT_POINTS[0],2*np.power(PAINT_POINTS[0],2)+1,c='#74BCFF', lw=3, label='upper bound')
plt.plot(PAINT_POINTS[0],2*np.power(PAINT_POINTS[0],2)+0,c='#FF9359', lw=3, label='lower bound')
plt.legend(loc="upper right")
plt.show()

在这里插入图片描述
PAINT_POINTS=np.vstack([np.linspace(-1,1,ART_COMPONENTS) for _ in range(BATCH_SIZE)])。其中np.vstack(a,b)是水平拼接。并且用了列表生成式,详情请看廖雪峰添加链接描述

2.定义大师作品的函数

def artist_workers():
    a=np.random.uniform(1,2,size=BATCH_SIZE)[:,np.newaxis]
    paintings=a*np.power(PAINT_POINTS,2)+(a-1)
    paintings=torch.from_numpy(paintings).float()
    return paintings

3.生成对抗网络的构建

G=nn.Sequential(
    nn.Linear(N_IDEAS, 128),  # random ideas (could from normal distribution)
    nn.ReLU(),
    nn.Linear(128, ART_COMPONENTS),  # making a painting from these random ideas
)

D=nn.Sequential(
    nn.Linear(ART_COMPONENTS, 128),  # receive art work either from the famous artist or a newbie like G
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid(),  # tell the probability that the art work is made by artist
)

opt_D=torch.optim.Adam(D.parameters(),lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

plt.ion()

4.交叉训练

for step in range(10000):
    artist_paintings=artist_workers()
    G_ideas=torch.rand(BATCH_SIZE,N_IDEAS)
    G_paintings=G(G_ideas)
    prob_artist0=D(artist_paintings)
    prob_artist1=D(G_paintings)
    D_loss=-torch.mean(torch.log(prob_artist0)+torch.log(1-prob_artist1))
    G_loss=torch.mean(torch.log(1-prob_artist1))
    opt_D.zero_grad()
    D_loss.backward(retain_graph=True)
    opt_D.step()
    opt_G.zero_grad()
    G_loss.backward()
    opt_G.step()

    if step % 50 == 0:  # plotting
        plt.cla()
        plt.plot(PAINT_POINTS[0],G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)
        plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
        plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
        plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(),
                 fontdict={'size': 13})
        plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})
        plt.ylim((0, 3));
        plt.legend(loc='upper right', fontsize=10);
        plt.draw();
        plt.pause(0.01)

    plt.ioff()
    plt.show()

在这里插入图片描述
下面简述一下交叉训练的过程:先调用函数生成大师的作品,在调用G生成伪造的画,分别对这两个产生的画进行判别,然后用两个计算loss的公式进行各自求解loss,其中D_loss=-(log(D(x)) + log(1-D(G(z))),因为我们希望仿造的画更小,真实的画更大,但是深度学习中只能计算最小值,所以加一个负号。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值