#pragma once
#include <mutex>
#include <condition_variable>
#include <algorithm>
class semaphore
{
public:
struct closed_exception {};
public:
explicit semaphore(size_t cnt = 0)
: m_cnt(cnt)
, m_opened(true)
{}
void open()
{
std::lock_guard<std::mutex> _(m_mtx);
m_opened = true;
}
void close()
{
std::lock_guard<std::mutex> _(m_mtx);
m_opened = false;
m_evt.notify_all();
}
void wait()
{
std::unique_lock<std::mutex> lck(m_mtx);
m_evt.wait(lck, [this]
{
if (!m_opened)
{
throw closed_exception();
}
return m_cnt > 0;
});
--m_cnt;
}
void post(size_t n = 1)
{
std::unique_lock<std::mutex> lck(m_mtx);
m_cnt += n;
m_evt.notify(lck, n);
}
protected:
class guard_
{
public:
explicit guard_(size_t& waiters)
: waiters_(waiters)
{
++waiters_;
}
~guard_()
{
--waiters_;
}
private:
size_t & waiters_;
};
class event_
{
public:
void wait(std::unique_lock<std::mutex>& lck)
{
guard_ _(m_waiters);
m_cnd.wait(lck);
}
template <typename F>
void wait(std::unique_lock<std::mutex>& lck, F f)
{
guard_ _(m_waiters);
m_cnd.wait(lck, f);
}
void notify(std::unique_lock<std::mutex>& lck, size_t n = 1)
{
auto times = std::min(n, m_waiters);
for (size_t i = 0; i < times; i++)
{
m_cnd.notify_one();
}
}
void notify_all()
{
m_cnd.notify_all();
}
private:
std::condition_variable m_cnd;
size_t m_waiters{ 0 };
};
private:
std::mutex m_mtx;
event_ m_evt;
size_t m_cnt{ 0 };
bool m_opened{ false };
};
测试代码
#include <thread>
#include <iostream>
#include "semaphore.h"
int main(int argc, char* argv[])
{
semaphore sem;
size_t cnt = 0;
std::thread thds[2];
for (size_t i = 0; i < 2; i++)
{
thds[i] = std::move(std::thread(([&]
{
try
{
for (;;)
{
sem.wait();
std::cout << "thread:" << std::this_thread::get_id() << ", semaphore post: " << cnt++ << std::endl;
}
}
catch (const semaphore::closed_exception&)
{
std::cout << "thread:" << std::this_thread::get_id() << ", semaphore closed" << std::endl;
}
})));
}
for (size_t i = 0; i < 10; i++)
{
std::this_thread::sleep_for(std::chrono::seconds(1));
sem.post();
}
sem.close();
for (size_t i = 0; i < 2; i++)
{
thds[i].join();
}
return 1;
}