pytorch医学成像膝关节炎检测CT图像分类识别+pyqt5界面设计

1 数据说明

本项目的数据集用于膝关节炎的检测和分类,包含高分辨率的医学成像和专家注释(查找开放数据集和机器学习项目 |卡格尔 (kaggle.com))。数据集分为五个阶段的图像:正常、可疑、轻度、中度和重度关节炎。具体数据集目录结构如下:

./Training
    /0Normal
    /1Doubtful
    /2Mild
    /3Moderate
    /4Severe

每个子目录分别存放不同阶段的膝关节CT图像。数据集提供了膝关节不同病变程度的详细图像,这些图像是由专业医生注释和分类的,确保了数据的可靠性和准确性。通过对这些图像进行分类训练,可以帮助构建出一个高效的膝关节炎检测模型,以便在临床中进行辅助诊断。数据集中包含的大量图像样本为模型的训练和验证提供了充足的数据支持,保证了模型在实际应用中的泛化能力和鲁棒性。 

2 模型构建训练

2.1 DenseNet模型

DenseNet,即密集连接卷积网络,是由Huang等人在2017年提出的一种深度卷积神经网络。DenseNet通过引入密集连接来缓解深层网络训练中的梯度消失问题,并实现了更高的参数效率。具体来说,DenseNet中的每一层都直接连接到其后面的每一层,这样可以确保特征和梯度的直接流动,避免信息的丢失。此外,DenseNet通过复用前层的特征来减少参数数量,提高计算效率。该模型在多个图像分类任务中取得了优异的性能。

模型构建实现:在构建DenseNet模型时,我们选择了DenseNet-121这一变体,其具有较少的参数量和较高的计算效率。我们对模型的最后一个全连接层进行了调整,使其输出为五个类别,分别对应正常、可疑、轻度、中度和重度膝关节炎。模型使用PyTorch框架实现,加载预训练权重以加速训练过程。

import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torch.optim import lr_scheduler

# 使用GPU进行训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义DenseNet模型
densenet_model = models.densenet121(pretrained=True)
num_ftrs = densenet_model.classifier.in_features
densenet_model.classifier = nn.Linear(num_ftrs, len(classes))
densenet_model = densenet_model.to(device)

# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(densenet_model.parameters(), lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# 训练DenseNet模型
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    train_loss = []
    train_acc = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_corrects = 0
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        scheduler.step()
        epoch_loss = running_loss / train_size
        epoch_acc = running_corrects.double() / train_size
        train_loss.append(epoch_loss)
        train_acc.append(epoch_acc.cpu().numpy()
        print(f'Epoch {epoch}/{num_epochs - 1}, Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}')
    return model, train_loss, train_acc

densenet_model, densenet_train_loss, densenet_train_acc = train_model(densenet_model, criterion, optimizer, scheduler)

DenseNet模型在训练过程中表现出稳定的性能提升。训练初期(第0至第4轮次),损失迅速下降,准确率逐步上升。特别是在第8轮次后,损失大幅降低并稳定在较低水平,同时准确率持续上升至95%以上。最终在第24轮次,DenseNet模型的损失降至0.0688,准确率达到97.95%。这一结果表明,DenseNet模型在处理膝关节炎检测任务时表现出色,具备较高的分类精度和稳定性。

2.2 ResNet模型

ResNet,即残差网络,是由He等人在2015年提出的一种深度卷积神经网络。ResNet通过引入残差模块来解决深层网络训练中的退化问题。具体来说,ResNet在每个残差模块中引入了捷径连接(skip connection),直接将输入传递到输出,从而使得网络更容易训练,并且可以训练更深的网络结构。ResNet在多个图像分类任务中取得了显著的效果,并赢得了2015年ImageNet大赛的冠军。

模型构建实现:在构建ResNet模型时,我们选择了ResNet-18这一变体,其具有较少的参数量和较高的计算效率。我们对模型的最后一个全连接层进行了调整,使其输出为五个类别,分别对应正常、可疑、轻度、中度和重度膝关节炎。模型使用PyTorch框架实现,加载预训练权重以加速训练过程。

# 定义ResNet模型
resnet_model = models.resnet18(pretrained=True)
num_ftrs = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(num_ftrs, len(classes))
resnet_model = resnet_model.to(device)

# 优化器和调度器与DenseNet一致
optimizer = optim.Adam(resnet_model.parameters(), lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# 训练ResNet模型
resnet_model, resnet_train_loss, resnet_train_acc = train_model(resnet_model, criterion, optimizer, scheduler)

 

ResNet模型的训练过程同样表现出显著的性能提升。在前几轮训练中(第0至第4轮次),损失迅速下降,准确率明显提高。第10轮次后,损失下降趋于平稳,准确率持续上升至96%以上。在第24轮次,ResNet模型的损失降至0.0550,准确率达到98.33%。这表明ResNet模型能够有效地学习和捕捉膝关节炎的特征,具有较高的分类能力。

2.3 MobileNetV2模型

MobileNetV2是一种轻量级深度卷积神经网络,由Sandler等人在2018年提出。MobileNetV2主要设计用于移动和嵌入式设备,其核心思想是使用深度可分离卷积(depthwise separable convolution)和反向残差模块(inverted residuals)来减少模型参数量和计算量,同时保持较高的准确率。深度可分离卷积将标准卷积分解为深度卷积和逐点卷积,大幅减少计算复杂度;反向残差模块则通过添加捷径连接进一步提升了模型的表达能力和训练效率。

模型构建实现:在构建MobileNetV2模型时,我们对模型的最后一个全连接层进行了调整,使其输出为五个类别,分别对应正常、可疑、轻度、中度和重度膝关节炎。模型使用PyTorch框架实现,加载预训练权重以加速训练过程。

# 定义MobileNet模型
mobilenet_model = models.mobilenet_v2(pretrained=True)
num_ftrs = mobilenet_model.classifier[1].in_features
mobilenet_model.classifier[1] = nn.Linear(num_ftrs, len(classes))
mobilenet_model = mobilenet_model.to(device)

# 优化器和调度器与DenseNet一致
optimizer = optim.Adam(mobilenet_model.parameters(), lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# 训练MobileNet模型
mobilenet_model, mobilenet_train_loss, mobilenet_train_acc = train_model(mobilenet_model, criterion, optimizer, scheduler)

 

MobileNetV2模型在训练过程中表现出较好的收敛性。训练初期(第0至第4轮次),损失快速下降,准确率逐步提升。第7轮次后,损失明显降低并趋于平稳,准确率稳定在97%以上。在第24轮次,MobileNetV2模型的损失降至0.0736,准确率达到97.73%。这一结果表明,MobileNetV2模型在膝关节炎检测任务中表现良好,具备较高的分类精度和可靠性。

2.4 EfficientNet模型

EfficientNet是由Tan和Le在2019年提出的一种高效卷积神经网络架构。EfficientNet通过结合模型缩放技术(compound scaling)来系统地调整网络的深度、宽度和分辨率,从而在减少参数量和计算量的同时,保持甚至提高模型的准确率。EfficientNet在多个图像分类任务中取得了显著的效果,并展示了其在各种计算资源受限的应用场景中的优势。

模型构建实现:在构建EfficientNet模型时,我们选择了EfficientNet-B0这一变体,其具有较少的参数量和较高的计算效率。我们对模型的最后一个全连接层进行了调整,使其输出为五个类别,分别对应正常、可疑、轻度、中度和重度膝关节炎。模型使用PyTorch框架实现,加载预训练权重以加速训练过程。

# 定义EfficientNet模型
efficientnet_model = models.efficientnet_b0(pretrained=True)
num_ftrs = efficientnet_model.classifier[1].in_features
efficientnet_model.classifier[1] = nn.Linear(num_ftrs, len(classes))
efficientnet_model = efficientnet_model.to(device)

# 优化器和调度器与DenseNet一致
optimizer = optim.Adam(efficientnet_model.parameters(), lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# 训练EfficientNet模型
efficientnet_model, efficientnet_train_loss, efficientnet_train_acc = train_model(efficientnet_model, criterion, optimizer, scheduler)

 

EfficientNet模型在训练过程中展现出强大的性能提升能力。训练初期(第0至第4轮次),损失迅速下降,准确率显著提高。第10轮次后,损失进一步降低并趋于稳定,准确率稳步提升至98%以上。在第24轮次,EfficientNet模型的损失降至0.0439,准确率达到98.86%。这一结果表明,EfficientNet模型在膝关节炎检测任务中表现卓越,具备较高的分类精度和稳定性。

3 模型评估

在膝关节炎检测任务中,四种模型的性能表现均较为出色。DenseNet模型的准确率为79.70%,Precision为80.48%,Recall为79.70%,F1为79.47%;ResNet模型的准确率为74.55%,Precision为75.69%,Recall为74.55%,F1为74.06%;MobileNetV2模型的准确率为78.48%,Precision为78.94%,Recall为78.48%,F1为78.43%;EfficientNet模型的准确率为79.70%,Precision为80.53%,Recall为79.70%,F1为79.36%。

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns

# 定义评估函数
def evaluate_model(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')
    return accuracy, precision, recall, f1, all_labels, all_preds

# 评估DenseNet模型
densenet_acc, densenet_prec, densenet_rec, densenet_f1, densenet_labels, densenet_preds = evaluate_model(densenet_model, test_loader)
print(f'DenseNet - Accuracy: {densenet_acc:.4f}, Precision: {densenet_prec:.4f}, Recall: {densenet_rec:.4f}, F1: {densenet_f1:.4f}')

# 评估ResNet模型
resnet_acc, resnet_prec, resnet_rec, resnet_f1, resnet_labels, resnet_preds = evaluate_model(resnet_model, test_loader)
print(f'ResNet - Accuracy: {resnet_acc:.4f}, Precision: {resnet_prec:.4f}, Recall: {resnet_rec:.4f}, F1: {resnet_f1:.4f}')

# 评估MobileNet模型
mobilenet_acc, mobilenet_prec, mobilenet_rec, mobilenet_f1, mobilenet_labels, mobilenet_preds = evaluate_model(mobilenet_model, test_loader)
print(f'MobileNet - Accuracy: {mobilenet_acc:.4f}, Precision: {mobilenet_prec:.4f}, Recall: {mobilenet_rec:.4f}, F1: {mobilenet_f1:.4f}')

# 评估EfficientNet模型
efficientnet_acc, efficientnet_prec, efficientnet_rec, efficientnet_f1, efficientnet_labels, efficientnet_preds = evaluate_model(efficientnet_model, test_loader)
print(f'EfficientNet - Accuracy: {efficientnet_acc:.4f}, Precision: {efficientnet_prec:.4f}, Recall: {efficientnet_rec:.4f}, F1: {efficientnet_f1:.4f}')

从结果可以看出,EfficientNet模型在准确率、Precision、Recall和F1值上均表现最佳,DenseNet紧随其后。ResNet和MobileNetV2的性能略逊一筹,但仍具备较高的分类能力。

在混淆矩阵分析中,DenseNet模型的分类结果显示其在某些类别上的分类效果较为准确,但在类别间存在一定的误分类现象。例如,类别1和类别2之间的误分类较为显著,但总体分类结果依然较为可靠。通过进一步优化模型参数和增加训练数据,DenseNet模型有望在未来的应用中取得更优异的分类性能。

4 pyqt5界面实现

为了更好地展示膝关节炎检测系统的效果,使用PyQt5设计了一个用户友好的界面。界面包括标题、选择图像、选择模型、展示结果和退出按钮等功能。用户可以通过选择不同的模型和图像,实时查看分类结果。

首先,使用PyQt5的QMainWindow类创建了主窗口,并在其中添加了各个功能按钮和显示区域。我们设计了一个标题栏用于显示系统名称,并在主窗口中部添加了一个QLabel用于显示选中的膝关节CT图像。用户可以通过点击“选择图像”按钮打开文件对话框,并选择要检测的图像。选择完成后,图像会在QLabel中显示。

其次,在主窗口的下部添加了一个模型选择下拉菜单(QComboBox),用户可以从中选择DenseNet、ResNet、MobileNetV2或EfficientNet模型。下拉菜单旁边的“开始检测”按钮用于触发模型检测过程。点击“开始检测”按钮后,系统会调用相应的模型对选中的图像进行分类,并在右侧的QLabel中显示分类结果。

import sys
from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QPushButton, QVBoxLayout, QHBoxLayout, QWidget, \
    QFileDialog, QComboBox
from PyQt5.QtGui import QPixmap, QFont
from PyQt5.QtCore import Qt
import torch
from torchvision import models, transforms
from PIL import Image


class KneeOsteoarthritisApp(QMainWindow):
    def __init__(self):
        super().__init__()

        self.setWindowTitle('膝关节炎检测系统')
        self.setGeometry(100, 100, 1200, 800)
        self.setStyleSheet("background-color: #f7f7f7;")

        self.image_path = None
        self.model_name = 'DenseNet'
        self.model = load_model(self.model_name)

        self.initUI()

    def initUI(self):
        main_layout = QVBoxLayout()
        main_layout.setContentsMargins(20, 20, 20, 20)

        title = QLabel('膝关节炎CT图像检测识别系统')
        title.setAlignment(Qt.AlignCenter)
        title.setStyleSheet("font-size: 52px; font-weight: bold; color: #333;")
        main_layout.addWidget(title)

        image_layout = QHBoxLayout()

        self.image_label = QLabel()
        self.image_label.setFixedSize(650, 600)
        self.image_label.setStyleSheet("border: 5px solid #ddd; margin: 20px;")
        image_layout.addWidget(self.image_label, alignment=Qt.AlignCenter)

        self.result_label = QLabel('识别结果:')
        self.result_label.setAlignment(Qt.AlignCenter)
        self.result_label.setFont(QFont('Arial', 16))
        self.result_label.setStyleSheet("margin: 20px; color: #555;")
        image_layout.addWidget(self.result_label, alignment=Qt.AlignCenter)

        main_layout.addLayout(image_layout)

        button_layout = QHBoxLayout()

        self.choose_image_button = QPushButton('选择图像')
        self.choose_image_button.clicked.connect(self.choose_image)
        self.choose_image_button.setFixedHeight(60)
        self.choose_image_button.setStyleSheet("""
            QPushButton {
                background-color: #4CAF50;
                color: white;
                font-size: 30px;
                padding: 10px;
                border: none;
                border-radius: 5px;
            }
            QPushButton:hover {
                background-color: #45a049;
            }
        """)
        button_layout.addWidget(self.choose_image_button)

        self.choose_model_combo = QComboBox()
        self.choose_model_combo.addItems(['DenseNet', 'ResNet', 'MobileNet', 'EfficientNet'])
        self.choose_model_combo.currentTextChanged.connect(self.choose_model)
        self.choose_model_combo.setFixedHeight(60)
        self.choose_model_combo.setStyleSheet("""
            QComboBox {
                font-size: 30px;
                padding: 10px;
                border: 1px solid #ccc;
                border-radius: 5px;
            }
        """)
        button_layout.addWidget(self.choose_model_combo)

        self.recognize_button = QPushButton('膝关节炎识别')
        self.recognize_button.clicked.connect(self.recognize_image)
        self.recognize_button.setFixedHeight(60)
        self.recognize_button.setStyleSheet("""
            QPushButton {
                background-color: #008CBA;
                color: white;
                font-size: 30px;
                padding: 10px;
                border: none;
                border-radius: 5px;
            }
            QPushButton:hover {
                background-color: #007bb5;
            }
        """)
        button_layout.addWidget(self.recognize_button)

        main_layout.addLayout(button_layout)

        container = QWidget()
        container.setLayout(main_layout)
        self.setCentralWidget(container)

    def choose_image(self):
        options = QFileDialog.Options()
        options |= QFileDialog.ReadOnly
        file_path, _ = QFileDialog.getOpenFileName(self, "选择图像", "",
                                                   "Image Files (*.png *.jpg *.bmp);;All Files (*)", options=options)
        if file_path:
            self.image_path = file_path
            pixmap = QPixmap(self.image_path)
            pixmap = pixmap.scaled(self.image_label.width(), self.image_label.height(), Qt.KeepAspectRatio)
            self.image_label.setPixmap(pixmap)
            self.result_label.setText('识别结果:')

    def choose_model(self, model_name):
        self.model_name = model_name
        self.model = load_model(self.model_name)

    def recognize_image(self):
        if self.image_path is None:
            self.result_label.setText('请选择一张图像!')
            return

        img = Image.open(self.image_path).convert('L')
        img = transform(img).unsqueeze(0).to(device)

        with torch.no_grad():
            outputs = self.model(img)
            _, preds = torch.max(outputs, 1)

        self.result_label.setText(f'识别结果:{classes[preds.item()]}')


if __name__ == '__main__':
    app = QApplication(sys.argv)
    ex = KneeOsteoarthritisApp()
    ex.show()
    sys.exit(app.exec_())

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值