import torch
import syft as sy
import copy
hook = sy.TorchHook(torch)
from torch import nn, optim
'''
Part 4: Federated Learning with Model Averaging
http://localhost:8888/notebooks/git-home/github/PySyft/examples/tutorials/Part%2004%20-%20Federated%20Learning%20via%20Trusted%20Aggregator.ipynb
'''
"""
本例演示:
A节点运行脚本。B、C两个节点分别有样本集,各自训练一个模型。D节点把B和C节点的模型进行简单平均。
"""
#创建worker
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
secure_worker = sy.VirtualWorker(hook, id="secure_worker")
#这个数据集不适合做模型平均
##数据集
#data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
#target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)
##分拆成不同的子数据集,发送给worker
#bobs_data = data[0:2].send(bob)
#bobs_target = target[0:2].send(bob)
#alices_data = data[2:].send(alice)
#alices_