import torch import torch.nn as nn import torch.optim as optim from torchvision.utils import save_image from torchvision import datasets, transforms # Define the generator network class Generator(nn.Module): def __init__(self, latent_dim): super(Generator, self).__init__() self.latent_dim = latent_dim self.fc1 = nn.Linear(latent_dim, 256) self.fc2 = nn.Linear(256, 512) self.fc3 = nn.Linear(512, 784) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.sigmoid(self.fc3(x)) return x # Define the discriminator network class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.fc1 = nn.Linear(784, 512) self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, 1) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() def forward(self, x): x = x.view(-1, 784) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.sigmoid(self.fc3(x)) return x # Define the loss function and optimizer criterion = nn.BCELoss() generator = Generator(latent_dim=100) discriminator = Discriminator() optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)) # Load the MNIST dataset transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) dataloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) # Train the generator and discriminator networks num_epochs = 100 for epoch in range(num_epochs): for i, (real_images, _) in enumerate(dataloader): # Train the discriminator discriminator.zero_grad() real_labels = torch.ones(real_images.size(0), 1) fake_labels = torch.zeros(real_images.size(0), 1) real_outputs = discriminator(real_images) loss_real = criterion(real_outputs, real_labels) z = torch.randn(real_images.size(0), 100) fake_images = generator(z) fake_outputs = discriminator(fake_images.detach()) loss_fake = criterion(fake_outputs, fake_labels) loss_d = loss_real + loss_fake loss_d.backward() optimizer_d.step() # Train the generator generator.zero_grad() z = torch.randn(real_images.size(0), 100) fake_images = generator(z) fake_outputs = discriminator(fake_images) loss_g = criterion(fake_outputs, real_labels) loss_g.backward() optimizer_g.step() # Print the loss at each iteration print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (epoch+1, num_epochs, i+1, len(dataloader), loss_d.item(), loss_g.item())) # Save a sample generated image at the end of each epoch with torch.no_grad(): z = torch.randn(64, 100) fake_images = generator(z) fake_images = fake_images.view(fake_images.size(0), 1, 28, 28) fake_images = (fake_images + 1) / 2 save_image(fake_images, './generated_images/epoch%d.png' % (epoch+1))
基于扩散模型的手写数字生成
最新推荐文章于 2024-02-29 14:38:10 发布