Rust__异步mpsc_ channel的基本设计
1.基本概念
由于在多个线程间共享数据结构容易产生线程安全问题,所以在某些场景下在线程间使用消息发送的方式进行通信,更加安全方便。Go语言中的channel便是经典的案例。Effective Go中说道不要通过共享内存进行通信,应该通过通信的方式共享内存。
mpsc代表的含义是Multi producer, Single consumer FIFO queue,也就是多生产者单消费者先入先出队列。异步的mpsc channel结构具有如下的特点:
(1)发送端和接收端存在一个缓冲区,发送者将消息发送后即可返回,不用等待接收端处理消息。
(2)可以同时存在多个发送者,但是只能有一个消费者。
(3)发送的消息类型为泛型参数T,则T应该满足T: Send约束。
(4)发送者使用send()发送,接收者使用recv()接收,返回类型都是Result类型。当发送者全部被销毁且缓冲区无数据时,接收者调用recv()会返回错误;存在发送者且缓冲区为空时,接收者调用recv()会阻塞当前线程,直到收到消息才会继续向下执行;当接收者销毁时,发送者调用send()会返回错误。
2.测试编写
我们采用TDD的方式设计并实现该结构,根据前文对该channel结构的特点分析,确定测试函数。首先根据特点(1),可以编写如下函数:
#[test]
fn channel_should_works() {
let (mut s, mut r) = async_channel();
s.send(1).unwrap();
let msg = r.recv().unwrap();
assert_eq!(1, msg);
}
根据特点(2),编写如下测试函数:
#[test]
fn multi_sender_should_work() {
let (mut s, mut r) = async_channel();
let mut v = vec![];
for i in 0..2 {
let mut ts = s.clone();
thread::spawn(move || {
ts.send(i).unwrap();
})
.join()
.unwrap();
}
s.send(2).unwrap();
//如果不drop,会出现一直存在发送者,缓冲区被全部读取后接收者线程持续阻塞的情况。
drop(s);
while let Ok(res) = r.recv() {
v.push(res);
}
//由于多个线程发送的顺序不确定,需要排序
v.sort();
assert_eq!(v, [0, 1, 2]);
}
根据特点(4)可以编写如下测试函数:
#[test]
//当所有的发送者都销毁时继续接收应当返回错误
fn all_sender_drop_should_error_when_receive() {
let (s, mut r) = async_channel();
let s1 = s.clone();
let senders = [s, s1];
let senders_num = senders.len();
//发送者移动到线程内,线程结束时将会被销毁
for mut sender in senders {
thread::spawn(move || {
sender.send("hello").unwrap();
})
.join()
.unwrap();
}
//在缓冲区不为空时,无论是否存在发送者均可以进行接收
for _ in 0..senders_num {
r.recv().unwrap();
}
//接收完所有在缓冲队列中的数据时,继续接收会报错
assert!(r.recv().is_err());
}
#[test]
//当缓冲区为空时,接收者继续调用recv(),应当阻塞接收者线程
fn receiver_should_be_blocked_when_queue_empty() {
let (s, r) = async_channel();
let mut s1 = s.clone();
let mut s2 = s.clone();
thread::spawn(move || {
for (idx, i) in r.into_iter().enumerate() {
assert_eq!(idx, i);
}
//如果线程阻塞则无法执行到该步骤,注意Iterator trait要通过recv()方法进行的实现。
assert!(false);
});
thread::spawn(move || {
for i in 0..100usize {
s1.send(i).unwrap();
}
});
//留点时间让接收者全部接收
thread::sleep(Duration::from_millis(1));
thread::spawn(move || {
for i in 100..200usize {
s2.send(i).unwrap();
}
});
thread::sleep(Duration::from_millis(1));
//已经接收完所有的消息,任务队列为空
assert_eq!(s.get_queued_items(), 0);
}
#[test]
//当接收者全部销毁时,继续发送会返回错误
fn receiver_drop_should_error_when_send() {
let (mut s, _) = async_channel();
let mut s1 = s.clone();
assert!(s.send(0).is_err());
assert!(s1.send(1).is_err());
}
#[test]
//特别情况!当接收者线程阻塞的时候,此时发送者全部销毁,可能会造成没有线程通知接收者
//导致接收者线程持续阻塞的情况。通过发送者的drop实现可以避免该问题。
fn all_sender_drop_when_receiver_block_should_work() {
let (mut s, mut r) = async_channel();
let mut v = 0;
thread::spawn(move || {
s.send(1).unwrap();
//使得接收者线程在阻塞时drop;
thread::sleep(Duration::from_millis(100));
});
while let Ok(res) = r.recv() {
v = res;
}
assert_eq!(v, 1);
}
}
3.基本框架
根据前文提到的特点进行分析,可以得到基本的数据结构和方法。根据特点(1)和(2),由于存在多个线程可以并发访问的缓冲区,且缓冲区满足FIFO,这里可以考虑使用Mutex<VecDeque<T>>
结构存储消息。
根据特点(4)在缓冲区为空的时候接收者继续调用recv(),需要阻塞接收者线程。可以使用Condvar
条件变量来控制线程的阻塞,以及发送者对接收者的通知唤醒。同时还要考虑到发送者在接收者全部销毁时继续发送会返回错误,接收者在缓冲区为空且发送者全部销毁的情况下继续接收也会返回错误;则需要对发送者和接收者的数量进行计数。
考虑到可能有多个线程会并发的对发送者进行clone或drop,以及对接收者数量进行读取,所以这里发送者和接收者的计数都需要采用原子类型。基本数据结构如下:
pub struct Shared<T> {
queue: Mutex<VecDeque<T>>,
available: Condvar,
senders_num: AtomicUsize,
receivers_num: AtomicUsize,
}
pub struct Sender<T> {
shared: Arc<Shared<T>>,
}
pub struct Receiver<T> {
shared: Arc<Shared<T>>,
}
在基本的方法方面,通过特点分析可知,需要为发送者和接收者实现如下方法:
impl<T> Sender<T> {
pub fn send(&mut self, t: T) -> Result<()> {
todo!();
}
pub fn get_receivers_num(&self) -> usize {
todo!();
}
pub fn get_queued_items(&self) -> usize {
todo!();
}
}
impl<T> Receiver<T> {
pub fn recv(&mut self) -> Result<T> {
todo!();
}
pub fn get_senders_num(&self) -> usize {
todo!();
}
}
还需要为发送者和接收者实现一些trait:
impl<T> Iterator for Receiver<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
todo!();
}
}
//mpsc channel 只有发送者需要实现Clone
impl<T> Clone for Sender<T> {
//克隆方法只需要增加引用计数即可
fn clone(&self) -> Self {
todo!();
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
todo!();
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
todo!();
}
}
实现一些公共方法:
// async_channel的创建
pub fn async_channel<T>() -> (Sender<T>, Receiver<T>) {
todo!();
}
//构造Shared<T>
impl<T> Default for Shared<T> {
fn default() -> Self {
todo!();
}
}
4.完整代码实现(注释讲解)
use anyhow::{anyhow, Result};
use std::{
collections::VecDeque,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Condvar, Mutex,
},
};
pub struct Shared<T> {
queue: Mutex<VecDeque<T>>,
available: Condvar,
senders_num: AtomicUsize,
receivers_num: AtomicUsize,
}
pub struct Sender<T> {
shared: Arc<Shared<T>>,
}
pub struct Receiver<T> {
shared: Arc<Shared<T>>,
}
impl<T> Sender<T> {
pub fn send(&mut self, t: T) -> Result<()> {
//如果没有接收者直接返回错误
if self.get_receivers_num() == 0 {return Err(anyhow!("no receiver"));}
//检查消息队列在push前是否为空,然后再push消息
let was_empty = {
let mut inner = self.shared.queue.lock().unwrap();
let empty = inner.is_empty();
inner.push_back(t);
empty
};
//如果消息队列在push前为空,可能有接收者线程阻塞,使用condvar通知
if was_empty {
self.shared.available.notify_one();
}
Ok(())
}
pub fn get_receivers_num(&self) -> usize {
//Ordering::SeqCst,严格内存序保证多个接收者线程能读到最新值
self.shared.receivers_num.load(Ordering::SeqCst)
}
pub fn get_queued_items(&self) -> usize {
//访问共享数据结构之前先加锁
let inner = self.shared.queue.lock().unwrap();
inner.len()
}
}
impl<T> Receiver<T> {
pub fn recv(&mut self) -> Result<T> {
let mut inner = self.shared.queue.lock().unwrap();
loop {
match inner.pop_front() {
//队列存在消息。直接返回消息
Some(v) => {return Ok(v)},
//队列没有消息且发送者都已经drop,返回错误
None if self.get_senders_num() == 0 => {
return Err(anyhow!("no sender!"));
},
//队列没有消息还存在发送者,阻塞线程
//wait()释放锁并挂起线程,等收到notify再拿回锁,重新初始化inner
None => {
inner = self.shared
.available
.wait(inner)
.map_err(|_| anyhow!("lock error"))?;
}
}
}
}
pub fn get_senders_num(&self) -> usize {
self.shared.senders_num.load(Ordering::SeqCst)
}
}
impl<T> Iterator for Receiver<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.recv().ok()
}
}
//mpsc channel 只有发送者需要实现Clone
impl<T> Clone for Sender<T> {
//克隆方法只需要增加引用计数即可
fn clone(&self) -> Self {
self.shared.senders_num.fetch_add(1, Ordering::AcqRel);
Self {
shared: self.shared.clone(),
}
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
let old = self.shared.senders_num.fetch_sub(1, Ordering::AcqRel);
//防止接收者线程阻塞时,发送者线程都结束,无法唤醒接收者线程
if old <= 1 {
self.shared.available.notify_all();
}
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
self.shared.receivers_num.fetch_sub(1, Ordering::SeqCst);
}
}
pub fn async_channel<T>() -> (Sender<T>, Receiver<T>) {
let shared = Arc::new(Shared::default());
(
Sender {shared: shared.clone()},
Receiver {shared},
)
}
const INIT_SIZE: usize = 32;
impl<T> Default for Shared<T> {
fn default() -> Self {
Self {
queue: Mutex::new(VecDeque::with_capacity(INIT_SIZE)),
available: Condvar::new(),
senders_num: AtomicUsize::new(1),
receivers_num: AtomicUsize::new(1),
}
}
}
#[cfg(test)]
mod tests {
use std::{thread, time::Duration};
use super::*;
#[test]
fn channel_should_works() {
let (mut s, mut r) = async_channel();
s.send(1).unwrap();
let msg = r.recv().unwrap();
assert_eq!(1, msg);
}
#[test]
fn multi_sender_should_work() {
let (mut s, mut r) = async_channel();
let mut v = vec![];
for i in 0..2 {
let mut ts = s.clone();
thread::spawn(move || {
ts.send(i).unwrap();
})
.join()
.unwrap();
}
s.send(2).unwrap();
//如果不drop,会出现接收者线程持续阻塞的情况。
drop(s);
while let Ok(res) = r.recv() {
v.push(res);
}
v.sort();
assert_eq!(v, [0, 1, 2]);
}
#[test]
fn all_sender_drop_should_error_when_receive() {
let (s, mut r) = async_channel();
let s1 = s.clone();
let senders = [s, s1];
let senders_num = senders.len();
for mut sender in senders {
thread::spawn(move || {
sender.send("hello").unwrap();
})
.join()
.unwrap();
}
for _ in 0..senders_num {
r.recv().unwrap();
}
//接收完所有在缓冲队列中的数据时,继续接收会报错
assert!(r.recv().is_err());
}
#[test]
fn receiver_should_be_blocked_when_queue_empty() {
let (s, r) = async_channel();
let mut s1 = s.clone();
let mut s2 = s.clone();
thread::spawn(move || {
for (idx, i) in r.into_iter().enumerate() {
assert_eq!(idx, i);
}
//如果线程阻塞则无法执行到该步骤
assert!(false);
});
thread::spawn(move || {
for i in 0..100usize {
s1.send(i).unwrap();
}
});
thread::sleep(Duration::from_millis(1));
thread::spawn(move || {
for i in 100..200usize {
s2.send(i).unwrap();
}
});
thread::sleep(Duration::from_millis(1));
//已经接收完所有的消息,任务队列为空
assert_eq!(s.get_queued_items(), 0);
}
#[test]
fn receiver_drop_should_error_when_send() {
let (mut s, _) = async_channel();
let mut s1 = s.clone();
assert!(s.send(0).is_err());
assert!(s1.send(1).is_err());
}
#[test]
fn all_sender_drop_when_receiver_block_should_work() {
let (mut s, mut r) = async_channel();
let mut v = 0;
thread::spawn(move || {
s.send(1).unwrap();
//使得接收者线程在阻塞时drop;
thread::sleep(Duration::from_millis(100));
});
while let Ok(res) = r.recv() {
v = res;
}
assert_eq!(v, 1);
}
}
5.参考文献
极客时间《Rust第一课》陈天
《深入浅出Rust》范长春