ps-lite 使用步骤
ps-lite使用步骤:
1.初始化KVServer
2.设置处理worker 向server push 和 pull 数据的函数,对应的函数为set_request_handle。这个函数需要用户自己实现,并且能处理来自worker的push和pull请求。
3. 设置参数更新的方式,同步还是异步。
4. 调用ps::Start() 函数 初始化 节点 、 网络 等。
5. 初始化worker 先从server拉取到全部参数,然后在worker更新参数后,再发送到server端。
1.初始化KVServer
KVServer 继承自 SimpleApp
功能:初始化server端 用来保存 key-values 数据。
explicit KVServer(int app_id) : SimpleApp() {
using namespace std::placeholders;
obj_ = new Customer(app_id, std::bind(&KVServer<Val>::Process, this, _1));
}
对比初始化KVWorker
explicit KVWorker(int app_id) : SimpleApp() {
using namespace std::placeholders;
slicer_ = std::bind(&KVWorker<Val>::DefaultSlicer, this, _1, _2, _3);
obj_ = new Customer(app_id, std::bind(&KVWorker<Val>::Process, this, _1));
}
可以看出KVWorker 对比KVServer 多了一个slicer_ 函数 并且在初始化Customer的时候bind的函数分别是KVWorker::Process 函数和 KVServer::Process 函数。
我们首先看KVServer::Process 函数。
inline void SimpleApp::Process(const Message& msg) {
SimpleData recv;
recv.sender = msg.meta.sender;
recv.head = msg.meta.head;
recv.body = msg.meta.body;
recv.timestamp = msg.meta.timestamp;
if (msg.meta.request) {
CHECK(request_handle_);
request_handle_(recv, this);
} else {
CHECK(response_handle_);
response_handle_(recv, this);
}
}
从以上函数可以看出,Process 函数在处理request 请求时调用了request_handle_,在处理response 请求时,调用了response_handle_函数。其中request_handle_和response_handle_ 都是Handle类型
using Handle = std::function<void(const SimpleData& recved, SimpleApp* app)>;
以上是Handle 的定义,这里使用了std::function,其实就是一个函数引用,返回值类型是void,参数1:SimpleData类型,参数2:SimpleApp 指针,也就是上面我们说的,需要自己实现的处理request和response函数的函数。
上面所说的都是Customer类的两个函数,构造函数内初始化 Customer 类,参数1:app_id,参数2:process bind后的函数引用。
Customer::Customer(int id, const Customer::RecvHandle& recv_handle)
: id_(id), recv_handle_(recv_handle) {
Postoffice::Get()->AddCustomer(this);//1
recv_thread_ = std::unique_ptr<std::thread>(new std::thread(&Customer::Receiving, this));//2
}
以上是Customer类的构造函数,1 这个地方调用了Postoffice 的AddCustomer 函数,如下:
void Postoffice::AddCustomer(Customer* customer) {
std::lock_guard<std::mutex> lk(mu_);//1
int id = CHECK_NOTNULL(customer)->id();
CHECK_EQ(customers_.count(id), (size_t)0) << "id " << id << " already exists";
customers_[id] = customer;//2
}
以上代码中1这个地方用了lock_guard , lock_guard 的作用就是对互斥量上锁,用lock_guard 上锁能避免直接上锁但是出现异常时不能及时释放锁的问题。
其中2 这个地方customers_ 是一个unordered_map
recv_thread_ = std::unique_ptr<std::thread>(new std::thread(&Customer::Receiving, this));//2
这个地方用了一个智能指针指向了新创将的Receiving 线程。
void Customer::Receiving() {
while (true) {
Message recv;
recv_queue_.WaitAndPop(&recv);
if (!recv.meta.control.empty() &&
recv.meta.control.cmd == Control::TERMINATE) {
break;
}
recv_handle_(recv);
if (!recv.meta.request) {
std::lock_guard<std::mutex> lk(tracker_mu_);
tracker_[recv.meta.timestamp].second++;
tracker_cond_.notify_all();
}
}
}
这个线程一直在循环的在线程安全的队列中推出消息,然后调用recv_handle_ 处理消息。并通知所有wait的线程。
看完Customer 类,我们在看下KVWorder 中的
slicer_ = std::bind(&KVWorker<Val>::DefaultSlicer, this, _1, _2, _3);
void KVWorker<Val>::DefaultSlicer(
const KVPairs<Val>& send, const std::vector<Range>& ranges,
typename KVWorker<Val>::SlicedKVs* sliced) {
sliced->resize(ranges.size());
// find the positions in msg.key
size_t n = ranges.size();
std::vector<size_t> pos(n+1);
const Key* begin = send.keys.begin();
const Key* end = send.keys.end();
for (size_t i = 0; i < n; ++i) {
if (i == 0) {
pos[0] = std::lower_bound(begin, end, ranges[0].begin()) - begin;
begin += pos[0];
} else {
CHECK_EQ(ranges[i-1].end(), ranges[i].begin());
}
size_t len = std::lower_bound(begin, end, ranges[i].end()) - begin;
begin += len;
pos[i+1] = pos[i] + len;
// don't send it to severs for empty kv
sliced->at(i).first = (len != 0);
}
CHECK_EQ(pos[n], send.keys.size());
if (send.keys.empty()) return;
// the length of value
size_t k = 0, val_begin = 0, val_end = 0;
if (send.lens.empty()) {
k = send.vals.size() / send.keys.size();
CHECK_EQ(k * send.keys.size(), send.vals.size());
} else {
CHECK_EQ(send.keys.size(), send.lens.size());
}
// slice
for (size_t i = 0; i < n; ++i) {
if (pos[i+1] == pos[i]) {
sliced->at(i).first = false;
continue;
}
sliced->at(i).first = true;
auto& kv = sliced->at(i).second;
kv.keys = send.keys.segment(pos[i], pos[i+1]);
if (send.lens.size()) {
kv.lens = send.lens.segment(pos[i], pos[i+1]);
for (int l : kv.lens) val_end += l;
kv.vals = send.vals.segment(val_begin, val_end);
val_begin = val_end;
} else {
kv.vals = send.vals.segment(pos[i]*k, pos[i+1]*k);
}
}
}
这段代码的主要作用就是把不同的梯度发送到不同的server中。
2. 设置处理worker 向server push 和 pull 数据的函数
void set_request_handle(const ReqHandle& request_handle) {
CHECK(request_handle) << "invalid request handle";
request_handle_ = request_handle;
}
其中request_handle是用户自己实现函数的函数指针。
3. 设置参数更新的方式,同步还是异步。
4. 调用ps::Start() 函数 初始化 节点 、 网络 等。
inline void Start(const char* argv0 = nullptr) {
Postoffice::Get()->Start(argv0, true);
}
这里调用了Postoffice::Start()函数
void Postoffice::Start(const char* argv0, const bool do_barrier) {
// init glog
if (argv0) {
dmlc::InitLogging(argv0);
} else {
dmlc::InitLogging("ps-lite\0");
}
// init node info.
for (int i = 0; i < num_workers_; ++i) {
int id = WorkerRankToID(i);
for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
kWorkerGroup + kScheduler,
kWorkerGroup + kServerGroup + kScheduler}) {
node_ids_[g].push_back(id);
}
}
for (int i = 0; i < num_servers_; ++i) {
int id = ServerRankToID(i);
for (int g : {id, kServerGroup, kWorkerGroup + kServerGroup,
kServerGroup + kScheduler,
kWorkerGroup + kServerGroup + kScheduler}) {
node_ids_[g].push_back(id);
}
}
for (int g : {kScheduler, kScheduler + kServerGroup + kWorkerGroup,
kScheduler + kWorkerGroup, kScheduler + kServerGroup}) {
node_ids_[g].push_back(kScheduler);
}
// start van
van_->Start();
// record start time
start_time_ = time(NULL);
// do a barrier here
if (do_barrier) Barrier(kWorkerGroup + kServerGroup + kScheduler);
}
这段代码
1. 初始化日志
2. 遍历所有workers , servers ,groups 并存入node_ids_
3. Van::star()
void Van::Start() {
// get scheduler info
scheduler_.hostname = std::string(CHECK_NOTNULL(Environment::Get()->find("DMLC_PS_ROOT_URI")));
scheduler_.port = atoi(CHECK_NOTNULL(Environment::Get()->find("DMLC_PS_ROOT_PORT")));
scheduler_.role = Node::SCHEDULER;
scheduler_.id = kScheduler;
is_scheduler_ = Postoffice::Get()->is_scheduler();
// get my node info
if (is_scheduler_) {
my_node_ = scheduler_;
} else {
auto role = is_scheduler_ ? Node::SCHEDULER :
(Postoffice::Get()->is_worker() ? Node::WORKER : Node::SERVER);
const char* nhost = Environment::Get()->find("DMLC_NODE_HOST");
std::string ip;
if (nhost) ip = std::string(nhost);
if (ip.empty()) {
const char* itf = Environment::Get()->find("DMLC_INTERFACE");
std::string interface;
if (itf) interface = std::string(itf);
if (interface.size()) {
GetIP(interface, &ip);
} else {
GetAvailableInterfaceAndIP(&interface, &ip);
}
CHECK(!interface.empty()) << "failed to get the interface";
}
int port = GetAvailablePort();
const char* pstr = Environment::Get()->find("PORT");
if (pstr) port = atoi(pstr);
CHECK(!ip.empty()) << "failed to get ip";
CHECK(port) << "failed to get a port";
my_node_.hostname = ip;
my_node_.role = role;
my_node_.port = port;
// cannot determine my id now, the scheduler will assign it later
// set it explicitly to make re-register within a same process possible
my_node_.id = Node::kEmpty;
}
// bind.
my_node_.port = Bind(my_node_, is_scheduler_ ? 0 : 40);
PS_VLOG(1) << "Bind to " << my_node_.DebugString();
CHECK_NE(my_node_.port, -1) << "bind failed";
// connect to the scheduler
Connect(scheduler_);
// for debug use
if (Environment::Get()->find("PS_DROP_MSG")) {
drop_rate_ = atoi(Environment::Get()->find("PS_DROP_MSG"));
}
// start receiver
receiver_thread_ = std::unique_ptr<std::thread>(
new std::thread(&Van::Receiving, this));
if (!is_scheduler_) {
// let the scheduler know myself
Message msg;
msg.meta.recver = kScheduler;
msg.meta.control.cmd = Control::ADD_NODE;
msg.meta.control.node.push_back(my_node_);
msg.meta.timestamp = timestamp_++;
Send(msg);
}
// wait until ready
while (!ready_) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
// resender
if (Environment::Get()->find("PS_RESEND") && atoi(Environment::Get()->find("PS_RESEND")) != 0) {
int timeout = 1000;
if (Environment::Get()->find("PS_RESEND_TIMEOUT")) {
timeout = atoi(Environment::Get()->find("PS_RESEND_TIMEOUT"));
}
resender_ = new Resender(timeout, 10, this);
}
if (!is_scheduler_) {
// start heartbeat thread
heartbeat_thread_ = std::unique_ptr<std::thread>(
new std::thread(&Van::Heartbeat, this));
}
}
1) 设置scheduler 的hostname,port,role,id等参数。
2) 如果当前节点不是scheduler节点,设置当前节点的hostname,port,role,id 等属性。
3)所有节点都连接到scheduler 节点。
4)初始化recerving线程
void Van::Receiving() {
const char* heartbeat_timeout_val = Environment::Get()->find("PS_HEARTBEAT_TIMEOUT");
const int heartbeat_timeout = heartbeat_timeout_val ? atoi(heartbeat_timeout_val) : 5;
Meta nodes; // for scheduler usage
while (true) {
Message msg;
int recv_bytes = RecvMsg(&msg);
// For debug, drop received message
if (ready_ && drop_rate_ > 0) {
unsigned seed = time(NULL) + my_node_.id;
if (rand_r(&seed) % 100 < drop_rate_) {
LOG(WARNING) << "Drop message " << msg.DebugString();
continue;
}
}
CHECK_NE(recv_bytes, -1);
recv_bytes_ += recv_bytes;
if (Postoffice::Get()->verbose() >= 2) {
PS_VLOG(2) << msg.DebugString();
}
// duplicated message
if (resender_ && resender_->AddIncomming(msg)) continue;
if (!msg.meta.control.empty()) {
// do some management
auto& ctrl = msg.meta.control;
if (ctrl.cmd == Control::TERMINATE) {
PS_VLOG(1) << my_node_.ShortDebugString() << " is stopped";
ready_ = false;
break;
} else if (ctrl.cmd == Control::ADD_NODE) {
size_t num_nodes = Postoffice::Get()->num_servers() +
Postoffice::Get()->num_workers();
auto dead_nodes = Postoffice::Get()->GetDeadNodes(heartbeat_timeout);
std::unordered_set<int> dead_set(dead_nodes.begin(), dead_nodes.end());
Meta recovery_nodes; // store recovery nodes
recovery_nodes.control.cmd = Control::ADD_NODE;
// assign an id
if (msg.meta.sender == Meta::kEmpty) {
CHECK(is_scheduler_);
CHECK_EQ(ctrl.node.size(), 1);
if (nodes.control.node.size() < num_nodes) {
nodes.control.node.push_back(ctrl.node[0]);
} else {
// some node dies and restarts
CHECK(ready_);
for (size_t i = 0; i < nodes.control.node.size() - 1; ++i) {
const auto& node = nodes.control.node[i];
if (dead_set.find(node.id) != dead_set.end() && node.role == ctrl.node[0].role) {
auto& recovery_node = ctrl.node[0];
// assign previous node id
recovery_node.id = node.id;
recovery_node.is_recovery = true;
PS_VLOG(1) << "replace dead node " << node.DebugString()
<< " by node " << recovery_node.DebugString();
nodes.control.node[i] = recovery_node;
recovery_nodes.control.node.push_back(recovery_node);
break;
}
}
}
}
// update my id
for (size_t i = 0; i < ctrl.node.size(); ++i) {
const auto& node = ctrl.node[i];
if (my_node_.hostname == node.hostname &&
my_node_.port == node.port) {
my_node_ = node;
std::string rank = std::to_string(Postoffice::IDtoRank(node.id));
#ifdef _MSC_VER
_putenv_s("DMLC_RANK", rank.c_str());
#else
setenv("DMLC_RANK", rank.c_str(), true);
#endif
}
}
if (is_scheduler_) {
time_t t = time(NULL);
if (nodes.control.node.size() == num_nodes) {
// sort the nodes according their ip and port,
std::sort(nodes.control.node.begin(), nodes.control.node.end(),
[](const Node& a, const Node& b) {
return (a.hostname.compare(b.hostname) | (a.port < b.port)) > 0;
});
// assign node rank
for (auto& node : nodes.control.node) {
CHECK_EQ(node.id, Node::kEmpty);
int id = node.role == Node::SERVER ?
Postoffice::ServerRankToID(num_servers_) :
Postoffice::WorkerRankToID(num_workers_);
PS_VLOG(1) << "assign rank=" << id << " to node " << node.DebugString();
node.id = id;
Connect(node);
if (node.role == Node::SERVER) ++num_servers_;
if (node.role == Node::WORKER) ++num_workers_;
Postoffice::Get()->UpdateHeartbeat(node.id, t);
}
nodes.control.node.push_back(my_node_);
nodes.control.cmd = Control::ADD_NODE;
Message back; back.meta = nodes;
for (int r : Postoffice::Get()->GetNodeIDs(
kWorkerGroup + kServerGroup)) {
back.meta.recver = r;
back.meta.timestamp = timestamp_++;
Send(back);
}
PS_VLOG(1) << "the scheduler is connected to "
<< num_workers_ << " workers and " << num_servers_ << " servers";
ready_ = true;
} else if (recovery_nodes.control.node.size() > 0) {
// send back the recovery node
CHECK_EQ(recovery_nodes.control.node.size(), 1);
Connect(recovery_nodes.control.node[0]);
Postoffice::Get()->UpdateHeartbeat(recovery_nodes.control.node[0].id, t);
Message back;
for (int r : Postoffice::Get()->GetNodeIDs(
kWorkerGroup + kServerGroup)) {
if (r != recovery_nodes.control.node[0].id
&& dead_set.find(r) != dead_set.end()) {
// do not try to send anything to dead node
continue;
}
// only send recovery_node to nodes already exist
// but send all nodes to the recovery_node
back.meta = (r == recovery_nodes.control.node[0].id) ? nodes : recovery_nodes;
back.meta.recver = r;
back.meta.timestamp = timestamp_++;
Send(back);
}
}
} else {
for (const auto& node : ctrl.node) {
Connect(node);
if (!node.is_recovery && node.role == Node::SERVER) ++num_servers_;
if (!node.is_recovery && node.role == Node::WORKER) ++num_workers_;
}
PS_VLOG(1) << my_node_.ShortDebugString() << " is connected to others";
ready_ = true;
}
} else if (ctrl.cmd == Control::BARRIER) {
if (msg.meta.request) {
if (barrier_count_.empty()) {
barrier_count_.resize(8, 0);
}
int group = ctrl.barrier_group;
++barrier_count_[group];
PS_VLOG(1) << "Barrier count for " << group << " : " << barrier_count_[group];
if (barrier_count_[group] ==
static_cast<int>(Postoffice::Get()->GetNodeIDs(group).size())) {
barrier_count_[group] = 0;
Message res;
res.meta.request = false;
res.meta.control.cmd = Control::BARRIER;
for (int r : Postoffice::Get()->GetNodeIDs(group)) {
res.meta.recver = r;
res.meta.timestamp = timestamp_++;
CHECK_GT(Send(res), 0);
}
}
} else {
Postoffice::Get()->Manage(msg);
}
} else if (ctrl.cmd == Control::HEARTBEAT) {
time_t t = time(NULL);
for (auto &node : ctrl.node) {
Postoffice::Get()->UpdateHeartbeat(node.id, t);
if (is_scheduler_) {
Message heartbeat_ack;
heartbeat_ack.meta.recver = node.id;
heartbeat_ack.meta.control.cmd = Control::HEARTBEAT;
heartbeat_ack.meta.control.node.push_back(my_node_);
heartbeat_ack.meta.timestamp = timestamp_++;
// send back heartbeat
Send(heartbeat_ack);
}
}
}
} else {
CHECK_NE(msg.meta.sender, Meta::kEmpty);
CHECK_NE(msg.meta.recver, Meta::kEmpty);
CHECK_NE(msg.meta.customer_id, Meta::kEmpty);
int id = msg.meta.customer_id;
auto* obj = Postoffice::Get()->GetCustomer(id, 5);
CHECK(obj) << "timeout (5 sec) to wait App " << id << " ready";
obj->Accept(msg);
}
}
}