在本教程中,我们将学习如何使用 Flower 和 PyTorch 在 CIFAR10 上训练卷积神经网络
首先,建议创建一个虚拟环境并在其中运行所有内容。
我们的示例由一台服务器和两个具有相同模型的客户端组成。
客户端负责根据其本地数据集为模型生成单独的权重更新。然后将这些更新发送到服务器,服务器将聚合它们以生成更好的模型。最后,服务器将此改进版本的模型发送回每个客户端。权重更新的完整周期称为回合。
现在我们对正在发生的事情有了大致的了解,让我们开始吧。我们首先需要安装 Flower。您可以通过运行以下命令来执行此操作:
$ pip install flwr
既然我们想使用 PyTorch 来解决计算机视觉任务,让我们继续安装 PyTorch 和 torchvision 库:
$ pip install torch torchvision
Flower客户端
现在我们已经安装了所有依赖项,让我们使用两个客户端和一个服务器运行一个简单的分布式训练。我们的训练进程和网络架构基于 PyTorch 的深度学习和 PyTorch。
在名为 client.py
的文档中,导入 Flower 和 PyTorch 相关包:
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import flwr as fl
此外,我们在 PyTorch 中定义设备分配:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
我们使用 PyTorch 加载 CIFAR10,这是一个流行的用于机器学习的彩色图像分类数据集。PyTorch DataLoader()
下载训练和测试数据,然后对其进行规范化。
def load_data():
"""Load CIFAR-10 (training and test set)."""
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = CIFAR10(".", train=True, download=True, transform=transform)
testset = CIFAR10(".", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
testloader = DataLoader(testset, batch_size=32)
num_examples = {"trainset" : len(trainset), "testset" : len(testset)}
return trainloader, testloader, num_examples
使用 PyTorch 定义损失和优化器。数据集的训练是通过循环访问数据集来完成的,测量相应的损失并对其进行优化。
def train(net, trainloader, epochs):
"""Train the network on the training set."""
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for _ in range(epochs):
for images, labels in trainloader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
loss = criterion(net(images), labels)
loss.backward()
optimizer.step()
然后定义机器学习网络的验证。我们遍历测试集并测量测试集的损失和准确性。
def test(net, testloader):
"""Validate the network on the entire test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
return loss, accuracy
在定义了 PyTorch 机器学习模型的训练和测试后,我们将函数用于 Flower 客户端。
Flower客户将使用改编自“PyTorch:60分钟闪电战”的简单CNN:
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
#Load model and data
net = Net().to(DEVICE)
trainloader, testloader, num_examples = load_data()
使用 load_data()
加载数据集后,我们定义 Flower 接口。
Flower 服务器通过名为 Client
的接口与客户端交互。当服务器选择特定客户端进行训练时,它会通过网络发送训练指令。客户端接收这些指令并调用 Client
方法之一来运行代码(即,训练我们之前定义的神经网络)。
Flower提供了一个名为NumPyClient
的方便类,当您的工作负载使用PyTorch时,它可以更轻松地实现 Client
接口。实现 NumPyClient
通常意味着定义以下方法(set_parameters
是可选的):
get_parameters
将模型权重作为 NumPy ndarray 的列表返回set_parameters
(optional) 使用从服务器接收的参数更新本地模型权重fit
设置本地模型权重 训练本地模型 接收更新的本地模型权重evaluate
测试本地模型
可以通过以下方式实现:
class CifarClient(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]
def set_parameters(self, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
def fit(self, parameters, config):
self.set_parameters(parameters)
train(net, trainloader, epochs=1)
return self.get_parameters(config={}), num_examples["trainset"], {}
def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, accuracy = test(net, testloader)
return float(loss), num_examples["testset"], {"accuracy": float(accuracy)}
现在,我们可以创建类 CifarClient
的实例,并添加一行来实际运行此客户端:
fl.client.start_numpy_client(server_address="[::]:8080", client=CifarClient())
这就是对客户来说。我们只需要实现客户端或 NumPyClient
并调用 fl.client.start_client()
或 fl.client.start_numpy_client()
。字符串“[::]:8080”
告诉客户端要连接到哪个服务器。在我们的例子中,我们可以在同一台机器上运行服务器和客户端,因此我们使用“[::]:8080”
。如果我们运行一个真正的联合工作负载,服务器和客户端在不同的机器上运行,那幺需要更改的只是我们指向客户端server_address
。
Flower服务器端
对于简单的工作负载,我们可以启动 Flower 服务器并将所有配置可能性保留为其默认值。在名为 server.py
的文档中,导入 Flower 并启动服务器:
import flwr as fl
fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3))
训练模型
客户端和服务器都准备好了,我们现在可以运行所有内容并查看联邦学习的实际效果。 FL 系统通常有一个服务器和多个客户端。因此,我们必须首先启动服务器:
$ python server.py
服务器运行后,我们可以在不同的终端启动客户端。打开一个新终端并启动第一个客户端:
$ python client.py
打开另一个终端,启动第二个客户端:
$ python client.py
每个客户端都有自己的数据集。现在,您应该看到训练在第一个终端(启动服务器的终端)中是如何完成的:
INFO flower 2021-02-25 14:00:27,227 | app.py:76 | Flower server running (insecure, 3 rounds)
INFO flower 2021-02-25 14:00:27,227 | server.py:72 | Getting initial parameters
INFO flower 2021-02-25 14:01:15,881 | server.py:74 | Evaluating initial parameters
INFO flower 2021-02-25 14:01:15,881 | server.py:87 | [TIME] FL starting
DEBUG flower 2021-02-25 14:01:41,310 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:02:00,256 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:00,262 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:02:03,047 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:03,049 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:02:23,908 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:23,915 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:02:27,120 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:27,122 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:03:04,660 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:03:04,671 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:03:09,273 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-02-25 14:03:09,273 | server.py:122 | [TIME] FL finished in 113.39180790000046
INFO flower 2021-02-25 14:03:09,274 | app.py:109 | app_fit: losses_distributed [(1, 650.9747924804688), (2, 526.2535400390625), (3, 473.76959228515625)]
INFO flower 2021-02-25 14:03:09,274 | app.py:110 | app_fit: accuracies_distributed []
INFO flower 2021-02-25 14:03:09,274 | app.py:111 | app_fit: losses_centralized []
INFO flower 2021-02-25 14:03:09,274 | app.py:112 | app_fit: accuracies_centralized []
DEBUG flower 2021-02-25 14:03:09,276 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:03:11,852 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-02-25 14:03:11,852 | app.py:121 | app_evaluate: federated loss: 473.76959228515625
INFO flower 2021-02-25 14:03:11,852 | app.py:122 | app_evaluate: results [('ipv6:[::1]:36602', EvaluateRes(loss=351.4906005859375, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.6067})), ('ipv6:[::1]:36604', EvaluateRes(loss=353.92742919921875, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.6005}))]
INFO flower 2021-02-25 14:03:27,514 | app.py:127 | app_evaluate: failures []
祝贺!您已成功构建并运行了您的第一个联邦学习系统。此示例的完整源代码可以在 examples/quickstart_pytorch
中找到。