基于 POCO 框架的 TCP 连接分流程序
发表于2013/11/25 19:18:27 11819人阅读
分类: Development
介绍
下面的程序实现了对 TCP 连接的分流,即将一个 TCP 连接的流量分布到多个 TCP 连接上进行传输。
本程序的主要作用是在特定网络环境下提升通过 TCP 连接的 OpenVPN 服务的速率,使之充分利用带宽。
程序的主要复杂之处在于单生产者—多消费者、多生产者—单消费者两种同步方式的实现。
代码
// multisocks, copyright (c) 2013 coolypf #include <stdio.h> #include <stdlib.h> #include <string.h> #include <string> #include <map> #include <vector> #include <Poco/Thread.h> #include <Poco/Event.h> #include <Poco/ErrorHandler.h> #include <Poco/FIFOBuffer.h> #include <Poco/Net/StreamSocket.h> #include <Poco/Net/ServerSocket.h> using namespace std; using namespace Poco; using namespace Poco::Net; class error_handler : public ErrorHandler { public: virtual void exception(const Exception& exc) { printf("[!] Thread %u: Caught exception: %s\n", (unsigned int)Thread::currentTid(), exc.displayText().c_str()); fflush(stdout); exit(1); } virtual void exception(const std::exception& exc) { printf("[!] Thread %u: Caught exception: %s\n", (unsigned int)Thread::currentTid(), exc.what()); fflush(stdout); exit(2); } virtual void exception() { printf("[!] Thread %u: Caught unknown exception\n", (unsigned int)Thread::currentTid()); fflush(stdout); exit(3); } }; class { map<string, string> config; string empty; int read_file(const char *filename, vector<char> &out) { FILE *fp = fopen(filename, "rb"); if (!fp) { printf("[!] Fail to load: %s\n", filename); return 1; } char buf[4096]; size_t sz; while ((sz = fread(buf, 1, 4096, fp))) out.insert(out.end(), buf, buf + sz); fclose(fp); for (size_t i = 0; i < out.size(); ++i) if (!out[i]) out[i] = ' '; return 0; } void split_lines(const vector<char> &in, vector<string> &out) { size_t pos = 0; for (size_t i = 0; i < in.size(); ++i) { if (in[i] != '\r' && in[i] != '\n') continue; out.push_back(string(in.begin() + pos, in.begin() + i)); if (in[i] == '\r' && i + 1 < in.size() && in[i + 1] == '\n') i++; pos = i + 1; } if (pos < in.size()) out.push_back(string(in.begin() + pos, in.end())); } void trim_string(string &str) { const char *spaces = " \r\n\t\xb\xc"; int first = 0, last = (int)str.size() - 1; while (first <= last && strchr(spaces, str[first])) ++first; while (last >= first && strchr(spaces, str[last])) --last; str.resize(last + 1); str.erase(0, first); } void split_string(const string &in, vector<string> &out, const char *delims) { size_t pos = 0; for (size_t i = 0; i < in.size(); ++i) { if (!strchr(delims, in[i])) continue; string part(in.begin() + pos, in.begin() + i); trim_string(part); out.push_back(part); pos = i + 1; } { string part(in.begin() + pos, in.end()); trim_string(part); out.push_back(part); } } public: int load(const char *filename) { vector<char> content; if (read_file(filename, content)) return 1; vector<string> lines; split_lines(content, lines); for (size_t i = 0; i < lines.size(); ++i) { if (lines[i].empty() || lines[i][0] == '#') continue; vector<string> parts; split_string(lines[i], parts, "="); if (parts.size() != 2 || parts[0].empty()) { printf("[!] Invalid config: %s (%d)\n", filename, (int)i + 1); continue; } if (config.find(parts[0]) != config.end()) printf("[!] Override config: %s (%d)\n", filename, (int)i + 1); config[parts[0]] = parts[1]; } return 0; } int i(const char *key) const { map<string, string>::const_iterator iter = config.find(string(key)); if (iter == config.end()) { printf("[!] No config: %s\n", key); return 0; } int ret = 0; if (sscanf(iter->second.c_str(), "%d", &ret) != 1) printf("[!] Invalid int: %s\n", key); return ret; } const string & s(const char *key) const { map<string, string>::const_iterator iter = config.find(string(key)); if (iter == config.end()) { printf("[!] No config: %s\n", key); return empty; } return iter->second; } } config; // Read from out buffer and send to remote class reader : public Runnable { int index; StreamSocket socket; Event &readable, &writable; FIFOBuffer &buffer; bool &self_stopped, &peer_stopped; public: reader(int i, StreamSocket s, Event &r, Event &w, FIFOBuffer &b, bool &self, bool &peer) : index(i), socket(s), readable(r), writable(w), buffer(b), self_stopped(self), peer_stopped(peer) { } virtual void run() { printf("[.] Reader %d, tid = %u\n", index, (unsigned int)Thread::currentTid()); fflush(stdout); char buf[8192]; bool closed = false; while (true) { if (buffer.isReadable()) { int sz = (int)buffer.read(buf, 8192); writable.set(); int sent = 0; while (sent < sz) { int sz2; try { sz2 = socket.sendBytes(buf + sent, sz - sent); } catch (Exception &exc) { printf("[!] Reader %d: %s\n", index, exc.displayText().c_str()); fflush(stdout); sz2 = 0; } if (sz2 <= 0) { closed = true; break; } sent += sz2; } if (closed) break; } else { if (peer_stopped) break; readable.wait(); } } try { socket.shutdownSend(); } catch (...) {} self_stopped = true; writable.set(); } }; // Receive from remote and write to in buffer class writer : public Runnable { int index; StreamSocket socket; Event &readable, &writable; FIFOBuffer &buffer; bool &self_stopped, &peer_stopped; public: writer(int i, StreamSocket s, Event &r, Event &w, FIFOBuffer &b, bool &self, bool &peer) : index(i), socket(s), readable(r), writable(w), buffer(b), self_stopped(self), peer_stopped(peer) { } virtual void run() { printf("[.] Writer %d, tid = %u\n", index, (unsigned int)Thread::currentTid()); fflush(stdout); char buf[8192]; while (true) { int sz; try { sz = socket.receiveBytes(buf, 8192); } catch (Exception &exc) { printf("[!] Writer %d: %s\n", index, exc.displayText().c_str()); fflush(stdout); sz = 0; } if (sz <= 0) break; int written = 0; while (written < sz) { if (buffer.isWritable()) { written += buffer.write(buf + written, sz - written); readable.set(); } else { if (peer_stopped) break; writable.wait(); } } if (peer_stopped) break; } try { socket.shutdownReceive(); } catch (...) {} self_stopped = true; readable.set(); } }; // Receive from local and write to out buffers class divider : public Runnable { int nr_conn; StreamSocket socket; vector<Event *> &readables; Event &writable; vector<FIFOBuffer *> buffers; bool &self_stopped, *peers_stopped; public: divider(int n, StreamSocket s, vector<Event *> &vr, Event &w, vector<FIFOBuffer *> &b, bool &self, bool *peers) : nr_conn(n), socket(s), readables(vr), writable(w), buffers(b), self_stopped(self), peers_stopped(peers) { } virtual void run() { printf("[.] Divider, tid = %u\n", (unsigned int)Thread::currentTid()); fflush(stdout); int sz0 = 8192 * nr_conn; char *buf0 = new char[sz0]; char **bufv = new char *[nr_conn]; for (int i = 0; i < nr_conn; ++i) bufv[i] = new char[8192]; int *szv = new int[nr_conn]; int *writtenv = new int[nr_conn]; long long total = 0; bool peer_stopped = false; while (true) { int sz; try { sz = socket.receiveBytes(buf0, sz0); } catch (Exception &exc) { printf("[!] Divider: %s\n", exc.displayText().c_str()); fflush(stdout); sz = 0; } if (sz <= 0) break; for (int i = 0; i < nr_conn; ++i) { szv[i] = 0; writtenv[i] = 0; } for (int i = 0; i < sz; ++i) bufv[(total + i) % nr_conn][szv[(total + i) % nr_conn]++] = buf0[i]; total += sz; while (true) { bool done = true; for (int i = 0; i < nr_conn; ++i) if (writtenv[i] < szv[i]) done = false; if (done) break; bool written = false; for (int i = 0; i < nr_conn; ++i) { if (buffers[i]->isWritable()) { writtenv[i] += buffers[i]->write(bufv[i] + writtenv[i], szv[i] - writtenv[i]); readables[i]->set(); written = true; } } if (!written) { for (int i = 0; i < nr_conn; ++i) if (peers_stopped[i]) peer_stopped = true; if (peer_stopped) break; writable.wait(); writable.reset(); } if (peer_stopped) break; } } try { socket.shutdownReceive(); } catch (...) {} self_stopped = true; for (int i = 0; i < nr_conn; ++i) readables[i]->set(); } }; // Read from in buffers and send to local class combiner : public Runnable { int nr_conn; StreamSocket socket; Event &readable; vector<Event *> &writables; vector<FIFOBuffer *> buffers; bool &self_stopped, *peers_stopped; public: combiner(int n, StreamSocket s, Event &r, vector<Event *> &vw, vector<FIFOBuffer *> &b, bool &self, bool *peers) : nr_conn(n), socket(s), readable(r), writables(vw), buffers(b), self_stopped(self), peers_stopped(peers) { } virtual void run() { printf("[.] Combiner, tid = %u\n", (unsigned int)Thread::currentTid()); fflush(stdout); char *buf0 = new char[8192 * nr_conn]; char **bufv = new char *[nr_conn]; for (int i = 0; i < nr_conn; ++i) bufv[i] = new char[8192]; int *szv = new int[nr_conn]; int *readv = new int[nr_conn]; long long total = 0; bool peer_stopped = false, closed = false; while (true) { for (int i = 0; i < nr_conn; ++i) { szv[i] = buffers[i]->peek(bufv[i], 8192); readv[i] = 0; } int sz = 0; while (true) { int i = (total + sz) % nr_conn; if (readv[i] >= szv[i]) break; buf0[sz++] = bufv[i][readv[i]++]; } total += sz; if (sz > 0) { for (int i = 0; i < nr_conn; ++i) { if (readv[i] > 0) { buffers[i]->read(bufv[i], readv[i]); writables[i]->set(); } } int sent = 0; while (sent < sz) { int sz1; try { sz1 = socket.sendBytes(buf0 + sent, sz - sent); } catch (Exception &exc) { printf("[!] Combiner: %s\n", exc.displayText().c_str()); fflush(stdout); sz1 = 0; } if (sz1 <= 0) { closed = true; break; } sent += sz1; } if (closed) break; } else { for (int i = 0; i < nr_conn; ++i) if (peers_stopped[i]) peer_stopped = true; if (peer_stopped) break; readable.wait(); readable.reset(); } } try { socket.shutdownSend(); } catch (...) {} self_stopped = true; for (int i = 0; i < nr_conn; ++i) writables[i]->set(); } }; int main(int argc, char **argv) { config.load("multisocks.txt"); for (int i = 1; i < argc; ++i) if (config.load(argv[i])) return 1; if (!config.s("log").empty()) freopen(config.s("log").c_str(), "w", stdout); ErrorHandler::set(new error_handler); fflush(stdout); try { int nr_conn = config.i("nr_conn"); bool divider_stopped = false, combiner_stopped = false; bool *readers_stopped = new bool[nr_conn](), *writers_stopped = new bool[nr_conn](); Event in_readable(false), out_writable(false); in_readable.reset(); out_writable.set(); vector<Event *> in_writables, out_readables; for (int i = 0; i < nr_conn; ++i) { in_writables.push_back(new Event); in_writables.back()->set(); out_readables.push_back(new Event); out_readables.back()->reset(); } vector<StreamSocket> remotes; if (config.s("remote") == "listen") { ServerSocket server; if (config.s("remote.protocol") == "ipv6") server.bind6(config.i("remote.port"), true, true); else server.bind(config.i("remote.port"), true); server.listen(); for (int i = 0; i < nr_conn; ++i) { SocketAddress client_addr; remotes.push_back(server.acceptConnection(client_addr)); remotes.back().setNoDelay(true); printf("[.] Remote incoming: %s\n", client_addr.toString().c_str()); fflush(stdout); if (client_addr.host() != remotes.front().peerAddress().host()) throw Exception(string("client addresses mismatch")); } server.close(); } else { SocketAddress server_addr(config.s("remote.host"), config.i("remote.port")); for (int i = 0; i < nr_conn; ++i) { StreamSocket remote; if (!config.s("remote.bind").empty()) remote.impl()->bind(SocketAddress(config.s("remote.bind"), 0), true); remote.connect(server_addr); remote.setNoDelay(true); remotes.push_back(remote); printf("[.] Connected to remote from: %s\n", remote.address().toString().c_str()); fflush(stdout); } } vector<FIFOBuffer *> in_buffers, out_buffers; vector<reader *> readers; vector<writer *> writers; Thread *reader_threads = new Thread[nr_conn], *writer_threads = new Thread[nr_conn]; for (int i = 0; i < nr_conn; ++i) { in_buffers.push_back(new FIFOBuffer(config.i("buffer_size"))); out_buffers.push_back(new FIFOBuffer(config.i("buffer_size"))); readers.push_back(new reader(i, remotes[i], *out_readables[i], out_writable, *out_buffers.back(), readers_stopped[i], divider_stopped)); writers.push_back(new writer(i, remotes[i], in_readable, *in_writables[i], *in_buffers.back(), writers_stopped[i], combiner_stopped)); reader_threads[i].setStackSize(65536); writer_threads[i].setStackSize(65536); reader_threads[i].start(*readers.back()); writer_threads[i].start(*writers.back()); } StreamSocket local; if (config.s("local") == "listen") { ServerSocket server; if (config.s("local.protocol") == "ipv6") server.bind6(config.i("local.port"), true, true); else server.bind(config.i("local.port"), true); server.listen(); local = server.acceptConnection(); printf("[.] Local incoming: %s\n", local.peerAddress().toString().c_str()); fflush(stdout); local.setNoDelay(true); server.close(); } else { SocketAddress server_addr(config.s("local.host"), config.i("local.port")); local.connect(server_addr); local.setNoDelay(true); printf("[.] Connected to local from: %s\n", local.address().toString().c_str()); fflush(stdout); } divider *pdivider = new divider(nr_conn, local, out_readables, out_writable, out_buffers, divider_stopped, readers_stopped); combiner *pcombiner = new combiner(nr_conn, local, in_readable, in_writables, in_buffers, combiner_stopped, writers_stopped); Thread divider_thread, combiner_thread; divider_thread.setStackSize(65536); combiner_thread.setStackSize(65536); divider_thread.start(*pdivider); combiner_thread.start(*pcombiner); printf("[.] Connection established\n"); fflush(stdout); divider_thread.join(); combiner_thread.join(); local.close(); for (int i = 0; i < nr_conn; ++i) { reader_threads[i].join(); writer_threads[i].join(); remotes[i].close(); } printf("[.] Connection closed\n"); } catch (Exception &exc) { ErrorHandler::handle(exc); } catch (exception &exc) { ErrorHandler::handle(exc); } catch (...) { ErrorHandler::handle(); } return 0; }
测试
客户端网络环境为 CERNET ,服务器是位于达拉斯的 BurstNET VPS ,连接方式是 TCP/IPv6 。
直接连接, OpenVPN 速率为 341 KB/s 。使用 5 个连接进行分流, OpenVPN 速率可达 1.4 MB/s 。