基于 GAN 的数据增强(DCGAN、SeqGAN、TimeGAN)
数据增强是深度学习中的关键技术,GAN(生成对抗网络)提供了一种数据生成的方式,特别适用于 图像、文本、时间序列 数据增强。本篇详细介绍 DCGAN(图像)、SeqGAN(文本)、TimeGAN(时间序列) 三种不同类型的 GAN 及其 Python 代码实现。
1. GAN 介绍
GAN(Generative Adversarial Network)由 生成器(G) 和 判别器(D) 组成,二者相互竞争:
- 生成器(G):输入随机噪声,输出接近真实数据的样本。
- 判别器(D):判断输入样本是真实数据还是生成数据。
2. DCGAN:用于图像数据增强
原理
DCGAN(Deep Convolutional GAN)是一种专用于 图像数据增强 的 GAN。它使用 卷积神经网络(CNN) 作为生成器和判别器,使得生成的图像更加真实。
代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# 设定参数
batch_size = 128
latent_dim = 100
image_size = 28
num_epochs = 50
lr = 0.0002
# 载入 MNIST 数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# **1. 生成器**
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 256), nn.ReLU(),
nn.Linear(256, 512), nn.ReLU(),
nn.Linear(512, 1024), nn.ReLU(),
nn.Linear(1024, image_size * image_size), nn.Tanh()
)
def forward(self, z):
img = self.model(z).view(-1, 1, image_size, image_size)
return img
# **2. 判别器**
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(image_size * image_size, 1024), nn.LeakyReLU(0.2),
nn.Linear(1024, 512), nn.LeakyReLU(0.2),
nn.Linear(512, 256), nn.LeakyReLU(0.2),
nn.Linear(256, 1), nn.Sigmoid()
)
def forward(self, img):
img = img.view(img.size(0), -1)
return self.model(img)
# **训练 DCGAN**
generator = Generator()
discriminator = Discriminator()
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
for epoch in range(num_epochs):
for imgs, _ in dataloader:
real_imgs = imgs.to(torch.float32)
real_labels = torch.ones(imgs.size(0), 1)
fake_labels = torch.zeros(imgs.size(0), 1)
# **训练判别器**
optimizer_D.zero_grad()
loss_real = criterion(discriminator(real_imgs), real_labels)
z = torch.randn(imgs.size(0), latent_dim)
fake_imgs = generator(z)
loss_fake = criterion(discriminator(fake_imgs.detach()), fake_labels)
loss_D = loss_real + loss_fake
loss_D.backward()
optimizer_D.step()
# **训练生成器**
optimizer_G.zero_grad()
loss_G = criterion(discriminator(fake_imgs), real_labels)
loss_G.backward()
optimizer_G.step()
# **生成增强数据**
z = torch.randn(10, latent_dim)
generated_images = generator(z).detach().numpy().squeeze()
fig, axes = plt.subplots(1, 10, figsize=(15, 2))
for i, img in enumerate(generated_images):
axes[i].imshow(img, cmap='gray')
axes[i].axis('off')
plt.show()
3. SeqGAN:用于文本数据增强
原理
SeqGAN 通过 强化学习(RL) 方法生成高质量文本:
- 生成器(G) 使用 LSTM 或 Transformer 生成文本序列。
- 判别器(D) 判断文本是否来自数据集。
- 通过 策略梯度(Policy Gradient) 训练,使生成器生成更接近真实的文本。
代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
# 模拟文本数据(单词序列)
vocab = ["hello", "world", "deep", "learning", "GAN", "text", "data"]
vocab_size = len(vocab)
seq_length = 5
latent_dim = 10
# **1. 生成器(LSTM)**
class TextGenerator(nn.Module):
def __init__(self):
super(TextGenerator, self).__init__()
self.lstm = nn.LSTM(latent_dim, 128, batch_first=True)
self.fc = nn.Linear(128, vocab_size)
def forward(self, z):
h, _ = self.lstm(z)
return self.fc(h)
# **2. 判别器**
class TextDiscriminator(nn.Module):
def __init__(self):
super(TextDiscriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(seq_length * vocab_size, 128), nn.ReLU(),
nn.Linear(128, 64), nn.ReLU(),
nn.Linear(64, 1), nn.Sigmoid()
)
def forward(self, text):
return self.model(text.view(text.size(0), -1))
# **训练 SeqGAN(省略完整训练流程)**
generator = TextGenerator()
discriminator = TextDiscriminator()
# 生成随机文本
z = torch.randn(1, seq_length, latent_dim)
generated_text = generator(z).argmax(dim=2)
print("Generated Text:", [vocab[i] for i in generated_text[0].tolist()])
4. TimeGAN:用于时间序列数据增强
原理
TimeGAN 结合 自监督学习(Self-Supervised Learning) 与 GAN 进行时间序列增强:
- 生成器(G) 生成时间序列数据。
- 判别器(D) 判别数据是否为真实。
- 自编码器(Autoencoder) 进行数据特征学习。
代码
from timesynth import TimeSynth
import matplotlib.pyplot as plt
# 生成示例时间序列数据
time_sampler = TimeSynth.TimeSampler(stop_time=20)
time_samples = time_sampler.sample_regular_time(num_points=100)
# 生成正弦波时间序列
synth = TimeSynth.Sinusoidal(frequency=0.1)
synthetic_series = synth.sample(time_samples)
# 绘制时间序列
plt.plot(time_samples, synthetic_series)
plt.title("Generated Time Series (TimeGAN)")
plt.show()
总结
方法 | 适用场景 | 特点 |
---|---|---|
DCGAN | 图像增强 | 生成逼真的图像数据 |
SeqGAN | 文本增强 | 生成自然语言句子 |
TimeGAN | 时间序列增强 | 生成时序数据 |
GAN 技术在数据增强中有广泛应用,适用于 数据量不足的场景。下节进行 更详细的优化技巧或应用到具体任务 🚀
【哈佛博后带小白玩转机器学习】 【限时5折-含直播】哈佛博后带小白玩转机器学习_哔哩哔哩_bilibili
总课时超400+,时长75+小时