在PyTorch从事一个项目,这个项目创建一个深度学习模型,可以检测未知物种的疾病。
最近,决定在Julia中重建这个项目,并将其用作学习Flux.jl[1]的练习,这是Julia最流行的深度学习包(至少在GitHub上按星级排名)。
但在这样做的过程中,遇到了一些挑战,这些挑战在网上或文档中找不到好的例子。因此,决定写这篇文章,作为其他任何想在Flux做类似事情的人的参考资料。
这是给谁的?
因为Flux.jl(以下简称为“Flux”)是一个深度学习包,所以我主要为熟悉深度学习概念(如迁移学习)的读者编写这篇文章。
虽然在写这篇文章时也考虑到了Flux的一个半新手(比如我自己),但其他人可能会觉得这很有价值。只是要知道,写这篇文章并不是对Julia或通量的全面介绍或指导。为此,将分别参考其他资源,如官方的Julia和Flux文档。
最后,对PyTorch做了几个比较。了解本文观点并不需要有PyTorch的经验,但有PyTorch经验的人可能会觉得它特别有趣。
为什么是Julia?为什么选择Flux.jl?
如果你已经使用了Julia和/或Flux,你可能可以跳过本节。此外,许多其他人已经写了很多关于这个问题的帖子,所以我将简短介绍。
归根结底,我喜欢Julia。它在数值计算方面很出色,编程时真的很开心,而且速度很快。原生快速:不需要NumPy或其他底层C++代码的包装器。
至于为什么选择Flux,是因为它是Julia中最流行的深度学习框架,用纯Julia编写,可与Julia生态系统组合。
项目本身
好吧,既然我已经无耻地说服了Julia,现在是时候了解项目本身的信息了。
我使用了三个数据集——PlantVillage[2]、PlantLeaves[3]和PlantaeK[4]——涵盖了许多不同的物种。
我使用PlantVillage作为训练集,其他两个组合作为测试集。这意味着模型必须学习一些可以推广到未知物种的知识,因为测试集将包含未经训练的物种。
了解到这一点,我创建了三个模型:
使用ResNet迁移学习的基线
具有自定义CNN架构的孪生(又名暹罗)神经网络
具有迁移学习的孪生神经网络
本文的大部分内容将详细介绍处理数据、创建和训练模型的一些挑战和痛点。
处理数据
第一个挑战是数据集的格式错误。我不会在这里详细介绍如何对它们进行预处理,但最重要的是我创建了两个图像目录,即训练和测试。
这两个文件都填充了一长串图像,分别命名为img0.jpg、img1.jpg、imm2.jpg等。我还创建了两个CSV,一个用于训练集,一个为测试集,其中一列包含文件名,一列包含标签。
上述结构很关键,因为数据集的总容量超过10 GB,我电脑的内存肯定无法容纳,更不用说GPU的内存了。因此,我们需要使用DataLoader。(如果你曾经使用过PyTorch,你会很熟悉;这里的概念与PyTorch基本相同。)
为了在Flux中实现这一点,我们需要创建一个自定义结构来包装我们的数据集,以允许它批量加载数据。
为了让我们的自定义结构能够构造数据加载器,我们需要做的就是为类型定义两个方法:length和getindex。下面是我们将用于数据集的实现:
using Flux
using Images
using FileIO
using DataFrames
using Pipe
"""
ImageDataContainer(labels_df, img_dir)
Implements the functions `length` and `getindex`, which are required to use ImageDataContainer
as an argument in a DataLoader for Flux.
"""
struct ImageDataContainer
labels::AbstractVector
filenames::AbstractVector{String}
function ImageDataContainer(labels_df::DataFrame, img_dir::AbstractString)
filenames = img_dir .* labels_df[!, 1] # first column should be the filenames
labels = labels_df[!, 2] # second column should be the labels
return new(labels, filenames)
end
end
"Gets the number of observations for a given dataset."
function Base.length(dataset::ImageDataContainer)
return length(dataset.labels)
end
"Gets the i-th to j-th observations (including labels) for a given dataset."
function Base.getindex(dataset::ImageDataContainer, idxs::Union{UnitRange,Vector})
batch_imgs = map(idx -> load(dataset.filenames[idx]), idxs)
batch_labels = map(idx -> dataset.labels[idx], idxs)
"Applies necessary transforms and reshapings to batches and loads them onto GPU to be fed into a model."
function transform_batch(imgs, labels)
# convert imgs to 256×256×3×64 array (Height×Width×Color×Number) of floats (values between 0.0 and 1.0)
# arrays need to be sent to gpu inside training loop for garbage collector to work properly
batch_X = @pipe hcat(imgs...) |> reshape(_, (HEIGHT, WIDTH, length(labels))) |> channelview |> permutedims(_, (2, 3, 1, 4))
batch_y = @pipe labels |> reshape(_, (1, length(la