在配置文件处将host节点扩充至10个
guest:
- '9999'
host:
- '10000'
- '10001'
- '10002'
- '10003'
- '10004'
- '10005'
- '10006'
- '10007'
- '10008'
- '10009'
arbiter:
- '10000'
假设每个节点都使用一份数据,便可以使用for循环进行简单控制,上传乳腺癌诊断的横向联邦学习的数据集
for i in range(10):
pipeline.transform_local_file_to_dataframe( file="/data/projects/fate/examples/data/breast_homo_host.csv",
meta=meta, head=True, extend_sid=True,
namespace="experiment",
name=f"breast_homo_host_{i}")
循环使用不同的host节点id初始化pipeline
for i in range(10):
host = parties.host[i]
pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter)
训练完成
设置训练配置的方法源码
def get_config_of_default_runner(
algo: str = "fedavg",
model: Union[TorchModule, Sequential, ModelLoader] = None,
optimizer: Union[TorchOptimizer, Loader] = None,
loss: Union[TorchModule, CustFuncLoader] = None,
training_args: TrainingArguments = None,
fed_args: FedArguments = None,
dataset: DatasetLoader = None,
data_collator: CustFuncLoader = None,
tokenizer: CustFuncLoader = None,
task_type: Literal["binary", "multi", "regression", "causal_lm", "others"] = "binary",
):
if model is not None and not isinstance(
model, (TorchModule, Sequential, ModelLoader)
):
raise ValueError(
f"The model is of type {type(model)}, not TorchModule, Sequential, or ModelLoader. Remember to use patched_torch_hook for passing NN Modules or Optimizers."
)
if fed_args is not None and not isinstance(fed_args, FedArguments):
raise ValueError(
f"Federation arguments are of type {type(fed_args)}, not FedArguments."
)
runner_conf = _get_config_of_default_runner(
optimizer, loss, training_args, dataset, data_collator, tokenizer, task_type
)
runner_conf['algo'] = algo
runner_conf['model_conf'] = model.to_dict() if model is not None else None
runner_conf['fed_args_conf'] = fed_args.to_dict() if fed_args is not None else None
return runner_conf
model项可以是TorchModule,Sequential,ModelLoader的实例,表示模型
- 对于
model
参数,如果它不为None
并且不是TorchModule
、Sequential
或ModelLoader
的实例,则抛出一个ValueError
。 - 对于
fed_args
参数,如果它不为None
并且不是FedArguments
的实例,则也抛出一个ValueError
。 - 方法返回填充了所有必要信息的
runner_conf
字典。
from torch.nn import Sequential as tSequential
class Sequential(tSequential):
def to_dict(self):
"""
get the structure of current sequential
"""
layer_confs = {}
idx = 0
for k in self._modules:
ordered_name = idx
layer_confs[ordered_name] = self._modules[k].to_dict()
idx += 1
ret_dict = {
"module_name": "fate.components.components.nn.torch.base",
"item_name": load_seq.__name__,
"kwargs": {"seq_conf": layer_confs},
}
return ret_dict
def to_json(self):
return json.dumps(self.to_dict(), indent=4)