使用pySyft进行一次简单的联邦平均FedAVG
import torch
from torch import nn
from torch import optim
import syft as sy
# 扩展pytorch功能使其满足联邦学习训练
hook = sy.TorchHook(torch)
# 建立工作机和安全工作机,工作机作为客户端,用来训练模型
# 安全工作机作为服务器,用于数据的聚合和交流
Li = sy.VirtualWorker(hook, id='li')
Zhang = sy.VirtualWorker(hook, id='zhang')
secure_worker = sy.VirtualWorker(hook, id='secure_worker')
data = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1.]], requires_grad=True)
target = torch.tensor([[0], [0], [0], [1.]], requires_grad=True)
dataLi = data[0:2].send(Li)
targetLi = target[0:2].send(Li)
dataZhang = data[2:].send(Zhang)
targetZhang = target[2:].send(Zhang)
# 建立模型
model = nn.Linear(2, 1)
# 训练函数
def train():
# 设置迭代次数
interations = 20
workerInters = 5
for inter in range(interations):
# 将服务器上全局模型发给两个参与方
LiModel = model.copy().send(Li)
ZhangModel = model.copy().send(Zhang)
liOpt = optim.SGD(params=LiModel.parameters(), lr=0.1)
ZhangOpt = optim.SGD(params=ZhangModel.parameters(), lr=0.1)
for wi in range(workerInters):
# li训练一次
# 消除之前的梯度
liOpt.zero_grad()
# 预测
liPre = LiModel(dataLi)
# 计算损失
liLoss = ((liPre - targetLi) ** 2).sum()
# 回传损失
liLoss.backward()
# 更新参数
liOpt.step()
liLoss = liLoss.get().data
# Zhang训练一次
ZhangOpt.zero_grad()
ZhangPre = ZhangModel(dataZhang)
ZhangLoss = ((ZhangPre - targetZhang) ** 2).sum()
ZhangLoss.backward()
ZhangOpt.step()
ZhangLoss = ZhangLoss.get().data
# 将更新的局部模型发送给安全工作机
LiModel.move(secure_worker)
ZhangModel.move(secure_worker)
# 模型平均
with torch.no_grad():
model.weight.set_(((ZhangModel.weight.data + LiModel.weight.data) / 2).get())
model.bias.set_(((ZhangModel.bias.data + LiModel.bias.data) / 2).get())
print('第' + str(inter+1) + '轮')
print('Li: ' + str(liLoss) + ' zhang: ' + str(ZhangLoss))
pass
pass
# 开始训练
train()
# 用全局模型预测训练结果
preSecure = model(data)
loss = ((preSecure-target)**2).sum()
print(target)
print(preSecure)
print(loss.data)
运行结果:
第1轮
Li: tensor(6.9580e-05) zhang: tensor(0.2173)
第2轮
Li: tensor(0.0061) zhang: tensor(0.1492)
第3轮
Li: tensor(0.0168) zhang: tensor(0.1112)
第4轮
Li: tensor(0.0268) zhang: tensor(0.0893)
第5轮
Li: tensor(0.0347) zhang: tensor(0.0762)
第6轮
Li: tensor(0.0406) zhang: tensor(0.0681)
第7轮
Li: tensor(0.0448) zhang: tensor(0.0631)
第8轮
Li: tensor(0.0477) zhang: tensor(0.0599)
第9轮
Li: tensor(0.0498) zhang: tensor(0.0579)
第10轮
Li: tensor(0.0512) zhang: tensor(0.0566)
第11轮
Li: tensor(0.0522) zhang: tensor(0.0557)
第12轮
Li: tensor(0.0529) zhang: tensor(0.0552)
第13轮
Li: tensor(0.0534) zhang: tensor(0.0548)
第14轮
Li: tensor(0.0537) zhang: tensor(0.0546)
第15轮
Li: tensor(0.0540) zhang: tensor(0.0544)
第16轮
Li: tensor(0.0542) zhang: tensor(0.0544)
第17轮
Li: tensor(0.0543) zhang: tensor(0.0543)
第18轮
Li: tensor(0.0545) zhang: tensor(0.0543)
第19轮
Li: tensor(0.0546) zhang: tensor(0.0542)
第20轮
Li: tensor(0.0546) zhang: tensor(0.0542)
tensor([[0.],
[0.],
[0.],
[1.]], requires_grad=True)
tensor([[-0.1793],
[ 0.3207],
[ 0.1649],
[ 0.6649]], grad_fn=<AddmmBackward>)
tensor(0.2745)Process finished with exit code 0
参考文献
- 王健宗,李泽远,何安珣. 《深入浅出联邦学习:原理与实践》. 机械工业出版社. 2021年5月
该文演示了如何利用pySyft库在Python中进行联邦学习的FedAVG算法。通过创建虚拟工作机模拟分布式环境,对数据进行分发,然后在每个客户端上独立训练模型,最后在服务器端聚合模型参数,实现了在保护数据隐私的同时进行模型优化。
1126

被折叠的 条评论
为什么被折叠?



