首先定义,generator和discriminator模型
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, noise_size):
super(Generator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(noise_size, 1024),
nn.ReLU(True),
nn.BatchNorm1d(1024),
nn.Linear(1024, 7 * 7 * 128),
nn.ReLU(True),
nn.BatchNorm1d(7 * 7 * 128)
)
self.conv = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, padding=1),
nn.ReLU(True),
nn.BatchNorm2d(64),
nn.ConvTranspose2d(64, 1, 4, 2, padding=1),
nn.Tanh()
)
def forward(self, x):
x = self.fc(x)
x = x.view(x.shape[0], 128, 7, 7) # reshape 通道是 128,大小是 7x7
x = self.conv(x)
return x
class Discriminator(nn.Module):
def __init__(self, input_size=1, wgan=False):
super(Discriminator, self).__init__()
self.wgan = wgan
self.conv = nn.Sequential(
nn.Conv2d(input_size, 32, 5, 1),
nn.LeakyReLU(0.01),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, 5, 1),
nn.LeakyReLU(0.01),
nn.MaxPool2d(2, 2)
)
self.fc = nn.Sequential(
nn.Linear(1024, 1024),
nn.LeakyReLU(0.01),
nn.Linear(1024, 1)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.conv(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
定义计算gradinet penalty函数,和label生成函数
def gen_label(batch_size):
"""
用于生成真与假两个label
"""
real = torch.ones((batch_size,)).view(-1, 1)
fake = torch.zeros((batch_size,)).view(-1, 1)
return real, fake
def cal_gradient_penalty(disc_net, device, real, fake):
# compute wgan-gp
batch_size = real.size(0)
alpa = torch.rand(batch_size, 1, 1, 1)
alpa = alpa.expand_as(real)
alpa = alpa.to(device)
#compute sample data
interpolates = alpa*real + ((1-alpa)*fake)
#compute y to obtain gradient
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = disc_net(interpolates)
# compute gradient
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
#compute gp
gradient_penalty = torch.pow((gradients.norm(2, dim=1)-1), 2).mean()
return gradient_penalty
定义数据集,和数据预处理, 模型选择gan,wgan,wgan-gp
torch.manual_seed(23)
np.random.seed(23)
def preprocess_img(x):
x = transforms.ToTensor()(x)
return (x - 0.5) / 0.5
def main():
epochs = 50
clamp_lower = -0.01
clamp_upper = 0.01
global_step = 0
viz = visdom.Visdom()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
batch_size = 64
#data processes
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
transform = preprocess_img
mnist_train = datasets.MNIST('./mnist', train=True, transform=transform, download=False)
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=nw)
mnist_test = datasets.MNIST('./mnist', train=False, transform=transform, download=False)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=nw)
train_steps = len(train_loader)
#noise
noise_size = 96
fix_noise = torch.empty((batch_size, noise_size), dtype=torch.float32).uniform_(-1, +1).to(device)
#model
model = 'wgan-gp'
# 根据命令行参数选择构建哪种模型
if model == 'wgan':
gen_net = Generator(noise_size).to(device)
disc_net = Discriminator(input_size=1, wgan=True).to(device)
elif model == 'wgan-gp':
gen_net = Generator(noise_size).to(device)
disc_net = Discriminator(input_size=1, wgan=True).to(device)
else:
gen_net = Generator(noise_size).to(device)
disc_net = Discriminator(input_size=1).to(device)
定义损失函数,选择优化器, viz是visdom对象用来实时显示训练损失,
criterion = nn.BCEWithLogitsLoss()
#optimizer
optimizer_D = optim.Adam(params=disc_net.parameters(), lr=3e-4, betas=(0.5,0.999))
optimizer_G = optim.Adam(params=gen_net.parameters(), lr=3e-4, betas=(0.5,0.999))
# optimizer_D = optim.RMSprop(disc_net.parameters(), lr=0.001)
# optimizer_G = optim.RMSprop(gen_net.parameters(), lr=0.001)
#learning rate decay
scheduler_D = ExponentialLR(optimizer_D, gamma=0.9)
scheduler_G = ExponentialLR(optimizer_G, gamma=0.9)
viz.line([0.], [0], win='real loss', opts=dict(title='real image'))
viz.line([0.], [0], win='fake loss', opts=dict(title='fake image'))
viz.line([0.], [0], win='discriminator_loss', opts=dict(title='discriminator loss'))
viz.line([0.], [0], win='generator_loss', opts=dict(title='generator loss'))
进行训练
for epoch in range(epochs):
G_epochloss = 0.0
D_epochloss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout) # 给训练过程加一个进度条
#(1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
for index, (real_img, _) in enumerate(train_bar):
gen_net.train()
disc_net.train()
real_img = real_img.to(device)
batch_size = real_img.size(0)
#real image label is one
reallabel, fakelabel = gen_label(batch_size)
reallabel, fakelabel = reallabel.to(device), fakelabel.to(device)
# 生成随机噪声
noise = (torch.rand(batch_size, noise_size) - 0.5) / 0.5
noise = noise.to(device)
# noise = torch.empty((batch_size, noise_size), dtype=torch.float32).uniform_(-1, +1).to(device)
# WGAN需要将判别器的参数绝对值截断到不超过一个固定常数c
if model == 'wgan':
for p in disc_net.parameters():
p.data.clamp_(clamp_lower, clamp_upper)
disc_net.zero_grad()
# 优化过程根据GAN、WGAN、WGAN-GP三种模型的不同而异。另外,为了能和之前求最小值的优化过程一致,这里我们选用损失值的相反数作为优化目标,即
# maximize A <==> min -A
if model == 'wgan':
# WGAN相较于GAN,判别器最后一层去掉sigmoid函数,故直接求期望即可,不必使用损失函数
D_Loss_real = disc_net(real_img).mean()
fake = gen_net(noise)
D_Loss_fake = disc_net(fake).mean()
D_Loss = -(D_Loss_real - D_Loss_fake)
# 反向传播
D_Loss.backward()
elif model == 'wgan-gp':
# WGAN-GP此处与WGAN同
D_Loss_real = disc_net(real_img).mean()
fake = gen_net(noise)
D_Loss_fake = disc_net(fake).mean()
# WGAN-GP相较于WGAN引入了gradient penalty限制梯度
gradient_penalty = cal_gradient_penalty(disc_net, device, real_img.data, fake.data)
D_Loss = -(D_Loss_real - D_Loss_fake) + gradient_penalty * 0.1
# 反向传播
D_Loss.backward()
else:
# 与上面两个不同的是,GAN的公式是maximize log(D(x)) + log(1 - D(G(z)))
D_Loss_real = criterion(disc_net(real_img), reallabel)
fake = gen_net(noise).detach()
D_Loss_fake = criterion(disc_net(fake), fakelabel)
D_Loss = D_Loss_real + D_Loss_fake
# 反向传播
D_Loss.backward()
D_epochloss += D_Loss.item()
optimizer_D.step()
"""
接着要进行maxmin算法的minimize生成器Loss的部分
"""
# 将梯度缓存置0
gen_net.zero_grad()
# 生成放入generator中的噪声
# noise = torch.randn(batch_size, noise_size).to(device)
fake = gen_net(noise)
# 分模型的细节与上述原理相同
if model == 'wgan':
G_Loss = -disc_net(fake).mean()
G_Loss.backward()
elif model == 'wgan-gp':
G_Loss = -disc_net(fake).mean()
G_Loss.backward()
else:
G_Loss = criterion(disc_net(fake), reallabel)
G_Loss.backward()
G_epochloss += G_Loss.item()
optimizer_G.step()
global_step += 1
viz.line([D_Loss_real.item()], [global_step], win='real loss', opts=dict(title='real image'), update='append')
viz.line([D_Loss_fake.item()], [global_step], win='fake loss', opts=dict(title='fake image'), update='append')
viz.line([D_Loss.item()], [global_step], win='discriminator_loss', opts=dict(title='discriminator loss'), update='append')
viz.line([G_Loss.item()], [global_step], win='generator_loss', opts=dict(title='generator loss'), update='append')
if index % 100 == 0:
viz.images(real_img, nrow=16, win='real_image', opts=dict(title='real image'))
viz.images(fake.detach(), nrow=16, win='fake image', opts=dict(title='fake image'))
scheduler_D.step(epoch)
scheduler_G.step(epoch)
print("%d / %d discriminator loss is %.3f" % (epoch + 1, epochs, D_epochloss / train_steps))
if __name__ == '__main__':
main()