使用PySyft进行一次简单的联邦学习模型训练
import torch
from torch import nn
from torch import optim
import syft as sy
# 扩展pytorch功能使其满足联邦学习训练
hook = sy.TorchHook(torch)
# 创建两个工作机
Li = sy.VirtualWorker(hook, id='li')
Zhang = sy.VirtualWorker(hook, id='zhang')
# 定义简易模型
data = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1.]], requires_grad=True)
target = torch.tensor([[0], [0], [0], [1.]], requires_grad=True)
model = nn.Linear(2, 1)
# 将训练数据分成两部分,分别发送给两个工作机
dataLi = data[0:2]
targetLi = target[0:2]
dataZhang = data[2:]
targetZhang = target[2:]
dataLi = dataLi.send(Li)
dataZhang = dataZhang.send(Zhang)
targetLi = targetLi.send(Li)
targetZhang = targetLi.send(Zhang)
# 存储张量指针
datasets = [(dataLi, targetLi), (dataZhang, targetZhang)]
# 定义训练函数
def train():
# 优化器
opt = optim.SGD(params=model.parameters(), lr=0.1)
for item in range(50):
# 遍历每个工作机的数据集
for data, target in datasets:
# 将模型发送给对应的工作机
model.send(data.location)
# 消除之前的梯度
opt.zero_grad()
# 预测
pre = model(data)
# 计算损失
loss = ((pre - target) ** 2).sum()
# 回传损失
loss.backward()
# 更新参数
opt.step()
# 获取模型
model.get()
# 打印进程
print('epoch', item, loss.values)
pass
pass
# 开始训练
train()
参考文献
- 王健宗,李泽远,何安珣. 《深入浅出联邦学习:原理与实践》. 机械工业出版社. 2021年5月