1. 数据集展示
txt文件,随便打了两行诗句,(不在意错别字)
孤山是被家庭西,睡眠出停运较低,几处早莺争暖树,谁家新燕啄春泥,乱花渐欲迷人眼,浅草才能没马蹄,最爱湖东行不足,绿杨阴厉白沙堤。
北风卷地白草折,胡天八月即飞雪;忽如一夜春风来,千树万树梨花开,散入珠帘是落幕,狐裘不暖锦琪娜
2. 训练代码
from transformers import AutoTokenizer, AutoModelForCausalLM,GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
import numpy as np
cache_dir = './huggingface'
checkpoint = "uer/gpt2-chinese-cluecorpussmall"
model = AutoModelForCausalLM.from_pretrained(checkpoint, cache_dir=cache_dir)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, cache_dir=cache_dir)
dataset = load_dataset("text", data_files="./data/poem.txt")
def process_func(data):
data = tokenizer.batch_encode_plus(data["text"], padding=True, truncation=True, max_length=512, return_tensors="pt")
data["labels"] = data["input_ids"].clone()
return data
dataset = dataset.map(process_func, batched=True, batch_size=2, remove_columns="text")
def eval_func(data):
preds, labels = data
preds = preds.argmax(axis=-1)
acc = np.sum(preds[:, :-1]==labels[:, 1:]) / labels.size
return {'acc': acc}
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
args = TrainingArguments(
output_dir="output",
per_device_eval_batch_size=1,
per_device_train_batch_size=1,
num_train_epochs=200
)
trainer = Trainer(
model=model,
args=args,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset["train"],
compute_metrics=eval_func,
data_collator=data_collator
)
trainer.train()
{‘train_runtime’: 73.2511, ‘train_samples_per_second’: 5.461, ‘train_steps_per_second’: 5.461, ‘train_loss’: 0.10257271766662597, ‘epoch’: 200.0}
trainer.evaluate()
{‘eval_loss’: 0.01241392083466053,
‘eval_acc’: 0.8484848484848485,
‘eval_runtime’: 0.0858,
‘eval_samples_per_second’: 23.318,
‘eval_steps_per_second’: 23.318,
‘epoch’: 200.0}
3. 推理测试
model.eval()
text = "孤山"
tokenized_text = tokenizer(text, return_tensors="pt")
tokenized_text = {k: v[:, :-1] for k,v in tokenized_text.items()}
tokenized_text = {k:v.cuda() for k,v in tokenized_text.items()}
out = model.generate(**tokenized_text, max_new_tokens=50, num_beams=1, do_sample=False, pad_token_id=50256)
print(tokenizer.batch_decode(out, skip_special_tokens=True))
[‘孤 山 是 被 家 庭 西 , 睡 眠 出 停 运 较 低 , 几 处 早 莺 争 暖 树 , 谁 家 新 燕 啄 春 泥 , 乱 花 渐 欲 迷 人 眼 , 浅 草 才 能 没 马 蹄 , 最 爱 湖 东’]