用 Transformers微调ViT图像分类

【翻译 Nate Raw 的 Fine-Tune ViT for Image Classification with 🤗 Transformers

正如基于 Transformers 的模型彻底改变了NLP一样,我们现在看到将它们应用于各种其他领域的论文激增。其中最具革命性的是 Vision Transformer (ViT),它由 Google Brain 的一组研究人员于 2021 年六月推出。

本文探讨了如何标记图像,就像标记句子一样,以便将它们传递给 transformer 模型进行训练。这是一个非常简单的概念,真的…

  1. 将映像拆分为子映像修补程序的网格
  2. 使用线性投影嵌入每个图面
  3. 每个嵌入式修补程序都将成为标记,生成的嵌入式修补程序序列就是传递给模型的序列。

事实证明,一旦你完成了上述操作,你就可以像习惯NLP任务一样对转换器进行预训练和微调。很贴心😎。

在这篇博文中,我们将介绍如何利用 🤗 下载的数据集和处理图像分类数据集,然后使用它们来微调带有 🤗 transformers 的预训练 ViT

首先,让我们先安装这两个软件包。

pip install datasets transformers
  • 1

补充建议4.28.x, 原因后面再提

加载一个数据集

让我们首先加载一个小的图像分类数据集并查看其结构。

我们将使用豆数据集,这是健康和不健康豆叶图片的集合。🍃

    from datasets import load_dataset

    ds = load_dataset('beans')
    ds
  • 1
  • 2
  • 3
  • 4

让我们看一下从 由’train’拆分的豆荚数据集的第 400 个示例。您会注意到数据集中的每个示例都有 3 个特征:

  1. image:PIL 图像

  2. image_file_path:str 路径指向需要加载的图像文件

  3. labels:数据集。类标签功能,它是标签的整数表示形式。(稍后您将看到如何获取字符串类名,别担心!)

     ex = ds['train'][400]
     ex
    
     {
         'image': <PIL.JpegImagePlugin ...>,
         'image_file_path': '/root/.cache/.../bean_rust_train.4.jpg',
         'labels': 1
     }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

让我们先看一下图像👀

    image = ex['image']
    image
  • 1
  • 2

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Q081sdpR-1685285046313)(null)]

那绝对是一片叶子!但是什么样的呢?😅

由于此数据集的特征是 datasets.features.ClassLabel,我们可以使用它来查找此示例的标签 ID 的相应名称。

首先,让我们访问labels的特征定义

    labels = ds['train'].features['labels']
    labels
  • 1
  • 2

    ClassLabel(num_classes=3, names=['angular_leaf_spot', 'bean_rust', 'healthy'], names_file=None, id=None)
  • 1

现在,让我们打印出示例的类标签。你可以通过使用 int2str 函数来做到这一点,顾名思义,它允许传递类的整数表示形式来查找字符串 label.ClassLabel

    labels.int2str(ex['labels'])
  • 1

    'bean_rust'
  • 1

事实证明,上面显示的叶子感染了豆锈病,豆锈病是豆类植物中的一种严重疾病。😢

让我们编写一个函数,该函数将显示每个类的示例网格,以便更好地了解您正在使用的内容。

import random
from PIL import ImageDraw, ImageFont, Image

def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):

    w, h = size
    labels = ds['train'].features['labels'].names
    grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
    draw = ImageDraw.Draw(grid)
    font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)

    for label_id, label in enumerate(labels):

        # Filter the dataset by a single label, shuffle it, and grab a few samples
        ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))

        # Plot this label's examples along a row
        for i, example in enumerate(ds_slice):
            image = example['image']
            idx = examples_per_class * label_id + i
            box = (idx % examples_per_class * w, idx // examples_per_class * h)
            grid.paste(image.resize(size), box=box)
            draw.text(box, label, (255, 255, 255), font=font)

    return grid

show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

数据集中每个类的几个示例的网格

依我所见:

  1. 角叶斑:有不规则的棕色斑块
  2. 豆锈:有圆形棕色斑点,周围环绕着白色黄色的环
  3. 健康:。。。看起来很健康。🤷‍♂️
加载 ViT 特征提取器

现在我们知道了我们的图像是什么样子的,并更好地了解了我们试图解决的问题。让我们看看如何为我们的模型准备这些图像!

训练 ViT 模型时,会将特定转换应用于馈入其中的图像。在图像上使用错误的转换,模型将无法理解它所看到的内容!🖼 ➡️ 🔢

为了确保我们应用正确的转换,我们将使用一个 ViTFeatureExtractor 初始化,该配置与我们计划使用的预训练模型一起保存。在我们的例子中,我们将使用google/vit-base-patch16-224-in21k模型,所以让我们从Hugging Face Hub加载它的特征提取器。

    from transformers import ViTFeatureExtractor

    model_name_or_path = 'google/vit-base-patch16-224-in21k'
    feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
  • 1
  • 2
  • 3
  • 4

您可以通过打印来查看特征提取器配置。

ViTFeatureExtractor {
    "do_normalize": true,
    "do_resize": true,
    "feature_extractor_type": "ViTFeatureExtractor",
    "image_mean": [
        0.5,
        0.5,
        0.5
    ],
    "image_std": [
        0.5,
        0.5,
        0.5
    ],
    "resample": 2,
    "size": 224
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

要处理图像,只需将其传递给特征提取器的调用函数即可。这将返回一个字典容器,这是要传递给model.pixel values的数字表示形式。

默认情况下,你会得到一个 NumPy 数组,但如果你添加参数 return_tensors=‘pt’,你会得到 torch 张量。

feature_extractor(image, return_tensors='pt')
  • 1

应该给你一些类似的东西…

{
'pixel_values': tensor([[[[ 0.2706,  0.3255,  0.3804,  ...]]]])
}
  • 1
  • 2
  • 3

…其中张量的形状为 。(1, 3, 224, 224)

处理数据集

现在您知道如何读取图像并将其转换为输入,让我们编写一个函数,将这两件事放在一起以处理数据集中的单个示例。

def process_example(example):
    inputs = feature_extractor(example['image'], return_tensors='pt')
    inputs['labels'] = example['labels']
    return inputs
  • 1
  • 2
  • 3
  • 4

process_example(ds['train'][0])
  • 1

{
    'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ..., ]]]]),
    'labels': 0
}
  • 1
  • 2
  • 3
  • 4

虽然您可以一次调用并将其应用于每个示例,但这可能会非常慢,尤其是在使用较大的数据集时。相反,您可以对数据集应用转换。转换仅在为示例编制索引时应用于示例 ds.map

但是,首先,您需要更新最后一个函数以接受一批数据,因为这是预期的。ds.with_transform

ds = load_dataset('beans')

def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['labels']
    return inputs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

您可以使用 ds.with_transform(transform) 将其直接应用于数据集。

prepared_ds = ds.with_transform(transform)
  • 1

现在,每当您从数据集中获取示例时,转换将是 实时应用(在样品和切片上,如下所示)

prepared_ds['train'][0:2]
  • 1

这一次,得到的张量将具有形状。pixel_values(2, 3, 224, 224)

{
    'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ..., ]]]]),
    'labels': [0, 0]
}
  • 1
  • 2
  • 3
  • 4
培训和评估

数据已处理完毕,即可开始设置训练管道。这篇博文使用 🤗 的 Trainer,但这需要我们先做几件事:

  • 定义排序规则函数。

  • 定义评估指标。在训练期间,应评估模型的预测准确性。您应该相应地定义一个函数。compute_metrics

  • 加载预训练的检查点。您需要加载预训练的检查点并正确配置它以进行训练。

  • 定义训练配置。

微调模型后,您将根据评估数据正确评估模型,并验证它是否确实学会了正确分类图像。

定义我们的数据整理器

批处理以字典列表的形式出现,因此您可以将它们解压缩+堆叠到批处理张量中。

由于 将返回批处理字典,因此您可以稍后将输入到模型中。✨collate_fn**unpack

import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
定义评估指标

来自的准确度指标可以轻松用于将预测与标签进行比较。下面,您可以看到如何使用datasetscompute_metricsTrainer

import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

让我们加载预训练模型。我们将添加 init,以便模型创建具有正确单位数的分类头。我们还将在 Hub 微件中包含 和 映射以具有人类可读的标签(如果您选择 )。num_labelsid2labellabel2idpush_to_hub

from transformers import ViTForImageClassification

labels = ds['train'].features['labels'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

快准备好训练了!在此之前需要做的最后一件事是通过定义 TrainingArguments 来设置训练配置。

其中大多数都是不言自明的,但这里非常重要的一点是。这将删除模型调用函数未使用的任何功能。默认情况下,这是因为通常最好删除未使用的特征列,从而更轻松地将输入解压缩到模型的调用函数中。但是,在我们的例子中,我们需要未使用的功能(特别是“图像”)来创建“pixel_values”.remove_unused_columns=FalseTrue

我想说的是,如果你忘记设置.remove_unused_columns=False,你会过得很糟糕

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./vit-base-beans",
    per_device_train_batch_size=16,
    evaluation_strategy="steps",
    num_train_epochs=4,
    fp16=True,
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to='tensorboard',
    load_best_model_at_end=True,
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

此处我一直遇到“NameError: name ‘PartialState’ is not defined.”,查阅
资料新的transformers用的accelerate有用到PartialState,所以需要控制它的版本,并且pip install git+https://github.com/huggingface/accelerate安装dev版本或没有用过multi-GPUs (such as in Colab)的使用 pip install accelerate -U

另外, 根据提示安装缺失的pypi包,解决raw.githubusercontent.com无法访问的问题

现在,所有实例都可以传递给训练师,我们准备开始训练了!

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=feature_extractor,
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
训练🚀
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
  • 1
  • 2
  • 3
  • 4
  • 5
评价📊
metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
  • 1
  • 2
  • 3

这是我的评估结果 - 酷豆!对不起,不得不说。

***** eval metrics *****
epoch                   =        4.0
eval_accuracy           =      0.985
eval_loss               =     0.0637
eval_runtime            = 0:00:02.13
eval_samples_per_second =     62.356
eval_steps_per_second   =       7.97
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

最后,如果需要,可以将模型推送到hub。在这里,如果您在训练配置中指定,我们将其向上推送。请注意,为了推送到 hub,您必须安装 git-lfs 并登录到您的 Hugging Face 帐户(可以通过 ).push_to_hub=Truehuggingface-cli 登录来完成

kwargs = {
    "finetuned_from": model.config._name_or_path,
    "tasks": "image-classification",
    "dataset": 'beans',
    "tags": ['image-classification'],
}

if training_args.push_to_hub:
    trainer.push_to_hub('🍻 cheers', **kwargs)
else:
    trainer.create_model_card(**kwargs)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

由此产生的模型已共享给nateraw/vit-base-beans。我假设你没有豆叶的图片,所以我添加了一些例子让你试一试!

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值