C++ 20 线程安全的Map
粗粒度(锁住整个bucket_data)
template <typename Key, typename Value, typename Hash = std::hash<Key>>
class threadsafe_lookup_table
{
private:
class bucket_type
{
private:
using bucket_value = std::pair<Key, Value>;
using bucket_data = std::list<bucket_value>;
using bucket_iterator = typename bucket_data::iterator;
bucket_data data;
mutable std::shared_mutex mutex;
bucket_iterator find_entry_for(Key const& key) const
{
return std::find_if(data.begin(), data.end(), [&](bucket_value const& item) { return item.first == key; });
}
public:
/**
* std::shared_lock<std::shared_mutex> 允许多线程读
*/
Value value_for(Key const& key, Value const& default_value) const
{
std::shared_lock<std::shared_mutex> lock(mutex);
bucket_iterator const found = find_entry_for(key);
return found == data.end() ? default_value : found->second;
}
/**
* std::unique_lock<std::shared_mutex> 只允许单个线程持有
*/
void add_or_update_mapping(Key const& key, Value const& value)
{
std::unique_lock<std::shared_mutex> lock(mutex);
bucket_iterator const found = find_entry_for(key);
if (found == data.end())
{
data.push_back(std::make_pair(key, value));
}
else
{
found->second = value;
}
}
/**
* std::unique_lock<std::shared_mutex> 只允许单个线程持有
*/
void remove_mapping(Key const& key)
{
std::unique_lock<std::shared_mutex> lock(mutex);
bucket_iterator const found = find_entry_for(key);
if (found != data.end())
{
data.erase(found);
}
}
};
std::vector<std::unique_ptr<bucket_type>> buckets;
Hash hasher;
bucket_type& get_bucket(Key const& key) const
{
std::size_t const bucket_index = hasher(key) % buckets.size();
return *buckets[bucket_index];
}
public:
using key_type = Key;
using mapped_type = Value;
using hash_type = Hash;
threadsafe_lookup_table(std::size_t num_buckets = 23/*最好为质数*/, Hash const& hasher = Hash())
: buckets(num_buckets), hasher(hasher)
{
for (auto& bucket : buckets)
{
bucket = std::make_unique<bucket_type>();
}
}
threadsafe_lookup_table(threadsafe_lookup_table const& other) = delete;
threadsafe_lookup_table& operator=(threadsafe_lookup_table const& other) = delete;
Value value_for(Key const& key, Value const& default_value = Value()) const
{
return get_bucket(key).value_for(key, default_value);
}
void add_or_update_mapping(Key const& key, Value const& value)
{
get_bucket(key).add_or_update_mapping(key, value);
}
void remove_mapping(Key const& key)
{
get_bucket(key).remove_mapping(key);
}
// 快照
std::map<Key, Value> get_map() const
{
std::vector<std::unique_lock<std::shared_mutex>> locks;
// 持有所有bucket的锁 独占
for (auto& bucket : buckets)
{
locks.push_back(std::unique_lock<std::shared_mutex>(bucket->mutex));
}
std::map<Key, Value> res;
for (auto& bucket : buckets)
{
for (auto& item : bucket->data)
{
res.emplace(item.first, item.second);
}
}
return res;
}
};
细粒度(锁住每个node)
template <typename T>
class threadsafe_list
{
private:
struct node
{
std::mutex m;
std::shared_ptr<T> data;
std::unique_ptr<node> next;
node(): next()
{
}
node(T const& value): data(std::make_shared<T>(value)), next()
{
}
};
node head;
public:
threadsafe_list()
{
}
~threadsafe_list()
{
remove_if([](node const&) { return true; });
}
threadsafe_list(threadsafe_list const& other) = delete;
threadsafe_list& operator=(threadsafe_list const& other) = delete;
void push_front(T const& value)
{
std::unique_ptr<node> new_node(new node(value));
std::lock_guard<std::mutex> lock(head.m);
new_node->next = std::move(head.next);
head.next = std::move(new_node);
}
template <typename Function>
void for_each(Function f)
{
node* current = &head;
std::unique_lock<std::mutex> lock(head.m);
while (node* const next = current->next.get())
{
std::unique_lock<std::mutex> next_lk(next->m);
lock.unlock();
f(*next->data);
current = next;
lock = std::move(next_lk);
}
}
template <typename Predicate>
std::shared_ptr<T> find_first_if(Predicate P)
{
node* current = &head;
std::unique_lock<std::mutex> lock(head.m);
while (node* const next = current->next.get())
{
std::unique_lock<std::mutex> next_lk(next->m);
lock.unlock();
if (P(*next->data))
{
return next->data;
}
current = next;
lock = std::move(next_lk);
}
return std::shared_ptr<T>();
}
template <typename Predicate>
void remove_if(Predicate p)
{
node* current = &head;
std::unique_lock<std::mutex> lk(head.m);
while (node* const next = current->next.get())
{
std::unique_lock<std::mutex> next_lk(next->m);
if (p(*next->data))
{
std::unique_ptr<node> old_next = std::move(current->next);
current->next = std::move(next->next);
next_lk.unlock();
}
else
{
lk.unlock();
current = next;
lk = std::move(next_lk);
}
}
}
};
参考《C++并发编程实战第二版》, 稍有改动