本文素材来源于《C++并发编程实战(第2版)》,撰写本文的原因是,书中关于本查找表的实现实在是漏洞百出,可能是笔者使用的vs过于lj?(无端联想),但是还是一点一点的把bug都修了一下,至少能跑通,并且相对多线程安全。
废话就不说了,直接上代码:
#include <mutex>
#include <thread>
#include <map>
#include <iostream>
#include <shared_mutex>
#include <list>
#include <vector>
#include <chrono>
template<class Key, class Value, class Hash = std::hash<Key>>
class threadsafe_lookup_table
{
private:
typedef std::pair<Key, Value> bucket_value;
typedef std::list<bucket_value> bucket_data;
typedef typename bucket_data::iterator bucket_iterator;
typedef Key key_type;
typedef Value value_type;
typedef Hash hash_type;
class bucket_type
{
private:
mutable std::list<bucket_value> data;
mutable std::shared_mutex mtx;
bucket_iterator find_entry_for(const Key& key) const
{
bucket_iterator it;
for (it = data.begin(); it != data.end(); ++it)
{
if ((*it).first == key)
break;
}
return it;
}
public:
Value value_for(const Key& key, const Value& default_value) const
{
std::unique_lock<std::shared_mutex>lck(mtx);
const bucket_iterator found_entry = find_entry_for(key);
return (found_entry == data.end()) ? default_value : found_entry->second;
}
void add_or_update_mapping(const Key& key, const Value& value)
{
std::unique_lock<std::shared_mutex>lck(mtx);
const bucket_iterator found_entry = find_entry_for(key);
if (found_entry == data.end())
data.push_back(bucket_value(key, value));
else
found_entry->second = value;
}
void remove_mapping(const Key& key)
{
std::unique_lock<std::shared_mutex>lck(mtx);
const bucket_iterator found_entry = find_entry_for(key);
if (found_entry != data.end())
data.erase(found_entry);
}
bucket_data& get_data()
{
return this->data;
}
std::shared_mutex& get_mutex()
{
return this->mtx;
}
};
// 用容器来存储桶
std::vector<std::unique_ptr<bucket_type>>buckets;
Hash hasher;
bucket_type& get_bucket(const Key& key) const
{
const std::size_t bucket_index = hasher(key) % buckets.size();
return *buckets[bucket_index];
}
public:
threadsafe_lookup_table(unsigned num_buckets = 19,
const Hash& hasher_ = Hash()) :
buckets(num_buckets), hasher(hasher_)
{
for (unsigned i = 0; i < num_buckets; i++)
buckets[i].reset(new bucket_type); // 重置 unique_ptr
}
threadsafe_lookup_table(const threadsafe_lookup_table& other) = delete;
threadsafe_lookup_table& operator=(const threadsafe_lookup_table& other) = delete;
Value value_for(const Key& key, const Value& default_value = Value()) const
{
return get_bucket(key).value_for(key, default_value);
}
void add_or_update_mapping(const Key& key, const Value& value)
{
get_bucket(key).add_or_update_mapping(key, value);
}
void remove_mapping(const Key& key)
{
get_bucket(key).remove_mapping(key);
}
std::map<Key, Value> get_map() const
{
std::vector<std::unique_lock<std::shared_mutex>>locks;
for (unsigned i = 0; i < buckets.size(); ++i)
locks.push_back(std::unique_lock<std::shared_mutex>(buckets[i]->get_mutex()));
std::map<Key, Value>res;
for (unsigned i = 0; i < buckets.size(); ++i)
{
for (bucket_iterator it = buckets[i]->get_data().begin();
it != buckets[i]->get_data().end(); ++it)
res.insert(*it);
}
return res;
}
};
void get_currentmap(const threadsafe_lookup_table<int, int>& mp)
{
const unsigned maxn = 8;
for (unsigned i = 0; i < maxn; ++i)
{
std::map<int, int>cur_mp = mp.get_map();
for (const auto tmp : cur_mp)
std::cout << "key :" << tmp.first
<< " value : " << tmp.second << std::endl;
auto dur = std::chrono::milliseconds(1);
std::this_thread::sleep_for(dur);
}
}
void func_add_or_update(threadsafe_lookup_table<int, int>& mp)
{
const unsigned maxn = 15;
for (unsigned i = 0; i < maxn; i++)
{
mp.add_or_update_mapping(i, i + 19);
std::cout << "add_or_update key :" << i << " value :" << i + 19 << std::endl;
mp.add_or_update_mapping(i, i + 38);
std::cout << "add_or_update key :" << i << " value :" << i + 38 << std::endl;
int res = mp.value_for(i);
std::cout << "current value for key :" << i << " is " << res << std::endl;
}
}
void func_remove(threadsafe_lookup_table<int, int>& mp)
{
const unsigned maxn = 5;
for (unsigned i = 0; i < maxn; ++i)
{
mp.remove_mapping(i);
std::cout << "remove from map on key :" << i << std::endl;
}
}
int main()
{
threadsafe_lookup_table<int, int>mp;
std::thread obj1(func_add_or_update, std::ref(mp));
std::thread obj2(func_remove, std::ref(mp));
std::thread obj3(get_currentmap, std::ref(mp));
obj1.join();
obj2.join();
obj3.join();
return 0;
}
多线程程序的运行结果当然是因人而异的,这里就不贴了大家的都不一样。上述代码中修改了原文中的问题以及修改如下:
1、原文在bucket_type类中的find_entry_for中使用了库函数find_if导致了笔者的编译器无法将_Init类型转换成相应bucket_data类型的迭代器,手动实现了一下。
2、原文在追加数据快照的时候并没有考虑到嵌套类的数据访问,导致嵌套类中定义的宏无法被外层的threadsafe_lookup_table识别,笔者略调了一下代码顺序,追加了数据访问的接口。
3、一些杂七杂八的const限制修正和指针访问修正