#pragma once
#include <mutex>
#include <condition_variable>
#include <deque>
#include <map>
#include <algorithm>
#include <memory>
#include <boost/thread/tss.hpp>
template <typename Kty, typename Vty>
class blocking_key_queue
{
public:
struct closed_exception{};
explicit blocking_key_queue(size_t capacity)
: m_capacity{ capacity }
, m_count{ 0 }
, m_opened{ true }
{}
void open()
{
std::lock_guard<std::mutex> _(m_mtx);
m_opened = true;
}
void close()
{
std::lock_guard<std::mutex> _(m_mtx);
m_opened = false;
m_not_empty.notify_all();
m_not_full.notify_all();
m_queue.clear();
m_sub_queues.clear();
m_count = 0;
}
void push(Kty&& key, Vty&& value)
{
std::unique_lock<std::mutex> lck(m_mtx);
m_not_full.wait(lck, [this]
{
if (!m_opened)
{
throw closed_exception();
}
return m_count < m_capacity;
});
++m_count;
m_queue.emplace_back(std::move(key), std::move(value));
m_not_empty.notify(lck);
}
void pop(Kty& key, Vty& value)
{
std::unique_lock<std::mutex> lck(m_mtx);
/// 优先消费线程队列
if (m_tss.get())
{
if (m_tss->sp_queue)
{
auto& q = *m_tss->sp_queue;
if (!q.empty())
{
if (!m_opened)
{
q.clear();
throw closed_exception();
}
--m_count;
key = std::move(q.front().first);
value = std::move(q.front().second);
q.pop_front();
m_not_full.notify(lck);
return;
}
/// 清除关键字关联
m_sub_queues.erase(m_tss->key);
}
}
do
{
m_not_empty.wait(lck, [this]
{
if (!m_opened)
{
throw closed_exception();
}
return m_count > 0 && !m_queue.empty();
});
/// 判断关键字是否有关联线程队列
auto& k = m_queue.front().first;
auto p = m_sub_queues.emplace(std::piecewise_construct,
std::forward_as_tuple(k),
std::forward_as_tuple());
if (p.second)
{
break;
}
auto sp = p.first->second.lock();
if (!sp)
{
break;
}
/// 转移元素到关联线程队列
sp->emplace_back(std::move(m_queue.front()));
m_queue.pop_front();
} while (true);
/// 建立关键字与线程队列关联
if (!m_tss.get())
{
m_tss.reset(new tss_);
}
if (!m_tss->sp_queue)
{
m_tss->sp_queue = std::make_shared<queue_t>();
}
auto& k = m_queue.front().first;
m_tss->key = k;
m_sub_queues[k] = m_tss->sp_queue;
/// 消费元素
--m_count;
key = std::move(m_queue.front().first);
value = std::move(m_queue.front().second);
m_queue.pop_front();
m_not_full.notify(lck);
}
protected:
class guard_
{
public:
explicit guard_(size_t& waiters)
: waiters_(waiters)
{
++waiters_;
}
~guard_()
{
--waiters_;
}
private:
size_t & waiters_;
};
class event_
{
public:
void wait(std::unique_lock<std::mutex>& lck)
{
guard_ _(m_waiters);
m_cnd.wait(lck);
}
template <typename F>
void wait(std::unique_lock<std::mutex>& lck, F f)
{
guard_ _(m_waiters);
m_cnd.wait(lck, f);
}
void notify(std::unique_lock<std::mutex>& lck, size_t n = 1)
{
auto times = std::min(n, m_waiters);
for (size_t i = 0; i < times; i++)
{
m_cnd.notify_one();
}
}
void notify_all()
{
m_cnd.notify_all();
}
private:
std::condition_variable m_cnd;
size_t m_waiters{ 0 };
};
private:
using queue_t = std::deque<std::pair<Kty, Vty>>;
struct tss_
{
Kty key;
std::shared_ptr<queue_t> sp_queue;
};
std::mutex m_mtx;
event_ m_not_empty;
event_ m_not_full;
size_t m_capacity{ 0 };
queue_t m_queue;
std::map<Kty, std::weak_ptr<queue_t>> m_sub_queues;
boost::thread_specific_ptr<tss_> m_tss;
size_t m_count{ 0 };
bool m_opened{ false };
};
测试代码:
#include <thread>
#include <iostream>
#include <cstdio>
#include "blocking_key_queue.h"
int main(int argc, char* argv[])
{
std::mutex mtx;
blocking_key_queue<size_t, size_t> mq(10);
static const size_t num_of_thread = 3;
std::thread thds[num_of_thread];
size_t cnts[num_of_thread];
for (size_t i = 0; i < num_of_thread; i++)
{
cnts[i] = 0;
thds[i] = std::move(std::thread(([&, i]
{
size_t key, value;
try
{
for (;;)
{
mq.pop(key, value);
{
cnts[i]++;
std::lock_guard<std::mutex> _(mtx);
std::cout << "[" << std::this_thread::get_id() << "]key = " << key << ", value = " << value << std::endl;
}
//std::this_thread::sleep_for(std::chrono::seconds(1));
}
}
catch (const blocking_key_queue<size_t, size_t>::closed_exception&)
{
std::lock_guard<std::mutex> _(mtx);
std::cout << "[" << std::this_thread::get_id() << "]queue closed, consume = " << cnts[i] << std::endl;
}
})));
}
for (size_t i = 0; i < 20; i++)
{
mq.push(1, std::move(i));
mq.push(2, std::move(i));
mq.push(3, std::move(i));
mq.push(4, std::move(i));
}
std::this_thread::sleep_for(std::chrono::seconds(5));
mq.close();
for (size_t i = 0; i < num_of_thread; i++)
{
thds[i].join();
}
int c = std::getc(stdin);
return 1;
}