ATM源码分析

example/example.py

from atm import ATM

atm = ATM()

results = atm.run(train_path="/home/tqc/PycharmProjects/automl/ATM/demos/pollution_1.csv")
results.describe()

atm.worker.Worker#select_hyperpartition

调试打印的信息和论文描述的一致,超划分hyperpartition表示条件参数树 ( c o n d i t i o n a l p a r a m e t e r t r e e , C P T ) (conditional parameter tree, CPT) (conditionalparametertree,CPT)从root到leaf的一个路径

>>> pprint(hyperpartitions)
[<dt: [('criterion', 'entropy')]>,
 <dt: [('criterion', 'gini')]>,
 <knn: [('weights', 'uniform'), ('algorithm', 'ball_tree'), ('metric', 'minkowski')]>,
 <knn: [('weights', 'uniform'), ('algorithm', 'ball_tree'), ('metric', 'euclidean')]>,...]

观察这个打印信息,会发现

>>> hyperpartitions[0].categoricals
[('criterion', 'entropy')]
>>> pprint(hyperpartitions[0].tunables)
[('max_features',
  <btb.hyper_parameter.FloatHyperParameter object at 0x7fd946ae83c8>),
 ('max_depth',
  <btb.hyper_parameter.IntHyperParameter object at 0x7fd946ae82e8>),
 ('min_samples_split',
  <btb.hyper_parameter.IntHyperParameter object at 0x7fd946ae8e80>),
 ('min_samples_leaf',
  <btb.hyper_parameter.IntHyperParameter object at 0x7fd946ae8f28>)]

超划分的作用就是从一个支离破碎的结构空间中取一个连续N维空间,从而使GP可以在这个空间中发挥作用。

btb.selection.uniform.Uniform#select
atm.worker.Worker#select_hyperpartition
atm.worker.Worker#run_classifier

hyperpartition = self.select_hyperpartition()

随机选择一个超划分。貌似在进行MAB

pprint(params)
{'_scale': True,
 'algorithm': 'kd_tree',
 'leaf_size': 38,
 'metric': 'chebyshev',
 'n_neighbors': 13,
 'weights': 'uniform'}

atm.database.Database#start_classifier
将超参实例化为分类器对象

        classifier = self.Classifier(hyperpartition_id=hyperpartition_id,
                                     datarun_id=datarun_id,
                                     host=host,
                                     hyperparameter_values=hyperparameter_values,
                                     start_time=datetime.now(),
                                     status=ClassifierStatus.RUNNING)

又是个阴间代码
atm/database.py:382

目测是在用ORM操作数据库

model, metrics = self.test_classifier(hyperpartition.method, params)
>>> model.pipeline
Pipeline(memory=None,
         steps=[('standard_scale',
                 StandardScaler(copy=True, with_mean=True, with_std=True)),
                ('knn',
                 KNeighborsClassifier(algorithm='ball_tree', leaf_size=20,
                                      metric='euclidean', metric_params=None,
                                      n_jobs=None, n_neighbors=16, p=2,
                                      weights='distance'))],
         verbose=False)
>>> metrics
{'cv': [{'accuracy': 1.0, 'cohen_kappa': 1.0, 'f1': 1.0, 'mcc': 1.0, 'roc_auc': 1.0, 'ap': 1.0}, ...

感觉总体流程也就这样

selectortuner默认为uniform的随机搜索

            selector (str):
                Type of selector to use. Optional. Defaults to ``'uniform'``.
            tuner (str):
                Type of tuner to use. Optional. Defaults to ``'uniform'``.
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值