联邦学习-fedavg_trainer

FedAVGTrainer

import torch
import torch as t
import tqdm
import numpy as np
from torch.utils.data import DataLoader
from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient as SecureAggClient
from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorServer as SecureAggServer
from federatedml.nn.dataset.base import Dataset
from federatedml.nn.homo.trainer.trainer_base import TrainerBase
from federatedml.util import LOGGER, consts
from federatedml.optim.convergence import converge_func_factory


class FedAVGTrainer(TrainerBase):
    """
    Parameters
    ----------
    epochs: int >0, epochs to train
    batch_size: int, -1 means full batch
    secure_aggregate: bool, default is True, whether to use secure aggregation. if enabled, will add random number
                            mask to local models. These random number masks will eventually cancel out to get 0.
    weighted_aggregation: bool, whether add weight to each local model when doing aggregation.
                         if True, According to origin paper, weight of a client is: n_local / n_global, where n_local
                         is the sample number locally and n_global is the sample number of all clients.
                         if False, simply averaging these models.

    early_stop: None, 'diff' or 'abs'. if None, disable early stop; if 'diff', use the loss difference between
                two epochs as early stop condition, if differences < tol, stop training ; if 'abs', if loss < tol,
                stop training
    tol: float, tol value for early stop

    aggregate_every_n_epoch: None or int. if None, aggregate model on the end of every epoch, if int, aggregate
                             every n epochs.
    cuda: bool, use cuda or not
    pin_memory: bool, for pytorch DataLoader
    shuffle: bool, for pytorch DataLoader
    data_loader_worker: int, for pytorch DataLoader, number of workers when loading data
    validation_freqs: None or int. if int, validate your model and send validate results to fate-board every n epoch.
                      if is binary classification task, will use metrics 'auc', 'ks', 'gain', 'lift', 'precision'
                      if is multi classification task, will use metrics 'precision', 'recall', 'accuracy'
                      if is regression task, will use metrics 'mse', 'mae', 'rmse', 'explained_variance', 'r2_score'
    checkpoint_save_freqs: save model every n epoch, if None, will not save checkpoint.
    task_type: str, 'auto', 'binary', 'multi', 'regression'
               this option decides the return format of this trainer, and the evaluation type when running validation.
               if auto, will automatically infer your task type from labels and predict results.
    """

    def __init__(self, epochs=10, batch_size=512,  # training parameter
                 early_stop=None, tol=0.0001,  # early stop parameters
                 secure_aggregate=True, weighted_aggregation=True, aggregate_every_n_epoch=None,  # federation
                 cuda=False, pin_memory=True, shuffle=True, data_loader_worker=0,  # GPU & dataloader
                 validation_freqs=None,  # validation configuration
                 checkpoint_save_freqs=None,  # checkpoint configuration
                 task_type='auto'
                 ):

        super(FedAVGTrainer, self).__init__()

        # training parameters
        self.epochs = epochs
        self.tol = tol
        self.validation_freq = validation_freqs
        self.save_freq = checkpoint_save_freqs

        self.task_type = task_type
        task_type_allow = [
            consts.BINARY,
            consts.REGRESSION,
            consts.MULTY,
            'auto']
        assert self.task_type in task_type_allow, 'task type must in {}'.format(
            task_type_allow)

        # aggregation param
        self.secure_aggregate = secure_aggregate
        self.weighted_aggregation = weighted_aggregation
        self.aggregate_every_n_epoch = aggregate_every_n_epoch

        # GPU
        self.cuda = cuda
        if not torch.cuda.is_available() and self.cuda:
            raise ValueError('Cuda is not available on this machine')

        # data loader
        self.batch_size = batch_size
        self.pin_memory = pin_memory
        self.shuffle = shuffle
        self.data_loader_worker = data_loader_worker

        self.early_stop = early_stop
        early_stop_type = ['diff', 'abs']
        if early_stop is not None:
            assert early_stop in early_stop_type, 'early stop type must be in {}, bug got {}' \
                .format(early_stop_type, early_stop)

        # communicate suffix
        self.comm_suffix = 'fedavg'

参数

----------

epochs:int>0,要训练的时间段

batch_size: int,-1表示完整批次

secure_aggregate: bool,默认值为True,是否使用安全聚合。如果启用,将添加随机数掩模到本地模型。这些随机数掩码最终将抵消,得到0。

weighted_gaggregation: bool,在进行聚合时是否向每个局部模型添加权重。如果为True,则根据原始纸张,客户端的权重为:n_local/n_global,其中n_local是本地的样本数,n_global是所有客户端的样本数。如果为False,则简单地对这些模型求平均值。

early_stop:无,“diff”或“abs”。如果无,则禁用提前停止;如果“diff”,则使用两个时期作为早期停止条件,如果差异<tol,停止训练;如果“abs”,如果损失<tol,停止训练

tol:浮动,提前停止的tol值

aggregate_every_n_epoch:无或int。如果无,则在每个epoch结束时聚合模型,如果int,则聚合每n个epochs。

cuda:bool,用不用cuda

pin_memory: bool,用于pytorch DataLoader

shuffle: bool,用于pytorch DataLoader

data_loader_worker: int,对于pytorch DataLoader,加载数据时的工作者数

validation_freqs: None或int。如果为int,则验证模型并每n个epoch将验证结果发送给命运板。

  • 如果是二进制分类任务,将使用度量“auc”、“ks”、“gain”、“lift”、“precision”

  • 如果是多分类任务,将使用度量“精度”、“召回”、“准确度”

  • 如果是回归任务,将使用度量“毫秒”、“mae”、“rmse”、“解释方差”、“r2_score”

checkpoint_save_freqs: 每n个历元保存一次模型,如果无,则不会保存检查点。

task_type: str,'auto','binary','multi','regression'该选项决定了该培训师的返回格式,以及运行验证时的评估类型。如果自动,将自动从标签推断任务类型并预测结果。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值