今天需要在transformers的trainer的API中使用profile,然后分析模型的性能,但是trainer的封装度比较高,不太好修改,其实可以使用callback的方式完成profile的性能监控。
class MyCallback(TrainerCallback):
"A callback that prints a message at the beginning of training"
def __init__(self, prof):
self.prof = prof
def on_train_begin(self, args, state, control, **kwargs):
print("Starting training")
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
self.prof.step()
然后在trianier实例化的时候,传入callback:
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(skip_first=3, wait=1, warmup=1, active=2, repeat=2),
on_trace_ready=torch.profiler.tensorboard_trace_handler('hf-training-trainer'),
profile_memory=True,
with_stack=True,
record_shapes=True) as prof:
trainer.add_callback(MyCallback(prof=prof))
trainer.train()
参考文献
Is there a pytorch profiler integration with huggingface trainer?