import logging
import os
import torch.distributed as dist
from torch.distributed import init_process_group
def setup_logger(log_file):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Create file handler which logs even debug messages
fh = logging.FileHandler(log_file)
fh.setLevel(logging.INFO)
# Create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.WARNING)
# Create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
# Add the handlers to the logger
logger.addHandler(fh)
logger.addHandler(ch)
return logger
if __name__ == "__main__":
# Example log file: logs/log_rank_0.log, logs/log_rank_1.log, etc.
# python -m torch.distributed.launch --nproc_per_node=4 test_data.py
init_process_group(backend="nccl")
rank = dist.get_rank()
os.makedirs("logs", exist_ok=True)
log_file = os.path.join("logs", f"log_rank_{rank}.log")
logger = setup_logger(log_file)
logger.info(f"Logging from rank {rank}")
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
然后命令行运行:
结果展示:
其中一张卡的内容为: