使用Flux.jl进行图像分类

本文介绍了作者在Julia中使用Flux.jl重建一个深度学习项目,用于图像分类,特别是植物疾病检测。文章详细讨论了数据处理、模型创建,包括ResNet迁移学习和自定义CNN的孪生网络,以及训练过程中的挑战和解决方案。作者还分享了Flux与PyTorch的对比以及在处理大型数据集时的数据加载策略。
摘要由CSDN通过智能技术生成

在PyTorch从事一个项目,这个项目创建一个深度学习模型,可以检测未知物种的疾病。

最近,决定在Julia中重建这个项目,并将其用作学习Flux.jl[1]的练习,这是Julia最流行的深度学习包(至少在GitHub上按星级排名)。

但在这样做的过程中,遇到了一些挑战,这些挑战在网上或文档中找不到好的例子。因此,决定写这篇文章,作为其他任何想在Flux做类似事情的人的参考资料。

这是给谁的?

因为Flux.jl(以下简称为“Flux”)是一个深度学习包,所以我主要为熟悉深度学习概念(如迁移学习)的读者编写这篇文章。

虽然在写这篇文章时也考虑到了Flux的一个半新手(比如我自己),但其他人可能会觉得这很有价值。只是要知道,写这篇文章并不是对Julia或通量的全面介绍或指导。为此,将分别参考其他资源,如官方的Julia和Flux文档。

最后,对PyTorch做了几个比较。了解本文观点并不需要有PyTorch的经验,但有PyTorch经验的人可能会觉得它特别有趣。

为什么是Julia?为什么选择Flux.jl?

如果你已经使用了Julia和/或Flux,你可能可以跳过本节。此外,许多其他人已经写了很多关于这个问题的帖子,所以我将简短介绍。

归根结底,我喜欢Julia。它在数值计算方面很出色,编程时真的很开心,而且速度很快。原生快速:不需要NumPy或其他底层C++代码的包装器。

至于为什么选择Flux,是因为它是Julia中最流行的深度学习框架,用纯Julia编写,可与Julia生态系统组合。

项目本身

好吧,既然我已经无耻地说服了Julia,现在是时候了解项目本身的信息了。

我使用了三个数据集——PlantVillage[2]、PlantLeaves[3]和PlantaeK[4]——涵盖了许多不同的物种。

我使用PlantVillage作为训练集,其他两个组合作为测试集。这意味着模型必须学习一些可以推广到未知物种的知识,因为测试集将包含未经训练的物种。

了解到这一点,我创建了三个模型:

  1. 使用ResNet迁移学习的基线

  2. 具有自定义CNN架构的孪生(又名暹罗)神经网络

  3. 具有迁移学习的孪生神经网络

本文的大部分内容将详细介绍处理数据、创建和训练模型的一些挑战和痛点。

处理数据

第一个挑战是数据集的格式错误。我不会在这里详细介绍如何对它们进行预处理,但最重要的是我创建了两个图像目录,即训练和测试。

这两个文件都填充了一长串图像,分别命名为img0.jpg、img1.jpg、imm2.jpg等。我还创建了两个CSV,一个用于训练集,一个为测试集,其中一列包含文件名,一列包含标签。

上述结构很关键,因为数据集的总容量超过10 GB,我电脑的内存肯定无法容纳,更不用说GPU的内存了。因此,我们需要使用DataLoader。(如果你曾经使用过PyTorch,你会很熟悉;这里的概念与PyTorch基本相同。)

为了在Flux中实现这一点,我们需要创建一个自定义结构来包装我们的数据集,以允许它批量加载数据。

为了让我们的自定义结构能够构造数据加载器,我们需要做的就是为类型定义两个方法:length和getindex。下面是我们将用于数据集的实现:

using Flux
using Images
using FileIO
using DataFrames
using Pipe

"""
    ImageDataContainer(labels_df, img_dir)
Implements the functions `length` and `getindex`, which are required to use ImageDataContainer
as an argument in a DataLoader for Flux.
"""
struct ImageDataContainer
    labels::AbstractVector
    filenames::AbstractVector{String}
    function ImageDataContainer(labels_df::DataFrame, img_dir::AbstractString)
        filenames = img_dir .* labels_df[!, 1] # first column should be the filenames
        labels = labels_df[!, 2] # second column should be the labels
        return new(labels, filenames)
    end
end

"Gets the number of observations for a given dataset."
function Base.length(dataset::ImageDataContainer)
    return length(dataset.labels)
end

"Gets the i-th to j-th observations (including labels) for a given dataset."
function Base.getindex(dataset::ImageDataContainer, idxs::Union{UnitRange,Vector})
    batch_imgs = map(idx -> load(dataset.filenames[idx]), idxs)
    batch_labels = map(idx -> dataset.labels[idx], idxs)

    "Applies necessary transforms and reshapings to batches and loads them onto GPU to be fed into a model."
    function transform_batch(imgs, labels)
        # convert imgs to 256×256×3×64 array (Height×Width×Color×Number) of floats (values between 0.0 and 1.0)
        # arrays need to be sent to gpu inside training loop for garbage collector to work properly
        batch_X = @pipe hcat(imgs...) |> reshape(_, (HEIGHT, WIDTH, length(labels))) |> channelview |> permutedims(_, (2, 3, 1, 4))
        batch_y = @pipe labels |> reshape(_, (1, length(la
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值