llama2c(4)之forward、sample、decode

12 篇文章 0 订阅
文章详细描述了Transformer模型中forward函数的工作原理,包括输入处理、注意力机制、多头注意力、相对位置编码以及文本生成的sample操作,展示了如何通过Transformer预测下一个token的概率分布并进行采样。
摘要由CSDN通过智能技术生成

1、forward

float* logits = forward(transformer, token, pos);
输入transformer的参数,当前token,pos位置,预测出下一个token的预测值(用矩阵乘,加减乘除等运算构成Transformer)
其中,logits如下:
s->logits = calloc(p->vocab_size, sizeof(float));
matmul(s->logits, &s->xq, w->wcls, dim, p->vocab_size);
根据以上两行代码,和matmul的定义matmul函数的定义,输出的s->logits维度是[1,p->vocab_size]

对应vocab每个字符串的概率分布情况

gdb) p *logits@1000
$15 = {-0.283571005, 3.44877911, -0.578277588, -3.24091816, -1.85795152, 2.61188054, -0.770998061, 0.366253316, -0.637891531, 0.122880608, 2.0521276, 
  0.259968579, 0.553953588, 1.23023224, -1.90220821, 0.791390121, -0.279410094, -2.03433132, 0.736696005, -2.83315516, 0.430814654, -0.45484668, -0.296925813, 
  -0.776587725, -0.373722374, -1.41853309, 0.44897157, 0.298399687, -2.28996897, -0.504646838, -0.219529897, 0.334682822, 0.359610289, 1.333992, -0.0392727256, 
  -0.277485281, -0.281440586, -0.278330177, -0.279631168, -0.275823981, -0.273261875, -0.281633765, -0.280521065, -0.279279858, -0.277830899, -0.275540143, 
  -0.278773159, -0.285891086, -0.275212795, -0.27603671, -0.276746958, -0.281391174, -0.27630195, -0.278620541, -0.281585068, -0.277181506, -0.279754519, 
  -0.276037633, -0.278509229, -0.278621584, -0.271104455, -0.280266523, -0.279526323, -0.280170411, -0.277653664, -0.28433004, -0.275049627, -0.280639797, 
  -0.27556017, -0.279702693, -0.286844194, -0.277686894, -0.278450489, -0.28413251, -0.279598236, -0.273824662, -0.276941836, -0.279240847, -0.281096309, 
  -0.275031894, -0.282162875, -0.282587916, -0.279308707, -0.279815942, -0.280733585, -0.278700113, -0.275241196, -0.273779333, -0.280413181, -0.277753592, 
--Type <RET> for more, q to quit, c to continue without paging--
float* forward(Transformer* transformer, int token, int pos) {

    // a few convenience variables
    Config* p = &transformer->config;
    TransformerWeights* w = &transformer->weights;
    RunState* s = &transformer->state;
    float *x = s->x;
    int dim = p->dim;
    int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
    int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
    int hidden_dim =  p->hidden_dim;
    int head_size = dim / p->n_heads;

    // copy the token embedding into x
    memcpy(x, w->token_embedding_table + token*dim, dim * sizeof(float));

    // forward all the layers
    for(int l = 0; l < p->n_layers; l++) {

        // attention rmsnorm
        rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);

        // qkv matmuls for this position
        quantize(&s->xq, s->xb, dim);
        matmul(s->q, &s->xq, w->wq + l, dim, dim);
        matmul(s->k, &s->xq, w->wk + l, dim, kv_dim);
        matmul(s->v, &s->xq, w->wv + l, dim, kv_dim);

        // RoPE relative positional encoding: complex-valued rotate q and k in each head
        for (int i = 0; i < dim; i+=2) {
            int head_dim = i % head_size;
            float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
            float val = pos * freq;
            float fcr = cosf(val);
            float fci = sinf(val);
            int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
            for (int v = 0; v < rotn; v++) {
                float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
                float v0 = vec[i];
                float v1 = vec[i+1];
                vec[i]   = v0 * fcr - v1 * fci;
                vec[i+1] = v0 * fci + v1 * fcr;
            }
        }

        // save key,value at this time step (pos) to our kv cache
        int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
        float* key_cache_row = s->key_cache + loff + pos * kv_dim;
        float* value_cache_row = s->value_cache + loff + pos * kv_dim;
        memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row));
        memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row));

        // multihead attention. iterate over all heads
        int h;
        #pragma omp parallel for private(h)
        for (h = 0; h < p->n_heads; h++) {
            // get the query vector for this head
            float* q = s->q + h * head_size;
            // attention scores for this head
            float* att = s->att + h * p->seq_len;
            // iterate over all timesteps, including the current one
            for (int t = 0; t <= pos; t++) {
                // get the key vector for this head and at this timestep
                float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
                // calculate the attention score as the dot product of q and k
                float score = 0.0f;
                for (int i = 0; i < head_size; i++) {
                    score += q[i] * k[i];
                }
                score /= sqrtf(head_size);
                // save the score to the attention buffer
                att[t] = score;
            }

            // softmax the scores to get attention weights, from 0..pos inclusively
            softmax(att, pos + 1);

            // weighted sum of the values, store back into xb
            float* xb = s->xb + h * head_size;
            memset(xb, 0, head_size * sizeof(float));
            for (int t = 0; t <= pos; t++) {
                // get the value vector for this head and at this timestep
                float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
                // get the attention weight for this timestep
                float a = att[t];
                // accumulate the weighted value into xb
                for (int i = 0; i < head_size; i++) {
                    xb[i] += a * v[i];
                }
            }
        }

        // final matmul to get the output of the attention
        quantize(&s->xq, s->xb, dim);
        matmul(s->xb2, &s->xq, w->wo + l, dim, dim);

        // residual connection back into x
        for (int i = 0; i < dim; i++) {
            x[i] += s->xb2[i];
        }

        // ffn rmsnorm
        rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);

        // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
        // first calculate self.w1(x) and self.w3(x)
        quantize(&s->xq, s->xb, dim);
        matmul(s->hb, &s->xq, w->w1 + l, dim, hidden_dim);
        matmul(s->hb2, &s->xq, w->w3 + l, dim, hidden_dim);

        // SwiGLU non-linearity
        for (int i = 0; i < hidden_dim; i++) {
            float val = s->hb[i];
            // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
            val *= (1.0f / (1.0f + expf(-val)));
            // elementwise multiply with w3(x)
            val *= s->hb2[i];
            s->hb[i] = val;
        }

        // final matmul to get the output of the ffn
        quantize(&s->hq, s->hb, hidden_dim);
        matmul(s->xb, &s->hq, w->w2 + l, hidden_dim, dim);

        // residual connection
        for (int i = 0; i < dim; i++) {
            x[i] += s->xb[i];
        }
    }

    // final rmsnorm
    rmsnorm(x, x, w->rms_final_weight, dim);

    // classifier into logits
    quantize(&s->xq, x, dim);
    matmul(s->logits, &s->xq, w->wcls, dim, p->vocab_size);
    return s->logits;
}

1)token_embedding_table

 // copy the token embedding into x
    memcpy(x, w->token_embedding_table + token*dim, dim * sizeof(float));

token_embedding_table是反量化后的,而且只有它是反量化的。

// dequantize token embedding table
w->token_embedding_table = malloc(p->vocab_size * p->dim * sizeof(float));
dequantize(w->q_tokens, w->token_embedding_table, p->vocab_size * p->dim);

2)

// attention rmsnorm
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
// qkv matmuls for this position
quantize(&s->xq, s->xb, dim);
matmul(s->q, &s->xq, w->wq + l, dim, dim);
(gdb) ptype s->xb
type = float *

量化是输入是确保与权重(int8)一样的数据类型,即将float32输入的s->xb通过quantize函数量化为int8的s->xq,下面都是int8了

2、sample

2.1 未进入

if (pos < num_prompt_tokens - 1) {
            // if we are still processing the input prompt, force the next prompt token
            next = prompt_tokens[pos + 1];
        } else {
            // otherwise sample the next token from the logits
            next = sample(sampler, logits);
        }

**确定next,**如果还在input prompt,那么下一个token就是next;不是,才用sample得出next
即执行

next = prompt_tokens[pos + 1];

(gdb) p pos
$10 = 0
(gdb) p next
$11 = 15043  //Hello

2.2 进入

根据参数进行采样,生成下一个词的token。

定义:
int sample(Sampler* sampler, float* logits)
(gdb) p *logits
$20 = 0.657589614
(gdb) p *sampler
$1 = {vocab_size = 32000, probindex = 0x7f12efe3b010, temperature = 1, topp = 0.899999976, rng_state = 1710049046}
`temperature`:控制文本生成随机性的参数,0.0意味着最确定(只选最高概率的词),1.0为原始概率分布,值越高生成结果越多样但可能偏离训练数据趋势。

`topp`:在核抽样技术中,决定词汇选择集合的阈值,如设为0.9,则仅考虑累积概率最高的那部分词汇。较低的topp值有助于生成更连贯、高质量文本,但计算上较慢。

`rng_seed`:初始化随机数生成器的种子,默认用当前时间,确保每次运行有不同随机性。设定特定种子可复现相同的随机序列,对生成一致性文本结果有用。

部分代码解释:

  1. temperature=0.0
sampler->temperature == 0.0f
next = sample_argmax(logits, sampler->vocab_size);

调用sample_argmax选取返回概率最高的那个索引

  1. temperature!= 0.0
    每个logits[q]除以sampler->temperature,并通过softmax函数中转化为更符合当前温度设置的概率分布。
    2)_1 当sampler->topp <= 0 或者 sampler->topp >= 1时,用sample_mult函数

调用

 next = sample_mult(logits, sampler->vocab_size, coin);

// sample index from probabilities (they must sum to 1!)
// coin is a random number in [0, 1), usually from random_f32()
定义

int sample_mult(float* probabilities, int n, float coin) {
    // sample index from probabilities (they must sum to 1!)
    // coin is a random number in [0, 1), usually from random_f32()
    float cdf = 0.0f;
    for (int i = 0; i < n; i++) {
        cdf += probabilities[i];
        if (coin < cdf) {    //遍历累加,并同时判断cdf的是否大于coin,有,就返回i
            return i;
        }
    }
    return n - 1; // in case of rounding errors  如果没有就返回n-1
}

2)_2 其他,top-p策略
调用

next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);

参数意义:

$1 = {vocab_size = 32000, probindex = 0x7f12efe3b010, temperature = 1, topp = 0.899999976, rng_state = 1710049046}
float topp: 采样阈值,通常在(0,1)之间,表示我们只考虑累积概率超过这个阈值的那一部分词汇。
ProbIndex* probindex: 一个结构体类型的数组,用于存储经过筛选后的索引及其对应概率。

定义

int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin)

**S1:**只保留概率大于等于 (1 - topp) / (n - 1) 的词汇,并将其对应的索引和概率存入 probindex 结构体数组。并按降序排序

const float cutoff = (1.0f - topp) / (n - 1);
for (int i = 0; i < n; i++) {
    if (probabilities[i] >= cutoff) {
        probindex[n0].index = i;
        probindex[n0].prob = probabilities[i];
        n0++;
    }
}
qsort(probindex, n0, sizeof(ProbIndex), compare);

S2:和上面sample_mult函数语言,这儿只是对筛选后的probindex的里面概率进行累加,如果大于了topp,返回idx

  // truncate the list where cumulative probability exceeds topp
    float cumulative_prob = 0.0f;
    int last_idx = n0 - 1; // in case of rounding errors consider all elements
    for (int i = 0; i < n0; i++) {
        cumulative_prob += probindex[i].prob;
        if (cumulative_prob > topp) {
            last_idx = i;
            break; // we've exceeded topp by including last_idx
        }
    }

S3:根据coin和筛选后的累计概率决定采样那个词汇, return probindex[i].index

  // sample from the truncated list
    float r = coin * cumulative_prob;
    float cdf = 0.0f;
    for (int i = 0; i <= last_idx; i++) {
        cdf += probindex[i].prob;
        if (r < cdf) {
            return probindex[i].index;
        }
    }
    return probindex[last_idx].index; // in case of rounding errors
}

3、decode

token=1,next=15043

调用
char* piece = decode(tokenizer, token, next);
定义
char* decode(Tokenizer* t, int prev_token, int token)
{
    char *piece = t->vocab[token];   //Hello
    // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
    if (prev_token == 1 && piece[0] == ' ') { piece++; }
    // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
    // parse this and convert and return the actual byte
    unsigned char byte_val;
    if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
        piece = (char*)t->byte_pieces + byte_val * 2;
    }
    return piece;
}
(gdb) p piece
$17 = 0x55ae4f286661 "Hello"
  • 8
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值