import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
public class MyBlockingQueue {
// 取锁
private Lock takeLock = new ReentrantLock();
// 非空谓词
private Condition notEmpty = takeLock.newCondition();
private Lock putLock = new ReentrantLock();
// 非满谓词
private Condition notFull = putLock.newCondition();
private final AtomicInteger size;
private Node head = null, last = null;
private final int capacity;
private static class Node {
Object item;
Node next;
public Node(Object data) {
this.item = data;
}
}
public MyBlockingQueue(int capacity) {
this.size = new AtomicInteger(0);
head = last = new Node(null);
this.capacity = capacity;
}
private void signalNotEmpty() {
this.takeLock.lock();
try {
notEmpty.signal();
} finally {
this.takeLock.unlock();
}
}
private void signalNotFull() {
this.putLock.lock();
try {
notFull.signal();
} finally {
this.putLock.unlock();
}
}
private void enqueue(Object o) {
last = last.next = new Node(o);
}
private Object dequeue() {
Node first = head.next;
Object o=first.item;
first.item=null;
head.next=head;
head=first;
return o;
}
public boolean offer(Object o) {
if (o == null)
throw new NullPointerException();
if (this.size.get() == capacity)
return false;
this.putLock.lock();
int count = -1; // 存储队列先前的大小
try {
if (this.size.get() < capacity) { // 进一步判断,防止在加锁过程中被其他线程加满
enqueue(o);
count = this.size.getAndIncrement();
if (count + 1 < capacity)
notFull.signal(); // 只唤醒一个线程,避免线程之间竞争锁资源
}
} finally {
this.putLock.unlock();
}
if (count == 0) {
signalNotEmpty();
}
return count >= 0;
}
public boolean offer(Object o, long timeout, TimeUnit unit) throws InterruptedException {
if (o == null)
throw new NullPointerException();
long nanos = unit.toNanos(timeout);
int count = -1;
Lock lock = this.putLock;
lock.lockInterruptibly();
try {
while (this.size.get() == capacity) {
if (nanos < 0)
return false;
nanos = notFull.awaitNanos(nanos);
}
enqueue(o);
count = this.size.getAndIncrement();
if (count + 1 < capacity)
this.notFull.signal();
} finally {
lock.unlock();
}
if (count == 0) {
signalNotEmpty();
}
return count >= 0;
}
public void put(Object o) throws InterruptedException {
if (o == null)
throw new NullPointerException();
int count = -1;
Lock lock = this.putLock;
lock.lockInterruptibly();
try {
while (this.size.get() == capacity) {
notFull.await();
}
enqueue(o);
count = this.size.getAndIncrement();
if (count + 1 < capacity)
this.notFull.signal();
} finally {
lock.unlock();
}
if (count == 0)
signalNotEmpty();
}
public Object poll() {
if (this.size.get() == 0) {
return null;
}
int count = -1;
Object o = null;
Lock lock = this.takeLock;
lock.lock();
try {
if (this.size.get() > 0) {
o = dequeue();
count = this.size.getAndDecrement();
if (count - 1 > 0)
this.notEmpty.signal();
}
} finally {
lock.unlock();
}
if (count == capacity)
signalNotFull();
return o;
}
// private void printQueue(){
// Node current=head.next;
// StringBuilder sb=new StringBuilder(size.get()+":");
// while (current != null) {
// sb.append(current.item + " ");
// current=current.next;
// }
// System.out.println(sb.toString());
//
// }
public Object poll(long timeout, TimeUnit unit) throws InterruptedException {
int count = -1;
Object o = null;
long nanos = unit.toNanos(timeout);
Lock lock = this.takeLock;
lock.lock();
try {
while (this.size.get() == 0) {
if (nanos < 0)
return null;
nanos = this.notEmpty.awaitNanos(nanos);
}
//System.out.println(this.size.get());
// printQueue();
o = dequeue();
count = this.size.getAndDecrement();
if (count - 1 > 0)
this.notEmpty.signal();
} finally {
lock.unlock();
}
if (count == capacity)
signalNotFull();
return o;
}
public Object take() throws InterruptedException {
int count = -1;
Object o = null;
Lock lock = this.takeLock;
lock.lock();
try {
while (this.size.get() == 0) {
notEmpty.await();
}
o = dequeue();
count = this.size.getAndDecrement();
if (count - 1 > 0)
this.notEmpty.signal();
} finally {
lock.unlock();
}
if (count == capacity)
signalNotFull();
return o;
}
}
在生产者消费者问题中的测试:
producer代码:
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
public class Producer implements Runnable {
public Producer(MyBlockingQueue queue) {
this.queue = queue;
}
public void run() {
String data = null;
Random r = new Random();
System.out.println("启动生产者线程!");
try {
while (isRunning) {
System.out.println("正在生产数据...");
Thread.sleep(r.nextInt(DEFAULT_RANGE_FOR_SLEEP));
data = "data:" + count.incrementAndGet();
System.out.println("将数据:" + data + "放入队列...");
if (!queue.offer(data, 2, TimeUnit.SECONDS)) {
System.out.println("放入数据失败:" + data);
}
}
} catch (InterruptedException e) {
e.printStackTrace();
Thread.currentThread().interrupt();
} finally {
System.out.println("退出生产者线程!");
}
}
public void stop() {
isRunning = false;
}
private volatile boolean isRunning = true;
private MyBlockingQueue queue;
private static AtomicInteger count = new AtomicInteger();
private static final int DEFAULT_RANGE_FOR_SLEEP = 1000;
}
消费者代码片段:
import java.util.Random;
import java.util.concurrent.TimeUnit;
public class Consumer implements Runnable {
public Consumer(MyBlockingQueue queue) {
this.queue = queue;
}
public void run() {
System.out.println("启动消费者线程!");
Random r = new Random();
boolean isRunning = true;
try {
while (isRunning) {
System.out.println("正从队列获取数据...");
Object data = queue.poll(2, TimeUnit.SECONDS);
if (null != data) {
System.out.println("拿到数据:" + data);
System.out.println("正在消费数据:" + data);
Thread.sleep(r.nextInt(DEFAULT_RANGE_FOR_SLEEP));
} else {
// 超过2s还没数据,认为所有生产线程都已经退出,自动退出消费线程。
isRunning = false;
}
}
} catch (InterruptedException e) {
e.printStackTrace();
Thread.currentThread().interrupt();
} finally {
System.out.println("退出消费者线程!");
}
}
private MyBlockingQueue queue;
private static final int DEFAULT_RANGE_FOR_SLEEP = 1000;
}
测试主程序:
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class MyBlockingQueueTest {
public static void main(String[] args) throws InterruptedException{
MyBlockingQueue queue = new MyBlockingQueue(10);
Producer producer1 = new Producer(queue);
Producer producer2 = new Producer(queue);
Producer producer3 = new Producer(queue);
Consumer consumer = new Consumer(queue);
// 借助Executors
ExecutorService service = Executors.newCachedThreadPool();
// 启动线程
service.execute(producer1);
service.execute(producer2);
service.execute(producer3);
service.execute(consumer);
// 执行10s
Thread.sleep(10 * 1000);
producer1.stop();
producer2.stop();
producer3.stop();
Thread.sleep(2000);
// 退出Executor
service.shutdown();
}
}
注:生产者消费者程序来自于:http://blog.itpub.net/143526/viewspace-1060365/