使用元学习进行心脏病分类的完整示例,包括数据准备、模型训练(使用MAML框架)

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.metrics import roc_curve, auc, brier_score_loss, calibration_curve
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
import learn2learn as l2l

# 数据准备
data = pd.read_csv('heart_disease_data.csv')

# 特征和目标
X = data.drop('target', axis=1)
y = data['target']

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 定义模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(13, 64)  # 假设特征维度是13
        self.fc2 = nn.Linear(64, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x

# 初始化模型、优化器和损失函数
model = SimpleModel()
meta_optimizer = l2l.optimizers.MAML(model, lr=0.01)
criterion = nn.BCELoss()

# 创建数据集和数据加载器
dataset = TensorDataset(torch.tensor(X_train.values, dtype=torch.float32), torch.tensor(y_train.values, dtype=torch.float32))
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# 训练函数
def train_meta(model, data_loader, meta_optimizer, num_epochs=10):
    for epoch in range(num_epochs):
        for X_batch, y_batch in data_loader:
            X_batch, y_batch = torch.tensor(X_batch, dtype=torch.float32), torch.tensor(y_batch, dtype=torch.float32)
            
            # 内循环
            meta_optimizer.zero_grad()
            y_pred = model(X_batch)
            loss = criterion(y_pred.squeeze(), y_batch)
            loss.backward()
            meta_optimizer.step()
            
            # 外循环
            meta_optimizer.step()
        
        print(f'Epoch {epoch+1}/{num_epochs} completed')

# 训练模型
train_meta(model, data_loader, meta_optimizer)

# 测试模型
model.eval()
with torch.no_grad():
    X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32)
    y_prob = model(X_test_tensor).squeeze().numpy()
    y_pred = (y_prob > 0.5).astype(int)

# 计算指标
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
roc_auc = roc_auc_score(y_test, y_prob)

print(f'Accuracy: {accuracy:.2f}')
print(f'Precision: {precision:.2f}')
print(f'Recall: {recall:.2f}')
print(f'F1 Score: {f1:.2f}')
print(f'AUC-ROC: {roc_auc:.2f}')

# 绘制评估曲线
# ROC曲线
fpr, tpr, _ = roc_curve(y_test, y_prob)
roc_auc = auc(fpr, tpr)

# DCA曲线
def dca_curve(y_true, y_prob, thresholds=np.linspace(0, 1, 10)):
    dca_results = []
    for threshold in thresholds:
        y_pred = (y_prob >= threshold).astype(int)
        tp = np.sum((y_pred == 1) & (y_true == 1))
        fp = np.sum((y_pred == 1) & (y_true == 0))
        tn = np.sum((y_pred == 0) & (y_true == 0))
        fn = np.sum((y_pred == 0) & (y_true == 1))
        net_benefit = (tp - fp * (threshold / (1 - threshold))) / len(y_true)
        dca_results.append(net_benefit)
    return dca_results

thresholds = np.linspace(0, 1, 10)
dca_results = dca_curve(y_test, y_prob, thresholds)

# 校准曲线
prob_true, prob_pred = calibration_curve(y_test, y_prob, n_bins=10, strategy='uniform')

# 绘制曲线
plt.figure(figsize=(15, 5))

# ROC曲线
plt.subplot(1, 3, 1)
plt.plot(fpr, tpr, color='blue', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='grey', linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc='lower right')

# DCA曲线
plt.subplot(1, 3, 2)
plt.plot(thresholds, dca_results, marker='o')
plt.xlabel('Threshold')
plt.ylabel('Net Benefit')
plt.title('DCA Curve')

# 校准曲线
plt.subplot(1, 3, 3)
plt.plot(prob_pred, prob_true, marker='o', label='Calibration curve')
plt.plot([0, 1], [0, 1], color='grey', linestyle='--')
plt.xlabel('Mean Predicted Probability')
plt.ylabel('Fraction of Positives')
plt.title('Calibration Curve')

plt.tight_layout()
plt.show()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

程序员奇奇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值