代码:
package conSet;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
/**
* 并发单向队列简单实现
*
* @author dingchd
*
* @param <T>
*/
public class NoBlockQueue<T> {
private Node<T> header;
private AtomicReference<Node<T>> tail;
private AtomicInteger size;
public NoBlockQueue() {
header = new Node<T>();
tail = new AtomicReference<Node<T>>(header);
size = new AtomicInteger(0);
}
/**
* 存元素的过程分两步骤:原子更新尾节点的next、原子更新尾节点 如果第二部更新失败 则原子还原尾节点的next
*
* @return
*/
public void add(T t) {
// 创建一个节点
Node<T> node = new Node<T>();
node.value = t;
Node<T> curTail = null;
for (;;) {
curTail = tail.get();
if (curTail.next.get() == null) {
if (casNext(curTail, null, node)) {
if (casTail(curTail, node)) {
size.incrementAndGet();
return;
} else {
curTail.next.getAndSet(null);
}
}
}
}
}
/**
* 取元素分两部:原子更新header的next、第一个元素为尾节点,则将尾节点原子更新到header 如果第二部失败,则原则还原第一步
*
* @return
*/
public T poll() {
Node<T> first = null;
T value = null;
for (;;) {
first = header.next.get();
Node<T> curTail = tail.get();
// 队列空
if (curTail == header && first == null) {
break;
}
// 中间状态
if ((first != null && curTail == header)
|| (first == null && curTail != header)) {
continue;
}
if (first != null) {
// 如果tail指向第一个元素,则取队首后将tail更新至header
if (curTail == first) {
if (casHeaderNext(first, null)) {
if (casTail(curTail, header)) {
value = first.value;
break;
} else {
header.next.getAndSet(first);
}
}
} else {
Node<T> second = first.next.get();
// 如果second为null,则说明当前获得的first已经被其他线程取走
if (second != null) {
if (casHeaderNext(first, second)) {
value = first.value;
break;
}
}
}
}
}
if (value != null) {
size.decrementAndGet();
}
return value;
}
public boolean isEmpty() {
return tail.get().value == null;
}
public T top() {
Node<T> first = header.next.get();
return first == null ? null : first.value;
}
public int size() {
return size.get();
}
private final boolean casHeaderNext(Node<T> before, Node<T> after) {
return header.next.compareAndSet(before, after);
}
private final boolean casTail(Node<T> before, Node<T> after) {
return tail.compareAndSet(before, after);
}
private final boolean casNext(Node<T> node, Node<T> before, Node<T> after) {
return node.next.compareAndSet(before, after);
}
static class Node<T> {
T value;
AtomicReference<Node<T>> next = new AtomicReference<Node<T>>();
}
}
测试代码:
package conSet;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Queue;
import java.util.UUID;
import java.util.concurrent.ConcurrentLinkedQueue;
public class NoBlockQueueTest2 {
public static int SIZE = 10000;
public static int C_NUM = 10;
/**
* @param args
*/
public static void main(String[] args) {
for (int i = 0; i < 10000; i++) {
test();
}
}
public static void test() {
NoBlockQueue<String> queue = new NoBlockQueue<String>();
Queue<String> input = new ConcurrentLinkedQueue<String>();
Queue<String> output = new ConcurrentLinkedQueue<String>();
for (int i = 0; i < C_NUM; i++) {
Runnable mp = new MP(queue, input);
new Thread(mp).start();
}
List<Thread> list = new ArrayList<Thread>();
for (int i = 0; i < C_NUM; i++) {
Runnable mc = new MC(queue, output);
Thread t = new Thread(mc);
t.start();
list.add(t);
}
for (Thread t : list) {
try {
t.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
ArrayList<String> sort1 = new ArrayList<String>();
ArrayList<String> sort2 = new ArrayList<String>();
while (!input.isEmpty()) {
sort1.add(input.poll());
}
while (!output.isEmpty()) {
sort2.add(output.poll());
}
Collections.sort(sort1);
Collections.sort(sort2);
if (sort1.size() != sort2.size()) {
throw new RuntimeException("test error,size not equal");
}
for (int i = 0; i < sort1.size(); i++) {
String left = sort1.get(i);
String right = sort2.get(i);
if (!left.equals(right)) {
throw new RuntimeException("test error,data wrong");
}
}
System.out.println("test ok size=" + queue.size());
}
static class MP implements Runnable {
NoBlockQueue<String> queue;
Queue<String> input;
public MP(NoBlockQueue<String> queue, Queue<String> input) {
super();
this.queue = queue;
this.input = input;
}
public void run() {
for (int i = 0; i < NoBlockQueueTest2.SIZE; i++) {
String s = UUID.randomUUID().toString();
input.add(s);
queue.add(s);
}
}
}
static class MC implements Runnable {
NoBlockQueue<String> queue;
Queue<String> output;
public MC(NoBlockQueue<String> queue, Queue<String> output) {
super();
this.queue = queue;
this.output = output;
}
public void run() {
final int count = NoBlockQueueTest2.C_NUM * NoBlockQueueTest2.SIZE;
for (;;) {
String s = queue.poll();
if (s != null) {
output.add(s);
} else {
if (output.size() == count) {
break;
}
}
}
}
}
}
因为没有实现remove和itr功能,因此复杂度甚微,经过10000次的不断测试,尚未发现测试失败