T4:猴痘病识别图片

T4周:猴痘病识别

⛽ 我的环境

  • 语言环境:Python3.10.12
  • 编译器:Google Colab
  • 深度学习环境:
    • TensorFlow2.15.0

一、前期工作

1.设置GPU,导入所需库
#os提供了一些与操作系统交互的功能,比如文件和目录操作
import os
#提供图像处理的功能,包括打开和显示、保存、裁剪等
import PIL
#pathlib提供了一个面向对象的接口来处理文件系统路径。路径被表示为Path对象,可以调用方法来进行各种文件和目录操作。
import pathlib

#用于绘制图形和可视化数据
import tensorflow as tf
import matplotlib.pyplot as plt
#用于数值计算的库,提供支持多维数组和矩阵运算
import numpy as np
#keras作为高层神经网络API,已被集成进tensorflow,使得训练更方便简单
from tensorflow import keras
#layers提供了神经网络的基本构建块,比如全连接层、卷积层、池化层等
#提供了构建和训练神经网络模型的功能,包括顺序模型(Sequential)和函数式模型(Functional API)
from tensorflow.keras import layers, models
# 获取所有可用的GPU设备列表,储存在变量gpus中
gpus = tf.config.list_physical_devices("GPU")

# 如果有GPU,即列表不为空
if gpus:
  # 获取第一个 GPU 设备
  gpu0 = gpus[0]
  # 设置 GPU 内存增长策略。开启这个选项可以让tf按需分配gpu内存,而不是一次性分配所有可用内存。
  tf.config.experimental.set_memory_growth(gpu0, True)
  #设置tf只使用指定的gpu(gpu[0])
  tf.config.set_visible_devices([gpu0],"GPU")

gpus
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
2.导入数据
#挂载google drive
from google.colab import drive
drive.mount("/content/drive/")
Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
#更改工作目录
%cd "/content/drive/MyDrive/Colab Notebooks/jupyter notebook/data/"
/content/drive/Othercomputers/My laptop/jupyter notebook/data
#这里使用相对路径
data_dir = './4/'
#将路径转换成pathlib.Path对象
data_dir = pathlib.Path(data_dir)
3.查看数据
# 使用glob方法获取当前目录的子目录里所有以'.jpg'为结尾的文件
# '*/*.jpg' 是一個通配符模式
# 第一个星号表示当前目录
# 第二个星号表示子目录
image_count = len (list(data_dir.glob("*/*.jpg")))

print("图片总数:", image_count)
图片总数: 2142
#Monkeypox/*.jpg表示Monkeypox子文件夹中以.jpg结尾的文件
Monkeypox = list(data_dir.glob("Monkeypox/*.jpg"))
PIL.Image.open(str(Monkeypox[69]))
#PIL.Image.open用于打开Monkeypox目录列表中的指定图像
#Monkeypox列表元素是文件路径(Path对象);
#将该路径对象转换为字符串类型,作为PIL.Image.open()方法的参数传入,然后方法会读取该路径指向的图像文件,并返回一个PIL.Image.Image对象。

在这里插入图片描述

二、数据预处理

1.加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

验证集v.s.数据集:

  1. 验证集并没有参与训练过程梯度下降过程的,狭义上来讲是没有参与模型的参数训练更新的。
  2. 但是广义上来讲,验证集存在的意义确实参与了一个“人工调参”的过程,我们根据每一个epoch训练之后模型在valid data上的表现来决定是否需要训练进行early stop,或者根据这个过程模型的性能变化来调整模型的超参数,如学习率,batch_size等等。
  3. 因此,我们也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集
#设置批量大小,即每次训练模型时输入图像数量
#每次训练迭代时,模型需处理32张图像
batch_size = 32
#图像的高度,加载图像数据时,将所有的图像调整为相同的高度
img_height = 224
#图像的宽度,加载图像数据时,将所有的图像调整为相同的宽度
img_width = 224
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
tr_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,#指定数据集中分割出多少比例数据当作验证集,0.2表示20%数据会被用来当验证集
    subset="training",#指定是用于训练还是验证的数据子集,这里设定为training
    #用于设置随机数种子,以确保数据集划分的可重复性和一致性。
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 2142 files belonging to 2 classes.
Using 1714 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split = 0.2,
    subset = "validation",
    seed = 123,
    image_size=(img_height,img_width),
    batch_size=batch_size
)
Found 2142 files belonging to 2 classes.
Using 428 files for validation.
class_names = tr_ds.class_names
# 可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称
class_names
['Monkeypox', 'Others']

随机数种子相关可参考:https://blog.csdn.net/weixin_51390582/article/details/124246873

2.可视化数据
#创建一个图形对象,设置图形大小宽度20英寸,高度10英寸
plt.figure(figsize=(20, 10))
#tr_ds.take(1)从训练数据集tr_ds中获取一个批次的数据。take(1)返回一个包含一批数据的Dataset对象
#images是这一个批次的图片
#labels是这一个批次的标签

for images, labels in tr_ds.take(1):
  #选择当前批次的前20张图
    for i in range(20):
      #将图形分成5行10列,子图索引为i+1
      ax = plt.subplot(5, 10, i + 1)
      #显示第i张图
      #images[i]是一个张量,使用.numpy()将其转化为Numpy数组,数据类型为float32,(0-255)数据大小,此处直接'/255.0'也可
      #.astype("uint8")将图像数据类型转换为无符号8位整数(uint8),介于0-1之间
      plt.imshow(images[i].numpy().astype("uint8"))
      #设置当前子图标题为当前图片的类别名称
      #labels[i]是当前图片对应的标签,通过该标签检索图片类别
      plt.title(class_names[labels[i]])
      #关闭坐标轴显示
      plt.axis("off")
      #print(images[1].numpy().dtype)--检查图片的数据类型,输出为float32
      #print(f"Min value: {images[1].numpy().min()}, Max value: {images[1].numpy().max()}")--检查像素值范围-(0-255)
#显示图片
plt.show()

在这里插入图片描述

for image_batch, labels_batch in tr_ds:
  print(image_batch.shape)
  print(labels_batch.shape)
  break

#`(32, 224, 224, 3)`--最后一维指的是彩色通道RGB
#`label_batch`是形状(32,)的张量,这些标签对应32张图片
(32, 224, 224, 3)
(32,)
3.配置数据集
#自动调整数据管道性能
AUTOTUNE = tf.data.AUTOTUNE
# 使用 tf.data.AUTOTUNE 具体的好处包括:
#自动调整并行度:自动决定并行处理数据的最佳线程数,以最大化数据吞吐量。
#减少等待时间:通过优化数据加载和预处理,减少模型训练时等待数据的时间。
#提升性能:自动优化数据管道的各个环节,使整个训练过程更高效。
#简化代码:不需要手动调整参数,代码更简洁且易于维护。

#使用cache()方法将训练集缓存到内存中,这样加快数据加载速度
#当多次迭代训练数据时,可以重复使用已经加载到内存的数据而不必重新从磁盘加载
#使用shuffle()对训练数据集进行洗牌操作,打乱数据集中的样本顺序
#参数1000指缓冲区大小,即每次从数据集中随机选择的样本数量
#prefetch()预取数据,节约在训练过程中数据加载时间
tr_ds = tr_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

三、构建CNN网络模型

卷积神经网络(CNN)的输入是张量 (Tensor) 形式的 (image_height, image_width, color_channels),包含了图像高度、宽度及颜色信息。不需要输入batch size。color_channels 为 (R,G,B) 分别对应 RGB 的三个颜色通道(color channel)。在此示例中,我们的 CNN 输入形状是 (224, 224, 3)。我们需要在声明第一层时将形状赋值给参数input_shape。

CNN的输入张量表示图像的结构和颜色信息。每个像素点都被表示为具有color_channels个数值的向量,在训练时,通过一系列卷积层、池化层和全连接层等操作提取和处理图像特征。

num_classes = 2

"""
关于卷积核的计算不懂的可以参考文章:https://blog.csdn.net/qq_38251616/article/details/114278995

layers.Dropout() 作用是防止过拟合,提高模型的泛化能力。

关于Dropout层的更多介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/115826689
"""
#创建序列模型,一种线性堆叠模型,各层按照他们被添加到模型中的顺序来堆叠
model = models.Sequential([
    # 数据预处理层:将像素值从 [0, 255] 缩放到 [0, 1]即归一化,输入(224, 224 ,3),
    layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
    #层输入可以手动添加batch size信息,model.summary()中每一层输出shape中None将为指定批次大小:
    #layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3),batch_size = 32),

    # 卷积层1:16 个 3x3 的卷积核,使用 ReLU 激活函数,输出 (222, 222, 16),
    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),

    # 池化层1:2x2 的平均池化,输出(111,111,16)
    layers.AveragePooling2D((2, 2)),  #池化层1,2*2采样

    # 卷积层2:32 个 3x3 的卷积核,使用 ReLU 激活函数,(109,109,32)
    layers.Conv2D(32, (3, 3), activation='relu'), #卷积层2,卷积核3*3

    # 池化层2:2x2 的平均池化,(54,54,32)
    layers.AveragePooling2D((2, 2)),  #池化层2,2*2采样

    # Dropout层:随机停止30%的神经元工作,防止过拟合
    layers.Dropout(0.3),

    # 卷积层3:64 个 3x3 的卷积核,使用 ReLU 激活函数 (52,52,64)
    layers.Conv2D(64, (3, 3), activation='relu'),

    # Dropout层:随机停止30%的神经元工作,防止过拟合
    layers.Dropout(0.3),

    # Flatten层:将多维特征图展平为一维,连接卷积层与全连接层
    layers.Flatten(),

    # 全连接层:128 个神经元,使用 ReLU 激活函数
    layers.Dense(128, activation='relu'),

    # 输出层:输出预期结果,神经元数量为 num_classes
    layers.Dense(num_classes)
])
model.summary()  # 打印网络结构
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 rescaling (Rescaling)       (None, 224, 224, 3)       0         
                                                                 
 conv2d (Conv2D)             (None, 222, 222, 16)      448       
                                                                 
 average_pooling2d (Average  (None, 111, 111, 16)      0         
 Pooling2D)                                                      
                                                                 
 conv2d_1 (Conv2D)           (None, 109, 109, 32)      4640      
                                                                 
 average_pooling2d_1 (Avera  (None, 54, 54, 32)        0         
 gePooling2D)                                                    
                                                                 
 dropout (Dropout)           (None, 54, 54, 32)        0         
                                                                 
 conv2d_2 (Conv2D)           (None, 52, 52, 64)        18496     
                                                                 
 dropout_1 (Dropout)         (None, 52, 52, 64)        0         
                                                                 
 flatten (Flatten)           (None, 173056)            0         
                                                                 
 dense (Dense)               (None, 128)               22151296  
                                                                 
 dense_1 (Dense)             (None, 2)                 258       
                                                                 
=================================================================
Total params: 22175138 (84.59 MB)
Trainable params: 22175138 (84.59 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

四、编译模型

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=1e-4)

model.compile(optimizer=opt,
       loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
       metrics=['accuracy'])
#Adam优化器是一种常用的梯度下降优化算法,用于更新模型的权重以最小化训练过程中的损失函数
#对模型的输出logits进行稀疏分类交叉熵损失计算
#from_logits=True说明模型输出是未经softmax函数转换的原始分数或概率值;
#False则会对经过softmax处理后的输出进行计算

五、训练模型

ModelCheckpoint函数

参考CSDN博客:https://blog.csdn.net/Marryvivien/article/details/126954192?spm=1001.2014.3001.5502

🥑函数原型https://keras.io/api/callbacks/model_checkpoint/

tf.keras.callbacks.ModelCheckpoint(
    filepath,
    monitor='val_loss',
    verbose=0,
    save_best_only=False,
    save_weights_only=False,
    mode='auto',
    save_freq='epoch',
    initial_value_threshold=None
)

ModelCheckpoint callback与使用model.fit()进行的训练结合使用,可在一定时间间隔内保存模型或权重(在检查点文件中),以便以后加载模型或权重,从保存的状态继续训练。

提供的一些选项包括:

  • 是否仅保留迄今为止实现“最佳性能”的模型,或者是否无论性能如何都在每个时期结束时保存模型。
  • “最佳”的定义;监控哪个数量以及是否应最大化或最小化。
  • 保存频率。目前,回调支持在每个时期结束时或在固定数量的训练批次后进行保存。
  • 是否仅保存权重,还是保存整个模型。

🥑函数参数

  • filepath:字符串或PathLike,保存模型文件的路径。filepath可以是格式化的字符串,里面的占位符将会被epoch值和传入on_epoch_end的logs关键字所填入。

    save_weights_only=True时,文件路径名需以.weights.h5结尾;当检查点保存整个模型时(默认),文件路径名应以.keras结尾。


    例如:如果文件路径为{epoch:02d}-{val_loss:.2f}.keras,那么在保存模型检查点时,文件名中将包含对应epoch和验证损失。文件路径的目录不应被任何其他回调重复使用,以避免冲突。
  • monitor:需监测的指标. 通常是由Model.compile方法设定的指标. 注意:
    • 在名称前加上 "val_"前缀来监测验证指标.
    • 使用"loss""val_loss" 监控模型总损失.
    • 如果将指标指定为字符串, 如 "accuracy", 则传递相同的字符串 (带或不带 "val_" 前缀).
    • 如果传递 metrics.Metric对象, monitor 应设置为 metric.name
    • 如果不确定指标名称,可检查history=model.fit()返回的history.history字典内容
    • 多输出模型在指标名称上设置额外前缀
  • verbose:显示模式,0或1,0代表静默,1时回调时显示信息
  • save_best_only:只在模型被认为是目前最好时保存。如果filepath不包含格式化选项,例如{epoch},则新保存的更好模型将覆盖之前保存的模型。
  • mode{"auto", "min", "max"} 之一.如果save_best_only=True,则根据监测指标的最大化或最小化来决定是否覆盖保存文件。对val_acc 应为 max ,对val_loss应为 min.在 auto 模式,如果监控的指标为 acc 或以 ‘fmeasure’ 开头,则模式为 max,对余下的则为 min
  • save_weights_only:True表示只保存模型的权重(model.save_weights(filepath))否则整个模型被保存(model.save(filepath)
  • save_freq'epoch' 或 integer。当使用 epoch时,callback 在每个 epoch 后保存模型。当使用 integer,则在这些 batch 后保存模型。如果 Model 使用 steps_per_execution= N 选项进行编译,则每 Nth batch 检查保存条件。注意,如果保存和 epoch 没对齐,则监控指标可能不可靠(它可能只反映一个 batch,因为指标在每个 epoch 结束会重置)。默认 ‘epoch’。
  • optionssave_weights_only 为 True 时可选的 tf.train.CheckpointOptions 对象 或 save_weights_only 为 False 时可选的 tf.saved_model.SaveOptions 对象。
  • initial_value_threshold:要监控的指标的初始 "最佳 "浮点值。仅适用于 save_best_value=True。只有当当前模型的性能优于该值时,才会覆盖已保存的模型权重。

🥑示例代码

model.compile(loss=..., optimizer=...,
              metrics=['accuracy'])

EPOCHS = 10
checkpoint_filepath = '/tmp/ckpt/checkpoint.model.keras'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

# Model is saved at the end of every epoch, if it's the best seen so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])

# The model (that are considered the best) can be loaded as -
keras.models.load_model(checkpoint_filepath)

# Alternatively, one could checkpoint just the model weights as -
checkpoint_filepath = '/tmp/ckpt/checkpoint.weights.h5'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])

# The model weights (that are considered the best) can be loaded as -
model.load_weights(checkpoint_filepath)

训练模型

from tensorflow.keras.callbacks import ModelCheckpoint

epochs = 50

checkpointer = ModelCheckpoint(
   filepath = '../best4model.weights.h5',
   monitor = 'val_accuracy',
   verbose = 1,
   mode = 'max',
   save_best_only=True,
   save_weights_only=True)

history = model.fit(tr_ds,
           validation_data=val_ds,
           epochs=epochs,
           callbacks=[checkpointer])
Epoch 1/50
54/54 [==============================] - ETA: 0s - loss: 0.7378 - accuracy: 0.5193
Epoch 1: val_accuracy improved from -inf to 0.53738, saving model to ../best4model.weights.h5
54/54 [==============================] - 164s 964ms/step - loss: 0.7378 - accuracy: 0.5193 - val_loss: 0.6784 - val_accuracy: 0.5374
Epoch 2/50
54/54 [==============================] - ETA: 0s - loss: 0.6542 - accuracy: 0.6144
Epoch 2: val_accuracy improved from 0.53738 to 0.67757, saving model to ../best4model.weights.h5
54/54 [==============================] - 3s 65ms/step - loss: 0.6542 - accuracy: 0.6144 - val_loss: 0.6231 - val_accuracy: 0.6776
Epoch 3/50
53/54 [============================>.] - ETA: 0s - loss: 0.6152 - accuracy: 0.6611
Epoch 3: val_accuracy improved from 0.67757 to 0.69393, saving model to ../best4model.weights.h5
54/54 [==============================] - 4s 76ms/step - loss: 0.6156 - accuracy: 0.6587 - val_loss: 0.5928 - val_accuracy: 0.6939
Epoch 4/50
54/54 [==============================] - ETA: 0s - loss: 0.5798 - accuracy: 0.6902
Epoch 4: val_accuracy improved from 0.69393 to 0.71729, saving model to ../best4model.weights.h5
54/54 [==============================] - 4s 66ms/step - loss: 0.5798 - accuracy: 0.6902 - val_loss: 0.5660 - val_accuracy: 0.7173
Epoch 5/50
54/54 [==============================] - ETA: 0s - loss: 0.5547 - accuracy: 0.7188
Epoch 5: val_accuracy improved from 0.71729 to 0.73364, saving model to ../best4model.weights.h5
54/54 [==============================] - 4s 65ms/step - loss: 0.5547 - accuracy: 0.7188 - val_loss: 0.5300 - val_accuracy: 0.7336
... ...
Epoch 46/50
53/54 [============================>.] - ETA: 0s - loss: 0.0515 - accuracy: 0.9869
Epoch 46: val_accuracy did not improve from 0.88785
54/54 [==============================] - 2s 45ms/step - loss: 0.0509 - accuracy: 0.9872 - val_loss: 0.4417 - val_accuracy: 0.8832
Epoch 47/50
53/54 [============================>.] - ETA: 0s - loss: 0.0687 - accuracy: 0.9750
Epoch 47: val_accuracy did not improve from 0.88785
54/54 [==============================] - 2s 44ms/step - loss: 0.0695 - accuracy: 0.9749 - val_loss: 0.4696 - val_accuracy: 0.8785
Epoch 48/50
53/54 [============================>.] - ETA: 0s - loss: 0.0570 - accuracy: 0.9845
Epoch 48: val_accuracy did not improve from 0.88785
54/54 [==============================] - 2s 44ms/step - loss: 0.0574 - accuracy: 0.9842 - val_loss: 0.4795 - val_accuracy: 0.8762
Epoch 49/50
53/54 [============================>.] - ETA: 0s - loss: 0.0880 - accuracy: 0.9673
Epoch 49: val_accuracy did not improve from 0.88785
54/54 [==============================] - 3s 47ms/step - loss: 0.0874 - accuracy: 0.9679 - val_loss: 0.4378 - val_accuracy: 0.8762
Epoch 50/50
53/54 [============================>.] - ETA: 0s - loss: 0.0483 - accuracy: 0.9851
Epoch 50: val_accuracy did not improve from 0.88785
54/54 [==============================] - 2s 46ms/step - loss: 0.0478 - accuracy: 0.9854 - val_loss: 0.4397 - val_accuracy: 0.8855

六、模型评估

#从history中获取准确率和损失

acc = history.history['accuracy'] #训练集正确率
val_acc = history.history['val_accuracy'] #验证集正确率

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

# 创建新图像,指定大小
plt.figure(figsize=(12, 4))
# 创建子图,绘制1行2列的第一个
plt.subplot(1, 2, 1)
# 训练集正确率随迭代次数变化图
plt.plot(epochs_range, acc, label='Training Accuracy')
# 验证集正确率随迭代次数变化图
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
# 图例位于右下角
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

#绘制第二个子图
plt.subplot(1, 2, 2)
# 训练集损失随迭代次数的变化图
plt.plot(epochs_range, loss, label='Training Loss')
# 验证集损失随迭代次数的变化图
plt.plot(epochs_range, val_loss, label='Validation Loss')
# 图例位于右上角
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

plt.show()

在这里插入图片描述

七、预测

注释参考:https://blog.csdn.net/qq_45735298/article/details/130056861?spm=1001.2014.3001.5502

model.load_weights('../best4model.weights.h5')
#这段代码用于加载之前训练中保存的最佳模型权重。使用之前保存的模型权重文件路径和名称。
#这样可以避免从头开始训练模型,直接使用已经训练好的最佳模型进行预测的工作。

from PIL import Image
import numpy as np

img = Image.open("./4/Monkeypox/M10_02_13.jpg") #使用 PIL 库中的 Image.open() 方法打开一张待预测的图片。
image = tf.image.resize(img, [img_height, img_width])
#这个函数调整输入图像的大小以符合模型的要求。
#在这个例子中,使用 TensorFlow 的 tf.image.resize() 函数将图像缩放为指定大小,其中 img_height 和 img_width 是指定的图像高度和宽度。

img_array = tf.expand_dims(image, 0)
'''
这个函数将输入图像转换为形状为 (1, height, width, channels) 的四维数组,
其中 height 和 width 是图像的高度和宽度,channels 是图像的通道数(例如 RGB 图像有 3 个通道)。
这里使用 TensorFlow 的 tf.expand_dims() 函数来扩展图像数组的维度,以匹配模型的输入格式。

具体来说:
image 是一个二维图片张量,它的形状是 (height, width, channels)。其中 height 和 width 分别为图片的高度和宽度,channels 为图片的颜色通道数。

0 是一个整数值,它指定在哪个维度上扩展此张量,这里表示在最前面(第一个)的维度上扩展。
因此,函数的作用是将输入张量 image 在最前面添加一个额外的维度(batch_size),生成一个四维张量。

tf.expand_dims(input, axis)
其中 input 表示要扩展的输入张量,axis 表示要在哪个维度上进行扩展。在这个例子中,input 是变量 image,axis 是 0。
'''

pre = model.predict(img_array)
#这个函数用于对输入图像进行分类预测。它使用已经训练好的模型来对输入数据进行推断,并输出每个类别的概率分布。
print("预测结果为:", class_names[np.argmax(pre)])


#将模型输出的概率分布转换为最终预测结果。
#具体来说,使用 np.argmax() 函数找到概率最大的类别索引,然后使用该索引在 class_names 列表中查找相应的类别名称,并输出预测结果。

1/1 [==============================] - 0s 18ms/step
预测结果为: Monkeypox

八、总结

  • 学习掌握callbacks当中的ModelCheckpoint函数
  • 跑通一遍代码,模型val-acc > 88%
  • 19
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值