线程池简单结构图:
下面我来实现自定义线程池,有4个部分
-
拒绝策略:当线程池线程全部运行,且阻塞队列满时,改使用那种策略
/** * 拒绝策略 * @param <T> 泛型-任务 */ @FunctionalInterface public interface RejectStrategy<T> { void reject(BlockingQueue<T> queue, T task); }
-
阻塞队列:当线程池线程全部运行,需要把任务放入阻塞队列
- 线程(任务执行)
import lombok.extern.slf4j.Slf4j; import java.util.ArrayDeque; import java.util.Deque; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.ReentrantLock; /** * 阻塞队列 * @param <T> 泛型-任务 */ @Slf4j(topic = "c.BlockingQueue") public class BlockingQueue<T> { /** 任务队列 */ private final Deque<T> queue = new ArrayDeque<>(); /** 锁 */ private final ReentrantLock lock = new ReentrantLock(); /** 条件变量-任务队列满时生产者 */ private final Condition fullWaitSet = lock.newCondition(); /** 条件变量-任务队列为空时消费者 */ private final Condition emptyWaitSet = lock.newCondition(); /** 容量 */ private int capacity; public BlockingQueue(int capacity) { this.capacity = capacity; } /** * 带超时时间的阻塞获取 * @param timeout 超时时间 * @param unit 超时单位 * @return 任务 */ public T pool(long timeout, TimeUnit unit) { lock.lock(); // 超时时间统一转换为纳秒 long nanos = unit.toNanos(timeout); try { while (queue.isEmpty()) { if (nanos <= 0) { return null; } nanos = emptyWaitSet.awaitNanos(nanos); } T t = queue.removeFirst(); fullWaitSet.signal(); return t; } catch (InterruptedException e) { e.printStackTrace(); } finally { lock.unlock(); } return null; } /** * 阻塞获取 * @return 任务 */ public T take() { lock.lock(); T t = null; try { while (queue.isEmpty()) { emptyWaitSet.await(); } t = queue.removeFirst(); fullWaitSet.signal(); } catch (InterruptedException e) { e.printStackTrace(); } finally { lock.unlock(); } return t; } /** * 超时阻塞添加 * @param t 任务 * @param timeout 超时时间 * @param unit 超时单位 * @return 添加成功true,false添加失败 */ public boolean offer(T t, long timeout, TimeUnit unit) { lock.lock(); // 超时时间统一转换为纳秒 long nanos = unit.toNanos(timeout); try { while (queue.size() == capacity) { if (nanos <= 0) { return false; } log.debug("等待加入任务队列 {}", t); nanos = fullWaitSet.awaitNanos(nanos); } queue.addLast(t); log.debug("加入任务队列 {}", t); emptyWaitSet.signal(); return true; } catch (InterruptedException e) { e.printStackTrace(); } finally { lock.unlock(); } return false; } /** * 阻塞添加 * @param t 任务 */ public void put(T t) { lock.lock(); try { while (queue.size() == capacity) { log.debug("等待加入任务队列 {}", t); fullWaitSet.await(); } queue.addLast(t); log.debug("加入任务队列 {}", t); emptyWaitSet.signal(); } catch (InterruptedException e) { e.printStackTrace(); } finally { lock.unlock(); } } /** * 获取容量 * @return 容量 */ public int size() { lock.lock(); try { return queue.size(); } finally { lock.unlock(); } } /** * 尝试添加 * @param rejectStrategy 决绝策略 * @param t 任务 */ public void tryPut(RejectStrategy<T> rejectStrategy, T t) { lock.lock(); try { // 判断队列是否已满 if (queue.size() == capacity) { rejectStrategy.reject(this, t); } else { queue.addLast(t); log.debug("加入任务队列 {}", t); emptyWaitSet.signal(); } } finally { lock.unlock(); } } }
-
线程池:存放线程
import lombok.extern.slf4j.Slf4j; import java.util.HashSet; import java.util.concurrent.TimeUnit; /** * 自定义线程池 */ @Slf4j(topic = "c.ThreadPool") public class ThreadPool { /** 阻塞队列 */ private final BlockingQueue<Runnable> taskQueue; /** 线程集合 */ private final HashSet<Worker> workers = new HashSet(); /** 核心线程数 */ private final int coreSize; /** 超时时间 */ private long timeout; /** 超时单位 */ private TimeUnit unit; /** 决绝测试 */ private RejectStrategy<Runnable> rejectStrategy; class Worker extends Thread{ private Runnable task; public Worker(Runnable task) { this.task = task; } @Override public void run() { // 执行任务 // 1、当task不为空时,执行任务 // 2、当task为空时,尝试从任务队列获取任务执行 // while (task != null || (task = taskQueue.take()) != null) { while (task != null || (task = taskQueue.pool(timeout, unit)) != null) { try { log.debug("正在执行 {}", task); task.run(); } catch (Exception e) { e.printStackTrace(); } finally { task = null; } } synchronized (workers) { log.debug("移除 worker {}", this); workers.remove(this); } } } public ThreadPool(int coreSize, long timeout, TimeUnit unit, int queueCapacity, RejectStrategy<Runnable> strategy) { this.taskQueue = new BlockingQueue<>(queueCapacity); this.coreSize = coreSize; this.timeout = timeout; this.unit = unit; this.rejectStrategy = strategy; } /** * 执行任务 * @param task 任务 */ public void execute(Runnable task) { // 当任务数没有超过核心线程数时,直接交给worker对象执行 // 否则,暂存入任务队列 synchronized (this) { if (workers.size() < coreSize) { Worker worker = new Worker(task); workers.add(worker); log.debug("新增 worker {}", worker); worker.start(); } else { taskQueue.tryPut(rejectStrategy, task); } } } }
-
测试
import lombok.extern.slf4j.Slf4j; import java.util.concurrent.TimeUnit; @Slf4j(topic = "c.TestThreadPool") public class TestThreadPool { public static void main(String[] args) { ThreadPool pool = new ThreadPool(1, 1000, TimeUnit.MILLISECONDS, 1, (queue, task) -> { // 1、死等 // queue.put(task); // 2、超时等待 // queue.offer(task, 500, TimeUnit.MILLISECONDS); // 3、放弃执行 // log.debug("放弃 {}", task); // 4、抛出异常 // throw new RuntimeException("任务执行失败" + task); // 5、调用者执行 }); for (int i = 0; i < 4; i++) { int j = i; pool.execute(() -> { try { TimeUnit.SECONDS.sleep(10); } catch (InterruptedException e) { e.printStackTrace(); } log.debug("{}", j); }); } } }
这里我们只是模拟线程池实现,方便我们后续学习其他已实现的成熟的线程池,实际应用中还是要使用成熟的线程池。
QQ:806797785
仓库:https://gitee.com/gaogzhen/concurrent