在小数据集上finetune ViT

代码整合自Huggingface的博文Fine-Tune ViT for Image Classification with 🤗 Transformers,主要功能是在数据集beans上finetune谷歌开源的vit模型 google/vit-base-patch16-224-in21k

import numpy as np
import torch
from datasets import load_metric
from transformers import Trainer
from transformers import TrainingArguments
from transformers import ViTForImageClassification
from transformers import ViTImageProcessor

metric = load_metric("accuracy")


def create_transform(processor):
    def transform(example_batch):
        # Take a list of PIL images and turn them to pixel values
        inputs = processor([x for x in example_batch['image']], return_tensors='pt')

        # Don't forget to include the labels!
        inputs['labels'] = example_batch['labels']
        return inputs

    return transform


def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }


def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)


from datasets import load_dataset


def train_vit(pretrained_model_name, dataset_name):
    ds = load_dataset(dataset_name)
    processor = ViTImageProcessor.from_pretrained(model_name_or_path)
    prepared_ds = ds.with_transform(create_transform(processor))
    labels = ds['train'].features['labels'].names

    model = ViTForImageClassification.from_pretrained(
        model_name_or_path,
        num_labels=len(labels),
        id2label={str(i): c for i, c in enumerate(labels)},
        label2id={c: str(i) for i, c in enumerate(labels)}
    )

    training_args = TrainingArguments(
        output_dir="./vit-base-beans",
        per_device_train_batch_size=16,
        evaluation_strategy="steps",
        num_train_epochs=4,
        fp16=True,
        save_steps=100,
        eval_steps=100,
        logging_steps=10,
        learning_rate=2e-4,
        save_total_limit=2,
        remove_unused_columns=False,
        push_to_hub=False,
        report_to='tensorboard',
        load_best_model_at_end=True,
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=collate_fn,
        compute_metrics=compute_metrics,
        train_dataset=prepared_ds["train"],
        eval_dataset=prepared_ds["validation"],
        tokenizer=processor,
    )

    train_results = trainer.train()
    trainer.save_model()
    trainer.log_metrics("train", train_results.metrics)
    trainer.save_metrics("train", train_results.metrics)
    trainer.save_state()

    metrics = trainer.evaluate(prepared_ds['test'])
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)


if __name__ == '__main__':
    model_name_or_path = 'google/vit-base-patch16-224-in21k'
    dataset_name_or_path = 'beans'
    train_vit(model_name_or_path, dataset_name_or_path)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值