联邦学习笔记(六)实现自己的联邦学习算法

我的联邦学习相关笔记Github

联邦学习平均算法(FAVG)

TFF平台还是挺难用,光是那些API就很难用熟练。所以在借鉴这位老哥代码的基础上改出来这份代码。

横向联邦学习-IID联邦平均算法

代码架构如下所示:
在这里插入图片描述

数据集预处理

将mnist数据集处理成IID类型

import random
import numpy as np
from termcolor import colored
from tensorflow.keras.datasets import mnist
import tensorflow as tf
from mnist_model import mnist_model
from utils import *
def generate_clients_data( num_expamples_list_in_clients:list , num_clients = 10, IsIID=True, batch_size=100 ,tt_rate = 0.3):
        
    (x_train, y_train), (x_test, y_test)= mnist.load_data()
    x_train , y_train = shuffle_dataset(x_train , y_train)
    x_test , y_test = shuffle_dataset(x_test  , y_test)

    x_train = x_train.astype('float32').reshape(-1,28*28)/255.0
    x_test = x_test.astype('float32').reshape(-1,28*28)/255.0
    y_test = tf.one_hot(y_test , depth=10 , on_value=None , off_value = None)
    y_train = tf.one_hot(y_train , depth=10 , on_value=None , off_value = None)



    
    if len(num_expamples_list_in_clients) == 1:
        num_expamples_list_in_clients *= num_clients 
    
    # dataset for server test 
    # get server datasets 
    client_dataset_test_size  = int(sum([x*tt_rate for x in num_expamples_list_in_clients]))
    dataset_server_size = int(client_dataset_test_size*0.3)
    
    server_test_x = x_test[ client_dataset_test_size : int(client_dataset_test_size+dataset_server_size) ]
    server_test_y = y_test[ client_dataset_test_size : int(client_dataset_test_size+dataset_server_size) ]

    server_dataset = tf.data.Dataset.from_tensor_slices((server_test_x , server_test_y )).batch(batch_size)

    x_test = x_test[:client_dataset_test_size]
    y_test = y_test[:client_dataset_test_size]





    if (IsIID == True):
        
        print(colored('---------- IID = True ----------', 'green'))  
        
        # get train dataset   for client
        train_data_list = []
        start_train = 0
        for size in num_expamples_list_in_clients:
            client_i_train_dataset = list( zip( x_train, y_train ))[start_train : size+start_train]  
            train_data_list.append( preprocess_client_data( client_i_train_dataset )  )
            start_train  += size
        
        # get  test dataset for client
        test_data_list = []
        start_test = 0
        for test_size in [x*tt_rate for x in num_expamples_list_in_clients]:
            client_i_test_dataset =  list(zip( x_test , y_test ))[start_test:int(test_size)+start_test] 
            test_data_list.append(  preprocess_client_data( client_i_test_dataset )  )
            start_test += int(test_size)
        
        #for test server

        return train_data_list, test_data_list , server_dataset
        
        
            
    else:
        ''' creates x non_IID clients'''

        
        
        print(colored('---------- IID = False ----------', 'green'))
        #create unique label list and shuffle
        
        unique_labels = np.unique(np.array(y_train))
        # random.shuffle(unique_labels)
        unique_labels = sorted(unique_labels)
        
        train_class = [None]*num_clients

        # classifar examples by unique label
        for (item , num_examples_client)  in zip(unique_labels , num_expamples_list_in_clients ):
            
            train_class[item] = [(image, label) for (image, label) in zip(x_train, y_train) if label == item][:num_examples_client]
        
        clients_dataset_list = []
        for dataset in train_class:
            clients_dataset_list.append( preprocess_client_data(dataset) )

        # dataset for server test


        return clients_dataset_list , server_dataset

模型

这里用的是TFF中的mnist模型,如果要改成自己的模型,记得把代码中要用的功能实现一下。

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 
import tensorflow as tf
from tensorflow_federated.python.simulation.models import mnist

#创建mnist神经网络模型
def mnist_model(comp_model = False):
    return mnist.create_keras_model(compile_model=comp_model)

客户端

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 
# import numpy as np
import tensorflow as tf
from termcolor import colored
import random
from tensorflow_federated.python.simulation.models import mnist
from utils import *
'''
客户端类
'''
class client(object):
    def __init__(self ,
                local_dataset:dict, 
                client_name = 0 , 
                local_model = mnist.create_keras_model(compile_model=False)  
                 ):
        self.local_dataset = local_dataset
        self.client_name = client_name
        self.local_model = local_model
        self.dataset_size = get_datasize(self.local_dataset['train']) +get_datasize(self.local_dataset['test'])
        self.val_acc_list = []
        self.val_loss_list = []
        self.local_model_size_list = [ get_model_size(self.local_model) ]
    def set_client_name(self , name):
        self.client_name = name
    def set_model_weights(self , model : tf.keras.Model):
        self.local_model.set_weights(model.get_weights())
    def set_local_dataset(self , dataset):
        self.local_dataset = dataset   
    def client_train(self , 
                    client_epochs = 10 , 
                    model_loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True) ,
                    model_optimizer=tf.keras.optimizers.SGD(learning_rate=0.1) , 
                    model_metrics = ['accuracy']
                    ):
        self.local_model.compile(
                                optimizer=model_optimizer , 
                                loss=model_loss ,
                                metrics=model_metrics)
        client_train_history = self.local_model.fit(self.local_dataset['train'] 
                                                    , epochs = client_epochs 
                                                    , validation_data=self.local_dataset['test'] 
                                                    , validation_freq= client_epochs 
                                                    , verbose=0 
                                                    , workers= 4
                                                    , use_multiprocessing=True
                                                    )
        self.val_acc_list.append(client_train_history.history['val_accuracy'][0])
        self.val_loss_list.append(client_train_history.history['val_loss'][0])
        self.local_model_size_list.append(get_model_size(self.local_model))
    def get_local_info(self):
        return {'client_name': self.client_name , 
                'local_dataset_size': self.dataset_size , 
                'client_model_size_history': self.local_model_size_list ,
                'client_val_acc_history':self.val_acc_list , 
                'client_val_loss_history': self.val_loss_list ,
                'current_local_model_size': self.local_model_size_list[-1] , 
                'current_local_acc': self.val_acc_list[-1] , 
                'current_local_loss': self.val_loss_list[-1]           
                }

服务器类

import tensorflow as tf
# import numpy as np
from tensorflow_federated.python.simulation.models import mnist
from client import client
from typing import List
from utils import *


type_client_list = List[client]

class server(object):
    def __init__(self , 
                server_name = 0 ,
                test_dataset = None ,
                server_model = mnist.create_keras_model(compile_model=False) ):
        self.server_name = server_name
        self.test_dataset = test_dataset
        self.server_model = server_model
        self.ave_acc_list = []
        self.ave_loss_list = []
    
    def get_server_info(self):
         return {
             'server name' : self.server_name , 
             'server dataset size' : get_datasize(self.test_dataset) , 
             'server model size' : get_model_size(self.server_model) , 
             'server acc history' : self.ave_acc_list , 
             'server loss history' : self.ave_loss_list
         }

    #calculate server model by clients list
    def calculate_server_model( self, client_list : type_client_list):
        # get sum of datasets size in clients
        sum_client_datasets = 0
        for client in client_list:
            sum_client_datasets +=client.dataset_size
        # get client impact factor to server model
        rate_client_dataset_size = []
        for client in client_list:
            rate_client_dataset_size.append( client.local_model_size_list[-1] /sum_client_datasets )
        #calculate server model
        clients_modelweight_list = []
        for (client , factor) in zip(client_list,rate_client_dataset_size):
            client_union_weights = []
            client_weights = client.local_model.get_weights()
            num_client_layers = len(client_weights)
            for i in range(num_client_layers):
                client_union_weights.append(factor*client_weights[i])
            clients_modelweight_list.append(client_union_weights)
        # union server model
        metrix = []
        for weights in zip(*clients_modelweight_list):
            weights_sum = tf.reduce_sum(weights, axis =0)
            metrix.append(weights_sum)
        self.server_model.set_weights(metrix)
            

    #broadcast server model to clients list
    def broadcast_server_model( self, client_list : type_client_list ):
        for client in client_list:
            client.set_model_weights(self.server_model)
        # return client_list
   
    # server test
    def server_model_test(self ):
        server_loss , server_acc =self.server_model.evaluate(self.test_dataset , verbose=1 ,  workers=4 , use_multiprocessing=True )
        self.ave_loss_list.append(server_loss)
        self.ave_acc_list.append(server_acc)

其他要使用到的功能函数

import numpy as np
import tensorflow as tf

def preprocess_client_data(data, bs=100):
    x, y = zip(*data) 
    return tf.data.Dataset.from_tensor_slices( ( list(x) , list(y) ) ).batch(bs)

'''
shuffle dataset
'''
def shuffle_dataset(datas , labels):
    shuffle_ix = np.random.permutation(np.arange(len(datas)))

    return datas[shuffle_ix] , labels[shuffle_ix]
'''
获取客户端模型大小
'''
def get_model_size(model):
	para_num = sum([np.prod(w.shape) for w in model.get_weights()])
	# para_size: 参数个数 * 每个4字节(float32) / 1024 / 1024,单位为 MB
	para_size = para_num * 4 / 1024 / 1024
	return para_size
'''
获取客户端数据集大小
'''
def get_datasize(dataset):
    dataset_size = 0
    for batch in dataset:
        dataset_size += len(batch)
    
    return dataset_size

def summary_acc_loss(logdir , name ,loss:list ,acc:list ):
    summary_writer = tf.summary.create_file_writer(logdir)
    for rnd in range(len(loss)):            
        tf.summary.scalar('{}_acc'.format(name) , acc[rnd] , step=rnd)
        tf.summary.scalar('{}_loss'.format(name) , loss[rnd] , step=rnd)
        summary_writer.flush()

主函数

from mnist import generate_clients_data
from client import client
from server import server
from mnist_model import mnist_model
from typing import List
import matplotlib.pyplot as plt
from utils import *
import numpy as np
def draw(epoch_sumloss , epoch_acc):
    x=[i for i in range(len(epoch_sumloss))]
    #左纵坐标
    fig , ax1 = plt.subplots()
    color = 'red'
    ax1.set_xlabel('epoch')
    ax1.set_ylabel('loss' , color=color)
    ax1.plot(x , epoch_sumloss , color=color)
    ax1.tick_params(axis='y', labelcolor= color)

    ax2=ax1.twinx()
    color1='blue'
    ax2.set_ylabel('acc',color=color1)
    ax2.plot(x , epoch_acc , color=color1)
    ax2.tick_params(axis='y' , labelcolor=color1)

    fig.tight_layout()
    plt.show()

def FAVG_init( ):
    # get IID data list for client list
    num_expamples_list_in_clients = [1000 , 2000 , 1500 , 500 , 3000 , 1000 , 1000,2000,1500 , 2000]
    client_train_data_list , client_test_data_list , server_test_dataset = generate_clients_data( num_expamples_list_in_clients , 
                                                                                                num_clients = 10, 
                                                                                                IsIID=True, 
                                                                                                batch_size=100,
                                                                                                tt_rate = 0.3)
    # experiment model
    #model = mnist_model(comp_model=False)

    # set dataset for server
    server_0 = server(test_dataset=server_test_dataset , server_model=mnist_model(comp_model=False) ) 

    # client set list
    clients_list = []

    # set dataset for clients
    client_name_list = list('client_{}'.format(i) for i in range(10))
    for i in range(len(num_expamples_list_in_clients)):
        client_data_dict = {'train': client_train_data_list[i] , 'test' : client_test_data_list[i]  }
        clients_list.append( client(
                    local_dataset=client_data_dict, 
                    client_name = client_name_list[i] ,
                    local_model=  mnist_model(comp_model=False)
                    ) )
    return server_0 , clients_list
# train
def FAVG_train(server:server , clients_list: List[client] , server_round : int , client_enpochs:int):
    for i in range(server_round):
        for client_ in clients_list:
            client_.client_train(client_epochs=client_enpochs)
            # print(client_.get_local_info() , '\n')
        server.calculate_server_model(clients_list)
        server.broadcast_server_model(clients_list)
        server.server_model_test()   
    return server , clients_list 

if __name__ == "__main__":
    server_0 , clients_list = FAVG_init()
    server_0 , clients_list = FAVG_train(server_0 , clients_list , 200 , 10)
    log_dir = "/tmp/logs/scalars/FAVG/"
    summary_acc_loss(logdir=log_dir , name=server_0 , loss=server_0.ave_loss_list , acc=server_0.ave_acc_list)
    # draw(server_0.ave_loss_list , server_0.ave_acc_list)
    np.savez('server_acc' , server_0.ave_acc_list)
    np.savez('server_loss' , server_0.ave_loss_list)
    for _client in clients_list:
        np.savez('{}_acc'.format(_client.client_name) , _client.val_acc_list )
        np.savez('{}_loss'.format(_client.client_name) , _client.val_loss_list )
  • 21
    点赞
  • 129
    收藏
    觉得还不错? 一键收藏
  • 30
    评论
评论 30
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值