import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 超参数设置
latent_dim = 20
batch_size = 64
num_epochs = 10
learning_rate = 0.001
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 编码器定义
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc_mu = nn.Linear(400, latent_dim)
self.fc_logvar = nn.Linear(400, latent_dim)
def forward(self, x):
h1 = torch.relu(self.fc1(x))
mu = self.fc_mu(h1)
logvar = self.fc_logvar(h1)
return mu, logvar
# 解码器定义
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.fc1 = nn.Linear(latent_dim, 400)
self.fc2 = nn.Linear(400, 784)
def forward(self, z):
h1 = torch.relu(self.fc1(z))
return torch.sigmoid(self.fc2(h1))
# VAE 定义
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = self.encoder(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
return self.decoder(z), mu, logvar
# 初始化模型和优化器
vae = VAE()
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)
# 损失函数
def loss_function(recon_x, x, mu, logvar):
BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
# 训练 VAE
for epoch in range(num_epochs):
vae.train()
train_loss = 0
for i, (imgs, _) in enumerate(train_loader):
optimizer.zero_grad()
recon_batch, mu, logvar = vae(imgs)
loss = loss_function(recon_batch, imgs, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {train_loss / len(train_loader.dataset):.4f}")
# 生成一些样本并展示
import matplotlib.pyplot as plt
import torchvision.utils as vutils
vae.eval()
with torch.no_grad():
z = torch.randn(16, latent_dim)
sample = vae.decoder(z).cpu()
sample = sample.view(16, 1, 28, 28)
grid_img = vutils.make_grid(sample, nrow=4, normalize=True)
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
- 85.
- 86.
- 87.
- 88.
- 89.
- 90.
- 91.
- 92.
- 93.
- 94.
- 95.
- 96.
- 97.
- 98.
- 99.
- 100.
- 101.