笔者也建立的自己的公众号啦,平时会分享一些编程知识,欢迎各位大佬支持~
扫码或微信搜索北风IT之路关注
生产者-消费者模式是一个经典的多线程设计模式,它为多线程间的协作提供了良好的解决方案。也经常有面试官会让手写一个生产者消费者,从代码细节可以看出你对多线程编程的熟练程度,今天我们来详细看一下如何写出一个生产者消费者模式,并且逐步对其优化争取做到高性能。
结构剖析
在生产者-消费者模式中,通常有两类线程,一类是生产者线程一类是消费者线程。生产者线程负责提交用户请求,消费者线程则负责处理生产者提交的任务。
最简单粗暴的做法就是生产者每提交一个任务,消费者就立即处理,直到所有任务处理完。但是这样直接通信很容易出现性能上的问题,消费者必须等待它的生产者提交到任务才能执行,就不能达到真正的并行。同时生产者和消费者之间存在依赖关系,在设计上耦合度非常高,这是不可取的。那么最好的做法就是加一个中间层作为通信桥梁。
生产者和消费者之间通过共享内存缓存区进行通信。多个生产者线程将任务提交给共享内存缓存区,消费者线程并不直接与生产者线程通信,而在共享内存缓冲区获取任务,并行地处理。其中内存缓冲区的主要功能是数据再多线程间的共享,同时还可以通过该缓存区,缓解生产者和消费者间的性能差。它是生产者消费者模式的核心组件,既能作为通信的桥梁,又能避免两者直接通信,从而将生产者和消费者进行解耦。生产者不需要消费者的存在,消费者也不需要知道生产者的存在。
由于内存缓冲区的存在,允许生产者和消费者在执行速度上有差异,无论哪一方速度超过另一方,缓冲区都可以得到缓解,确保系统正常运行。
生产者-消费者模式UML图如下:
实测
下面是我用IDEA的工具生成的UML图
其中Queue的选择最为关键,一般情况下应该会选择阻塞式队列,也就是BlockingQueue,但是它的并发效果并不好,所以我选择了ConcurrentLinkedQueue,笔者经过实测,生产和消费100万个数据,使用LinkedBlockingQueue耗时是38秒,并且在处理了防超量生产的情况下仍然产生了超量生产,其原因就是其并发效果不如后者,后者使用的是CAS无锁实现,而使用ConcurrentLinkedQueue耗时是27秒,无超量生产情况产生。其中ConcurrentLinkedQueue运行时占用内存最大时为546MB(jconsole测得)。
代码实现
Data类
package thread.producerConsumer;
/**
* @author beifengtz
* <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
* <p>location: thread.producerConsumer.javase_learning</p>
* Created in 19:11 2019/5/19
*/
public class Data {
private final int data;
public Data(int data) {
this.data = data;
}
public int getData() {
return data;
}
@Override
public String toString() {
return "Data{" +
"data=" + data +
'}';
}
}
生产者Producer
package thread.producerConsumer;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
/**
* @author beifengtz
* <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
* <p>location: thread.producerConsumer.javase_learning</p>
* Created in 19:11 2019/5/19
*/
public class Producer implements Runnable {
// 缓冲队列
private ConcurrentLinkedQueue<Data> queue;
// 每个生产者生产数据的数量
private int produceNum;
// 任务计数,同时模拟数据生产
private static AtomicInteger count = new AtomicInteger(0);
// 需要生产的总任务数量,出处有耦合性,实际开发时应当处理。
private static AtomicInteger total = new AtomicInteger(Main.NUMS);
public Producer(ConcurrentLinkedQueue<Data> queue, int produceNum) {
this.queue = queue;
this.produceNum = produceNum;
}
@Override
public void run() {
Data data = null;
// 限定生产,防止超量生产
if (total.get() <= 0) {
Thread.currentThread().interrupt();
}
for (int i = 0; i < produceNum; i++) {
// 再次检测,防止超量生产
if (total.get() > 0) {
// 构造数据
data = new Data(count.incrementAndGet());
// 向缓冲队列中提交
if (!queue.offer(data)) {
System.out.println("Filed to put data:" + data);
} else {
System.out.println("Successfully produced data:" + data);
}
// 需要生产的总量-1
total.decrementAndGet();
} else {
Thread.currentThread().interrupt();
break;
}
}
}
}
消费者Consumer
package thread.producerConsumer;
import java.text.MessageFormat;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
/**
* @author beifengtz
* <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
* <p>location: thread.producerConsumer.javase_learning</p>
* Created in 19:11 2019/5/19
*/
public class Consumer implements Runnable {
private ConcurrentLinkedQueue<Data> queue;
private CountDownLatch countDown;
public Consumer(ConcurrentLinkedQueue<Data> queue, CountDownLatch countDown) {
this.queue = queue;
this.countDown = countDown;
}
@Override
public void run() {
Data data = queue.poll();
if (data != null) {
// 计算+1
int res = data.getData() + 1;
System.out.println("Data processing: " + MessageFormat.format("{0}+1={1}", data.getData(), res));
// 记录该线程任务已结束
countDown.countDown();
}
}
}
Main运行主类
package thread.producerConsumer;
import java.util.concurrent.*;
/**
* @author beifengtz
* <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
* <p>location: thread.producerConsumer.javase_learning</p>
* Created in 19:10 2019/5/19
*/
public class Main {
public static final int NUMS = 1000000;
public static void main(String[] args) {
// 定义缓冲队列
ConcurrentLinkedQueue<Data> queue = new ConcurrentLinkedQueue<>();
// 创建生产者和消费者线程池,为防止冲爆内存,这里限定线程数
ExecutorService producerEs = Executors.newFixedThreadPool(100);
ExecutorService consumerEs = Executors.newFixedThreadPool(100);
// 定义计时器,等待消费者全部消费完
CountDownLatch countDown = new CountDownLatch(NUMS);
long start = System.currentTimeMillis();
long end = System.currentTimeMillis();
// 向线程池中提交 NUMS 个生产任务
for (int i = 0; i < NUMS; i++) {
// 每个生产者以3倍的量生产,模拟生产者和消费者的速度差
producerEs.execute(new Producer(queue,3));
consumerEs.execute(new Consumer(queue, countDown));
}
try {
countDown.await();
end = System.currentTimeMillis();
producerEs.shutdown();
consumerEs.shutdown();
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println("Total time:" + (end - start));
}
}
Disruptor
Disruptor是一个高效的无锁内存队列,它使用无锁的方式实现了一个有界环形队列,非常适合生产者-消费者模式。它相比于BlockingQueue具有更多的特点:
- 同一个“事件”可以有多个消费者,消费者之间既可以并行处理,也可以相互依赖形成处理的先后次序(形成一个依赖图);
- 预分配用于存储事件内容的内存空间;
- 针对极高的性能目标而实现的极度优化和无锁的设计。
这里简单引用另一篇文章(https://www.jianshu.com/p/8473bbb556af)中的一部分Disruptor介绍
- Ring Buffer
如其名,环形的缓冲区。曾经 RingBuffer 是 Disruptor 中的最主要的对象,但从3.0版本开始,其职责被简化为仅仅负责对通过 Disruptor 进行交换的数据(事件)进行存储和更新。在一些更高级的应用场景中,Ring Buffer 可以由用户的自定义实现来完全替代。
- Sequence Disruptor
通过顺序递增的序号来编号管理通过其进行交换的数据(事件),对数据(事件)的处理过程总是沿着序号逐个递增处理。一个 Sequence 用于跟踪标识某个特定的事件处理者( RingBuffer/Consumer )的处理进度。虽然一个 AtomicLong 也可以用于标识进度,但定义 Sequence 来负责该问题还有另一个目的,那就是防止不同的 Sequence 之间的CPU缓存伪共享(Flase Sharing)问题。
(注:这是 Disruptor 实现高性能的关键点之一,网上关于伪共享问题的介绍已经汗牛充栋,在此不再赘述)。
- Sequencer
Sequencer 是 Disruptor 的真正核心。此接口有两个实现类 SingleProducerSequencer、MultiProducerSequencer ,它们定义在生产者和消费者之间快速、正确地传递数据的并发算法。
- Sequence Barrier
用于保持对RingBuffer的 main published Sequence 和Consumer依赖的其它Consumer的 Sequence 的引用。 Sequence Barrier 还定义了决定 Consumer 是否还有可处理的事件的逻辑。
- Wait Strategy
定义 Consumer 如何进行等待下一个事件的策略。 (注:Disruptor 定义了多种不同的策略,针对不同的场景,提供了不一样的性能表现)
- Event
在 Disruptor 的语义中,生产者和消费者之间进行交换的数据被称为事件(Event)。它不是一个被 Disruptor 定义的特定类型,而是由 Disruptor 的使用者定义并指定。
- EventProcessor
EventProcessor 持有特定消费者(Consumer)的 Sequence,并提供用于调用事件处理实现的事件循环(Event Loop)。
- EventHandler
Disruptor 定义的事件处理接口,由用户实现,用于处理事件,是 Consumer 的真正实现。
- Producer
即生产者,只是泛指调用 Disruptor 发布事件的用户代码,Disruptor 没有定义特定接口或类型。
Disruptor代码实现
笔者简单地写了一个用Disruptor实现的程序,结果惊人!同样是100万个数据,时间消耗只用了22秒,时间上和前面ConcurrentLinkedQueue差别不大,但有所提升,但是它的内存消耗居然最高时仅使用了60MB!!!在内存上节省了9倍的空间,可以看到这种内存复用的策略很有效果,在实际开发中很有帮助。
提示:Disruptor是第三方包,需要下载jar包导入项目或者使用Maven导入
数据类Data
package thread.producerConsumer.disruptor;
/**
* @author beifengtz
* <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
* <p>location: thread.producerConsumer.disruptor.javase_learning</p>
* Created in 21:57 2019/5/19
*/
public class Data {
private long value;
public long getValue() {
return value;
}
public void setValue(long value) {
this.value = value;
}
}
数据生产工厂
package thread.producerConsumer.disruptor;
import com.lmax.disruptor.EventFactory;
/**
* @author beifengtz
* <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
* <p>location: thread.producerConsumer.disruptor.javase_learning</p>
* Created in 21:56 2019/5/19
*/
public class DataFactory implements EventFactory<Data> {
@Override
public Data newInstance() {
return new Data();
}
}
生产者Prodocer
package thread.producerConsumer.disruptor;
import com.lmax.disruptor.RingBuffer;
import java.nio.ByteBuffer;
/**
* @author beifengtz
* <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
* <p>location: thread.producerConsumer.disruptor.javase_learning</p>
* Created in 21:59 2019/5/19
*/
public class Producer {
private final RingBuffer<Data> ringBuffer;
public Producer(RingBuffer<Data> ringBuffer) {
this.ringBuffer = ringBuffer;
}
public void pushData(ByteBuffer bb){
long sequence = ringBuffer.next();
try{
Data even = ringBuffer.get(sequence);
even.setValue(bb.getLong(0));
}finally {
ringBuffer.publish(sequence);
}
}
}
消费者Consumer
package thread.producerConsumer.disruptor;
import com.lmax.disruptor.WorkHandler;
import java.text.MessageFormat;
/**
* @author beifengtz
* <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
* <p>location: thread.producerConsumer.disruptor.javase_learning</p>
* Created in 21:51 2019/5/19
*/
public class Consumer implements WorkHandler<Data> {
@Override
public void onEvent(Data data) throws Exception {
long res = data.getValue() + 1;
System.out.println(MessageFormat.format("Data processing: {0} + 1 = {1}",data.getValue(),res));
}
}
运行主类Main
package thread.producerConsumer.disruptor;
import com.lmax.disruptor.RingBuffer;
import com.lmax.disruptor.dsl.Disruptor;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadFactory;
/**
* @author beifengtz
* <a href='http://www.beifengtz.com'>www.beifengtz.com</a>
* <p>location: thread.producerConsumer.disruptor.javase_learning</p>
* Created in 21:50 2019/5/19
*/
public class Main {
private static final int NUMS = 10;
private static final int SUM = 1000000;
public static void main(String[] args) {
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
e.printStackTrace();
}
long start = System.currentTimeMillis();
DataFactory factory = new DataFactory();
int bufferSize = 1024;
Disruptor<Data> disruptor = new Disruptor<>(factory, bufferSize, new ThreadFactory() {
@Override
public Thread newThread(Runnable r) {
return new Thread(r);
}
});
// 构造100个消费者
Consumer[] consumers = new Consumer[NUMS];
for (int i = 0; i < NUMS; i++) {
consumers[i] = new Consumer();
}
disruptor.handleEventsWithWorkerPool(consumers);
disruptor.start();
RingBuffer<Data> ringBuffer = disruptor.getRingBuffer();
Producer producer = new Producer(ringBuffer);
ByteBuffer bb = ByteBuffer.allocate(8);
for (long l = 0; l < SUM; l++) {
bb.putLong(0, l);
producer.pushData(bb);
System.out.println("Successfully produced data: " + l);
}
long end = System.currentTimeMillis();
disruptor.shutdown();
System.out.println("Total time: " + (end - start));
}
}