分类模型在新生儿肺部超声中的应用:ResNet152与EfficientNet-B7

分类模型在新生儿肺部超声中的应用:ResNet152与EfficientNet-B7

如果有疑问,想要指定内容或者获取更多内容,关注是小杜吖公众号留言即可

在现代医学影像中,超声技术因其无创性和实时性,成为了新生儿肺部疾病诊断的重要工具。然而,手动分析超声图像既费时又容易受主观因素影响。因此,利用深度学习技术进行自动分类成为了一种趋势。本文将介绍两种常用的卷积神经网络模型——ResNet152和EfficientNet-B7在新生儿肺部超声图像分类中的应用。

背景介绍
新生儿肺部超声

新生儿肺部超声是一种有效的影像学检查手段,可用于诊断新生儿呼吸系统疾病,如新生儿呼吸窘迫综合征、肺炎和气胸等。与传统的胸部X光相比,超声检查具有无辐射、实时性好、可床旁操作等优点,特别适合新生儿的检查需求。

深度学习在医学影像中的应用

深度学习,尤其是卷积神经网络(CNN),在医学影像分析中取得了显著的进展。通过自动提取图像特征,CNN能够在多种医学影像分类任务中达到甚至超过人类专家的水平。

ResNet152和EfficientNet-B7简介
ResNet152

ResNet(Residual Network)是由何凯明等人提出的一种深度卷积神经网络结构。ResNet152是该系列中的一种变体,具有152层深度。其核心思想是引入残差连接(residual connection),通过这种方式解决了深层网络训练中的梯度消失问题。ResNet152在ImageNet等大规模图像分类任务中表现优异,被广泛应用于各类图像处理任务。

EfficientNet-B7

EfficientNet是由谷歌提出的一种高效卷积神经网络架构。其通过一种复合缩放方法(compound scaling)系统性地调整网络的深度、宽度和分辨率,从而在计算效率和准确率之间取得平衡。EfficientNet-B7是该系列中的一个较大模型,在多个图像分类基准上达到了领先的性能。

实验设置与结果
数据集

采用新生儿肺部超声数据集,由于数据量较少,所以本次实验将数据同时用在训练集,测试集,验证集

模型训练

我们在ResNet152和EfficientNet-B7模型上进行了迁移学习,即先在ImageNet预训练,再在新生儿肺部超声图像数据集上微调。为了提高模型的泛化能力,我们采用了数据增强技术,包括旋转、平移和翻转等。

废话不多说,上完整代码(带注释)

以下是整理和注释后的完整代码,已经去除了多余的内容,并添加了注释:

#!/usr/bin/env python
# coding: utf-8

# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import cv2
import warnings
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Conv2D, BatchNormalization, Dense, MaxPool2D, Flatten, GlobalMaxPooling2D, Dropout
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.models import Sequential
from tensorflow.keras.applications import ResNet152, EfficientNetB7
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report

# Ignore warnings
warnings.filterwarnings("ignore")

# Set parameters
learning_rate = 0.0001
input_size = (224, 224)
batch_size = 8
monitor = "val_loss"
list_model = ["ResNet152", "EfficientNetB7"]
list_model_acc = []
list_model_loss = []

# Define function to plot loss and accuracy
def plot_loss_acc(history):
    fig, ax = plt.subplots(1, 2, figsize=(12, 5))
    ax[0].plot(history.history['accuracy'])
    ax[0].plot(history.history['val_accuracy'])
    ax[0].set_title('Model Accuracy')
    ax[0].set_xlabel('Epochs')
    ax[0].set_ylabel('Accuracy')
    ax[0].legend(['train', 'val'], loc='upper left')

    ax[1].plot(history.history['loss'])
    ax[1].plot(history.history['val_loss'])
    ax[1].set_title('Model Loss')
    ax[1].set_xlabel('Epochs')
    ax[1].set_ylabel('Loss')
    ax[1].legend(['train', 'val'], loc='upper left')
    plt.tight_layout()
    plt.show()

# Define function to evaluate model on test data
def eval_test(model, checkPoint_path):
    model.load_weights(checkPoint_path)
    test_evaluate = model.evaluate(test_data)
    list_model_acc.append(test_evaluate[1])
    list_model_loss.append(test_evaluate[0])
    y_pred = model.predict(test_data)
    y_pred = np.argmax(y_pred, axis=1)
    print(classification_report(y_pred, test_data.classes))
    cm = confusion_matrix(y_pred, test_data.classes)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    ax = disp.plot(cmap='Blues').ax_
    ax.set_title("Confusion matrix in test data")
    plt.show()

# Initialize early stopping callback
early_stop = EarlyStopping(monitor=monitor, patience=300)

# Load data from directory
train_data = ImageDataGenerator().flow_from_directory(
    'C:/code/lung/train',
    shuffle=True,
    batch_size=batch_size,
    target_size=input_size
)

valid_data = ImageDataGenerator().flow_from_directory(
    'C:/code/lung/val',
    shuffle=False,
    batch_size=batch_size,
    target_size=input_size
)

test_data = ImageDataGenerator().flow_from_directory(
    'C:/code/lung/test',
    shuffle=False,
    batch_size=batch_size,
    target_size=input_size
)

# Plot some training images
class_names = list(test_data.class_indices.keys())
plt.figure(figsize=(10, 10))
for i, (images, labels) in enumerate(train_data):
    if i == 15:
        break
    plt.subplots_adjust(hspace=0.5)
    plt.subplot(3, 5, i+1)
    plt.imshow(images[0].astype("uint8"))
    class_index = list(labels[0]).index(1)
    class_name = class_names[class_index]
    plt.title(class_name)
    plt.axis("off")
plt.tight_layout()
plt.show()

# ResNet152 model setup
resnet152_base = ResNet152(include_top=False, input_shape=(224, 224, 3))
resnet152_base.trainable = False

resnet152_model = Sequential([
    resnet152_base,
    BatchNormalization(),
    Flatten(),
    Dense(512, activation='relu'),
    Dropout(0.3),
    Dense(256, activation='relu'),
    Dropout(0.3),
    Dense(4, activation='softmax')
])

checkpoint_resnet152 = ModelCheckpoint(
    'resnet152_best_weight.h5',
    monitor=monitor,
    save_best_only=True,
    save_weights_only=True
)

resnet152_model.compile(optimizer=Adam(learning_rate), loss='categorical_crossentropy', metrics=['accuracy'])

history_resnet152 = resnet152_model.fit(
    train_data,
    validation_data=valid_data,
    epochs=10,
    callbacks=[checkpoint_resnet152],
    batch_size=batch_size,
    verbose=2
)

plot_loss_acc(history_resnet152)
eval_test(resnet152_model, "resnet152_best_weight.h5")

# EfficientNetB7 model setup
efficientnetB7_base = EfficientNetB7(include_top=False, input_shape=(224, 224, 3))
efficientnetB7_base.trainable = False

efficientnetB7_model = Sequential([
    efficientnetB7_base,
    BatchNormalization(),
    Flatten(),
    Dense(256, activation='relu'),
    Dense(4, activation='softmax')
])

checkpoint_efficientnetB7 = ModelCheckpoint(
    'EfficientNetB7_best_weight.h5',
    monitor=monitor,
    save_best_only=True,
    save_weights_only=True
)

efficientnetB7_model.compile(optimizer=Adam(learning_rate), loss='categorical_crossentropy', metrics=['accuracy'])

history_efficientnetB7 = efficientnetB7_model.fit(
    train_data,
    validation_data=valid_data,
    epochs=10,
    callbacks=[checkpoint_efficientnetB7],
    batch_size=batch_size,
    verbose=2
)

plot_loss_acc(history_efficientnetB7)
eval_test(efficientnetB7_model, "EfficientNetB7_best_weight.h5")

# Conclusion: Compare model performance
test_eval_general = pd.DataFrame({
    "Model": list_model,
    "Test Loss": list_model_loss,
    "Test Accuracy": list_model_acc
})

test_eval_general["Test Accuracy"] = pd.Categorical(test_eval_general["Test Accuracy"])
test_eval_general = pd.melt(test_eval_general, id_vars="Model")

plt.figure(figsize=(8, 5))
sns.lineplot(data=test_eval_general, x="Model", y="value", hue="variable", marker='o', markersize=10)
plt.title("Line Plot of Model's Loss and Accuracy in Test Data")
plt.xlabel("Models")
plt.ylabel("Values")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

以上代码按以下顺序进行:

  1. 导入必要的库。
  2. 设置参数。
  3. 定义绘图函数 plot_loss_acc 和评估函数 eval_test
  4. 初始化早期停止回调。
  5. 加载数据集。
  6. 可视化部分训练图像。
  7. 构建和训练 ResNet152 模型。
  8. 构建和训练 EfficientNetB7 模型。
  9. 比较模型性能并绘制结果图表。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值