利用u-net进行眼底图像与糖尿病视网膜病变特征分割(采用IDRID 数据集)

如题:

数据集提取:链接: https://pan.baidu.com/s/1yhWbYlWKK3eXtpXMQww-TA?pwd=tdte 提取码: tdte

框架代码:

Version:0.9 StartHTML:0000000105 EndHTML:0000021995 StartFragment:0000000141 EndFragment:0000021955

import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D,
concatenate
from tensorflow.keras.models import Model
# 数据预处理函数
def load_images_and_labels(image_folder, label_folder, img_size=(256, 256),
num_classes=5):
image_files = sorted(os.listdir(image_folder))
label_files = sorted(os.listdir(label_folder))
images = []
labels = []
for img_file, lbl_file in zip(image_files, label_files):
img_path = os.path.join(image_folder, img_file)
lbl_path = os.path.join(label_folder, lbl_file) image
=
load_img(img_path,
target_size=img_size,
color_mode='grayscale')
label = load_img(lbl_path, target_size=img_size, color_mode='grayscale')
image = img_to_array(image) / 255.0
label = img_to_array(label).astype(np.uint8)
images.append(image)
labels.append(label)
images = np.array(images)
labels = np.array(labels).squeeze() # 移除单一通道维度
# 检查标签中的唯一值
unique_labels = np.unique(labels)
print("Unique labels in the dataset:", unique_labels)
# 如果标签中有不在 [0, num_classes-1] 范围内的值,进行处理
if np.any(unique_labels >= num_classes):
labels[labels >= num_classes] = num_classes - 1 # 将超出范围的标签值
映射到 num_classes-1
# 再次检查标签中的唯一值
unique_labels = np.unique(labels)
print("Unique labels after processing:", unique_labels)
# 将标签转换为 one-hot 编码
labels = to_categorical(labels, num_classes=num_classes)
return images, labels # 加载训练数据
train_images, train_labels = load_images_and_labels(
r'C:\Users\PycharmProjects\pythonProject3\IDRiD\train\image',
r'C:\Users\PycharmProjects\pythonProject3\IDRiD\train\label',
num_classes=5 # 设置类别数为 5 ,以包含所有标签值
)
# 加载测试数据
test_images, test_labels = load_images_and_labels(
r'C:\Users\LSY\PycharmProjects\pythonProject3\IDRiD\test\image',
r'C:\Users\LSY\PycharmProjects\pythonProject3\IDRiD\test\label',
num_classes=5 # 设置类别数为 5 ,以包含所有标签值
)
# 构建 UNet 模型
def unet_multiclass(input_size=(256, 256, 1), num_classes=5):
inputs = Input(input_size)
# Encoder
conv1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
conv1 = Conv2D(64, 3, activation='relu', padding='same')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(128, 3, activation='relu', padding='same')(pool1)
conv2 = Conv2D(128, 3, activation='relu', padding='same')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(256, 3, activation='relu', padding='same')(pool2) conv3 = Conv2D(256, 3, activation='relu', padding='same')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(512, 3, activation='relu', padding='same')(pool3)
conv4 = Conv2D(512, 3, activation='relu', padding='same')(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
# Bottleneck
conv5 = Conv2D(1024, 3, activation='relu', padding='same')(pool4)
conv5 = Conv2D(1024, 3, activation='relu', padding='same')(conv5)
# Decoder
up6
=
Conv2D(512,
2,
activation='relu',
padding='same')(UpSampling2D(size=(2, 2))(conv5))
merge6 = concatenate([conv4, up6], axis=3)
conv6 = Conv2D(512, 3, activation='relu', padding='same')(merge6)
conv6 = Conv2D(512, 3, activation='relu', padding='same')(conv6)
up7
=
Conv2D(256,
2,
activation='relu',
padding='same')(UpSampling2D(size=(2, 2))(conv6))
merge7 = concatenate([conv3, up7], axis=3)
conv7 = Conv2D(256, 3, activation='relu', padding='same')(merge7)
conv7 = Conv2D(256, 3, activation='relu', padding='same')(conv7)
up8
=
Conv2D(128,
2,
activation='relu',
padding='same')(UpSampling2D(size=(2, 2))(conv7))
merge8 = concatenate([conv2, up8], axis=3)
conv8 = Conv2D(128, 3, activation='relu', padding='same')(merge8)
conv8 = Conv2D(128, 3, activation='relu', padding='same')(conv8) up9
=
Conv2D(64,
2,
activation='relu',
padding='same')(UpSampling2D(size=(2, 2))(conv8))
merge9 = concatenate([conv1, up9], axis=3)
conv9 = Conv2D(64, 3, activation='relu', padding='same')(merge9)
conv9 = Conv2D(64, 3, activation='relu', padding='same')(conv9)
conv9 = Conv2D(64, 3, activation='relu', padding='same')(conv9)
conv10 = Conv2D(num_classes, 1, activation='softmax')(conv9)
model = Model(inputs=inputs, outputs=conv10)
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
return model
# 创建模型实例
model = unet_multiclass(num_classes=5)
model.summary()
# 训练模型
history = model.fit(train_images, train_labels,
epochs=5,
batch_size=8,
validation_split=0.2,
verbose=1)
# 评估模型
test_loss, test_accuracy = model.evaluate(test_images, test_
  • 13
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值