Note: C++ 17里已经引进了读写锁 std::shared_mutex , 其lock()即以写方式加锁, 其lock_shared()即以读方式加锁。
https://en.cppreference.com/w/cpp/thread/shared_mutex
实现一个读写锁类, 可以有两种方式获取锁,读方式,写方式。 允许多个"读线程"同时进入临界区,但是同一时刻只允许一个"写线程"进入临界区。
当有写线程进入临界区时,不允许任何其他读或写线程同时进入。 写线程优先。
方法:
在类中增加成员, 记录当前正在临界区的“读线程”,"写线程"数目, 等待进入临界区的“读线程”,"写线程"数目。
增加2个信号量成员 用于"写线程"的等待和唤醒, 用于“读线程”的等待和唤醒。
常用场景:
多个读线程频繁访问临界区,偶尔有一个写线程进入临界区。
C++11 实现:
#pragma once
#include <iostream>
#include <thread>
#include <mutex>
#include <list>
#include <cstdlib>
#include <vector>
using namespace std;
class cppReadWriteLock
{
public:
cppReadWriteLock():
mWaitReadThreadNum(0),
mWaitWriteTrheadNum(0),
mReadingThreadNum(0),
mWritingThreadNum(0){
}
~cppReadWriteLock() {};
void getReadLock() {
unique_lock<mutex> uniLock(mMyMutex);
if (mWritingThreadNum || mWaitWriteTrheadNum) { //写优先,只要有线程在等待写,则不能让读线程得到机会。
++mWaitReadThreadNum;
while (mWritingThreadNum || mWaitWriteTrheadNum) {
mReadThreadCV.wait(uniLock);
}
--mWaitReadThreadNum;
}
++mReadingThreadNum;
}
void getWriteLock() {
unique_lock<mutex> uniLock(mMyMutex);
if (mWritingThreadNum || mReadingThreadNum) {
++mWaitWriteTrheadNum;
while (mWritingThreadNum || mReadingThreadNum) {
mWriteThreadCV.wait(uniLock);
}
--mWaitWriteTrheadNum;
}
++mWritingThreadNum;
}
void releaseReadLock() {
unique_lock<mutex> uniLock(mMyMutex);
--mReadingThreadNum;
if (mWaitWriteTrheadNum) {//有写线程在等待的话,直接尝试唤醒一个写线程,即使还有其他线程在读。写优先!
mWriteThreadCV.notify_one();
}
}
void releaseWriteLock() {
unique_lock<mutex> uniLock(mMyMutex);
--mWritingThreadNum;
if (mWaitWriteTrheadNum) {//写优先
mWriteThreadCV.notify_one();
}
else {
mReadThreadCV.notify_all();//通知所有被阻塞的read线程
}
}
private:
int mWaitReadThreadNum, mReadingThreadNum;
int mWaitWriteTrheadNum, mWritingThreadNum;
mutex mMyMutex;
condition_variable mReadThreadCV;//用于“读线程”的等待和唤醒。
condition_variable mWriteThreadCV;//用于"写线程"的等待和唤醒
};
测试代码:
生产者线程5个, 产生0-49数字, 将产生的数字存到全局变量list<int>尾部。第一个线程产生0~9.
消费者线程10个,用于从全局变量list<int>头部get数据,并打印,被get到的数据从list剔除;
观察者线程2个, 用于打印当前list中的元素。
#include "pch.h"
#include "cppReadWriteLock.h"
//生产者线程5个, 产生0-49数字, 将产生的数字存到全局变量list<int>尾部。第一个线程产生0~9.
//消费者线程10个,用于从全局变量list<int>头部get数据,并打印,被get到的数据从list剔除;
//观察者线程2个, 用于打印当前list中的元素。
const int produceThreadNum = 5;
const int consumeThreadNum = 10;
const int watchThreadNum = 2;
list<int> listCache;
int totalTargetNum = 50;//所有的生产者的目标是总共生产50个数字。
int currentProducedNum = 0;
int currentConsumedNum = 0;
cppReadWriteLock gWrLock;
void produceThread(int stIdx, int num) {
for (int i = stIdx; i < stIdx+num; i++) {
gWrLock.getWriteLock();
listCache.push_back(i);
currentProducedNum++;
cout << "Produce " << i << endl;
gWrLock.releaseWriteLock();
//sleep:
std::this_thread::sleep_for(std::chrono::milliseconds(rand() % 15 + 1));
}
}
void consumeThread() {
bool bStop = false;
while (true) {
gWrLock.getWriteLock();
if (!listCache.empty()) {
int topNumber = listCache.front();
listCache.pop_front();
currentConsumedNum++;
cout << "Consume " << topNumber << endl;
}
if (currentConsumedNum >= totalTargetNum) {
bStop = true;
}
gWrLock.releaseWriteLock();
if (bStop)
{
break;
}
//sleep:
std::this_thread::sleep_for(std::chrono::milliseconds(rand() % 15 + 1));
}
}
void watchThread() {
bool bIshouldStop = false;
while (true) {
gWrLock.getReadLock();
if (!listCache.empty()) {
cout << "Watch: ";
for (const auto& it : listCache) {
cout << it << "--";
}
cout << endl;
}
if (currentConsumedNum >= totalTargetNum) {
bIshouldStop = true;
}
gWrLock.releaseReadLock();
if (bIshouldStop) {
break;
}
//sleep:
std::this_thread::sleep_for(std::chrono::milliseconds(rand() % 15 + 1));
}
}
void main()
{
vector<thread> watchThreadsVec;
vector<thread> consumeThreadsVec;
vector<thread> produceThreadsVec;
for (int i = 0; i < watchThreadNum; ++i) {
watchThreadsVec.push_back(thread(watchThread));
}
for (int i = 0; i < consumeThreadNum; ++i) {
consumeThreadsVec.push_back(thread(consumeThread));
}
for (int i = 0; i < produceThreadNum; ++i) {
produceThreadsVec.push_back(thread(produceThread, i*(totalTargetNum / produceThreadNum), totalTargetNum / produceThreadNum));
}
for (auto& it : watchThreadsVec) {
it.join();
}
for (auto& it : produceThreadsVec) {
it.join();
}
for (auto& it : consumeThreadsVec) {
it.join();
}
}
Ref:
https://github.com/bo-yang/read_write_lock