基于 Tensorflow 的蘑菇分类

引言

当我们在大自然中行走的时候,经常会碰到各种各样的菌子,这时候我们就有了疑问:我们可以触碰它们吗?它们可以吃吗?如果有一个可以识别菌子的app就很棒了,so,现在让我们来实现吧~

在我们开始之前,让我们理解一些概念。计算机视觉是人工智能的一个有趣分支之一,是教模型在图像中查找信息从而理解视觉内容的艺术。当对人类(猫、狗、汽车……)进行图像分类非常简单时,机器总是很难具有竞争力,这是我们人类从小就学习的东西。计算机视觉已经走过了漫长的道路,现在有了深度学习,它的识别和人类一样好,在特定领域甚至更好。例如,在医学放射学中,可以训练人工智能来检测和分类肿瘤,并且通常比人类有更好的结果。

计算机视觉的第一步是图像检测。图像检测是在给定的图像中找到图像中的特定对象,并返回其坐标或包围盒。

图像分类是当你给出一个物体的图像时,你的模型以概率和置信率返回一个类。因此,我们的模型应该首先检测对象,然后根据它所训练的类型对它们进行分类。为此,我们通常使用 CNN(卷积神经网络)。

图像识别是当您给模型一个图像与多个对象。该模型为图像中的每个物体提供了它的边界框(目标检测)和类的预测,并给出了置信率。

现在我们遇到的是多目标的图像分类问题。

收集数据

为了训练一个模型,你需要好的标记数据,如果这一步出现了错误,后面所有的步骤都将徒劳无功。现在我们用的是 Kaggle 的真菌数据集,这是一个非常好的数据集,有1394个类可以在这里使用。数据集的链接如下:https://www.kaggle.com/c/fungi-challenge-fgvc-2018。

数据处理

Tensorflow 为我们提供了一个很便利的API,即 tf.data.dataset。我们可以很方便的用一行代码创建一个有效的数据集,让我们来看看吧~

 data_dir = '/Mydirectory/images/'
    img_height = 256
    img_width = 256
    batch_size = 32
    
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="training",
        seed=123,
        image_size=(img_height, img_width),
        batch_size=batch_size)
    
    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="validation",
        seed=123,
        image_size=(img_height, img_width),
        batch_size=batch_size)


    class_names = train_ds.class_names

以下是我们将使用的 10 个类:

[‘11082_Xerocomellus_chrysenteron’, ‘12919_Cylindrobasidium_laeve’, ‘14064_Fomitopsis_pinicola’, ‘14160_Ganoderma_pfeifferi’, ‘17233_Mycena_galericulata’, ‘20983_Trametes_versicolor’, ‘21143_Tricholoma_scalpturatum’, ‘40392_Armillaria_lutea’, ‘40985_Byssomerulius_corium’, ‘61207_Coprinellus_micaceus’]

让我们设置数据集性能

    #################################################
    # Dataset Performance
    ##################################################
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
    val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

迁移学习模式

开始创建自己的 CNN,但效果不佳。我不是那么有耐心去改进它,我选择进行迁移学习。迁移学习是重用在更大数据集上训练的模型的能力,这些模型已经学习了多个特征。为此,我们冻结顶层并使用新类重新训练,权重可重复使用。所以让我们用这些预先训练好的模型来帮助自己。我使用了 MobileNetV2 模型,因为它非常轻巧,在我的 GPU 上运行只需几秒钟。

为了提高准确性,我增加了一个独特的步骤,那就是数据增强。数据增强是: 对于一个标记为图像的输入,您可以缩放或翻转它,并将其作为模型的输入添加。这有助于模型继续识别对象,即使它并不总是处于相同的位置。

#################################################
    # Data Augmentation
    ##################################################
    data_augmentation = tf.keras.Sequential([
        tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
        tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
        tf.keras.layers.experimental.preprocessing.RandomZoom(0.1),
    ])
    #################################################
    # CREATE THE MODEL
    ##################################################
    num_classes = 10
    preprocess_input_mobilenet_v2 = tf.keras.applications.mobilenet_v2.preprocess_input


    base_model = tf.keras.applications.MobileNetV2(input_shape=(256, 256, 3),
                                                   include_top=False,
                                                   weights='imagenet')
    
      
    base_model.trainable = False
    
    image_batch, label_batch = next(iter(train_ds))
    feature_batch = base_model(image_batch)


    global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
    feature_batch_average = global_average_layer(feature_batch)
    prediction_layer = tf.keras.layers.Dense(num_classes, kernel_regularizer=tf.keras.regularizers.l2(0.001))
    prediction_batch = prediction_layer(feature_batch_average)
    inputs = tf.keras.Input(shape=(256, 256, 3))
    x = data_augmentation(inputs)
    x = preprocess_input_mobilenet_v2(x)
    x = base_model(x, training=False)
    x = global_average_layer(x)
    x = tf.keras.layers.Dropout(0.2)(x)
    outputs = prediction_layer(x)
    model = tf.keras.Model(inputs, outputs)


    #################################################
    # COMPILE THE MODEL
    ##################################################
    #
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])


    #################################################
    # TRAIN THE MODEL
    ##################################################


    epochs = 10
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs
    )

结果如下

现在让我们来预测下,代码如下所示:

# #################################################
# # LOAD THE MODEL
# ##################################################
model = tf.keras.models.load_model('MobileNetV2_Ep20')


# #################################################
# # Predictions
# ##################################################


img_url = "https://www.mycodb.fr/photos/Xerocomellus_chrysenteron_2014_rp_1.jpg"
img_path = tf.keras.utils.get_file('mushroom_image', origin=img_url)


img = tf.keras.preprocessing.image.load_img(
    img_path, target_size=(256, 256, 3)
)


img_array = tf.keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)  # Create a batch


predictions = model.predict(img_array)


predictions_sigmoid = tf.nn.sigmoid(predictions)
score = tf.nn.softmax(predictions[0])


print(
    "This image most likely belongs to {} with a {:.2f} percent confidence."
    .format(class_names[np.argmax(score)], 100 * np.max(score))
)


预测结果:

结果还是相当不错的吧~

总结

在本文中,我们了解了如何使用 tensorflow 训练一个用于分类菌子的模型,下一步我们就可以将它移植到移动端,想想还是很兴奋的呢~

·  END  ·

HAPPY LIFE

  • 3
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
TensorFlow.js 是一个基于 JavaScript 的机器学习库,可以在浏览器环境和 Node.js 环境中运行机器学习模型。而蘑菇分类小程序是一个利用 TensorFlow.js 实现的,用于识别不同种类的蘑菇的应用。 在蘑菇分类小程序中,首先需要准备一个已经经过训练的模型,该模型可以预测蘑菇的类型。这个模型可以通过 TensorFlow 或者其他机器学习库在训练数据集上进行训练得到。 接下来,在小程序的界面上,用户可以通过拍摄或上传一张蘑菇的照片。小程序会将这张照片转换为一个张量(Tensor),然后通过加载预训练模型来进行蘑菇类型的预测。常见的蘑菇类型可以包括有毒和无毒两种。 在预测完成后,小程序会将预测结果展示给用户。用户可以了解到这种蘑菇是有毒的还是无毒的,以便在野外采摘蘑菇时可以进行正确的警惕和保护。 蘑菇分类小程序的核心是利用 TensorFlow.js 进行蘑菇类型的预测。将训练好的模型加载到小程序中,并利用该模型对用户上传的蘑菇照片进行预测。这样,用户可以通过这个小程序来判断野外的蘑菇是否有毒,提高采摘蘑菇的安全性。 通过这个小程序,用户可以方便地识别出蘑菇的类型,避免误食有毒蘑菇对健康造成危害。同时,这个小程序也展现了 TensorFlow.js 在网页和移动应用中的应用潜力,为开发人员提供了一种用 JavaScript 来构建机器学习应用的新方式。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值