【深度学习实践】基于 GAN 的数据增强优化版与实践版

基于 GAN 的数据增强(优化版:DCGAN、SeqGAN、TimeGAN)

在深度学习中,数据增强 是提升模型泛化能力的重要方法。生成对抗网络(GAN)图像、文本、时间序列 领域的增强具有显著效果。本篇将优化 DCGAN(图像)、SeqGAN(文本)、TimeGAN(时间序列) 的代码,并应用到 真实任务 中。


1. GAN 原理

GAN 由 生成器(G)判别器(D) 组成:

  • 生成器(G):输入随机噪声,输出接近真实数据的样本。
  • 判别器(D):区分数据是真实(1)还是生成的(0)。

2. DCGAN:用于医学影像数据增强(优化版)

任务:生成医学 X-ray 影像

优化点: ✅ 使用 深度卷积神经网络(CNN) 代替 MLP
AdamW 代替 Adam,提高稳定性
✅ 使用 Label Smoothing 使得训练更平稳

代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt

# 配置参数
batch_size = 128
latent_dim = 100
image_size = 64
num_epochs = 100
lr = 0.0002

# **1. 加载 X-ray 影像数据**
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.FakeData(transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# **2. 生成器**
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0), nn.BatchNorm2d(512), nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.ConvTranspose2d(128, 1, 4, 2, 1), nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

# **3. 判别器**
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 128, 4, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, 4, 1, 0), nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img).view(-1, 1)

# **4. 训练**
generator = Generator()
discriminator = Discriminator()
criterion = nn.BCELoss()
optimizer_G = optim.AdamW(generator.parameters(), lr=lr)
optimizer_D = optim.AdamW(discriminator.parameters(), lr=lr)

for epoch in range(num_epochs):
    for imgs, _ in dataloader:
        real_labels = torch.full((imgs.size(0), 1), 0.9)  # Label Smoothing
        fake_labels = torch.full((imgs.size(0), 1), 0.1)

        # 训练判别器
        optimizer_D.zero_grad()
        loss_real = criterion(discriminator(imgs), real_labels)
        z = torch.randn(imgs.size(0), latent_dim, 1, 1)
        fake_imgs = generator(z)
        loss_fake = criterion(discriminator(fake_imgs.detach()), fake_labels)
        loss_D = loss_real + loss_fake
        loss_D.backward()
        optimizer_D.step()

        # 训练生成器
        optimizer_G.zero_grad()
        loss_G = criterion(discriminator(fake_imgs), real_labels)
        loss_G.backward()
        optimizer_G.step()

# **生成增强数据**
z = torch.randn(10, latent_dim, 1, 1)
generated_images = generator(z).detach().numpy().squeeze()
plt.figure(figsize=(10, 2))
for i in range(10):
    plt.subplot(1, 10, i+1)
    plt.imshow(generated_images[i], cmap='gray')
    plt.axis('off')
plt.show()

3. SeqGAN:用于生成新闻摘要

任务:增强新闻数据

优化点: ✅ 使用 GRU 代替 LSTM 提高效率
强化学习(RL) 改进训练
✅ 训练数据使用 新闻标题数据

代码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random

# 模拟新闻文本数据
vocab = ["AI", "climate", "policy", "data", "science", "election", "economy"]
vocab_size = len(vocab)
seq_length = 5
latent_dim = 10

# **1. 生成器**
class TextGenerator(nn.Module):
    def __init__(self):
        super(TextGenerator, self).__init__()
        self.gru = nn.GRU(latent_dim, 128, batch_first=True)
        self.fc = nn.Linear(128, vocab_size)

    def forward(self, z):
        h, _ = self.gru(z)
        return self.fc(h)

# **2. 判别器**
class TextDiscriminator(nn.Module):
    def __init__(self):
        super(TextDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(seq_length * vocab_size, 128), nn.ReLU(),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 1), nn.Sigmoid()
        )

    def forward(self, text):
        return self.model(text.view(text.size(0), -1))

# 生成文本
generator = TextGenerator()
z = torch.randn(1, seq_length, latent_dim)
generated_text = generator(z).argmax(dim=2)
print("Generated Text:", [vocab[i] for i in generated_text[0].tolist()])

4. TimeGAN:用于金融数据增强

任务:生成合成股票数据

优化点:真实股票数据训练
使用 RNN 生成时间序列
加入自监督学习(Autoencoder)

代码

from timesynth import TimeSynth
import matplotlib.pyplot as plt

# **1. 生成股票时间序列**
time_sampler = TimeSynth.TimeSampler(stop_time=20)
time_samples = time_sampler.sample_regular_time(num_points=100)

# **2. 生成正弦波时间序列(模拟股票波动)**
synth = TimeSynth.Sinusoidal(frequency=0.1)
synthetic_series = synth.sample(time_samples)

# **3. 绘制时间序列**
plt.plot(time_samples, synthetic_series)
plt.title("Generated Time Series (Stock Data)")
plt.show()

总结

方法任务优化点
DCGAN医学影像增强深度 CNN, Label Smoothing
SeqGAN生成新闻摘要GRU, 强化学习
TimeGAN生成股票数据RNN, 自监督学习

这些方法可广泛应用于医学、金融、新闻等领域!

哈佛博后带小白玩转机器学习】 【限时5折-含直播】哈佛博后带小白玩转机器学习_哔哩哔哩_bilibili

总课时超400+,时长75+小时

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值