#! pip install datasets transformers 
# -i https://pypi.tuna.tsinghua.edu.cn/simple




在当前jupyter笔记本中,我们将说明如何通过微调任意🤗Transformers 模型来构建多选任务,该任务是在给定的多个答案中选择最合理的一个。我们使用的数据集是SWAG,当然你也可以将预处理过程用于其他多选数据集或者你自己的数据。SWAG是一个关于常识推理的数据集,每个样本描述一种情况,然后给出四个可能的选项。

这个jupyter笔记本可以运行在model Hub中的任何模型上,只要该模型具有一个多选择头的版本。根据你的模型和你使用的GPU,你可能需要调整批大小,以避免显存不足的错误。设置好这两个参数之后,jupyter笔记本的其余部分就可以顺利运行了:

model_checkpoint = "bert-base-uncased"
batch_size = 16



from datasets import load_dataset, load_metric

load_dataset 将缓存数据集以避免下次运行时再次下载它。

datasets = load_dataset("swag", "regular")
import os

data_path = './datasets/swag/'
cache_dir = os.path.join(data_path, 'cache')
data_files = {'train': os.path.join(data_path, 'train.csv'), 'val': os.path.join(data_path, 'val.csv'), 'test': os.path.join(data_path, 'test.csv')}
datasets = load_dataset(data_path, 'regular', data_files=data_files, cache_dir=cache_dir)
    train: Dataset({
        features: ['video-id', 'fold-ind', 'startphrase', 'sent1', 'sent2', 'gold-source', 'ending0', 'ending1', 'ending2', 'ending3', 'label'],
        num_rows: 73546
    validation: Dataset({
        features: ['video-id', 'fold-ind', 'startphrase', 'sent1', 'sent2', 'gold-source', 'ending0', 'ending1', 'ending2', 'ending3', 'label'],
        num_rows: 20006
    test: Dataset({
        features: ['video-id', 'fold-ind', 'startphrase', 'sent1', 'sent2', 'gold-source', 'ending0', 'ending1', 'ending2', 'ending3', 'label'],
        num_rows: 20005

To access an actual element, you need to select a split first, then give an index:

{'ending0': 'passes by walking down the street playing their instruments.',
 'ending1': 'has heard approaching them.',
 'ending2': "arrives and they're outside dancing and asleep.",
 'ending3': 'turns the lead singer watches the performance.',
 'fold-ind': '3416',
 'gold-source': 'gold',
 'label': 0,
 'sent1': 'Members of the procession walk down the street holding small horn brass instruments.',
 'sent2': 'A drum line',
 'startphrase': 'Members of the procession walk down the street holding small horn brass instruments. A drum line',
 'video-id': 'anetv_jkn6uvmqwh4'}


from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
0are seated on a field.are skiing down the slope.are in a lift.are pouring out in a man.16668gold1A man is wiping the skiboard.Group of peopleA man is wiping the skiboard. Group of peopleanetv_JmL6BiuXr_g
1performs stunts inside a gym.shows several shopping in the water.continues his skateboard while talking.is putting a black bike close.11424gold0The credits of the video are shown.A ladyThe credits of the video are shown. A ladyanetv_dWyE0o2NetQ
2is emerging into the hospital.are strewn under water at some wreckage.tosses the wand together and saunters into the marketplace.swats him upside down.15023gen1Through his binoculars, someone watches a handful of surfers being rolled up into the wave.SomeoneThrough his binoculars, someone watches a handful of surfers being rolled up into the wave. Someonelsmdc3016_CHASING_MAVERICKS-6791
3spies someone sitting below.opens the fridge and checks out the photo.puts a little sheepishly.staggers up to him.5475gold3He tips it upside down, and its little umbrella falls to the floor.Back inside, someoneHe tips it upside down, and its little umbrella falls to the floor. Back inside, someonelsmdc1008_Spider-Man2-75503
4carries her to the grave.laughs as someone styles her hair.sets down his glass.stares after her then trudges back up into the street.6904gen1Someone kisses her smiling daughter on the cheek and beams back at the camera.SomeoneSomeone kisses her smiling daughter on the cheek and beams back at the camera. Someonelsmdc1028_No_Reservations-83242
5stops someone and sweeps all the way back from the lower deck to join them.is being dragged towards the monstrous animation.beats out many events at the touch of the sword, crawling it.reaches into a pocket and yanks open the door.14089gen1But before he can use his wand, he accidentally rams it up the troll's nostril.The angry trollBut before he can use his wand, he accidentally rams it up the troll's nostril. The angry trolllsmdc1053_Harry_Potter_and_the_philosophers_stone-95867
6sees someone's name in the photo.gives a surprised look.kneels down and touches his ripped specs.spies on someone's clock.8407gen1Someone keeps his tired eyes on the road.Glancing over, heSomeone keeps his tired eyes on the road. Glancing over, helsmdc1024_Identity_Thief-82693
7stops as someone speaks into the camera.notices how blue his eyes are.is flung out of the door and knocks the boy over.flies through the air, its a fireball.4523gold1Both people are knocked back a few steps from the force of the collision.SheBoth people are knocked back a few steps from the force of the collision. Shelsmdc0043_Thelma_and_Luise-68271
8sits close to the river.have pet's supplies and pets.pops parked outside the dirt facility, sending up a car highway to catch control.displays all kinds of power tools and website.8112gold1A guy waits in the waiting room with his pet.A pet store and its vanA guy waits in the waiting room with his pet. A pet store and its vananetv_9VWoQpg9wqE
9the slender someone, someone turns on the light., someone gives them to her boss then dumps some alcohol into dough.liquids from a bowl, she slams them drunk.wags his tail as someone returns to the hotel room.10867gold3Inside a convenience store, she opens a freezer case.DolceInside a convenience store, she opens a freezer case. Dolcelsmdc3090_YOUNG_ADULT-43871


def show_one(example):
    print(f"Context: {example['sent1']}")
    print(f"  A - {example['sent2']} {example['ending0']}")
    print(f"  B - {example['sent2']} {example['ending1']}")
    print(f"  C - {example['sent2']} {example['ending2']}")
    print(f"  D - {example['sent2']} {example['ending3']}")
    print(f"\nGround truth: option {['A', 'B', 'C', 'D'][example['label']]}")
Context: Members of the procession walk down the street holding small horn brass instruments.
  A - A drum line passes by walking down the street playing their instruments.
  B - A drum line has heard approaching them.
  C - A drum line arrives and they're outside dancing and asleep.
  D - A drum line turns the lead singer watches the performance.

Ground truth: option A
Context: Now it's someone's turn to rain blades on his opponent.
  A - Someone pats his shoulder and spins wildly.
  B - Someone lunges forward through the window.
  C - Someone falls to the ground.
  D - Someone rolls up his fast run from the water and tosses in the sky.

Ground truth: option C






from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)



tokenizer("Hello, this one sentence!", "And this sentence goes with it.")
{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}





ending_names = ["ending0", "ending1", "ending2", "ending3"]

def preprocess_function(examples):
    # Repeat each first sentence four times to go with the four possibilities of second sentences.
    first_sentences = [[context] * 4 for context in examples["sent1"]]
    # Grab all second sentences possible for each context.
    question_headers = examples["sent2"]
    second_sentences = [[f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers)]
    # Flatten everything
    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])
    # Tokenize
    tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)
    # Un-flatten
    return {k: [v[i:i+4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}

This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists of lists for each key: a list of all examples (here 5), then a list of all choices (4) and a list of input IDs (length varying here since we did not apply any padding):


examples = datasets["train"][:5]
features = preprocess_function(examples)
print(len(features["input_ids"]), len(features["input_ids"][0]), [len(x) for x in features["input_ids"][0]])
5 4 [30, 25, 30, 28]


idx = 3
[tokenizer.decode(features["input_ids"][idx][i]) for i in range(4)]
['[CLS] a drum line passes by walking down the street playing their instruments. [SEP] members of the procession are playing ping pong and celebrating one left each in quick. [SEP]',
 '[CLS] a drum line passes by walking down the street playing their instruments. [SEP] members of the procession wait slowly towards the cadets. [SEP]',
 '[CLS] a drum line passes by walking down the street playing their instruments. [SEP] members of the procession makes a square call and ends by jumping down into snowy streets where fans begin to take their positions. [SEP]',
 '[CLS] a drum line passes by walking down the street playing their instruments. [SEP] members of the procession play and go back and forth hitting the drums while the audience claps for them. [SEP]']

我们可以将它和之前生成的ground truth进行比较:

Context: A drum line passes by walking down the street playing their instruments.
  A - Members of the procession are playing ping pong and celebrating one left each in quick.
  B - Members of the procession wait slowly towards the cadets.
  C - Members of the procession makes a square call and ends by jumping down into snowy streets where fans begin to take their positions.
  D - Members of the procession play and go back and forth hitting the drums while the audience claps for them.

Ground truth: option D


encoded_datasets = datasets.map(preprocess_function, batched=True)
from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer

model = AutoModelForMultipleChoice.from_pretrained(model_checkpoint)
args = TrainingArguments(
    evaluation_strategy = "epoch",


然后,我们需要告诉我们的Trainer如何从预处理的输入数据中构造批数据。我们还没有做任何填充,因为我们将填充每个批到批内的最大长度(而不是使用整个数据集的最大长度)。这将是data collator的工作。它接受示例的列表,并将它们转换为一个批(在我们的示例中,通过应用填充)。由于在库中没有data collator来处理我们的特定问题,这里我们根据DataCollatorWithPadding自行改编一个:

from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from typing import Optional, Union
import torch

class DataCollatorForMultipleChoice:
    Data collator that will dynamically pad the inputs for multiple choice received.

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature.pop(label_name) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])
        flattened_features = [[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features]
        flattened_features = sum(flattened_features, [])
        batch = self.tokenizer.pad(
        # Un-flatten
        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        # Add back labels
        batch["labels"] = torch.tensor(labels, dtype=torch.int64)
        return batch

当传入一个示例的列表时,它会将大列表中的所有输入/注意力掩码等都压平,并传递给tokenizer.pad方法。这将返回一个带有大张量的字典(其大小为(batch_size * 4) x seq_length),然后我们将其展开。

我们可以在特征列表上检查data collator是否正常工作,在这里,我们只需要确保删除所有不被我们的模型接受的输入特征(这是Trainer自动为我们做的):

accepted_keys = ["input_ids", "attention_mask", "label"]
features = [{k: v for k, v in encoded_datasets["train"][i].items() if k in accepted_keys} for i in range(10)]
batch = DataCollatorForMultipleChoice(tokenizer)(features)


[tokenizer.decode(batch["input_ids"][8][i].tolist()) for i in range(4)]
['[CLS] someone walks over to the radio. [SEP] someone hands her another phone. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',
 '[CLS] someone walks over to the radio. [SEP] someone takes the drink, then holds it. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',
 '[CLS] someone walks over to the radio. [SEP] someone looks off then looks at someone. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',
 '[CLS] someone walks over to the radio. [SEP] someone stares blearily down at the floor. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]']
Context: Someone walks over to the radio.
  A - Someone hands her another phone.
  B - Someone takes the drink, then holds it.
  C - Someone looks off then looks at someone.
  D - Someone stares blearily down at the floor.

Ground truth: option D



import numpy as np

def compute_metrics(eval_predictions):
    predictions, label_ids = eval_predictions
    preds = np.argmax(predictions, axis=1)
    return {"accuracy": (preds == label_ids).astype(np.float32).mean().item()}


trainer = Trainer(


<table border="1" class="dataframe">
Epoch Training Loss Validation Loss Accuracy 1 0.154598 0.828017 0.766520 2 0.296633 0.667454 0.786814 3 0.111786 0.994927 0.789363

TrainOutput(global_step=6897, training_loss=0.19714653808275168)

最后,不要忘记将你的模型上传🤗 模型中心





