介绍
本阶段任务为训练代码编写、调试等工作。
分布式数据加载
为了使用多卡并行训练,我们重新编写了分布式数据集代码CustomTrainer类,如下:
CustomTrainer(Trainer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_train_dataloader(self) -> DataLoader:
dataset = self.train_dataset
sampler = DistributedSampler(dataset)
return torch.utils.data.DataLoader(
dataset, batch_size=self._train_batch_size,
sampler=sampler,
num_workers=self.args.dataloader_num_workers,
drop_last=True,
pin_memory=True,
collate_fn=batchify)
def get_eval_dataloader(self, eval_dataset) -> DataLoader:
dataset = self.eval_dataset
sampler = DistributedSampler(dataset, shuffle=False)
return torch.utils.data.DataLoader(
dataset, batch_size=self._train_batch_size,
sampler=sampler,
num_workers=self.args.dataloader_num_workers,
drop_last=True,
pin_memory=True,
collate_fn=batchify)
video关键帧处理
我们将video的关键帧进行提取特征处理,形成tokens,每个token代表该关键帧中的高级语义信息。之后将这些token和文本的token进行拼接送入大语言模型进行训练。
if video_pixel_values is not None:
video_embeds = self.vision_model(video_pixel_values, return_dict=True).last_hidden_state
video_attention_mask = torch.ones(video_embeds.size()[:-1], dtype=torch.long, device=video_embeds.device)
video_attention_mask = einops.rearrange(
video_attention_mask, 'b t n -> b (t n)'
)
query_tokens = self.query_tokens.expand(video_embeds.shape[0], -1, -1)
temporal_query_tokens = self.temporal_query_tokens.expand(video_embeds.shape[0], -1, -1)
video_query_features = self.abstractor(
query_embeds=query_tokens,
temporal_query_embeds=temporal_query_tokens,
encoder_hidden_states=video_embeds,
encoder_attention_mask=video_attention_mask,
)["last_hidden_state"]
vid_seq_length = video_query_features.shape[1]
#################################################
num_images_per_sample = num_images.long().cpu().tolist()
num_videos_per_sample = num_videos.long().cpu().tolist()
text_chunk_embeds = []
img_idx = 0
for b in range(batch_size):
start = 0
result = []
if len(media_token_indices[b]) > 0:
for i, pos in enumerate(media_token_indices[b]):
if pos > start:
result.append(text_embeds[b, start:pos])
result.append(query_features[img_idx + i])
start = pos + img_seq_length
if start < text_embeds.shape[1]:
result.append(text_embeds[b, start:])
img_idx += num_images_per_sample[b]
text_chunk_embeds.append(torch.cat(result, dim=0))
input_embeds = torch.stack(text_chunk_embeds, dim=0)
训练流程
自定义训练器类
首先,我们需要创建一个自定义训练器类,以便使用分布式采样器进行数据加载。这个类确保训练和评估数据集在分布式环境下能够正确地被采样和加载。
class CustomTrainer(Trainer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_train_dataloader(self) -> DataLoader:
dataset = self.train_dataset
sampler = DistributedSampler(dataset)
return torch.utils.data.DataLoader(
dataset, batch_size=self._train_batch_size,
sampler=sampler,
num_workers=self.args.dataloader_num_workers,
drop_last=True,
pin_memory=True,
collate_fn=batchify
)
def get_eval_dataloader(self, eval_dataset) -> DataLoader:
dataset = self.eval_dataset
sampler = DistributedSampler(dataset, shuffle=False)
return torch.utils.data.DataLoader(
dataset, batch_size=self._train_batch_size,
sampler=sampler,
num_workers=self.args.dataloader_num_workers,
drop_last=True,
pin_memory=True,
collate_fn=batchify
)
主函数
在主函数中,我们解析命令行参数,加载配置,并初始化模型和 tokenizer。然后,我们根据配置和参数设置训练和评估数据集,最后创建并启动训练器。
在主函数中,我们解析命令行参数,加载配置,并初始化模型和 tokenizer。然后,我们根据配置和参数设置训练和评估数据集,最后创建并启动训练器。
def main():
args, left_argv = parser.parse_known_args()
config = Config(args.mm_config)
set_args(args)
model = MplugOwlForConditionalGeneration.from_pretrained(
args.pretrained_ckpt,
torch_dtype=torch.bfloat16 if args.bf16 else torch.half,
)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_ckpt)
if args.gradient_checkpointing:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.language_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model.gradient_checkpointing_enable()
else:
for name, param in model.named_parameters():
if 'language_model' in name:
param.requires_grad = True
else:
param.requires_grad = False
if args.gradient_checkpointing:
model.language_model.apply(
partial(model.language_model._set_gradient_checkpointing, value=True))
print("model.train()")
model.train()
train_data, valid_data = train_valid_test_datasets_provider(
config.data_files, config=config,
tokenizer=tokenizer, seq_length=args.seq_length
)
trainer = CustomTrainer(
model=model,
train_dataset=train_data,
eval_dataset=valid_data,
args=TrainingArguments(
learning_rate=args.lr,
warmup_steps=args.num_warmup_steps,
do_train=args.do_train,
num_train_epochs=args.train_epochs,
output_dir=args.save_path,
save_strategy='steps',
save_steps=args.save_interval,
evaluation_strategy='steps',
eval_steps=args.eval_iters,
per_device_train_batch_size=args.micro_batch_size,
max_grad_norm=args.clip_grad,
weight_decay=args.weight_decay,
bf16=args.bf16,
fp16=not args.bf16,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_checkpointing=args.gradient_checkpointing,
logging_steps=args.eval_iters//4,
logging_nan_inf_filter=args.logging_nan_inf_filter,
ddp_find_unused_parameters=args.ddp_find_unused_parameters,
),
)
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
print("trainer.train()")
trainer.train()
model.save_pretrained(args.save_path)