山东大学软件学院创新实训:角色疆界 - 智能电影角色扮演对话大模型(四)

介绍

本阶段任务为训练代码编写、调试等工作。

分布式数据加载

        为了使用多卡并行训练,我们重新编写了分布式数据集代码CustomTrainer类,如下:

CustomTrainer(Trainer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def get_train_dataloader(self) -> DataLoader:
        dataset = self.train_dataset
        sampler = DistributedSampler(dataset)
        return torch.utils.data.DataLoader(
            dataset, batch_size=self._train_batch_size,
            sampler=sampler,
            num_workers=self.args.dataloader_num_workers,
            drop_last=True,
            pin_memory=True,
            collate_fn=batchify)


    def get_eval_dataloader(self, eval_dataset) -> DataLoader:
        dataset = self.eval_dataset
        sampler = DistributedSampler(dataset, shuffle=False)
        return torch.utils.data.DataLoader(
            dataset, batch_size=self._train_batch_size,
            sampler=sampler,
            num_workers=self.args.dataloader_num_workers,
            drop_last=True,
            pin_memory=True,
            collate_fn=batchify)

video关键帧处理

我们将video的关键帧进行提取特征处理,形成tokens,每个token代表该关键帧中的高级语义信息。之后将这些token和文本的token进行拼接送入大语言模型进行训练。

if video_pixel_values is not None:
    video_embeds = self.vision_model(video_pixel_values, return_dict=True).last_hidden_state
    video_attention_mask = torch.ones(video_embeds.size()[:-1], dtype=torch.long, device=video_embeds.device)
    video_attention_mask = einops.rearrange(
        video_attention_mask, 'b t n -> b (t n)'
    )
    query_tokens = self.query_tokens.expand(video_embeds.shape[0], -1, -1)
    temporal_query_tokens = self.temporal_query_tokens.expand(video_embeds.shape[0], -1, -1)

    video_query_features = self.abstractor(
        query_embeds=query_tokens,
        temporal_query_embeds=temporal_query_tokens,
        encoder_hidden_states=video_embeds,
        encoder_attention_mask=video_attention_mask,
    )["last_hidden_state"]
    vid_seq_length = video_query_features.shape[1]



#################################################

num_images_per_sample = num_images.long().cpu().tolist()
num_videos_per_sample = num_videos.long().cpu().tolist()

text_chunk_embeds = []
img_idx = 0
for b in range(batch_size):
    start = 0
    result = []
    if len(media_token_indices[b]) > 0:
        for i, pos in enumerate(media_token_indices[b]):
            if pos > start:
                result.append(text_embeds[b, start:pos])
            result.append(query_features[img_idx + i])
            start = pos + img_seq_length
    if start < text_embeds.shape[1]:
        result.append(text_embeds[b, start:])

    img_idx += num_images_per_sample[b]
    text_chunk_embeds.append(torch.cat(result, dim=0))

input_embeds = torch.stack(text_chunk_embeds, dim=0)

训练流程

自定义训练器类

        首先,我们需要创建一个自定义训练器类,以便使用分布式采样器进行数据加载。这个类确保训练和评估数据集在分布式环境下能够正确地被采样和加载。

class CustomTrainer(Trainer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def get_train_dataloader(self) -> DataLoader:
        dataset = self.train_dataset
        sampler = DistributedSampler(dataset)
        return torch.utils.data.DataLoader(
            dataset, batch_size=self._train_batch_size,
            sampler=sampler,
            num_workers=self.args.dataloader_num_workers,
            drop_last=True,
            pin_memory=True,
            collate_fn=batchify
        )

    def get_eval_dataloader(self, eval_dataset) -> DataLoader:
        dataset = self.eval_dataset
        sampler = DistributedSampler(dataset, shuffle=False)
        return torch.utils.data.DataLoader(
            dataset, batch_size=self._train_batch_size,
            sampler=sampler,
            num_workers=self.args.dataloader_num_workers,
            drop_last=True,
            pin_memory=True,
            collate_fn=batchify
        )

主函数

        在主函数中,我们解析命令行参数,加载配置,并初始化模型和 tokenizer。然后,我们根据配置和参数设置训练和评估数据集,最后创建并启动训练器。

 在主函数中,我们解析命令行参数,加载配置,并初始化模型和 tokenizer。然后,我们根据配置和参数设置训练和评估数据集,最后创建并启动训练器。

def main():
    args, left_argv = parser.parse_known_args()  
    config = Config(args.mm_config)

    set_args(args)
    
    model = MplugOwlForConditionalGeneration.from_pretrained(
        args.pretrained_ckpt,
        torch_dtype=torch.bfloat16 if args.bf16 else torch.half,
    )
    
    tokenizer = AutoTokenizer.from_pretrained(args.pretrained_ckpt)

    if args.gradient_checkpointing:
        def make_inputs_require_grad(module, input, output):
            output.requires_grad_(True)
        model.language_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
        model.gradient_checkpointing_enable()

    else:
        for name, param in model.named_parameters():
            if 'language_model' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
        if args.gradient_checkpointing:
            model.language_model.apply(
                partial(model.language_model._set_gradient_checkpointing, value=True))

    print("model.train()")
    model.train()

    train_data, valid_data = train_valid_test_datasets_provider(
        config.data_files, config=config, 
        tokenizer=tokenizer, seq_length=args.seq_length
    )

    trainer = CustomTrainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=valid_data,
        args=TrainingArguments(
            learning_rate=args.lr,
            warmup_steps=args.num_warmup_steps,
            do_train=args.do_train,
            num_train_epochs=args.train_epochs,
            output_dir=args.save_path,
            save_strategy='steps',
            save_steps=args.save_interval,
            evaluation_strategy='steps',
            eval_steps=args.eval_iters,
            per_device_train_batch_size=args.micro_batch_size,
            max_grad_norm=args.clip_grad,
            weight_decay=args.weight_decay,
            bf16=args.bf16,
            fp16=not args.bf16,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            gradient_checkpointing=args.gradient_checkpointing,
            logging_steps=args.eval_iters//4,
            logging_nan_inf_filter=args.logging_nan_inf_filter,
            ddp_find_unused_parameters=args.ddp_find_unused_parameters,
        ),
    )

    if torch.__version__ >= "2" and sys.platform != "win32":
        model = torch.compile(model)

    print("trainer.train()")
    trainer.train()

    model.save_pretrained(args.save_path)
  • 8
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值