使用 Ray 用 15 行 Python 代码实现一个参数服务器
参数服务器是很多机器学习应用的核心部分。其核心作用是存放机器学习模型的参数(如,神经网络的权重)和提供服务将参数传给客户端(客户端通常是处理数据和计算参数更新的 workers)
参数服务器(如同数据库)是正常构建并 shipped 像一个单一系统。这个文章讲解如何使用 Ray 来用几行代码实现参数服务器。
通过将参数服务器从一个“系统”调整为一个“应用”,这个方法将量级的 orders 变得更加简单来部署一个参数服务器应用。类似地,通过让应用和库实现自身的参数服务器,这个方法让参数服务器的行为更加可配置和灵活(因为这个应用可以轻松地修改实现)
什么是 Ray? Ray 是一个用于并行和分布式的通用框架。Ray 提供了一个统一的任务并行和actor抽象,并且通过共享内存、零复制序列化和分布式调度达到了高的性能。Ray 也包含了针对人工智能应用(如超参数调优和强化学习)的高性能库。
什么是一个参数服务器?
一个参数服务器是一个用来在集群上训练机器学习模型的键值对。其值(values)是机器学习模型的参数(如一个神经网络)。其键(keys)索引了模型参数。
例如,在一个电影的推荐系统中,可能会针对每个用户、每个电影都有相应的键。对每个用户和电影,有对应的以用户特属和以电影特属的参数。在语言建模的应用中,词可能会作为键而其嵌入则可能为值。在最简单的形式中,参数服务器可能会隐式地有一个单个键,允许你所有的参数被获取并一次性更新。我们展示了如何作为一个 Ray 的 actor 实现一个参数服务器。
import numpy as np
import ray
@ray.remote
class ParameterServer(object):
def __init__(self, dim):
# params 可以是一个将键映射到数组的字典
self.params = np.zeros(dim)
def get_params(self):
return self.params
def update_params(self, grad):
self.params += grad
@ray.remote 装饰器定义了一个服务。以类 ParameterServer 为‘输入’并使之作为一个远程服务或者 actor 被实例化。
这里,我们假设更新是一个梯度,这个被加到参数的向量上。这仅仅是最简单可能例子,可以有很多不同的选择。
参数服务器一般作为远程进程或者服务存在 并通过远程过程调用来和客户端交互。为了实例化参数服务器为一个远程 actor,我们可以这样:
ray.init()