多线程编程模式-Producer-consumer

5.生产者消费者Producer-consumer

避免生产和消费速率的差异,引入一个channel,对二者解耦

实现producer的时候用到了两阶段模式的AbstractTerminatedThread:

Channel.java
package comsumerproducer;

/**
 * Channel
 */
public interface Channel<P> {
    /**
     * 取出一个产品
     * @return
     * @throws InterruptedException
     */
    P take() throws InterruptedException;

    /**
     * 生产一个产品放入通道
     * @param product
     * @throws InterruptedException
     */
    void put(P product) throws InterruptedException;
}

BlockingQueueChannel implements Channel
package comsumerproducer;

import java.util.concurrent.BlockingQueue;

/**
 * 基于阻塞队列的通道实现
 */
public class BlockingQueueChannel<P> implements Channel<P> {
    private final BlockingQueue<P> queue;

    public BlockingQueueChannel(BlockingQueue<P> queue) {
        this.queue = queue;
    }

    @Override
    public P take() throws InterruptedException {
        return queue.take();
    }

    @Override
    public void put(P product) throws InterruptedException {
        queue.put(product);
    }
}

AttachmentProcessor 角色Producer
package comsumerproducer;

import twophased.AbstractTerminatedThread;

import java.io.*;
import java.text.Normalizer;
import java.util.Random;
import java.util.concurrent.ArrayBlockingQueue;

/**
 * 模式角色:producer
 */
public class AttachmentProcessor {
    private final String ATTACHMENT_STORE_BASE_DIR =
            "/home/viscent/tmp/attachments/";

    /**
     * 模式角色:Channel
     */
    private final Channel<File> channel =
            new BlockingQueueChannel<File>(new ArrayBlockingQueue<File>(200));

    /**
     * 模式角色:Consumer
     */
    private final AbstractTerminatedThread indexingThread = new AbstractTerminatedThread() {
        @Override
        protected void doRun() throws Exception {
            File file = null;
            file = channel.take();
            try {
                indexFile(file);
            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                terminationToken.reservations.decrementAndGet();
            }
        }

        private void indexFile(File file) throws Exception {
            /**
             * 省略其他代码
              */
            /**
             * 模拟产生索引文件的时间消耗
             */
            Random random = new Random();
            try {
                Thread.sleep(random.nextInt(100));
            } catch (InterruptedException e) {
                ;
            }
        }
    };

    public void init() {
        indexingThread.start();
    }

    public void shutdown() {
        indexingThread.terminate();
    }

    public void saveAttachment(InputStream in, String documentId,
                               String originalFileName) throws Exception {
        File file = saveAsFile(in, documentId, originalFileName);
        try {
            channel.put(file);
        } catch (InterruptedException e) {
            ;
        }
        indexingThread.terminationToken.reservations.incrementAndGet();
    }

    private File saveAsFile(InputStream in, String documentId,
                            String originalFileName) throws IOException {
        String dirName = ATTACHMENT_STORE_BASE_DIR + documentId;
        File dir = new File(dirName);
        dir.mkdirs();
        File file = new File(dirName + "/"
                + Normalizer.normalize(originalFileName, Normalizer.Form.NFC));

        // 防止目录跨越攻击
        if (!dirName.equals(file.getCanonicalFile().getParent())) {
            throw new SecurityException("Invalid originalFileName:" + originalFileName);
        }

        BufferedOutputStream bos = null;
        BufferedInputStream bis = new BufferedInputStream(in);
        byte[] buf = new byte[2048];
        int len = -1;
        try {
            bos = new BufferedOutputStream(new FileOutputStream(file));
            while ((len = bis.read(buf)) > 0) {
                bos.write(buf, 0, len);
            }
            bos.flush();
        } finally {
            try {
                bis.close();
            } catch (IOException e) {
                ;
            }
            try {
                if (null != bos) {
                    bos.close();
                }
            } catch (IOException e) {
                ;
            }
        }
        return file;
    }
}

 

通道积压

消费者处理过慢时会出现通道积压,需要进行处理,分以下两种:

1.使用有界阻塞队列:

ArrayBlockingQueue和有容量限制的LinkedBlockingQueue

2.使用带流量控制的无界阻塞队列:

不带容量控制的LinkedBlockingQueue。借助流量控制实现,对同一时间内可有多少个生产者线程往通道中存储产品进行限制。本例使用基于Semaphore的支持流量控制的实现

package comsumerproducer;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Semaphore;

/**
 * 基于Semaphore的支持流量控制的通道实现
 * @param <P>
 */
public class SemaphoreBasedChannel<P> implements Channel<P> {
    private final BlockingQueue<P> queue;
    private final Semaphore semaphore;

    public SemaphoreBasedChannel(BlockingQueue<P> queue, int flowLimit) {
        this.queue = queue;
        this.semaphore = new Semaphore(flowLimit);
    }

    @Override
    public P take() throws InterruptedException {
        return queue.take();
    }

    @Override
    public void put(P product) throws InterruptedException {
        semaphore.acquire();
        try {
            queue.put(product);
        } finally {
            semaphore.release();
        }
    }
}

优化的工作窃取算法思想

producer-consumer中通常channel用queue实现,一个通道可对应一个或多个队列实例。现在本例仅使用一个ArrayBlockingQueue,如果有多个消费者线程从这个queue中获取产品,共享同一个实例。导致锁的竞争。

如果一个通道实例对应多个队列实例,就可以实现多个消费者线程从通道中取产品时候访问各自的队列实例。

如果一个消费者从自己的队列中取完任务,可以继续从其他消费者的队列中取出产品进行处理。

package comsumerproducer;

import java.util.concurrent.BlockingDeque;

public interface WorkStealingEnableChannel<P> extends Channel<P> {
    P take(BlockingDeque<P> preferredQueue) throws InterruptedException;
}

package comsumerproducer;

import java.util.concurrent.BlockingDeque;

public class WorkStealingChannel<T> implements WorkStealingEnableChannel<T> {
    /**
     * 受管队列
     * @param preferredQueue
     * @return
     * @throws InterruptedException
     */
    private final BlockingDeque<T>[] managedQueue;

    public WorkStealingChannel(BlockingDeque<T>[] managedQueue) {
        this.managedQueue = managedQueue;
    }

    @Override
    public T take(BlockingDeque<T> preferredQueue) throws InterruptedException {
        /**
         * 优先从指定的受管队列中取产品
         */
        BlockingDeque<T> targetQueue = preferredQueue;
        T product = null;

        /**
         * 试图从指定的队列队首取"产品"
         */
        if (null != targetQueue) {
            product = targetQueue.poll();
        }

        int queueIndex = -1;
        while (null == product) {
            queueIndex = (queueIndex + 1) % managedQueue.length;
            targetQueue = managedQueue[queueIndex];
            /**
             * 试图从其他受管队列的队尾取
             */
            product = targetQueue.pollLast();
            if (preferredQueue == targetQueue) {
                break;
            }
        }

        if (null == product) {
            /**
             * 随机窃取其他受管队列的产品
             */
            queueIndex = (int) (System.currentTimeMillis() % managedQueue.length);
            targetQueue = managedQueue[queueIndex];
            product = targetQueue.pollLast();
            System.out.println("stealed from " + queueIndex + ":" + product);
        }
        return product;
    }

    @Override
    public T take() throws InterruptedException {
        return take(null);
    }

    @Override
    public void put(T product) throws InterruptedException {
        int targetIndex = (product.hashCode() % managedQueue.length);
        BlockingDeque<T> targetQueue = managedQueue[targetIndex];
        targetQueue.put(product);
    }
}
package comsumerproducer;

import twophased.AbstractTerminatedThread;
import twophased.TerminationToken;

import java.util.Random;
import java.util.concurrent.BlockingDeque;
import java.util.concurrent.LinkedBlockingDeque;

/**
 * 工作窃取算法
 */
public class WorkStealingExample {
    private final WorkStealingEnableChannel<String> channel;

    private final TerminationToken token = new TerminationToken();

    public WorkStealingExample() {
        int nCPU = Runtime.getRuntime().availableProcessors();
        int consumerCount = nCPU / 2 + 1;

        BlockingDeque<String>[] managedQueue = new LinkedBlockingDeque[consumerCount];

        /**
         * 该通道实例对应了多个queue
         */
        channel = new WorkStealingChannel(managedQueue);

        Consumer[] consumers = new Consumer[consumerCount];

        for (int i = 0; i < consumerCount; i++) {
            managedQueue[i] = new LinkedBlockingDeque<String>();
            consumers[i] = new Consumer(token, managedQueue[i]);
        }

        for (int i = 0; i < nCPU; i++) {
            new Producer().start();
        }

        for (int i = 0; i < consumerCount; i++) {
            consumers[i].start();
        }
    }

    public void doSomething() {

    }

    public static void main(String[] args) throws InterruptedException {
        WorkStealingExample wse = new WorkStealingExample();

        wse.doSomething();
        Thread.sleep(3500);
    }

    private class Producer extends AbstractTerminatedThread {
        private int i = 0;

        @Override
        protected void doRun() throws Exception {
            channel.put(String.valueOf(i++));
            token.reservations.incrementAndGet();
        }
    }

    private class Consumer extends AbstractTerminatedThread {
        private final BlockingDeque<String> workQueue;

        public Consumer(TerminationToken token, BlockingDeque<String> workQueue) {
            super(token);
            this.workQueue = workQueue;
        }

        @Override
        protected void doRun() throws Exception {
            /**
             * WorkStealingEnableChannel接口的take方法实现了工作窃取算法
             */
            String produce = channel.take(workQueue);
            System.out.println("processing product:" + produce);

            /**
             * 模拟执行真正操作的时间消耗
             */
            try {
                Thread.sleep(new Random().nextInt(50));
            } catch (InterruptedException e) {
                ;
            } finally {
                token.reservations.decrementAndGet();
            }
        }
    }
}

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
生产者-消费者问题是一个经典的多线程并发问题,主要涉及到生产者线程和消费者线程之间的协作与同步。在这个问题中,生产者线程负责生产数据,并将数据存入一个共享的缓冲区中,而消费者线程则从缓冲区中取出数据进行消费。为了避免生产者和消费者线程之间的竞争条件和死锁问题,需要使用线程同步技术。 下面是一个使用 Python 语言实现的生产者-消费者问题的示例代码: ```python import threading import time import random # 缓冲区大小 BUFFER_SIZE = 5 # 共享的缓冲区 buffer = [] # 生产者线程 class ProducerThread(threading.Thread): def run(self): global buffer while True: # 生产一个随机数 item = random.randint(1, 10) print("生产者生产了数据:", item) # 获取锁 lock.acquire() # 如果缓冲区已满,等待消费者线程消费数据 while len(buffer) >= BUFFER_SIZE: print("缓冲区已满,生产者等待...") lock.wait() # 将数据存入缓冲区 buffer.append(item) print("生产者将数据存入缓冲区:", buffer) # 释放锁 lock.release() # 随机等待一段时间 time.sleep(random.randint(1, 3)) # 消费者线程 class ConsumerThread(threading.Thread): def run(self): global buffer while True: # 获取锁 lock.acquire() # 如果缓冲区为空,等待生产者线程生产数据 while len(buffer) == 0: print("缓冲区为空,消费者等待...") lock.wait() # 从缓冲区取出数据进行消费 item = buffer.pop(0) print("消费者消费了数据:", item) # 释放锁 lock.release() # 随机等待一段时间 time.sleep(random.randint(1, 3)) # 创建锁 lock = threading.Condition() # 创建生产者线程和消费者线程 producer_thread = ProducerThread() consumer_thread = ConsumerThread() # 启动线程 producer_thread.start() consumer_thread.start() # 等待线程结束 producer_thread.join() consumer_thread.join() ``` 在这个示例代码中,我们使用了 Python 中的 Condition 类来实现线程同步和协作。在生产者线程中,如果缓冲区已满,则使用 wait() 方法等待消费者线程消费数据;在消费者线程中,如果缓冲区为空,则使用 wait() 方法等待生产者线程生产数据。当生产者线程向缓冲区中添加数据或消费者线程从缓冲区中取出数据时,需要使用 acquire() 方法获取锁,以避免竞争条件的发生。 需要注意的是,在生产者-消费者问题中,线程同步和协作是非常重要的,如果实现不当,将会导致死锁、竞争条件等问题。因此,在实际开发中,需要仔细设计和测试多线程程序,以确保程序的正确性和稳定性。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值