【基础模型】开始构建我们自己的大语言模型3:训练我们的模型(内附免费完整训练资源)

在前两章中,我们介绍了如何准备数据集以及搭建基本的语言模型架构。现在,我们进入到模型的训练阶段(这一章很简单)

训练资源

提供鲁迅作品数据集:数据集文件源URL
和MPTS数据集:数据集文件2源URL
(mpts数据集虽然是json格式,但可以不用预处理,直接将带有json语法的数据集拿去训练也行)

代码概览

本篇博文将通过分析提供的代码片段来介绍如何训练一个语言模型。该代码包含了一些实用的功能,比如损失记录、模型检查点保存、实时绘图展示训练过程中的损失变化等。

准备工作

确保你已经完成了以下准备工作:

  • 数据集已经被正确地处理和准备好。
  • 模型架构已经被定义好。
  • 导入了必要的库(例如 matplotlibtensorflow)。
训练循环

训练循环是模型训练的核心部分。在我们的代码中,训练循环被组织在一个简单的 for 循环中,它遍历每个 epoch,并在每个 epoch 中遍历数据集的每一个 batch。

for epoch in range(ste, EPOCHS):
    for ii, dd in enumerate(dataset):
        # ...

在这个循环中,我们首先检查是否需要跳过某些 batch,然后获取当前 batch 的数据并预处理为模型所需的输入格式。接下来,我们调用模型的 train_on_batch 方法来进行一次前向传播和反向传播,更新模型权重以最小化损失函数。

input_sequences, target_sequences = preprocess_data_for_training(dd)#这个函数上一张定义了
loss = model.train_on_batch(input_sequences, target_sequences)#训练1个batch的数据
监控训练进度

为了监控训练过程,我们记录每次训练迭代的损失值,并将其可视化显示出来。这有助于我们了解模型的学习情况。

vli.append(loss)
plt.plot(range(1, len(vli) + 1), vli)
plt.show(block=False)
plt.pause(0.05)
模型检查点与保存

在训练过程中,我们通常会定期保存模型的权重,以便在训练中断或完成时能够恢复或评估模型。此外,我们还允许用户通过文件指示来控制模型的保存。

if ii % save_every == 0:
    model.save_weights(ncheckpoint_prefix)
elif 'true' in s or '1' in s:
    model.save_weights(ncheckpoint_prefix)
测试模型

在训练过程中,我们还可以定期测试模型的生成能力,这对于验证模型的有效性非常有帮助。

if 'true' in sp or '1' in sp:
    m, d = test.load(ckpt_dir=ncheckpoint_dir, model_type=mt)
    ret = test.generate_texts_fast(m, d, yw, num_generate=16, ret_ori=False)
    print('原文:', yw, '\n\n测试续写:', repr(ret))
总结

在本篇博文中,我们介绍了如何设置训练循环来训练一个语言模型,包括监控训练进度、保存模型检查点、以及进行中间测试。这些步骤对于训练任何深度学习模型都是非常重要的,希望这篇博文能帮助你更好地理解和实践模型训练的过程。

下一步

接下来,你可以尝试运行这段代码并观察模型的训练过程。调整超参数如学习率、批次大小等,看看它们如何影响模型的表现。此外,还可以尝试使用不同的数据集或模型架构来进一步提高模型的性能(详见下几章)。

基础模型1-3的完整代码(代码了一下,导入库的步骤和文件路径可以自己加),调用train函数就可以直接训练了:

训练函数以及后面的调用模型推理的函数会使用配置变量dic,格式如下:

  • dic={0.01:[int(256420.7),int(1024420.7),512],
    }
    • 解释:{mt(model_type):[embedding_dim, rnn_units]}
      【你可以把它写到conf.pkl里面,每次调用dic的时候用pickle反序列化一下。】

代码有问题可以私信联系我,临时调整的格式。我用的tf版本是2.10.1。

                                            
def train(
    mt=3,
    big_file=False,#是否采用大文件加载策略
    #数据集
    path_to_file = r'en_novel.txt',
    ntype_='_en',#保存为微调模型名称
    #设置vocab版本
    vtype_='_lx',#type_#
    fen=50,#数据量分几份
    fwidx=0,#第几份
    BATCH_SIZE = 64,
    loadmodel=False,
    pass_=-1,
    ste=0,
    ):'''
    多出的参数不必理会,后面会用到
    '''
    global LR,param_data,p_ntype
    p_ntype=ntype_
    


    if ntype_[0]!='_':ntype_='_'+ntype_
    type_=ntype_
    print('path_to_file',path_to_file)
    print('LR',LR)

    

    import os
    #dataset与vocab是配对的!
    if not os.path.exists(r'dataset/vocab'+vtype_+'.txt'):
        raise Exception("can't reading vocab from "+r'E:\小思框架\论文\ganskchat\vocab'+vtype_+'.txt')
    else:
        with open('dataset/vocab'+vtype_+'.txt','r',encoding='utf-8') as f:
            vocab=eval(f.read())

    UNK=0
    unkli=[]
    
    char2idx = {
   u:i for i, u in enumerate(vocab)}
    idx2char = np.array(vocab)

    print('{')
    for char,_ in zip(char2idx, range(20)):
        print('  {:4s}: {:3d},'.format(repr(char), char2idx[char]))
    print('  ...\n}')

    # 设定每个输入句子长度的最大值
    seq_length = dic[mt][2]
    
    def split_input_target(chunk):
        input_text = chunk[:
  • 17
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值