在前两章中,我们介绍了如何准备数据集以及搭建基本的语言模型架构。现在,我们进入到模型的训练阶段(这一章很简单)
训练资源
提供鲁迅作品数据集:数据集文件源URL
和MPTS数据集:数据集文件2源URL
(mpts数据集虽然是json格式,但可以不用预处理,直接将带有json语法的数据集拿去训练也行)
代码概览
本篇博文将通过分析提供的代码片段来介绍如何训练一个语言模型。该代码包含了一些实用的功能,比如损失记录、模型检查点保存、实时绘图展示训练过程中的损失变化等。
准备工作
确保你已经完成了以下准备工作:
- 数据集已经被正确地处理和准备好。
- 模型架构已经被定义好。
- 导入了必要的库(例如
matplotlib
和tensorflow
)。
训练循环
训练循环是模型训练的核心部分。在我们的代码中,训练循环被组织在一个简单的 for
循环中,它遍历每个 epoch,并在每个 epoch 中遍历数据集的每一个 batch。
for epoch in range(ste, EPOCHS):
for ii, dd in enumerate(dataset):
# ...
在这个循环中,我们首先检查是否需要跳过某些 batch,然后获取当前 batch 的数据并预处理为模型所需的输入格式。接下来,我们调用模型的 train_on_batch
方法来进行一次前向传播和反向传播,更新模型权重以最小化损失函数。
input_sequences, target_sequences = preprocess_data_for_training(dd)#这个函数上一张定义了
loss = model.train_on_batch(input_sequences, target_sequences)#训练1个batch的数据
监控训练进度
为了监控训练过程,我们记录每次训练迭代的损失值,并将其可视化显示出来。这有助于我们了解模型的学习情况。
vli.append(loss)
plt.plot(range(1, len(vli) + 1), vli)
plt.show(block=False)
plt.pause(0.05)
模型检查点与保存
在训练过程中,我们通常会定期保存模型的权重,以便在训练中断或完成时能够恢复或评估模型。此外,我们还允许用户通过文件指示来控制模型的保存。
if ii % save_every == 0:
model.save_weights(ncheckpoint_prefix)
elif 'true' in s or '1' in s:
model.save_weights(ncheckpoint_prefix)
测试模型
在训练过程中,我们还可以定期测试模型的生成能力,这对于验证模型的有效性非常有帮助。
if 'true' in sp or '1' in sp:
m, d = test.load(ckpt_dir=ncheckpoint_dir, model_type=mt)
ret = test.generate_texts_fast(m, d, yw, num_generate=16, ret_ori=False)
print('原文:', yw, '\n\n测试续写:', repr(ret))
总结
在本篇博文中,我们介绍了如何设置训练循环来训练一个语言模型,包括监控训练进度、保存模型检查点、以及进行中间测试。这些步骤对于训练任何深度学习模型都是非常重要的,希望这篇博文能帮助你更好地理解和实践模型训练的过程。
下一步
接下来,你可以尝试运行这段代码并观察模型的训练过程。调整超参数如学习率、批次大小等,看看它们如何影响模型的表现。此外,还可以尝试使用不同的数据集或模型架构来进一步提高模型的性能(详见下几章)。
基础模型1-3的完整代码(代码了一下,导入库的步骤和文件路径可以自己加),调用train函数就可以直接训练了:
训练函数以及后面的调用模型推理的函数会使用配置变量dic,格式如下:
- dic={0.01:[int(256420.7),int(1024420.7),512],
}- 解释:{mt(model_type):[embedding_dim, rnn_units]}
【你可以把它写到conf.pkl里面,每次调用dic的时候用pickle反序列化一下。】
- 解释:{mt(model_type):[embedding_dim, rnn_units]}
代码有问题可以私信联系我,临时调整的格式。我用的tf版本是2.10.1。
def train(
mt=3,
big_file=False,#是否采用大文件加载策略
#数据集
path_to_file = r'en_novel.txt',
ntype_='_en',#保存为微调模型名称
#设置vocab版本
vtype_='_lx',#type_#
fen=50,#数据量分几份
fwidx=0,#第几份
BATCH_SIZE = 64,
loadmodel=False,
pass_=-1,
ste=0,
):'''
多出的参数不必理会,后面会用到
'''
global LR,param_data,p_ntype
p_ntype=ntype_
if ntype_[0]!='_':ntype_='_'+ntype_
type_=ntype_
print('path_to_file',path_to_file)
print('LR',LR)
import os
#dataset与vocab是配对的!
if not os.path.exists(r'dataset/vocab'+vtype_+'.txt'):
raise Exception("can't reading vocab from "+r'E:\小思框架\论文\ganskchat\vocab'+vtype_+'.txt')
else:
with open('dataset/vocab'+vtype_+'.txt','r',encoding='utf-8') as f:
vocab=eval(f.read())
UNK=0
unkli=[]
char2idx = {
u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)
print('{')
for char,_ in zip(char2idx, range(20)):
print(' {:4s}: {:3d},'.format(repr(char), char2idx[char]))
print(' ...\n}')
# 设定每个输入句子长度的最大值
seq_length = dic[mt][2]
def split_input_target(chunk):
input_text = chunk[: