#include <atomic> #include <future> #include <chrono> #include <vector> #include <iostream> #include <algorithm> #include <cassert> using namespace std; template<typename T> class waitfree_queue { public: waitfree_queue() : datalist_(0),freelist_(0){ return ; node *head = 0; for (int i=0;i<10000;++i){ node * item = new node; item->next=head; head = item; } freelist_ = head; } ~waitfree_queue(){ node *next = freelist_; while(next){ node *todel = next; next = todel->next; delete todel; } } //链表项 struct node { T data; node * next; }; //插入数据 void push(const T &data) { //从内存池中分配一个 node * n = alloc(); //拷贝数据 n->data = data; //放入队列 node * stale_head = datalist_.load(memory_order_relaxed); do { n->next = stale_head; } while (!datalist_.compare_exchange_weak(stale_head, n, memory_order_release)); } //一次提取一批数据,后插入的数据在链表的前面,先插入的在链表后面。 node * pop_all(void) { return datalist_.exchange(0, memory_order_consume); } private: node* alloc(){ //读取freelist_链表的头(可能陈旧) node *stale_head = freelist_.load(memory_order_relaxed); node* new_head = 0; do { if (stale_head==0){ node* n = new node; n->next=0; return n; } new_head = stale_head->next; } while (!freelist_.compare_exchange_weak(stale_head, new_head, memory_order_release)); } void free(node* n){ //计算n链表的尾 node *tail=n; while(tail->next){ tail = tail->next; } //读取freelist_链表的头(可能陈旧) node *stale_head = freelist_.load(memory_order_relaxed); do { tail->next = stale_head; } while (!freelist_.compare_exchange_weak(stale_head, n, memory_order_release)); } private: atomic<node*> datalist_ ; //数据传送的队列 atomic<node*> freelist_ ; //预分配的内存池 }; waitfree_queue<int> g_queue; atomic<int> done; void foo_imp(int& count){ // pop elements waitfree_queue<int>::node *head = g_queue.pop_all(); if (head==0){ std::this_thread::sleep_for(std::chrono::milliseconds(1)); return; } waitfree_queue<int>::node *ite = head; while(ite) { count+=1; ite = ite->next; } free(head); } int foo(){ int count=0; do{ foo_imp(count); } while(!done.load(memory_order_relaxed)); foo_imp(count); return count; } int bar(){ for(int j=0;j<1000;++j){ for (int i=0;i<2000;++i){ g_queue.push(1); } std::this_thread::sleep_for(std::chrono::milliseconds(1)); } return 0; } int main(){ done=0; auto f1 = async(std::launch::async,&foo); auto f2 = async(std::launch::async,&foo); auto f3 = async(std::launch::async,&foo); auto b1 = async(std::launch::async,&bar); auto b2 = async(std::launch::async,&bar); b1.get(); b2.get(); done.store(1,memory_order_relaxed); int count1=f1.get(); int count2=f2.get(); int count3=f3.get(); cout<<"OK "<<count1+count2+count3<<endl; return 0; }
学习笔记:改进的waitfree_queue,排除一些bug
最新推荐文章于 2023-03-16 01:17:47 发布