Rust 实现线程安全的 Lock Free 计数器

完整代码:https://github.com/chiehw/hello_rust/blob/main/crates/counter/src/lib.rs

定义 Trait

Trait 可以看作是一种能力的抽象,和接口有点类似。Trait 还能作为泛型约束条件,作为参数的限制条件。

pub trait AtomicCounter: Send + Sync {
  type PrimitiveType;

  fn get(&self) -> Self::PrimitiveType;	// 获取当前计数器的值。
  fn increase(&self) -> Self::PrimitiveType;	// 自增,并返回上一次的值
  fn add(&self, count: Self::PrimitiveType) -> Self::PrimitiveType;	// 添加一个数,并返回上一次的值
  fn reset(&self) -> Self::PrimitiveType;	// 重置计数器
  fn into_inner(self) -> Self::PrimitiveType;	// 获取内部值
}

简单的测试用例TDD

使用测试驱动开发可以让目标更明确,这里先写个简单的测试案例。

#[cfg(test)]
mod tests {
  use super::*;

  fn test_simple<Counter>(counter: Counter)
  where
    Counter: AtomicCounter<PrimitiveType = usize>,	// 使用 Trait 作为泛型约束条件
  {
    counter.reset();
    assert_eq!(0, counter.add(5));
    assert_eq!(5, counter.increase());
    assert_eq!(6, counter.get())
  }

  #[test]
  fn it_works() {
    test_simple(RelaxedCounter::new(10));
  }
}

亿点细节

直接封装 AtomicUsize

#[derive(Default, Debug)]
pub struct ConsistentCounter(AtomicUsize);

impl ConsistentCounter {
  pub fn new(init_num: usize) -> ConsistentCounter {
    ConsistentCounter(AtomicUsize::new(init_num))
  }
}

impl AtomicCounter for ConsistentCounter {
  type PrimitiveType = usize;

  fn get(&self) -> Self::PrimitiveType {
    self.0.load(Ordering::SeqCst)
  }

  fn increase(&self) -> Self::PrimitiveType {
    self.add(1)
  }

  fn add(&self, count: Self::PrimitiveType) -> Self::PrimitiveType {
    self.0.fetch_add(count, Ordering::SeqCst)
  }

  fn reset(&self) -> Self::PrimitiveType {
    self.0.swap(0, Ordering::SeqCst)
  }

  fn into_inner(self) -> Self::PrimitiveType {
    self.0.into_inner()
  }
}

增加测试用例

使用多线程同时对计数器进行操作,然后判断计数的结果是否正确。更多的测试案例请查看【完整代码】

fn test_increase<Counter>(counter: Arc<Counter>)
  where
    Counter: AtomicCounter<PrimitiveType = usize> + Debug + 'static,
  {
    println!("[+] test_increase: Spawning {} thread, each with {}", NUM_THREADS, NUM_ITERATIONS);
    let mut join_handles = Vec::new();
    // 创建 NUM_THREADS 个线程,同时使用 increase 函数
    for _ in 0..NUM_THREADS {
      let counter_ref = counter.clone();
      join_handles.push(thread::spawn(move || {
        let counter: &Counter = counter_ref.deref();
        for _ in 0..NUM_ITERATIONS {
          counter.increase();
        }
      }));
    }
    // 等待线程完成
    for handle in join_handles {
      handle.join().unwrap();
    }
    let count = Arc::try_unwrap(counter).unwrap().into_inner();
    let excepted_num = NUM_ITERATIONS * NUM_THREADS;
    println!("[+] test_increase: get count {}, excepted num is {}", count, excepted_num);
    // 确定 count 正确
    assert_eq!(count, excepted_num)
  }

参考教程:

  • 谈谈 C++ 中的内存顺序 (Memory Order):https://luyuhuang.tech/2022/06/25/cpp-memory-order.html#happens-before
  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值