联邦学习框架Flower-快速入门(PyTorch)

本文介绍了一个使用Flower和PyTorch在CIFAR10数据集上训练卷积神经网络的教程。首先创建虚拟环境,然后通过安装Flower和PyTorch来设置分布式训练环境。教程详细说明了如何定义网络、损失和优化器,以及如何在Flower的客户端和服务器之间协调训练过程。最后,展示了如何启动服务器和客户端以进行联合学习,并解释了训练过程中的日志输出。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在这里插入图片描述

在本教程中,我们将学习如何使用 Flower 和 PyTorch 在 CIFAR10 上训练卷积神经网络

首先,建议创建一个虚拟环境并在其中运行所有内容。

我们的示例由一台服务器和两个具有相同模型的客户端组成。

客户端负责根据其本地数据集为模型生成单独的权重更新。然后将这些更新发送到服务器,服务器将聚合它们以生成更好的模型。最后,服务器将此改进版本的模型发送回每个客户端。权重更新的完整周期称为回合。

现在我们对正在发生的事情有了大致的了解,让我们开始吧。我们首先需要安装 Flower。您可以通过运行以下命令来执行此操作:

$ pip install flwr

既然我们想使用 PyTorch 来解决计算机视觉任务,让我们继续安装 PyTorch 和 torchvision 库:

$ pip install torch torchvision

Flower客户端

现在我们已经安装了所有依赖项,让我们使用两个客户端和一个服务器运行一个简单的分布式训练。我们的训练进程和网络架构基于 PyTorch 的深度学习和 PyTorch。
在名为 client.py 的文档中,导入 Flower 和 PyTorch 相关包:

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

import flwr as fl

此外,我们在 PyTorch 中定义设备分配:

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

我们使用 PyTorch 加载 CIFAR10,这是一个流行的用于机器学习的彩色图像分类数据集。PyTorch DataLoader() 下载训练和测试数据,然后对其进行规范化。

def load_data():
    """Load CIFAR-10 (training and test set)."""
    transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10(".", train=True, download=True, transform=transform)
    testset = CIFAR10(".", train=False, download=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
    testloader = DataLoader(testset, batch_size=32)
    num_examples = {"trainset" : len(trainset), "testset" : len(testset)}
    return trainloader, testloader, num_examples

使用 PyTorch 定义损失和优化器。数据集的训练是通过循环访问数据集来完成的,测量相应的损失并对其进行优化。

def train(net, trainloader, epochs):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    for _ in range(epochs):
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()

然后定义机器学习网络的验证。我们遍历测试集并测量测试集的损失和准确性。

def test(net, testloader):
    """Validate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return loss, accuracy

在定义了 PyTorch 机器学习模型的训练和测试后,我们将函数用于 Flower 客户端。

Flower客户将使用改编自“PyTorch:60分钟闪电战”的简单CNN:

class Net(nn.Module):
  def __init__(self) -> None:
      super(Net, self).__init__()
      self.conv1 = nn.Conv2d(3, 6, 5)
      self.pool = nn.MaxPool2d(2, 2)
      self.conv2 = nn.Conv2d(6, 16, 5)
      self.fc1 = nn.Linear(16 * 5 * 5, 120)
      self.fc2 = nn.Linear(120, 84)
      self.fc3 = nn.Linear(84, 10)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
      x = self.pool(F.relu(self.conv1(x)))
      x = self.pool(F.relu(self.conv2(x)))
      x = x.view(-1, 16 * 5 * 5)
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = self.fc3(x)
      return x
#Load model and data
net = Net().to(DEVICE)
trainloader, testloader, num_examples = load_data()

使用 load_data() 加载数据集后,我们定义 Flower 接口。

Flower 服务器通过名为 Client 的接口与客户端交互。当服务器选择特定客户端进行训练时,它会通过网络发送训练指令。客户端接收这些指令并调用 Client 方法之一来运行代码(即,训练我们之前定义的神经网络)。

Flower提供了一个名为NumPyClient的方便类,当您的工作负载使用PyTorch时,它可以更轻松地实现 Client 接口。实现 NumPyClient 通常意味着定义以下方法(set_parameters是可选的):

  1. get_parameters 将模型权重作为 NumPy ndarray 的列表返回
  2. set_parameters (optional) 使用从服务器接收的参数更新本地模型权重
  3. fit 设置本地模型权重 训练本地模型 接收更新的本地模型权重
  4. evaluate 测试本地模型
    可以通过以下方式实现:
class CifarClient(fl.client.NumPyClient):
   def get_parameters(self, config):
       return [val.cpu().numpy() for _, val in net.state_dict().items()]

   def set_parameters(self, parameters):
       params_dict = zip(net.state_dict().keys(), parameters)
       state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
       net.load_state_dict(state_dict, strict=True)

   def fit(self, parameters, config):
       self.set_parameters(parameters)
       train(net, trainloader, epochs=1)
       return self.get_parameters(config={}), num_examples["trainset"], {}

   def evaluate(self, parameters, config):
       self.set_parameters(parameters)
       loss, accuracy = test(net, testloader)
       return float(loss), num_examples["testset"], {"accuracy": float(accuracy)}

现在,我们可以创建类 CifarClient 的实例,并添加一行来实际运行此客户端:

fl.client.start_numpy_client(server_address="[::]:8080", client=CifarClient())

这就是对客户来说。我们只需要实现客户端或 NumPyClient 并调用 fl.client.start_client()fl.client.start_numpy_client()。字符串“[::]:8080”告诉客户端要连接到哪个服务器。在我们的例子中,我们可以在同一台机器上运行服务器和客户端,因此我们使用“[::]:8080”。如果我们运行一个真正的联合工作负载,服务器和客户端在不同的机器上运行,那幺需要更改的只是我们指向客户端server_address

Flower服务器端

对于简单的工作负载,我们可以启动 Flower 服务器并将所有配置可能性保留为其默认值。在名为 server.py 的文档中,导入 Flower 并启动服务器:

import flwr as fl

fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3))

训练模型

客户端和服务器都准备好了,我们现在可以运行所有内容并查看联邦学习的实际效果。 FL 系统通常有一个服务器和多个客户端。因此,我们必须首先启动服务器:

$ python server.py

服务器运行后,我们可以在不同的终端启动客户端。打开一个新终端并启动第一个客户端:

$ python client.py

打开另一个终端,启动第二个客户端:

$ python client.py

每个客户端都有自己的数据集。现在,您应该看到训练在第一个终端(启动服务器的终端)中是如何完成的:

INFO flower 2021-02-25 14:00:27,227 | app.py:76 | Flower server running (insecure, 3 rounds)
INFO flower 2021-02-25 14:00:27,227 | server.py:72 | Getting initial parameters
INFO flower 2021-02-25 14:01:15,881 | server.py:74 | Evaluating initial parameters
INFO flower 2021-02-25 14:01:15,881 | server.py:87 | [TIME] FL starting
DEBUG flower 2021-02-25 14:01:41,310 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:02:00,256 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:00,262 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:02:03,047 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:03,049 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:02:23,908 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:23,915 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:02:27,120 | server.py:149 | evaluate received 2 results and 0 failures
DEBUG flower 2021-02-25 14:02:27,122 | server.py:165 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-02-25 14:03:04,660 | server.py:177 | fit_round received 2 results and 0 failures
DEBUG flower 2021-02-25 14:03:04,671 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:03:09,273 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-02-25 14:03:09,273 | server.py:122 | [TIME] FL finished in 113.39180790000046
INFO flower 2021-02-25 14:03:09,274 | app.py:109 | app_fit: losses_distributed [(1, 650.9747924804688), (2, 526.2535400390625), (3, 473.76959228515625)]
INFO flower 2021-02-25 14:03:09,274 | app.py:110 | app_fit: accuracies_distributed []
INFO flower 2021-02-25 14:03:09,274 | app.py:111 | app_fit: losses_centralized []
INFO flower 2021-02-25 14:03:09,274 | app.py:112 | app_fit: accuracies_centralized []
DEBUG flower 2021-02-25 14:03:09,276 | server.py:139 | evaluate: strategy sampled 2 clients
DEBUG flower 2021-02-25 14:03:11,852 | server.py:149 | evaluate received 2 results and 0 failures
INFO flower 2021-02-25 14:03:11,852 | app.py:121 | app_evaluate: federated loss: 473.76959228515625
INFO flower 2021-02-25 14:03:11,852 | app.py:122 | app_evaluate: results [('ipv6:[::1]:36602', EvaluateRes(loss=351.4906005859375, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.6067})), ('ipv6:[::1]:36604', EvaluateRes(loss=353.92742919921875, num_examples=10000, accuracy=0.0, metrics={'accuracy': 0.6005}))]
INFO flower 2021-02-25 14:03:27,514 | app.py:127 | app_evaluate: failures []

祝贺!您已成功构建并运行了您的第一个联邦学习系统。此示例的完整源代码可以在 examples/quickstart_pytorch 中找到。

### 配置和使用 Flower 框架 #### 在 PyCharm 中安装依赖项 为了在 PyCharm 中配置并使用 Flower 框架,首先需要确保项目环境中已正确安装所需库。可以通过命令行工具来完成此操作: ```bash pip install flower tensorflow datasets ``` 这一步骤会下载并安装运行 Flower 所需的核心组件[^3]。 #### 创建虚拟环境 建议创建一个新的 Python 虚拟环境用于隔离不同项目的依赖关系。可以在终端中通过以下命令实现: ```bash python -m venv my_flower_env source my_flower_env/bin/activate # Linux 或 macOS 下激活环境 my_flower_env\Scripts\activate # Windows 下激活环境 ``` 一旦激活了新的虚拟环境,在该环境下安装的所有包都将独立于其他项目之外[^4]。 #### 导入数据集 对于机器学习模型训练而言,准备合适的数据至关重要。可以利用 TensorFlow 提供的功能轻松加载预处理过的图像数据集: ```python import tensorflow as tf from tensorflow.keras.preprocessing import image_dataset_from_directory data_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz" dataset_dir = tf.keras.utils.get_file( origin=data_url, untar=True, cache_subdir='datasets', extract=True ) flowers_dataset = image_dataset_from_directory(dataset_dir + '/flower_photos') ``` 上述代码片段展示了如何获取并解析花朵图片集合以备后续实验之用。 #### 启动 Flower Server 和 Client Flower 是一种分布式深度学习框架,允许多个客户端共同参与联邦学习过程而不必共享原始数据样本。启动服务端与客户端的方式如下所示: 服务器端: ```python # server.py 文件内容 import flwr as fl fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3)) ``` 客户端侧则可能看起来像这样: ```python # client.py 文件内容 import flwr as fl class FlowerClient(fl.client.NumPyClient): def get_parameters(self): pass def fit(self, parameters, config): pass def evaluate(self, parameters, config): pass if __name__ == "__main__": fl.client.start_numpy_client("[::]:8080", client=FlowerClient()) ``` 请注意实际应用时还需要定义具体的 `get_parameters`、`fit` 及 `evaluate` 方法逻辑[^1]。 #### 运行脚本 最后,在 PyCharm 的 Terminal 窗口中分别执行两个Python文件即可开始一次完整的联合训练周期: ```bash python server.py # 新建一个终端标签页作为Server节点 python client.py # 复制多份client实例模拟集群计算场景 ``` 以上就是在 PyCharm IDE 内部设置好基于 Flower 架构的开发流程概览[^2]。
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值