Pysyft学习笔记,06Minst+CNN

本文介绍如何利用PySyft库在PyTorch1.0环境中快速实现联邦学习,通过10行代码演示如何在MNIST数据集上进行模型训练,强调了数据隐私保护的重要性。教程逐步展示了如何在分散的工作机上分布数据并进行联合训练。
摘要由CSDN通过智能技术生成

使用PyTorch + PySyft 10行代码升级到联邦学习

背景

联邦学习是一种非常令人兴奋且令人振奋的机器学习技术,旨在建立可在分散数据上学习的系统。想法是,数据保留在其生产者(也称为_worker_)的手中,这有助于改善隐私和所有权,并且该模型在工作机之间共享。例如,一种直接的应用程序是在编写文本时预测手机上的下一个单词:您不希望将用于训练的数据(即,短信)发送到中央服务器。

因此,联合学习的兴起与数据隐私意识的传播紧密相关,并且自2018年5月起实施数据保护的欧盟GDPR成为催化剂。为了遵循法规,苹果或谷歌等大型参与者已开始对该技术进行大量投资,特别是为了保护移动用户的隐私,但他们尚未提供其工具。在OpenMined,我们相信愿意进行机器学习项目的任何人都应该能够毫不费力地实现隐私保护工具。我们已经构建了用于单行加密数据的工具如我们的博客文章所述,现在我们发布了利用新的PyTorch 1.0版本提供了直观的界面来构建安全且可扩展的模型。

在这个教程中,我们直接使用了例子 the canonical example of training a CNN on MNIST using PyTorch ,展示使用我们的库 PySyft library升级为联邦学习是多么简单。我们将遍历示例的每个部分,并在不同的代码后着重标注。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

import syft as sy  # <-- NEW: import the Pysyft library
hook = sy.TorchHook(torch)  # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning
bob = sy.VirtualWorker(hook, id="bob")  # <-- NEW: define remote worker bob
alice = sy.VirtualWorker(hook, id="alice")  # <-- NEW: and alice


#我们定义学习任务的设置
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 1000
        self.epochs = 10
        self.momentum = 0.5
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 30
        self.save_model = False
        self.lr=0.1#自己添加的,原文没有

args = Arguments()

#use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} #if use_cuda else {}

'''数据加载并发送给工作机
我们首先加载数据,然后使用.federate方法将训练数据集转换为跨工作人员的联合数据集。
现在,该联合数据集已提供给FederatedDataLoader。测试数据集保持不变。'''

federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
    .federate((bob, alice)), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
    batch_size=args.batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)

#CNN规格
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


'''定义训练和测试函数
对于训练功能,由于数据批次分布在alice和bob之间,因此您需要将模型发送到每个批次的正确位置。 
然后,您使用相同的语法远程执行所有操作,就像执行本地PyTorch一样。完成后,您需要恢复模型更新和损失以寻求改进。'''

def train(args, model, device, federated_train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset
        model.send(data.location) # <-- NEW: send the model to the right location
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        model.get() # <-- NEW: get the model back
        if batch_idx % args.log_interval == 0:
            loss = loss.get() # <-- NEW: get the loss back
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(federated_train_loader) * args.batch_size,
                100. * batch_idx / len(federated_train_loader), loss.item()))


def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


#%%time
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment

for epoch in range(1, args.epochs + 1):
    train(args, model, device, federated_train_loader, optimizer, epoch)
    test(args, model, device, test_loader)

# if (args.save_model):
#     torch.save(model.state_dict(), "mnist_cnn.pt"

不知为何报错AttributeError: 'Arguments' object has no attribute 'lr'发现对于学习的设置里没有学习率

同时有警告如下

运行一会后报错

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 105, in spawn_main
    exitcode = _main(fd)
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 114, in _main
    prepare(preparation_data)
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 225, in prepare
    _fixup_main_from_path(data['init_main_from_path'])
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 277, in _fixup_main_from_path
    run_name="__mp_main__")
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\runpy.py", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\runpy.py", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "C:\Users\PH34R\AppData\Roaming\JetBrains\PyCharmCE2023.3\scratches\Pysyftlearn4Minst+CNN.py", line 123, in <module>
    test(args, model, device, test_loader)
  File "C:\Users\PH34R\AppData\Roaming\JetBrains\PyCharmCE2023.3\scratches\Pysyftlearn4Minst+CNN.py", line 103, in test
    for data, target in test_loader:
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 279, in __iter__
    return _MultiProcessingDataLoaderIter(self)
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 719, in __init__
    w.start()
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\process.py", line 112, in start
    self._popen = self._Popen(self)
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\context.py", line 223, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\context.py", line 322, in _Popen
    return Popen(process_obj)
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\popen_spawn_win32.py", line 46, in __init__
    prep_data = spawn.get_preparation_data(process_obj._name)
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 143, in get_preparation_data
    _check_not_importing_main()
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 136, in _check_not_importing_main
    is not going to be frozen to produce an executable.''')
RuntimeError: 
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.
Traceback (most recent call last):
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 761, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\queue.py", line 178, in get
    raise Empty
_queue.Empty

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\Users\PH34R\AppData\Roaming\JetBrains\PyCharmCE2023.3\scratches\Pysyftlearn4Minst+CNN.py", line 123, in <module>
    test(args, model, device, test_loader)
  File "C:\Users\PH34R\AppData\Roaming\JetBrains\PyCharmCE2023.3\scratches\Pysyftlearn4Minst+CNN.py", line 103, in test
    for data, target in test_loader:
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 345, in __next__
    data = self._next_data()
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 841, in _next_data
    idx, data = self._get_data()
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 798, in _get_data
    success, data = self._try_get_data()
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 774, in _try_get_data
    raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
RuntimeError: DataLoader worker (pid(s) 20604) exited unexpectedly

解决问题1:

RuntimeError: 
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.

修改dataloader

'''数据加载并发送给工作机
我们首先加载数据,然后使用.federate方法将训练数据集转换为跨工作人员的联合数据集。
现在,该联合数据集已提供给FederatedDataLoader。测试数据集保持不变。'''

if __name__ == '__main__':
    import multiprocessing
    multiprocessing.freeze_support()

    federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]))
        .federate((bob, alice)), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
        batch_size=args.batch_size, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))

                       ])),

        batch_size=args.test_batch_size, shuffle=True, **kwargs)

报错

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 105, in spawn_main
    exitcode = _main(fd)
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 114, in _main
    prepare(preparation_data)
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 225, in prepare
    _fixup_main_from_path(data['init_main_from_path'])
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\multiprocessing\spawn.py", line 277, in _fixup_main_from_path
    run_name="__mp_main__")
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\runpy.py", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\runpy.py", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "C:\Users\PH34R\AppData\Roaming\JetBrains\PyCharmCE2023.3\scratches\Pysyftlearn4Minst+CNN.py", line 128, in <module>
    train(args, model, device, federated_train_loader, optimizer, epoch)
NameError: name 'federated_train_loader' is not defined
Traceback (most recent call last):
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 761, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\queue.py", line 178, in get
    raise Empty
_queue.Empty

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\Users\PH34R\AppData\Roaming\JetBrains\PyCharmCE2023.3\scratches\Pysyftlearn4Minst+CNN.py", line 129, in <module>
    test(args, model, device, test_loader)
  File "C:\Users\PH34R\AppData\Roaming\JetBrains\PyCharmCE2023.3\scratches\Pysyftlearn4Minst+CNN.py", line 109, in test
    for data, target in test_loader:
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 345, in __next__
    data = self._next_data()
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 841, in _next_data
    idx, data = self._get_data()
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 798, in _get_data
    success, data = self._try_get_data()
  File "D:\School\tech\env\ananconda\envs\pysyft24\lib\site-packages\torch\utils\data\dataloader.py", line 774, in _try_get_data
    raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
RuntimeError: DataLoader worker (pid(s) 13080) exited unexpectedly

解决办法

修改

kwargs = {'num_workers': 0, 'pin_memory': False} #if use_cuda else {}

运行结果:

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值