import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
# 参数设置
latent_dim = 100
image_size = 64
channels = 3
batch_size = 128
num_epochs = 200
lr = 0.0002
b1 = 0.5
b2 = 0.999
# 数据预处理
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize([0.5]*channels, [0.5]*channels)
])
# 加载数据集
dataset = datasets.ImageFolder(root='path_to_your_dataset', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.init_size = image_size // 4
self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, channels, 3, stride=1, padding=1),
nn.Tanh()
)
def forward(self, z):
out = self.l1(z)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
# 判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, bn=True):
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
return block
self.model = nn.Sequential(
*discriminator_block(channels, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)
ds_size = image_size // 2 ** 4
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
def forward(self, img):
out = self.model(img)
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
return validity
# 初始化模型
generator = Generator()
discriminator = Discriminator()
# 优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
# 损失函数
adversarial_loss = nn.BCELoss()
# 训练模型
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(dataloader):
# 真实图像标签为1,生成图像标签为0
valid = torch.ones(imgs.size(0), 1, requires_grad=False)
fake = torch.zeros(imgs.size(0), 1, requires_grad=False)
# 训练生成器
optimizer_G.zero_grad()
z = torch.randn(imgs.shape[0], latent_dim)
gen_imgs = generator(z)
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
# 训练判别器
optimizer_D.zero_grad()
real_loss = adversarial_loss(discriminator(imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")
# 保存生成的图像
if epoch % 10 == 0:
save_image(gen_imgs.data[:25], f"images/{epoch}.png", nrow=5, normalize=True)
01-12
451