【深度学习实践】GAN技术进行数据增强

基于 GAN 的数据增强(DCGAN、SeqGAN、TimeGAN)

数据增强是深度学习中的关键技术,GAN(生成对抗网络)提供了一种数据生成的方式,特别适用于 图像、文本、时间序列 数据增强。本篇详细介绍 DCGAN(图像)、SeqGAN(文本)、TimeGAN(时间序列) 三种不同类型的 GAN 及其 Python 代码实现。


1. GAN 介绍

GAN(Generative Adversarial Network)由 生成器(G)判别器(D) 组成,二者相互竞争:

  • 生成器(G):输入随机噪声,输出接近真实数据的样本。
  • 判别器(D):判断输入样本是真实数据还是生成数据。

2. DCGAN:用于图像数据增强

原理

DCGAN(Deep Convolutional GAN)是一种专用于 图像数据增强 的 GAN。它使用 卷积神经网络(CNN) 作为生成器和判别器,使得生成的图像更加真实。

代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# 设定参数
batch_size = 128
latent_dim = 100
image_size = 28
num_epochs = 50
lr = 0.0002

# 载入 MNIST 数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# **1. 生成器**
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256), nn.ReLU(),
            nn.Linear(256, 512), nn.ReLU(),
            nn.Linear(512, 1024), nn.ReLU(),
            nn.Linear(1024, image_size * image_size), nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z).view(-1, 1, image_size, image_size)
        return img

# **2. 判别器**
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_size * image_size, 1024), nn.LeakyReLU(0.2),
            nn.Linear(1024, 512), nn.LeakyReLU(0.2),
            nn.Linear(512, 256), nn.LeakyReLU(0.2),
            nn.Linear(256, 1), nn.Sigmoid()
        )

    def forward(self, img):
        img = img.view(img.size(0), -1)
        return self.model(img)

# **训练 DCGAN**
generator = Generator()
discriminator = Discriminator()
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

for epoch in range(num_epochs):
    for imgs, _ in dataloader:
        real_imgs = imgs.to(torch.float32)
        real_labels = torch.ones(imgs.size(0), 1)
        fake_labels = torch.zeros(imgs.size(0), 1)

        # **训练判别器**
        optimizer_D.zero_grad()
        loss_real = criterion(discriminator(real_imgs), real_labels)
        z = torch.randn(imgs.size(0), latent_dim)
        fake_imgs = generator(z)
        loss_fake = criterion(discriminator(fake_imgs.detach()), fake_labels)
        loss_D = loss_real + loss_fake
        loss_D.backward()
        optimizer_D.step()

        # **训练生成器**
        optimizer_G.zero_grad()
        loss_G = criterion(discriminator(fake_imgs), real_labels)
        loss_G.backward()
        optimizer_G.step()

# **生成增强数据**
z = torch.randn(10, latent_dim)
generated_images = generator(z).detach().numpy().squeeze()

fig, axes = plt.subplots(1, 10, figsize=(15, 2))
for i, img in enumerate(generated_images):
    axes[i].imshow(img, cmap='gray')
    axes[i].axis('off')
plt.show()

3. SeqGAN:用于文本数据增强

原理

SeqGAN 通过 强化学习(RL) 方法生成高质量文本:

  1. 生成器(G) 使用 LSTM 或 Transformer 生成文本序列。
  2. 判别器(D) 判断文本是否来自数据集。
  3. 通过 策略梯度(Policy Gradient) 训练,使生成器生成更接近真实的文本。

代码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random

# 模拟文本数据(单词序列)
vocab = ["hello", "world", "deep", "learning", "GAN", "text", "data"]
vocab_size = len(vocab)
seq_length = 5
latent_dim = 10

# **1. 生成器(LSTM)**
class TextGenerator(nn.Module):
    def __init__(self):
        super(TextGenerator, self).__init__()
        self.lstm = nn.LSTM(latent_dim, 128, batch_first=True)
        self.fc = nn.Linear(128, vocab_size)

    def forward(self, z):
        h, _ = self.lstm(z)
        return self.fc(h)

# **2. 判别器**
class TextDiscriminator(nn.Module):
    def __init__(self):
        super(TextDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(seq_length * vocab_size, 128), nn.ReLU(),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 1), nn.Sigmoid()
        )

    def forward(self, text):
        return self.model(text.view(text.size(0), -1))

# **训练 SeqGAN(省略完整训练流程)**
generator = TextGenerator()
discriminator = TextDiscriminator()

# 生成随机文本
z = torch.randn(1, seq_length, latent_dim)
generated_text = generator(z).argmax(dim=2)
print("Generated Text:", [vocab[i] for i in generated_text[0].tolist()])

4. TimeGAN:用于时间序列数据增强

原理

TimeGAN 结合 自监督学习(Self-Supervised Learning) 与 GAN 进行时间序列增强:

  • 生成器(G) 生成时间序列数据。
  • 判别器(D) 判别数据是否为真实。
  • 自编码器(Autoencoder) 进行数据特征学习。

代码

from timesynth import TimeSynth
import matplotlib.pyplot as plt

# 生成示例时间序列数据
time_sampler = TimeSynth.TimeSampler(stop_time=20)
time_samples = time_sampler.sample_regular_time(num_points=100)

# 生成正弦波时间序列
synth = TimeSynth.Sinusoidal(frequency=0.1)
synthetic_series = synth.sample(time_samples)

# 绘制时间序列
plt.plot(time_samples, synthetic_series)
plt.title("Generated Time Series (TimeGAN)")
plt.show()

总结

方法适用场景特点
DCGAN图像增强生成逼真的图像数据
SeqGAN文本增强生成自然语言句子
TimeGAN时间序列增强生成时序数据

GAN 技术在数据增强中有广泛应用,适用于 数据量不足的场景。下节进行 更详细的优化技巧或应用到具体任务 🚀

哈佛博后带小白玩转机器学习】 【限时5折-含直播】哈佛博后带小白玩转机器学习_哔哩哔哩_bilibili

总课时超400+,时长75+小时

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值