openai联合创始人用1000行纯C语言手搓的GPT-2训练代码

不久之前,特斯拉前 AI 总监、OpenAI 联合创始人 Andrej Karpathy 在社交媒体平台 X 上高调宣布,他友好地从 OpenAI 离职,未来将专注于“个人项目”,而后又口口声声说要尝试性地戒掉上网两周,去了 Bhutan(不丹王国)休假。然后就手搓如下代码,我们一起来欣赏.

这段代码是一个GPT-2模型的训练实现,包括前向传播、后向传播、参数更新等过程。代码中包含了注释,解释了每个函数和代码块的作用。此外,还有一些辅助函数和结构体,用于数据加载、随机数生成和模型的初始化与释放。这个训练循环还包含了验证损失的计算和文本生成的示例,以便在训练过程中检查模型的性能。

/*
这个文件训练了GPT-2模型。
这个版本是干净、最小、参考版本。因此:
- 它在CPU上运行。
- 它不会使代码过于复杂;它是可读的。
- 它不使用任何特定于处理器的指令、内联等。
- 它确实使用了一些OpenMP指令,因为这是在非常低的成本下获得巨大加速的一种方式
将会有其他版本的代码专门化并使其快速。
*/

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>
#include <string.h>
#include <unistd.h>
#ifdef OMP
#include <omp.h>
#endif

// ------------------------------------------------------------------------------
// 所有单独层的前向和后向传递

// 编码器前向传递
void encoder_forward(float* out, // 输出
                   int* inp, float* wte, float* wpe, // 输入和位置编码权重
                   int B, int T, int C) {
    // 批次大小、序列长度、通道数
    for (int b = 0; b < B; b++) {
   
        for (int t = 0; t < T; t++) {
   
            // 寻找输出位置
            float* out_bt = out + b * T * C + t * C;
            // 获取输入令牌的索引
            int ix = inp[b * T + t];
            // 寻找对应于令牌的wte中的位置
            float* wte_ix = wte + ix * C;
            // 寻找对应于位置的wpe中的位置
            float* wpe_t = wpe + t * C;
            // 将两个向量相加并将结果存储在out[b,t,:]
            for (int i = 0; i < C; i++) {
   
                out_bt[i] = wte_ix[i] + wpe_t[i];
            }
        }
    }
}

// 编码器后向传递
void encoder_backward(float* dwte, float* dwpe, // 编码器权重的梯度
                      float* dout, int* inp, // 梯度输出和输入
                      int B, int T, int C) {
    // 批次大小、序列长度、通道数
    for (int b = 0; b < B; b++) {
   
        for (int t = 0; t < T; t++) {
   
            float* dout_bt = dout + b * T * C + t * C;
            int ix = inp[b * T + t];
            float* dwte_ix = dwte + ix * C;
            float* dwpe_t = dwpe + t * C;
            for (int i = 0; i < C; i++) {
   
                float d = dout_bt[i];
                dwte_ix[i] += d;
                dwpe_t[i] += d;
            }
        }
    }
}

// 层归一化前向传递
void layernorm_forward(float* out, float* mean, float* rstd, // 输出、均值和标准差
                       float* inp, float* weight, float* bias, // 输入、权重和偏置
                       int B, int T, int C) {
    // 批次大小、序列长度、通道数
    float eps = 1e-5f;
    for (int b = 0; b < B; b++) {
   
        for (int t = 0; t < T; t++) {
   
            // 寻找输入位置
            float* x = inp + b * T * C + t * C;
            // 计算均值
            float m = 0.0f;
            for (int i = 0; i < C; i++) {
   
                m += x[i];
            }
            m = m / C;
            // 计算方差(不带任何偏置校正)
            float v = 0.0f;
            for (int i = 0; i < C; i++) {
   
                float xshift = x[i] - m;
                v += xshift * xshift;
            }
            v = v / C;
            // 计算标准差
            float s = 1.0f / sqrtf(v + eps);
            // 寻找输出位置
            float* out_bt = out + b * T * C + t * C;
            for (int i = 0; i < C; i++) {
   
                float n = (s * (x[i] - m)); // 归一化输出
                float o = n * weight[i] + bias[i]; // 缩放和平移它
                out_bt[i] = o; // 写入
            }
            // 缓存均值和标准差以供后向传递使用
            mean[b * T + t] = m;
            rstd[b * T + t] = s;
        }
    }
}

// 层归一化后向传递
void layernorm_backward(float* dinp, float* dweight, float* dbias, // 输入、权重和偏置的梯度
                        float* dout, float* inp, float* weight, float* mean, float* rstd, // 梯度输出、输入、权重、均值和标准差
                        int B, int T, int C) {
    // 批次大小、序列长度、通道数
    for (int b = 0; b < B; b++) {
   
        for (int t = 0; t < T; t++) {
   
            float* dout_bt = dout + b * T * C + t * C;
            float* inp_bt = inp + b * T * C + t * C;
            float* dinp_bt = dinp + b * T * C + t * C;
            float mean_bt = mean[b * T + t];
            float rstd_bt = rstd[b * T + t];

            // 首先:两个reduce操作
            float dnorm_mean = 0.0f;
            float dnorm_norm_mean = 0.0f;
            for (int i = 0; i < C; i++) {
   
                float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
                float dnorm_i = weight[i] * dout_bt[i];
                dnorm_mean += dnorm_i;
                dnorm_norm_mean += dnorm_i * norm_bti;
            }
            dnorm_mean = dnorm_mean / C;
            dnorm_norm_mean = dnorm_norm_mean / C;

            // 现在再次迭代并累积所有梯度
            for (int i = 0; i < C; i++) {
   
                float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
                float dnorm_i = weight[i] * dout_bt[i];
                // 梯度对偏置的贡献
                dbias[i] += dout_bt[i];
                // 梯度对权重的贡献
                dweight[i] += norm_bti * dout_bt[i];
                // 梯度对输入的贡献
                float dval = 0.0f;
                dval += dnorm_i; // 项1
                dval -= dnorm_mean; // 项2
                dval -= norm_bti * dnorm_norm_mean; // 项3
                dval *= rstd_bt; // 最终缩放
                dinp_bt[i] += dval;
            }
        }
    }
}

// 矩阵乘法前向传递
void matmul_forward(float* out, // 输出
                    float* inp, float* weight, float* bias, // 输入、权重和偏置
                    int B, int T, int C, int OC) {
    // 批次大小、序列长度、通道数、输出通道数
    // 大部分运行时间都在这里和matmul_backward中度过
    // OC是"输出通道"的缩写
    // inp是(B,T,C),weight是(OC, C),bias是(OC)
    // out将是(B,T,OC)
    #pragma omp parallel for collapse(2)
    for (int b = 0; b < B; b++) {
   
        for (int t = 0; t < T; t++) {
   
            float* out_bt = out + b * T * OC + t * OC;
            float* inp_bt = inp + b * T * C + t * C;
            for (int o = 0; o < OC; o++) {
   
                float val = (bias != NULL) ? bias[o] : 0.0f;
                float* wrow = weight + o*C;
                for (int i = 0; i < C; i++) {
   
                    val += inp_bt[i] * wrow[i];
                }
                out_bt[o] = val;
            }
        }
    }
}

// 矩阵乘法后向传递
void matmul_backward(float* dinp, float* dweight, float* dbias, // 输入、权重和偏置的梯度
                     float* dout, float* inp, float* weight, // 梯度输出、输入、权重
                     int B, int T, int C, int OC) {
    // 批次大小、序列长度、通道数、输出通道数
    // 大部分运行时间都在这里和matmul_forward中度过
   // 这个后向可以通过一个"轮次"的循环来完成
    // 但这样的并行化策略并不高效

    // 首先向后传递到inp,按B,T并行化
    #pragma omp parallel for collapse(2)
    for (int b = 0; b < B; b++) {
   
        for (int t = 0; t < T; t++) {
   
            float* dout_bt = dout + b * T * OC + t * OC;
            float* dinp_bt = dinp + b * T * C + t * C;
            for (int o = 0; o < OC; o++) {
   
                float* wrow = weight + o*C;
                float d = dout_bt[o];
                for (int i = 0; i < C; i++) {
   
                    dinp_bt[i] += wrow[i] * d;
                }
            }
        }
    }
    // 向后传递到weight/bias,按输出通道OC并行化
    #pragma omp parallel for
    for (int o = 0; o < OC; o++) {
   
        for (int b = 0; b < B; b++) {
   
            for (int t = 0; t < T; t++) {
   
                float* dout_bt = dout + b * T * OC + t * OC;
                float* inp_bt = inp + b * T * C + t * C;
                float* dwrow = dweight + o*C;
                float d = dout_bt[o];
                if (dbias != NULL) {
    dbias[o] += d; }
                for (int i = 0; i < C; i++) {
   
                    dwrow[i] += inp_bt[i] * d;
                }
            }
        }
    }
}

// 注意力前向传递
void attention_forward(float* out, float* preatt, float* att, // 输出、前馈注意力和注意力矩阵
                        float* inp,
                        int B, int T, int C, int NH) {
    // 批次大小、序列长度、通道数、头数
    // 输入是(B, T, 3C) Q,K,V
    // preatt, att是(B, NH, T, T)
    // 输出是(B, T, C)
    int C3 = C*3;
    int hs = C / NH; // 头大小
    float scale = 1.0 / sqrtf(hs);

    #pragma omp parallel for collapse(3)
    for (int b = 0; b < B; b++) {
   
        for (int t = 0; t < T; t++) {
   
            for (int h = 0; h < NH; h++) {
   
                float* query_t = inp + b * T * C3 + t * C3 + h * hs;
                float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T;
                float* att_bth = att + b*NH*T*T + h*T*T + t*T;

                // 步骤1:计算query dot key并找到最大值
                float maxval = -10000.0f; // TODO 使用更好的方法
                for (int t2 = 0; t2 <= t; t2++) {
   
                    float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C因为是key

                    // (query_t) dot (key_t2)
                    float val = 0.0f;
                    for (int i = 0; i < hs; i++) {
   
                        val += query_t[i] * key_t2[i];
                    }
                    val *= scale;
                    if (val > maxval) {
   
                        maxval = val;
                    }

                    preatt_bth[t2] = val;
                }

                // 步骤2:计算exp并跟踪总和
                float expsum = 0.0f;
                for (int t2 = 0; t2 <= t; t2++) {
   
                    float expv = expf(preatt_bth[t2] - maxval);
                    expsum += expv;
                    att_bth[t2] = expv;
                }
                float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum;

                // 步骤3:归一化得到softmax
                for (int t2 = 0; t2 < T; t2++) {
   
                    if (t2 <= t) {
   
                        att_bth[t2] *= expsum_inv;
                    } else {
   
                        // 因果注意力掩码。这里设置为零并不严格必要
                        // 只是为了调试和检查PyTorch
                        att_bth[t2] = 0.0f;
                    }
                }

                // 步骤4:累积加权值到注意力的输出
                float* out_bth = out + b * T * C + t * C + h * hs;
                for (int i = 0; i < hs; i++) {
    out_bth[i] = 0.0f; }
                for (int t2 = 0; t2 <= t; t2++) {
   
                    float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2因为是value
                    float att_btht2 = att_bth[t2];
                    for (int i = 0; i < hs; i++) {
   
  
  • 11
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

openwin_top

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值