使用PyTorch + PySyft 10行代码升级到联邦学习
背景
联邦学习是一种非常令人兴奋且令人振奋的机器学习技术,旨在建立可在分散数据上学习的系统。想法是,数据保留在其生产者(也称为_worker_)的手中,这有助于改善隐私和所有权,并且该模型在工作机之间共享。例如,一种直接的应用程序是在编写文本时预测手机上的下一个单词:您不希望将用于训练的数据(即,短信)发送到中央服务器。
因此,联合学习的兴起与数据隐私意识的传播紧密相关,并且自2018年5月起实施数据保护的欧盟GDPR成为催化剂。为了遵循法规,苹果或谷歌等大型参与者已开始对该技术进行大量投资,特别是为了保护移动用户的隐私,但他们尚未提供其工具。在OpenMined,我们相信愿意进行机器学习项目的任何人都应该能够毫不费力地实现隐私保护工具。我们已经构建了用于单行加密数据的工具如我们的博客文章所述,现在我们发布了利用新的PyTorch 1.0版本提供了直观的界面来构建安全且可扩展的模型。
在这个教程中,我们直接使用了例子 the canonical example of training a CNN on MNIST using PyTorch ,展示使用我们的库 PySyft library升级为联邦学习是多么简单。我们将遍历示例的每个部分,并在不同的代码后着重标注。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import syft as sy # <-- NEW: import the Pysyft library
hook = sy.TorchHook(torch) # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning
bob = sy.VirtualWorker(hook, id="bob") # <-- NEW: define remote worker bob
alice = sy.VirtualWorker(hook, id="alice") # <-- NEW: and alice
#我们定义学习任务的设置
class Arguments():
def __init__(self):
self.batch_size = 64
self.test_batch_size = 1000
self.epochs = 10
self.momentum = 0.5
self.no_cuda = False
self.seed = 1
self.log_interval = 30
self.save_model = False
self.lr=0.1#自己添加的,原文没有
args = Arguments()
#use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} #if use_cuda else {}
'''数据加载并发送给工作机
我们首先加载数据,然后使用.federate方法将训练数据集转换为跨工作人员的联合数据集。
现在,该联合数据集已提供给FederatedDataLoader。测试数据集保持不变。'''
federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
.federate((bob, alice)), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
#CNN规格
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
'''定义训练和测试函数
对于训练功能,由于数据批次分布在alice和bob之间,因此您需要将模型发送到每个批次的正确位置。
然后,您使用相同的语法远程执行所有操作,就像执行本地PyTorch一样。完成后,您需要恢复模型更新和损失以寻求改进。'''
def train(args, model, device, federated_train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset
model.send(data.location) # <-- NEW: send the model to the right location
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
model.get() # <-- NEW: get the model back
if batch_idx % args.log_interval == 0:
loss = loss.get() # <-- NEW: get the loss back
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * args.batch_size, len(federated_train_loader) * args.batch_size,
100. * batch_idx / len(federated_train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
#%%time
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment
for epoch in range(1, args.epochs + 1):
train(args, model, device, federated_train_loader, optimizer, epoch)
test(args, model, device, test_loader)
# if (args.save_model):
# torch.save(model.state_dict(), "mnist_cnn.pt"
不知为何报错AttributeError: 'Arguments' object has no attribute 'lr'发现对于学习的设置里没有学习率
同时有警告如下
运行一会后报错
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 105, in spawn_main
exitcode = _main(fd)
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 114, in _main
prepare(preparation_data)
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 225, in prepare
_fixup_main_from_path(data['init_main_from_path'])
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 277, in _fixup_main_from_path
run_name="__mp_main__")
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\runpy.py", line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\runpy.py", line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\runpy.py", line 85, in _run_code
exec(code, run_globals)
File "C:\Users\PH34R\AppData\Roaming\JetBrains\PyCharmCE2023.3\scratches\Pysyftlearn4Minst+CNN.py", line 123, in <module>
test(args, model, device, test_loader)
File "C:\Users\PH34R\AppData\Roaming\JetBrains\PyCharmCE2023.3\scratches\Pysyftlearn4Minst+CNN.py", line 103, in test
for data, target in test_loader:
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 279, in __iter__
return _MultiProcessingDataLoaderIter(self)
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 719, in __init__
w.start()
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\process.py", line 112, in start
self._popen = self._Popen(self)
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\context.py", line 223, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\context.py", line 322, in _Popen
return Popen(process_obj)
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\popen_spawn_win32.py", line 46, in __init__
prep_data = spawn.get_preparation_data(process_obj._name)
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 143, in get_preparation_data
_check_not_importing_main()
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 136, in _check_not_importing_main
is not going to be frozen to produce an executable.''')
RuntimeError:
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.
This probably means that you are not using fork to start your
child processes and you have forgotten to use the proper idiom
in the main module:
if __name__ == '__main__':
freeze_support()
...
The "freeze_support()" line can be omitted if the program
is not going to be frozen to produce an executable.
Traceback (most recent call last):
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 761, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\queue.py", line 178, in get
raise Empty
_queue.Empty
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:\Users\PH34R\AppData\Roaming\JetBrains\PyCharmCE2023.3\scratches\Pysyftlearn4Minst+CNN.py", line 123, in <module>
test(args, model, device, test_loader)
File "C:\Users\PH34R\AppData\Roaming\JetBrains\PyCharmCE2023.3\scratches\Pysyftlearn4Minst+CNN.py", line 103, in test
for data, target in test_loader:
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 345, in __next__
data = self._next_data()
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 841, in _next_data
idx, data = self._get_data()
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 798, in _get_data
success, data = self._try_get_data()
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 774, in _try_get_data
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
RuntimeError: DataLoader worker (pid(s) 20604) exited unexpectedly
解决问题1:
RuntimeError:
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.
This probably means that you are not using fork to start your
child processes and you have forgotten to use the proper idiom
in the main module:
if __name__ == '__main__':
freeze_support()
...
The "freeze_support()" line can be omitted if the program
is not going to be frozen to produce an executable.
修改dataloader
'''数据加载并发送给工作机
我们首先加载数据,然后使用.federate方法将训练数据集转换为跨工作人员的联合数据集。
现在,该联合数据集已提供给FederatedDataLoader。测试数据集保持不变。'''
if __name__ == '__main__':
import multiprocessing
multiprocessing.freeze_support()
federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
.federate((bob, alice)), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
报错
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 105, in spawn_main
exitcode = _main(fd)
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 114, in _main
prepare(preparation_data)
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 225, in prepare
_fixup_main_from_path(data['init_main_from_path'])
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 277, in _fixup_main_from_path
run_name="__mp_main__")
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\runpy.py", line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\runpy.py", line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\runpy.py", line 85, in _run_code
exec(code, run_globals)
File "C:\Users\PH34R\AppData\Roaming\JetBrains\PyCharmCE2023.3\scratches\Pysyftlearn4Minst+CNN.py", line 128, in <module>
train(args, model, device, federated_train_loader, optimizer, epoch)
NameError: name 'federated_train_loader' is not defined
Traceback (most recent call last):
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 761, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\queue.py", line 178, in get
raise Empty
_queue.Empty
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:\Users\PH34R\AppData\Roaming\JetBrains\PyCharmCE2023.3\scratches\Pysyftlearn4Minst+CNN.py", line 129, in <module>
test(args, model, device, test_loader)
File "C:\Users\PH34R\AppData\Roaming\JetBrains\PyCharmCE2023.3\scratches\Pysyftlearn4Minst+CNN.py", line 109, in test
for data, target in test_loader:
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 345, in __next__
data = self._next_data()
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 841, in _next_data
idx, data = self._get_data()
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 798, in _get_data
success, data = self._try_get_data()
File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 774, in _try_get_data
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
RuntimeError: DataLoader worker (pid(s) 13080) exited unexpectedly
解决办法
修改
kwargs = {'num_workers': 0, 'pin_memory': False} #if use_cuda else {}
运行结果: