pytorch实现多个模型的weights平均和修改weights

1. 操作说明

有3个结构相同但是weights不同的model组成一个list,models=[model1,model2,model3],还有一个中心模型fl_model,这四个模型的结构和超参数都相同。

需要进行这样一种操作:平均models里面三个模型的weights,把平均之后的weights"赋值"给fl_model的weights。

2.代码

在tensorflow里可以直接用model.get_weights()和model.set_weights()来做,比较直观和方便。感觉pytorch里面稍微复杂一些。进行上述操作的代码如下:

worker_state_dict = [x.state_dict() for x in models]
weight_keys = list(worker_state_dict[0].keys())
fed_state_dict = collections.OrderedDict()
for key in weight_keys:
    key_sum = 0
    for i in range(len(models)):
        key_sum = key_sum + worker_state_dict[i][key]
    fed_state_dict[key] = key_sum / len(models)
#### update fed weights to fl model
fl_model.load_state_dict(fed_state_dict)
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值