package com.test.juc;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Level;
import java.util.logging.Logger;
public class MyThreadPool {
// 任务队列
private final WorkerQueue<Runnable> workerQueue;
// 线程集合
private final int coreSize;
private final Set<Worker> workers = new CopyOnWriteArraySet<>();
private final long timeout;
private final TimeUnit timeUnit;
private final RejectPolicy<Runnable> rejectPolicy;
private static final Logger LOGGER = Logger.getGlobal();
/**
* 初始化线程池
*
* @param coreSize 核心线程数
* @param timeout 超时阈值
* @param timeUnit 时间单位
* @param capacity 队列容量
* @param policy 拒绝策略
*/
public MyThreadPool(int coreSize, long timeout, TimeUnit timeUnit, int capacity, RejectPolicy<Runnable> policy) {
this.coreSize = coreSize;
this.timeout = timeout;
this.timeUnit = timeUnit;
workerQueue = new BlockingQueue<>(capacity);
this.rejectPolicy = policy;
}
/**
* 提交到池中
*
* @param task 线程任务
*/
public void execute(Runnable task) {
if (task == null) throw new NullPointerException();
if (workers.size() < coreSize) {
Worker worker = new Worker(task);
Thread thread = new Thread(worker);
LOGGER.log(Level.INFO, "新增worker\t{0}", worker);
thread.start();
workers.add(worker);
} else {
rejectPolicy.reject(workerQueue, task);
}
}
/**
* 工作线程
* 负责任务队列的消费
*/
class Worker implements Runnable {
private Runnable task;
public Worker(Runnable task) {
this.task = task;
}
@Override
public void run() {
while (task != null || (task = workerQueue.take(timeout, timeUnit)) != null) {
task.run();
task = null;
}
LOGGER.log(Level.INFO, "移除work\t{0}", this);
workers.remove(this);
}
}
}
/**
* 工作队列
*
* @param <T>
*/
interface WorkerQueue<T> {
void put(T task);
void put(T task, long timeout, TimeUnit timeUnit);
T take();
T take(long timeout, TimeUnit timeUnit);
}
/**
* 拒绝策略
*
* @param <T>
*/
@FunctionalInterface
interface RejectPolicy<T> {
void reject(WorkerQueue<T> taskQueue, T task);
}
/**
* 任务队列的实现
* 提供入队、出队方法
*
* @param <T>
*/
class BlockingQueue<T> implements WorkerQueue<T> {
private static final Logger LOGGER = Logger.getGlobal();
// 任务队列
private final Queue<T> queue;
// 队列容量
private final int capacity;
// 锁
private final Lock lock = new ReentrantLock();
private final Condition condition = lock.newCondition();
public BlockingQueue(int capacity) {
this.capacity = capacity;
this.queue = new ArrayDeque<>(capacity);
}
@Override
public void put(T task) {
lock.lock();
try {
while (queue.size() == capacity) {
LOGGER.log(Level.WARNING, "等待入队\t{0}", task);
condition.await();
}
condition.signal();
queue.offer(task);
LOGGER.log(Level.INFO, "成功入队\t{0}", task);
} catch (Exception e) {
e.printStackTrace();
} finally {
lock.unlock();
}
}
@Override
public void put(T task, long timeout, TimeUnit timeUnit) {
lock.lock();
try {
long nanos = timeUnit.toNanos(timeout);
while (queue.size() == capacity) {
if (nanos <= 0) return;
LOGGER.log(Level.WARNING, "等待入队\t{0}", task);
nanos = condition.awaitNanos(nanos);
}
condition.signal();
queue.offer(task);
LOGGER.log(Level.INFO, "成功入队\t{0}", task);
} catch (Exception e) {
e.printStackTrace();
} finally {
lock.unlock();
}
}
@Override
public T take() {
lock.lock();
try {
while (queue.isEmpty()) {
condition.await();
}
condition.signal();
T poll = queue.poll();
LOGGER.log(Level.INFO, "成功出队\t{0}", poll);
return poll;
} catch (Exception e) {
e.printStackTrace();
} finally {
lock.unlock();
}
return null;
}
@Override
public T take(long timeout, TimeUnit timeUnit) {
lock.lock();
try {
long nanos = timeUnit.toNanos(timeout);
while (queue.isEmpty()) {
if (nanos <= 0) return null;
nanos = condition.awaitNanos(nanos);
}
condition.signal();
T poll = queue.poll();
LOGGER.log(Level.INFO, "成功出队\t{0}", poll);
return poll;
} catch (Exception e) {
e.printStackTrace();
} finally {
lock.unlock();
}
return null;
}
}
package com.test.juc;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;
public class MyThreadPoolDemo {
public static void main(String[] args) {
// 初始化线程池
MyThreadPool threadPool = new MyThreadPool(3, 1_000, TimeUnit.MILLISECONDS,
1, WorkerQueue::put);
// 提交到线程池
IntStream.range(0, 5).forEach((i) -> threadPool.execute(
() -> {
System.out.printf("当前线程:%s\t%s\n",
Thread.currentThread().getName(), i + 1);
try {
Thread.sleep(new Random().nextInt(1_000));
} catch (InterruptedException e) {
e.printStackTrace();
}
})
);
}
// static void test() {
// ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(1, 2,
// 10_000, TimeUnit.MILLISECONDS, new ArrayBlockingQueue<>(10),
// Executors.defaultThreadFactory(), new ThreadPoolExecutor.DiscardOldestPolicy());
// IntStream.range(0, 30).forEach(
// (i) -> threadPoolExecutor.execute(() ->
// System.out.printf("hello world %s %s\n", i + 1, Thread.currentThread().getName())
// ));
// threadPoolExecutor.shutdown();
// }
}