Model Search - 大规模模型架构自动搜索框架

本文翻译整理自:https://github.com/google/model_search


header


一、关于 Model Search

Model Search (MS) 是一个实现 AutoML 算法的大规模模型架构搜索框架。它旨在帮助研究人员加速探索过程,为分类问题(即包含不同类型层的深度神经网络)找到合适的模型架构。

该库支持以下功能:

  • 在数据上开箱即用运行多种 AutoML 算法 - 包括自动搜索合适的模型架构、最佳模型集成方案以及最优蒸馏模型
  • 对比搜索过程中发现的不同模型
  • 创建自定义搜索空间以定制神经网络中的层类型

框架技术细节详见 InterSpeech 论文

虽然该框架理论上可用于回归问题,但当前版本仅支持分类问题。下面我们将通过经典分类问题来展示框架如何自动发现具有竞争力的模型架构。


相关链接资源


关键功能特性

  • 支持多种 AutoML 算法自动搜索
  • 提供模型比较功能
  • 允许自定义神经网络层类型
  • 支持分布式并行搜索
  • 提供图像和非结构化数据处理能力

二、快速开始


1、基础 CSV 数据示例

假设您有一个包含数值特征的 CSV 文件,希望自动寻找最佳模型架构:

import model_search
from model_search import constants
from model_search import single_trainer
from model_search.data import csv_data

trainer = single_trainer.SingleTrainer(
    data=csv_data.Provider(
        label_index=0,
        logits_dimension=2,
        record_defaults=[0, 0, 0, 0],
        filename="model_search/data/testdata/csv_random_data.csv"),
    spec=constants.DEFAULT_DNN)

trainer.try_models(
    number_models=200,
    train_steps=1000,
    eval_steps=100,
    root_dir="/tmp/run_example",
    batch_size=32,
    experiment_name="example",
    experiment_owner="model_search_user")

上述代码将尝试 200 种不同的二分类模型(logits_dimension设为 2)。根目录下会生成所有模型的子目录,均已包含评估结果。您可以使用 TensorBoard 打开目录查看所有模型及其评估指标。

搜索将按照默认规范执行,该规范位于:
model_search/configs/dnn_config.pbtxt

如需了解字段详情或创建自定义规范,请参考:
model_search/proto/phoenix_spec.proto


2、图像数据示例

以下是图像二分类的示例代码:

import model_search
from model_search import constants
from model_search import single_trainer
from model_search.data import image_data

trainer = single_trainer.SingleTrainer(
    data=image_data.Provider(
        input_dir="model_search/data/testdata/images"
        image_height=100,
        image_width=100,
        eval_fraction=0.2),
    spec=constants.DEFAULT_CNN)

trainer.try_models(
    number_models=200,
    train_steps=1000,
    eval_steps=100,
    root_dir="/tmp/run_example",
    batch_size=32,
    experiment_name="example",
    experiment_owner="model_search_user")

此 API 遵循与 tf.keras.preprocessing.image_dataset_from_directory 相同的输入字段。

搜索将按照默认 CNN 规范执行,该规范位于:
model_search/configs/cnn_config.pbtxt


三、非 CSV/图像数据处理

要处理非 CSV 数据,您需要实现继承自 model_search.data.Provider 抽象类的子类。这允许自定义特征列和分类任务(即分类类别数)。

class Provider(object, metaclass=abc.ABCMeta):
  """数据提供者接口
  
  定义三个与 Estimator 相关训练的函数:
    * 返回特征和标签批次张量的训练/测试输入函数
    * 数据集的特征列
    * 问题描述
  """

  def get_input_fn(self, hparams, mode, batch_size: int):
    """返回训练和评估的输入函数"""

  def get_serving_input_fn(self, hparams):
    """返回用于导出 SavedModel 的服务输入函数"""

  @abc.abstractmethod
  def number_of_classes(self) -> int:
    """返回类别数(回归问题为 logits 维度)"""

  def get_feature_columns(
      self
  ) -> List[Union[feature_column._FeatureColumn,                   feature_column_v2.FeatureColumn]]:
    """返回特征列列表"""

实现示例可参考 model_search/data/csv_data.py。完成类实现后,可将其传递给 model_search.single_trainer.SingleTrainer 来读取数据。


四、自定义模型与架构

1、添加自定义块

系统通过 Block 概念进行搜索。您需要实现两个核心方法:

class Block(object, metaclass=abc.ABCMeta):
  """块 API 接口"""

  @abc.abstractmethod
  def build(self, input_tensors, is_training, lengths=None):
    """构建神经网络块"""

  @abc.abstractproperty
  def is_input_order_important(self):
    """判断输入张量中元素的顺序是否重要"""

2、注册自定义块

实现块后,需使用装饰器注册:

@register_block(
    lookup_name='AVERAGE_POOL_2X2', init_args={'kernel_size': 2}, enum_id=8)
@register_block(
    lookup_name='AVERAGE_POOL_4X4', init_args={'kernel_size': 4}, enum_id=9)
class AveragePoolBlock(Block):
  """平均池化层"""

注册后,可通过修改 PhoenixSpec 中的 blocks_to_use 字段(位于 model_search/proto/phoenix_spec.proto)来扩展搜索空间。

注:系统默认将块堆叠形成塔式架构进行集成。您可将配置中的最小/最大深度设为 1,使系统直接搜索最佳单块结构。


五、快速创建训练二进制文件

1、注册数据提供者

@data.register_provider(lookup_name='csv_data_provider', init_args={})
class Provider(data.Provider):
  """A csv data provider."""

  def __init__(self):

2、创建构建规则

model_search_oss_binary(
    name = "csv_data_binary",
    dataset_dep = ":csv_data_for_binary",
)

3、添加集成测试

model_search_oss_test(
    name = "csv_data_for_binary_test",
    dataset_dep = ":csv_data_for_binary",
    problem_type = "dnn",
    extra_args = [
        "--filename=$${TEST_SRCDIR}/model_search/data/testdata/csv_random_data.csv",
    ],
    test_data = [
        "//model_search/data/testdata:csv_random_data",
    ],
)

六、分布式运行

系统支持分布式并行搜索。要实现此功能:

1、在多台机器上运行二进制文件
2、修改 model_search/metadata/ml_metadata_db.py 中的标志指向您的数据库

配置完成后,前文创建的二进制文件将连接到此数据库开始异步搜索。


七、Cloud AutoML 服务

如需更高性能且无需编写代码的 AutoML 解决方案,请尝试:
https://cloud.google.com/automl-tables


伊织 xAI 2025-04-27(日)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

编程乐园

请我喝杯伯爵奶茶~!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值