Pytorch实现联邦学习中遇到的坑

1. 服务器分发模型时记得深拷贝数据或者整个模型

如果不这么做,你会发现其实就是一个模型在不断地被训练。这里有两种办法可以实现:
在PyTorch中,可以使用torch.nn.Module的state_dict()和load_state_dict()方法来拷贝模型。 state_dict()方法返回模型参数的字典,而load_state_dict()方法接受一个字典,并将其加载到模型中。
以下是一个简单的示例,展示如何使用这些方法拷贝一个模型:

import torch.nn as nn

# 创建一个模型
model1 = nn.Linear(10, 1)

# 拷贝模型
model2 = nn.Linear(10, 1)
model2.load_state_dict(model1.state_dict())

在上面的示例中,我们首先创建了一个名为model1的模型,它是一个线性层,输入维度为10,输出维度为1。然后,我们创建了一个名为model2的新模型,具有相同的结构和参数。通过调用model1.state_dict(),我们获取model1的参数字典,并通过调用model2.load_state_dict()将该字典加载到model2中。
这种方法只会拷贝模型的参数,而不会拷贝其他属性,例如模型的名称、优化器状态等。如果您需要拷贝整个模型,包括这些属性,您可以使用Python的copy模块,例如:

# 拷贝整个模型
model2 = copy.deepcopy(model1)

在这种情况下,deepcopy()方法会创建一个全新的模型对象,并复制模型的所有属性和参数。

2. 如何获得模型里的参数?

for name, param in self.global_model.state_dict().items():
	c.local_model.state_dict()[name].copy_(param.clone())

模型里有一个方法state_dict()是将对象的数据变为字典展示,因此我们只需要像字典一样处理就可以,调用items()遍历整个字典,name是模型对应的层数,param是该层的tensor,使用state_dict()[name]的时候是会修改到模型参数的。
所以这里的c.local_model.state_dict()[name].copy_(param.clone())其实可以直接写成c.local_model.state_dict()[name] = param.clone()
这里也有个坑点,如果你想要复制一个tensor,都必须要使用clone(),不然默认是传递引用。

3. 一群Subset中的dataset其实是同一个

subset_dataset = Subset(train_datasets, subset_indices)
subset_dataset.dataset

曾经以为subset获得的数据就是一个一个的子集了,那我subset_dataset.dataset[0]是不是就可以获得每一个子集的第一个数据?答案是错误的,所有的子集的dataset成员都是一样的,基本大部分pytorch的data或dataset成员都保存着原始数据,获取子集的行为包含在函数里。

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

volcanical

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值