Gan的全称是Generative Adveratial Nets,生成对抗网络。
Generator采用随机数生成有意义的数据,Discriminator学习判定哪些是真实数据哪些是生成数据,并反向传递到Generator。
生成对抗网络接收一些信息,生成有意义的物体。
下面是示例代码:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable
# 超参数
BATCH_SIZE = 64
LR_G = 0.0001 # learning rate for generator
LR_D = 0.0001 # learning rate for discriminator
N_IDEAS = 5 # think of this as numbe