最近再重学Java多线程的内容,Java中线程的同步基本是靠两种方式,一是Object自带的Monitor机制,通过Object上的wait/notify实现的等待/通知模式;二是JUC并发包下的Lock系列API,底层是通过LockSupport的park/unpark来实现的等待通知(park和unpark是native方法,由JVM实现,在Linux下是借助pthread_cond_wait和pthread_cond_signal实现)
手写线程池
使用Object的wait/notify机制,手写一个简易的线程池,进行练习,以加深对并发的理解(包括线程中断机制等)
定义一个线程池接口
package experiment;
/**
* @Author yogurtzzz
* @Date 2022/4/20 15:08
**/
public interface ThreadPool {
// 提交一个任务到线程池
void execute(Runnable runnable);
// 增加线程池中的线程
void addWorker(int n);
// 移除线程池中的线程
void removeWorker(int n);
// 关闭线程池
void shutdown();
// 打印线程池状态
void printStatus();
}
编写实现类
package experiment;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
/**
* @Author yogurtzzz
* @Date 2022/4/20 15:09
**/
public class DefaultThreadPool implements ThreadPool {
enum State {
INITILIAZING, RUNNING, TERMINATED
}
private final LinkedList<Runnable> jobList = new LinkedList<>();
private final Set<Worker> workerSet = new HashSet<>(); // 总的 Worker
private final Set<Worker> busyWorkers = new HashSet<>();
private final Set<Worker> idleWorkers = new HashSet<>();
private final AtomicLong completedJobCnt = new AtomicLong();
private final AtomicLong processingJobCnt = new AtomicLong();
private State state;
private static final int MAX_SIZE = 100;
private static final int MIN_SIZE = 1;
private static final AtomicInteger threadNum = new AtomicInteger();
public DefaultThreadPool(int initialSize) {
state = State.INITILIAZING;
int n = initialSize < MIN_SIZE ? MIN_SIZE : initialSize > MAX_SIZE ? MAX_SIZE : initialSize;
addWorkers(n);
state = State.RUNNING;
}
private void addWorkers(int n) {
if (n <= 0) return;
synchronized (workerSet) {
int size = workerSet.size();
int remain = MAX_SIZE - size; // 最多还能添加多少个Worker
if (n > remain) n = remain;
for (int i = 0; i < n; i++) {
Worker worker = new Worker();
Thread t = new Thread(worker, "Worker-" + threadNum.getAndIncrement());
worker.t = t;
t.start();
workerSet.add(worker);
}
}
}
private void removeWorkers(int n) {
if (n <= 0) return;
synchronized (workerSet) {
int size = workerSet.size();
int remain = size - n; // 减掉后还剩多少个
if (remain < MIN_SIZE) n = size - 1; // 保证最少还剩一个线程
// 减掉n个线程
int i = 0;
for (Worker w : idleWorkers) {
if (i >= n) break; // 已经减完
w.shutdown();
i++; // 已经减掉了一个
}
if (i < n) {
for (Worker w : busyWorkers) {
if (i >= n) break;
w.shutdown();
i++;
}
}
}
}
@Override
public void execute(Runnable runnable) {
synchronized (jobList) {
jobList.add(runnable);
jobList.notify();
}
}
@Override
public void addWorker(int n) {
addWorkers(n);
}
@Override
public void removeWorker(int n) {
removeWorkers(n);
}
@Override
public void shutdown() {
synchronized (workerSet) {
for (Worker w : workerSet) w.shutdown();
}
state = State.TERMINATED;
}
@Override
public void printStatus() {
System.out.printf("Pool State: %s, WorkerCnt: %d, BusyWorker: %d, IdleWorker: %d, CompletedTask: %d, ProcessingTask: %d, WaitingTask: %d\n",
state, workerSet.size(), busyWorkers.size(), idleWorkers.size(), completedJobCnt.get(), processingJobCnt.get(), jobList.size());
}
private class Worker implements Runnable {
private Thread t;
@Override
public void run() {
retry:
while (!Thread.currentThread().isInterrupted()) {
Runnable runnable = null;
// 取任务
synchronized (jobList) {
while (jobList.isEmpty()) {
try {
busyWorkers.remove(this);
idleWorkers.add(this);
System.out.printf("%s is waiting for job...\n", Thread.currentThread().getName());
jobList.wait();
} catch (InterruptedException e) {
Thread.currentThread().interrupt(); // 重置中断标志位
break retry;
}
}
runnable = jobList.poll();
}
idleWorkers.remove(this);
busyWorkers.add(this);
// 执行任务
if (runnable != null) {
processingJobCnt.incrementAndGet();
runnable.run();
processingJobCnt.decrementAndGet();
completedJobCnt.incrementAndGet();
}
}
idleWorkers.remove(this);
busyWorkers.remove(this);
workerSet.remove(this);
System.out.printf("%s is going to terminate...\n", Thread.currentThread().getName());
}
public void shutdown() {
t.interrupt(); // 中断 Worker 所属的线程
}
}
}
编写测试类
package experiment;
import java.util.Random;
import java.util.Scanner;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
/**
* @Author yogurtzzz
* @Date 2022/4/20 16:05
**/
public class Test {
private static Scanner scanner = new Scanner(System.in);
private static final AtomicInteger taskCnt = new AtomicInteger();
private static final Random random = new Random();
public static void main(String[] args) {
ThreadPool threadPool = new DefaultThreadPool(20);
Runnable task = () -> {
int i = random.nextInt(100);
System.out.printf("task-%d is processing...will take %d seconds \n", taskCnt.getAndIncrement(), i);
try {
TimeUnit.SECONDS.sleep(i);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
};
while (true) {
int op = 1;
try {
op = Integer.parseInt(scanner.nextLine());
} catch (Exception e) { }
if (op == 1) threadPool.execute(task);
else if (op == 2) threadPool.addWorker(1);
else if (op == 3) threadPool.removeWorker(1);
else if (op == 4) threadPool.printStatus();
else if (op == 5) threadPool.shutdown();
}
}
}
效果:
手写阻塞队列
手写一版简单的 ArrayBlockingQueue,底层用普通数组模拟一个循环队列,阻塞机制仍然借助Object的wait/notify
package experiment;
/**
* @Author yogurtzzz
* @Date 2022/4/20 16:58
*
* 简单的阻塞队列
**/
public class ArrayBlockingQueue<T> {
private final Object notEmpty = new Object();
private final Object notFull = new Object();
private Object[] elements;
// 双指针循环队列
private int first;
private int last;
public ArrayBlockingQueue(int size) {
elements = new Object[size + 1]; // 空出一个位置, 用来标识队列是满还是空
first = 0; // last 用来指示队尾的下一个位置(插入时直接在last位置插入), first 留空
last = 1; // 当 last 在 first 下一个位置时, 即 last = first + 1 时, 队列为空; 当 first = last时, 队列满
// 每次插入, 直接在 last 位置插入, 并后移 last;
// 每次取元素, 在 first + 1 的位置取
}
public void put(T e) {
int c = -1;
synchronized (notFull) {
while (isFull()) { // 如果队列满, 则等待
try {
notFull.wait();
} catch (InterruptedException e1) {
Thread.currentThread().interrupt(); // 传递中断
}
}
c = size();
elements[last] = e;
last = (last + 1) % elements.length;
}
// 当先前的size为0时, 添加后才进行通知
if (c == 0) {
synchronized (notEmpty) {
notEmpty.notify();
}
}
}
public T get() {
T res = null;
int c = -1;
synchronized (notEmpty) {
while (isEmpty()) {
try {
notEmpty.wait();
} catch (InterruptedException e) {
Thread.currentThread().interrupt(); // 传递中断
}
}
c = size();
res = (T) elements[(first + 1) % elements.length];
first = (first + 1) % elements.length;
}
if (c == elements.length - 1) {
synchronized (notFull) {
notFull.notify();
}
}
return res;
}
public int size() {
return (last - first - 1 + elements.length) % elements.length;
}
private boolean isFull() {
return last == first;
}
private boolean isEmpty() {
return ((first + 1) % elements.length) == last;
}
}