[源码解析] PyTorch 分布式(16) --- 使用异步执行实现批处理 RPC

[源码解析] PyTorch 分布式(16) — 使用异步执行实现批处理 RPC

0x00 摘要

在前面的文章之中,我们已经学习了PyTorch 分布式的基本模块,接下来我们通过几篇文章来看看如何把这些模块应用到实践之中,顺便把PyTorch分布式逻辑整体梳理一下。本文介绍如何使用异步执行操作来实现批处理 RPC,大家可以学习到PyTorch对参数服务器一个新的实现方式。

本文以IMPLEMENTING BATCH RPC PROCESSING USING ASYNCHRONOUS EXECUTIONS的翻译为基础,加入了自己的理解。

PyTorch分布式其他文章如下:

[ 源码解析] PyTorch 分布式(1)------历史和概述

[ 源码解析] PyTorch 如何使用GPU

源码解析] PyTorch 分布式(2) ----- DataParallel(上)

[ 源码解析] PyTorch 分布式(3) ----- DataParallel(下)

[ 源码解析] PyTorch 分布式(4)------分布式应用基础概念

源码解析] PyTorch 分布式(5) ------ DistributedDataParallel 总述&如何使用

[ 源码解析] PyTorch分布式(6) —DistributedDataParallel – 初始化&store

[ 源码解析] PyTorch 分布式(7) ----- DistributedDataParallel 之进程组

[源码解析] PyTorch 分布式(8) -------- DistributedDataParallel之论文篇

[ 源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化

[源码解析] PyTorch 分布式(10)------DistributedDataParallel之Reducer静态架构

[ 源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer和Join操作

源码解析] PyTorch 分布式(12) ----- DistributedDataParallel 之 前向传播

[源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播

源码解析] PyTorch 分布式 Autograd (1) ---- 设计

源码解析] PyTorch 分布式 Autograd (2) ---- RPC基础

源码解析] PyTorch 分布式 Autograd (3) ---- 上下文相关

[源码解析] PyTorch 分布式 Autograd (4) ---- 如何切入引擎

[源码解析] PyTorch 分布式 Autograd (5) ---- 引擎(上)

[源码解析] PyTorch 分布式 Autograd (6) ---- 引擎(下)

[源码解析] PyTorch分布式优化器(1)----基石篇

[源码解析] PyTorch分布式优化器(2)----数据并行优化器

[源码解析] PyTorch分布式优化器(3)---- 模型并行

源码解析] PyTorch 分布式(14) --使用 Distributed Autograd 和 Distributed Optimizer

[源码解析] PyTorch 分布式(15) — 使用分布式 RPC 框架实现参数服务器

注:本文没有完全按照原文顺序进行翻译,而是按照自己理解的思路重新组织了文章。

0x01 前言

1.1 先决条件

本文的先决条件如下:

本教程演示了如何使用@rpc.functions.async_execution 装饰器构建批处理 RPC 应用程序,这有助于通过减少被阻塞的 RPC 线程的数量,并且在被调用方整合 CUDA 操作来加快训练速度。这与使用 TorchServer 进行批量推理的想法相同。Batch RPC 有助于将动作整合到较少的 CUDA 操作中,从而摊销开销。

注意:本教程需要 PyTorch v1.6.0 或更高版本。

1.2 基础知识

之前的教程已经展示了使用torch.distributed.rpc构建分布式训练应用程序的步骤,但他们没有详细说明在处理 RPC 请求时被调用方会发生什么。从 PyTorch v1.5 开始,针对每个 RPC 请求,被调用者都会启动一个线程来执行该请求中的函数,该线程会阻塞直到该函数返回。这适用于许多用例,但有一个问题:如果用户函数在 IO 上阻塞,例如使用嵌套的 RPC 调用或信号(例如等待不同的 RPC 请求来解除阻塞),则被调用者上的 RPC 线程将不得不空闲等待,直到 IO 完成或信号(signal)事件发生。因此,RPC 被调用者使用的线程可能会使用比实际需要更多。造成这个问题的原因是RPC把用户函数当成黑盒,对函数中发生的事情知之甚少。为了让用户函数能够让出和释放 RPC 线程,需要向 RPC 系统提供更多的提示。

从 v1.6.0 开始,PyTorch 通过引入两个新概念来解决这个问题:

  • torch.futures.Future 封装了一个异步执行,同时也支持安装回调函数。
  • @rpc.functions.async_execution 装饰器,它允许应用程序告诉被调用者,本目标函数将返回一个future,并且可以在执行过程中多次暂停和yield。

使用这两个工具,应用程序代码可以将用户函数分解为多个较小的函数,将它们链接在一起作为Future 对象的回调方法,并返回包含最终结果的 Future给调用者。在被调用方,在获取Future对象时,它也会安装后续的 RPC 响应处理作为回调方法,这些回调会在最终结果准备好时被触发。这样,被调用者不再需要阻塞一个线程,只是等待最终返回值准备好就行。 简单的例子请参考@rpc.functions.async_execution的API文档 。

除了减少被调用者的空闲线程数量外,这些工具还使批处理 RPC 处理更容易、更快。本教程演示了如何使用@rpc.functions.async_execution 装饰器构建分布式批量更新参数服务器和批量处理强化学习应用程序 。

注:我们不考虑强化学习的领域,那样会影响我们的思路,牵扯精力

1.3 代码

因为原文主要是强化学习代码讲解,而我们只关注普通分布式批量更新参数服务器,所以需要看原始代码。

代码位于 https://github.com/pytorch/examples/blob/master/distributed/rpc/batch/parameter_server.py。先全部摘录如下:

import os
import threading
from datetime import datetime

import torch
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn
from torch import optim

import torchvision


batch_size = 20
image_w = 64
image_h = 64
num_classes = 30
batch_update_size = 5
num_batches = 6

def timed_log(text):
    print(f"{datetime.now().strftime('%H:%M:%S')} {text}")

class BatchUpdateParameterServer(object):

    def __init__(self, batch_update_size=batch_update_size):
        self.model = torchvision.models.resnet50(num_classes=num_classes)
        self.lock = threading.Lock()
        self.future_model = torch.futures.Future()
        self.batch_update_size = batch_update_size
        self.curr_update_size = 0
        self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
        for p in self.model.parameters():
            p.grad = torch.zeros_like(p)

    def get_model(self):
        return self.model

    @staticmethod
    @rpc.functions.async_execution
    def update_and_fetch_model(ps_rref, grads):
        self = ps_rref.local_value()
        timed_log(f"PS got {self.curr_update_size}/{batch_update_size} updates")
        for p, g in zip(self.model.parameters(), grads):
            p.grad += g
        with self.lock:
            self.curr_update_size += 1
            fut = self.future_model

            if self.curr_update_size >= self.batch_update_size:
                for p in self.model.parameters():
                    p.grad /= self.batch_update_size
                self.curr_update_size = 0
                self.optimizer.step()
                self.optimizer.zero_grad()
                fut.set_result(self.model)
                timed_log("PS updated model")
                self.future_model = torch.futures.Future()

        return fut


class Trainer(object):

    def __init__(self, ps_rref):
        self.ps_rref = ps_rref
        self.loss_fn = nn.MSELoss()
        self.one_hot_indices = torch.LongTensor(batch_size) \
                                    .random_(0, num_classes) \
                                    .view(batch_size, 1)

    def get_next_batch(self):
        for _ in range(num_batches):
            inputs = torch.randn(batch_size, 3, image_w, image_h)
            labels = torch.zeros(batch_size, num_classes) \
                        .scatter_(1, self.one_hot_indices, 1)
            yield inputs.cuda(), labels.cuda()

    def train(self):
        name = rpc.get_worker_info().name
        m = self.ps_rref.rpc_sync().get_model().cuda()
        for inputs, labels in self.get_next_batch():
            timed_log(f"{name} processing one batch")
            self.loss_fn(m(inputs), labels).backward()
            timed_log(f"{name} reporting grads")
            m = rpc.rpc_sync(
                self.ps_rref.owner(),
                BatchUpdateParameterServer.update_and_fetch_model,
                args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]),
            ).cuda()
            timed_log(f"{name} got updated model")


def run_trainer(ps_rref):
    trainer = Trainer(ps_rref)
    trainer.train()


def run_ps(trainers):
    timed_log("Start training")
    ps_rref = rpc.RRef(BatchUpdateParameterServer())
    futs = []
    for trainer in trainers:
        futs.append(
            rpc.rpc_async(trainer, run_trainer, args=(ps_rref,))
        )

    torch.futures.wait_all(futs)
    timed_log("Finish training")


def run(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    options=rpc.TensorPipeRpcBackendOptions(
        num_worker_threads=16,
        rpc_timeout=0  # infinite timeout
     )
    if rank != 0:
        rpc.init_rpc(
            f"trainer{rank}",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=options
        )
        # trainer passively waiting for ps to kick off training iterations
    else:
        rpc.init_rpc(
            "ps",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=options
        )
        run_ps([f"trainer{r}" for r in range(1, world_size)])

    # block until all rpcs finish
    rpc.shutdown()


if __name__=="__main__":
    world_size = batch_update_size + 1
    mp.spawn(run, args=(world_size, ), nprocs=world_size, join=True)

0x02 启动

我们首先看看如何启动。

2.1 总体启动

我们假设有一个master(rank 0),一个worker。Master 之上运行的是参数服务器,worker 之上是训练代码。

def run(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    options=rpc.TensorPipeRpcBackendOptions(
        num_worker_threads=16,
        rpc_timeout=0  # infinite timeout
     )
    if rank != 0:
        rpc.init_rpc( # 训练代码
            f"trainer{rank}",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=options
        )
        # trainer passively waiting for ps to kick off training iterations
    else:
        rpc.init_rpc( # 参数服务器
            "ps", 
            rank=rank,
            world_size=world_size,
            rpc_backend_options=options
        )
        run_ps([f"trainer{r}" for r in range(1, world_size)])

    # block until all rpcs finish
    rpc.shutdown()


if __name__=="__main__":
    world_size = batch_update_size + 1
    mp.spawn(run, args=(world_size, ), nprocs=world_size, join=True)

逻辑如下图:

             torch.multiprocessing.spawn
                        +
                        |
                        |
           +------------+-------------------------------------------------
           |                                                             |
           |                                                             |
           v                                                             v
+----------+----------------------------------------------+ +------------+----------------+
| "ps"                                           rank = 0 | | f"trainer{rank}"   rank = 1 |
|                                                         | |                             |
|                                                         | |                             |
|                     rpc.init_rpc                        | |         rpc.init_rpc        |
|                                                         | |                             |
|                                                         | |                             |
|  run_ps([f"trainer{r}" for r in range(1, world_size)])  | |                             |
|                                                         | |                             |
|                                                         | |                             |
+---------------------------------------------------------+ +-----------------------------+

2.2 启动参数服务器

run_ps 启动了参数服务器和trainer。注意,这里在参数服务器之中启动 trainer,即,master 不仅仅有一个参数服务器,还负责通过 rpc 来驱动trainer上的训练循环。

def run_ps(trainers):
    timed_log("Start training")
    ps_rref = rpc.RRef(BatchUpdateParameterServer())
    futs = []
    for trainer in trainers: # trainer 是字符串,比如"trainer1"
        futs.append(
            rpc.rpc_async(trainer, run_trainer, args=(ps_rref,)) # 运行run_trainer
        )

    torch.futures.wait_all(futs)
    timed_log("Finish training")
    
def run_trainer(ps_rref):
    trainer = Trainer(ps_rref)
    trainer.train() # 调用 Trainer 的方法   

具体拓展如下:

这里没有给出参数服务器和trainer的逻辑,我们会在后续分析之后陆续给出。trainer 也只给出了一个。

             torch.multiprocessing.spawn
                        +
                        |
                        |
           +------------+------------------------------------------------+
           |                                                             |
           |                                                             |
           v                                                             v
+----------+----------------------------------------------+ +------------+----------------+
| "ps"                                           rank = 0 | | f"trainer{rank}"   rank = 1 |
|                                                         | |                             |
|                                                         | |                             |
|                     rpc.init_rpc                        | |         rpc.init_rpc        |
|                                                         | |                             |
|                                                         | |        +-----------------+  |
|  run_ps([f"trainer{r}" for r in range(1, world_size)])  | |        |   Trainer       |  |
|                         +                               | |        |                 |  |
|                         |                               | |    +-------->  train()   |  |
|                         |                               | |    |   |                 |  |
|                         v                               | |    |   +-----------------+  |
|   +---------------------+---------------------------+   | |    |                        |
|   | run_ps                                          |   | +-----------------------------+
|   |                                                 |   |      |
|   |                                                 |   |      |
|   | ps_rref = rpc.RRef(BatchUpdateParameterServer())|   |      |
|   | for trainer in trainers:                        |   |      |
|   |     futs.append(                                |   |      |
|   |         rpc.rpc_async(trainer, run_trainer,+---------------+
|   |                       args=(ps_rref,))          |   |
|   |     )                                           |   |
|   +-------------------------------------------------+   |
+---------------------------------------------------------+

手机上如下:

0x03 参数服务器

上面图中没有给出具体参数服务器代码,我们接下来就分析一下。

这里考虑具有一个参数服务器 (PS) 和多个trainer的同步训练应用程序。在这个应用中,PS 持有参数并等待所有训练器报告梯度。在每次迭代中,它等待直到从所有训练器接收梯度,然后一次性更新所有参数。

下面的代码显示了 PS 类的实现。

  • PS初始化时候生成了常规SGB优化器,不是分布式优化器,而且优化器是在PS之上
  • update_and_fetch_model方法被 @rpc.functions.async_execution所装饰,将由trainer调用。
  • 每次调用都会返回一个Future对象,该对象将被用来处理更新后的模型。
  • 大多数训练器发起的调用只是累积梯度到 .grad成员变量 ,然后立即返回,并在 PS 上产生 RPC 线程。
  • 最后到达的训练器将触发优化器步骤并消耗所有先前上报的梯度。然后它使用更新后的模型来设置future_model,这是依靠通过Future对象来依次通知来自其他训练者的先前请求,并将更新后的模型发送给所有训练者。

具体代码如下:

batch_size = 20
image_w = 64
image_h = 64
num_classes = 30
batch_update_size = 5
num_batches = 6

def timed_log(text):
    print(f"{datetime.now().strftime('%H:%M:%S')} {text}")

class BatchUpdateParameterServer(object):

    def __init__(self, batch_update_size=batch_update_size):
        self.model = torchvision.models.resnet50(num_classes=num_classes)
        self.lock = threading.Lock()
        self.future_model = torch.futures.Future()
        self.batch_update_size = batch_update_size
        self.curr_update_size = 0
        # 重点:这里是常规SGB优化器,不是分布式优化器
        self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
        for p in self.model.parameters():
            p.grad = torch.zeros_like(p)

    def get_model(self):
        return self.model

    @staticmethod
    @rpc.functions.async_execution # trainer会直接调用
    def update_and_fetch_model(ps_rref, grads):
        self = ps_rref.local_value()
        timed_log(f"PS got {self.curr_update_size}/{batch_update_size} updates")
        for p, g in zip(self.model.parameters(), grads): # 得到
            p.grad += g # 累积梯度
        with self.lock:
            self.curr_update_size += 1
            fut = self.future_model

            if self.curr_update_size >= self.batch_update_size:
                # 最后到达的训练器将触发优化器步骤并消耗所有先前上报的梯度。
                for p in self.model.parameters():
                    p.grad /= self.batch_update_size
                self.curr_update_size = 0
                self.optimizer.step() # 更新模型
                self.optimizer.zero_grad()
                fut.set_result(self.model) # 将更新后的模型发送给所有训练者
                timed_log("PS updated model")
                self.future_model = torch.futures.Future() # 使用更新后的模型来设置future_model

        return fut # 该对象将被用来处理更新后的模型

逻辑拓展如下,这里省略了参数服务器生成trainer的步骤:

                            torch.multiprocessing.spawn
                                       +
                                       |
                                       |
                          +------------+--------------------------------------------------------------------------------+
                          |                                                                                             |
                          |                                                                                             |
                          v                                                                                             v
+-------------------------+-----------------------------------------------------------------------------+  +------------+----------------+
|  "ps"                                                                                        rank = 0 |  | f"trainer{rank}"   rank = 1 |
|                                                        +-------------------------------------------+  |  |                             |
|                                                        | BatchUpdateParameterServer                |  |  |                             |
|  rpc.init_rpc                                          |                                           |  |  |         rpc.init_rpc        |
|                                                        |                                           |  |  |                             |
|  run_ps([f"trainer{r}" for r in range(1, world_size)]) |                                           |  |  |  +-----------------------+  |
|                        +                               | model = resnet50(num_classes)             |  |  |  | Trainer               |  |
|                        |                               |                                           |  |  |  |                       |  |
|                        |                               | future_model = Future()                   |  |  |  |                       |  |
|                        v                               |                                           |  |  |  |  +----> train()       |  |
|  +---------------------+---------------------------+   | optimizer = optim.SGD(model.parameters()) |  |  |  |  |                    |  |
|  | run_ps                                          |   |                                           |  |  |  |  |                    |  |
|  |                                                 |   +-------------------------------------------+  |  |  +-----------------------+  |
|  |                                                 |                                                  |  |     |                       |
|  | ps_rref = rpc.RRef(BatchUpdateParameterServer())|                                                  |  |     |                       |
|  | for trainer in trainers:                        |                                                  |  +-----------------------------+
|  |     futs.append(                                |                                                  |        |
|  |         rpc.rpc_async(trainer, run_trainer,+----------------------------------------------------------------+
|  |                       args=(ps_rref,))          |                                                  |
|  |     )                                           |                                                  |
|  +-------------------------------------------------+                                                  |
|                                                                                                       |
+-------------------------------------------------------------------------------------------------------+

手机如下:

0x04 Trainer

对于训练器,它们都使用来自 PS 的相同参数集进行初始化。在每次迭代中执行如下操作:

  • 每个训练器首先运行前向和后向传播以在本地生成梯度。
  • 然后,每个训练器使用 RPC 向 PS 报告其梯度,并通过同一 RPC 请求的返回值取回更新后的参数。

在训练器的实现中,目标函数是否被标记 @rpc.functions.async_execution是没有区别的。训练器只需使用 rpc_sync 调用update_and_fetch_model,其将阻塞训练器,直到返回更新的模型。

可以看到,参数服务器存储模型,模型可以返回到trainer。

class Trainer(object):

    def __init__(self, ps_rref):
        self.ps_rref = ps_rref
        self.loss_fn = nn.MSELoss()
        self.one_hot_indices = torch.LongTensor(batch_size) \
                                    .random_(0, num_classes) \
                                    .view(batch_size, 1)

    def get_next_batch(self):
        for _ in range(num_batches):
            inputs = torch.randn(batch_size, 3, image_w, image_h)
            labels = torch.zeros(batch_size, num_classes) \
                        .scatter_(1, self.one_hot_indices, 1)
            yield inputs.cuda(), labels.cuda()

    def train(self):
        name = rpc.get_worker_info().name
        # 从参数服务器获取模型
        m = self.ps_rref.rpc_sync().get_model().cuda()
        for inputs, labels in self.get_next_batch():
            timed_log(f"{name} processing one batch")
            # 利用模型来前向传播/反向传播
            self.loss_fn(m(inputs), labels).backward()
            timed_log(f"{name} reporting grads")
            # 调用参数服务器的函数来提交梯度
            m = rpc.rpc_sync( # rpc_sync 操作完成之后,m就是最新模型了
                self.ps_rref.owner(),
                BatchUpdateParameterServer.update_and_fetch_model,
                args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]),
            ).cuda()
            timed_log(f"{name} got updated model")

拓展逻辑如下:

  1. 参数服务器的run_trainer 方法会直接调用 trainer.train() 方法来执行一步step。
  2. train 方法之中,会调用 self.ps_rref.rpc_sync().get_model().cuda() 从参数服务器获得模型,放到本地设备之上(图上是双向箭头,表示这是一个get/return动作,需要把模型存储在worker本地)。
  3. 调用 self.loss_fn(m(inputs), labels).backward() 来进行前向传播/反向传播。
  4. 调用参数服务器的 update_and_fetch_model 函数来提交梯度,这里使用了异步RPC
  5. 参数服务器的 update_and_fetch_model 之中,进行梯度累积,模型更新是通过PS之上常规SGD优化器完成,最后调用 fut.set_result(self.model) 来发布新模型给trainer。在trainer 之中,就是 m = rpc.rpc_sync(…) 这个赋值之后,m 是最新模型了。
                            torch.multiprocessing.spawn
                                       +
                                       |
                                       |
                          +------------+----------------------------------------------------------------------------------------+
                          |                                                                                                     |
                          |                                                                                                     |
                          v                                                                                                     v
+-------------------------+--------------------------------------------------------------------------------+     +--------------+-------------------------------------+
|  "ps"                                                                                        rank = 0    |     | f"trainer{rank}"                         rank = 1  |
|                                                                                                          |     |                                                    |
|                                                                                                          |     | rpc.init_rpc                                       |
|  rpc.init_rpc                                          +----------------------------------------------+  |     |                                                    |
|                                                        | BatchUpdateParameterServer                   |  |     | +------------------------------------------------+ |
|  run_ps([f"trainer{r}" for r in range(1, world_size)]) |                                              |  |     | | Trainer                                        | |
|                        +                               |    model = resnet50(num_classes)             |  |     | |                                                | |
|                        |                               |                                              |  |     | | +-------------------------------------------+  | |
|                        |                               |    future_model = Future()                   |  |     | | | train                                     |  | |
|                        v                               |                                              |  |     | | |                                           |  | |
|  +---------------------+---------------------------+   |    optimizer = optim.SGD(model.parameters()) |  |     | | |                                           |  | |
|  | run_ps                                          |   |                                              |  |  2  | | |     m = ps_rref.rpc_sync()                |  | |
|  |                                                 |   |    def get_model(self):   <------------------------------------>           .get_model()               |  | |
|  |                                                 |   |        return self.model                     |  |     | | |                .cuda()                    |  | |
|  | ps_rref = rpc.RRef(BatchUpdateParameterServer())|   |                                              |  |     | | |                                           |  | |
|  | for trainer in trainers:                        |   |                                              |  |     | | | 3   loss_fn(m(inputs), labels).backward() |  | |
|  |     futs.append(                                |   |                                              |  |  4  | | |                                           |  | |
|  |         rpc.rpc_async(trainer, run_trainer,     |   |    update_and_fe+ch_model <-----------------------------------> BatchUpdateParameterServer            |  | |
|  |                       args=(ps_rref,))  +       |   |                 |                            |  |     | | |           .update_and_fetch_model()       |  | |
|  |     )                                   |       |   +----------------------------------------------+  |     | | |                                           |  | |
|  |                                         |       |                     |                               |     | | |                                           |  | |
|  +-------------------------------------------------+                     +---------------------------------------------> m = rpc.rpc_sync(...).cuda()          |  | |
|                                            |                                                             |  5  | | |                                           |  | |
|                                            |                                                             |     | | |                                           |  | |
+----------------------------------------------------------------------------------------------------------+     | | +----------------+--------------------------+  | |
                                             |                                                                   | |                  |                             | |
                                             |                                                                   | +------------------------------------------------+ |
                                             |                                                                   +----------------------------------------------------+
                                             |                              1                                                         |
                                             +----------------------------------------------------------------------------------------+


手机如下:

0x05 对比

前文结尾,我们对比参数服务器的经典实现 ps-lite 和 前两篇实现的参数服务器。

  • ps-lite 是类似传统服务器实现,有自己主动的业务循环,可以响应用户的显式请求,也有自己明确的逻辑,本地也有自己的KV存储。
  • PyTorch 前两篇官方文档(本系列前两篇文章)之中,参数服务器则是另外一种思路:
    • 参数服务器上没有主动的循环,没有KV存储,没有服务器逻辑,而是可以直接存储业务模型,ps 会把业务模型需要优化的参数返回给trainer 之上的 DistributedOptimizer。
    • 业务驱动由trainer完成:train loop代码在trainer 之中,DistributedOptimizer 在trainer 之中,DistributedOptimizer 负责进行分布式优化。
  • 本文又与上面不同,看起来更像是ps-lite,但是又糅合了RPC实现:
    • ps进程会启动trainer的训练循环
    • 每个迭代之中,trainer 会从参数服务器获取最新模型,前向操作/后向传播都在trainer 完成。
    • trainer 会通过异步RPC把梯度提交给参数服务器。
    • 模型更新是通过**PS之上常规SGD优化器完成**。
    • 模型更新之后通过异步RPC把模型再次分发给trainer。

不得不说,官方这几篇文章快把各种实现方式玩出花来了,大家可以依据自己业务特点来参考实现。

0xEE 个人信息

★★★★★★关于生活和技术的思考★★★★★★

微信公众账号:罗西的思考

如果您想及时得到个人撰写文章的消息推送,或者想看看个人推荐的技术资料,敬请关注。

在这里插入图片描述

0xFF 参考

IMPLEMENTING BATCH RPC PROCESSING USING ASYNCHRONOUS EXECUTIONS

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值