pytorch学习11-GAN

batch_size = 32
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,),
                        std=(0.5,))
])
mnist_data = torchvision.datasets.MNIST("./mnist_data",train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset=mnist_data,
                                        batch_size=batch_size,
                                        shuffle=True)
image_size = 784

hidden_size = 256
# discriminator
D = nn.Sequential(
    nn.Linear(image_size,hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size,hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size,1),
    nn.Sigmoid()
)

latent_size = 64
G = nn.Sequential(
    nn.Linear(latent_size,hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh()
)

D = D.to(device)
G = G.to(device)

loss_fn = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

total_step = len(dataloader)
print(total_step)
num_epochs = 200
for epoch in range(num_epochs):
    for i,(images,_) in enumerate(dataloader):
        batch_size = images.size(0)
        images = images.reshape(batch_size,image_size).to(device)
        
        real_labels = torch.ones(batch_size,1).to(device)
        fake_labels = torch.zeros(batch_size,1).to(device)
        
        outputs = D(images)
        d_loss_real = loss_fn(outputs,real_labels)
        real_score = outputs
        
        # 开始生成fake images
        z = torch.randn(batch_size,latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images.detach())
        #GAN的G的更新,主要是GAN loss。就是G生成的fake图让D来判别,得到的损失,计算梯度进行反传。
        #这个梯度只能影响G,不能影响D!简单来说detach就是截断反向传播的梯度流。
        d_loss_fake = loss_fn(outputs,fake_labels)
        fake_score = outputs
        
        # 开始优化discriminator
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # 开始优化generator
        z = torch.randn(batch_size,latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = loss_fn(outputs,real_labels)
        
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if i % 1000 == 0:
            print("Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}"
                 .format(epoch, num_epochs, i, total_step, d_loss.item(), g_loss.item(), real_score.mean().item(), fake_score.mean().item()))

在这里插入图片描述

z = torch.randn(1,latent_size).to(device)
fake_images = G(z).view(28,28).data.cpu().numpy()
plt.imshow(fake_images)

在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值