[源码解析]机器学习参数服务器ps-lite(4) ----- 应用节点实现

时间:2023-12-11 21:43:38

[源码解析]机器学习参数服务器ps-lite(4) ----- 应用节点实现

0x00 摘要

本文是参数服务器的第四篇,介绍KVWorker, KVServer。

本系列其他文章是:

[源码解析] 机器学习参数服务器ps-lite 之(1) ----- PostOffice

[源码解析] 机器学习参数服务器ps-lite(2) ----- 通信模块Van

[源码解析] 机器学习参数服务器ps-lite 之(3) ----- 代理人Customer

KVWorker, KVServer这两个分别是 Server / Worker 节点的抽象,是被 Van ---> Customer ---> recv_handle_ 来作为引擎的一部分来启动的。

本文会先介绍一些基础支撑类,然后介绍 Server / Worker的基类 SimpleApp,最后介绍 Server / Worker 的具体实现。

总体流程图提前剧透如下:

[源码解析]机器学习参数服务器ps-lite(4) ----- 应用节点实现

0x01 基础类

我们首先需要介绍一些基础类。

1.1 Range

Range 类作用是:根据这个Range确定要拉取的参数在哪个server上,以及一个server对应的key的range。

Range 类提供如下函数:

  • begin()和end()两个uint64的位置;
  • size() 获得 本 range 的大小,即 end_ - begin_;
class Range {
public:
Range() : Range(0, 0) {}
Range(uint64_t begin, uint64_t end) : begin_(begin), end_(end) { } uint64_t begin() const { return begin_; }
uint64_t end() const { return end_; }
uint64_t size() const { return end_ - begin_; }
private:
uint64_t begin_;
uint64_t end_;
};

1.2 TreadsafeQueue

TreadsafeQueue 是一个可以供多个线程读取的队列,通过锁和条件量合作来达到线程安全,用来做消息队列。

/**
* \brief thread-safe queue allowing push and waited pop
*/
class ThreadsafePQueue {
public:
ThreadsafePQueue() { }
~ThreadsafePQueue() { } /**
* \brief push an value into the end. threadsafe.
* \param new_value the value
*/
void Push(Message new_value) {
mu_.lock();
queue_.push(std::move(new_value));
mu_.unlock();
cond_.notify_all();
} /**
* \brief wait until pop an element from the beginning, threadsafe
* \param value the poped value
*/
void WaitAndPop(Message* value) { // 等待队列不为空,按照优先级pop message
std::unique_lock<std::mutex> lk(mu_);
cond_.wait(lk, [this]{return !queue_.empty();});
*value = std::move(queue_.top());
queue_.pop();
} private:
class Compare {
public:
bool operator()(const Message &l, const Message &r) {
return l.meta.priority <= r.meta.priority;
}
};
mutable std::mutex mu_; //数据同步互斥变量
std::priority_queue<Message, std::vector<Message>, Compare> queue_; // message优先队列
std::condition_variable cond_; //队列不为空条件变量
};

0x02 SimpleApp

2.1 概述

SimpleApp是一个基类,把应用节点功能做了一个统一抽象。

  • 提供了基本发送功能和简单消息处理函数(Request, Wait, Response)。
  • 消息类型为:int型的head和string型的body。
  • 它有2个派生类。KVServer和KVWorker。

2.2 定义

2.2.1 支撑类

SimpleData 定义了 Request 和 Response 的基本格式。

struct SimpleData {
/** \brief the int head */
int head;
/** \brief the string body */
std::string body;
/** \brief sender's node id */
int sender;
/** \brief the associated timestamp */
int timestamp;
/** \brief sender's customer id */
int customer_id;
};

2.2.2 成员变量

SimpleApp 主要有如下成员变量:

  • Customer* obj_ :本 App 的 Customer,控制请求连接;
  • Handle request_handle_ :request 处理函数;
  • Handle response_handle_ :response 处理函数;
  • set_request_handle,set_response_handle:设置成员request_handle_, response_handle_。在客户端调用SimpleApp::Process时,根据message.meta中的指示变量判断是request还是response,调用相应handle处理;
class SimpleApp {
public:
/**
* \brief constructor
* @param app_id the app id, should match with the remote node app with which this app
* @param customer_id the customer_id, should be node-locally unique
* is communicated
*/
explicit SimpleApp(int app_id, int customer_id); /** \brief deconstructor */
virtual ~SimpleApp() { delete obj_; obj_ = nullptr; } /**
* \brief send a request to a remote node
*
* \param req_head request head
* \param req_body request body
* \param recv_id remote node id
*
* @return the timestamp of this request
*/
virtual inline int Request(int req_head, const std::string& req_body, int recv_id); /**
* \brief wait until a request is finished
*
* \param timestamp
*/
virtual inline void Wait(int timestamp) { obj_->WaitRequest(timestamp); } /**
* \brief send back a response for a request
* \param recv_req the received request
* \param the response body
*/
virtual inline void Response(const SimpleData& recv_req, const std::string& res_body = ""); /**
* \brief the handle to proces a received request/respoonse
*
* \param recved the received request or response
* \param app this pointer
*/
using Handle = std::function<void(const SimpleData& recved, SimpleApp* app)>; /**
* \brief set the request handle
* \param request_handle the request handle
*/
virtual inline void set_request_handle(const Handle& request_handle) {
CHECK(request_handle) << "invalid request handle";
request_handle_ = request_handle;
} /**
* \brief set the response handle
* \param response_handle the response handle
*/
virtual inline void set_response_handle(const Handle& response_handle) {
CHECK(response_handle) << "invalid response handle";
response_handle_ = response_handle;
} /**
* \brief returns the customer
*/
virtual inline Customer* get_customer() { return obj_; } protected:
/** \brief empty construct */
inline SimpleApp() : obj_(nullptr) {
request_handle_ = [](const SimpleData& recved, SimpleApp* app) {
app->Response(recved);
};
response_handle_ = [](const SimpleData& recved, SimpleApp* app) { };
} /** \brief process a received message */
virtual inline void Process(const Message& msg); /** \brief ps internal object */
Customer* obj_; private:
/** \brief request handle */
Handle request_handle_;
/** \brief request handle */
Handle response_handle_;
};

2.3 功能函数

三个简单功能函数如下:

Request 就是调用 Van 发送消息。

inline int SimpleApp::Request(int req_head, const std::string& req_body, int recv_id) {
// setup message
Message msg;
msg.meta.head = req_head;
if (req_body.size()) msg.meta.body = req_body;
int ts = obj_->NewRequest(recv_id);
msg.meta.timestamp = ts;
msg.meta.request = true;
msg.meta.simple_app = true;
msg.meta.app_id = obj_->app_id();
msg.meta.customer_id = obj_->customer_id(); // send
for (int r : Postoffice::Get()->GetNodeIDs(recv_id)) {
msg.meta.recver = r;
Postoffice::Get()->van()->Send(msg);
}
return ts;
}

Response 是调用 Van 回复消息。

inline void SimpleApp::Response(const SimpleData& req, const std::string& res_body) {
// setup message
Message msg;
msg.meta.head = req.head;
if (res_body.size()) msg.meta.body = res_body;
msg.meta.timestamp = req.timestamp;
msg.meta.request = false;
msg.meta.simple_app = true;
msg.meta.app_id = obj_->app_id();
msg.meta.customer_id = req.customer_id;
msg.meta.recver = req.sender; // send
Postoffice::Get()->van()->Send(msg);
}

Process 函数根据message.meta中的指示变量判断是request还是response,调用相应handle处理。

inline void SimpleApp::Process(const Message& msg) {
SimpleData recv;
recv.sender = msg.meta.sender;
recv.head = msg.meta.head;
recv.body = msg.meta.body;
recv.timestamp = msg.meta.timestamp;
recv.customer_id = msg.meta.customer_id;
if (msg.meta.request) { // 判断是request还是response,调用相应handle处理
CHECK(request_handle_);
request_handle_(recv, this);
} else {
CHECK(response_handle_);
response_handle_(recv, this);
}
}

0x03 KVServer

KVServer 是 Server 节点的抽象,其作用是 接收信息处理信息返回结果三个步骤,主要功能是:

  • 维护 key-value pairs 数据;
  • 处理 & 应答 客户端的 push & pull 请求;
    • 函数request_handle_ 处理请求:
      • 在调用KVServer::Process时 会调用到 request_handle_
      • request_handle_默认为KVServerDefaultHandle
    • 函数Response用于返回数据;

3.1 定义

request_handle_ 是 request 处理函数,需要自定义。

  • 在该回调函数中使用者则需要实现各种优化器的的模型权重梯度更新算法和模型权重返回操作
  • 可直接参考ps-lite已实现的默认版本KVServerDefaultHandle。
/**
* \brief A server node for maintaining key-value pairs
*/
template <typename Val>
class KVServer : public SimpleApp {
public:
/**
* \brief constructor
* \param app_id the app id, should match with \ref KVWorker's id
*/
explicit KVServer(int app_id) : SimpleApp() {
using namespace std::placeholders;
obj_ = new Customer(app_id, app_id, std::bind(&KVServer<Val>::Process, this, _1));
} /** \brief deconstructor */
virtual ~KVServer() { delete obj_; obj_ = nullptr; } /**
* \brief the handle to process a push/pull request from a worker
* \param req_meta meta-info of this request
* \param req_data kv pairs of this request
* \param server this pointer
*/
using ReqHandle = std::function<void(const KVMeta& req_meta,
const KVPairs<Val>& req_data,
KVServer* server)>;
void set_request_handle(const ReqHandle& request_handle) {
CHECK(request_handle) << "invalid request handle";
request_handle_ = request_handle;
} /**
* \brief response to the push/pull request
* \param req the meta-info of the request
* \param res the kv pairs that will send back to the worker
*/
void Response(const KVMeta& req, const KVPairs<Val>& res = KVPairs<Val>()); private:
/** \brief internal receive handle */
void Process(const Message& msg);
/** \brief request handle */
ReqHandle request_handle_; // 需要用户自己实现
};

3.2 功能函数

3.2.1 Response

Response()就是向调用的worker发送 response 信息。与SimpleApp 比较下,发现 KVServer 这里对于 head 和 body 都有了新的处理。

需要注意的是:Response 函数应该是被用户自定义的 request_handle_ 调用,即 request_handle_ 处理收到的消息,然后调用 Response 对 worker 进行回复应答

template <typename Val>
void KVServer<Val>::Response(const KVMeta& req, const KVPairs<Val>& res) {
Message msg;
msg.meta.app_id = obj_->app_id();
msg.meta.customer_id = req.customer_id;
msg.meta.request = false;
msg.meta.push = req.push;
msg.meta.pull = req.pull;
msg.meta.head = req.cmd;
msg.meta.timestamp = req.timestamp;
msg.meta.recver = req.sender;
if (res.keys.size()) {
msg.AddData(res.keys);
msg.AddData(res.vals);
if (res.lens.size()) {
msg.AddData(res.lens);
}
}
Postoffice::Get()->van()->Send(msg);
}

3.2.2 Process

Process()被注册到Customer对象中,当Customer对象的receiving thread接受到消息时,就调用Process()对数据进行处理。

Process()内部的逻辑是:

  • 提取消息的元信息,构建一个 KVMeta。
  • 可以看到,在 Process 中没有对 KV 数据的维护。
  • Process 调用 用户自行实现的一个request_handle_ (std::function函数对象)对数据进行处理。
  • 在回调函数 request_handle_ 中使用者则需要实现各种优化器的的模型权重梯度更新算法和模型权重返回操作
template <typename Val>
void KVServer<Val>::Process(const Message& msg) {
if (msg.meta.simple_app) {
SimpleApp::Process(msg); return;
}
KVMeta meta;
meta.cmd = msg.meta.head;
meta.push = msg.meta.push;
meta.pull = msg.meta.pull;
meta.sender = msg.meta.sender;
meta.timestamp = msg.meta.timestamp;
meta.customer_id = msg.meta.customer_id;
KVPairs<Val> data;
int n = msg.data.size();
if (n) {
CHECK_GE(n, 2);
data.keys = msg.data[0];
data.vals = msg.data[1];
if (n > 2) {
CHECK_EQ(n, 3);
data.lens = msg.data[2];
CHECK_EQ(data.lens.size(), data.keys.size());
}
}
CHECK(request_handle_);
request_handle_(meta, data, this);
}

3.2.3 例子函数

KVServerDefaultHandle 是 ps-lite 提供的例子,用于演示如何维护 KV,处理消息,返回请求。

这里维护一个哈希表 unordered_map,记录key和value,并对push和pull请求进行响应。

使用std::unordered_map store保存server的参数,当请求为push时,对store参数做更新,请求为pull时对参数进行拉取;

/**
* \brief an example handle adding pushed kv into store
*/
template <typename Val>
struct KVServerDefaultHandle {
void operator()(
const KVMeta& req_meta, const KVPairs<Val>& req_data, KVServer<Val>* server) {
size_t n = req_data.keys.size();
KVPairs<Val> res;
if (!req_meta.pull) {
CHECK_EQ(n, req_data.vals.size());
} else {
res.keys = req_data.keys; res.vals.resize(n);
}
for (size_t i = 0; i < n; ++i) {
Key key = req_data.keys[i];
if (req_meta.push) {
store[key] += req_data.vals[i];
}
if (req_meta.pull) {
res.vals[i] = store[key];
}
}
server->Response(req_meta, res);
}
std::unordered_map<Key, Val> store;
};

3.2.4 流程

我们接着上文继续梳理细化流程。

  • worker节点 或者 server节点 在程序的最开始会执行Postoffice::start()

  • Postoffice::start()会初始化节点信息,并且调用Van::start()

  • 每个节点都监听了本地一个端口;该连接的节点在启动时已经连接。

  • Van::start() 启动一个本地线程专门接收socket的信息,使用Van::Receiving()来持续监听收到的message。

    • receiver_thread_ = std::unique_ptr<std::thread>(new std::thread(&Van::Receiving, this));
  • Van::Receiving()接收后消息之后,根据不同命令执行不同动作。针对数据消息,如果需要下一步处理,会调用 ProcessDataMsg:

    • 依据消息中的app id找到 Customer(每个app 任务会绑定一个custom类),即会根据customer id的不同将message发给不同的customer的recv thread。
    • 将消息传递给Customer::Accept函数。
  • Customer::Accept() 函数将消息添加到一个队列recv_queue_

  • Customer 对象本身也会启动一个接受线程 recv_thread_,使用 Customer::Receiving() :

    • 不断的从recv_queue_队列取消息。
    • 如果 (!recv.meta.request) ,就说明是 response,则tracker_[req.timestamp].second++
    • 调用注册的用户自定义的recv_handle_函数对消息进行处理。
  • 对于worker来说,其注册的recv_handle_KVWorker::Process()函数。因为worker的recv thread接受到的消息主要是从server处pull下来的KV对,因此该Process()主要是接收message中的KV对;

  • 而对于Server来说,其注册的recv_handle_KVServer::Process()函数。

  • 因为我们这里是 KVServer,而且server接受的是worker们push上来的KV对,需要对其进行处理,因此该Process()函数中调用的用户通过KVServer::set_request_handle()传入的函数对象。

  • 在 用户自定义的 request_handle_ 函数中,如果需要发送 response 给 worker,则调用 KVServer::Response。

目前逻辑如下图,在 第 8 步,recv_handle_ 指向 KVServer::Process 或者 KVWorker::Process(本节是server,所以对应的是KVServer::Process)。在第10步,返回 response 给 worker。

            +--------------------------+
| Van |
| |
Request +-----------> Receiving |
| 1 + | +---------------------------+
| | | | Postoffice |
| | 2 | | |
| v | GetCustomer | |
| ProcessDataMsg <------------------> unordered_map customers_|
| + | 3 | |
| | | +---------------------------+
+--------------------------+
|
| 4
|
+------------------------------------+
| Customer | |
| | |
| v |
| Accept |
| + |
| | |
| | 5 |
| v |
| recv_queue_ | +------------------+
| + | |KVWorker |
| | 6 | +--------> | |
| | | | 8 | Process |
| v | | +------------------+
| recv_thread_ +---> Receiving | |
| + | |
| | 7 | |
| | | | +------------------+
| v | | |KVServer |
| recv_handle_+---------+--------> | |
| | 8 | Process |
+------------------------------------+ | + |
+------------------+
|
| 9
v
+-----------+-------+
| request_handle_ |
10 | |
Response <----------------------------------------------------+ Response |
| |
+-------------------+

0x04 KVWorker

4.1 概述

KVWorker用于向server节点push,pull key-value对,就是在算法过程中,需要并行处理的各种参数。

  • Worker中的push和pull操作都是异步返回一个ID,然后使用ID进行wait阻塞等待,即同步操作。
  • 或者异步调用时传入一个Callback进行后续操作。

4.2 定义

KVWorker 主要变量为:

  • std::unordered_map<int, std::vector<KVPairs>> recv_kvs :收到的pull 结果: kv value ;
  • std::unordered_map<int, Callback> callbacks :收到 request 的所有 response 之后执行的回调函数;
  • Slicer slicer_ :默认 slice 函数变量,该函数在调用Send函数时,将KVPairs按照每个server的Range切片;

主要函数为:

  • ZPush 零拷贝push函数

  • ZPull 零拷贝pull函数

  • AddPullCB key重组函数

  • Process 消息处理函数

  • DefaultSlicer 默认的slice 处理函数

  • set_slicer:设置slicer_成员,该函数在调用Send函数时,将KVPairs按照每个server的Range切片;

/**
* \brief A worker node that can \ref Push (\ref Pull) key-value pairs to (from) server
* nodes
*
* \tparam Val the type of value, which should be primitive types such as
* int32_t and float
*/
template<typename Val>
class KVWorker : public SimpleApp {
public:
/** avoid too many this-> */
using SimpleApp::obj_; // Customer 对象
/**
* \brief callback function for \ref Push and \ref Pull
*
* It is called by the data receiving thread of this instance when the push or
* pull is actually finished. Namely the kv pairs have already written into
* servers' data structure or the kv pairs have already pulled back.
*/
using Callback = std::function<void()>; /**
* \brief constructor
*
* \param app_id the app id, should match with \ref KVServer's id
* \param customer_id the customer id which is unique locally
*/
explicit KVWorker(int app_id, int customer_id) : SimpleApp() {
using namespace std::placeholders;
slicer_ = std::bind(&KVWorker<Val>::DefaultSlicer, this, _1, _2, _3);
obj_ = new Customer(app_id, customer_id, std::bind(&KVWorker<Val>::Process, this, _1));
} /** \brief deconstructor */
virtual ~KVWorker() { delete obj_; obj_ = nullptr; } using SlicedKVs = std::vector<std::pair<bool, KVPairs<Val>>>;
/**
* \brief a slicer partitions a key-value list according to the key ranges
* \param send the kv list for partitioning
* \param ranges the key ranges, ranges[i] is the key range of server i
* \param sliced the sliced lists. slices[i] should only contains keys in
* ranges[i] and the according values
*/
using Slicer = std::function<void(
const KVPairs<Val>& send, const std::vector<Range>& ranges,
SlicedKVs* sliced)>; /**
* \brief set a user-defined slicer
*/
void set_slicer(const Slicer& slicer) {
CHECK(slicer); slicer_ = slicer;
} private:
/**
* \brief add a callback for a request. threadsafe.
* @param cb callback
* @param timestamp the timestamp of the request
*/
void AddCallback(int timestamp, const Callback& cb) {
if (!cb) return;
std::lock_guard<std::mutex> lk(mu_);
callbacks_[timestamp] = cb;
} /** \brief data buffer for received kvs for each timestamp */
std::unordered_map<int, std::vector<KVPairs<Val>>> recv_kvs_; // 收到的 kv value
/** \brief callbacks for each timestamp */
std::unordered_map<int, Callback> callbacks_; // 收到 request 的所有 response 之后执行的回调函数
/** \brief lock */
std::mutex mu_;
/** \brief kv list slicer */
Slicer slicer_; // 默认 slice 函数变量
};

4.3 功能函数

4.3.1 Push & ZPush

因为 Push 调用了 ZPush,所以我们放在一起介绍。

Push方法主要就是:

  • 把数据(KV列表)发送到对应的服务器节点;
  • KV列表是依据每个服务器维护的 Key range 来进行分区发送;
  • Push 是异步直接返回,如果想知道返回结果如何,则可以:
    • 使用 Wait 来等待,即利用tracker_来记录发送的请求量和对应的响应请求量,当发送量等于接收量的时候,表示每个请求都成功发送了,以此来达到同步的目的;
    • 使用 callback,这样当结束时候就可以回调到。

ZPush 方法是:

  • 使用obj_(Customer类型)的 NewRequest 方法来记录记录发送的请求量和对应的响应请求量,并且返回一个时间戳;
  • 设置好对应 timestamp 的 callback;
  • 使用传入的参数构造KVPair对象,调用Send送出该对象;
  int Push(const std::vector<Key>& keys,
const std::vector<Val>& vals,
const std::vector<int>& lens = {},
int cmd = 0,
const Callback& cb = nullptr,
int priority = 0) {
return ZPush(
SArray<Key>(keys), SArray<Val>(vals), SArray<int>(lens), cmd, cb,
priority);
} int ZPush(const SArray<Key>& keys,
const SArray<Val>& vals,
const SArray<int>& lens = {},
int cmd = 0,
const Callback& cb = nullptr,
int priority = 0) {
int ts = obj_->NewRequest(kServerGroup);
AddCallback(ts, cb);
KVPairs<Val> kvs;
kvs.keys = keys;
kvs.vals = vals;
kvs.lens = lens;
kvs.priority = priority;
Send(ts, true, false, cmd, kvs);
return ts;
}

如何调用可以参考其注释:

   * Sample usage: the following codes push two KV pairs `{1, (1.1, 1.2)}` and `{3,
* (3.1,3.2)}` to server nodes, where the value is a length-2 float vector
* \code
* KVWorker<float> w;
* std::vector<Key> keys = {1, 3};
* std::vector<float> vals = {1.1, 1.2, 3.1, 3.2};
* w.Push(keys, vals);
* \endcode

4.3.2 Pull

pull方法跟push的逻辑大体类似:

  • 绑定一个回调函数,用于拷贝数据,并且得到一个时间戳。
  • 根据key_vector从Server上拉取val_vector,
  • 最终返回timestamp,
  • 该函数不阻塞,可用worker.Wait(timestamp)等待;
  int Pull(const std::vector<Key>& keys,
std::vector<Val>* vals,
std::vector<int>* lens = nullptr,
int cmd = 0,
const Callback& cb = nullptr,
int priority = 0) {
SArray<Key> skeys(keys);
int ts = AddPullCB(skeys, vals, lens, cmd, cb);
KVPairs<Val> kvs;
kvs.keys = skeys;
kvs.priority = priority;
Send(ts, false, true, cmd, kvs);
return ts;
}

4.3.3 ZPull

逻辑与 Pull 一致,只是省略了拷贝到系统这个过程。因此需要保证在ZPull完成前,调用者没有改变key_vector;

  int ZPull(const SArray<Key>& keys,
SArray<Val>* vals,
SArray<int>* lens = nullptr,
int cmd = 0,
const Callback& cb = nullptr,
int priority = 0) {
int ts = AddPullCB(keys, vals, lens, cmd, cb);
KVPairs<Val> kvs;
kvs.keys = keys;
kvs.priority = priority;
Send(ts, false, true, cmd, kvs);
return ts;
}

4.3.4 Send

Push()Pull()最后都会调用Send()函数,Send()对KVPairs进行切分,因为每个Server只保留一部分参数,因此切分后的SlicedKVpairs就会被发送给不同的Server。

如果是 skipped,则会直接调用 callback。

否则遍历发送。

template <typename Val>
void KVWorker<Val>::Send(int timestamp, bool push, bool pull, int cmd, const KVPairs<Val>& kvs) {
// slice the message
SlicedKVs sliced;
slicer_(kvs, Postoffice::Get()->GetServerKeyRanges(), &sliced); // need to add response first, since it will not always trigger the callback
int skipped = 0;
for (size_t i = 0; i < sliced.size(); ++i) {
if (!sliced[i].first) ++skipped;
}
obj_->AddResponse(timestamp, skipped);
if ((size_t)skipped == sliced.size()) {
RunCallback(timestamp);
} for (size_t i = 0; i < sliced.size(); ++i) {
const auto& s = sliced[i];
if (!s.first) continue;
Message msg;
msg.meta.app_id = obj_->app_id();
msg.meta.customer_id = obj_->customer_id();
msg.meta.request = true;
msg.meta.push = push;
msg.meta.pull = pull;
msg.meta.head = cmd;
msg.meta.timestamp = timestamp;
msg.meta.recver = Postoffice::Get()->ServerRankToID(i);
msg.meta.priority = kvs.priority;
const auto& kvs = s.second;
if (kvs.keys.size()) {
msg.AddData(kvs.keys);
msg.AddData(kvs.vals);
if (kvs.lens.size()) {
msg.AddData(kvs.lens);
}
}
Postoffice::Get()->van()->Send(msg);
}
}

4.3.5 DefaultSlicer

切分函数可以由用户自行重写,默认为DefaultSlicer,每个SlicedKVPairs被包装成Message对象,然后用van::send()发送。

根据std::vector& ranges分片范围信息,将要发送的数据进行分片。目前默认的使用 Postoffice::GetServerKeyRanges来划分分片范围。

template <typename Val>
void KVWorker<Val>::DefaultSlicer(
const KVPairs<Val>& send, const std::vector<Range>& ranges,
typename KVWorker<Val>::SlicedKVs* sliced) {
sliced->resize(ranges.size()); // find the positions in msg.key
size_t n = ranges.size();
std::vector<size_t> pos(n+1);
const Key* begin = send.keys.begin();
const Key* end = send.keys.end();
for (size_t i = 0; i < n; ++i) {
if (i == 0) {
pos[0] = std::lower_bound(begin, end, ranges[0].begin()) - begin;
begin += pos[0];
} else {
CHECK_EQ(ranges[i-1].end(), ranges[i].begin());
}
size_t len = std::lower_bound(begin, end, ranges[i].end()) - begin;
begin += len;
pos[i+1] = pos[i] + len; // don't send it to servers for empty kv
sliced->at(i).first = (len != 0);
}
CHECK_EQ(pos[n], send.keys.size());
if (send.keys.empty()) return; // the length of value
size_t k = 0, val_begin = 0, val_end = 0;
if (send.lens.empty()) {
k = send.vals.size() / send.keys.size();
CHECK_EQ(k * send.keys.size(), send.vals.size());
} else {
CHECK_EQ(send.keys.size(), send.lens.size());
} // slice
for (size_t i = 0; i < n; ++i) {
if (pos[i+1] == pos[i]) {
sliced->at(i).first = false;
continue;
}
sliced->at(i).first = true;
auto& kv = sliced->at(i).second;
kv.keys = send.keys.segment(pos[i], pos[i+1]);
if (send.lens.size()) {
kv.lens = send.lens.segment(pos[i], pos[i+1]);
for (int l : kv.lens) val_end += l;
kv.vals = send.vals.segment(val_begin, val_end);
val_begin = val_end;
} else {
kv.vals = send.vals.segment(pos[i]*k, pos[i+1]*k);
}
}
}

4.3.6 PushPull & ZPushPull

就是把 push,pull 聚合在一起。

  int PushPull(const std::vector<Key>& keys,
const std::vector<Val>& vals,
std::vector<Val>* outs,
std::vector<int>* lens = nullptr,
int cmd = 0,
const Callback& cb = nullptr,
int priority = 0) {
CHECK_NOTNULL(outs);
if (outs->empty())
outs->resize(vals.size());
else
CHECK_EQ(vals.size(), outs->size()); SArray<Key> skeys(keys);
SArray<Val> svals(vals);
auto souts = new SArray<Val>(outs->data(), outs->size());
SArray<int>* slens = lens ?
new SArray<int>(lens->data(), lens->size()) : nullptr;
int ts = ZPushPull(skeys, svals, souts, slens, cmd,
[this, cb, souts, slens]() {
delete souts;
delete slens;
if (cb) cb();
}, priority);
return ts;
} int ZPushPull(const SArray<Key>& keys,
const SArray<Val>& vals,
SArray<Val>* outs,
SArray<int>* lens = nullptr,
int cmd = 0,
const Callback& cb = nullptr,
int priority = 0) {
int ts = AddPullCB(keys, outs, lens, cmd, cb);
KVPairs<Val> kvs;
kvs.keys = keys;
kvs.vals = vals;
kvs.priority = priority;
if (lens)
kvs.lens = *lens;
Send(ts, true, true, cmd, kvs);
re

4.3.7 Callback 相关

前面提到了一些回调函数的设置,下面我们看看如何使用。

4.3.7.1 设置

我们可以看到,针对每个时间戳,设置了一个回调函数,进而构成了一个回调函数列表。

每次发送请求之后,都会往这个列表中注册回调函数。

  using Callback = std::function<void()>;

  /** \brief callbacks for each timestamp */
std::unordered_map<int, Callback> callbacks_; // 回调函数列表 void AddCallback(int timestamp, const Callback& cb) {
if (!cb) return;
std::lock_guard<std::mutex> lk(mu_);
callbacks_[timestamp] = cb; // 添加回调函数
}
4.3.7.2 AddPullCB

这是 pull 之后,得到应答的回调函数,用于拷贝返回的数据。

但是,如果是多个 Server 都应该有返回,应该如何处理?无论是 push 还是 pull,只有在收到了所有的Response之后,才会将从各个server上拉取的value填入本地的vals里。

template <typename Val>
template <typename C, typename D>
int KVWorker<Val>::AddPullCB(
const SArray<Key>& keys, C* vals, D* lens, int cmd,
const Callback& cb) {
int ts = obj_->NewRequest(kServerGroup);
AddCallback(ts, [this, ts, keys, vals, lens, cb]() mutable {
mu_.lock();
auto& kvs = recv_kvs_[ts];
mu_.unlock(); // do check
size_t total_key = 0, total_val = 0;
for (const auto& s : kvs) { // 进行有效性验证
Range range = FindRange(keys, s.keys.front(), s.keys.back()+1);
CHECK_EQ(range.size(), s.keys.size())
<< "unmatched keys size from one server";
if (lens) CHECK_EQ(s.lens.size(), s.keys.size());
total_key += s.keys.size();
total_val += s.vals.size();
}
CHECK_EQ(total_key, keys.size()) << "lost some servers?"; // fill vals and lens
std::sort(kvs.begin(), kvs.end(), [](
const KVPairs<Val>& a, const KVPairs<Val>& b) {
return a.keys.front() < b.keys.front();
});
CHECK_NOTNULL(vals);
if (vals->empty()) {
vals->resize(total_val);
} else {
CHECK_EQ(vals->size(), total_val);
}
Val* p_vals = vals->data();
int *p_lens = nullptr;
if (lens) {
if (lens->empty()) {
lens->resize(keys.size());
} else {
CHECK_EQ(lens->size(), keys.size());
}
p_lens = lens->data();
}
for (const auto& s : kvs) { // 拷贝返回的数据
memcpy(p_vals, s.vals.data(), s.vals.size() * sizeof(Val));
p_vals += s.vals.size();
if (p_lens) {
memcpy(p_lens, s.lens.data(), s.lens.size() * sizeof(int));
p_lens += s.lens.size();
}
} mu_.lock();
recv_kvs_.erase(ts);
mu_.unlock();
if (cb) cb();
}); return ts;
}
4.3.7.3 运行

就是依据时间戳找到回调函数,运行,然后删除。

何时调用,就是在 Process 之中会调用,我们马上介绍。

template <typename Val>
void KVWorker<Val>::RunCallback(int timestamp) {
mu_.lock();
auto it = callbacks_.find(timestamp);
if (it != callbacks_.end()) {
mu_.unlock(); CHECK(it->second);
it->second(); mu_.lock();
callbacks_.erase(it);
}
mu_.unlock();
}

4.3.8 Process

如果是 Pull 的 response, 在每次收到的Response返回的values,会先保存recv_kvs_里,recv_kvs_[ts].push_back(kvs);

无论是 push 还是 pull,只有在收到了所有的Response之后,才会将从各个server上拉取的value填入本地的vals里。

template <typename Val>
void KVWorker<Val>::Process(const Message& msg) {
if (msg.meta.simple_app) {
SimpleApp::Process(msg); return;
}
// store the data for pulling
int ts = msg.meta.timestamp;
if (msg.meta.pull) {
CHECK_GE(msg.data.size(), (size_t)2);
KVPairs<Val> kvs;
kvs.keys = msg.data[0];
kvs.vals = msg.data[1];
if (msg.data.size() > (size_t)2) {
kvs.lens = msg.data[2];
}
mu_.lock();
recv_kvs_[ts].push_back(kvs);
mu_.unlock();
} // finished, run callbacks,只有在收到了所有的Response之后
if (obj_->NumResponse(ts) == Postoffice::Get()->num_servers() - 1) {
RunCallback(ts); // 在这里调用了 RunCallback。
}
}

0x05 总结

最后我们用一个消息传递流程做一下总结,看看各个部分在其中如何使用。总体流程图如下:

  1. worker节点 要发送消息,所以调用了 Send 方法。
  2. Send 方法会调用到了 Customer的 NewRequest,来建立一个新请求。
  3. Postoffice::start()会初始化节点信息,并且调用Van::start()
  4. Send方法会调用 Van 的 send 方法来进行网络交互。
  5. 经过网络传递之后,流程来到了 Server 处,对于 Server 来说,这是一个 Request,调用到了 Van 的 Receiving。Van::Receiving()接收后消息之后,根据不同命令执行不同动作。针对数据消息,如果需要下一步处理,会调用 ProcessDataMsg。
  6. 继续调用到 Van 的 ProcessDataMsg,然后调用 GetCustomer。
  7. GetCustomer 会调用到Postoffice,对于 customers_ 进行相应处理。
  8. Customer 会使用 Accept 来处理消息。
  9. Customer::Accept() 函数将消息添加到一个队列recv_queue_
  10. Customer 对象本身也会启动一个接受线程 recv_thread_,使用 Customer::Receiving() :
    1. 不断的从recv_queue_队列取消息。
    2. 如果 (!recv.meta.request) ,就说明是 response,则tracker_[req.timestamp].second++
    3. 调用注册的用户自定义的recv_handle_函数对消息进行处理。
  11. Van::Receiving() 调用注册的用户自定义的recv_handle_函数对消息进行处理。
  12. 对于Server来说,其注册的recv_handle_KVServer::Process()函数。
  13. Process 函数调用 request_handle_ 继续处理,生成 Response,返回给 Worker。
  14. Response 经过网络传递给 Worker。
  15. 运行回到了 Worker,来到了 Worker 的 Van。对于 worker 来说,这是一个 Request,调用到了 Van 的 Receiving。(以下操作序列和 Server 类似
  16. Van::Receiving()接收后消息之后,根据不同命令执行不同动作。针对数据消息,如果需要下一步处理,会调用 ProcessDataMsg。
  17. Customer 会使用 Accept 来处理消息。
  18. Customer::Accept() 函数将消息添加到一个队列recv_queue_
  19. 这里有个解耦合,由一个新线程 recv_thread_处理。
  20. Customer 对象本身已经启动一个新线程 recv_thread_,使用 Customer::Receiving() 。
  21. 对于Worker来说,其注册的recv_handle_KVWorker::Process()函数。
  22. 调用到KVWorker::Process()函数处理响应消息Response。
+---------------------+       +------------------------+   Worker   +  Server            +--------------------------+
| KVWorker | 1 | Van | 3 | | Van |
| Send +--------+---------------> send +-----------------+-----> Request +-----------> Receiving |
| | | | | | + |
| | | | Receiving <---------+ | 4 | | | +---------------------------+
| | | | + | | | | | | | Postoffice |
| Process | | | | 16 | | | | | 5 | | |
| ^ | | | v | | 15 | | v | GetCustomer | |
| | | | | ProcessDataMsg | | | | ProcessDataMsg <------------------> unordered_map customers_|
| | | | | + | | | | + | 6 | |
| | | | | | | | | | | | +---------------------------+
+---------------------+ | +------------------------+ | | +--------------------------+
| | | | | |
| |2 | 17 | | | 7
| | | | | |
| +---------------------------------------+ | | +------------------------------------+
| | Customer | | | | | | Customer | |
| | | v | | | | | |
| | v | | | | v |
| | NewRequest Accept | | | | Accept |
| | + | | | | + |
| | | 18 | | | | | |
| | | | | | | | 8 |
| | v | | | | v |
| | revc_queue_ | | | | recv_queue_ |
| | + | | | | + |
22 | | | 19 | | | | | 9 |
| | | | | | | | |
| | 20 v | | | | 10 v |
| | recv_thread_ +-------> Receving | | | | recv_thread_ +---> Receiving |
| | | | | | | + |
| | | 21 | | | | | 11 |
| | | | | | | | | +------------------+
| | v | | | | v | |KVServer |
+---------------------------+ recv_handle | | | | recv_handle_+------------------> | |
| | | | | | 12 | Process |
+---------------------------------------+ | | +------------------------------------+ | + |
| | +------------------+
| | |
| | | 13
| | v
| | +-----------+-------+
| | | request_handle_ |
| | 14 | |
+<-----------+ Response <----------------------------------------------------+ Response |
| | |
| +-------------------+
+

手机如下:

[源码解析]机器学习参数服务器ps-lite(4) ----- 应用节点实现

0xEE 个人信息

★★★★★★关于生活和技术的思考★★★★★★

微信公众账号:罗西的思考

如果您想及时得到个人撰写文章的消息推送,或者想看看个人推荐的技术资料,敬请关注。

[源码解析]机器学习参数服务器ps-lite(4) ----- 应用节点实现

0xFF 参考

史上最全面的ps-lite理解

从零实现机器学习参数服务器框架(二)