FATE联邦学习开源平台V2.1.0学习笔记(三):launcher用法

0. 引言

本文将介绍FATE开源框架中launcher的功能,以及如何使用launcher来测试单个组件,最后将教程如果一步步编写自己组件的launcher代码并运行,本文以coordinate_lr组件为例,最后给出完成代码以供参考。

FATE官网https://fate.readthedocs.io/en/latest/
FATE-Githubhttps://github.com/FederatedAI/FATE

具体请参考github中的doc\2.0\fate\ml\run_launchers.md此处

1. Launcher介绍

根据FATE官方介绍,FATE2.0版本引入了launchers用于本地运行机器学习模块,这是一种一种轻量级的方式来执行本地实验而不依赖于FATE-Flow服务,launchers不通过FATE_Client直接调用算法模块。launchers可以通过简单的命令行参数直接启动,可以理解为本地开发和测试机器学习组件的一种便捷工具。

接下来将首先分析Examples中的launchers实例,然后给出自己编写的launcher思路。

2. 对examples中launcher实例代码解析(SSHE_LR)

本文将逐步解析sshe_lr_launcher.py代码

2.1 导包

导入相关的包,主要是有关launchers的文件(通过launch函数运行一个launcher)

from fate.arch.launchers.multiprocess_launcher import launch

以及需要测试的机器学习代码组件
代码组件在fata\ml文件夹下

from fate.ml.glm.hetero.sshe import SSHELogisticRegression
from fate.arch import dataframe

2.2 定义数据类

定义需要用到的数据类SSHEArguments并用@dataclass修饰

@dataclass
class SSHEArguments:
    lr: float = field(default=0.15)
    guest_data: str = field(default=None)
    host_data: str = field(default=None)

2.3 定义运行函数run_sshe_lr

首先初始化mpc实例(安全多方计算)

ctx.mpc.init()          # ctx中的mpc方法,返回MPC类对象,在调用MPC中的init方法,通过初始化默认提供程序来初始化MPCSensor模块以及设置RNG生成器

加载guest和host数据(直接用examples里的data,或者在命令行中添加data参数),优于是纵向联邦学习,因此对guest和host的处理稍有区别

guest_data = './hetero_breast_guest.csv'
host_data = './hetero_breast_host.csv'
if ctx.is_on_guest:
    kwargs = {
        "sample_id_name": None,
        "match_id_name": "id",
        "delimiter": ",",
        "label_name": "y",
        "label_type": "int32",
        "dtype": "float32",
    }
    input_data = dataframe.CSVReader(**kwargs).to_frame(ctx, guest_data)
else:
    kwargs = {
        "sample_id_name": None,
        "match_id_name": "id",
        "delimiter": ",",
        "dtype": "float32",
    }
    input_data = dataframe.CSVReader(**kwargs).to_frame(ctx, host_data)

如需在命令行中输入传入参数,需要创建HfArgumentParser类对象并调用parse_args_into_datackasses方法解析命令行参数。

from fate.arch.launchers.argparser import HfArgumentParser
args, _ = HfArgumentParser(SSHEArguments).parse_args_into_dataclasses(return_remaining_strings=True)

2.4 定义一个sshe_lr模块的对象

创建fata\ml中相关组件(\glm\hetero\sshe)的类对象:

inst = SSHELogisticRegression(
        epochs=5,
        batch_size=300,
        tol=0.01,
        early_stop="diff",
        learning_rate=args.lr,
        init_param={"method": "random_uniform", "fit_intercept": True, "random_state": 1},
        reveal_every_epoch=False,
        reveal_loss_freq=2,
        threshold=0.5,
    )

2.5 将输入数据集输入到该模块对象中训练并打印结果

调用模块对象的fit成员方法,传入数据并进行训练。

inst.fit(ctx, train_data=input_data)
logger.info(f"model: {pprint.pformat(inst.get_model())}")

2.6 调用launch函数接口,并从终端运行launcher

确保使用fata.arch中的launch作为程序入口调用函数

if __name__ == "__main__":
    launch(run_sshe_lr, extra_args_desc=[SSHEArguments])

终端运行Launcher:

python sshe_lr_launcher.py --parties guest:9999 host:10000 --log_level INFO --guest_data ../data/breast_hetero_guest.csv --host_data ../data/breast_hetero_host.csv

运行结果

在这里插入图片描述

3. 自己编写代码实现有协调方的逻辑回归的launcher

要编写一个launcher,首先设定想要实现的FATE模块运行的案例(在fate/ml内),并将其包装成一个函数。

根据run_launchers.md文档,按照说明文档编写具有协调方的逻辑回归launcher代码(coordinated_lr_launcher.py),其中FATE模块文件在fate/ml/glm/hetero/coordinated_lr中。

模仿examples里的sshe_lr_launcher代码结构修改并编写自己的coordinated_lr_launcher代码

导包修改

将SSHE相应的包修改为coor相应的组件包,即

# from fate.ml.glm.hetero.sshe import SSHELogisticRegression
from fate.ml.glm.hetero.coordinated_lr import CoordinatedLRModuleArbiter, CoordinatedLRModuleGuest, \
        CoordinatedLRModuleHost

因为coordinate的组件包含三个参与方(arbiter, host, guest)的文件,因此都需要导包

数据类修改

class CoorArguments:
    lr: float = field(default=0.15)
    guest_data: str = field(default=None)
    host_data: str = field(default=None)

数据集通过命令行参数输入,也使用examples中的数据集

运行函数修改

# def run_sshe_lr(ctx: "Context"):
def run_coor_linr(ctx: "Context"):
	···

HfArgumentParser类对象的参数更改为coor的参数类

# args, _ = HfArgumentParser(SSHEArguments).parse_args_into_dataclasses(return_remaining_strings=True)
args, _ = HfArgumentParser(CoorArguments).parse_args_into_dataclasses(return_remaining_strings=True)

修改模块的对象

因为coor模块的三个参与方具有单独的类对象,因此每个参与方都需要实例化其类对象
这里的参数可以参考组件中的实例化参数\python\fate\components\components\coordinated_lr.py或者examplescorrdinated_lr的参数

guest:

    if ctx.is_on_guest:
        inst = CoordinatedLRModuleGuest(
            epochs=5,
            batch_size=300,
            optimizer_param={"method": "sgd", "optimizer_params": {"lr": 0.1}, "penalty": "l2", "alpha": 0.001},
            learning_rate_param={"method": "linear", "scheduler_params": {"start_factor": 0.7,
                                                                          "total_iters": 100}},
            init_param={"fit_intercept": True, "method": "zeros"}
        )
        kwargs = {
            "sample_id_name": None,
            "match_id_name": "id",
            "delimiter": ",",
            "label_name": "y",
            "label_type": "int32",
            "dtype": "float32",
        }
        input_data = dataframe.CSVReader(**kwargs).to_frame(ctx, args.guest_data)
        inst.fit(ctx, train_data=input_data)
        print(f"guest_model: {inst.get_model()}")

host:

    elif ctx.is_on_host:
        inst = CoordinatedLRModuleHost(
            epochs=5,
            batch_size=300,
            optimizer_param={"method": "sgd", "optimizer_params": {"lr": 0.1}, "penalty": "l2", "alpha": 0.001},
            learning_rate_param={"method": "linear", "scheduler_params": {"start_factor": 0.7,
                                                                          "total_iters": 100}},
            init_param={"fit_intercept": True, "method": "zeros"}

        )
        kwargs = {
            "sample_id_name": None,
            "match_id_name": "id",
            "delimiter": ",",
            "dtype": "float32",
        }
        input_data = dataframe.CSVReader(**kwargs).to_frame(ctx, args.host_data)
        inst.fit(ctx, train_data=input_data)
        print(f"host_model: {inst.get_model()}")

arbiter: arbiter没有训练数据,因此不需要加载数据

    elif ctx.is_on_arbiter:
        inst = CoordinatedLRModuleArbiter(
            epochs=5,
            batch_size=300,
            early_stop="diff",
            tol=1e-4,
            optimizer_param={"method": "sgd", "optimizer_params": {"lr": 0.1}, "penalty": "l2", "alpha": 0.001},
            learning_rate_param={"method": "linear", "scheduler_params": {"start_factor": 0.7,
                                                                          "total_iters": 100}}
        )
        inst.fit(ctx)
        print(f"model: {inst.get_model()}")

调用launch函数接口,并从终端运行launcher

if __name__ == "__main__":
    # launch(run_sshe_lr, extra_args_desc=[SSHEArguments])
    launch(run_coor_linr, extra_args_desc=[CoorArguments])

终端运行Launcher:注意在parties参数里添加arbiter:10000

# python sshe_lr_launcher.py --parties guest:9999 host:10000 --log_level INFO --guest_data ../data/breast_hetero_guest.csv --host_data ../data/breast_hetero_host.csv
python coor_lr_launcher.py --parties guest:9999 host:10000 arbiter:10000 --log_level INFO --guest_data ../data/breast_hetero_guest.csv --host_data ../data/breast_hetero_host.csv    

运行结果

在这里插入图片描述

完整代码

import logging
import typing
from dataclasses import dataclass, field

from fate.arch.launchers.argparser import HfArgumentParser
from fate.arch.launchers.multiprocess_launcher import launch

if typing.TYPE_CHECKING:
    from fate.arch import Context

logger = logging.getLogger(__name__)


@dataclass
class CoorArguments:
    lr: float = field(default=0.05)
    guest_data: str = field(default=None)
    host_data: str = field(default=None)


def run_coor_linr(ctx: "Context"):
    from fate.ml.glm.hetero.coordinated_lr import CoordinatedLRModuleArbiter, CoordinatedLRModuleGuest, \
        CoordinatedLRModuleHost
    from fate.arch import dataframe

    ctx.mpc.init()
    args, _ = HfArgumentParser(CoorArguments).parse_args_into_dataclasses(return_remaining_strings=True)

    if ctx.is_on_guest:
        inst = CoordinatedLRModuleGuest(
            epochs=5,
            batch_size=300,
            optimizer_param={"method": "sgd", "optimizer_params": {"lr": 0.1}, "penalty": "l2", "alpha": 0.001},
            learning_rate_param={"method": "linear", "scheduler_params": {"start_factor": 0.7,
                                                                          "total_iters": 100}},
            init_param={"fit_intercept": True, "method": "zeros"}
        )
        kwargs = {
            "sample_id_name": None,
            "match_id_name": "id",
            "delimiter": ",",
            "label_name": "y",
            "label_type": "int32",
            "dtype": "float32",
        }
        input_data = dataframe.CSVReader(**kwargs).to_frame(ctx, args.guest_data)
        inst.fit(ctx, train_data=input_data)
        print(f"guest_model: {inst.get_model()}")
    elif ctx.is_on_host:
        inst = CoordinatedLRModuleHost(
            epochs=5,
            batch_size=300,
            optimizer_param={"method": "sgd", "optimizer_params": {"lr": 0.1}, "penalty": "l2", "alpha": 0.001},
            learning_rate_param={"method": "linear", "scheduler_params": {"start_factor": 0.7,
                                                                          "total_iters": 100}},
            init_param={"fit_intercept": True, "method": "zeros"}

        )
        kwargs = {
            "sample_id_name": None,
            "match_id_name": "id",
            "delimiter": ",",
            "dtype": "float32",
        }
        input_data = dataframe.CSVReader(**kwargs).to_frame(ctx, args.host_data)
        inst.fit(ctx, train_data=input_data)
        print(f"host_model: {inst.get_model()}")
    elif ctx.is_on_arbiter:
        inst = CoordinatedLRModuleArbiter(
            epochs=5,
            batch_size=300,
            early_stop="diff",
            tol=1e-4,

            optimizer_param={"method": "sgd", "optimizer_params": {"lr": 0.1}, "penalty": "l2", "alpha": 0.001},
            learning_rate_param={"method": "linear", "scheduler_params": {"start_factor": 0.7,
                                                                          "total_iters": 100}}
        )
        inst.fit(ctx)
        print(f"model: {inst.get_model()}")


if __name__ == "__main__":
    launch(run_coor_linr, extra_args_desc=[CoorArguments])

  • 24
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值