AI写诗--基于GPT2预训练模型

目录

 1  处理数据

 1.1  加载预训练的分词器¶

2  自定义创建数据集 

 2.1  创建dataset

2.2  自定义collate_fn(数据批量输出的方法) 

 2.3  创建数据加载器 

3  创建模型 

4  训练过程代码 

 5  保存训练好的模型

 6  加载保存好的模型

 7  测试预测阶段代码


 

#目前,NLP与CV主要使用transformers库

#框架:主要使用PyTorch

#NLP任务的大体流程:
#处理数据: 中文字符 ---> 数字
#创建数据集。 把处理好的数据变成PyTorch的数据集
#生成模型, 一般使用transformers库,不需要自己建模
#训练预测过程

#配置代理
# import os

# os.environ['http_proxy'] = '127.0.0.1:10809'
# os.environ['https_proxy'] = '127.0.0.1:10809'
#这里是本地加载预训练模型,不需要

 1  处理数据

 1.1  加载预训练的分词器

from transformers import AutoTokenizer   #AutoTokenizer分词器 可以使中文字符转变成数字


#我这里是本地加载的模型文件
tokenizer = AutoTokenizer.from_pretrained('../data/model/gpt2-chinese-cluecorpussmall/')
print(tokenizer)

 

BertTokenizerFast(name_or_path='../data/model/gpt2-chinese-cluecorpussmall/', vocab_size=21128, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
#编码分词试算
text = [ '明朝驿使发,一夜絮征袍.素手抽针冷,那堪把剪刀.裁缝寄远道,几日到临洮.',
         '长安一片月,万户捣衣声.秋风吹不尽,总是玉关情.何日平胡虏,良人罢远征.']
#输出结果为一个字典,包含'input_ids'、'token_type_ids'、'attention_mask'
tokenizer.batch_encode_plus(text)
{'input_ids': [[101, 3209, 3308, 7731, 886, 1355, 117, 671, 1915, 5185, 2519, 6151, 119, 5162, 2797, 2853, 7151, 1107, 117, 6929, 1838, 2828, 1198, 1143, 119, 6161, 5361, 2164, 6823, 6887, 117, 1126, 3189, 1168, 707, 3826, 119, 102], [101, 7270, 2128, 671, 4275, 3299, 117, 674, 2787, 2941, 6132, 1898, 119, 4904, 7599, 1430, 679, 2226, 117, 2600, 3221, 4373, 1068, 2658, 119, 862, 3189, 2398, 5529, 5989, 117, 5679, 782, 5387, 6823, 2519, 119, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

2  自定义创建数据集 

 2.1  创建dataset

import torch


class Dataset(torch.utils.data.Dataset):
    def __init__(self):
        super().__init__()
        #从本地读取数据
        with open('../data/datasets/chinese_poems.txt', encoding='utf-8') as f:
            lines = f.readlines()  #读取的每一行数据都会以一个字符串的形式 依次添加到一个列表中
            
        #split()函数可以根据指定的分隔符将字符串拆分成多个子字符串,并将这些子字符串存储在一个列表中。
        #strip()函数默认移除字符串两端的空白字符(包括空格、制表符、换行符等)
        
        lines = [line.strip() for line in lines]  #输出的lines是一个一维列表,里面的每一行诗都是一个字符串
        #self.的变量在类里面可以调用
        self.lines = lines   #self.lines是一个列表,里面的元素都是一个个字符串
        
        
    def __len__(self):
        return len(self.lines)
    
    def __getitem__(self, i):
        """可以向列表一样通过索引来获取数据"""
        return self.lines[i]

#试跑一下
dataset = Dataset()        
len(dataset), dataset[0]        
        

 

(304752, '欲出未出光辣达,千山万山如火发.须臾走向天上来,逐却残星赶却月.')

dataset数据集只能一条一条数据的输出,不能一批批数据传输,
需要将datatset变成pytorch中dataloader的数据形式,将数据可以批量输出 

2.2  自定义collate_fn(数据批量输出的方法) 

def collate_fn(batch):
    #使用分词器 把中文编码成数字
    #tokenizer分词器的输出结果data是一个字典,包含'input_ids'、'token_type_ids'、'attention_mask'
    data = tokenizer.batch_encode_plus(batch, 
                                padding=True,
                                truncation=True,
                                max_length=512,
                                return_tensors='pt')
    #向字典data中添加数据标签目标值labels, 用data原数据中的['input_ids']诗句文字编码来赋值,
    #克隆一份对原数据无影响
    data['labels'] = data['input_ids'].clone()
    return data

 2.3  创建数据加载器 

loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=4, 
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)
#dataloader不能直接访问数据,需要for循环来获取数据
#查看第一批数据
for i, data in enumerate(loader):
    break  #只循环一次
i
data  #data是一个字典, 包含'input_ids'、'token_type_ids'、'attention_mask'、'labels'

 

{'input_ids': tensor([[ 101, 2708, 4324, 2406,  782, 1777, 1905, 3918,  117, 7345, 5125, 7346,
         7790, 6387, 4685, 2192,  119,  738, 4761, 5632, 3300, 1921, 1045, 1762,
          117, 6475,  955,  865, 6778, 4212, 2769, 1412,  119,  102,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0],
        [ 101, 1921, 4495,  671, 4954,  117, 5966, 1434, 3369, 7755,  119, 7755,
         3323, 2768, 1759,  117, 1759, 5543, 4495, 4289,  119, 5310,  702, 5872,
         5701,  117, 2899, 6627, 2336, 1880,  119, 3719, 5564, 6762, 1726,  117,
         6631,  676,  686,  867,  119,  102,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0],
        [ 101,  753, 2399, 3736,  677, 6224, 3217, 2495,  117, 7564, 2682, 7028,
         3341, 2769, 3313, 1726,  119, 3922, 6862,  686, 7313, 6443, 3160, 2533,
          117, 4856, 2418, 4685, 6878,  684, 6124, 3344,  119,  102,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0],
        [ 101, 3926, 7599,  711, 2769, 6843, 2495, 5670,  117, 3144, 5108, 7471,
         3351, 6629, 5946, 4170,  119, 2359, 2512, 2661, 7607, 4904, 3717,  100,
          117, 3587, 1898, 3009, 3171, 1911, 7345, 6068,  119, 1126,  782, 2157,
         1762, 3983, 1928, 2279,  117,  671, 4275,  756, 4495, 3717, 2419, 1921,
          119, 4007, 4706, 5679, 3301, 3187, 1962, 6983,  117, 3634, 2552, 2347,
         2899,  736, 3736, 6804,  119,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[ 101, 2708, 4324, 2406,  782, 1777, 1905, 3918,  117, 7345, 5125, 7346,
         7790, 6387, 4685, 2192,  119,  738, 4761, 5632, 3300, 1921, 1045, 1762,
          117, 6475,  955,  865, 6778, 4212, 2769, 1412,  119,  102,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0],
        [ 101, 1921, 4495,  671, 4954,  117, 5966, 1434, 3369, 7755,  119, 7755,
         3323, 2768, 1759,  117, 1759, 5543, 4495, 4289,  119, 5310,  702, 5872,
         5701,  117, 2899, 6627, 2336, 1880,  119, 3719, 5564, 6762, 1726,  117,
         6631,  676,  686,  867,  119,  102,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0],
        [ 101,  753, 2399, 3736,  677, 6224, 3217, 2495,  117, 7564, 2682, 7028,
         3341, 2769, 3313, 1726,  119, 3922, 6862,  686, 7313, 6443, 3160, 2533,
          117, 4856, 2418, 4685, 6878,  684, 6124, 3344,  119,  102,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0],
        [ 101, 3926, 7599,  711, 2769, 6843, 2495, 5670,  117, 3144, 5108, 7471,
         3351, 6629, 5946, 4170,  119, 2359, 2512, 2661, 7607, 4904, 3717,  100,
          117, 3587, 1898, 3009, 3171, 1911, 7345, 6068,  119, 1126,  782, 2157,
         1762, 3983, 1928, 2279,  117,  671, 4275,  756, 4495, 3717, 2419, 1921,
          119, 4007, 4706, 5679, 3301, 3187, 1962, 6983,  117, 3634, 2552, 2347,
         2899,  736, 3736, 6804,  119,  102]])}

3  创建模型 

#LM:语言模型
#AutoModelForCausalLM 语言模型的加载器
# from transformers import AutoModelForCausalLM, GPT2Model
from transformers import AutoModelForCausalLM
#加载模型
model = AutoModelForCausalLM.from_pretrained('../data/model/gpt2-chinese-cluecorpussmall/')


#查看加载的预训练模型的参数量
print(sum(p.numel() for p in model.parameters()))
102068736
#试算预测一下
with torch.no_grad():  #模型预测时,参数不需要梯度下降
    #outs是一个元组,包含'loss'(损失)和'logits'(概率)
    outs = model(**data)   


outs['logits'].shape
#4:batch_size
#197:每个句子的序列长度
#21128:每个词对应的21128(vocab_size)个词概率

 

torch.Size([4, 66, 21128])
outs['loss'], outs['logits']

 

(tensor(8.5514),
 tensor([[[ -9.9143,  -9.7647,  -9.8217,  ...,  -9.6961,  -9.7799,  -9.6771],
          [ -7.4731,  -8.7423,  -8.4802,  ...,  -8.2767,  -8.6411,  -9.1488],
          [ -8.7324,  -9.3639,  -9.3685,  ...,  -9.7467,  -9.2594,  -9.9237],
          ...,
          [ -3.6951,  -3.9939,  -4.2000,  ...,  -4.2021,  -4.6660,  -4.4627],
          [ -3.7271,  -4.0562,  -4.2753,  ...,  -4.2301,  -4.7670,  -4.5282],
          [ -3.6152,  -3.9949,  -4.1994,  ...,  -4.1643,  -4.6812,  -4.4797]],
 
         [[ -9.9143,  -9.7647,  -9.8217,  ...,  -9.6961,  -9.7799,  -9.6771],
          [ -8.5889,  -9.2279,  -9.2168,  ...,  -8.6957,  -8.1567,  -8.5526],
          [ -8.8908,  -8.8825,  -8.7488,  ...,  -9.8976,  -9.4964, -10.1446],
          ...,
          [ -3.8280,  -3.7346,  -4.4447,  ...,  -3.8380,  -4.3585,  -4.2275],
          [ -4.0099,  -3.8985,  -4.6581,  ...,  -3.9868,  -4.5486,  -4.3698],
          [ -3.8161,  -3.7165,  -4.4473,  ...,  -3.8579,  -4.3969,  -4.2764]],
 
         [[ -9.9143,  -9.7647,  -9.8217,  ...,  -9.6961,  -9.7799,  -9.6771],
          [ -7.7595,  -8.7731,  -8.8029,  ...,  -9.2167,  -8.4741,  -8.4485],
          [ -9.1754,  -8.8637,  -9.1363,  ...,  -8.7321,  -8.7189,  -8.9582],
          ...,
          [ -3.7426,  -4.1014,  -4.2192,  ...,  -4.3925,  -4.5313,  -4.6184],
          [ -3.8279,  -4.2058,  -4.3173,  ...,  -4.4447,  -4.6614,  -4.6665],
          [ -3.7733,  -4.1448,  -4.2570,  ...,  -4.4249,  -4.6132,  -4.6397]],
 
         [[ -9.9143,  -9.7647,  -9.8217,  ...,  -9.6961,  -9.7799,  -9.6771],
          [ -6.8225,  -7.6599,  -7.4913,  ...,  -7.5897,  -7.4440,  -7.5681],
          [ -7.3068,  -7.6038,  -7.2369,  ...,  -7.8313,  -8.0071,  -7.8388],
          ...,
          [ -5.6309,  -5.6956,  -5.5271,  ...,  -5.4339,  -5.3443,  -5.6756],
          [ -6.4130,  -6.3038,  -6.4816,  ...,  -6.5781,  -6.2063,  -6.4139],
          [ -3.6458,  -4.0801,  -3.7062,  ...,  -4.2418,  -3.9411,  -4.0311]]]))

4  训练过程代码 

from transformers import AdamW
from transformers.optimization import get_scheduler  #学习率的衰减策略


#训练
def train():
    #model是在此函数外部创建的,在此函数内调用前,需要声明model是全局变量
    global model
    #设置设备
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    #将模型传到设备上
    model = model.to(device)
    
    #创建梯度下降的优化器
    optimizer = AdamW(model.parameters(), lr=5e-5)      #lr=0.00005,  -5表示有5位小数
    #创建学习率衰减计划
    scheduler = get_scheduler(name='linear',   #线性的
                              num_warmup_steps=0,  #学习率从一开始就开始衰减,没有预热缓冲期
                              num_training_steps=len(loader),  #loader中有多少批数据就训练多少次
                              optimizer=optimizer)
    model.train()
    for i, data in enumerate(loader):
        for k in data.key():
            #将字典data中每个key所对应的value都传到设备上,再赋值给data[k],相当于把data传到了设备上
            data[k] = data[k].to(device)
        #将设备上的data传入模型中,获取输出结果outs(一个字典,包含loss和logits(概率分布))
        outs = model(**data)  #data是一个字典, **data将字典解包成关键字参数传入
        
        #从outs中获取损失,在训练过程中观察loss是不是在下降,不下降就是不正常
        loss = outs['loss']
        #反向传播
        loss.backward()
        #为了梯度下降的稳定性,防止梯度太大,进行梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters, 1.0)  #公式中的c最大值就是1
        #梯度更新
        optimizer.step()
        scheduler.step()
        #梯度清零
        optimizer.zero_grad()
        model.zero_grad()
        
        
        if i % 1000 == 0:  #每1000个步数,就输出打印内容
            #下一句诗句是上一句的预测目标真实值,有一个偏移量
            labels = data['labels'][:, 1:]
            #预测值
            outs = outs['logits'].argmax(dim=2)[:, :-1]
            
            #筛选条件
            select = labels != 0   #0是补得pad没有意义,需要筛掉
            #分别对labels和outs进行筛选
            labels = labels[select]
            outs = outs[select]
            del select  #后面这个变量没有用了, 删除防止占用过多内存
            
            #计算准确率
            #labels.numel()  求labels内元素的总个数
            #.item() 在pytorch中,取出tensor标量的数值
            cccuracy = (labels == outs).sum().item() / labels.numel()  
            #取出学习率
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            print(i, loss.item(), lr, accuracy)



train()
    

 5  保存训练好的模型

#保存训练好的模型
# model = model.to('cpu')  #将模型传到设备上
# torch.save(model, 'model.pt')

 6  加载保存好的模型

# model_2 = torch.laod('../data/model/AI-Poem-save.model')

 7  测试预测阶段代码

def generate(text, row, col, model):
    """
    text:传入的数据
    row, col:预测的诗句是几行几列的
    model:使用的是哪个模型来预测
    """
    def generate_loop(data):
        """循环来预测"""
        #模型预测时,不需要求导来反向传播
        with torch.no_grad():
            outs = model(**data)
            
        #从outs中获取分类概率,   输出形状与输入形状一致,所以batch_size在后面
        # outs形状 [5(五言诗,序列长度), batch_size, vocab_size]
        outs = outs['logits']
        #outs形状 [5(五言诗,序列长度), vocab_size]
        #只取一个元素会把对应的维度降调
        outs = outs[:, -1]  #最后一个是预测值

        #写诗:预测概率最高的词不一定是最合适的
        #取出概率较高的前50个
        #[5, vocab_size]  --> [5, 50]
        topk_value = torch.topk(outs, 50).values   #按从小到大排序的
        #取最后一个就是概率最大的那一个
        #[5, 50] --> [5] ,升维--> [5, 1]
        topk_value = topk_value[:, -1].unsqueeze(dim=1)

        #赋值    # -float('inf')负无穷大 ,表示没有意义
        outs = outs.masked_fill(outs < topk_value, -float('inf')) 

        #不允许写特殊字符, 将其赋值为负无穷大
        outs[:, tokenizer.sep_token_id] = -float('inf')  #分隔符
        outs[:, tokenizer.unk_token_id] = -float('inf')  #未知字符
        outs[:, tokenizer.pad_token_id] = -float('inf')  #填充pad

        for i in ',。':
            outs[:, tokenizer.get_vocab()[i]] = -float('inf')

        #根据概率做一个无放回的采样:不会出现重复的数据
        #[5, vocab_size] ---> [5, 1]
        outs = outs.softmax(dim=1)
        outs = outs.multinomial(num_sample=1)  #从中取一个

        #强制添加标点
        c = data['input_ids'].shape[1] / (col + 1)
        #若c为整数
        if c % 1 == 0:
            #若为偶数行
            if c % 2 == 0:
                outs[:, 0] = tokenizer.get_vocab()['。']
            else:
                outs[:, 0] = tokenizer.get_vocab()[',']

        #将原始的输入数据和预测的结果拼到一起, 当做下一次预测的输入, 依次循环
        data['input_ids'] = torch.cat([data['input_ids'], outs], dim=1)
        data['attention_mask'] = torch.ones_like(data['input_ids'])
        data['token_type_ids'] = torch.zeros_like(data['input_ids'])
        data['labels'] = data['input_ids'].clone()

        # row * col + 1   : 总字数+标点符号
        if data['input_ids'].shape[1] >= row * col + 1:
            return data
        return generate_loop(data)

    #重复三遍:一次生成三首,一次生成的效果可能不太好
    data = tokenizer.batch_encode_plus([text]*3, return_tensors='pt')
    data['input_ids'] = data['input_ids'][:, :-1]  #最后一个不要
    data['attention_mask'] = torch.ones_like(data['input_ids'])
    data['token_type_ids'] = torch.zeros_like(data['input_ids'])
    data['labels'] = data['input_ids'].clone()

    data = generate_loop(data)

    for i in range(3):
        #一次生成三首,按索引打印输出其中一首
        print(i, tokenizer.decode(data['input_ids'][i]))
model_2 = torch.load('../data//model/AI-Poem-save.model')

generate('秋高气爽', row=4, col=7, model=model_2)

0 [CLS] 秋 高 气 爽 雁 初 飞 , 云 树 高 峰 落 叶 稀 。 人 尽 夜 归 山 外 宿 , 鸡 鸣 霜 月 下 寒 衣 。
1 [CLS] 秋 高 气 爽 木 生 秋 , 何 处 仙 方 未 可 求 。 莫 遣 夜 猿 催 老 去 , 东 风 吹 老 上 林 丘 。
2 [CLS] 秋 高 气 爽 早 蝉 喧 , 清 籁 无 声 响 自 喧 。 野 望 岂 容 云 梦 见 , 江 涵 应 属 月 华 昏 。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值