【联邦学习框架FLGo学习】2.2 运行配置-初始化和选项(option)

首先根据前面的几个例子总结一下直接使用FLGo运行其中算法的流程

  1. 定义联邦任务:指定任务路径、使用的数据集、数据集的划分方式
  2. 创建联邦i任务:使用flgo.gen_task函数根据config创建不同的联邦任务
  3. 选择算法,进行联邦训练:使用flgo.init函数和run函数进行训练
  4. 提取训练结果进行分析展示

之前创建不同的联邦任务时学习的主要是熟悉第一、二部分,本篇将简单介绍第三部分选择算法,进行联邦训练;在这一节中只介绍了在调用flgo.init和run函数iaangz时的一些参数设置,至于这两个函数具体的实现流程,在之后的教程中也会进行介绍。

flgo.init函数介绍

def init(task: str, algorithm, option = {}, model=None, Logger: flgo.experiment.logger.BasicLogger = flgo.experiment.logger.simple_logger.SimpleLogger, Simulator: BasicSimulator=flgo.system_simulator.DefaultSimulator, scene='horizontal'):
    r"""
    Initialize a runner in FLGo, which is to optimize a model on a specific task (i.e. IID-mnist-of-100-clients) by the selected federated algorithm.
    :param
        task (str): the dictionary of the federated task
        algorithm (module || class): the algorithm will be used to optimize the model in federated manner, which must contain pre-defined attributions (e.g. algorithm.Server and algorithm.Client for horizontal federated learning)
        option (dict || str): the configurations of training, environment, algorithm, logger and simulator
        model (module || class): the model module that contains two methods: model.init_local_module(object) and model.init_global_module(object)
        Logger (class): the class of the logger inherited from flgo.experiment.logger.BasicLogger
        Simulator (class): the class of the simulator inherited from flgo.system_simulator.BasicSimulator
        scene (str): 'horizontal' or 'vertical' in current version of FLGo
    :return
        runner: the object instance that has the method runner.run()
    """
    ...

调用示例

 runner = flgo.init(task, fedavg, {'log_file':True, 'num_epochs':1})
    runner.run()

可以看到flgo.init返回一个具有run方法的对象,通过调用run方法来开启迭代训练,具体的init函数的输入包括:

  • task: 联邦任务(路径、处理数据集等)由前面的gen_task生成
  • algorithm: ,要求algorithm的类型是class或module,横向联邦中需要其具备algorithm.Server和algorithm.Client两个可访问的属性;
  • option(可选):运行选项,类型为字典,包含运行时的各类参数;
  • model(可选):待优化的模型模块,要求model的类型为类或module,横向联邦中需要其具备model.init_global_module和model.init_local_module两个可访问的方法,为实体初始化module;
  • Logger(可选):日志记录器类,要求父类为flgo.experiment.logger.BasicLogger
  • Simulator(可选):系统模拟器类,要求父类为flgo.system_simulator.BasicSimulator
  • scene(可选):类型为字符串,联邦场景(e.g. 横向 or 纵向)

每个参数应该如何设置,来达到自己的实验目的在之后的教程中也将会依次介绍。

本节要介绍一下其中option参数的设置

option参数的设置

在flgo中,每个runner的运行时参数由传入的字典option来指定,这些参数主要分为4类(训练参数、用户选项、日志选项、模拟器选项),为了能够在写代码时能够查看option包含哪些关键字,可以使用flgo.option_helper( )函数来查看。该函数输出结果如下
在这里插入图片描述

+---------------------+-----------+------------------------------------------------------------------------------------------------+---------------+---------------------------------------------------------------------------+
|         Name        |    Type   |                                          Description                                           | Default Value |                                  Comment                                  |
+---------------------+-----------+------------------------------------------------------------------------------------------------+---------------+---------------------------------------------------------------------------+
|      num_rounds     |    int    |                                 number of communication rounds                                 |       20      |                                                                           |
|      proportion     |   float   |                            proportion of clients sampled per round                             |      0.2      |                                                                           |
| learning_rate_decay |   float   |                          learning rate decay for the training process                          |     0.998     |                        effective if lr_scheduler>-1                       |
|     lr_scheduler    |    int    |                           type of the global learning rate scheduler                           |       -1      |                        effective if larger than -1                        |
|      early_stop     |    int    |        stop training if there is no improvement  for no smaller than the maximum rounds        |       -1      |                        effective if larger than -1                        |
|      num_epochs     |    int    |                               number of epochs of local training                               |       5       |                                                                           |
|      num_steps      |    int    |                               number of steps of local training                                |       -1      |                    dominates num_epochs if larger than 0                  |
|    learning_rate    |   float   |                                learning rate of local training                                 |      0.1      |                                                                           |
|      batch_size     | int\float |                                  batch size of local training                                  |       64      | -1 means full batch and float value  means the ratio of the full datasets |
|      optimizer      |    str    |                           to select the optimizer of local training                            |     'sgd'     |                      'sgd'|'adam'|'rmsprop'|'adagrad'                     |
|      clip_grad      |   float   |           clipping gradients if the max norm of  gradients \|\|g\|\| > clip_norm > 0           |      0.0      |                        effective if larger than 0.0                       |
|       momentum      |   float   |                                   momentum of local training                                   |      0.0      |                                                                           |
|     weight_decay    |   float   |                                 weight decay of local training                                 |      0.0      |                                                                           |
|   num_edge_rounds   |    int    |                                number of edge rounds in hierFL                                 |       5       |                    effective if scene is 'hierarchical'                   |
|      algo_para      |  int\list |                              algorithm-specific hyper-parameters                               |       []      |               the order should be consistent with  the claim              |
|        sample       |    str    |                                    to select sampling form                                     |    'uniform'  |               'uniform'|'md'| 'full'| x+'_with_availability'              |
|      aggregate      |    str    |                                   to select aggregation form                                   |     'other'   |             'uniform'|'weighted_com'|'weighted_scale'|'other']            |
|    train_holdout    |   float   |      the rate of holding out the validation  dataset from all the local training datasets      |      0.1      |                                                                           |
|     test_holdout    |   float   | the rate of holding out the validation  dataset from the testing datasets owned by  the server |      0.0      |              effective if the server has  no validation data              |
|      local_test     |    bool   |  the local validation data will be equally  split into validation and testing parts  if True   |     False     |                                                                           |
|         seed        |    int    |                                seed for all the random modules                                 |       0       |                                                                           |
|       dataseed      |    int    |               seed for all the random modules for data train/val/test partition                |       0       |                                                                           |
|         gpu         |  int\list |                            GPU IDs and empty input means using CPU                             |       []      |                                                                           |
|   server_with_cpu   |    bool   |                   the model parameters will be stored in  the memory if True                   |     False     |                                                                           |
|    num_parallels    |    int    |                         the number of parallels during communications                          |       1       |                                                                           |
|     num_workers     |    int    |                              the number of workers of DataLoader                               |       0       |                                                                           |
|      pin_memory     |    bool   |               1)pin_memory of DataLoader and 2) load  data directly into memory                |     False     |                                                                           |
|   test_batch_size   |    int    |                              the batch_size used in testing phase                              |      512      |                                                                           |
|     availability    |    str    |                               to select client availability mode                               |     'IDL'     |            'IDL'|'YMF'|'MDF'|'LDF'|'YFF'|'HOMO'|'LN'|'SLN'|'YC'           |
|     connectivity    |    str    |                               to select client connectivity mode                               |     'IDL'     |                                'IDL'|'HOMO'                               |
|     completeness    |    str    |                               to select client completeness mode                               |     'IDL'     |                       'IDL'|'PDU'|'FSU'|'ADU'|'ASU'                       |
|    responsiveness   |    str    |                              to select client responsiveness mode                              |     'IDL'     |                              'IDL'|'LN'|'UNI'                             |
|      log_level      |    str    |                                      the level of logger                                       |     'INFO'    |                               'INFO'|'DEBUG'                              |
|       log_file      |    bool   |                        whether log to file and default  value is False                         |     False     |                                                                           |
|    no_log_console   |    bool   |                        whether log to screen and default  value is True                        |      True     |                                                                           |
|     no_overwrite    |    bool   |                              whether to overwrite the old result                               |     False     |                                                                           |
|    eval_interval    |    int    |                                   evaluate every __ rounds;                                    |       1       |                                                                           |
+---------------------+-----------+------------------------------------------------------------------------------------------------+---------------+---------------------------------------------------------------------------+
  • 4
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值