知识点回顾:
- 对抗生成网络的思想:关注损失从何而来
- 生成器、判别器
- nn.sequential容器:适合于按顺序运算的情况,简化前向传播写法
- leakyReLU介绍:避免relu的神经元失活现象
ps;如果你学有余力,对于gan的损失函数的理解,建议去找找视频看看,如果只是用,没必要学
作业:对于心脏病数据集,对于病人这个不平衡的样本用GAN来学习并生成病人样本,观察不用GAN和用GAN的F1分数差异。
一、数据预处理(修改 src/data/preprocessing.py )
def split_minority_class(data_df):
# 提取少数类(病人样本)
minority = data_df[data_df.target == 1]
return minority.drop('target', axis=1).values
二、GAN网络定义(新增 src/models/gan.py )
class Generator(nn.Sequential):
def __init__(self, input_dim, output_dim):
super().__init__(
nn.Linear(input_dim, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, output_dim),
nn.Tanh()
)
class Discriminator(nn.Sequential):
def __init__(self, input_dim):
super().__init__(
nn.Linear(input_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1),
nn.Sigmoid()
)
三、训练流程(修改 src/models/train.py )
# GAN训练循环
for epoch in range(epochs):
for real_data in minority_loader:
# 生成假数据
z = torch.randn(batch_size, latent_dim)
fake_data = generator(z)
# 判别器训练
d_loss_real = criterion(discriminator(real_data), real_labels)
d_loss_fake = criterion(discriminator(fake_data.detach()), fake_labels)
d_loss = (d_loss_real + d_loss_fake) / 2
# 生成器训练
g_loss = criterion(discriminator(fake_data), real_labels)
四、评估对比(新增 src/visualization/evaluate.py )
def compare_f1(original_f1, gan_f1):
plt.figure(figsize=(10,6))
plt.bar(['Original', 'GAN Augmented'], [original_f1, gan_f1])
plt.title('F1 Score Comparison')
plt.savefig('reports/figures/f1_comparison.png')
执行流程
1.安装依赖
pip install imbalanced-learn
2.训练GAN生成样本
3.分别训练基线模型和增强模型
4.生成对比报告