许多人可能感到在新时代中有些不适应。面对新颖算法和高实验条件,似乎难以适应这个新世界。然而,有一种方法可以在“一无所有”的状态下开始探索。更新思维,勇敢成为“调包侠”,利用Hugging Face的Transformers等现代包,让人能够轻松上手。

相信很多人和我一样都是从“古典深度学习”时代一路走来的,面对当今“LLM才是AI“的时代,有着很多的不适应症状。看着那么多的新论文里的算法,想从头开始撸一遍发现要学习一大堆有的没的的东西,然后 paper 里提到的实验条件还吓人的高,似乎这个喧嚣的新世界和自己有点遥不可及。

那么,有没有什么办法能在“一无所有”的状态下做点啥呢?

首先,进入新时代了,要更新好自己的思想,放心大胆的当“调包侠”。而现在的确有很多很科学的包!

其中以 hugging face 的 transformers 为集大成者。实际上,从自娱自乐的 toy example,到一系列还挺有影响力的项目比如 Vicuna,LLAVA,翻开代码库,都能看到那行金光闪闪的

import transformers
  • 1.

甚至,翻遍整个repo,都找不到“网络结构”写在哪里,只有一行

from transformers import LlamaForCausalLM
  • 1.

在等着你。

而在 2024 年了,这些包的安装也没那么“阴间”,直接 pip install 回来的 transformers,accelerate,就是亲测能用的(当然,假设你已经把 pytorch、CUDA 的安装和修bug搞定了)。反正我在玩的时候真就是装了就能用,没啥玄学。

对于模型训练 ,官方文档并没有一个“最小集”的样本,不过对着文档琢磨一下还是很容易写出来的:

Transformer~66_数据

嗯,是的,一共就16行,配好两个对象,然后 Trainer.train() 就成了…… 我第一次用的时候没配 save_steps 导致跑完了不知道模型存哪了,查了一下文档才搞明白。

其中,dataset对象 就是 torch.utils.data.Dataset,要实现 __ len __ 和 __ getitem __ 的接口,这个自己搞搞就好了。

而要使用一个训练好的模型,直接

import transformers
model = transformers.GPT2LMHeadModel.from_pretrained(ifname)
  • 1.
  • 2.

都不用自己去手工维护模型的超参数列表,这个库在 checkpoint 文件夹里已经自己按约定存好了。

然后肯定有人会提出质疑了,这么直接搞真的严肃么?
不过,让我们回想一下为啥 CV 类的模型的训练代码都那么“复杂”,然后就能发现现在的这种搞法的“科学”之处了:

  • 数据集。CV 类模型往往要使用大量的压缩(jpg/h264)后的数据并进行大量的增广来进行训练,训练过程中总的 IO 吞吐是恐怖的。于是乎几乎人人都要会写分布式 data provider 框架,甚至还要自己搞专门的数据提供的加速办法(例如,多个实验共享data provider)。然而,在 NLP 领域,就算是大如 GPT level 的模型,也就几 T tokens 就能训明白了,弄个 SSD 本地存一下挺香的。。。最多是 tokenizer 之类的预处理部分折腾一下。因为每 token 的训练 flo 的巨大区别,在这个地方花费的精力是可以极大的节省的。
  • 调参。CNN 类的模型,好多地方调一调参还是挺有用的。而来到了 2024 年的 transformers,模型一方面很“简单”,大量的 transformer block 的堆叠,想乱改涨点还不好涨;另一方面又很“复杂 ”,面对着跑一次要花很多 $ 的基础模型调参实验(比如你想把gelu换回relu?),穷人也就不用去想着折腾了。开源世界调出来的那一组“默认参数”,其实还挺 work 的,于是乎就成了一个“非必要别乱调”的格局。
  • 表示。搞 Vision 的模型“如何表示输入输出”是一门大学问,改一点点影响巨大。而搞序列建模的时候……输入输出就是 tokens 的序列,torch.LongTensor,没了。。管你的 tokens 里是画了一棵树,建了一张图,还是存了一大段 json ,在 NN 这边看来,都(至少在观念上)是一样一样的。
  • 部署。“古典”AI模型的一大主题是如何在端上跑,所以就产生了一大堆的模型打包、静态化、计算图转化等等的 concern。不过在“新时代”,还远没有卷到这个地步,能 python 跑就还算可以了,这也就减少了大量为“训推一体”付出的精力。

transformers 里的 Trainer 集成了一些很科学的默认行为,例如定期存 checkpoint 放到实验名的文件夹下面,把各种曲线信息同步到 tensorboard、clearml、wandb 等监控软件,啥参数都不传也是可以接受的选择。

而如果你就是要调参,在 TrainingArguments、Trainer、GPT2Config 等地方一共有 138 个可以传的参数,以及大量通过调方法来填的参数,能满足不少的需要。比如,可以传一个 fp16 = True 来“一键”加速训练,而它背后是 apex.amp 这种库在支撑。

所以来说,把“靠谱”的库“集成”在一起,暴露出一组带有合理默认参数的接口,这个方法论在 2024 年来看的确还是可行的。当层层调用的这些库不“炸”的时候,使用体验相当的丝滑。

然而,如果“炸”了,或者想搞一些比较深度的 hack (比如定制某个算子在训练的时候触发一个神奇的行为),就会发现,在一行看似无辜的网络定义之上,还摞着层层叠叠的 加速库、优化库、分布式库、混合精度库等等,想去“一层层debug”就成了一件恐怖的事情。

好不容易(?)学会了序列建模这个技能(又称:transformers 库的安装与使用),不整点啥活似乎没意思。

于是,我用我的笔记本电脑(带3060显卡)训练了一个 GPT2-small 级别的模型来做 C++ 代码的补全。

我从 github 上找了一些经典的 c/c++ 的 repo(如,linux kernel,gcc,cpython,等),收集了 1G 多的 .zip,从里面找出所有的 c类源文件,形成了一个 2.2 GB 的训练集。

不做tokenization,23万步,长度 512,批大小 4,训练大概需要不到一天。
然后搞了一个贪心的补全推荐的方法,试了试,好像也不是完全不能用:

视频发不了...

训练出来的语言模型压缩率大概是 0.7 bits / byte,比 xz -9 还是要强一点的。

考虑到实现出这样的效果,从搜索库名字到实现出来只需要一天,这还是挺惊人的了。
当然,一旦想从这个 baseline 开始提升效果,那么 language model 里各种考量的的“洪流”就要来了:

  • 用 tokenizer 进行文本的预压缩,提升 token 的效率(window size是一种昂贵的资源),同时让网络不用去学习各种“显然”的东西。各种 tokenizer 以及训练加速的方法就可以选一选了
  • 清洗数据,减少类似训练数据被 linux kernel 里的 #define 淹没带来的副作用,提升“高质量”数据的占比
  • 更大的模型!找更多更好的卡!另外还要搞预训练
  • 最优的批大小、并行数、learning rate调整,对着 loss 曲线“围炉炼丹”
  • 把 suffix 预测改为 infix 预测,加入文件名、函数名等的各种 hint
  • 考虑用户反馈数据的利用,考虑加入用户问答,考虑针对用户的 finetune
  • 等等等等……

这就是一个很深很深的坑了,而 copilot 之类就是“登峰造极”之后的产物。想往这方面卷,就会一步步走入“去哪融资,去哪买卡,国产半导体何时崛起”的无尽焦虑中。

不过好的一面在于,如果我们不把自己当做一个 “训 LLM” 的人,而是当一个 “用 transfomers 库进行序列建模” 的,就会发现事情也没那么坏。虽然几十M的“小”(对 vision 来说其实不小了)模型并不会表现出“大”模型的一些独特能力,但是也已经可以做一些很有意义的事情了(我还真没试过用 CNN 硬怼上面的这个 demo 是什么效果……)。

刨开训练 transformer 类模型的一些新的技巧,当前这个时代其实最关键的问题就变成了:手头有啥好的问题,可以表达成一个序列建模吗?

如果有,或者原来的某些“老”问题可以这么表示,那不妨用 transformer 们来试一试,也许就有全新的可能。 

新的时代,就应该去拥抱新的方法。

最后,附代码补全demo的全套代码:训练

train.py

import transformers
import numpy as np
import torch
import os
import random
import hashlib

class CodeDataset(torch.utils.data.Dataset):
 def __init__(self, name, ctxlen, totalnum):
  super().__init__()
  self.name = name
  fname = './data/codes.txt'
  self.content_len = os.path.getsize(fname)
  self.fin = open(fname, 'rb')
  self.ctxlen = ctxlen
  self.length = totalnum

 def __len__(self):
  return self.length
 def __getitem__(self, index):
  key = int(hashlib.md5((self.name + str(index)).encode('utf8')).hexdigest()[-16:], 16)
  rnd = random.Random(key)
  while True:
   idx = rnd.randrange(self.content_len - self.ctxlen + 256)
   self.fin.seek(idx)
   buf = self.fin.read(self.ctxlen + 256)
   j = 1
   while j < 256 and buf[j - 1] != (b'\n')[0] and buf[j-1] != (b'\xff')[0]:
    j += 1
   if j >= 256:
    continue
   buf = buf[j:j + self.ctxlen - 1]
   if b'\xff' in buf[:-1]:
    continue
   break
  vec = np.frombuffer(b'\x00' + buf, dtype = 'uint8').astype('int64')
  vec = torch.from_numpy(vec)
  return {'input_ids' : vec, 'labels' : vec}

if __name__ == '__main__':
 import argparse
 parser = argparse.ArgumentParser(
  prog = 'train',
  description = 'train on dataset',
 )
 parser.add_argument('--n_embd', dest = 'n_embd', type = int, default = 768)
 parser.add_argument('--n_layer', dest = 'n_layer', type = int, default = 12)
 parser.add_argument('--n_head', dest = 'n_head', type = int, default = 12)
 parser.add_argument('--bsize', dest = 'bsize', type = int, default = 4)
 parser.add_argument('--ctxlen', dest = 'ctxlen', type = int, default = 512)
 parser.add_argument('--accstep', dest = 'accstep', type = int, default = 1)
 parser.add_argument('--compute', dest = 'compute', type = int, default = 10000)
 args = parser.parse_args()

 cfg = transformers.GPT2Config(vocab_size = 256, n_positions = args.ctxlen, n_ctx = args.ctxlen, n_embd = args.n_embd, n_layer = args.n_layer, n_head = args.n_head)
 model = transformers.GPT2LMHeadModel(cfg)
 nparam = sum([np.prod(i.shape) for i in model.parameters()])
 print('nparam', nparam)
 model = model.cuda()

 name = 'code_%dM_%dT_c%d_l%d_e%d_h%d_b%d_a%d'%(nparam // 1000000, args.compute, args.ctxlen, args.n_layer, args.n_embd, args.n_head, args.bsize, args.accstep)
 print(name)

 total_train_ops = 10 ** 12 * args.compute
 total_train_steps = total_train_ops // (nparam * args.bsize * args.accstep * args.ctxlen)
 print('steps', total_train_steps)

 train_ds = CodeDataset('train', args.ctxlen, 1048576)
 eval_ds = CodeDataset('valid', args.ctxlen, 1024)

 train_cfg = transformers.TrainingArguments(
  output_dir = name,
  num_train_epochs = total_train_steps * args.bsize * args.accstep / len(train_ds),
  save_steps = total_train_steps // 10,
  gradient_accumulation_steps = args.accstep,
  per_device_train_batch_size = args.bsize,
  per_device_eval_batch_size = args.bsize,
  evaluation_strategy = 'steps',
  eval_steps = (total_train_steps // 10),
  logging_strategy = 'steps',
  logging_steps = (total_train_steps // 100),
  report_to = 'none',
 )
 trainer = transformers.Trainer(
  model = model,
  args = train_cfg,
  train_dataset = train_ds,
  eval_dataset = eval_ds,
 )
 trainer.train()
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.
  • 70.
  • 71.
  • 72.
  • 73.
  • 74.
  • 75.
  • 76.
  • 77.
  • 78.
  • 79.
  • 80.
  • 81.
  • 82.
  • 83.
  • 84.
  • 85.
  • 86.
  • 87.
  • 88.
  • 89.
  • 90.

演示
demo.py

import transformers
import hashlib
import random
import torch
import numpy as np

def sugguest_from_model(model, prefix):
 with torch.no_grad():
  predicts = ''
  past_key_values = None
  cum_prob = 1
  best_len = 0
  best_sugguest = ''
  while len(predicts) < 32:
   if len(predicts) == 0:
    last_choice = np.int64([0] + [ord(i) for i in prefix])[None, :]
   ret = model.forward(
    input_ids = torch.from_numpy(last_choice),
    past_key_values = past_key_values,
    use_cache = True
   )
   past_key_values = ret.past_key_values
   logits = ret.logits.cpu().numpy()[0, -1]
   prob = logits - logits.max(axis=0, keepdims = True)
   prob = np.exp(prob)
   prob = prob / np.sum(prob, axis = 0, keepdims = True)
   cur_predict = prob.argmax()
   if cur_predict == 10:
    break
   last_choice = np.int64([[cur_predict]])
   cur_prob = prob[cur_predict]
   predicts = predicts + chr(cur_predict)
   cum_prob *= cur_prob
   if len(predicts) * cum_prob > best_len:
    best_len = len(predicts) * cum_prob
    best_sugguest = predicts
   if min(len(predicts) + 8, 32) * cum_prob < best_len:
    break

 return best_sugguest

if __name__ == '__main__':
 import getch
 import termios

 ifname = './code_85M_40000T_c512_l12_e768_h12_b4_a1/checkpoint-228040/'
 model = transformers.GPT2LMHeadModel.from_pretrained(ifname)

 def demo_type():
  print('loaded')
  attr = termios.tcgetattr(0)
  old_attr = attr[:]
  attr[3] = attr[3] & ~ (termios.ECHO | termios.ICANON)

  termios.tcsetattr(0, termios.TCSANOW, attr)

  lines = []
  cur_line = ''
  sugguest = ''
  try:
   while True:
    c = getch.getch()
    if len(sugguest) != 0:
     print(' '*len(sugguest) + '\b' * len(sugguest), end = '', flush = True)

    if ord(c) == 27:
     break
    if ord(c) == 127:
     if len(cur_line) > 0:
      if cur_line[-1] == '\t':
       print('\b'*8, end = '', flush = True)
      else:
       print('\b  \b', end = '', flush = True)
      cur_line = cur_line[:-1]
    elif c == '\n':
     print()
     lines.append(cur_line)
     cur_line = ''
    elif c == '\t' and (cur_line.replace('\t','') != ''):
     cur_line = cur_line + sugguest
     print(sugguest, end = '', flush = True)
    else:
     print(c, end = '', flush = True)
     cur_line += c

    def get_suggest():
     ctx = cur_line
     i = len(lines) - 1
     while i >= 0 and len(ctx) + len(lines[i]) + 1 < 200:
      ctx = lines[i] + '\n' + ctx
      i -= 1
     return sugguest_from_model(model, ctx)

    sugguest = get_suggest()
    if cur_line.replace('\t','') != '':
     pass
    elif cur_line == '' and ord(c) != 127 and len(lines) > 0 and lines[-1].startswith('\t'):
     tabs = ''
     for j in range(len(sugguest)):
      if sugguest[j] == '\t':
       tabs = tabs + sugguest[j]
      else:
       break
     cur_line = tabs
     print(tabs, end = '', flush = True)
     if len(tabs):
      sugguest = get_suggest()
     else:
      sugguest = ''
    else:
     sugguest = ''
    print('\033[0;38;5;248m\033[0;48;5;223m' + sugguest + '\b'*len(sugguest) + '\033[0;0m', end = '', flush = True)
  except KeyboardInterrupt:
   pass
  termios.tcsetattr(0, termios.TCSANOW, old_attr)
  if len(cur_line):
   lines.append(cur_line)
  print()
  for line in lines:
   print(line)
 demo_type()
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.
  • 70.
  • 71.
  • 72.
  • 73.
  • 74.
  • 75.
  • 76.
  • 77.
  • 78.
  • 79.
  • 80.
  • 81.
  • 82.
  • 83.
  • 84.
  • 85.
  • 86.
  • 87.
  • 88.
  • 89.
  • 90.
  • 91.
  • 92.
  • 93.
  • 94.
  • 95.
  • 96.
  • 97.
  • 98.
  • 99.
  • 100.
  • 101.
  • 102.
  • 103.
  • 104.
  • 105.
  • 106.
  • 107.
  • 108.
  • 109.
  • 110.
  • 111.
  • 112.
  • 113.
  • 114.
  • 115.
  • 116.
  • 117.
  • 118.
  • 119.
  • 120.
  • 121.