知识点回顾:
- 对抗生成网络的思想:关注损失从何而来
- 生成器、判别器
- nn.sequential容器:适合于按顺序运算的情况,简化前向传播写法
- leakyReLU介绍:避免relu的神经元失活现象
ps;如果你学有余力,对于gan的损失函数的理解,建议去找找视频看看,如果只是用,没必要学
作业:对于心脏病数据集,对于病人这个不平衡的样本用GAN来学习并生成病人样本,观察不用GAN和用GAN的F1分数差异
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score, classification_report, confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")
# 设置中文字体(解决中文显示问题)
plt.rcParams['font.sans-serif'] = ['SimHei'] # Windows系统常用黑体字体
plt.rcParams['axes.unicode_minus'] = False # 正常显示负号
data = pd.read_csv('data.csv') #读取数据
# 设置随机种子以确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# ---------------------------
# 1. 加载并预处理心脏病数据集
# ---------------------------
def load_heart_disease_data():
"""加载心脏病数据集并进行预处理"""
try:
# 读取CSV文件
df = pd.read_csv("heart.csv")
print(f"成功加载heart.csv,数据形状: {df.shape}")
# 显示数据基本信息
print("\n数据基本信息:")
df.info()
# 显示数据集行数和列数
rows, columns = df.shape
# 数据可视化 - 目标变量分布
plt.figure(figsize=(6, 4))
sns.countplot(x='target', data=df)
plt.title('心脏病患者分布')
plt.xlabel('是否患病')
plt.ylabel('样本数')
plt.xticks([0, 1], ['健康', '患病'])
plt.show()
# 重命名列以便更好理解
df.columns = ['age', 'sex', 'chest_pain_type', 'resting_blood_pressure',
'cholesterol', 'fasting_blood_sugar', 'rest_ecg',
'max_heart_rate_achieved', 'exercise_induced_angina',
'st_depression', 'st_slope', 'num_major_vessels',
'thalassemia', 'target']
# 转换分类特征
categorical_features = ['sex', 'chest_pain_type', 'fasting_blood_sugar',
'rest_ecg', 'exercise_induced_angina', 'st_slope',
'num_major_vessels', 'thalassemia']
for feature in categorical_features:
df[feature] = df[feature].astype('object')
# 独热编码
df = pd.get_dummies(df, drop_first=True)
# 划分特征和目标变量
X = df.drop(columns='target').values
y = df['target'].values
# 打印类别分布
class_counts = pd.Series(y).value_counts()
print(f"\n类别分布: \n{class_counts}")
print(f"不平衡比例: {class_counts[0]/class_counts[1]:.2f}:1")
return X, y
except FileNotFoundError:
print("错误:未找到heart.csv文件,请检查文件路径")
return None, None
except Exception as e:
print(f"数据加载错误: {str(e)}")
return None, None
# 加载数据
X, y = load_heart_disease_data()
if X is None or y is None:
exit()
# 数据缩放
scaler = MinMaxScaler(feature_range=(-1, 1))
X_scaled = scaler.fit_transform(X)
# 分割数据
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
# 提取少数类样本 (心脏病患者)
X_minority = X_train[y_train == 1]
y_minority = y_train[y_train == 1]
print(f"训练集中少数类样本数量: {len(X_minority)}")
# ---------------------------
# 2. 定义条件GAN模型
# ---------------------------
class ConditionalGenerator(nn.Module):
def __init__(self, input_dim, output_dim, label_dim):
super(ConditionalGenerator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim + label_dim, 64),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(64),
nn.Linear(64, 128),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(128),
nn.Linear(128, output_dim),
nn.Tanh() # 输出范围为[-1, 1],与数据缩放范围一致
)
def forward(self, z, labels):
# 合并噪声和标签
input_tensor = torch.cat([z, labels], dim=1)
return self.model(input_tensor)
class ConditionalDiscriminator(nn.Module):
def __init__(self, input_dim, label_dim):
super(ConditionalDiscriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim + label_dim, 128),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(128, 64),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(64, 1),
nn.Sigmoid() # 输出概率值
)
def forward(self, x, labels):
# 合并输入特征和标签
input_tensor = torch.cat([x, labels], dim=1)
return self.model(input_tensor)
# ---------------------------
# 3. 训练CGAN模型
# ---------------------------
# 模型参数
LATENT_DIM = 10
INPUT_DIM = X_train.shape[1]
LABEL_DIM = 1 # 二分类问题
EPOCHS = 1000
BATCH_SIZE = 32
LR = 0.0002
BETA1 = 0.5
# 创建数据加载器
minority_dataset = TensorDataset(
torch.FloatTensor(X_minority),
torch.FloatTensor(y_minority).view(-1, 1)
)
minority_dataloader = DataLoader(minority_dataset, batch_size=BATCH_SIZE, shuffle=True)
# 实例化模型
generator = ConditionalGenerator(LATENT_DIM, INPUT_DIM, LABEL_DIM).to(device)
discriminator = ConditionalDiscriminator(INPUT_DIM, LABEL_DIM).to(device)
# 定义损失函数和优化器
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=LR, betas=(BETA1, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=LR, betas=(BETA1, 0.999))
# 训练循环
g_losses, d_losses = [], []
for epoch in range(EPOCHS):
epoch_g_loss, epoch_d_loss = 0, 0
batches_per_epoch = 0
for i, (real_data, real_labels) in enumerate(minority_dataloader):
real_data = real_data.to(device)
real_labels = real_labels.to(device)
batch_size = real_data.size(0)
batches_per_epoch += 1
# 创建真实和虚假标签
real_targets = torch.ones(batch_size, 1).to(device)
fake_targets = torch.zeros(batch_size, 1).to(device)
# ---------------------
# 训练判别器
# ---------------------
d_optimizer.zero_grad()
# 用真实数据训练
real_validity = discriminator(real_data, real_labels)
d_real_loss = criterion(real_validity, real_targets)
# 生成假数据
z = torch.randn(batch_size, LATENT_DIM).to(device)
fake_labels = real_labels # 生成与真实样本相同类别的数据
fake_data = generator(z, fake_labels)
# 用假数据训练
fake_validity = discriminator(fake_data.detach(), fake_labels)
d_fake_loss = criterion(fake_validity, fake_targets)
# 总判别器损失
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.backward()
d_optimizer.step()
epoch_d_loss += d_loss.item()
# ---------------------
# 训练生成器
# ---------------------
g_optimizer.zero_grad()
# 生成假数据
fake_data = generator(z, fake_labels)
fake_validity = discriminator(fake_data, fake_labels)
# 生成器损失
g_loss = criterion(fake_validity, real_targets)
g_loss.backward()
g_optimizer.step()
epoch_g_loss += g_loss.item()
# 计算平均损失
avg_g_loss = epoch_g_loss / batches_per_epoch
avg_d_loss = epoch_d_loss / batches_per_epoch
g_losses.append(avg_g_loss)
d_losses.append(avg_d_loss)
# 每100个epoch打印一次损失
if (epoch + 1) % 100 == 0:
print(f"Epoch [{epoch+1}/{EPOCHS}], D_loss: {avg_d_loss:.4f}, G_loss: {avg_g_loss:.4f}")
print("CGAN训练完成!")
# 绘制训练损失曲线
plt.figure(figsize=(10, 5))
plt.plot(g_losses, label='生成器损失')
plt.plot(d_losses, label='判别器损失')
plt.title('训练损失曲线')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.show()
# ---------------------------
# 4. 生成合成样本
# ---------------------------
# 设置为评估模式
generator.eval()
# 生成与少数类样本数量相同的合成数据
num_samples_to_generate = len(X_minority)
z = torch.randn(num_samples_to_generate, LATENT_DIM).to(device)
labels = torch.ones(num_samples_to_generate, 1).to(device) # 标签为1,表示心脏病患者
with torch.no_grad():
synthetic_data = generator(z, labels).cpu().numpy()
# 逆缩放合成数据
synthetic_data = scaler.inverse_transform(synthetic_data)
# 为合成数据创建标签
synthetic_labels = np.ones(num_samples_to_generate)
# 逆缩放原始训练数据用于可视化
X_train_original = scaler.inverse_transform(X_train)
# ---------------------------
# 5. 可视化原始数据和合成数据
# ---------------------------
# 可视化特征分布对比
plt.figure(figsize=(14, 10))
feature_names = ['年龄', '血压', '胆固醇', '最大心率']
feature_indices = [0, 3, 4, 7] # 对应数据集中的特征索引
for i, idx in enumerate(feature_indices):
plt.subplot(2, 2, i+1)
# 绘制原始少数类样本的特征分布
sns.kdeplot(X_train_original[y_train == 1, idx], label='原始数据', color='blue')
# 绘制合成样本的特征分布
sns.kdeplot(synthetic_data[:, idx], label='合成数据', color='orange')
plt.title(f'{feature_names[i]}分布对比')
plt.xlabel('特征值')
plt.ylabel('密度')
plt.legend()
plt.tight_layout()
plt.show()
# ---------------------------
# 6. 比较模型性能
# ---------------------------
# 6.1 使用原始数据训练的模型
model_original = RandomForestClassifier(random_state=42)
model_original.fit(X_train, y_train)
y_pred_original = model_original.predict(X_test)
y_pred_prob_original = model_original.predict_proba(X_test)[:, 1]
# 6.2 使用增强数据训练的模型
# 将合成数据添加到训练集中
X_train_augmented = np.vstack([X_train, scaler.transform(synthetic_data)])
y_train_augmented = np.hstack([y_train, synthetic_labels])
model_augmented = RandomForestClassifier(random_state=42)
model_augmented.fit(X_train_augmented, y_train_augmented)
y_pred_augmented = model_augmented.predict(X_test)
y_pred_prob_augmented = model_augmented.predict_proba(X_test)[:, 1]
# 6.3 比较F1分数
f1_original = f1_score(y_test, y_pred_original)
f1_augmented = f1_score(y_test, y_pred_augmented)
print("\n模型性能比较:")
print(f"原始数据训练的模型 F1 分数: {f1_original:.4f}")
print(f"增强数据训练的模型 F1 分数: {f1_augmented:.4f}")
# 打印详细分类报告
print("\n原始数据训练的模型分类报告:")
print(classification_report(y_test, y_pred_original))
print("\n增强数据训练的模型分类报告:")
print(classification_report(y_test, y_pred_augmented))
# 计算混淆矩阵
cm_original = confusion_matrix(y_test, y_pred_original)
cm_augmented = confusion_matrix(y_test, y_pred_augmented)
# 计算ROC曲线
fpr_original, tpr_original, _ = roc_curve(y_test, y_pred_prob_original)
fpr_augmented, tpr_augmented, _ = roc_curve(y_test, y_pred_prob_augmented)
roc_auc_original = auc(fpr_original, tpr_original)
roc_auc_augmented = auc(fpr_augmented, tpr_augmented)
# ---------------------------
# 7. 可视化评估结果
# ---------------------------
# 7.1 混淆矩阵对比
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
sns.heatmap(cm_original, annot=True, fmt='d', cmap='Blues')
plt.title('原始数据模型混淆矩阵')
plt.xlabel('预测类别')
plt.ylabel('真实类别')
plt.subplot(1, 2, 2)
sns.heatmap(cm_augmented, annot=True, fmt='d', cmap='Greens')
plt.title('增强数据模型混淆矩阵')
plt.xlabel('预测类别')
plt.ylabel('真实类别')
plt.tight_layout()
plt.show()
# 7.2 ROC曲线对比
plt.figure(figsize=(8, 6))
plt.plot(fpr_original, tpr_original, color='blue', lw=2,
label=f'原始数据 (AUC = {roc_auc_original:.2f})')
plt.plot(fpr_augmented, tpr_augmented, color='green', lw=2,
label=f'增强数据 (AUC = {roc_auc_augmented:.2f})')
plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('假阳性率 (1-特异性)')
plt.ylabel('真阳性率 (敏感性)')
plt.title('ROC曲线比较')
plt.legend(loc="lower right")
plt.show()
# 7.3 F1分数对比
plt.figure(figsize=(8, 5))
plt.bar(['原始数据', '增强数据'], [f1_original, f1_augmented], color=['#636EFA', '#EF553B'])
plt.ylim(0, 1)
plt.title('使用GAN增强前后的模型F1分数对比')
plt.ylabel('F1分数')
for i, v in enumerate([f1_original, f1_augmented]):
plt.text(i, v + 0.02, f'{v:.4f}', ha='center')
plt.show()
使用设备: cuda
成功加载heart.csv,数据形状: (303, 14)
数据基本信息:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 303 entries, 0 to 302
Data columns (total 14 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 age 303 non-null int64
1 sex 303 non-null int64
2 cp 303 non-null int64
3 trestbps 303 non-null int64
4 chol 303 non-null int64
5 fbs 303 non-null int64
6 restecg 303 non-null int64
7 thalach 303 non-null int64
8 exang 303 non-null int64
9 oldpeak 303 non-null float64
10 slope 303 non-null int64
11 ca 303 non-null int64
12 thal 303 non-null int64
13 target 303 non-null int64
dtypes: float64(1), int64(13)
memory usage: 33.3 KB
类别分布:
1 165
0 138
Name: count, dtype: int64
不平衡比例: 0.84:1
训练集中少数类样本数量: 133
Epoch [100/1000], D_loss: 0.5932, G_loss: 0.9958
Epoch [200/1000], D_loss: 0.6242, G_loss: 0.9327
Epoch [300/1000], D_loss: 0.5531, G_loss: 1.0192
Epoch [400/1000], D_loss: 0.5781, G_loss: 0.9459
Epoch [500/1000], D_loss: 0.5800, G_loss: 0.8725
Epoch [600/1000], D_loss: 0.6331, G_loss: 0.9173
Epoch [700/1000], D_loss: 0.6049, G_loss: 0.9303
Epoch [800/1000], D_loss: 0.6015, G_loss: 0.9530
Epoch [900/1000], D_loss: 0.6002, G_loss: 0.9590
Epoch [1000/1000], D_loss: 0.6356, G_loss: 0.9191
CGAN训练完成!
模型性能比较:
原始数据训练的模型 F1 分数: 0.8333
增强数据训练的模型 F1 分数: 0.8333
原始数据训练的模型分类报告:
precision recall f1-score support
0 0.79 0.90 0.84 29
1 0.89 0.78 0.83 32
accuracy 0.84 61
macro avg 0.84 0.84 0.84 61
weighted avg 0.84 0.84 0.84 61
增强数据训练的模型分类报告:
precision recall f1-score support
0 0.79 0.90 0.84 29
1 0.89 0.78 0.83 32
accuracy 0.84 61
macro avg 0.84 0.84 0.84 61
weighted avg 0.84 0.84 0.84 61
总结
一、数据分布分析
- 原始类别分布:
从加载数据的输出可知,心脏病数据集共 303 条记录,其中患病(target=1
)样本 165 条,健康(target=0
)样本 138 条 ,不平衡比例约为0.84:1
,属于轻度不平衡数据集,这也是引入 GAN 做数据增强的初衷,试图补充少数类(健康样本相对患病样本是少数类 )数据,优化模型对少数类的识别能力。 - 合成数据分布:
观察特征分布对比图(如年龄、血压、胆固醇、最大心率分布),合成数据与原始数据的核密度曲线整体趋势较为接近,说明 GAN 生成的合成数据在关键特征维度上,能一定程度模拟原始数据的分布模式,具备补充真实数据分布的潜力 。不过部分特征(如胆固醇分布)的曲线形态仍有差异,反映出 GAN 生成数据与真实数据存在一定 gap,后续可尝试调整 GAN 网络结构(如增加层数、调整隐层维度)、训练参数(如学习率、迭代次数)优化生成效果。
二、GAN 训练过程分析
- 损失曲线:
生成器(Generator)和判别器(Discriminator)的损失曲线呈现出典型的 GAN 训练过程特征。前期判别器损失快速下降,生成器损失快速上升,体现判别器对真假数据的辨别能力快速提升,生成器则努力学习生成更 “逼真” 数据;后期二者损失进入相对稳定的波动状态,说明模型逐渐达到博弈平衡,但损失值未趋近于理想的低水平(如判别器损失未接近 0.5 附近稳定 ),可能是训练轮次不足、网络结构未充分拟合数据分布,或数据集本身特征复杂导致,可考虑延长训练周期、改进网络架构(如引入注意力机制、调整激活函数)进一步优化。
三、模型评估指标分析
- 混淆矩阵:
原始数据与增强数据训练的随机森林模型,混淆矩阵结构几乎一致。以类别 0(健康)和类别 1(患病)为例,真实为健康的样本中,预测正确 26 例、错误 3 例;真实为患病的样本中,预测正确 25 例、错误 7 例 。说明 GAN 增强数据未显著改变模型对不同类别样本的分类正误格局,可能是因为原始数据集不平衡程度相对较低,少量合成数据补充对模型决策边界影响有限;也可能是生成数据质量仍有提升空间,未给模型提供足够有价值的新信息。 - ROC 曲线与 AUC:
两条 ROC 曲线形态接近,AUC 均为 0.93 ,表明模型在原始数据和增强数据上,对正负样本的区分能力相当。AUC 数值较高,说明模型整体预测性能良好,但增强数据未带来 AUC 提升,侧面反映合成数据对模型区分能力的增益不明显 。可能需要生成更多、质量更高的合成数据,或尝试其他数据增强策略(如 SMOTE 结合 GAN ),强化对模型性能的正向影响。 - F1 分数:
原始数据和增强数据训练模型的 F1 分数均为 0.8333 ,F1 分数综合考量精确率(precision)和召回率(recall),该结果说明 GAN 增强后,模型在平衡精确率和召回率方面未体现优势。结合分类报告,模型对两类样本的精确率、召回率也无明显变化,再次验证 GAN 增强在当前实验设置下,未有效改善模型性能,后续需从数据生成质量、模型训练策略(如调整分类器参数、尝试其他分类模型)等方面优化。
整体来看,本次实验中 GAN 虽一定程度模拟了数据分布,但因原始数据集不平衡程度不高、生成数据质量和数量有限等,未显著提升分类模型性能。后续可从优化 GAN 训练(如调整网络、延长训练)、探索多策略数据增强、更换分类模型等方向深入,挖掘 GAN 对不平衡数据分类的价值。