import logging
import os
import torch.distributed as dist
from torch.distributed import init_process_group


def setup_logger(log_file):
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # Create file handler which logs even debug messages
    fh = logging.FileHandler(log_file)
    fh.setLevel(logging.INFO)

    # Create console handler with a higher log level
    ch = logging.StreamHandler()
    ch.setLevel(logging.WARNING)

    # Create formatter and add it to the handlers
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    fh.setFormatter(formatter)
    ch.setFormatter(formatter)

    # Add the handlers to the logger
    logger.addHandler(fh)
    logger.addHandler(ch)
    return logger

if __name__ == "__main__":
    # Example log file: logs/log_rank_0.log, logs/log_rank_1.log, etc.
    # python -m torch.distributed.launch --nproc_per_node=4 test_data.py
    init_process_group(backend="nccl")
    rank = dist.get_rank()
    os.makedirs("logs", exist_ok=True)
    log_file = os.path.join("logs", f"log_rank_{rank}.log")
    logger = setup_logger(log_file)
    logger.info(f"Logging from rank {rank}")
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.

然后命令行运行:

python -m torch.distributed.launch --nproc_per_node=4 test_data.py
  • 1.

结果展示:

pytorch将不同的卡的日志输出到不同的文件_python


其中一张卡的内容为:

2024-08-17 10:16:37,947 - root - INFO - Logging from rank 1
  • 1.