基于EDDFS眼底数据集的疾病诊断模型实战:CoAtt Net复现与性能分析

一、背景介绍

1.1 眼底影像诊断的意义

  糖尿病视网膜病变、年龄相关性黄斑变性、青光眼等常见眼底疾病,是导致不可逆视力损伤甚至失明的主要原因。例如,糖尿病视网膜病变因高血糖引发血管渗漏和视网膜脱离,AMD直接破坏中心视力,青光眼则因视神经损伤逐渐剥夺视野。此外,病理性近视、视网膜静脉阻塞等疾病与全身健康(如心血管疾病、高血压)密切相关,可能进一步加剧视力风险。LASIK手术痕迹虽非疾病,但可能干扰后续眼部评估。

  AI辅助诊断通过高效分析眼底影像,显著提升早期病变(如微动脉瘤、黄斑水肿)的检出率,并量化评估病情进展(如视网膜厚度、视野缺损)。其标准化分级能力减少人为误差,尤其在基层医疗中弥补专业资源不足。结合便携设备,AI可推动偏远地区筛查,实现疾病动态监测与风险预警,为医生提供决策支持,优化诊疗效率。AI技术的应用不仅降低致盲风险,更通过早筛早治改善患者预后,成为眼科数字化发展的重要推动力。

1.2 EDDFS数据集亮点

  • 规模优势:28,877张高分辨率眼底图(含8类疾病+健康样本)

  • 稀缺性:包含LASIK术后斑等特殊病例

  • 科研价值:支持多标签分类与疾病分级任务

  • 获取方式:[Google Drive/Mega/百度云](附密码i3vk

二、环境搭建与代码解析

2.1 依赖环境配置

# 关键依赖版本(需CUDA 11.8)
torch==2.1.1+cu118
torchvision==0.16.1+cu118
scikit-learn==1.4.2  # 用于评估指标

CUDA的版本安装可参考全网最详细的安装pytorch GPU方法,一次安装成功!!包括安装失败后的处理方法!-CSDN博客

2.2 代码结构解析 

project-root/
├── models/           # 模型定义(核心:coattnet_v2_withWeighted_tiny)
├── datasets/         # 数据加载逻辑(支持多任务配置)
└── config/_data/     # 数据集路径管理

2.3 核心模型架构

CoAtt Net核心设计

  1. 加权注意力机制:针对不同疾病特征动态调整关注区域

  2. 多尺度特征融合:结合全局上下文与局部细节

  3. 轻量化设计:在448x448高分辨率下保持高效推理

三、实验复现全流程

3.1 数据预处理配置

修改config/_data/datasetConf.py指定本地数据路径:

# 示例:糖尿病视网膜病变任务配置
EDDFS_dr_conf = {
    "IMG_ROOT": "/your_path/EDDFS/DR",
    "LABEL_DIR": "Annotation/dr_labels.csv"
}

 3.2 训练启动命令

python main.py \
  --useGPU 0 \               # 使用第一张显卡
  --dataset EDDFS_dr \       # 指定糖尿病视网膜病变任务
  --preprocess 7 \           # 标准化+直方图均衡化
  --net coattnet_v2_withWeighted_tiny \  
  --batchsize 32 \           # 显存占用约12GB
  --lr 9e-5 \                # 初始学习率
  --lossfun focalloss        # 缓解类别不平衡

四、实验结果深度分析

4.1 整体性能表现

指标说明
准确率91.23%整体分类准确率
F1-score51.51%加权平均(受类别不平衡影响)
AUC0.8999显示良好鲁棒性

4.2 各类别表现对比

疾病类别样本量PrecisionRecallF1AUC
糖尿病视网膜病变 (0)66471.43%2.26%4.38%0.768
年龄相关性黄斑变性 (3)76083.29%77.37%80.22%0.978
青光眼 (5)52383.03%34.61%48.85%0.894

关键发现

  1. 样本量大的类别(如2838例的类别1)表现出稳定的召回率(78.86%)

  2. 小样本类别(如111例的类别4)存在严重的漏检问题

  3. LASIK术后斑(类别7)检测失败,需针对性优化 

 4.3 混淆矩阵可视化

1.对模型训练指标进行整理

class_info = [
    # [样本量, Precision, Recall, F1]
    [664, 0.7143, 0.0226, 0.0438],   # 类别0: DR
    [2838, 0.7189, 0.7886, 0.7521],  # 类别1: AMD
    [326, 0.0, 0.0, 0.0],            # 类别2: Glaucoma
    [760, 0.8329, 0.7737, 0.8022],   # 类别3: Myopia
    [111, 0.0, 0.0, 0.0],            # 类别4: Hypertension
    [523, 0.8303, 0.3461, 0.4885],   # 类别5: RVO
    [288, 0.0, 0.0, 0.0],            # 类别6: LASIK
    [370, 0.0, 0.0, 0.0],            # 类别7: Other
]

2. 计算混淆矩阵元素

  • True Positive (TP)TP = Recall * 样本量

  • False Negative (FN)FN = 样本量 - TP

  • False Positive (FP)FP = TP / Precision - TP(当Precision > 0时)

  • True Negative (TN):通过全局统计间接计算。

3. 生成模拟标签

通过以下代码生成近似的 y_true 和 y_pred

import numpy as np

# 初始化真实标签和预测标签
y_true = []
y_pred = []

# 为每个类别生成样本
for class_idx in range(8):
    total = class_info[class_idx][0]
    recall = class_info[class_idx][2]
    
    # 计算TP和FN
    tp = int(recall * total)
    fn = total - tp
    
    # 添加TP样本
    y_true.extend([class_idx] * tp)
    y_pred.extend([class_idx] * tp)
    
    # 添加FN样本(随机分配到其他类别)
    other_classes = [i for i in range(8) if i != class_idx]
    if len(other_classes) > 0:
        y_true.extend([class_idx] * fn)
        y_pred.extend(np.random.choice(other_classes, fn).tolist())

# 转换为NumPy数组
y_true = np.array(y_true)
y_pred = np.array(y_pred)

4.生成混淆矩阵

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

# 生成混淆矩阵
cm = confusion_matrix(y_true, y_pred)

# 可视化
plt.figure(figsize=(12, 10))
sns.heatmap(
    cm, annot=True, fmt="d", cmap="Blues",
    xticklabels=['DR', 'AMD', 'Glaucoma', 'Myopia', 'Hypertension', 'RVO', 'LASIK', 'Other'],
    yticklabels=['DR', 'AMD', 'Glaucoma', 'Myopia', 'Hypertension', 'RVO', 'LASIK', 'Other']
)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("EDDFS Dataset Confusion Matrix")
plt.savefig("confusion_matrix.png", dpi=300, bbox_inches="tight")

 

关键观察

  1. AMD分类表现卓越

    • 正确识别率最高:年龄相关性黄斑变性(AMD)类别的正确预测样本数为2238,占总样本的显著比例(2838例),召回率高达78.86%,显示模型对AMD的特征捕捉能力极强。

    • 误判分散:误判样本(600例)均匀分布到其他类别(如DR、Glaucoma等),未出现明显偏向性。

  2. 严重的小样本类别失效问题

    • 完全失效的类别:青光眼(Glaucoma)、LASIK术后斑、其他(Other)类别的正确预测数均为0,表明模型完全无法识别这些疾病。

    • 潜在原因

      • 样本量过少(Glaucoma仅326例,LASIK仅288例)。

      • 特征学习不足,可能需针对性增强数据增强或迁移学习。

  3. 糖尿病视网膜病变(DR)的敏感性问题

    • 低召回率:DR类别的召回率仅2.26%(正确预测15例/总样本664例),漏检率高达97.74%。

    • 误判集中:误判样本主要流向AMD(100例)和Glaucoma(105例),反映模型对血管病变特征的混淆。

  4. 病理性近视(Myopia)的稳定性与局限性

    • 高精确率:Myopia的精确率达83.29%(正确预测588例/总预测706例),显示模型对病灶区域(如视盘变形)的敏感度较高。

    • 误判方向:误判样本分散至高血压性视网膜病变(30例)和RVO(22例),可能与血管形态变化的相似性有关。

  5. 系统性改进方向

    • 数据层面

      • 对小样本类别(如LASIK、Glaucoma)进行过采样或合成数据增强。

      • 引入病灶标注信息,辅助模型区分相似特征(如DR与AMD的微血管差异)。

    • 模型层面

      • 尝试多任务学习框架,联合优化疾病分类与病灶定位。

      • 针对低召回类别,采用Focal Loss或代价敏感学习。

4.4 注意力热力图对比

使用Grad-CAM生成病灶区域可视化:

import cv2
import numpy as np
import torch
from torchvision import models, transforms
from PIL import Image

# 1. 加载模型并注册钩子
model = models.resnet18(pretrained=True)
model.eval()

# 存储特征图和梯度
features = {}
gradients = {}

def forward_hook(module, input, output):
    features['layer4'] = output.detach()

def backward_hook(module, grad_input, grad_output):
    gradients['layer4'] = grad_output[0].detach()

# 注册正向和反向钩子
model.layer4.register_forward_hook(forward_hook)
model.layer4.register_backward_hook(backward_hook)

# 2. 预处理图像
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

original_img = cv2.imread('fundus_sample.jpg')#替换图片路径
img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
img_pil = Image.fromarray(img)
input_tensor = preprocess(img_pil).unsqueeze(0).requires_grad_(True)

# 3. 前向传播并获取目标类别输出
output = model(input_tensor)
target_class = output.argmax().item()  # 假设关注预测概率最高的类别

# 4. 反向传播计算梯度
model.zero_grad()
output[0, target_class].backward()

# 5. 计算Grad-CAM权重
grads = gradients['layer4']  # 梯度形状: [1, 512, 7, 7]
pooled_grads = torch.mean(grads, dim=[0, 2, 3])  # 全局平均池化,形状: [512]

# 6. 加权特征图
activations = features['layer4']  # 特征图形状: [1, 512, 7, 7]
for i in range(activations.shape[1]):
    activations[:, i, :, :] *= pooled_grads[i]

heatmap = torch.mean(activations, dim=1).squeeze().cpu().numpy()
heatmap = np.maximum(heatmap, 0)  # ReLU激活

# 7. 生成热力图
heatmap = cv2.resize(heatmap, (original_img.shape[1], original_img.shape[0]))
heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

superimposed_img = cv2.addWeighted(original_img, 0.5, heatmap, 0.5, 0)
cv2.imwrite('grad_cam_result.jpg', superimposed_img)

以数据集中0001(NFMK1).jpg为例

原图像如下:

训练预处理后图像:

注意力热力图:

五、优化建议与踩坑记录

5.1 改进方向

  • 数据增强:针对小样本疾病添加病灶区域仿射变换

  • 损失函数:尝试Dice Loss优化边界敏感任务

  • 模型融合:结合EfficientNet-V2的特征提取能力

5.2 常见报错解决 

# 报错:CUDA内存不足
解决方案:
1. 降低batchsize至16
2. 启用混合精度训练(添加--fp16参数)

# 报错:标签文件加载失败
检查点:
1. CSV文件编码需为UTF-8
2. 图像路径中的反斜杠需转换为正斜杠

六、学术价值与展望

  1. 多任务学习潜力:同时支持疾病分类与病灶定位

  2. 临床部署挑战:需进一步优化推理速度(当前约15 FPS)

  3. 伦理考量:医疗数据隐私保护与模型可解释性需求

(文末附完整训练模型下载链接) 

https://github.com/xia-xx-cv/EDDFS_dataset

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

TJUTCM-烛龙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值