5.生产者消费者Producer-consumer
避免生产和消费速率的差异,引入一个channel,对二者解耦
实现producer的时候用到了两阶段模式的AbstractTerminatedThread:
Channel.java
package comsumerproducer;
/**
* Channel
*/
public interface Channel<P> {
/**
* 取出一个产品
* @return
* @throws InterruptedException
*/
P take() throws InterruptedException;
/**
* 生产一个产品放入通道
* @param product
* @throws InterruptedException
*/
void put(P product) throws InterruptedException;
}
BlockingQueueChannel implements Channel
package comsumerproducer;
import java.util.concurrent.BlockingQueue;
/**
* 基于阻塞队列的通道实现
*/
public class BlockingQueueChannel<P> implements Channel<P> {
private final BlockingQueue<P> queue;
public BlockingQueueChannel(BlockingQueue<P> queue) {
this.queue = queue;
}
@Override
public P take() throws InterruptedException {
return queue.take();
}
@Override
public void put(P product) throws InterruptedException {
queue.put(product);
}
}
AttachmentProcessor 角色Producer
package comsumerproducer;
import twophased.AbstractTerminatedThread;
import java.io.*;
import java.text.Normalizer;
import java.util.Random;
import java.util.concurrent.ArrayBlockingQueue;
/**
* 模式角色:producer
*/
public class AttachmentProcessor {
private final String ATTACHMENT_STORE_BASE_DIR =
"/home/viscent/tmp/attachments/";
/**
* 模式角色:Channel
*/
private final Channel<File> channel =
new BlockingQueueChannel<File>(new ArrayBlockingQueue<File>(200));
/**
* 模式角色:Consumer
*/
private final AbstractTerminatedThread indexingThread = new AbstractTerminatedThread() {
@Override
protected void doRun() throws Exception {
File file = null;
file = channel.take();
try {
indexFile(file);
} catch (Exception e) {
e.printStackTrace();
} finally {
terminationToken.reservations.decrementAndGet();
}
}
private void indexFile(File file) throws Exception {
/**
* 省略其他代码
*/
/**
* 模拟产生索引文件的时间消耗
*/
Random random = new Random();
try {
Thread.sleep(random.nextInt(100));
} catch (InterruptedException e) {
;
}
}
};
public void init() {
indexingThread.start();
}
public void shutdown() {
indexingThread.terminate();
}
public void saveAttachment(InputStream in, String documentId,
String originalFileName) throws Exception {
File file = saveAsFile(in, documentId, originalFileName);
try {
channel.put(file);
} catch (InterruptedException e) {
;
}
indexingThread.terminationToken.reservations.incrementAndGet();
}
private File saveAsFile(InputStream in, String documentId,
String originalFileName) throws IOException {
String dirName = ATTACHMENT_STORE_BASE_DIR + documentId;
File dir = new File(dirName);
dir.mkdirs();
File file = new File(dirName + "/"
+ Normalizer.normalize(originalFileName, Normalizer.Form.NFC));
// 防止目录跨越攻击
if (!dirName.equals(file.getCanonicalFile().getParent())) {
throw new SecurityException("Invalid originalFileName:" + originalFileName);
}
BufferedOutputStream bos = null;
BufferedInputStream bis = new BufferedInputStream(in);
byte[] buf = new byte[2048];
int len = -1;
try {
bos = new BufferedOutputStream(new FileOutputStream(file));
while ((len = bis.read(buf)) > 0) {
bos.write(buf, 0, len);
}
bos.flush();
} finally {
try {
bis.close();
} catch (IOException e) {
;
}
try {
if (null != bos) {
bos.close();
}
} catch (IOException e) {
;
}
}
return file;
}
}
通道积压
消费者处理过慢时会出现通道积压,需要进行处理,分以下两种:
1.使用有界阻塞队列:
ArrayBlockingQueue和有容量限制的LinkedBlockingQueue
2.使用带流量控制的无界阻塞队列:
不带容量控制的LinkedBlockingQueue。借助流量控制实现,对同一时间内可有多少个生产者线程往通道中存储产品进行限制。本例使用基于Semaphore的支持流量控制的实现
package comsumerproducer;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Semaphore;
/**
* 基于Semaphore的支持流量控制的通道实现
* @param <P>
*/
public class SemaphoreBasedChannel<P> implements Channel<P> {
private final BlockingQueue<P> queue;
private final Semaphore semaphore;
public SemaphoreBasedChannel(BlockingQueue<P> queue, int flowLimit) {
this.queue = queue;
this.semaphore = new Semaphore(flowLimit);
}
@Override
public P take() throws InterruptedException {
return queue.take();
}
@Override
public void put(P product) throws InterruptedException {
semaphore.acquire();
try {
queue.put(product);
} finally {
semaphore.release();
}
}
}
优化的工作窃取算法思想
producer-consumer中通常channel用queue实现,一个通道可对应一个或多个队列实例。现在本例仅使用一个ArrayBlockingQueue,如果有多个消费者线程从这个queue中获取产品,共享同一个实例。导致锁的竞争。
如果一个通道实例对应多个队列实例,就可以实现多个消费者线程从通道中取产品时候访问各自的队列实例。
如果一个消费者从自己的队列中取完任务,可以继续从其他消费者的队列中取出产品进行处理。
package comsumerproducer;
import java.util.concurrent.BlockingDeque;
public interface WorkStealingEnableChannel<P> extends Channel<P> {
P take(BlockingDeque<P> preferredQueue) throws InterruptedException;
}
package comsumerproducer;
import java.util.concurrent.BlockingDeque;
public class WorkStealingChannel<T> implements WorkStealingEnableChannel<T> {
/**
* 受管队列
* @param preferredQueue
* @return
* @throws InterruptedException
*/
private final BlockingDeque<T>[] managedQueue;
public WorkStealingChannel(BlockingDeque<T>[] managedQueue) {
this.managedQueue = managedQueue;
}
@Override
public T take(BlockingDeque<T> preferredQueue) throws InterruptedException {
/**
* 优先从指定的受管队列中取产品
*/
BlockingDeque<T> targetQueue = preferredQueue;
T product = null;
/**
* 试图从指定的队列队首取"产品"
*/
if (null != targetQueue) {
product = targetQueue.poll();
}
int queueIndex = -1;
while (null == product) {
queueIndex = (queueIndex + 1) % managedQueue.length;
targetQueue = managedQueue[queueIndex];
/**
* 试图从其他受管队列的队尾取
*/
product = targetQueue.pollLast();
if (preferredQueue == targetQueue) {
break;
}
}
if (null == product) {
/**
* 随机窃取其他受管队列的产品
*/
queueIndex = (int) (System.currentTimeMillis() % managedQueue.length);
targetQueue = managedQueue[queueIndex];
product = targetQueue.pollLast();
System.out.println("stealed from " + queueIndex + ":" + product);
}
return product;
}
@Override
public T take() throws InterruptedException {
return take(null);
}
@Override
public void put(T product) throws InterruptedException {
int targetIndex = (product.hashCode() % managedQueue.length);
BlockingDeque<T> targetQueue = managedQueue[targetIndex];
targetQueue.put(product);
}
}
package comsumerproducer;
import twophased.AbstractTerminatedThread;
import twophased.TerminationToken;
import java.util.Random;
import java.util.concurrent.BlockingDeque;
import java.util.concurrent.LinkedBlockingDeque;
/**
* 工作窃取算法
*/
public class WorkStealingExample {
private final WorkStealingEnableChannel<String> channel;
private final TerminationToken token = new TerminationToken();
public WorkStealingExample() {
int nCPU = Runtime.getRuntime().availableProcessors();
int consumerCount = nCPU / 2 + 1;
BlockingDeque<String>[] managedQueue = new LinkedBlockingDeque[consumerCount];
/**
* 该通道实例对应了多个queue
*/
channel = new WorkStealingChannel(managedQueue);
Consumer[] consumers = new Consumer[consumerCount];
for (int i = 0; i < consumerCount; i++) {
managedQueue[i] = new LinkedBlockingDeque<String>();
consumers[i] = new Consumer(token, managedQueue[i]);
}
for (int i = 0; i < nCPU; i++) {
new Producer().start();
}
for (int i = 0; i < consumerCount; i++) {
consumers[i].start();
}
}
public void doSomething() {
}
public static void main(String[] args) throws InterruptedException {
WorkStealingExample wse = new WorkStealingExample();
wse.doSomething();
Thread.sleep(3500);
}
private class Producer extends AbstractTerminatedThread {
private int i = 0;
@Override
protected void doRun() throws Exception {
channel.put(String.valueOf(i++));
token.reservations.incrementAndGet();
}
}
private class Consumer extends AbstractTerminatedThread {
private final BlockingDeque<String> workQueue;
public Consumer(TerminationToken token, BlockingDeque<String> workQueue) {
super(token);
this.workQueue = workQueue;
}
@Override
protected void doRun() throws Exception {
/**
* WorkStealingEnableChannel接口的take方法实现了工作窃取算法
*/
String produce = channel.take(workQueue);
System.out.println("processing product:" + produce);
/**
* 模拟执行真正操作的时间消耗
*/
try {
Thread.sleep(new Random().nextInt(50));
} catch (InterruptedException e) {
;
} finally {
token.reservations.decrementAndGet();
}
}
}
}