在pytorch中使用torch.nn.parallel.DistributedDataParallel(简称DDP)封装一个模型时,原始模型被封装在DDP实例的.module属性下。下面将简单记录下model.parameters()和model.module.parameters()之间的区别:
1、model.module.parameters()
这个调用直接返回原始模型(即封装前的model模型)的参数。 这里的.module就是指向原来传入DDP的那个模型实例。 如果需要访问或者操作模型参数的原始状态,比如保存或加载模型权重,通常会使用这种方式。
2、model.parameters()
当对DDP对象调用.parameters()时,它实际上是在调用model.module.parameters(),因为DDP类的.parameters()方法被设计为委托给它的.module属性。
也就是说,model.parameters()和model.module.parameters()在功能上是等价的,都会返回原始模型的参数。
代码示例如下,
import os, torch, warnings
from torch import distributed
import torchvision.models as models
warnings.filterwarnings("ignore")
try:
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
distributed.init_process_group("nccl")
except KeyError:
world_size = 1
rank = 0
local_rank = 0
distributed.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:12584",
rank=rank,
world_size=world_size,
)
if __name__ == "__main__":
torch.cuda.set_device(local_rank)
model = models.resnet50(pretrained=False)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
params1 = list(model.parameters())
params2 = list(model.module.parameters())
if rank == 0:
print(params1 == params2)
输出结果如下,
[2024-09-11 20:03:24,934] torch.distributed.run: [WARNING]
[2024-09-11 20:03:24,934] torch.distributed.run: [WARNING] *****************************************
[2024-09-11 20:03:24,934] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
[2024-09-11 20:03:24,934] torch.distributed.run: [WARNING] *****************************************
True