一、背景介绍
1.1 眼底影像诊断的意义
糖尿病视网膜病变、年龄相关性黄斑变性、青光眼等常见眼底疾病,是导致不可逆视力损伤甚至失明的主要原因。例如,糖尿病视网膜病变因高血糖引发血管渗漏和视网膜脱离,AMD直接破坏中心视力,青光眼则因视神经损伤逐渐剥夺视野。此外,病理性近视、视网膜静脉阻塞等疾病与全身健康(如心血管疾病、高血压)密切相关,可能进一步加剧视力风险。LASIK手术痕迹虽非疾病,但可能干扰后续眼部评估。
AI辅助诊断通过高效分析眼底影像,显著提升早期病变(如微动脉瘤、黄斑水肿)的检出率,并量化评估病情进展(如视网膜厚度、视野缺损)。其标准化分级能力减少人为误差,尤其在基层医疗中弥补专业资源不足。结合便携设备,AI可推动偏远地区筛查,实现疾病动态监测与风险预警,为医生提供决策支持,优化诊疗效率。AI技术的应用不仅降低致盲风险,更通过早筛早治改善患者预后,成为眼科数字化发展的重要推动力。
1.2 EDDFS数据集亮点
-
规模优势:28,877张高分辨率眼底图(含8类疾病+健康样本)
-
稀缺性:包含LASIK术后斑等特殊病例
-
科研价值:支持多标签分类与疾病分级任务
-
获取方式:[Google Drive/Mega/百度云](附密码
i3vk
)
二、环境搭建与代码解析
2.1 依赖环境配置
# 关键依赖版本(需CUDA 11.8)
torch==2.1.1+cu118
torchvision==0.16.1+cu118
scikit-learn==1.4.2 # 用于评估指标
CUDA的版本安装可参考全网最详细的安装pytorch GPU方法,一次安装成功!!包括安装失败后的处理方法!-CSDN博客
2.2 代码结构解析
project-root/
├── models/ # 模型定义(核心:coattnet_v2_withWeighted_tiny)
├── datasets/ # 数据加载逻辑(支持多任务配置)
└── config/_data/ # 数据集路径管理
2.3 核心模型架构
CoAtt Net核心设计:
-
加权注意力机制:针对不同疾病特征动态调整关注区域
-
多尺度特征融合:结合全局上下文与局部细节
-
轻量化设计:在448x448高分辨率下保持高效推理
三、实验复现全流程
3.1 数据预处理配置
修改config/_data/datasetConf.py
指定本地数据路径:
# 示例:糖尿病视网膜病变任务配置
EDDFS_dr_conf = {
"IMG_ROOT": "/your_path/EDDFS/DR",
"LABEL_DIR": "Annotation/dr_labels.csv"
}
3.2 训练启动命令
python main.py \
--useGPU 0 \ # 使用第一张显卡
--dataset EDDFS_dr \ # 指定糖尿病视网膜病变任务
--preprocess 7 \ # 标准化+直方图均衡化
--net coattnet_v2_withWeighted_tiny \
--batchsize 32 \ # 显存占用约12GB
--lr 9e-5 \ # 初始学习率
--lossfun focalloss # 缓解类别不平衡
四、实验结果深度分析
4.1 整体性能表现
指标 | 值 | 说明 |
---|---|---|
准确率 | 91.23% | 整体分类准确率 |
F1-score | 51.51% | 加权平均(受类别不平衡影响) |
AUC | 0.8999 | 显示良好鲁棒性 |
4.2 各类别表现对比
疾病类别 | 样本量 | Precision | Recall | F1 | AUC |
---|---|---|---|---|---|
糖尿病视网膜病变 (0) | 664 | 71.43% | 2.26% | 4.38% | 0.768 |
年龄相关性黄斑变性 (3) | 760 | 83.29% | 77.37% | 80.22% | 0.978 |
青光眼 (5) | 523 | 83.03% | 34.61% | 48.85% | 0.894 |
关键发现:
-
样本量大的类别(如2838例的类别1)表现出稳定的召回率(78.86%)
-
小样本类别(如111例的类别4)存在严重的漏检问题
-
LASIK术后斑(类别7)检测失败,需针对性优化
4.3 混淆矩阵可视化
1.对模型训练指标进行整理
class_info = [
# [样本量, Precision, Recall, F1]
[664, 0.7143, 0.0226, 0.0438], # 类别0: DR
[2838, 0.7189, 0.7886, 0.7521], # 类别1: AMD
[326, 0.0, 0.0, 0.0], # 类别2: Glaucoma
[760, 0.8329, 0.7737, 0.8022], # 类别3: Myopia
[111, 0.0, 0.0, 0.0], # 类别4: Hypertension
[523, 0.8303, 0.3461, 0.4885], # 类别5: RVO
[288, 0.0, 0.0, 0.0], # 类别6: LASIK
[370, 0.0, 0.0, 0.0], # 类别7: Other
]
2. 计算混淆矩阵元素
-
True Positive (TP):
TP = Recall * 样本量
-
False Negative (FN):
FN = 样本量 - TP
-
False Positive (FP):
FP = TP / Precision - TP
(当Precision > 0时) -
True Negative (TN):通过全局统计间接计算。
3. 生成模拟标签
通过以下代码生成近似的 y_true
和 y_pred
:
import numpy as np
# 初始化真实标签和预测标签
y_true = []
y_pred = []
# 为每个类别生成样本
for class_idx in range(8):
total = class_info[class_idx][0]
recall = class_info[class_idx][2]
# 计算TP和FN
tp = int(recall * total)
fn = total - tp
# 添加TP样本
y_true.extend([class_idx] * tp)
y_pred.extend([class_idx] * tp)
# 添加FN样本(随机分配到其他类别)
other_classes = [i for i in range(8) if i != class_idx]
if len(other_classes) > 0:
y_true.extend([class_idx] * fn)
y_pred.extend(np.random.choice(other_classes, fn).tolist())
# 转换为NumPy数组
y_true = np.array(y_true)
y_pred = np.array(y_pred)
4.生成混淆矩阵
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
# 生成混淆矩阵
cm = confusion_matrix(y_true, y_pred)
# 可视化
plt.figure(figsize=(12, 10))
sns.heatmap(
cm, annot=True, fmt="d", cmap="Blues",
xticklabels=['DR', 'AMD', 'Glaucoma', 'Myopia', 'Hypertension', 'RVO', 'LASIK', 'Other'],
yticklabels=['DR', 'AMD', 'Glaucoma', 'Myopia', 'Hypertension', 'RVO', 'LASIK', 'Other']
)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("EDDFS Dataset Confusion Matrix")
plt.savefig("confusion_matrix.png", dpi=300, bbox_inches="tight")
关键观察:
-
AMD分类表现卓越
-
正确识别率最高:年龄相关性黄斑变性(AMD)类别的正确预测样本数为2238,占总样本的显著比例(2838例),召回率高达78.86%,显示模型对AMD的特征捕捉能力极强。
-
误判分散:误判样本(600例)均匀分布到其他类别(如DR、Glaucoma等),未出现明显偏向性。
-
-
严重的小样本类别失效问题
-
完全失效的类别:青光眼(Glaucoma)、LASIK术后斑、其他(Other)类别的正确预测数均为0,表明模型完全无法识别这些疾病。
-
潜在原因:
-
样本量过少(Glaucoma仅326例,LASIK仅288例)。
-
特征学习不足,可能需针对性增强数据增强或迁移学习。
-
-
-
糖尿病视网膜病变(DR)的敏感性问题
-
低召回率:DR类别的召回率仅2.26%(正确预测15例/总样本664例),漏检率高达97.74%。
-
误判集中:误判样本主要流向AMD(100例)和Glaucoma(105例),反映模型对血管病变特征的混淆。
-
-
病理性近视(Myopia)的稳定性与局限性
-
高精确率:Myopia的精确率达83.29%(正确预测588例/总预测706例),显示模型对病灶区域(如视盘变形)的敏感度较高。
-
误判方向:误判样本分散至高血压性视网膜病变(30例)和RVO(22例),可能与血管形态变化的相似性有关。
-
-
系统性改进方向
-
数据层面:
-
对小样本类别(如LASIK、Glaucoma)进行过采样或合成数据增强。
-
引入病灶标注信息,辅助模型区分相似特征(如DR与AMD的微血管差异)。
-
-
模型层面:
-
尝试多任务学习框架,联合优化疾病分类与病灶定位。
-
针对低召回类别,采用Focal Loss或代价敏感学习。
-
-
4.4 注意力热力图对比
使用Grad-CAM生成病灶区域可视化:
import cv2
import numpy as np
import torch
from torchvision import models, transforms
from PIL import Image
# 1. 加载模型并注册钩子
model = models.resnet18(pretrained=True)
model.eval()
# 存储特征图和梯度
features = {}
gradients = {}
def forward_hook(module, input, output):
features['layer4'] = output.detach()
def backward_hook(module, grad_input, grad_output):
gradients['layer4'] = grad_output[0].detach()
# 注册正向和反向钩子
model.layer4.register_forward_hook(forward_hook)
model.layer4.register_backward_hook(backward_hook)
# 2. 预处理图像
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
original_img = cv2.imread('fundus_sample.jpg')#替换图片路径
img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
img_pil = Image.fromarray(img)
input_tensor = preprocess(img_pil).unsqueeze(0).requires_grad_(True)
# 3. 前向传播并获取目标类别输出
output = model(input_tensor)
target_class = output.argmax().item() # 假设关注预测概率最高的类别
# 4. 反向传播计算梯度
model.zero_grad()
output[0, target_class].backward()
# 5. 计算Grad-CAM权重
grads = gradients['layer4'] # 梯度形状: [1, 512, 7, 7]
pooled_grads = torch.mean(grads, dim=[0, 2, 3]) # 全局平均池化,形状: [512]
# 6. 加权特征图
activations = features['layer4'] # 特征图形状: [1, 512, 7, 7]
for i in range(activations.shape[1]):
activations[:, i, :, :] *= pooled_grads[i]
heatmap = torch.mean(activations, dim=1).squeeze().cpu().numpy()
heatmap = np.maximum(heatmap, 0) # ReLU激活
# 7. 生成热力图
heatmap = cv2.resize(heatmap, (original_img.shape[1], original_img.shape[0]))
heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = cv2.addWeighted(original_img, 0.5, heatmap, 0.5, 0)
cv2.imwrite('grad_cam_result.jpg', superimposed_img)
以数据集中0001(NFMK1).jpg为例
原图像如下:
训练预处理后图像:
注意力热力图:
五、优化建议与踩坑记录
5.1 改进方向
-
数据增强:针对小样本疾病添加病灶区域仿射变换
-
损失函数:尝试Dice Loss优化边界敏感任务
-
模型融合:结合EfficientNet-V2的特征提取能力
5.2 常见报错解决
# 报错:CUDA内存不足
解决方案:
1. 降低batchsize至16
2. 启用混合精度训练(添加--fp16参数)
# 报错:标签文件加载失败
检查点:
1. CSV文件编码需为UTF-8
2. 图像路径中的反斜杠需转换为正斜杠
六、学术价值与展望
-
多任务学习潜力:同时支持疾病分类与病灶定位
-
临床部署挑战:需进一步优化推理速度(当前约15 FPS)
-
伦理考量:医疗数据隐私保护与模型可解释性需求
(文末附完整训练模型下载链接)