C++ 20 线程安全的Map

C++ 20 线程安全的Map

粗粒度(锁住整个bucket_data)

template <typename Key, typename Value, typename Hash = std::hash<Key>>
class threadsafe_lookup_table
{
private:
	class bucket_type
	{
	private:
		using bucket_value = std::pair<Key, Value>;
		using bucket_data = std::list<bucket_value>;
		using bucket_iterator = typename bucket_data::iterator;

		bucket_data data;
		mutable std::shared_mutex mutex;


		bucket_iterator find_entry_for(Key const& key) const
		{
			return std::find_if(data.begin(), data.end(), [&](bucket_value const& item) { return item.first == key; });
		}

	public:
		/**
		 * std::shared_lock<std::shared_mutex>  允许多线程读
		 */
		Value value_for(Key const& key, Value const& default_value) const
		{
			std::shared_lock<std::shared_mutex> lock(mutex);
			bucket_iterator const found = find_entry_for(key);
			return found == data.end() ? default_value : found->second;
		}

		/**
		 * std::unique_lock<std::shared_mutex> 只允许单个线程持有
		 */

		void add_or_update_mapping(Key const& key, Value const& value)
		{
			std::unique_lock<std::shared_mutex> lock(mutex);
			bucket_iterator const found = find_entry_for(key);
			if (found == data.end())
			{
				data.push_back(std::make_pair(key, value));
			}
			else
			{
				found->second = value;
			}
		}

		/**
		 * std::unique_lock<std::shared_mutex> 只允许单个线程持有
		 */

		void remove_mapping(Key const& key)
		{
			std::unique_lock<std::shared_mutex> lock(mutex);
			bucket_iterator const found = find_entry_for(key);
			if (found != data.end())
			{
				data.erase(found);
			}
		}
	};

	std::vector<std::unique_ptr<bucket_type>> buckets;
	Hash hasher;

	bucket_type& get_bucket(Key const& key) const
	{
		std::size_t const bucket_index = hasher(key) % buckets.size();
		return *buckets[bucket_index];
	}


public:
	using key_type = Key;
	using mapped_type = Value;
	using hash_type = Hash;

	threadsafe_lookup_table(std::size_t num_buckets = 23/*最好为质数*/, Hash const& hasher = Hash())
		: buckets(num_buckets), hasher(hasher)
	{
		for (auto& bucket : buckets)
		{
			bucket = std::make_unique<bucket_type>();
		}
	}

	threadsafe_lookup_table(threadsafe_lookup_table const& other) = delete;
	threadsafe_lookup_table& operator=(threadsafe_lookup_table const& other) = delete;

	Value value_for(Key const& key, Value const& default_value = Value()) const
	{
		return get_bucket(key).value_for(key, default_value);
	}

	void add_or_update_mapping(Key const& key, Value const& value)
	{
		get_bucket(key).add_or_update_mapping(key, value);
	}

	void remove_mapping(Key const& key)
	{
		get_bucket(key).remove_mapping(key);
	}

	// 快照
	std::map<Key, Value> get_map() const
	{
		std::vector<std::unique_lock<std::shared_mutex>> locks;


		// 持有所有bucket的锁 独占

		for (auto& bucket : buckets)
		{
			locks.push_back(std::unique_lock<std::shared_mutex>(bucket->mutex)); 
		}

		std::map<Key, Value> res;

		for (auto& bucket : buckets)
		{
			for (auto& item : bucket->data)
			{
				res.emplace(item.first, item.second);
			}
		}

		return res;
	}
};

细粒度(锁住每个node)

template <typename T>
class threadsafe_list
{
private:
	struct node
	{
		std::mutex m;
		std::shared_ptr<T> data;
		std::unique_ptr<node> next;

		node(): next()
		{
		}

		node(T const& value): data(std::make_shared<T>(value)), next()
		{
		}
	};

	node head;

public:
	threadsafe_list()
	{
	}

	~threadsafe_list()
	{
		remove_if([](node const&) { return true; });
	}

	threadsafe_list(threadsafe_list const& other) = delete;
	threadsafe_list& operator=(threadsafe_list const& other) = delete;

	void push_front(T const& value)
	{
		std::unique_ptr<node> new_node(new node(value));
		std::lock_guard<std::mutex> lock(head.m);
		new_node->next = std::move(head.next);
		head.next = std::move(new_node);
	}

	template <typename Function>
	void for_each(Function f)
	{
		node* current = &head;
		std::unique_lock<std::mutex> lock(head.m);
		while (node* const next = current->next.get())
		{
			std::unique_lock<std::mutex> next_lk(next->m);
			lock.unlock();
			f(*next->data);
			current = next;
			lock = std::move(next_lk);
		}
	}


	template <typename Predicate>
	std::shared_ptr<T> find_first_if(Predicate P)
	{
		node* current = &head;
		std::unique_lock<std::mutex> lock(head.m);
		while (node* const next = current->next.get())
		{
			std::unique_lock<std::mutex> next_lk(next->m);
			lock.unlock();
			if (P(*next->data))
			{
				return next->data;
			}
			current = next;
			lock = std::move(next_lk);
		}
		return std::shared_ptr<T>();
	}

	template <typename Predicate>
	void remove_if(Predicate p)
	{
		node* current = &head;
		std::unique_lock<std::mutex> lk(head.m);
		while (node* const next = current->next.get())
		{
			std::unique_lock<std::mutex> next_lk(next->m);
			if (p(*next->data))
			{
				std::unique_ptr<node> old_next = std::move(current->next);
				current->next = std::move(next->next);
				next_lk.unlock();
			}
			else
			{
				lk.unlock();
				current = next;
				lk = std::move(next_lk);
			}
		}
	}
};

参考《C++并发编程实战第二版》, 稍有改动

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值