史上最全面的ps-lite理解

概述

ps-lite是一个分布式参数服务器,具体什么是分布式,什么是参数服务器就不在此详述,talk is cheap, show me code。

代码

既然是分布式,那么我们就来看看整个框架有哪几部分。
在这里插入图片描述
可以看到有worker, server, scheduler.在这里我们假设有一个scheduler, 2个server,2个worker,既然是分布式,那么就假设分布在5台电脑上,在每台电脑上肯定要起一个进程,好了,我们首先先来启动一个scheduler。

scheduler node
全局观
  • 先启动一个节点
  • 等待各个work和sever发来的message
  • 一旦收集到4个message,就说明work和server都到齐了,这时给每一个work和server发一个message告诉他们对应的身份id,同时也让work去链接server,server去链接work
  • scheduler的初始化完成
具体代码

环境配置

export DMLC_NUM_SERVER=2
export DMLC_NUM_WORKER=2
export DMLC_PS_ROOT_URI='127.0.0.1'
export DMLC_PS_ROOT_PORT=8000 
export DMLC_ROLE='scheduler'

具体代码

一层代码
Start(0);//Start(int customer_id, const char* argv0 = nullptr) {Postoffice::Get()->Start(customer_id, argv0, true);
Finalize(0, true);//Finalize(int customer_id, const bool do_barrier = true) {Postoffice::Get()->Finalize(customer_id, do_barrier);
二层代码

首先出来了一个Postoffice的类型,每一个节点都有且只有一个Postoffice对象,具体如下:
在这里插入图片描述
用到什么成员变量和函数再说,不做一一介绍,首先看第一部分, 这里我把线程锁和一些无关紧要的逻辑到删除了,可以看到这个函数主要做了这几件事,首先是初始化node_ids_这个成员变量,接下的事情就是初始化变量van_, 最后执行了一个Barrier函数。

void Postoffice::Start(int customer_id, const char* argv0, const bool do_barrier) 
{
//读取环境变量
  InitEnvironment();//这一行核心内容就是这个:Van::Create(van_type="zmp"),此外还初始化一点成员变量
// init node_ids_
  直接看下面的图
// start van
  van_->Start(customer_id);  
// do a barrier here
  if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);
}
三层代码

下面首先来看看node_ids_的初始化流程,代码逻辑非常简单,可以自己去查看,这里我们假设有两个work ,两个server, 具体的结果如下:
在这里插入图片描述
接下来看看van_这个成员变量的初始化, 我们知道van_其实是一个zmp对象,zmp继承于Van这个类,在这个类的基础上加了两个成员变量,分别是:unordered_map<int, void*> senders_ 和变量void *receiver_ = nullptr, senders_是一个集合,就是发送的消息的结合,比如8号节点要给9号节点发消息,那么只要找到(9,socket_9)这个组合就行了,然后调用socket_9.send(message), receiver_就只有一个,因为你节点对外肯定只有一个门户。 这个等用到的时候再说,由于大体上改变的不多,所以也可以对照父类的机构看看,具体如下:
在这里插入图片描述
具体开看看代码:

void Van::Start(int customer_id) 
{
	//初始化scheduler_这个成员变量
	scheduler_.hostname = "DMLC_PS_ROOT_URI";
	scheduler_.port ="DMLC_PS_ROOT_PORT";
	scheduler_.role = Node::SCHEDULER;
	scheduler_.id = kScheduler;
	//确认本节点是scheduler节点
	is_scheduler_ = true;
	//初始化本节点,因为是scheduler,所以直接就是等于赋值就行
	my_node_ = scheduler_;
	//绑定接口,把本节点绑定到ip:port这个socket上,理论来说这个函数就是初始化了receiver_
	Bind(my_node_,  0)
	//连接上scheduler_,由于本节点就是scheduler_,其实就是初始化senders_,由于发送的节点很多,所以这里是一个map<int,void*>
	// 在这里就是senders_[1] = socket_1, socket_1中的body设置一点字符“ps1***”, 注意链接不是sendMsg,这一点一定要闹清楚
	Connect(scheduler_);
	//开启一个接收消息的线程,其实这里就是一直待阻塞了,等到所有的work和server都发发来了消息
	receiver_thread_ =new thread(&Van::Receiving, this);
	//然后就是等着ready_啥时候从false变成true,当是scheduler的时候,必须要有等worker和server节点过来,不然一直都是阻塞在这。
	while (!ready_.load()) {this_thread::sleep_for(std::chrono::milliseconds(100));
	// 如果设置了超时重传,就初始化resender_这个变量
	resender_ = new Resender(timeout, 10, this);
  }
四层代码

接下来再往里面深入看看代码,上可以看到主要就是Bind函数,Connect函数,以及Receiving函数

//这个函数对schedule节点的话,你不需要指定port ,但是对于work和server需要自己查找一个本地可用端口。
int Bind(const Node& node, int max_retry) override
 {
     //在这里可以看到receiver_这个变量被初始化了,
    //是一个socket,下面绑定了具体的ip:port,每次RecvMsg(Message* msg)时候里面都要从这个socket读取。
    receiver_ = zmq_socket(context_, ZMQ_ROUTER);
    string hostname = node.hostname;
    string addr =  "tcp://" + hostname + ":";
    int port = node.port;
    unsigned seed = static_cast<unsigned>(time(NULL) + port);
    for (int i = 0; i < max_retry + 1; ++i) 
    {
      auto address = addr + std::to_string(port);
      if (zmq_bind(receiver_, address.c_str()) == 0) break;
      if (i == max_retry) 
      {
        port = -1;
      } 
      else 
      {
        port = 10000 + rand_r(&seed) % 40000;
      }
    }
    return port;
  }

connect开始初始化senders_,或者在后面的时候就补充

void Connect(const Node& node) override 
{
    int id = node.id;
    auto it = senders_.find(id);
    if (it != senders_.end()) {zmq_close(it->second);}//如果找到了对应socket就关闭socket
    // worker doesn't need to connect to the other workers. same for server
    if ((node.role == my_node_.role) && (node.id != my_node_.id)) {return;}
    void *sender = zmq_socket(context_, ZMQ_DEALER);//建立一个socket
    //我们知道对于scheduler而言,一开始就是知道自己的id,为1,下面这一个if条件就是说把自己的id捆绑到当下socket上
    if (my_node_.id != Node::kEmpty) 
    {
      std::string my_id = "ps" + std::to_string(my_node_.id);
      zmq_setsockopt(sender, ZMQ_IDENTITY, my_id.data(), my_id.size());
      const char* watermark = Environment::Get()->find("DMLC_PS_WATER_MARK");
      if (watermark) {
        const int hwm = atoi(watermark);
        zmq_setsockopt(sender, ZMQ_SNDHWM, &hwm, sizeof(hwm));
      }
    }
    // connect
    string addr = "tcp://" + node.hostname + ":" + to_string(node.port);
    zmq_connect(sender, addr.c_str());//将sender这个socket和目标地址连接
    senders_[id] = sender;//将目标id的socket存放起来后面使用
  }

最后看看receiving这个函数,其实在这里就开始等待work和server节点的接入了,假如现在开始有一个work发来消息,消息是控制信息,具体指令是ADD_NODE.

void Van::Receiving() {
  Meta nodes;
  Meta recovery_nodes;  // store recovery nodes 储存康复的节点
  recovery_nodes.control.cmd = Control::ADD_NODE;// 康复节点的control都设置为add_node
  while (true) 
  {
    Message msg;
    int recv_bytes = RecvMsg(&msg);//利用receiver_这个变量拿到消息
    recv_bytes_ += recv_bytes;//收到的中字节数累加
    // duplicated message
    if (resender_ && resender_->AddIncomming(msg)) continue;//重传机制先不看

    if (!msg.meta.control.empty()) //如果是控制类型的消息
    {
      // control msg
      auto& ctrl = msg.meta.control;
      if (ctrl.cmd == Control::TERMINATE) {
        ProcessTerminateCommand();break;
      } else if (ctrl.cmd == Control::ADD_NODE) {
        ProcessAddNodeCommand(&msg, &nodes, &recovery_nodes);//当执行到这个位置的时候继续跳转
      } else if (ctrl.cmd == Control::BARRIER) {
        ProcessBarrierCommand(&msg);
      } else if (ctrl.cmd == Control::HEARTBEAT) {
        ProcessHearbeat(&msg);
      } 
    } 
    else //非控制类型的消息处理方式
    {
      ProcessDataMsg(&msg);
    }
  }
}

接下来看看scheduler对于控制类型消息的处理:

void Van::ProcessAddNodeCommandAtScheduler(Message* msg, Meta* nodes, Meta* recovery_nodes)
{
  auto dead_nodes = Postoffice::Get()->GetDeadNodes(heartbeat_timeout_);//查出心跳包超时的id
  unordered_set<int> dead_set(dead_nodes.begin(), dead_nodes.end());//又给转存到dead_set里面
  auto& ctrl = msg->meta.control;//拿到收到消息里面的control信息
  //下面这个函数就比较骚,名字叫做更新节点ID,记住当下是在schedule节点,我们先下去看看这个函数。
  UpdateLocalID(msg, &dead_set, nodes, recovery_nodes);
  //上面的函数代码看完后继续往下走
  recovery_nodes->control.cmd = Control::ADD_NODE;//不知道为啥又写一边
  time_t t = time(NULL);
  size_t num_nodes = Postoffice::Get()->num_servers() + Postoffice::Get()->num_workers();
  //根据上面updatelocalId的函数,我们知道当下nodes还是没有收集齐全,一旦收齐后进入if条件中
  if (nodes->control.node.size() == num_nodes) {
    // sort the nodes according their ip and port,这个排序就是不说了,就是根据IP和port给work,server排个序
    std::sort(nodes->control.node.begin(), nodes->control.node.end(), [](const Node& a, const Node& b) {
    return (a.hostname.compare(b.hostname) | (a.port < b.port)) > 0; });
    // assign node rank
    for (auto& node : nodes->control.node) 
    {
	      string node_host_ip = node.hostname + ":" + to_string(node.port);
	      if (connected_nodes_.find(node_host_ip) == connected_nodes_.end()) //如果ip:port不存在van_中的话
	      {
	        CHECK_EQ(node.id, Node::kEmpty);//判断是不是初始化节点
	        int id = node.role == Node::SERVER
	                     ? Postoffice::ServerRankToID(num_servers_)//如果是sever的话,就id产生一个id号,num_servers_初始化为0
	                     : Postoffice::WorkerRankToID(num_workers_);
	        PS_VLOG(1) << "assign rank=" << id << " to node " << node.DebugString();
	        node.id = id;//将这个节点的id赋值为id
	        Connect(node);//链接这个节点, 其实就是建立一个socket, 然后senders_[id] = sender;//将目标id的socket存放起来后面使用
	        Postoffice::Get()->UpdateHeartbeat(node.id, t);//更新心跳包
	        connected_nodes_[node_host_ip] = id;//你work发message来了,我这里要把这个节点作为已经链接的节点
	      } 
	      else 
	      {
	        int id = node.role == Node::SERVER
	                     ? Postoffice::ServerRankToID(num_servers_)
	                     : Postoffice::WorkerRankToID(num_workers_);
	        shared_node_mapping_[id] = connected_nodes_[node_host_ip];
	        node.id = connected_nodes_[node_host_ip];
	      }
	      if (node.role == Node::SERVER) num_servers_++;//更新rank
	      if (node.role == Node::WORKER) num_workers_++;
    }
    nodes->control.node.push_back(my_node_);//要把本节点放到里面
    nodes->control.cmd = Control::ADD_NODE;
    Message back;
    back.meta = *nodes;//消息包装nodes,广播到每一个work,server
    for (int r : Postoffice::Get()->GetNodeIDs(kWorkerGroup + kServerGroup)) 
    {
      int recver_id = r;
      if (shared_node_mapping_.find(r) == shared_node_mapping_.end()) 
      {
        back.meta.recver = recver_id;
        back.meta.timestamp = timestamp_++;
        Send(back);
      }
    }
    PS_VLOG(1) << "the scheduler is connected to " << num_workers_
               << " workers and " << num_servers_ << " servers";
    ready_ = true;//到这里可以看到scheduler已经显示准备好了,至于其他work和server收没收到啥的,我不管了。
  } else if (!recovery_nodes->control.node.empty()) {
    auto dead_nodes = Postoffice::Get()->GetDeadNodes(heartbeat_timeout_);
    std::unordered_set<int> dead_set(dead_nodes.begin(), dead_nodes.end());
    // send back the recovery node
    CHECK_EQ(recovery_nodes->control.node.size(), 1);
    Connect(recovery_nodes->control.node[0]);
    Postoffice::Get()->UpdateHeartbeat(recovery_nodes->control.node[0].id, t);
    Message back;
    for (int r : Postoffice::Get()->GetNodeIDs(kWorkerGroup + kServerGroup)) 
    {
      if (r != recovery_nodes->control.node[0].id && dead_set.find(r) != dead_set.end()) {
        // do not try to send anything to dead node
        continue;
      }
      // only send recovery_node to nodes already exist
      // but send all nodes to the recovery_node
      back.meta = (r == recovery_nodes->control.node[0].id) ? *nodes : *recovery_nodes;
      back.meta.recver = r;
      back.meta.timestamp = timestamp_++;
      Send(back);
    }
  }
}

这里面的msg就是一个work发来的消息,deadnodes_set先不用管,nodes是一个meta类型(一个message的数据头,具体数据结构可以看之前的类图或者源码),recovery_nodes也是。

void Van::UpdateLocalID(Message* msg, std::unordered_set<int>* deadnodes_set, Meta* nodes, Meta* recovery_nodes) 
{
  auto& ctrl = msg->meta.control;
  size_t num_nodes = Postoffice::Get()->num_servers() + Postoffice::Get()->num_workers();//num_nodes=4;
  // assign an id
  if (msg->meta.sender == Meta::kEmpty) //因为是work节点发过来的,而work节点初始化时候的id就是KEmpty.
  {
    CHECK(is_scheduler_);
    CHECK_EQ(ctrl.node.size(), 1);//msg中的control命令中的节点集合就是work自己,所以就是1个节点。
    if (nodes->control.node.size() < num_nodes) {nodes->control.node.push_back(ctrl.node[0]);} //因为sizes小于4
    else //如果四个work和server到齐了,就进入else
    {
      // some node dies and restarts
      CHECK(ready_.load());
      for (size_t i = 0; i < nodes->control.node.size() - 1; ++i) 
      {
	        const auto& node = nodes->control.node[i];
	        if (deadnodes_set->find(node.id) != deadnodes_set->end() &&node.role == ctrl.node[0].role) 
	        {
	          auto& recovery_node = ctrl.node[0];
	          // assign previous node id
	          recovery_node.id = node.id;
	          recovery_node.is_recovery = true;
	          nodes->control.node[i] = recovery_node;
	          recovery_nodes->control.node.push_back(recovery_node);
	          break;
            }
      }
    }
  }

  // update my id, 其实对于scheduler的话这个函数没用,因为是work节点刚push进来,但是如果是schedule发给这个work这个几点的消息,如果发现本地的ip和port和消息中的某个一点重合,那么就把本地节点的ID(初始化时候没有ID,只是等于Empty)改为schedule发过来的身份证ID。
  for (size_t i = 0; i < ctrl.node.size(); ++i) 
  {
    const auto& node = ctrl.node[i];
    if (my_node_.hostname == node.hostname && my_node_.port == node.port) 
    {
      if (getenv("DMLC_RANK") == nullptr || my_node_.id == Meta::kEmpty) 
      {
        my_node_ = node;
        string rank = to_string(Postoffice::IDtoRank(node.id));//max((id - 8) / 2, 0)
        setenv("DMLC_RANK", rank.c_str(), true);
      }
    }
  }
}
worker/server

由于是框架,所以上面代码基本都覆盖了,具体看代码的时候就是看一下不同的分支判断就行了。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值