python打卡day53@浙大疏锦行

知识点回顾:

  1. 对抗生成网络的思想:关注损失从何而来
  2. 生成器、判别器
  3. nn.sequential容器:适合于按顺序运算的情况,简化前向传播写法
  4. leakyReLU介绍:避免relu的神经元失活现象

ps;如果你学有余力对于gan的损失函数的理解,建议去找找视频看看,如果只是用,没必要学

作业:对于心脏病数据集,对于病人这个不平衡的样本用GAN来学习并生成病人样本,观察不用GAN和用GAN的F1分数差异。

一、数据预处理(修改 src/data/preprocessing.py )

def split_minority_class(data_df):
    # 提取少数类(病人样本)
    minority = data_df[data_df.target == 1]
    return minority.drop('target', axis=1).values

二、GAN网络定义(新增 src/models/gan.py )

class Generator(nn.Sequential):
    def __init__(self, input_dim, output_dim):
        super().__init__(
            nn.Linear(input_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, output_dim),
            nn.Tanh()
        )

class Discriminator(nn.Sequential):
    def __init__(self, input_dim):
        super().__init__(
            nn.Linear(input_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

三、训练流程(修改 src/models/train.py )

# GAN训练循环
for epoch in range(epochs):
    for real_data in minority_loader:
        # 生成假数据
        z = torch.randn(batch_size, latent_dim)
        fake_data = generator(z)
        
        # 判别器训练
        d_loss_real = criterion(discriminator(real_data), real_labels)
        d_loss_fake = criterion(discriminator(fake_data.detach()), fake_labels)
        d_loss = (d_loss_real + d_loss_fake) / 2
        
        # 生成器训练
        g_loss = criterion(discriminator(fake_data), real_labels)

四、评估对比(新增 src/visualization/evaluate.py )

def compare_f1(original_f1, gan_f1):
    plt.figure(figsize=(10,6))
    plt.bar(['Original', 'GAN Augmented'], [original_f1, gan_f1])
    plt.title('F1 Score Comparison')
    plt.savefig('reports/figures/f1_comparison.png')

执行流程

1.安装依赖

pip install imbalanced-learn

2.训练GAN生成样本
3.分别训练基线模型和增强模型
4.生成对比报告

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值