代码整合自Huggingface的博文Fine-Tune ViT for Image Classification with 🤗 Transformers,主要功能是在数据集beans上finetune谷歌开源的vit模型 google/vit-base-patch16-224-in21k
import numpy as np
import torch
from datasets import load_metric
from transformers import Trainer
from transformers import TrainingArguments
from transformers import ViTForImageClassification
from transformers import ViTImageProcessor
metric = load_metric("accuracy")
def create_transform(processor):
def transform(example_batch):
# Take a list of PIL images and turn them to pixel values
inputs = processor([x for x in example_batch['image']], return_tensors='pt')
# Don't forget to include the labels!
inputs['labels'] = example_batch['labels']
return inputs
return transform
def collate_fn(batch):
return {
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
'labels': torch.tensor([x['labels'] for x in batch])
}
def compute_metrics(p):
return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
from datasets import load_dataset
def train_vit(pretrained_model_name, dataset_name):
ds = load_dataset(dataset_name)
processor = ViTImageProcessor.from_pretrained(model_name_or_path)
prepared_ds = ds.with_transform(create_transform(processor))
labels = ds['train'].features['labels'].names
model = ViTForImageClassification.from_pretrained(
model_name_or_path,
num_labels=len(labels),
id2label={str(i): c for i, c in enumerate(labels)},
label2id={c: str(i) for i, c in enumerate(labels)}
)
training_args = TrainingArguments(
output_dir="./vit-base-beans",
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
fp16=True,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=False,
report_to='tensorboard',
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
compute_metrics=compute_metrics,
train_dataset=prepared_ds["train"],
eval_dataset=prepared_ds["validation"],
tokenizer=processor,
)
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
metrics = trainer.evaluate(prepared_ds['test'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if __name__ == '__main__':
model_name_or_path = 'google/vit-base-patch16-224-in21k'
dataset_name_or_path = 'beans'
train_vit(model_name_or_path, dataset_name_or_path)