不久之前,特斯拉前 AI 总监、OpenAI 联合创始人 Andrej Karpathy 在社交媒体平台 X 上高调宣布,他友好地从 OpenAI 离职,未来将专注于“个人项目”,而后又口口声声说要尝试性地戒掉上网两周,去了 Bhutan(不丹王国)休假。然后就手搓如下代码,我们一起来欣赏.
这段代码是一个GPT-2模型的训练实现,包括前向传播、后向传播、参数更新等过程。代码中包含了注释,解释了每个函数和代码块的作用。此外,还有一些辅助函数和结构体,用于数据加载、随机数生成和模型的初始化与释放。这个训练循环还包含了验证损失的计算和文本生成的示例,以便在训练过程中检查模型的性能。
/*
这个文件训练了GPT-2模型。
这个版本是干净、最小、参考版本。因此:
- 它在CPU上运行。
- 它不会使代码过于复杂;它是可读的。
- 它不使用任何特定于处理器的指令、内联等。
- 它确实使用了一些OpenMP指令,因为这是在非常低的成本下获得巨大加速的一种方式
将会有其他版本的代码专门化并使其快速。
*/
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>
#include <string.h>
#include <unistd.h>
#ifdef OMP
#include <omp.h>
#endif
// ------------------------------------------------------------------------------
// 所有单独层的前向和后向传递
// 编码器前向传递
void encoder_forward(float* out, // 输出
int* inp, float* wte, float* wpe, // 输入和位置编码权重
int B, int T, int C) {
// 批次大小、序列长度、通道数
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
// 寻找输出位置
float* out_bt = out + b * T * C + t * C;
// 获取输入令牌的索引
int ix = inp[b * T + t];
// 寻找对应于令牌的wte中的位置
float* wte_ix = wte + ix * C;
// 寻找对应于位置的wpe中的位置
float* wpe_t = wpe + t * C;
// 将两个向量相加并将结果存储在out[b,t,:]
for (int i = 0; i < C; i++) {
out_bt[i] = wte_ix[i] + wpe_t[i];
}
}
}
}
// 编码器后向传递
void encoder_backward(float* dwte, float* dwpe, // 编码器权重的梯度
float* dout, int* inp, // 梯度输出和输入
int B, int T, int C) {
// 批次大小、序列长度、通道数
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dout_bt = dout + b * T * C + t * C;
int ix = inp[b * T + t];
float* dwte_ix = dwte + ix * C;
float* dwpe_t = dwpe + t * C;
for (int i = 0; i < C; i++) {
float d = dout_bt[i];
dwte_ix[i] += d;
dwpe_t[i] += d;
}
}
}
}
// 层归一化前向传递
void layernorm_forward(float* out, float* mean, float* rstd, // 输出、均值和标准差
float* inp, float* weight, float* bias, // 输入、权重和偏置
int B, int T, int C) {
// 批次大小、序列长度、通道数
float eps = 1e-5f;
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
// 寻找输入位置
float* x = inp + b * T * C + t * C;
// 计算均值
float m = 0.0f;
for (int i = 0; i < C; i++) {
m += x[i];
}
m = m / C;
// 计算方差(不带任何偏置校正)
float v = 0.0f;
for (int i = 0; i < C; i++) {
float xshift = x[i] - m;
v += xshift * xshift;
}
v = v / C;
// 计算标准差
float s = 1.0f / sqrtf(v + eps);
// 寻找输出位置
float* out_bt = out + b * T * C + t * C;
for (int i = 0; i < C; i++) {
float n = (s * (x[i] - m)); // 归一化输出
float o = n * weight[i] + bias[i]; // 缩放和平移它
out_bt[i] = o; // 写入
}
// 缓存均值和标准差以供后向传递使用
mean[b * T + t] = m;
rstd[b * T + t] = s;
}
}
}
// 层归一化后向传递
void layernorm_backward(float* dinp, float* dweight, float* dbias, // 输入、权重和偏置的梯度
float* dout, float* inp, float* weight, float* mean, float* rstd, // 梯度输出、输入、权重、均值和标准差
int B, int T, int C) {
// 批次大小、序列长度、通道数
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dout_bt = dout + b * T * C + t * C;
float* inp_bt = inp + b * T * C + t * C;
float* dinp_bt = dinp + b * T * C + t * C;
float mean_bt = mean[b * T + t];
float rstd_bt = rstd[b * T + t];
// 首先:两个reduce操作
float dnorm_mean = 0.0f;
float dnorm_norm_mean = 0.0f;
for (int i = 0; i < C; i++) {
float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = weight[i] * dout_bt[i];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean = dnorm_mean / C;
dnorm_norm_mean = dnorm_norm_mean / C;
// 现在再次迭代并累积所有梯度
for (int i = 0; i < C; i++) {
float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = weight[i] * dout_bt[i];
// 梯度对偏置的贡献
dbias[i] += dout_bt[i];
// 梯度对权重的贡献
dweight[i] += norm_bti * dout_bt[i];
// 梯度对输入的贡献
float dval = 0.0f;
dval += dnorm_i; // 项1
dval -= dnorm_mean; // 项2
dval -= norm_bti * dnorm_norm_mean; // 项3
dval *= rstd_bt; // 最终缩放
dinp_bt[i] += dval;
}
}
}
}
// 矩阵乘法前向传递
void matmul_forward(float* out, // 输出
float* inp, float* weight, float* bias, // 输入、权重和偏置
int B, int T, int C, int OC) {
// 批次大小、序列长度、通道数、输出通道数
// 大部分运行时间都在这里和matmul_backward中度过
// OC是"输出通道"的缩写
// inp是(B,T,C),weight是(OC, C),bias是(OC)
// out将是(B,T,OC)
#pragma omp parallel for collapse(2)
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* out_bt = out + b * T * OC + t * OC;
float* inp_bt = inp + b * T * C + t * C;
for (int o = 0; o < OC; o++) {
float val = (bias != NULL) ? bias[o] : 0.0f;
float* wrow = weight + o*C;
for (int i = 0; i < C; i++) {
val += inp_bt[i] * wrow[i];
}
out_bt[o] = val;
}
}
}
}
// 矩阵乘法后向传递
void matmul_backward(float* dinp, float* dweight, float* dbias, // 输入、权重和偏置的梯度
float* dout, float* inp, float* weight, // 梯度输出、输入、权重
int B, int T, int C, int OC) {
// 批次大小、序列长度、通道数、输出通道数
// 大部分运行时间都在这里和matmul_forward中度过
// 这个后向可以通过一个"轮次"的循环来完成
// 但这样的并行化策略并不高效
// 首先向后传递到inp,按B,T并行化
#pragma omp parallel for collapse(2)
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dout_bt = dout + b * T * OC + t * OC;
float* dinp_bt = dinp + b * T * C + t * C;
for (int o = 0; o < OC; o++) {
float* wrow = weight + o*C;
float d = dout_bt[o];
for (int i = 0; i < C; i++) {
dinp_bt[i] += wrow[i] * d;
}
}
}
}
// 向后传递到weight/bias,按输出通道OC并行化
#pragma omp parallel for
for (int o = 0; o < OC; o++) {
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dout_bt = dout + b * T * OC + t * OC;
float* inp_bt = inp + b * T * C + t * C;
float* dwrow = dweight + o*C;
float d = dout_bt[o];
if (dbias != NULL) {
dbias[o] += d; }
for (int i = 0; i < C; i++) {
dwrow[i] += inp_bt[i] * d;
}
}
}
}
}
// 注意力前向传递
void attention_forward(float* out, float* preatt, float* att, // 输出、前馈注意力和注意力矩阵
float* inp,
int B, int T, int C, int NH) {
// 批次大小、序列长度、通道数、头数
// 输入是(B, T, 3C) Q,K,V
// preatt, att是(B, NH, T, T)
// 输出是(B, T, C)
int C3 = C*3;
int hs = C / NH; // 头大小
float scale = 1.0 / sqrtf(hs);
#pragma omp parallel for collapse(3)
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
for (int h = 0; h < NH; h++) {
float* query_t = inp + b * T * C3 + t * C3 + h * hs;
float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T;
float* att_bth = att + b*NH*T*T + h*T*T + t*T;
// 步骤1:计算query dot key并找到最大值
float maxval = -10000.0f; // TODO 使用更好的方法
for (int t2 = 0; t2 <= t; t2++) {
float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C因为是key
// (query_t) dot (key_t2)
float val = 0.0f;
for (int i = 0; i < hs; i++) {
val += query_t[i] * key_t2[i];
}
val *= scale;
if (val > maxval) {
maxval = val;
}
preatt_bth[t2] = val;
}
// 步骤2:计算exp并跟踪总和
float expsum = 0.0f;
for (int t2 = 0; t2 <= t; t2++) {
float expv = expf(preatt_bth[t2] - maxval);
expsum += expv;
att_bth[t2] = expv;
}
float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum;
// 步骤3:归一化得到softmax
for (int t2 = 0; t2 < T; t2++) {
if (t2 <= t) {
att_bth[t2] *= expsum_inv;
} else {
// 因果注意力掩码。这里设置为零并不严格必要
// 只是为了调试和检查PyTorch
att_bth[t2] = 0.0f;
}
}
// 步骤4:累积加权值到注意力的输出
float* out_bth = out + b * T * C + t * C + h * hs;
for (int i = 0; i < hs; i++) {
out_bth[i] = 0.0f; }
for (int t2 = 0; t2 <= t; t2++) {
float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2因为是value
float att_btht2 = att_bth[t2];
for (int i = 0; i < hs; i++) {