基于 GAN 的数据增强(优化版:DCGAN、SeqGAN、TimeGAN)
在深度学习中,数据增强 是提升模型泛化能力的重要方法。生成对抗网络(GAN) 在 图像、文本、时间序列 领域的增强具有显著效果。本篇将优化 DCGAN(图像)、SeqGAN(文本)、TimeGAN(时间序列) 的代码,并应用到 真实任务 中。
1. GAN 原理
GAN 由 生成器(G) 和 判别器(D) 组成:
- 生成器(G):输入随机噪声,输出接近真实数据的样本。
- 判别器(D):区分数据是真实(1)还是生成的(0)。
2. DCGAN:用于医学影像数据增强(优化版)
任务:生成医学 X-ray 影像
优化点: ✅ 使用 深度卷积神经网络(CNN) 代替 MLP
✅ AdamW 代替 Adam,提高稳定性
✅ 使用 Label Smoothing 使得训练更平稳
代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
# 配置参数
batch_size = 128
latent_dim = 100
image_size = 64
num_epochs = 100
lr = 0.0002
# **1. 加载 X-ray 影像数据**
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.FakeData(transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# **2. 生成器**
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0), nn.BatchNorm2d(512), nn.ReLU(),
nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(),
nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(),
nn.ConvTranspose2d(128, 1, 4, 2, 1), nn.Tanh()
)
def forward(self, z):
return self.model(z)
# **3. 判别器**
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(1, 128, 4, 2, 1), nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2),
nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2),
nn.Conv2d(512, 1, 4, 1, 0), nn.Sigmoid()
)
def forward(self, img):
return self.model(img).view(-1, 1)
# **4. 训练**
generator = Generator()
discriminator = Discriminator()
criterion = nn.BCELoss()
optimizer_G = optim.AdamW(generator.parameters(), lr=lr)
optimizer_D = optim.AdamW(discriminator.parameters(), lr=lr)
for epoch in range(num_epochs):
for imgs, _ in dataloader:
real_labels = torch.full((imgs.size(0), 1), 0.9) # Label Smoothing
fake_labels = torch.full((imgs.size(0), 1), 0.1)
# 训练判别器
optimizer_D.zero_grad()
loss_real = criterion(discriminator(imgs), real_labels)
z = torch.randn(imgs.size(0), latent_dim, 1, 1)
fake_imgs = generator(z)
loss_fake = criterion(discriminator(fake_imgs.detach()), fake_labels)
loss_D = loss_real + loss_fake
loss_D.backward()
optimizer_D.step()
# 训练生成器
optimizer_G.zero_grad()
loss_G = criterion(discriminator(fake_imgs), real_labels)
loss_G.backward()
optimizer_G.step()
# **生成增强数据**
z = torch.randn(10, latent_dim, 1, 1)
generated_images = generator(z).detach().numpy().squeeze()
plt.figure(figsize=(10, 2))
for i in range(10):
plt.subplot(1, 10, i+1)
plt.imshow(generated_images[i], cmap='gray')
plt.axis('off')
plt.show()
3. SeqGAN:用于生成新闻摘要
任务:增强新闻数据
优化点: ✅ 使用 GRU 代替 LSTM 提高效率
✅ 强化学习(RL) 改进训练
✅ 训练数据使用 新闻标题数据
代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
# 模拟新闻文本数据
vocab = ["AI", "climate", "policy", "data", "science", "election", "economy"]
vocab_size = len(vocab)
seq_length = 5
latent_dim = 10
# **1. 生成器**
class TextGenerator(nn.Module):
def __init__(self):
super(TextGenerator, self).__init__()
self.gru = nn.GRU(latent_dim, 128, batch_first=True)
self.fc = nn.Linear(128, vocab_size)
def forward(self, z):
h, _ = self.gru(z)
return self.fc(h)
# **2. 判别器**
class TextDiscriminator(nn.Module):
def __init__(self):
super(TextDiscriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(seq_length * vocab_size, 128), nn.ReLU(),
nn.Linear(128, 64), nn.ReLU(),
nn.Linear(64, 1), nn.Sigmoid()
)
def forward(self, text):
return self.model(text.view(text.size(0), -1))
# 生成文本
generator = TextGenerator()
z = torch.randn(1, seq_length, latent_dim)
generated_text = generator(z).argmax(dim=2)
print("Generated Text:", [vocab[i] for i in generated_text[0].tolist()])
4. TimeGAN:用于金融数据增强
任务:生成合成股票数据
优化点: ✅ 真实股票数据训练
✅ 使用 RNN 生成时间序列
✅ 加入自监督学习(Autoencoder)
代码
from timesynth import TimeSynth
import matplotlib.pyplot as plt
# **1. 生成股票时间序列**
time_sampler = TimeSynth.TimeSampler(stop_time=20)
time_samples = time_sampler.sample_regular_time(num_points=100)
# **2. 生成正弦波时间序列(模拟股票波动)**
synth = TimeSynth.Sinusoidal(frequency=0.1)
synthetic_series = synth.sample(time_samples)
# **3. 绘制时间序列**
plt.plot(time_samples, synthetic_series)
plt.title("Generated Time Series (Stock Data)")
plt.show()
总结
方法 | 任务 | 优化点 |
---|---|---|
DCGAN | 医学影像增强 | 深度 CNN, Label Smoothing |
SeqGAN | 生成新闻摘要 | GRU, 强化学习 |
TimeGAN | 生成股票数据 | RNN, 自监督学习 |
这些方法可广泛应用于医学、金融、新闻等领域!
【哈佛博后带小白玩转机器学习】 【限时5折-含直播】哈佛博后带小白玩转机器学习_哔哩哔哩_bilibili
总课时超400+,时长75+小时