import torch
import torch.nn as nn
import torch.optim as optim
from sklearn import datasets
# 潜在的空间 其实GAN 训练出来的判别器对对抗样例的防御是有意义的。但是
# 为了防御对抗样例而训练代价太大了。
# z 潜在map
nz = 10
# 生成器 map
ngf = 20
# 判别器 map
ndf = 20
lr = 0.001
# X = torch.tensor([
# [1, 1, 1, 1],
# [2, 2, 2, 2],
# [1.1, 1.1, 1.2, 1.2],
# [3, 3, 3, 3],
# [10, 11, 9, 10],
# [5, 6, 5, 4.9]
# ])
X = torch.tensor(datasets.load_iris().data).float()
print(X.shape)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(nz, ngf),
nn.ReLU(True),
nn.Linear(ngf, X.shape[1]),
)
def forward(self, input):
return self.main(input)
#print(netG)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(X.shape[1], ngf),
nn.ReLU(True),
nn.Linear(ngf, 1),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
netG = Generator()
netD = Discriminator()
criterion = nn.BCELoss()
real_label = 1
fake_label = 0
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
postive_label = torch.full((X.shape[0],), real_label)
fake_label = torch.full((X.shape[0],), fake_label)
print(postive_label)
print(fake_label)
for i in range(1000):
netD.zero_grad()
output = netD(X)
errD_real = criterion(output, postive_label)
errD_real.backward()
D_x = output.mean().item()
noise = torch.randn(X.shape[0], nz)
fake = netG(noise)
output = netD(fake.detach())
errD_fake = criterion(output, fake_label)
errD_fake.backward()
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
optimizerD.step()
# 小技巧
netG.zero_grad()
output = netD(fake)
errG = criterion(output, postive_label)
errG.backward()
D_G_z2 = output.mean().item()
optimizerG.step()
print(errD_fake)
# 生成的新样本
noise = torch.randn(X.shape[0], nz)[0]
print(netG(noise))