概述
线程池是基于生产者-消费者模型实现的一种对线程的管理工具,其中维护了"消费者"队列和作为"缓冲区"的任务队列
生产者:产生任务的线程,比如客户端的请求
消费者:线程池里维护的线程,负责执行任务
缓冲区:线程池里维护的任务队列
线程池的基本工作过程就是生产者不停的往任务队列中加入任务,消费者队列中的线程不停的从任务队列中取出任务去完成
之所以诞生这样一个设计,原因是线程池可以:
- 降低上下文切换开销:任务到来不需要新建线程去执行,而是直接从消费者队列取出一个线程去执行任务,任务结束也不需要销毁线程
- 提高响应速度:任务到来直接取出线程而不是创造线程,节省了创造线程的时间
- 速度匹配:维护一个缓冲区可以有效缓解生产者生产任务和消费者处理任务的速度差异
- 提高整个系统的可控性:线程是稀缺资源,线程池可以对线程进行统一管理和分配,防止无止境地创建线程
线程池的框架
写自定义类自然是先搭好框架
一个线程池类应该具备的要素应该有:
- 缓冲区: 一个阻塞队列
- 消费者:一个继承Thread的类
- 消费者队列:存储存活的消费者的容器
- 对消费者和缓冲区的管理操作
因此可以搭出如下框架:
线程池类:
public class MyThreadPool {
private TaskQueue<Runnable> taskQueue; // 缓冲区---任务队列
private HashSet<Worker> threadPool = new HashSet<>(); // 消费者线程的容器
class TaskQueue<T> {
}
// 消费者类
class Worker extends Thread {
@Override
public void run() {
// 不停从任务队列获取任务执行
}
}
// 线程池里添加新的消费者线程
public void addWorker (Worker worker) {
}
// 消费者线程执行任务
public void execute(Runnable task) {
}
}
任务队列—缓冲区
在学习操作系统时,在生产者-消费者模型的缓冲区中有两个比较重要的概念叫时同步量与互斥量
- 通过互斥量实现对临界区的互斥访问
- 通过同步量限制生产者和消费者访问临界区的顺序,即缓冲区为空消费者无法访问,缓冲区满生产者无法继续生产
在Java中很容易理解,互斥量的实现就是锁(Lock),同步量的实现就是条件变量(Condition)
缓冲区至少应具备的要素是
- 任务队列
- 控制同步的条件变量和控制互斥的锁
- 获取任务和加入任务操作
class TaskQueue<T> {
private Deque<T> tasks = new ArrayDeque<>(); // 任务队列
private int capacity = 1; // 任务队列存储任务上限(缓冲区容量)
TaskQueue (int cap) {
this.capacity = cap;
}
private ReentrantLock lock = new ReentrantLock(); // 互斥量
// 两个同步量
private Condition notFull = lock.newCondition();
private Condition notEmpty = lock.newCondition();
// 往缓冲区加入任务
public void put(T task) {
lock.lock();
try {
while (tasks.size() == capacity) {
try {
notFull.await(); // 缓冲区满, 生产者线程阻塞
} catch (InterruptedException e) {
e.printStackTrace();
}
}
tasks.addLast(task);
notEmpty.signal(); // 唤醒消费者线程
} finally {
lock.unlock();
}
}
// 从缓冲区获取任务
public T get() {
lock.lock();
try {
while (tasks.isEmpty()) {
try {
notEmpty.await(); // 缓冲区空, 消费者阻塞等待生产者产生任务
} catch (InterruptedException e) {
e.printStackTrace();
}
}
T t = tasks.removeFirst();
notFull.signal(); // 唤醒生产者线程
return t;
} finally {
lock.unlock();
}
}
}
上面代码基本实现了任务队列的功能,但是有一个问题:
生产者如果发现队列已满,会自我阻塞,这会导致生产者线程无法进行其它工作,有些时候这种选择是不利于系统运行的
所以应该给生产者提供一个选择,遇到队列已满时
- 选择一直等待直到队列不满
- 选择等待一段时间后返回,放弃本次提交任务
- 选择直接放弃
这里其实就是任务队列的拒绝策略(后面会详细叙述)的选择,这里仅实现前两种选择
为此添加 offer( )方法
// 等待超时模式下的添加任务
public boolean offer(T task) {
lock.lock();
// nano为剩余可等待时间
long nano = timeUnit.toNanos(timeout); // timeout为设定等待时间, 将timeout转化为毫秒
try {
while (tasks.size() == capacity) { // 任务队列已满
if(nano <= 0) {
return false; // 可等待时间归0,添加任务失败,生产者线程继续执行其它工作
}
try {
nano = notFull.awaitNanos(nano); // awaitNanos返回值为剩余等待时间
} catch (InterruptedException e) {
e.printStackTrace();
}
}
tasks.addLast(task);
notEmpty.signal();
return true;
} finally {
lock.unlock();
}
}
offer()方法基于等待超时模式,设定一个等待时间上限timeout,随着不断尝试添加任务的失败,剩余等待时间nano会减少至0,此时添加任务失败,生产者线程继续进行其它工作
同样,消费者取出任务也需要等待超时模式
// 等待超时模式的获取任务
public T poll() {
lock.lock();
try {
long nano = timeUnit.toNanos(timeout);
while (tasks.isEmpty()) {
try {
if(nano <= 0) {
return null;
}
nano = notEmpty.awaitNanos(nano); // awaitNanos返回值为剩余等待时间
} catch (InterruptedException e) {
e.printStackTrace();
}
}
notFull.signal();
return tasks.removeFirst();
} finally {
lock.unlock();
}
}
消费者
消费者是一个线程因此继承Thread类,重写run()方法实现从任务队列取出任务并处理的动作
class Worker extends Thread {
// 记录此时Worker正在运行的任务
private Runnable task;
@Override
public void run() {
// 不停从任务队列获取任务
while(task != null || (task = taskQueue.get()) != null) {
task.run(); // 直接运行任务, task是个实现了Runnable的类对象
task = null;
}
synchronized (threadPool) { // 改动消费者容器需要上锁
threadPool.remove(this); // 长时间无任务则销毁线程, 避免不停循环浪费资源
}
}
}
线程池的管理方法
public void addWorker (Worker worker) {
threadPool.add(worker);
}
public void execute(Runnable task) {
synchronized (threadPool) {
if(threadPool.size() < coreSize) { // 线程池消费者线程数量不足
// 添加消费者线程
Worker worker = new Worker();
addWorker(worker);
taskQueue.put(task);
worker.start();
} else {
// 添加任务到队列中
taskQueue.put(task); // 拒绝策略使用put()---即一直等待到队列不满为止
}
}
}
策略模式进行改进
生产者提交任务时,如果任务队列已满,线程池应该提供给生产者线程一个选择,即拒绝策略,确定此时生产者是等还是放弃任务的提交
上文中的代码拒绝策略体现在put()和offer()两个函数,一个是持续等待,一个是超时等待
而选择哪种拒绝策略在上文的代码中在execute()中体现
// 添加任务到队列中
taskQueue.put(task); // 拒绝策略使用put()---即一直等待到队列不满为止
这等于将选择的权利交给了线程池,不利于代码的通用性
较为理想的代码应该是
taskQueue.tryPut(task, 拒绝策略);
抽象出一个拒绝策略对象,无论何种拒绝策略都调用同一行代码taskQueue.tryPut(),在尝试将任务添加到任务队列,具体执行哪种拒绝策略则由生产者线程提交任务时,将这种策略作为参数传递进入线程池中
所以作如下改动:
添加策略接口
interface RejectPolicy<T> {
void reject(MyThreadPool.TaskQueue<T> taskQueue, T task);
}
线程池类添加策略对象和tryPut()函数
private RejectPolicy<Runnable> rejectPolicy;
public void tryPut(T task, RejectPolicy<T> rejectPolicy) {
lock.lock();
try {
if (tasks.size() == capacity) { // 队列已满选择拒绝生产者的策略
try {
rejectPolicy.reject(this, task);
} catch (Exception e) {
e.printStackTrace();
}
} else {
tasks.addLast(task);
notEmpty.signal();
}
} finally {
lock.unlock();
}
}
创建线程池时通过如下方式定义拒绝策略:
main()线程为生产者线程
public static void main(String[] args) {
MyThreadPool threadPool = new MyThreadPool(new RejectPolicy<Runnable>() {
@Override
public void reject(MyThreadPool.TaskQueue<Runnable> taskQueue, Runnable task) {
taskQueue.put(task);
}
});
// 生产任务
for(int i = 0; i < 5; i++) {
Runnable task = new Runnable() {
@Override
public void run() {
try {
Thread.sleep(10000L);
} catch (Exception e) {
e.printStackTrace();
}
}
};
// 执行任务
threadPool.execute(task);
}
}
使用匿名类对象实现reject()方法,该方法中写明拒绝策略
完整线程池类代码
完整的线程池类如下所示:
interface RejectPolicy<T> {
void reject(MyThreadPool.TaskQueue<T> taskQueue, T task);
}
@Slf4j(topic = "ThreadPoolTest")
public class MyThreadPool {
private int coreSize = 1; // 线程池活动线程数量
private TaskQueue<Runnable> taskQueue; // 任务队列
private HashSet<Worker> threadPool = new HashSet<>();
private long timeout = 1000; // 等待超时上限
private TimeUnit timeUnit; // 时间转换工具
private RejectPolicy<Runnable> rejectPolicy; // 拒绝策略
MyThreadPool(int coreSize, int cap) {
this.coreSize = coreSize;
this.taskQueue = new TaskQueue<>(cap);
}
MyThreadPool(int coreSize, int cap, long timeout, TimeUnit timeUnit) {
this.coreSize = coreSize;
this.taskQueue = new TaskQueue<>(cap);
this.timeUnit = timeUnit;
this.timeout = timeout;
}
MyThreadPool(int coreSize, int cap, long timeout, TimeUnit timeUnit, RejectPolicy<Runnable> rejectPolicy) {
this.coreSize = coreSize;
this.taskQueue = new TaskQueue<>(cap);
this.timeUnit = timeUnit;
this.timeout = timeout;
this.rejectPolicy = rejectPolicy;
}
public int getCoreSize() {
return this.coreSize;
}
class TaskQueue<T> {
private Deque<T> tasks = new ArrayDeque<>(); // 任务队列
private int capacity = 1; // 任务队列存储任务上限
TaskQueue (int cap) {
this.capacity = cap;
}
private ReentrantLock lock = new ReentrantLock();
// 两个同步量
private Condition notFull = lock.newCondition();
private Condition notEmpty = lock.newCondition();
// 阻塞添加
public void put(T task) {
lock.lock();
try {
while (tasks.size() == capacity) {
try {
log.debug("队列已满, 等待Worker执行任务...");
notFull.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
tasks.addLast(task);
log.debug("任务: {} 加入任务队列", task);
notEmpty.signal();
} finally {
lock.unlock();
}
}
// 等待-超时阻塞添加
public boolean offer(T task) {
lock.lock();
long nano = timeUnit.toNanos(timeout);
try {
while (tasks.size() == capacity) {
if(nano <= 0) {
log.debug("任务 [ {} ]添加失败", task);
return false;
}
try {
log.debug("队列已满, 等待Worker执行任务...");
nano = notFull.awaitNanos(nano); // 更新剩余时间
} catch (InterruptedException e) {
e.printStackTrace();
}
}
tasks.addLast(task);
log.debug("任务: {} 加入任务队列", task);
notEmpty.signal();
return true;
} finally {
lock.unlock();
}
}
public void tryPut(T task, RejectPolicy<T> rejectPolicy) {
lock.lock();
try {
if (tasks.size() == capacity) { // 队列已满选择拒绝生产者的策略
try {
rejectPolicy.reject(this, task);
} catch (Exception e) {
e.printStackTrace();
}
} else {
tasks.addLast(task);
log.debug("任务: {} 加入任务队列", task);
notEmpty.signal();
}
} finally {
lock.unlock();
}
}
public T get() {
lock.lock();
try {
while (tasks.isEmpty()) {
try {
log.debug("任务队列暂时没有任务,等待任务...");
notEmpty.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
T t = tasks.removeFirst();
notFull.signal();
return t;
} finally {
lock.unlock();
}
}
public T poll() {
lock.lock();
try {
long nano = timeUnit.toNanos(timeout);
while (tasks.isEmpty()) {
try {
if(nano <= 0) {
return null;
}
log.debug("任务队列暂时没有任务,等待任务...");
nano = notEmpty.awaitNanos(nano); // awaitNanos返回剩余时间作为新的nano
} catch (InterruptedException e) {
e.printStackTrace();
}
}
notFull.signal();
return tasks.removeFirst();
} finally {
lock.unlock();
}
}
}
class Worker extends Thread {
// 记录此时Worker正在运行的任务
private Runnable task;
@Override
public void run() {
// 不停从任务队列获取任务
while(task != null || (task = taskQueue.get()) != null) {
log.debug("正在执行{}", task);
task.run();
task = null;
}
synchronized (threadPool) {
threadPool.remove(this);
log.debug("线程:{} 已移除", this);
}
}
}
public void addWorker (Worker worker) {
threadPool.add(worker);
}
// 将任务放进任务队列即可
public void execute(Runnable task) {
synchronized (threadPool) {
if(threadPool.size() < coreSize) { // 线程池线程数量不足
Worker worker = new Worker();
addWorker(worker);
taskQueue.put(task);
log.debug("新增线程: {}", worker);
worker.start();
} else {
// 添加任务到队列中
taskQueue.tryPut(task, rejectPolicy);
}
}
}
}
测试代码
public static void main(String[] args) {
MyThreadPool threadPool = new MyThreadPool(2, 2, 1000, TimeUnit.MILLISECONDS, new RejectPolicy<Runnable>() {
@Override
public void reject(MyThreadPool.TaskQueue<Runnable> taskQueue, Runnable task) {
taskQueue.put(task);
}
});
for(int i = 0; i < 5; i++) {
Runnable task = new Runnable() {
@Override
public void run() {
try {
Thread.sleep(10000L);
} catch (Exception e) {
e.printStackTrace();
}
}
};
threadPool.execute(task);
}
}