Communication-Efficient Learning of Deep Networks from Decentralized Data论文代码复现

前言

这两天看了一下《Communication-Efficient Learning of Deep Networks from Decentralized Data》论文,感兴趣的同学可以去看看我上一个博客,然后根据论文还有问gpt写了一下相关代码,如果有问题欢迎指正。

MNIST 2NN and CNN v1

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import datasets, transforms
import numpy as np
import random
import matplotlib.pyplot as plt


# 定义MLP模型
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 200)
        self.fc2 = nn.Linear(200, 200)
        self.fc3 = nn.Linear(200, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# 定义CNN模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 1024)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# 数据加载与预处理
def load_data_iid(num_clients):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    mnist_train = datasets.MNIST('../dataset', train=True, download=True, transform=transform)

    # IID数据分割
    data_size = len(mnist_train)
    indices = np.arange(data_size)
    np.random.shuffle(indices)
    client_indices = np.array_split(indices, num_clients)

    client_datasets = []
    for client_idx in client_indices:
        client_datasets.append(data.Subset(mnist_train, client_idx))

    return client_datasets


def load_data_noniid(num_clients):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    mnist_train = datasets.MNIST('../dataset', train=True, download=True, transform=transform)

    # Non-IID数据分割
    shard_size = 300
    shards_per_client = 2
    num_shards = num_clients * shards_per_client
    indices_by_digit = [np.where(np.array(mnist_train.targets) == i)[0] for i in range(10)]

    # 将数据按标签排序
    sorted_indices = np.concatenate([idx[:num_shards * shard_size // 10] for idx in indices_by_digit])
    client_indices = []

    for i in range(num_clients):
        client_data = np.concatenate(
            [sorted_indices[j * shard_size:(j + 1) * shard_size] for j in
             range(i * shards_per_client, (i + 1) * shards_per_client)]
        )
        client_indices.append(client_data)

    client_datasets = [data.Subset(mnist_train, idxs) for idxs in client_indices]

    return client_datasets


# 定义客户端训练过程
def train_local(client_data, model, optimizer, criterion, epochs=1, batch_size=32):
    model.train()
    loader = data.DataLoader(client_data, batch_size=batch_size, shuffle=True)
    for epoch in range(epochs):
        for x, y in loader:
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()

    return model.state_dict(), loss.item()


# 联邦平均算法
def federated_avg(global_model, clients_data, num_clients, epochs_per_client=1, C=0.1):
    num_selected = max(int(C * num_clients), 1)  # 每一轮选择的客户端数
    selected_clients = random.sample(clients_data, num_selected)
    global_model_state = global_model.state_dict()

    aggregated_weights = None
    for client_data in selected_clients:
        local_model = MLP()  # 或CNN()
        local_model.load_state_dict(global_model_state)
        optimizer = optim.SGD(local_model.parameters(), lr=0.1)
        criterion = nn.CrossEntropyLoss()

        local_weights, local_loss = train_local(client_data, local_model, optimizer, criterion, epochs_per_client)

        # 逐客户端平均模型权重
        if aggregated_weights is None:
            aggregated_weights = {key: local_weights[key] / num_selected for key in local_weights}
        else:
            for key in aggregated_weights:
                aggregated_weights[key] += local_weights[key] / num_selected

    global_model.load_state_dict(aggregated_weights)
    return global_model


# 训练过程的图像
def plot_loss_curve(losses):
    plt.plot(range(len(losses)), losses)
    plt.xlabel("Rounds")
    plt.ylabel("Loss")
    plt.title("Training Loss Curve")
    plt.show()


# 主程序
num_clients = 100
num_rounds = 10
epochs_per_client = 1
iid = True

# 加载数据(IID或Non-IID)
if iid:
    clients_data = load_data_iid(num_clients)
else:
    clients_data = load_data_noniid(num_clients)

# 定义全局模型
global_model = MLP()  # 或CNN()

# 开始联邦学习过程
losses = []
for r in range(num_rounds):
    global_model = federated_avg(global_model, clients_data, num_clients, epochs_per_client)
    # 在每轮训练结束后计算全局模型在验证集上的损失
    test_loss = 0
    global_model.eval()
    with torch.no_grad():
        test_loader = data.DataLoader(
            datasets.MNIST('../dataset', train=False, download=True, transform=transforms.ToTensor()), batch_size=1000)
        for x, y in test_loader:
            output = global_model(x)
            loss = nn.CrossEntropyLoss()(output, y)
            test_loss += loss.item()

    losses.append(test_loss)

# 绘制训练过程中的损失变化
plot_loss_curve(losses)

MNIST 2NN and CNN v2

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 200)
        self.fc2 = nn.Linear(200, 200)
        self.fc3 = nn.Linear(200, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 10)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 1024)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x
# 数据转换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 加载MNIST数据集
train_dataset = datasets.MNIST('../pytorch_Liuer/dataset', train=True, download=True, transform=transform)

# 定义IID数据分区
def iid_partition(dataset, num_clients):
    client_data_loaders = []
    num_examples_per_client = len(dataset) // num_clients
    for _ in range(num_clients):
        client_data = torch.utils.data.Subset(dataset, range(num_examples_per_client))
        client_data_loader = DataLoader(client_data, batch_size=32, shuffle=True)
        client_data_loaders.append(client_data_loader)
    return client_data_loaders

# 定义非IID数据分区(例子中的 pathological 非IID 分区)
def non_iid_partition(dataset, num_clients):
    client_data_loaders = []
    shard_size = len(dataset) // 200  # 每个 shard 大小为 300
    for _ in range(num_clients):
        shard_indices = [i for i in range(_ * 2, (_ + 1) * 2)]  # 每个客户端获取2个 shards
        samples = [j for i in shard_indices for j in range(i * shard_size, (i + 1) * shard_size)]
        client_data = torch.utils.data.Subset(dataset, samples)
        client_data_loader = DataLoader(client_data, batch_size=32, shuffle=True)
        client_data_loaders.append(client_data_loader)
    return client_data_loaders

# 选择数据分区方式
num_clients = 100
# client_data_loaders = iid_partition(train_dataset, num_clients)
client_data_loaders = non_iid_partition(train_dataset, num_clients)

def federated_averaging(client_data_loaders, model, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        for client_data_loader in client_data_loaders:
            for data, target in client_data_loader:
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
    return model

def evaluate(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in data_loader:
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    return correct / total

def main():
    # 初始化模型和优化器
    model = CNN()  # 选择使用 MLP 或 CNN
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    # 联邦学习训练过程
    model = federated_averaging(client_data_loaders, model, criterion, optimizer, num_epochs=10)

    # 评估模型
    test_dataset = datasets.MNIST('../pytorch_Liuer/dataset', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    accuracy = evaluate(model, test_loader)
    print(f'Test accuracy: {accuracy}')

if __name__ == '__main__':
    main()

SHAKESPEARE LSTM

先在这个网址:https://www.gutenberg.org/files/100/100-0.txt,复制所有txt到本地代码目录下为shakespeare.txt
然后需要先读取这个txt,分割为训练集和测试集,具体代码如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict
import random
import numpy as np
import os
import re
from collections import defaultdict


# 1. 数据集下载和处理
class ShakespeareDataset(Dataset):
    def __init__(self, data, vocab, seq_length=80):
        self.data = data
        self.seq_length = seq_length
        self.vocab = vocab

    def __len__(self):
        # 确保返回的长度不为负数
        return max(0, len(self.data) - self.seq_length)

    def __getitem__(self, index):
        x = self.data[index:index + self.seq_length]
        y = self.data[index + 1:index + self.seq_length + 1]
        # 将字符映射为词汇表中的索引,处理未知字符
        x_indices = [self.vocab[char] if char in self.vocab else self.vocab['UNK'] for char in x]
        y_indices = [self.vocab[char] if char in self.vocab else self.vocab['UNK'] for char in y]
        return torch.tensor(x_indices), torch.tensor(y_indices)



# def load_shakespeare_data():
#     # 假设已经下载并处理好的莎士比亚文本数据
#     # 数据格式: {角色名: 台词文本}
#     data = defaultdict(str)
#
#     with open('shakespeare.txt', 'r') as f:
#         for line in f:
#             # 每行格式:角色名:台词
#             name, text = line.split(':', 1)
#             data[name.strip()] += text.strip()
#
#     return data

def parse_shakespeare(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()

    # 使用正则表达式提取角色名和台词
    character_dialogues = defaultdict(str)
    current_character = None
    lines = text.splitlines()

    for line in lines:
        # 查找角色名
        match = re.match(r"^\s{2,}(.*)$", line)
        if match:
            current_character = match.group(1).strip().upper()
        elif current_character:
            # 如果该行是角色的台词,则保存
            character_dialogues[current_character] += line.strip() + " "

    # 移除空角色和台词较少的角色
    filtered_dialogues = {char: dialog for char, dialog in character_dialogues.items() if len(dialog) > 50}

    return filtered_dialogues


def split_data(data):
    """按80%训练集,20%测试集划分"""
    train_data = {}
    test_data = {}

    for character, text in data.items():
        split_point = int(len(text) * 0.8)
        train_data[character] = text[:split_point]
        test_data[character] = text[split_point:]

    return train_data, test_data


# 2. LSTM模型定义
class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=8, hidden_size=256, num_layers=2):
        super(LSTMModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embedding(x)
        output, hidden = self.lstm(x, hidden)
        output = self.fc(output)
        return output, hidden


# 3. 联邦学习相关函数
def local_train(model, train_loader, criterion, optimizer, epochs=1):
    model.train()
    for _ in range(epochs):
        for x, y in train_loader:
            # print(f"Batch max index: {torch.max(x)}, Batch min index: {torch.min(x)}")
            assert torch.max(x) < vocab_size, f"超出了词汇表范围:最大值为 {torch.max(x)}, 词汇表大小为 {vocab_size}"

            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output, _ = model(x)
            loss = criterion(output.view(-1, vocab_size), y.view(-1))
            loss.backward()
            optimizer.step()


def federated_avg(models):
    """进行联邦平均"""
    global_model = models[0]
    with torch.no_grad():
        for k in global_model.state_dict().keys():
            global_model.state_dict()[k] = torch.stack([model.state_dict()[k].float() for model in models], dim=0).mean(
                dim=0)
    return global_model


# 4. 训练流程
def train_federated(train_data, test_data, vocab_size, num_clients=10, local_epochs=5, communication_rounds=10):
    global_model = LSTMModel(vocab_size=vocab_size).to(device)
    criterion = nn.CrossEntropyLoss()

    client_ids = list(train_data.keys())

    for round in range(communication_rounds):
        selected_clients = random.sample(client_ids, num_clients)
        local_models = []

        for client in selected_clients:
            # 如果数据的长度不足 81,无法生成一个完整的输入和输出序列。这时,即使有部分数据,也无法用于训练,因为ShakespeareDataset中 x 或 y 中会缺少元素,导致无法进行 LSTM 模型的有效训练。
            if len(train_data[client]) < 81:  # 如果文本长度小于81(seq_length + 1),则跳过
                print(f"Skipping client {client} due to insufficient data length.")
                continue

            local_model = LSTMModel(vocab_size).to(device)
            local_model.load_state_dict(global_model.state_dict())
            optimizer = optim.SGD(local_model.parameters(), lr=0.1)
            
            # 创建数据集
            train_dataset = ShakespeareDataset(train_data[client], vocab)
            
            # **检查数据集长度是否大于0**
            if len(train_dataset) == 0:
                print(f"Skipping client {client} because dataset length is 0.")
                continue
            
            train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

            local_train(local_model, train_loader, criterion, optimizer, local_epochs)
            local_models.append(local_model)

        if local_models:  # 确保本轮有至少一个客户端
            global_model = federated_avg(local_models)
            print(f"Round {round + 1}/{communication_rounds} completed.")

    return global_model


# 5. 评估
def evaluate_model(model, test_data, vocab_size):
    model.eval()
    total_loss = 0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for client, text in test_data.items():
            test_dataset = ShakespeareDataset(text, vocab)
            test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                output, _ = model(x)
                loss = criterion(output.view(-1, vocab_size), y.view(-1))
                total_loss += loss.item()
    print(f"Test Loss: {total_loss / len(test_data)}")

def build_vocab(text):
    """构建词汇表,返回字符到索引的映射"""
    vocab = defaultdict(int)
    vocab['UNK'] = 0  # 添加一个未知字符,索引为0
    idx = 1
    for char in set(text):  # 遍历所有不同字符
        vocab[char] = idx
        idx += 1
    return vocab


# 将角色的台词映射为索引
def text_to_indices(text, vocab):
    return [vocab[char] if char in vocab else vocab['UNK'] for char in text]


# 6. 主函数
if __name__ == "__main__":
    device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

    # 加载莎士比亚数据集
    # data = load_shakespeare_data()

    # 解析莎士比亚全集并将台词保存为{角色: 台词}格式
    character_dialogues = parse_shakespeare('shakespeare.txt')
    # 构建词汇表
    full_text = ''.join(character_dialogues.values())  # 将所有角色的台词拼接成一个大文本
    vocab = build_vocab(full_text)
    # 查看词汇表大小
    vocab_size = len(vocab)
    print(f"词汇表大小:{vocab_size}")
    # 将所有台词转换为索引形式
    character_dialogue_indices = {char: text_to_indices(dialogue, vocab) for char, dialogue in character_dialogues.items()}

    # 划分训练集和测试集
    train_data, test_data = split_data(character_dialogues)

    # # 构建字符集
    # vocab = sorted(list(set("".join(train_data.values()))))
    # vocab_size = len(vocab)

    print(f"Input sample: {train_data[list(train_data.keys())[0]][:10]}")
    print(f"Vocab size: {vocab_size}")

    # 联邦学习训练
    global_model = train_federated(train_data, test_data, vocab_size, num_clients=10, local_epochs=5,
                                   communication_rounds=10)

    # 模型评估
    print('before eval')
    evaluate_model(global_model, test_data, vocab_size)
    print('after eval')



训练结果:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
怎么训练了100轮loss还变多了

CIFAR CNN v1

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import random

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义模型(类似于TensorFlow教程中的模型)
class CIFARModel(nn.Module):
    def __init__(self):
        super(CIFARModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
        self.fc1 = nn.Linear(64 * 6 * 6, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 处理数据集
transform = transforms.Compose([
    transforms.RandomCrop(24),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# 将数据划分给100个客户端
def partition_data(dataset, num_clients):
    client_data = {}
    data_size = len(dataset) // num_clients
    indices = list(range(len(dataset)))
    random.shuffle(indices)
    for i in range(num_clients):
        client_indices = indices[i * data_size:(i + 1) * data_size]
        client_data[i] = Subset(dataset, client_indices)
    return client_data

num_clients = 100
clients_data = partition_data(train_dataset, num_clients)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()

def local_train(model, data_loader, epochs, lr):
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=lr)
    for epoch in range(epochs):
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

# FedAvg过程
def federated_avg(models):
    global_model = models[0]
    with torch.no_grad():
        for k in global_model.state_dict().keys():
            global_model.state_dict()[k] = torch.stack([model.state_dict()[k].float() for model in models], dim=0).mean(dim=0)
    return global_model

def train_federated(clients_data, num_clients, global_model, epochs=5, lr=0.01, communication_rounds=100):
    global_losses = []
    for round in range(communication_rounds):
        local_models = []
        selected_clients = random.sample(range(num_clients), 10)  # 每轮选择10个客户端
        for client_id in selected_clients:
            local_model = CIFARModel().to(device)
            local_model.load_state_dict(global_model.state_dict())
            data_loader = DataLoader(clients_data[client_id], batch_size=50, shuffle=True)
            local_train(local_model, data_loader, epochs, lr)
            local_models.append(local_model)

        # 平均化权重
        global_model = federated_avg(local_models)
        
        # 在测试集上评估全局模型
        test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
        global_loss = evaluate_model(global_model, test_loader)
        global_losses.append(global_loss)
        print(f"Round {round + 1}/{communication_rounds}, Loss: {global_loss:.4f}")

    return global_model, global_losses

# 模型评估
def evaluate_model(model, test_loader):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
    return total_loss / len(test_loader)

# 主函数
if __name__ == "__main__":
    global_model = CIFARModel().to(device)
    global_model, global_losses = train_federated(clients_data, num_clients, global_model, epochs=5, lr=0.01, communication_rounds=100)
    
    # 输出loss随通信轮次变化
    print(global_losses)

在这里插入图片描述

CIFAR CNN v2

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import random
import matplotlib.pyplot as plt
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 定义模型(类似于TensorFlow教程中的模型)
class CIFARModel(nn.Module):
    def __init__(self):
        super(CIFARModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
        self.fc1 = nn.Linear(64 * 6 * 6, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# 处理数据集
transform = transforms.Compose([
    transforms.RandomCrop(24),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)


# 将数据划分给100个客户端
def partition_data(dataset, num_clients):
    client_data = {}
    data_size = len(dataset) // num_clients
    indices = list(range(len(dataset)))
    random.shuffle(indices)
    for i in range(num_clients):
        client_indices = indices[i * data_size:(i + 1) * data_size]
        client_data[i] = Subset(dataset, client_indices)
    return client_data


num_clients = 100
clients_data = partition_data(train_dataset, num_clients)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()



# 定义本地训练函数
def local_train(model, train_loader, criterion, optimizer, scheduler, epochs=1):
    model.train()
    total_loss = 0
    for _ in range(epochs):
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

    return total_loss / len(train_loader)


# FedAvg过程
def federated_avg(models):
    global_model = models[0]
    with torch.no_grad():
        for k in global_model.state_dict().keys():
            global_model.state_dict()[k] = torch.stack([model.state_dict()[k].float() for model in models], dim=0).mean(
                dim=0)
    return global_model


# 定义联邦训练流程
def train_federated(train_data, num_clients, local_epochs=5, communication_rounds=100, lr=0.1, lr_decay=0.99):
    global_model = CIFARModel().to(device)
    criterion = nn.CrossEntropyLoss()

    # 用于绘制损失下降的列表
    losses_per_round = []

    client_ids = list(train_data.keys())
    for round in range(communication_rounds):
        selected_clients = random.sample(client_ids, num_clients)
        local_models = []
        round_loss = 0

        for client in selected_clients:
            local_model = CIFARModel().to(device)
            local_model.load_state_dict(global_model.state_dict())
            optimizer = optim.SGD(local_model.parameters(), lr=0.1, momentum=0.9)
            # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
            train_loader = DataLoader(train_data[client], batch_size=50, shuffle=True)

            # 本地训练
            loss = local_train(local_model, train_loader, criterion, optimizer, local_epochs)
            local_models.append(local_model)
            round_loss += loss

        global_model = federated_avg(local_models)
        avg_loss = round_loss / num_clients
        losses_per_round.append(avg_loss)

        print(f"Round {round + 1}/{communication_rounds}, Loss: {avg_loss}")
        # 每轮结束后对学习率进行衰减
        lr *= lr_decay

    # 绘制loss曲线
    plt.plot(range(communication_rounds), losses_per_round, label='Loss')
    plt.xlabel('Communication Rounds')
    plt.ylabel('Loss')
    plt.title('Loss over Communication Rounds')
    plt.legend()
    plt.show()

    return global_model


# 定义模型评估函数
def evaluate_model(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    print(f'Test Accuracy: {accuracy * 100:.2f}%')


# 主函数
if __name__ == "__main__":
    global_model = train_federated(clients_data, num_clients, local_epochs=5, communication_rounds=100, lr=0.1, lr_decay=0.99)

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值