>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**
本人往期文章可查阅: 深度学习总结
🏡 我的环境:
- 语言环境:Python3.11
- 编译器:PyCharm
- 深度学习环境:Pytorch
-
- torch==2.0.0+cu118
-
- torchvision==0.18.1+cu118
- 显卡:NVIDIA GeForce GTX 1660
一些思路:
当看到这个项目数据的时候,第一反应是数据维度很多,一般是先要做特征提取,降维等操作。对于特征提取:
- 首先想到的是相关性分析,用热力图。但是分析得出与是否患病相关性比较强的只有4个特征,而一些我们日常以为的年龄、日常生活得分这些没有看出有相关性;
- 接着就通过画图来分析特征是否与目标有关系,首先对年龄进行了分析,发现有关系,之后做了一些其他特征的关系,通过绘图发现,种族、功能评估得分、日常生活得分都有关系;
- 由于特征维度多,绘图显然不是一个有效的方式,后面采用了随机森林进行分析,发现效果很好,于是就通过RFE对随机森林的结果进行特征筛选。
1.导入数据
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,TensorDataset
plt.rcParams['font.sans-serif']=['Microsoft YaHei'] # 显示中文
plt.rcParams['axes.unicode_minus']=False # 显示负号
data_df=pd.read_csv("E:/DATABASE/RNN/R8/alzheimers_disease_data.csv")
data_df.head()
输出:
该数据集是2149名被诊断患有阿尔茨海默病或有阿尔茨海默病风险的患者的健康记录的综合集合。数据集中的每个患者都有一个唯一的ID号,范围从4751到6900。该数据集涵盖了广泛的信息,这些信息对于理解与阿尔茨海默病相关的各种因素至关重要。它包括人口统计细节、生活习惯、病史、临床测量、认知和功能评估、症状和诊断信息。
# 标签中文化
data_df.rename(columns={"Age":"年龄","Gender":"性别","Ethnicity":"种族","EducationLevel":"教育水平","BMI":"身体质量指数(BMI)",
"Smoking":"吸烟状况","AlcoholConsumption":"酒精摄入量","PhysicalActivity":"体育活动时间","DietQuality":"饮食质量评分",
"SleepQuality":"睡眠质量评分","FamilyHistoryAlzheimers":"家族阿尔茨海默病史","CardiovascularDisease":"心血管疾病",
"Diabetes":"糖尿病","Depression":"抑郁病史","HeadInjury":"头部受伤","Hypertension":"高血压",
"SystolicBP":"收缩压","DiastolicBP":"舒张压","CholesterolTotal":"胆固醇总量","CholesterolLDL":"低密度脂蛋白胆固醇(LDL)",
"CholesterolHDL":"高密度脂蛋白胆固醇(HDL)","CholesterolTriglycerides":"甘油三酯","MMSE":"简易精神状态检查(MMSE)得分",
"FunctionalAssessment":"功能评估得分","MemoryComplaints":"记忆抱怨","BehavioralProblems":"行为问题",
"ADL":"日常生活活动(ADL)得分","Confusion":"混乱与定向障碍","Disorientation":"迷失方向","PersonalityChanges":"人格变化",
"DifficultyCompletingTasks":"完成任务困难","Forgetfulness":"健忘","Diagnosis":"诊断状态","DoctorInCharge":"主诊医生"},inplace=True)
data_df.columns
输出:
Index(['PatientID', '年龄', '性别', '种族', '教育水平', '身体质量指数(BMI)', '吸烟状况', '酒精摄入量',
'体育活动时间', '饮食质量评分', '睡眠质量评分', '家族阿尔茨海默病史', '心血管疾病', '糖尿病', '抑郁病史',
'头部受伤', '高血压', '收缩压', '舒张压', '胆固醇总量', '低密度脂蛋白胆固醇(LDL)',
'高密度脂蛋白胆固醇(HDL)', '甘油三酯', '简易精神状态检查(MMSE)得分', '功能评估得分', '记忆抱怨', '行为问题',
'日常生活活动(ADL)得分', '混乱与定向障碍', '迷失方向', '人格变化', '完成任务困难', '健忘', '诊断状态',
'主诊医生'],
dtype='object')
2.数据处理
检查是否有空数据:
data_df.isnull().sum()
输出:
PatientID 0
年龄 0
性别 0
种族 0
教育水平 0
身体质量指数(BMI) 0
吸烟状况 0
酒精摄入量 0
体育活动时间 0
饮食质量评分 0
睡眠质量评分 0
家族阿尔茨海默病史 0
心血管疾病 0
糖尿病 0
抑郁病史 0
头部受伤 0
高血压 0
收缩压 0
舒张压 0
胆固醇总量 0
低密度脂蛋白胆固醇(LDL) 0
高密度脂蛋白胆固醇(HDL) 0
甘油三酯 0
简易精神状态检查(MMSE)得分 0
功能评估得分 0
记忆抱怨 0
行为问题 0
日常生活活动(ADL)得分 0
混乱与定向障碍 0
迷失方向 0
人格变化 0
完成任务困难 0
健忘 0
诊断状态 0
主诊医生 0
dtype: int64
from sklearn.preprocessing import LabelEncoder
# 创建 LabelEncoder 实例
label_encoder=LabelEncoder()
# 对非数值型列进行标签编码
data_df['主诊医生']=label_encoder.fit_transform(data_df['主诊医生'])
data_df.head()
输出:
2.1.患病占比
# 计算是否患病,人数
counts=data_df["诊断状态"].value_counts()
# 计算百分比
sizes=counts/counts.sum()*100
# 绘制环形图
fig,ax=plt.subplots()
wedges,textx,autotexts=ax.pie(sizes,labels=sizes.index,autopct='%1.2ff%%',startangle=90,wedgeprops=dict(width=0.3))
plt.title("患病占比(1患病,0没有患病)")
plt.rcParams['figure.dpi']=300 # 分辨率
plt.show()
输出:
由上图可知:不患病人数居多
2.2.相关性分析
plt.figure(figsize=(40,35))
sns.heatmap(data_df.corr(),annot=True,fmt=".2f")
plt.show()
输出:
其中,与患病相关性比较强的有:MMSE得分、功能评估得分、记忆抱怨、行为问题等相关性比较强。其中,MMSE得分、功能评估得分为负相关,记忆抱怨、行为问题为正相关。
2.3.年龄与患病探究
data_df['年龄'].min(),data_df['年龄'].max()
输出:
(60, 90)
# 计算每一个年龄段患病人数
age_bins=range(60,91)
# 分组、聚合函数:sum求和,size总大小
grouped=data_df.groupby('年龄').agg({'诊断状态':['sum','size']})
grouped.columns=['患病','总人数']
# 计算不患病的人数
grouped['不患病']=grouped['总人数']-grouped['患病']
# 设置绘图风格
sns.set(style="whitegrid")
plt.figure(figsize=(12,5))
# 获取x轴标签(即年龄)
x=grouped.index.astype(str) # 将年龄转换为字符串格式便于显示
# 画图
plt.bar(x,grouped["不患病"],0.35,label="不患病",color='skyblue')
plt.bar(x,grouped["患病"],0.35,label="患病",color='salmon')
# 设置标题
plt.title("患病年龄分布",fontproperties='Microsoft YaHei')
plt.xlabel("年龄",fontproperties='Microsoft YaHei')
plt.ylabel("人数",fontproperties='Microsoft YaHei')
# 如果需要对图例也应用相同的字体
plt.legend(prop={'family':'Microsoft YaHei'})
# 展示
plt.tight_layout()
plt.show()
输出:
通过上图研究发现,由于原本数据中不患病的多,所以不患病的在图像中显示多。通过观察发现患病与年龄有关,尤其是年龄大于80岁的,患病与不患病比例高。
3.特征选择
模型采用:决策树特征训练,可以很好的对特征重要性进行排序。
特征选择:采用RFE特征选择方法:
RFE(Recursive Feature Elimination,递归特征消除)和 SelectFromModel 都是 Scikit-learn 中用于特征选择的方法,但它们的工作机制和使用场景有所不同。
SelectFromModel:
- 工作原理: SelectFromModel 是一种基于模型的特征选择方法。它通过一个基础评估器来判断每个特征的重要性,并根据给定的阈值选择那些重要性得分超过该阈值的特征。默认情况下,它会使用基础评估器提供的 feature_importances_ 或者 coef_ 属性来衡量特征的重要性。
- 使用场景:当你希望基于某个预训练模型的特征重要性来进行特征选择时特别有用。它允许你设置一个全局阈值来控制特征选择的标准,但不直接支持指定想要选择的特征数量。
- 优点:简单易用,适合快速进行特征筛选。
- 缺点:不如 RFE 精细,不能直接控制最终选择的特征数量。
RFE(Recursive Feature Elimination):
- 工作原理:RFE 采用了一种递归的方式进行特征选择。首先,它会训练一个模型,并格局模型对每个特征的重要性评分进行排序。然后,它会移除最不重要的特征,并重复这个过程,直到留下指定数量的特征为止。
- 适用场景:当您确切知道想要选择多少个特征时非常有用。它提供了比 SelectFromModel 更细致的控制,因为您可以直接指定要保留的特征数量。
- 优点:可以精确控制最终选择的特征数量,并且在每一轮迭代中都能考虑到所有剩余特征的整体贡献。
- 缺点:计算成本相对较高,因为它需要多次训练模型,特别是当数据集很大或模型复杂度很高时。
总结:
- 如果你的目标是基于某个预定义的重要性阈值来简化模型,那么 SelectFromModel 可能是更合适的选择。
- 如果你希望直接控制最终选择的特征数量,并愿意接受更高的计算成本以获得更精细的控制,那么 RFE 可能更适合您的需求。
两种方法都有其独特的优势和适用场景,选择哪一种取决于您的具有应用需求、数据特性以及性能考虑。
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
data=data_df.copy()
x=data_df.iloc[:,1:-2]
y=data_df.iloc[:,-2]
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.2,random_state=42)
# 标准化
sc=StandardScaler()
x_train=sc.fit_transform(x_train)
x_test=sc.transform(x_test)
# 模型创建
tree=DecisionTreeClassifier()
tree.fit(x_train,y_train)
pred=tree.predict(x_test)
reporter=classification_report(y_test,pred)
print(reporter)
输出:
precision recall f1-score support
0 0.91 0.91 0.91 277
1 0.84 0.83 0.83 153
accuracy 0.88 430
macro avg 0.87 0.87 0.87 430
weighted avg 0.88 0.88 0.88 430
效果不错,进行特征选择。
# 特征展示
feature_importances=tree.feature_importances_
features_rf=pd.DataFrame({'特征':x.columns,'重要度':feature_importances})
features_rf.sort_values(by='重要度',ascending=False,inplace=True)
plt.figure(figsize=(20,10))
sns.barplot(x='重要度',y='特征',data=features_rf)
plt.rcParams['font.sans-serif']=['Microsoft YaHei'] # 显示中文
plt.rcParams['axes.unicode_minus']=False # 显示负号
plt.xlabel('重要度')
plt.ylabel('特征')
plt.title('随机森林特征图')
plt.show()
输出:
从这个可以看出,有些特征没有效果,如性别、高血压等。
下面进行特征选择,选取20个特征。
from sklearn.feature_selection import RFE
# 使用 RFE 来选择特征
rfe_selector=RFE(estimator=tree,n_features_to_select=20) # 选择前20个特征
rfe_selector.fit(x,y)
x_new=rfe_selector.transform(x)
feature_names=np.array(x.columns)
selected_feature_names=feature_names[rfe_selector.support_]
print(selected_feature_names)
输出:
['年龄' '种族' '教育水平' '身体质量指数(BMI)' '酒精摄入量' '体育活动时间' '饮食质量评分' '睡眠质量评分' '心血管疾病'
'收缩压' '舒张压' '胆固醇总量' '低密度脂蛋白胆固醇(LDL)' '高密度脂蛋白胆固醇(HDL)' '甘油三酯'
'简易精神状态检查(MMSE)得分' '功能评估得分' '记忆抱怨' '行为问题' '日常生活活动(ADL)得分']
4.构建数据集
4.1.数据集划分与标准化
feature_selection=['年龄','种族','教育水平','身体质量指数(BMI)', '酒精摄入量', '体育活动时间', '饮食质量评分',
'睡眠质量评分', '心血管疾病', '收缩压', '舒张压', '胆固醇总量', '低密度脂蛋白胆固醇(LDL)',
'高密度脂蛋白胆固醇(HDL)' ,'甘油三酯', '简易精神状态检查(MMSE)得分', '功能评估得分',
'记忆抱怨', '行为问题', '日常生活活动(ADL)得分']
'''feature_selection=['年龄','种族','教育水平','身体质量指数(BMI)','酒精摄入量','体育活动时间','饮食质量评分',
'睡眠质量评分','心血管疾病','收缩压','舒张压','胆固醇总量','低密度脂蛋白胆固醇(LDL)',
'高密度脂蛋白胆固醇(HDL)','甘油三酯','简易精神状态检查(MMSE)得分','功能评估得分',
'记忆抱怨','行为问题','日常生活活动(ADL)得分']'''
x=data_df[feature_selection]
# 标准化,标准化其实对应连续性数据,分类数据不适合,由于特征中只有种族是分类数据,这里偷个“小懒”
sc=StandardScaler()
x=sc.fit_transform(x)
x=torch.tensor(np.array(x),dtype=torch.float32)
y=torch.tensor(np.array(y),dtype=torch.long)
# 再次进行特征选择
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.2,random_state=42)
x_train.shape,y_train.shape
输出:
(torch.Size([1719, 20]), torch.Size([1719]))
4.2.构建加载
batch_size=32
train_dl=DataLoader(TensorDataset(x_train,y_train),batch_size=batch_size,shuffle=True)
test_dl=DataLoader(TensorDataset(x_test,y_test),batch_size=batch_size,shuffle=False)
5.构建模型
class RNN_Model(nn.Module):
def __init__(self):
super().__init__()
# 调用rnn
self.rnn=nn.RNN(input_size=20,hidden_size=200,num_layers=1,batch_first=True)
self.fc1=nn.Linear(200,50)
self.fc2=nn.Linear(50,2)
def forward(self,x):
x,hidden1=self.rnn(x)
x=self.fc1(x)
x=self.fc2(x)
return x
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=RNN_Model().to(device)
model
输出:
RNN_Model(
(rnn): RNN(20, 200, batch_first=True)
(fc1): Linear(in_features=200, out_features=50, bias=True)
(fc2): Linear(in_features=50, out_features=2, bias=True)
)
model(torch.randn(32,20).to(device)).shape
输出:
torch.Size([32, 2])
6.模型训练
6.1.构建训练函数
def train(data,model,loss_fn,opt):
size=len(data.dataset)
batch_num=len(data)
train_loss,train_acc=0.0,0.0
for x,y in data:
x,y=x.to(device),y.to(device)
pred=model(x)
loss=loss_fn(pred,y)
# 反向传播
opt.zero_grad() # 梯度清零
loss.backward() # 求导
opt.step() # 设置梯度
train_loss+=loss.item()
train_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()
train_loss/=batch_num
train_acc/=size
return train_acc,train_loss
6.2.构建测试函数
def test(data,model,loss_fn):
size=len(data.dataset)
batch_num=len(data)
test_loss,test_acc=0.0,0.0
with torch.no_grad():
for x,y in data:
x,y=x.to(device),y.to(device)
pred=model(x)
loss=loss_fn(pred,y)
test_loss+=loss.item()
test_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()
test_loss/=batch_num
test_acc/=size
return test_acc,test_loss
6.3.设置超参数
超参数,开始设置了:
- 1e-3,但是不稳定;
- 1e-4,效果还不错。
loss_fn=nn.CrossEntropyLoss() # 损失函数
learn_lr=1e-4 # 超参数
optimizer=torch.optim.Adam(model.parameters(),lr=learn_lr) # 优化器
7.模型训练
train_acc=[]
train_loss=[]
test_acc=[]
test_loss=[]
epoches=50
for i in range(epoches):
model.train()
epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)
model.eval()
epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)
train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
# 输出
template=('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f}')
print(template.format(i+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss))
print('Done')
输出:
Epoch: 1,Train_acc:62.9%,Train_loss:0.649,Test_acc:67.4%,Test_loss:0.600
Epoch: 2,Train_acc:68.7%,Train_loss:0.579,Test_acc:70.7%,Test_loss:0.542
Epoch: 3,Train_acc:75.7%,Train_loss:0.521,Test_acc:79.1%,Test_loss:0.489
Epoch: 4,Train_acc:80.6%,Train_loss:0.464,Test_acc:84.0%,Test_loss:0.444
Epoch: 5,Train_acc:83.7%,Train_loss:0.428,Test_acc:82.6%,Test_loss:0.417
Epoch: 6,Train_acc:84.4%,Train_loss:0.397,Test_acc:83.7%,Test_loss:0.405
Epoch: 7,Train_acc:85.0%,Train_loss:0.385,Test_acc:82.8%,Test_loss:0.397
Epoch: 8,Train_acc:84.8%,Train_loss:0.378,Test_acc:82.6%,Test_loss:0.398
Epoch: 9,Train_acc:84.9%,Train_loss:0.377,Test_acc:82.8%,Test_loss:0.399
Epoch:10,Train_acc:84.5%,Train_loss:0.374,Test_acc:83.0%,Test_loss:0.403
Epoch:11,Train_acc:84.5%,Train_loss:0.374,Test_acc:83.7%,Test_loss:0.402
Epoch:12,Train_acc:84.6%,Train_loss:0.369,Test_acc:83.5%,Test_loss:0.400
Epoch:13,Train_acc:84.4%,Train_loss:0.374,Test_acc:83.3%,Test_loss:0.404
Epoch:14,Train_acc:84.1%,Train_loss:0.373,Test_acc:83.3%,Test_loss:0.400
Epoch:15,Train_acc:84.2%,Train_loss:0.375,Test_acc:82.8%,Test_loss:0.397
Epoch:16,Train_acc:84.5%,Train_loss:0.374,Test_acc:82.8%,Test_loss:0.396
Epoch:17,Train_acc:84.7%,Train_loss:0.372,Test_acc:83.7%,Test_loss:0.399
Epoch:18,Train_acc:84.7%,Train_loss:0.373,Test_acc:83.7%,Test_loss:0.400
Epoch:19,Train_acc:84.2%,Train_loss:0.372,Test_acc:84.0%,Test_loss:0.402
Epoch:20,Train_acc:85.0%,Train_loss:0.370,Test_acc:83.5%,Test_loss:0.399
Epoch:21,Train_acc:84.7%,Train_loss:0.374,Test_acc:83.3%,Test_loss:0.397
Epoch:22,Train_acc:84.6%,Train_loss:0.370,Test_acc:84.2%,Test_loss:0.398
Epoch:23,Train_acc:84.9%,Train_loss:0.371,Test_acc:84.7%,Test_loss:0.399
Epoch:24,Train_acc:84.8%,Train_loss:0.372,Test_acc:84.7%,Test_loss:0.395
Epoch:25,Train_acc:84.9%,Train_loss:0.368,Test_acc:83.5%,Test_loss:0.395
Epoch:26,Train_acc:84.5%,Train_loss:0.369,Test_acc:84.7%,Test_loss:0.395
Epoch:27,Train_acc:84.1%,Train_loss:0.373,Test_acc:84.4%,Test_loss:0.393
Epoch:28,Train_acc:84.4%,Train_loss:0.371,Test_acc:84.0%,Test_loss:0.396
Epoch:29,Train_acc:84.6%,Train_loss:0.373,Test_acc:84.4%,Test_loss:0.395
Epoch:30,Train_acc:84.7%,Train_loss:0.371,Test_acc:84.2%,Test_loss:0.396
Epoch:31,Train_acc:85.0%,Train_loss:0.369,Test_acc:84.4%,Test_loss:0.396
Epoch:32,Train_acc:84.9%,Train_loss:0.375,Test_acc:84.4%,Test_loss:0.395
Epoch:33,Train_acc:85.0%,Train_loss:0.370,Test_acc:83.7%,Test_loss:0.396
Epoch:34,Train_acc:84.7%,Train_loss:0.371,Test_acc:84.0%,Test_loss:0.396
Epoch:35,Train_acc:84.7%,Train_loss:0.372,Test_acc:84.4%,Test_loss:0.396
Epoch:36,Train_acc:84.5%,Train_loss:0.371,Test_acc:84.0%,Test_loss:0.394
Epoch:37,Train_acc:84.6%,Train_loss:0.372,Test_acc:84.4%,Test_loss:0.396
Epoch:38,Train_acc:85.3%,Train_loss:0.371,Test_acc:84.0%,Test_loss:0.395
Epoch:39,Train_acc:84.6%,Train_loss:0.370,Test_acc:84.2%,Test_loss:0.397
Epoch:40,Train_acc:84.4%,Train_loss:0.372,Test_acc:84.2%,Test_loss:0.399
Epoch:41,Train_acc:85.2%,Train_loss:0.371,Test_acc:83.7%,Test_loss:0.396
Epoch:42,Train_acc:84.9%,Train_loss:0.369,Test_acc:84.2%,Test_loss:0.397
Epoch:43,Train_acc:84.9%,Train_loss:0.371,Test_acc:84.4%,Test_loss:0.398
Epoch:44,Train_acc:85.2%,Train_loss:0.373,Test_acc:84.2%,Test_loss:0.400
Epoch:45,Train_acc:84.8%,Train_loss:0.370,Test_acc:83.3%,Test_loss:0.402
Epoch:46,Train_acc:84.7%,Train_loss:0.370,Test_acc:84.4%,Test_loss:0.402
Epoch:47,Train_acc:85.3%,Train_loss:0.368,Test_acc:83.3%,Test_loss:0.401
Epoch:48,Train_acc:84.4%,Train_loss:0.372,Test_acc:83.5%,Test_loss:0.401
Epoch:49,Train_acc:84.6%,Train_loss:0.371,Test_acc:83.7%,Test_loss:0.401
Epoch:50,Train_acc:84.9%,Train_loss:0.369,Test_acc:83.5%,Test_loss:0.400
Done
8.模型评估
8.1.结果图
import matplotlib.pyplot as plt
# 隐藏警告
import warnings
warnings.filterwarnings("ignore") # 忽略警告信息
from datetime import datetime
current_time=datetime.now() # 获取当前时间
epochs_range=range(epoches)
plt.figure(figsize=(12,3))
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='Training Accuracy')
plt.plot(epochs_range,test_acc,label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training Accuracy')
plt.xlabel(current_time)
plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='Training Loss')
plt.plot(epochs_range,test_loss,label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training Loss')
plt.show()
输出:
8.2.混淆矩阵
混淆矩阵(Confusion Matrix)是机器学习和数据科学中用于评估分类模型性能的一种表格。它通过展示模型预测结果与实际标签之间的对比,帮助我们理解模型的准确的以及其在不同类别上的表现。
对于一个二分类问题,混淆矩阵通常是一个2×2的表格,包含一下四个指标:
- 真正列(True Positive,TP):模型正确预测为正类的样本数。
- 假正列(False Positive,FP):模型错误地将负类预测为正类的样本数。
- 假负列(False Negative,FN):模型错误地将正类预测为负类的样本数。
- 真负列(True Negative,TN):模型正确预测为负类的样本数。
而对于多分类问题,混淆矩阵会相应地扩展到 N×N 的大小(N 为类别数量),每一行代表实际类别,每一列代表预测类别。
from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay
pred=model(x_test.to(device)).argmax(1).cpu().numpy()
# 计算混淆矩阵
cm=confusion_matrix(y_test,pred)
# 计算
plt.figure(figsize=(6,5))
sns.heatmap(cm,annot=True,fmt="d",cmap="Blues")
# 标题
plt.title("混淆矩阵")
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.tight_layout() # 自适应
plt.show()
输出: