建议学习前,做好如下准备
1、阅读Unsafe源码
2、CAS原理
代码
import sun.misc.Unsafe;
import java.lang.reflect.Field;
/**
* @Author 李雷(KyLin)
* @Desc
* @Date 2019/12/31
*/
public class Sync {
static final Node EMPTY = new Node();
private volatile int state;//0空闲,1占用
private volatile Node head;//虚拟队列头节点
private volatile Node tail;//虚拟队列尾节点
static Unsafe unsafe;
private static long stateOffset;
private static long tailOffset;
public Sync() {
//初始化EMPTY对象,这样更新tail head的next也更新了
head = tail = EMPTY;
}
//初始化Unsafe工具和offeset
static {
try {
Field field = Unsafe.class.getDeclaredField("theUnsafe");
field.setAccessible(true);
unsafe = (Unsafe) field.get(null);
stateOffset = unsafe.objectFieldOffset(Sync.class.getDeclaredField("state"));
tailOffset = unsafe.objectFieldOffset(Sync.class.getDeclaredField("tail"));
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (NoSuchFieldException e) {
e.printStackTrace();
}
}
public boolean compareAndSetState(int i,int i1) {
return unsafe.compareAndSwapInt(this,stateOffset,i,i1);
}
private boolean compareAndSetTail(Node expect,Node update) {
return unsafe.compareAndSwapObject(this,tailOffset,expect,update);
}
/**
* 加锁,加锁前,尝试一次获取锁
*/
public void lock() {
if (compareAndSetState(0,1)) {
return;// 成功不用加锁
}
Node node = enqueue();
Node prev = node.prev;
//再次尝试获取锁,需要检测上一个节点是不是head,按入队顺序加锁
while (node.prev != head || !compareAndSetState(0,1)) {
unsafe.park(false,0L);
}
//当第一个无阻塞线程结束,唤醒一个线程后就会执行这个。
head = node;
node.thread = null;
node.prev = null;
prev.next = null;
}
/**
* 入队
*
* @return
*/
private Node enqueue() {
while (true) {
Node t = tail;
Node newNode = new Node(Thread.currentThread(),t);
if (compareAndSetTail(t,newNode)) {
t.next = newNode;
return newNode;
}
}
}
/**
* 解锁
*/
public void unlock() {
state = 0;
Node next = head.next;
if (next != null) {
unsafe.unpark(next.thread);
}
}
//虚拟队列节点
private static class Node {
Thread thread;
Node prev;
Node next;
public Node() {
}
public Node(Thread thread,Node prev) {
this.thread = thread;
this.prev = prev;
}
}
}
测试类
import java.util.concurrent.CountDownLatch;
import java.util.stream.IntStream;
/**
* @Author 李雷(KyLin)
* @Desc
* @Date 2019/12/31
*/
public class SyncTest {
public static int count = 0;
public static void main(String[] args) throws InterruptedException {
Sync sync = new Sync();
CountDownLatch countDownLatch = new CountDownLatch(100);
IntStream.range(0,100).forEach(i -> new Thread(() -> {
sync.lock();
try {
IntStream.range(0,1000).forEach(j -> {
count++;
});
}finally {
sync.unlock();
}
countDownLatch.countDown();
},"pt" + i).start());
countDownLatch.await();
System.out.println(count);
}
}