A Fork of Stuffs

It is simple enough thus nothing more to describe.

基于 POCO 框架的 TCP 连接分流程序

介绍


下面的程序实现了对 TCP 连接的分流,即将一个 TCP 连接的流量分布到多个 TCP 连接上进行传输。

本程序的主要作用是在特定网络环境下提升通过 TCP 连接的 OpenVPN 服务的速率,使之充分利用带宽。

程序的主要复杂之处在于单生产者—多消费者、多生产者—单消费者两种同步方式的实现。


代码


// multisocks, copyright (c) 2013 coolypf

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <string>
#include <map>
#include <vector>
#include <Poco/Thread.h>
#include <Poco/Event.h>
#include <Poco/ErrorHandler.h>
#include <Poco/FIFOBuffer.h>
#include <Poco/Net/StreamSocket.h>
#include <Poco/Net/ServerSocket.h>

using namespace std;
using namespace Poco;
using namespace Poco::Net;

class error_handler : public ErrorHandler
{
public:
    virtual void exception(const Exception& exc)
    {
        printf("[!] Thread %u: Caught exception: %s\n", (unsigned int)Thread::currentTid(), exc.displayText().c_str());
        fflush(stdout);
        exit(1);
    }

    virtual void exception(const std::exception& exc)
    {
        printf("[!] Thread %u: Caught exception: %s\n", (unsigned int)Thread::currentTid(), exc.what());
        fflush(stdout);
        exit(2);
    }

    virtual void exception()
    {
        printf("[!] Thread %u: Caught unknown exception\n", (unsigned int)Thread::currentTid());
        fflush(stdout);
        exit(3);
    }
};

class
{
    map<string, string> config;
    string empty;

    int read_file(const char *filename, vector<char> &out)
    {
        FILE *fp = fopen(filename, "rb");
        if (!fp)
        {
            printf("[!] Fail to load: %s\n", filename);
            return 1;
        }
        char buf[4096];
        size_t sz;
        while ((sz = fread(buf, 1, 4096, fp)))
            out.insert(out.end(), buf, buf + sz);
        fclose(fp);
        for (size_t i = 0; i < out.size(); ++i)
            if (!out[i])
                out[i] = ' ';
        return 0;
    }

    void split_lines(const vector<char> &in, vector<string> &out)
    {
        size_t pos = 0;
        for (size_t i = 0; i < in.size(); ++i)
        {
            if (in[i] != '\r' && in[i] != '\n')
                continue;
            out.push_back(string(in.begin() + pos, in.begin() + i));
            if (in[i] == '\r' && i + 1 < in.size() && in[i + 1] == '\n')
                i++;
            pos = i + 1;
        }
        if (pos < in.size())
            out.push_back(string(in.begin() + pos, in.end()));
    }

    void trim_string(string &str)
    {
        const char *spaces = " \r\n\t\xb\xc";
        int first = 0, last = (int)str.size() - 1;
        while (first <= last && strchr(spaces, str[first]))
            ++first;
        while (last >= first && strchr(spaces, str[last]))
            --last;
        str.resize(last + 1);
        str.erase(0, first);
    }

    void split_string(const string &in, vector<string> &out, const char *delims)
    {
        size_t pos = 0;
        for (size_t i = 0; i < in.size(); ++i)
        {
            if (!strchr(delims, in[i]))
                continue;
            string part(in.begin() + pos, in.begin() + i);
            trim_string(part);
            out.push_back(part);
            pos = i + 1;
        }
        {
            string part(in.begin() + pos, in.end());
            trim_string(part);
            out.push_back(part);
        }
    }

public:
    int load(const char *filename)
    {
        vector<char> content;
        if (read_file(filename, content))
            return 1;
        vector<string> lines;
        split_lines(content, lines);
        for (size_t i = 0; i < lines.size(); ++i)
        {
            if (lines[i].empty() || lines[i][0] == '#')
                continue;
            vector<string> parts;
            split_string(lines[i], parts, "=");
            if (parts.size() != 2 || parts[0].empty())
            {
                printf("[!] Invalid config: %s (%d)\n", filename, (int)i + 1);
                continue;
            }
            if (config.find(parts[0]) != config.end())
                printf("[!] Override config: %s (%d)\n", filename, (int)i + 1);
            config[parts[0]] = parts[1];
        }
        return 0;
    }

    int i(const char *key) const
    {
        map<string, string>::const_iterator iter = config.find(string(key));
        if (iter == config.end())
        {
            printf("[!] No config: %s\n", key);
            return 0;
        }
        int ret = 0;
        if (sscanf(iter->second.c_str(), "%d", &ret) != 1)
            printf("[!] Invalid int: %s\n", key);
        return ret;
    }

    const string & s(const char *key) const
    {
        map<string, string>::const_iterator iter = config.find(string(key));
        if (iter == config.end())
        {
            printf("[!] No config: %s\n", key);
            return empty;
        }
        return iter->second;
    }
} config;

// Read from out buffer and send to remote
class reader : public Runnable
{
    int index;
    StreamSocket socket;
    Event &readable, &writable;
    FIFOBuffer &buffer;
    bool &self_stopped, &peer_stopped;
public:
    reader(int i, StreamSocket s, Event &r, Event &w, FIFOBuffer &b, bool &self, bool &peer)
        : index(i), socket(s), readable(r), writable(w), buffer(b), self_stopped(self), peer_stopped(peer)
    {
    }

    virtual void run()
    {
        printf("[.] Reader %d, tid = %u\n", index, (unsigned int)Thread::currentTid());
        fflush(stdout);
        char buf[8192];
        bool closed = false;
        while (true)
        {
            if (buffer.isReadable())
            {
                int sz = (int)buffer.read(buf, 8192);
                writable.set();
                int sent = 0;
                while (sent < sz)
                {
                    int sz2;
                    try
                    {
                        sz2 = socket.sendBytes(buf + sent, sz - sent);
                    }
                    catch (Exception &exc)
                    {
                        printf("[!] Reader %d: %s\n", index, exc.displayText().c_str());
                        fflush(stdout);
                        sz2 = 0;
                    }
                    if (sz2 <= 0)
                    {
                        closed = true;
                        break;
                    }
                    sent += sz2;
                }
                if (closed)
                    break;
            }
            else
            {
                if (peer_stopped)
                    break;
                readable.wait();
            }
        }
        try { socket.shutdownSend(); } catch (...) {}
        self_stopped = true;
        writable.set();
    }
};

// Receive from remote and write to in buffer
class writer : public Runnable
{
    int index;
    StreamSocket socket;
    Event &readable, &writable;
    FIFOBuffer &buffer;
    bool &self_stopped, &peer_stopped;
public:
    writer(int i, StreamSocket s, Event &r, Event &w, FIFOBuffer &b, bool &self, bool &peer)
        : index(i), socket(s), readable(r), writable(w), buffer(b), self_stopped(self), peer_stopped(peer)
    {
    }

    virtual void run()
    {
        printf("[.] Writer %d, tid = %u\n", index, (unsigned int)Thread::currentTid());
        fflush(stdout);
        char buf[8192];
        while (true)
        {
            int sz;
            try
            {
                sz = socket.receiveBytes(buf, 8192);
            }
            catch (Exception &exc)
            {
                printf("[!] Writer %d: %s\n", index, exc.displayText().c_str());
                fflush(stdout);
                sz = 0;
            }
            if (sz <= 0)
                break;
            int written = 0;
            while (written < sz)
            {
                if (buffer.isWritable())
                {
                    written += buffer.write(buf + written, sz - written);
                    readable.set();
                }
                else
                {
                    if (peer_stopped)
                        break;
                    writable.wait();
                }
            }
            if (peer_stopped)
                break;
        }
        try { socket.shutdownReceive(); } catch (...) {}
        self_stopped = true;
        readable.set();
    }
};

// Receive from local and write to out buffers
class divider : public Runnable
{
    int nr_conn;
    StreamSocket socket;
    vector<Event *> &readables;
    Event &writable;
    vector<FIFOBuffer *> buffers;
    bool &self_stopped, *peers_stopped;
public:
    divider(int n, StreamSocket s, vector<Event *> &vr, Event &w, vector<FIFOBuffer *> &b, bool &self, bool *peers)
        : nr_conn(n), socket(s), readables(vr), writable(w), buffers(b), self_stopped(self), peers_stopped(peers)
    {
    }

    virtual void run()
    {
        printf("[.] Divider, tid = %u\n", (unsigned int)Thread::currentTid());
        fflush(stdout);
        int sz0 = 8192 * nr_conn;
        char *buf0 = new char[sz0];
        char **bufv = new char *[nr_conn];
        for (int i = 0; i < nr_conn; ++i)
            bufv[i] = new char[8192];
        int *szv = new int[nr_conn];
        int *writtenv = new int[nr_conn];
        long long total = 0;
        bool peer_stopped = false;
        while (true)
        {
            int sz;
            try
            {
                sz = socket.receiveBytes(buf0, sz0);
            }
            catch (Exception &exc)
            {
                printf("[!] Divider: %s\n", exc.displayText().c_str());
                fflush(stdout);
                sz = 0;
            }
            if (sz <= 0)
                break;
            for (int i = 0; i < nr_conn; ++i)
            {
                szv[i] = 0;
                writtenv[i] = 0;
            }
            for (int i = 0; i < sz; ++i)
                bufv[(total + i) % nr_conn][szv[(total + i) % nr_conn]++] = buf0[i];
            total += sz;
            while (true)
            {
                bool done = true;
                for (int i = 0; i < nr_conn; ++i)
                    if (writtenv[i] < szv[i])
                        done = false;
                if (done)
                    break;
                bool written = false;
                for (int i = 0; i < nr_conn; ++i)
                {
                    if (buffers[i]->isWritable())
                    {
                        writtenv[i] += buffers[i]->write(bufv[i] + writtenv[i], szv[i] - writtenv[i]);
                        readables[i]->set();
                        written = true;
                    }
                }
                if (!written)
                {
                    for (int i = 0; i < nr_conn; ++i)
                        if (peers_stopped[i])
                            peer_stopped = true;
                    if (peer_stopped)
                        break;
                    writable.wait();
                    writable.reset();
                }
                if (peer_stopped)
                    break;
            }
        }
        try { socket.shutdownReceive(); } catch (...) {}
        self_stopped = true;
        for (int i = 0; i < nr_conn; ++i)
            readables[i]->set();
    }
};

// Read from in buffers and send to local
class combiner : public Runnable
{
    int nr_conn;
    StreamSocket socket;
    Event &readable;
    vector<Event *> &writables;
    vector<FIFOBuffer *> buffers;
    bool &self_stopped, *peers_stopped;
public:
    combiner(int n, StreamSocket s, Event &r, vector<Event *> &vw, vector<FIFOBuffer *> &b, bool &self, bool *peers)
        : nr_conn(n), socket(s), readable(r), writables(vw), buffers(b), self_stopped(self), peers_stopped(peers)
    {
    }

    virtual void run()
    {
        printf("[.] Combiner, tid = %u\n", (unsigned int)Thread::currentTid());
        fflush(stdout);
        char *buf0 = new char[8192 * nr_conn];
        char **bufv = new char *[nr_conn];
        for (int i = 0; i < nr_conn; ++i)
            bufv[i] = new char[8192];
        int *szv = new int[nr_conn];
        int *readv = new int[nr_conn];
        long long total = 0;
        bool peer_stopped = false, closed = false;
        while (true)
        {
            for (int i = 0; i < nr_conn; ++i)
            {
                szv[i] = buffers[i]->peek(bufv[i], 8192);
                readv[i] = 0;
            }
            int sz = 0;
            while (true)
            {
                int i = (total + sz) % nr_conn;
                if (readv[i] >= szv[i])
                    break;
                buf0[sz++] = bufv[i][readv[i]++];
            }
            total += sz;
            if (sz > 0)
            {
                for (int i = 0; i < nr_conn; ++i)
                {
                    if (readv[i] > 0)
                    {
                        buffers[i]->read(bufv[i], readv[i]);
                        writables[i]->set();
                    }
                }
                int sent = 0;
                while (sent < sz)
                {
                    int sz1;
                    try
                    {
                        sz1 = socket.sendBytes(buf0 + sent, sz - sent);
                    }
                    catch (Exception &exc)
                    {
                        printf("[!] Combiner: %s\n", exc.displayText().c_str());
                        fflush(stdout);
                        sz1 = 0;
                    }
                    if (sz1 <= 0)
                    {
                        closed = true;
                        break;
                    }
                    sent += sz1;
                }
                if (closed)
                    break;
            }
            else
            {
                for (int i = 0; i < nr_conn; ++i)
                    if (peers_stopped[i])
                        peer_stopped = true;
                if (peer_stopped)
                    break;
                readable.wait();
                readable.reset();
            }
        }
        try { socket.shutdownSend(); } catch (...) {}
        self_stopped = true;
        for (int i = 0; i < nr_conn; ++i)
            writables[i]->set();
    }
};

int main(int argc, char **argv)
{
    config.load("multisocks.txt");
    for (int i = 1; i < argc; ++i)
        if (config.load(argv[i]))
            return 1;
    if (!config.s("log").empty())
        freopen(config.s("log").c_str(), "w", stdout);
    ErrorHandler::set(new error_handler);
    fflush(stdout);
    try
    {
        int nr_conn = config.i("nr_conn");
        bool divider_stopped = false, combiner_stopped = false;
        bool *readers_stopped = new bool[nr_conn](), *writers_stopped = new bool[nr_conn]();
        Event in_readable(false), out_writable(false);
        in_readable.reset();
        out_writable.set();
        vector<Event *> in_writables, out_readables;
        for (int i = 0; i < nr_conn; ++i)
        {
            in_writables.push_back(new Event);
            in_writables.back()->set();
            out_readables.push_back(new Event);
            out_readables.back()->reset();
        }
        vector<StreamSocket> remotes;
        if (config.s("remote") == "listen")
        {
            ServerSocket server;
            if (config.s("remote.protocol") == "ipv6")
                server.bind6(config.i("remote.port"), true, true);
            else
                server.bind(config.i("remote.port"), true);
            server.listen();
            for (int i = 0; i < nr_conn; ++i)
            {
                SocketAddress client_addr;
                remotes.push_back(server.acceptConnection(client_addr));
                remotes.back().setNoDelay(true);
                printf("[.] Remote incoming: %s\n", client_addr.toString().c_str());
                fflush(stdout);
                if (client_addr.host() != remotes.front().peerAddress().host())
                    throw Exception(string("client addresses mismatch"));
            }
            server.close();
        }
        else
        {
            SocketAddress server_addr(config.s("remote.host"), config.i("remote.port"));
            for (int i = 0; i < nr_conn; ++i)
            {
                StreamSocket remote;
                if (!config.s("remote.bind").empty())
                    remote.impl()->bind(SocketAddress(config.s("remote.bind"), 0), true);
                remote.connect(server_addr);
                remote.setNoDelay(true);
                remotes.push_back(remote);
                printf("[.] Connected to remote from: %s\n", remote.address().toString().c_str());
                fflush(stdout);
            }
        }
        vector<FIFOBuffer *> in_buffers, out_buffers;
        vector<reader *> readers;
        vector<writer *> writers;
        Thread *reader_threads = new Thread[nr_conn], *writer_threads = new Thread[nr_conn];
        for (int i = 0; i < nr_conn; ++i)
        {
            in_buffers.push_back(new FIFOBuffer(config.i("buffer_size")));
            out_buffers.push_back(new FIFOBuffer(config.i("buffer_size")));
            readers.push_back(new reader(i, remotes[i], *out_readables[i], out_writable, *out_buffers.back(), readers_stopped[i], divider_stopped));
            writers.push_back(new writer(i, remotes[i], in_readable, *in_writables[i], *in_buffers.back(), writers_stopped[i], combiner_stopped));
            reader_threads[i].setStackSize(65536);
            writer_threads[i].setStackSize(65536);
            reader_threads[i].start(*readers.back());
            writer_threads[i].start(*writers.back());
        }
        StreamSocket local;
        if (config.s("local") == "listen")
        {
            ServerSocket server;
            if (config.s("local.protocol") == "ipv6")
                server.bind6(config.i("local.port"), true, true);
            else
                server.bind(config.i("local.port"), true);
            server.listen();
            local = server.acceptConnection();
            printf("[.] Local incoming: %s\n", local.peerAddress().toString().c_str());
            fflush(stdout);
            local.setNoDelay(true);
            server.close();
        }
        else
        {
            SocketAddress server_addr(config.s("local.host"), config.i("local.port"));
            local.connect(server_addr);
            local.setNoDelay(true);
            printf("[.] Connected to local from: %s\n", local.address().toString().c_str());
            fflush(stdout);
        }
        divider *pdivider = new divider(nr_conn, local, out_readables, out_writable, out_buffers, divider_stopped, readers_stopped);
        combiner *pcombiner = new combiner(nr_conn, local, in_readable, in_writables, in_buffers, combiner_stopped, writers_stopped);
        Thread divider_thread, combiner_thread;
        divider_thread.setStackSize(65536);
        combiner_thread.setStackSize(65536);
        divider_thread.start(*pdivider);
        combiner_thread.start(*pcombiner);
        printf("[.] Connection established\n");
        fflush(stdout);
        divider_thread.join();
        combiner_thread.join();
        local.close();
        for (int i = 0; i < nr_conn; ++i)
        {
            reader_threads[i].join();
            writer_threads[i].join();
            remotes[i].close();
        }
        printf("[.] Connection closed\n");
    }
    catch (Exception &exc)
    {
        ErrorHandler::handle(exc);
    }
    catch (exception &exc)
    {
        ErrorHandler::handle(exc);
    }
    catch (...)
    {
        ErrorHandler::handle();
    }
    return 0;
}

测试


客户端网络环境为 CERNET ,服务器是位于达拉斯的 BurstNET VPS ,连接方式是 TCP/IPv6 。

直接连接, OpenVPN 速率为 341 KB/s 。使用 5 个连接进行分流, OpenVPN 速率可达 1.4 MB/s 。

阅读更多
个人分类: Development
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页

加入CSDN,享受更精准的内容推荐,与500万程序员共同成长!
关闭
关闭