llama.c中的代码

12 篇文章 0 订阅

1、build transformer

void build_transformer(Transformer *t, char* checkpoint_path) {
    // read in the Config and the Weights from the checkpoint
    read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
    // allocate the RunState buffers
    malloc_run_state(&t->state, &t->config);
}

1.1 read_checkpoint

void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
                     int* fd, float** data, ssize_t* file_size) {
    FILE *file = fopen(checkpoint, "rb");
    if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }
    // read in the config header  判断是否成功从文件中读取了一个完整的Config结构体。如果读取失败或者没有读满一个Config结构体(例如遇到文件结尾),则表达式的值为真(非零),表示出现了错误或意外情况。
    if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); }
    // negative vocab size is hacky way of signaling unshared weights. bit yikes.这意味着如果 config 结构体中的 vocab_size 成员变量大于0,则将 shared_weights 变量赋值为1;反之,若 vocab_size 等于0或小于0,则将 shared_weights 变量赋值为0。
    int shared_weights = config->vocab_size > 0 ? 1 : 0;
    config->vocab_size = abs(config->vocab_size);  //negative  is hacky way,so abs  32000
    // figure out the file size
    fseek(file, 0, SEEK_END); // move file pointer to end of file
    *file_size = ftell(file); // get the file size, in bytes   0
    fclose(file);
    // memory map the Transformer weights into the data pointer
    *fd = open(checkpoint, O_RDONLY); // open in read only mode
    if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); }
    *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0); //整个语句的作用是将文件的内容映射到进程的虚拟地址空间中,这样就可以像操作内存一样直接访问文件内容,而无需通过传统的I/O操作。映射结束后,*data指向的就是映射后的内存区域的起始地址。
    if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); } //判断是否映射成功,如果mmap()函数返回MAP_FAILED(通常是一个负值)
    float* weights_ptr = *data + sizeof(Config)/sizeof(float);//算权重数组在映射内存中的起始地址:首先找到Config结构体结束的位置,然后将这个位置加上sizeof(Config)/sizeof(float)得到weights_ptr,即指向权重数据的首地址。
    memory_map_weights(weights, config, weights_ptr, shared_weights);
}

1.2 memory_map_weights

//(gdb) print *p  $15 = {dim = 288, hidden_dim = 768, n_layers = 6, n_heads = 6, n_kv_heads = 6, vocab_size = 32000, seq_len = 256}
//函数的目的是TransformerWeights结构体中的所有成员变量都已指向正确的内存区域,包含了完整的模型权重数据
/*
+-----------------------+
| token_embedding_table |
+-----------------------+
| rms_att_weight        |
+-----------------------+
| wq (多层次分布)      |
+-----------------------+
| wk (多层次分布)      |
+-----------------------+
| wv (多层次分布)      |
+-----------------------+
| wo (多层次分布)      |
+-----------------------+
| rms_ffn_weight        |
+-----------------------+
| w1 (多层次分布)      |
+-----------------------+
| w2 (多层次分布)      |
+-----------------------+
| w3 (多层次分布)      |
+-----------------------+
| rms_final_weight      |
+-----------------------+
[... skipped RoPE-related memory ...]
+-----------------------+
| wcls (可能共享)       |
+-----------------------+
*/
void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
    int head_size = p->dim / p->n_heads; //计算注意力头大小(head_size)。
    // make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models
    unsigned long long n_layers = p->n_layers;
    w->token_embedding_table = ptr;  //将weight中的第一个位置给weight结构体中的token_embedding_table
    ptr += p->vocab_size * p->dim;//以上算一个整体
    w->rms_att_weight = ptr;
    ptr += n_layers * p->dim;
    w->wq = ptr;
    ptr += n_layers * p->dim * (p->n_heads * head_size);
    w->wk = ptr;
    ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
    w->wv = ptr;
    ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
    w->wo = ptr;
    ptr += n_layers * (p->n_heads * head_size) * p->dim;
    w->rms_ffn_weight = ptr;
    ptr += n_layers * p->dim;
    w->w1 = ptr;
    ptr += n_layers * p->dim * p->hidden_dim;
    w->w2 = ptr;
    ptr += n_layers * p->hidden_dim * p->dim;
    w->w3 = ptr;
    ptr += n_layers * p->dim * p->hidden_dim;
    w->rms_final_weight = ptr;
    ptr += p->dim;
    ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_real (for RoPE)
    ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_imag (for RoPE)
    w->wcls = shared_weights ? w->token_embedding_table : ptr;
}

1.3 malloc_run_state

// 为state分配空间,监控代码运行情况
void malloc_run_state(RunState* s, Config* p) {
    // we calloc instead of malloc to keep valgrind happy
    int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
    s->x = calloc(p->dim, sizeof(float));
    s->xb = calloc(p->dim, sizeof(float));
    s->xb2 = calloc(p->dim, sizeof(float));
    s->hb = calloc(p->hidden_dim, sizeof(float));
    s->hb2 = calloc(p->hidden_dim, sizeof(float));
    s->q = calloc(p->dim, sizeof(float));
    s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
    s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
    s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
    s->logits = calloc(p->vocab_size, sizeof(float));
    // ensure all mallocs went fine
    if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
     || !s->key_cache || !s->value_cache || !s->att || !s->logits) {
        fprintf(stderr, "malloc failed!\n");
        exit(EXIT_FAILURE);
    }
}

2 generate(&transformer, &tokenizer, &sampler, prompt, steps);

2.1 encode

 encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
(gdb) print *t
$1 = {vocab = 0x7fe657a28010, vocab_scores = 0x560eb70a5b00, sorted_vocab = 0x0, vocab_size = 32000, max_token_length = 27, 
  byte_pieces = "\000\000\001\000\002\000\003\000\004\000\005\000\006\000\a\000\b\000\t\000\n\000\v\000\f\000\r\000\016\000\017\000\020\000\021\000\022\000\023\000\024\000\025\000\026\000\027\000\030\000\031\000\032\000\033\000\034\000\035\000\036\000\037\000 \000!\000\"\000#\000$\000%\000&\000'\000(\000)\000*\000+\000,\000-\000.\000/\000\060\000\061\000\062\000\063\000\064\000\065\000\066\000\067\000\070\000\071\000:\000;\000<\000=\000>\000?\000@\000A\000B\000C\000D\000E\000F\000G\000H\000I\000J\000K\000L\000M\000N\000O\000P\000Q\000R\000S\000T\000U\000V\000W\000X\000Y\000Z\000[\000\\\000]\000^\000_\000`\000a\000b\000c\000"...}
  vocab:指向一个包含词汇表字符串的数组的指针。输出排序前的vocab数组内容
(gdb)  p  *text
$2 = 104 'h'

(gdb) print bos
$14 = 1 '\001'

(gdb) print eos
$15 = 0 '\000'

(gdb) p *tokens
$5 = 1537102592

(gdb) p  *n_tokens
$6 = 0
void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
    // encode the string text (input) into an upper-bound preallocated tokens[] array
    // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
    if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
    //处理词汇表,如果词汇表未排序或初始化,会进行初始化分配内存t->sorted_vocab
    if (t->sorted_vocab == NULL) {
        // lazily malloc and sort the vocabulary  通过一个循环将原始未排序的词汇表t->vocab中的每个元素复制到新分配的sorted_vocab中。并排序
        t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
        for (int i = 0; i < t->vocab_size; i++) {
            t->sorted_vocab[i].str = t->vocab[i];
            t->sorted_vocab[i].id = i;
        }
        qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
    }

解释如下:

(gdb) p t->vocab[0]
$5 = 0x561b5bfbbf20 "<unk>"
(gdb) p t->vocab[1]
$6 = 0x561b5bfbbf40 "\n<s>\n"
(gdb) p t->vocab[2]
$7 = 0x561b5bfbbf60 "\n</s>\n"
(gdb) p t->vocab[100]
$8 = 0x561b5bfbcba0 "<0x61>"
(gdb) p t->vocab[1100]
$9 = 0x561b5bfc48a0 "son"
(gdb) p t->vocab[1101]
$10 = 0x561b5bfc48c0 " follow"

排序后:地址,字母,id

(gdb) print t->sorted_vocab[0]
$3 = {str = 0x5564f5456f60 "\n</s>\n", id = 2}
(gdb) print t->sorted_vocab[1000]
$4 = {str = 0x5564f54a5880 " Chicago", id = 10059}
(gdb) print t->sorted_vocab[1001]
$5 = {str = 0x5564f54c8960 " Chief", id = 14546}
(gdb) print t->sorted_vocab[1002]
$6 = {str = 0x5564f5526660 " Chiesa", id = 26553}
 // create a temporary buffer that will store merge candidates of always two consecutive tokens
    // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
    char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));
    size_t str_len = 0;

解释如下:

t->max_token_length * 2 + 1 + 2:计算分配内存大小的方式解释如下:

t->max_token_length * 2:假设每个词元的最大长度是t->max_token_length,将两个词元拼接起来时,可能的最大长度将是它们长度之和。
+1:添加一个额外字节用于字符串的终止符\0(null terminator),它标志着字符串的结束。
+2:考虑到UTF-8编码中,即使是单个字符也可能占用多个字节(最多4个字节),这里预留了额外的空间以处理这种可能性。如果最大词元长度是1,则至少需要额外预留2个字节来防止在处理多字节UTF-8字符时发生越界。
    // start at 0 tokens
    *n_tokens = 0;

    // add optional BOS (=1) token, if desired
    if (bos) tokens[(*n_tokens)++] = 1;

    // add_dummy_prefix is true by default
    // so prepend a dummy prefix token to the input string, but only if text != ""
    // TODO: pretty sure this isn't correct in the general case but I don't have the
    // energy to read more of the sentencepiece code to figure out what it's doing
    if (text[0] != '\0') {
        int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);//
        tokens[(*n_tokens)++] = dummy_prefix;
    }

解释如下:
对输入文本进行编码前初始化和处理token数组。

(gdb) print *tokens
$7 = 1815610112
(gdb) p text[0]
$10 = 104 'h'
(gdb) p bos
$9 = 1 '\001'
如果需要添加开始标记(BOS,Beginning Of Sentence),则将BOS token的ID(这里为1)写入tokens数组,并递增*n_tokens的值。
 int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);函数返回找到的字符串在词汇表中的索引(即token ID)。
 语法:这个表达式在C/C++编程中执行了两个操作:

数组赋值:tokens[(*n_tokens)] = some_value; 这部分将某个值(这里为some_value)赋给数组tokens的指定索引位置。这里的索引是通过指针n_tokens所指向的整数值来确定的。

后置递增运算符:(*n_tokens)++ 后置递增运算符会先使用当前*n_tokens的值作为索引,然后在此之后将*n_tokens的值加1。这意味着每次执行此表达式时,都会向tokens数组添加一个新元素,并且下一次添加将会发生在数组的下一个位置上。

结合在一起,这段代码常用于动态地填充一个数组,其中数组的索引随着每次添加新元素而自动增加。在这种情况下,它被用来逐个添加token ID到已编码的token列表(tokens数组)中。

    // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
    // Code point ↔ UTF-8 conversion
    // First code point	Last code point	Byte 1	Byte 2	Byte 3	Byte 4
    // U+0000	U+007F	    0xxxxxxx
    // U+0080	U+07FF	    110xxxxx	10xxxxxx
    // U+0800	U+FFFF	    1110xxxx	10xxxxxx	10xxxxxx
    // U+10000	U+10FFFF    11110xxx	10xxxxxx	10xxxxxx	10xxxxxx

    // process the raw (UTF-8) byte sequence of the input string
    for (char *c = text; *c != '\0'; c++) {

        // reset buffer if the current byte is ASCII or a leading byte
        // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
        // 0x80 is 10000000
        // in UTF-8, all continuation bytes start with "10" in first two bits
        // so in English this is: "if this byte is not a continuation byte"
        if ((*c & 0xC0) != 0x80) {
            // this byte must be either a leading byte (11...) or an ASCII char (0x...)
            // => reset our location, as we're starting a new UTF-8 codepoint
            str_len = 0;
        }

        // append the current byte to the buffer
        str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
        str_buffer[str_len] = '\0';

        // while the next character is a continuation byte, continue appending
        // but if there are too many of them, just stop to avoid overruning str_buffer size.
        if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) {
            continue;
        }

        // ok c+1 is not a continuation byte, so we've read in a full codepoint
        int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);

        if (id != -1) {
            // we found this codepoint in vocab, add it as a token
            tokens[(*n_tokens)++] = id;
        } else {
            // byte_fallback encoding: just encode each byte as a token
            // +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
            // so the individual bytes only start at index 3
            for (int i=0; i < str_len; i++) {
                tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
            }
        }
        str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
    }

代码解释如下

这段代码主要用于处理输入字符串(UTF-8编码)中的每个字符,将其转换为词汇表中对应的token ID
根据Wikipedia上的UTF-8编码规则,根据不同的Unicode码点范围,字符可以由14个字节组成。例如:

U+0000至U+007F的字符只需要一个字节表示。
U+0080至U+07FF的字符需要两个字节表示,第一个字节以110开头,第二个字节以10开头。
更高级别的码点需要更多字节。
(gdb) print *c    //输入的是hello
$11 = 104 'h'

假设有以下UTF-8编码序列:

Code
63 61 C3 A9
这里,C3 和 A9 组成了一个法语字符 'é' 的UTF-8编码。在这个例子中:

当指针 c 指向 C3 时,(*c & 0xC0) == 0xC0,因此这不是一个延续字节,而是多字节字符的起始字节。
当指针 c 指向 A9 时,(*c & 0xC0) == 0x80,这意味着这是一个延续字节。
 if ((*c & 0xC0) != 0x80) {
            // this byte must be either a leading byte (11...) or an ASCII char (0x...)
            // => reset our location, as we're starting a new UTF-8 codepoint
            str_len = 0;
        }

        // append the current byte to the buffer
        str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
        str_buffer[str_len] = '\0';
(h &0xC0) != 0x80,执行str_len = 0,即 str_buffer[0]=h
(gdb) print str_buffer[0]
$5 = 104 'h'
(gdb) print str_buffer[1]
$6 = 0 '\000'


 if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) {
            continue;
        }
(gdb) print *(c+1)
$25 = 101 'e'


int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
函数用于查找str_buffer所表示的Unicode码点,在已排序的词汇表(t->sorted_vocab)中是否存在对应的条目,以及其在词汇表中的索引或ID。
(gdb) print id
$1 = 29882

 if (id != -1) {
            // we found this codepoint in vocab, add it as a token
            tokens[(*n_tokens)++] = id;
(gdb) print (*n_tokens)++
$8 = 3

(gdb) print tokens[0] 
$9 = 1
来源 if (bos) tokens[(*n_tokens)++] = 1;

(gdb) print tokens[1]  (
$10 = 29871
来源:gdb) print dummy_prefix
$1 = 29871

(gdb) print tokens[2]
$11 = 29882

也即是t->sorted_vocab是排序好的词汇表,str_buffer是prompt中的一个个字母。用prompt中的字母在t->sorted_vocab词汇表找id,然后给tokens

 while (1) {
        float best_score = -1e10;
        int best_id = -1;
        int best_idx = -1;

        for (int i=0; i < (*n_tokens-1); i++) {
            // check if we can merge the pair (tokens[i], tokens[i+1])
            sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
            int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
            if (id != -1 && t->vocab_scores[id] > best_score) {
                // this merge pair exists in vocab! record its score and position
                best_score = t->vocab_scores[id];
                best_id = id;
                best_idx = i;
            }
        }

代码解释如下:

(gdb) print *n_tokens
$10 = 7
来源:1、dummy_prefix、h、e、l、l、o
(gdb) p *tokens@7
$12 = {1, 29871, 29882, 29872, 29880, 29880, 29877}
(gdb) p *t->sorted_vocab@12
$17 = {{str = 0x5591f27cdf60 "\n</s>\n", id = 2}, {str = 0x5591f27cdf40 "\n<s>\n", id = 1}, {str = 0x5591f28b85d0 "\r", id = 30004}, {
    str = 0x5591f28b7530 " ", id = 29871}, {str = 0x5591f2802ba0 " \r", id = 6756}, {str = 0x5591f27cff80 "  ", id = 259}, {str = 0x5591f27db0e0 "   ", 
    id = 1678}, {str = 0x5591f27d00a0 "    ", id = 268}, {str = 0x5591f27d1360 "     ", id = 418}, {str = 0x5591f27d2280 "      ", id = 539}, {
    str = 0x5591f27f2b60 "       ", id = 4706}, {str = 0x5591f27d05a0 "        ", id = 308}}

(gdb) print *t->vocab@8
$19 = {0x5591f27cdf20 "<unk>", 0x5591f27cdf40 "\n<s>\n", 0x5591f27cdf60 "\n</s>\n", 0x5591f27cdf80 "<0x00>", 0x5591f27cdfa0 "<0x01>", 0x5591f27cdfc0 "<0x02>", 0x5591f27cdfe0 "<0x03>", 
  0x5591f27ce000 "<0x04>"}

 sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
这段代码的作用是将连续的两个token(由tokens数组中的索引i和i+1确定)对应的字符串从词汇表中取出,并拼接到一起,存储到str_buffer中

调用str_lookup函数,使用拼接后的字符串在已排序的词汇表(sorted_vocab)中查找是否有对应的ID。如果找到了并且该ID在vocab_scores中的得分大于之前找到的最佳得分,则更新最佳得分、最佳ID以及最佳索引(即当前考虑的token对的位置i)。

 int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
  if (id != -1 && t->vocab_scores[id] > best_score) {
                // this merge pair exists in vocab! record its score and position
                best_score = t->vocab_scores[id];
                best_id = id;
                best_idx = i;
            }
这段循环的意义在于寻找最优的token对合并方式,**即将那些能形成词汇表中存在的有效词的连续token合并成一个token,依据是它们合并后在词汇表中的得分。通过这样的优化过程,可以生成更加紧凑且符合预定义词汇表结构的编码结果。**
那为什么要取排序前的t->vocab两个相邻token呢?
千问解释:
虽然实际编码时使用的是排序后的词汇表,但在构建要查找的合并字符串时,需要依据原始未排序的词汇表来获取token对应的字符串内容。这是因为排序前的t->vocab数组可能按照某种规则存储了token及其对应的字符串值,而排序后t->sorted_vocab主要关注ID顺序和得分,便于快速查找和比较。
        if (best_idx == -1) {
            break; // we couldn't find any more pairs to merge, so we're done
        }

        // merge the consecutive pair (best_idx, best_idx+1) into new token best_id
        tokens[best_idx] = best_id;
        // delete token at position best_idx+1, shift the entire sequence back 1
        for (int i = best_idx+1; i < (*n_tokens-1); i++) {
            tokens[i] = tokens[i+1];
        }
        (*n_tokens)--; // token length decreased
    }

    // add optional EOS (=2) token, if desired
    if (eos) tokens[(*n_tokens)++] = 2;

    free(str_buffer);
}
找到可合并token对,进行合并更新。
  • 8
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值