DDP下的model.parameters()和model.module.parameters()

在pytorch中使用torch.nn.parallel.DistributedDataParallel(简称DDP)封装一个模型时,原始模型被封装在DDP实例的.module属性下。下面将简单记录下model.parameters()和model.module.parameters()之间的区别:

1、model.module.parameters()

这个调用直接返回原始模型(即封装前的model模型)的参数。 这里的.module就是指向原来传入DDP的那个模型实例。 如果需要访问或者操作模型参数的原始状态,比如保存或加载模型权重,通常会使用这种方式。

2、model.parameters()

当对DDP对象调用.parameters()时,它实际上是在调用model.module.parameters(),因为DDP类的.parameters()方法被设计为委托给它的.module属性。

也就是说,model.parameters()和model.module.parameters()在功能上是等价的,都会返回原始模型的参数。

代码示例如下,

import os, torch, warnings
from torch import distributed
import torchvision.models as models
warnings.filterwarnings("ignore")


try:
    world_size = int(os.environ["WORLD_SIZE"])
    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    distributed.init_process_group("nccl")
except KeyError:
    world_size = 1
    rank = 0
    local_rank = 0
    distributed.init_process_group(
        backend="nccl",
        init_method="tcp://127.0.0.1:12584",
        rank=rank,
        world_size=world_size,
    )


if __name__ == "__main__":
    torch.cuda.set_device(local_rank)
    
    model = models.resnet50(pretrained=False)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(local_rank)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
    
    params1 = list(model.parameters())
    params2 = list(model.module.parameters())
    
    if rank == 0:
        print(params1 == params2)

输出结果如下,

[2024-09-11 20:03:24,934] torch.distributed.run: [WARNING] 
[2024-09-11 20:03:24,934] torch.distributed.run: [WARNING] *****************************************
[2024-09-11 20:03:24,934] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2024-09-11 20:03:24,934] torch.distributed.run: [WARNING] *****************************************
True
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

chen_znn

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值