GLM(General Language Model)代码分析

主要分析一下预训练过程(pretrain_glm.py)
首先是参数设置

# Arguments.
    args = get_args()
    args.mem_length = args.mem_length if args.transformer_xl else 0
    if args.load and not args.new_save_directory:
        args.experiment_name = os.path.basename(os.path.normpath(args.load))
    else:
        args.experiment_name = args.experiment_name + datetime.now().strftime("%m-%d-%H-%M")
    if args.save:
        args.save = os.path.join(args.save, args.experiment_name)
    # Pytorch distributed.
    # 设置初始化参数
    initialize_distributed(args)

    # Random seeds for reproducability.
    # 随机数种子
    set_random_seed(args.seed)

接下来准备tokenizer,提供了Bert、GPT、中文编码三种方式,分别对应了BertWordPieceTokenizer、GPT2BPETokenizer 和ChineseSPTokenizer。

然后通过get_train_val_test_data()函数得到训练集、验证集以及测试集。

def get_train_val_test_data(args, tokenizer):
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""

    (train_data, val_data, test_data) = (None, None, None)
    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
        data_config = configure_data()
        if args.block_lm:
            data_set_type = "Block"
        elif args.transformer_xl:
            data_set_type = "GPT-XL"
        else:
            data_set_type = "GPT2"
        data_config.set_defaults(data_set_type=data_set_type, transpose=False)
        train_data, val_data, test_data = data_config.apply(args, tokenizer)

        data_counts = torch.cuda.LongTensor([int(args.do_train), int(args.do_valid), int(args.do_test)])
    else:
        data_counts = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(data_counts,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    args.do_train = data_counts[0].item()
    args.do_valid = data_counts[1].item()
    args.do_test = data_counts[2].item()

    return train_data, val_data, test_data

这里的数据集获取主要依托于data_config.apply(),其中包含了make_loaders()函数

def make_loaders(args, tokenizer):
    """makes training/val/test"""

    if args.use_tfrecords:
        return make_tfrecord_loaders(args)
    world_size = torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
    if args.loader_scatter is not None:
        assert world_size % args.loader_scatter == 0
    # batch_size以及句子长度设置
    batch_size = args.batch_size * world_size
    eval_batch_size = batch_size
    if args.eval_batch_size is not None:
        eval_batch_size = args.eval_batch_size * world_size
    seq_length = args.seq_length
    if seq_length < 0:
        seq_length = seq_length * world_size
    eval_seq_length = args.eval_seq_length
    if eval_seq_length is not None and eval_seq_length < 0:
        eval_seq_length = eval_seq_length * world_size
    split = get_split(args)
    # 数据参数设置
    data_set_args = {
        'path': args.train_data,
        'seq_length': seq_length,
        'mem_length': args.mem_length,
        'delim': args.delim,
        'text_key': args.text_key,
        'label_key': 'label',
        'ds_type': args.data_set_type,
        'split': split,
        'loose': args.loose_json,
        'max_preds_per_seq': args.max_preds_per_seq,
        'presplit_sentences': args.presplit_sentences,
        'sample_one_document': args.sample_one_document,
        'filter_english': args.filter_english,
        'pre_tokenize': not args.no_pre_tokenize,
        'tokenizer': tokenizer,
        'save_splits': args.save_splits,
        'load_splits': args.load_splits,
        'save_test_data': args.save_test_data,
        'no_lazy_loader': args.no_lazy_loader,
        'loader_scatter': args.loader_scatter,
        'data_parallel_rank': mpu.get_data_parallel_rank(),
        "non_sentence_start": args.non_sentence_start,
        "half_lazy_loader": args.half_lazy_loader
    }

    eval_set_args = copy.copy(data_set_args)
    eval_set_args['split'] = [1.]
    # if optional eval args were set then replace their
    # equivalent values in the arg dict
    if eval_seq_length:
        eval_set_args['seq_length'] = eval_seq_length
    if args.eval_max_preds_per_seq:
        eval_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq
    if args.eval_text_key is not None:
        eval_set_args['text_key'] = args.eval_text_key

    # make datasets splits and tokenizer
    train, valid, test = None, None, None

    if args.train_data is not None:
    	# 构建训练数据集
        train = data_utils.make_dataset(**data_set_args)
        if data_utils.should_split(split):
            train, valid, test = train
        eval_set_args['tokenizer'] = tokenizer

    # make training and val dataset if necessary
    if valid is None and args.valid_data is not None:
        eval_set_args['path'] = args.valid_data
        valid = data_utils.make_dataset(**eval_set_args)
        eval_set_args['tokenizer'] = tokenizer
    if test is None and args.test_data is not None:
        eval_set_args['path'] = args.test_data
        test = data_utils.make_dataset(**eval_set_args)

    # wrap datasets with data loader
    use_block = args.block_lm or args.encoder_decoder

    if train is not None and args.batch_size > 0:
        train = make_data_loader(train, tokenizer, batch_size, args.train_iters, args, shuffle=args.shuffle,
                                 block_collate=use_block)
        args.do_train = True
    else:
        args.do_train = False
    eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size
    if valid is not None:
        valid = make_data_loader(valid, tokenizer, eval_batch_size, args.train_iters, args, shuffle=args.shuffle,
                                 block_collate=use_block)
        args.do_valid = True
    else:
        args.do_valid = False
    if test is not None:
        test = make_data_loader(test, tokenizer, eval_batch_size, len(test) // eval_batch_size + 1, args,
                                shuffle=args.shuffle, block_collate=use_block)
        args.do_test = True
    else:
        args.do_test = False

    return train, valid, test

其中,进一步调用了data_utils.make_dataset(),进而调用get_dataset(),其中包含了对MASK和跨度的设计,主要思路如下图所示
在这里插入图片描述

def get_dataset(name, tokenizer, pre_tokenize, data_parallel_rank, loader_scatter=None, no_lazy_loader=False,
                half_lazy_loader=False):
    """gets dataset object based on keyword args and file at `path`"""
    global_rank = torch.distributed.get_rank()
    if not supported_corpus(name):
        raise NotImplementedError('dataset %s is not supported' % name)
    dataset = corpora.NAMED_CORPORA[name]
    path = dataset.PATH
    if issubclass(dataset, corpora.PromptReader):
        if not (exists_lazy(path, data_type='prompt') and exists_lazy(path, data_type='text')) and not (
                loader_scatter is not None and exists_scatter(path, data_type='prompt',
                                                              scatter_num=loader_scatter) and exists_scatter(path,
                                                                                                             data_type='text',
                                                                                                             scatter_num=loader_scatter)):
            # create cached version of dataset for lazy loading if it doesn't exist
            if global_rank == 0:
                print(f"Creating lazy loader for dataset {name}")
                prompt_writer = LazyWriter(path, data_type='prompt', is_array=pre_tokenize)
                text_writer = LazyWriter(path, data_type='text', is_array=pre_tokenize)
                writers = {'prompt': prompt_writer, 'text': text_writer}
                reader = dataset(writers=writers, tokenizer=tokenizer, tokenize=pre_tokenize)
                reader.process()
                prompt_writer.close()
                text_writer.close()
            else:
                while not os.path.exists(LazyWriter.get_len_path(path, data_type='prompt')):
                    time.sleep(1)
        map_fn = (lambda x: x.tolist()) if pre_tokenize else None
        if loader_scatter is not None:
            if not (exists_scatter(path, data_type='prompt', scatter_num=loader_scatter) and exists_scatter(path,
                                                                                                            data_type='text',
                                                                                                            scatter_num=loader_scatter)):
                if global_rank == 0:
                    print(f"Creating scatter loader for dataset {name}")
                    prompts = LazyLoader(path, data_type='prompt', map_fn=map_fn, mem_map=True,
                                         is_array=pre_tokenize)
                    texts = LazyLoader(path, data_type='text', map_fn=map_fn, mem_map=True,
                                       is_array=pre_tokenize)
                    indices = list(range(len(texts)))
                    random.shuffle(indices)
                    segment_length = (len(indices) - 1) // loader_scatter + 1
                    for i in range(loader_scatter):
                        scatter_path = get_scatter_path(path, scatter_rank=i)
                        prompt_writer = LazyWriter(scatter_path, data_type='prompt', is_array=pre_tokenize)
                        text_writer = LazyWriter(scatter_path, data_type='text', is_array=pre_tokenize)
                        for idx in indices[i * segment_length: (i + 1) * segment_length]:
                            prompt_writer.write(prompts[idx])
                            text_writer.write(texts[idx])
                        prompt_writer.close()
                        text_writer.close()
                else:
                    while not (
                            exists_scatter(path, data_type='prompt', scatter_num=loader_scatter) and exists_scatter(
                        path, data_type='text', scatter_num=loader_scatter)):
                        time.sleep(1)
            scatter_path = get_scatter_path(path, scatter_rank=data_parallel_rank % loader_scatter)
            print(f"Rank {global_rank} is using scatter from {scatter_path}")
            prompts = LazyLoader(scatter_path, data_type='prompt', map_fn=map_fn, mem_map=True,
                                 is_array=pre_tokenize, load_memory=no_lazy_loader, half_load=half_lazy_loader)
            texts = LazyLoader(scatter_path, data_type='text', map_fn=map_fn, mem_map=True,
                               is_array=pre_tokenize, load_memory=no_lazy_loader, half_load=half_lazy_loader)
        else:
            prompts = LazyLoader(path, data_type='prompt', map_fn=map_fn, mem_map=True,
                                 is_array=pre_tokenize, load_memory=no_lazy_loader, half_load=half_lazy_loader)
            texts = LazyLoader(path, data_type='text', map_fn=map_fn, mem_map=True,
                               is_array=pre_tokenize, load_memory=no_lazy_loader, half_load=half_lazy_loader)
        text = corpora.PromptDataset(prompt_loader=prompts, text_loader=texts, tokenizer=tokenizer,
                                     to_tokenize=not pre_tokenize)
        if loader_scatter is None:
            if global_rank == 0:
                print(f"Create dataset {name} with {len(text)} documents")
                for i in range(10):
                    rand_id = i if i < 5 else random.randrange(len(text))
                    sample_tokens = text[rand_id]['tokens'][:1024]
                    print(sample_tokens)
                    print(tokenizer.DecodeIds(sample_tokens).encode('utf-8'))
        else:
            for scatter_id in range(loader_scatter):
                if data_parallel_rank % loader_scatter == scatter_id and data_parallel_rank // loader_scatter == 0:
                    print(f"Create dataset {name} at scatter {scatter_id} with {len(text)} documents")
                    for i in range(10):
                        sample_tokens = text[i]['tokens'][:1024]
                        print(sample_tokens)
                        print(tokenizer.DecodeIds(sample_tokens))
                torch.distributed.barrier()
        return text
    elif issubclass(dataset, corpora.KeyReader):
        if not (exists_lazy(path, data_type='text') and exists_lazy(path, data_type='mask')):
            # create cached version of dataset for lazy loading if it doesn't exist
            if global_rank == 0:
                text_writer = LazyWriter(path, data_type='text', is_array=pre_tokenize)
                mask_writer = LazyWriter(path, data_type='mask', is_array=True)
                writers = {'mask': mask_writer, 'text': text_writer}
                dataset(writers=writers, tokenizer=tokenizer, tokenize=pre_tokenize)
                mask_writer.close()
                text_writer.close()
            else:
                while not os.path.exists(LazyWriter.get_len_path(path, data_type='mask')):
                    time.sleep(1)
        map_fn = (lambda x: x.tolist()) if pre_tokenize else None
        masks = LazyLoader(path, data_type='mask', map_fn=map_fn, mem_map=True, is_array=True)
        texts = LazyLoader(path, data_type='text', map_fn=map_fn, mem_map=True, is_array=pre_tokenize)
        text = corpora.KeyDataset(mask_loader=masks, text_loader=texts, tokenizer=tokenizer,
                                  to_tokenize=not pre_tokenize)
        return text

这样数据集就构建好了,如果是多任务场景,则构建多任务数据集

# 得到多任务数据集的内容
    if args.multi_task_ratio > 0.0:
        multi_train_data, multi_val_data = build_multi_task_dataset(args, tokenizer)

接下来会判断是否需要导入预训练模型,并且对日志和记录进行保存

# 是否导入预训练模型
    if args.load is not None:
        with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"), timeout=-1):
            args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args)
    else:
        args.iteration = 0
    torch.distributed.barrier()
    if args.switch_linear:
        lr_scheduler.switch_linear(args)

    summary_writer = None
    if torch.distributed.get_rank() == 0:
        print('Pretrain GPT2 model')
        args.log_dir = None
        # 保存日志和输出
        if args.train_iters > 0:
            args.log_dir = get_log_dir(base=args.summary_dir, name=args.experiment_name)
            summary_writer = get_sample_writer(log_dir=args.log_dir, iteration=args.iteration)
        print_and_save_args(args, verbose=True, log_dir=args.log_dir)

如果导入的模型已经训练过了,构建继续训练的数据集,和上面构建数据集的主要区别在于对开始迭代次数的处理

# Resume data loader if necessary.
    if args.resume_dataloader:
        print_rank_0("Resume dataloader")
        if train_data is not None:
            train_data.batch_sampler.start_iter = args.iteration % len(train_data)
        if val_data is not None:
            start_iter_val = (args.iteration // args.eval_interval) * args.eval_iters
            val_data.batch_sampler.start_iter = start_iter_val % len(val_data)
        if multi_train_data is not None:
            multi_train_data.batch_sampler.start_iter = int(args.iteration * args.multi_task_ratio) % len(
                multi_train_data)
        if multi_val_data is not None:
            start_iter_val = (args.iteration // args.eval_interval) * args.eval_iters * args.multi_task_ratio
            multi_val_data.batch_sampler.start_iter = start_iter_val % len(multi_val_data)
    if train_data is not None:
        train_data_iterator = iter(train_data)
    else:
        train_data_iterator = None
    if multi_train_data is not None:
        multi_train_iterator = iter(multi_train_data)
    else:
        multi_train_iterator = None
    if val_data is not None:
        val_data_iterator = iter(val_data)
    else:
        val_data_iterator = None
    if multi_val_data is not None:
        multi_val_iterator = iter(multi_val_data)
    else:
        multi_val_iterator = None

然后通过train函数进行训练

iteration, skipped = train(model, optimizer,
                                           lr_scheduler,
                                           (train_data_iterator, multi_train_iterator),
                                           (val_data_iterator, multi_val_iterator),
                                           timers, args, summary_writer=summary_writer)

其中,主要通过train_step()函数进行训练,其核心是forward_step()函数

def forward_step(data_iterator, model, args, timers, mems):
    """Forward step."""

    # Get the batch.
    timers('batch generator').start()
    timers('data loader').start()
    rand = random.Random(args.iteration * mpu.get_data_parallel_world_size() + mpu.get_data_parallel_rank())
    if data_iterator[1] and rand.random() < args.multi_task_ratio:
        data = next(data_iterator[1]) if data_iterator[1] else None
        data["mode"] = "multi-task"
    else:
        data = next(data_iterator[0]) if data_iterator[0] else None
    # print_rank_0("data iterator")
    timers('data loader').stop()
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data, args)
    timers('batch generator').stop()

    # print_rank_0("get batch")

    def print_masked_text(batch_id):
        block_position_ids = position_ids[:, 1]
        position_ids_ = position_ids[:, 0]
        sep = attention_mask.item() if torch.numel(attention_mask) == 1 else attention_mask[batch_id].item()
        text, last_segment = "", []
        for i, token_id in enumerate(tokens[batch_id, :sep].tolist()):
            token = tokenizer.IdToToken(token_id)
            if token.startswith('[MASK') or token.endswith('MASK]'):
                if last_segment:
                    text += tokenizer.DecodeIds(last_segment)
                    last_segment = []
                text += f" [{position_ids_[batch_id, i].item()}, {token}]"
            else:
                last_segment.append(token_id)
        if last_segment:
            text += tokenizer.DecodeIds(last_segment)
        print(text.encode('utf-8'))
        last_index = None
        for i in range(sep, tokens.size(1)):
            if tokenizer.IdToToken(tokens[batch_id, i].item()).startswith("<|startofpiece"):
                if last_index is not None:
                    print(tokenizer.DecodeIds(tokens[batch_id, last_index: i].tolist()).encode('utf-8'), "|",
                          tokenizer.DecodeIds(labels[batch_id, last_index: i].tolist()).encode('utf-8'),
                          position_ids_[batch_id, last_index: i].tolist(),
                          block_position_ids[batch_id, last_index:i].tolist())
                last_index = i
        if last_index is not None:
            print(tokenizer.DecodeIds(tokens[batch_id, last_index:].tolist()).encode('utf-8'), "|",
                  tokenizer.DecodeIds(labels[batch_id, last_index:].tolist()).encode('utf-8'),
                  position_ids_[batch_id, last_index:].tolist(), block_position_ids[batch_id, last_index:].tolist())

    if data is not None and "mode" in data:
        mode = data['mode']
    else:
        mode = 'bert'

    logits, *mems = model(tokens, position_ids, attention_mask, *mems)
    # 损失函数定义为交叉熵损失
    losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(),
                                              labels)
    loss_mask = loss_mask.view(-1)
    loss = torch.sum(losses.view(-1) * loss_mask)
    if loss_mask.sum().item() > 0:
        loss = loss / loss_mask.sum()

    return loss, mems, mode

最后就是模型的保存以及评估

 # Checkpointing
        if args.save and args.save_interval and args.iteration % args.save_interval == 0:
            save_checkpoint(args.iteration, model, optimizer, lr_scheduler, args)

        # Evaluation
        if args.eval_interval and args.iteration % args.eval_interval == 0 and args.do_valid:
            prefix = 'iteration {}'.format(args.iteration)
            evaluate_and_print_results(
                prefix, val_data_iterator, model, args, timers, verbose=False, step=args.iteration,
                summary_writer=summary_writer, forward_step_func=forward_step)
  • 6
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值