ReentrantLock实现原理-如何实现一把锁(一)


一、使用CAS实现一把锁

锁作用可以抽象理解为避免共享资源被并发访问。按照这条概念我们在JAVA中可以定义一下实现。
1. 定义一个锁变量state。
2. 当多个线程同时范围同一个共享资源时,我们通过cas保证只有一个线程修改这个锁变量state成功,即获得锁。其他没有获得锁的线程,不断自旋尝试获得锁。
3. 当使用完共享资源时,还原state的值,让其他线程获得锁。

定义锁接口

public interface Lock {
    void lock();
    void unlock();
}

按照上面原则具体实现如下:

public class SpinLock implements Lock {
    AtomicInteger state = new AtomicInteger();
    @Override
    public void  lock() {
        boolean flag;
        do {
        //自旋
            flag = this.state.compareAndSet(0, 1);
        }
        while (!flag);
    }
    @Override
    public void unlock() {
        state.compareAndSet(1,0);
    }
}

测试

public class Main {

    static int value = 0;
    public static void main(String[] args) throws InterruptedException {
        SpinLock spinLock = new SpinLock();
        final CyclicBarrier cyclicBarrier = new CyclicBarrier(10);
        for (int i = 0; i < 10; i++) {
            new Thread(new Runnable() {
                public void run() {
                    try {
                        cyclicBarrier.await();
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                    spinLock.lock();
                    for (int j = 0; j < 100; j++) {
                        value++;
                    }
                    spinLock.unlock();

                }
            }).start();
        }
        TimeUnit.SECONDS.sleep(3);
        System.out.println("value: " + value);
    }
}

结果

value: 1000

二、实现可重入

当我们判断是同一个线程再次获得锁时,把state自增1。代表获得锁的次数,即可实现可重入。

为了后面讲解ReentrantLock方便,我们重构代码。定义CustomAbstractQueuedSynchronizer抽象类并继承AbstractOwnableSynchronizer。AbstractOwnableSynchronizer是JDK提供的抽象类,用于设置和获取当前获得锁的线程。为了使用state方便,改用unsafe对state进行操作。

public abstract class AbstractOwnableSynchronizer
    implements java.io.Serializable {

    private static final long serialVersionUID = 3737899427754241961L;

 
    protected AbstractOwnableSynchronizer() { }


    private transient Thread exclusiveOwnerThread;


    protected final void setExclusiveOwnerThread(Thread thread) {
        exclusiveOwnerThread = thread;
    }

  
    protected final Thread getExclusiveOwnerThread() {
        return exclusiveOwnerThread;
    }
}
public abstract class CustomAbstractQueuedSynchronizer extends AbstractOwnableSynchronizer {
    /**
     * The synchronization state.
     */
    private volatile int state;
    private static final long stateOffset;

    static {
        try {
            Field field =
                    Unsafe.class.getDeclaredField("theUnsafe");
            field.setAccessible(true);
            unsafe = (Unsafe) field.get(null);

            stateOffset = unsafe.objectFieldOffset
                    (CustomAbstractQueuedSynchronizer.class.getDeclaredField("state"));
        } catch (Exception ex) { throw new Error(ex); }
    }

    protected final int getState() {
        return state;
    }

    protected final void setState(int newState) {
        state = newState;
    }

    protected final boolean compareAndSetState(int expect, int update) {
        // See below for intrinsics setup to support this
        return unsafe.compareAndSwapInt(this, stateOffset, expect, update);
    }

}

重入锁的实现如下:
实现逻辑很简单,当有线程获得锁时调用setExclusiveOwnerThread方法设置当前获得锁的线程。当cas获得锁失败,判断是否是同一个线程再次获得锁,如果是则state加1。释放锁的时state减1。如果state为0,清空当前获得锁的线程。

public class SpinReentrantLock implements Lock {


    private Sync sync;

    public SpinReentrantLock() {
        sync = new SimpleNonfairSync();
    }

 abstract static class Sync extends CustomAbstractQueuedSynchronizer {
        protected abstract void lock();

        protected abstract void unlock();
    }
  static final class SimpleNonfairSync extends Sync {
        @Override
        protected void lock() {
            boolean flag;
            do {
                Thread current = Thread.currentThread();
                if (flag = compareAndSetState(0, 1)) {
                    //System.out.println(current.getName() + " 获得锁");
                    setExclusiveOwnerThread(current);

                } else if (current == getExclusiveOwnerThread()) {
                    int c = getState();
                    int nextc = c + 1;
                    if (nextc < 0) {
                        // overflow
                        throw new Error("Maximum lock count exceeded");
                    }
                    //System.out.println(current.getName() + " 重入state:" + nextc);
                    setState(nextc);
                    flag = true;

                }
            }
            while (!flag);

        }

        @Override
        protected void unlock() {
            int c = getState() - 1;
            if (Thread.currentThread() != getExclusiveOwnerThread())
                throw new IllegalMonitorStateException();
            if (c == 0) {
                setExclusiveOwnerThread(null);
            }
           // System.out.println(Thread.currentThread().getName() + " 释放锁state: " + c);
            setState(c);
        }

    }

    @Override
    public void lock() {
        sync.lock();
    }

    @Override
    public void unlock() {
        sync.unlock();
    }
 }

三、队列

当并发比较高的时候大量的CAS失败可能导致SpinReentrantLock锁的效率比较低,且自旋比较消耗CUP。所以当线程获取锁失败,我们把线程放入队列中并挂起。当线程释放锁时唤起挂起的线程。
image.png
在抽象类CustomAbstractQueuedSynchronizer中加入一个线程安全的链表threadQueue用于存放被挂起的线程。head变量的作用是记录队列的头结点。acquire方法使用的是模板设计模式,tryAcquire获得锁的逻辑,交由子类实现,当线程获得锁失败,调用LockSupport.park(this)挂起线程,如果获得锁成功线程出队,并更新head。完整代码如下

public abstract class CustomAbstractQueuedSynchronizer extends AbstractOwnableSynchronizer {
    /**
     * The synchronization state.
     */
    private volatile int state;

    private static final Unsafe unsafe;
    private static final long stateOffset;

    private transient volatile Thread head;

    protected Queue<Thread> threadQueue = new ConcurrentLinkedQueue<>();

    static {
        try {
            Field field =
                    Unsafe.class.getDeclaredField("theUnsafe");
            field.setAccessible(true);
            unsafe = (Unsafe) field.get(null);

            stateOffset = unsafe.objectFieldOffset
                    (CustomAbstractQueuedSynchronizer.class.getDeclaredField("state"));
        } catch (Exception ex) {
            throw new Error(ex);
        }
    }

    protected final int getState() {
        return state;
    }

    protected final void setState(int newState) {
        state = newState;
    }


    protected final boolean compareAndSetState(int expect, int update) {
        // See below for intrinsics setup to support this
        return unsafe.compareAndSwapInt(this, stateOffset, expect, update);
    }

    public Thread getHead() {
        return head;
    }

    public void setHead(Thread head) {
        this.head = head;
    }

    /**
     * 获取锁的逻辑,交由子类实现
     * @param arg
     * @return
     */
    protected boolean tryAcquire(int arg) {
        throw new UnsupportedOperationException();
    }
    /**
     * 判断队列中是否为空
     * @return
     */
    public final boolean hasQueuedPredecessors() {
        return threadQueue.isEmpty();
    }

/**
     * 释放锁的逻辑,交由子类实现
     * @param arg
     * @return
     */
    protected boolean tryRelease(int arg) {
        throw new UnsupportedOperationException();
    }

    /**
     * 获得锁和线程入队,以及唤醒后的逻辑
     * @param arg
     */
    public final void acquire(int arg) {
        Thread current = Thread.currentThread();
        //调用tryAcquire获得锁失败,线程放入队列中
        if (!tryAcquire(arg) && threadQueue.add(current)) {
            if (getHead() == null) {
                setHead(threadQueue.peek());
            }
            //只要获得锁成功才能跳出循环
            for (; ; ) {
                if (current == getHead() && tryAcquire(arg)) {
                //任务出队
                    threadQueue.poll();
                    //头部元素出队之后,更新头元素
                    setHead(threadQueue.peek());
                    return;
                }
               // System.out.println("挂起线程: " + current.getName() + " size: " + Arrays.toString(threadQueue.toArray()));
               //获得锁失败,挂起线程
                LockSupport.park(this);
            }
        }
    }
}

Sync的unlock方法逻辑如下
1. 重写tryRelease方法,当sate等于0的时候返回true表示释放锁成功。
2. 如果释放锁成功,则调用threadQueue.peek()方法获得头结点,并通过LockSupport.unpark(poll)唤起线程。

 abstract static class Sync extends CustomAbstractQueuedSynchronizer {
        protected abstract void lock();

        protected void unlock() {
         if (tryRelease(1)){
             Thread poll = threadQueue.peek();
             if (poll != null) {
                 //System.out.println(Thread.currentThread().getName() + " 唤起线程: " + poll.getName() + " size: "+threadQueue.size());
                 LockSupport.unpark(poll);
             } else {
                 setHead(null);
             }
         }
        }

        @Override
        protected boolean tryRelease(int arg) {
            int c = getState() - 1;
            if (Thread.currentThread() != getExclusiveOwnerThread()){
                throw new IllegalMonitorStateException();
            }
            boolean free = false;
            if (c == 0) {
                free=true;
                setExclusiveOwnerThread(null);
            }
            // System.out.println(Thread.currentThread().getName() + " 释放锁state: " + c);
            setState(c);
            return free;
        }
    }

NonfairSync类方法如下。
正如上面提到acquire使用的是模板设计模式,获得锁的逻辑由tryAcquire实现。(tryAcquire的实现是一种非公平的模式)

 static final class NonfairSync extends Sync {

        @Override
        protected void lock() {
            Thread current = Thread.currentThread();
            if (compareAndSetState(0, 1)) {
                // System.out.println(current.getName() + " 获得锁");
                setExclusiveOwnerThread(current);

            }else {
                acquire(1);
            }

        }

        @Override
        protected boolean tryAcquire(int arg) {
            return nonfairTryAcquire(arg);
        }

        final boolean nonfairTryAcquire(int acquires) {
            final Thread current = Thread.currentThread();
            int c = getState();
            if (c == 0) {
                if (compareAndSetState(0, acquires)) {
                    //   System.out.println(current.getName() + " 获得锁");
                    setExclusiveOwnerThread(current);
                    return true;
                }
            } else if (current == getExclusiveOwnerThread()) {
                int nextc = c + acquires;
                if (nextc < 0) // overflow
                    throw new Error("Maximum lock count exceeded");
                // System.out.println(current.getName() + " 重入state:" + nextc);
                setState(nextc);
                return true;
            }
            return false;
        }
    }

测试

public class Main {


    static int value = 0;
    public static void main(String[] args) throws InterruptedException {
        SpinReentrantLock spinReentrantLock = new SpinReentrantLock(true);
        final CyclicBarrier cyclicBarrier = new CyclicBarrier(1000);
        final CountDownLatch countDownLatch = new CountDownLatch(1000);
        long start = System.currentTimeMillis();
        for (int i = 0; i < 1000 ; i++) {
            new Thread(new Runnable() {
                public void run() {
                    try {
                        cyclicBarrier.await();
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                    spinReentrantLock.lock();
                    // System.out.println(Thread.currentThread().getName() + " 获得锁");
                    for (int j = 0; j < 1000; j++) {
                        value++;
                    }
                    spinReentrantLock.unlock();
                    countDownLatch.countDown();
                }
            },"thread:"+i).start();
        }
        countDownLatch.await();
        long end = System.currentTimeMillis();
        System.out.println("执行时间:" + (end - start));
        System.out.println("value: " + value);
    }
}
执行时间:70
value: 1000000

四、公平锁

队列中的任务线程优先执行,后到的线程只能只能排队等待。代码实现如下:

可以看到相对于非公平锁,公平锁的实现只是在获得锁前,调用hasQueuedPredecessors方法检查队列中是否有值。

  static final class FairSync extends Sync {

        @Override
        protected void lock() {
            acquire(1);
        }

        protected final boolean tryAcquire(int acquires) {
            final Thread current = Thread.currentThread();
            int c = getState();
            if (c == 0) {
                if (!hasQueuedPredecessors() &&
                        compareAndSetState(0, acquires)) {
                    setExclusiveOwnerThread(current);
                    return true;
                }
            } else if (current == getExclusiveOwnerThread()) {
                int nextc = c + acquires;
                if (nextc < 0)
                    throw new Error("Maximum lock count exceeded");
                setState(nextc);
                return true;
            }
            return false;
        }
    }

测试

public class Main {
    static int value = 0;
    public static void main(String[] args) throws InterruptedException {
        SpinReentrantLock spinReentrantLock = new SpinReentrantLock(true);
        final CyclicBarrier cyclicBarrier = new CyclicBarrier(10);
        final CountDownLatch countDownLatch = new CountDownLatch(10);
        long start = System.currentTimeMillis();
        for (int i = 0; i < 10 ; i++) {
            new Thread(new Runnable() {
                public void run() {
                    try {
                        cyclicBarrier.await();
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                    spinReentrantLock.lock();
                    // System.out.println(Thread.currentThread().getName() + " 获得锁");
                    for (int j = 0; j < 1000; j++) {
                        value++;
                    }
                    spinReentrantLock.unlock();
                    countDownLatch.countDown();
                }
            },"thread:"+i).start();
        }
        countDownLatch.await();
        long end = System.currentTimeMillis();
        System.out.println("执行时间:" + (end - start));
        System.out.println("value: " + value);
    }
}

结果,可以看到任务都是按照入队的顺序执行。

thread:0获得锁
挂起线程: thread:6 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main], Thread[thread:6,5,main]]
挂起线程: thread:5 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main], Thread[thread:6,5,main]]
挂起线程: thread:1 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main]]
挂起线程: thread:2 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main]]
挂起线程: thread:3 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main]]
挂起线程: thread:4 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main]]
挂起线程: thread:7 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main], Thread[thread:6,5,main], Thread[thread:7,5,main]]
挂起线程: thread:9 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main], Thread[thread:6,5,main], Thread[thread:7,5,main], Thread[thread:8,5,main]]
挂起线程: thread:8 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main], Thread[thread:6,5,main], Thread[thread:7,5,main], Thread[thread:8,5,main]]
thread:9获得锁
thread:1获得锁
thread:2获得锁
thread:3获得锁
thread:4获得锁
thread:5获得锁
thread:6获得锁
thread:7获得锁
thread:8获得锁
执行时间:3
value: 10000

五、总结:

最后附上SpinReentrantLock完整实现。

public class SpinReentrantLock implements Lock {


    private Sync sync;

    public SpinReentrantLock() {
        sync = new NonfairSync();

    }

    public SpinReentrantLock(boolean fair) {
        if (fair){
            sync = new FairSync();
        }else {
            sync = new NonfairSync();
        }
    }

    static final class FairSync extends Sync {

        @Override
        protected void lock() {
            acquire(1);
        }

//        public final void acquire(int arg) {
//            Thread current = Thread.currentThread();
//            if (!tryAcquire(arg) &&threadQueue.add(current)) {
//                if (getHead() == null) {
//                    setHead(threadQueue.peek());
//                }
//                for (; ; ) {
//                    if (current == getHead() && tryAcquire(arg)) {
//                        threadQueue.poll();
//                        //头部元素出队之后,更新头元素
//                        setHead(threadQueue.peek());
//                        return;
//                    }
//                     System.out.println("挂起线程: " +current.getName()+" size: "+ Arrays.toString(threadQueue.toArray()));
//                    LockSupport.park(this);
//                }
//            }
//        }

        protected final boolean tryAcquire(int acquires) {
            final Thread current = Thread.currentThread();
            int c = getState();
            if (c == 0) {
                if (!hasQueuedPredecessors() &&
                        compareAndSetState(0, acquires)) {
                    setExclusiveOwnerThread(current);
                    return true;
                }
            } else if (current == getExclusiveOwnerThread()) {
                int nextc = c + acquires;
                if (nextc < 0)
                    throw new Error("Maximum lock count exceeded");
                setState(nextc);
                return true;
            }
            return false;
        }
    }

    abstract static class Sync extends CustomAbstractQueuedSynchronizer {
        protected abstract void lock();

        protected void unlock() {
         if (tryRelease(1)){
             Thread poll = threadQueue.peek();
             if (poll != null) {
                 //System.out.println(Thread.currentThread().getName() + " 唤起线程: " + poll.getName() + " size: "+threadQueue.size());
                 LockSupport.unpark(poll);
             } else {
                 setHead(null);
             }
         }
        }

        @Override
        protected boolean tryRelease(int arg) {
            int c = getState() - 1;
            if (Thread.currentThread() != getExclusiveOwnerThread()){
                throw new IllegalMonitorStateException();
            }
            boolean free = false;
            if (c == 0) {
                free=true;
                setExclusiveOwnerThread(null);
            }
             System.out.println(Thread.currentThread().getName() + " 释放锁state: " + c);
            setState(c);
            return free;
        }
    }

    static final class NonfairSync extends Sync {

        @Override
        protected void lock() {
            Thread current = Thread.currentThread();
            if (compareAndSetState(0, 1)) {
                // System.out.println(current.getName() + " 获得锁");
                setExclusiveOwnerThread(current);

            }else {
                acquire(1);
            }
//            else if (!tryAcquire(1) && threadQueue.add(current)) {
//                //每次都是从头部元素开始唤起
//                if (getHead() == null) {
//                    setHead(threadQueue.peek());
//                }
//                for (; ; ) {
//                    if (current == getHead() && tryAcquire(1)) {
//                        threadQueue.poll();
//                        //头部元素出队之后,更新头元素
//                        setHead(threadQueue.peek());
//                        return;
//                    }
//                  //  System.out.println("挂起线程: " +current.getName()+" size: "+ Arrays.toString(threadQueue.toArray()));
//                    LockSupport.park(this);
//                }
//            }
        }

        @Override
        protected boolean tryAcquire(int arg) {
            return nonfairTryAcquire(arg);
        }

        final boolean nonfairTryAcquire(int acquires) {
            final Thread current = Thread.currentThread();
            int c = getState();
            if (c == 0) {
                if (compareAndSetState(0, acquires)) {
                    //   System.out.println(current.getName() + " 获得锁");
                    setExclusiveOwnerThread(current);
                    return true;
                }
            } else if (current == getExclusiveOwnerThread()) {
                int nextc = c + acquires;
                if (nextc < 0) // overflow
                    throw new Error("Maximum lock count exceeded");
                // System.out.println(current.getName() + " 重入state:" + nextc);
                setState(nextc);
                return true;
            }
            return false;
        }
    }

    static final class SimpleNonfairSync extends Sync {
        @Override
        protected void lock() {
            boolean flag;
            do {
                Thread current = Thread.currentThread();
                if (flag = compareAndSetState(0, 1)) {
                    System.out.println(current.getName() + " 获得锁");
                    setExclusiveOwnerThread(current);

                } else if (current == getExclusiveOwnerThread()) {
                    int c = getState();
                    int nextc = c + 1;
                    if (nextc < 0) {
                        // overflow
                        throw new Error("Maximum lock count exceeded");
                    }
                    System.out.println(current.getName() + " 重入state:" + nextc);
                    setState(nextc);
                    flag = true;

                }
            }
            while (!flag);

        }

        @Override
        protected void unlock() {
            int c = getState() - 1;
            if (Thread.currentThread() != getExclusiveOwnerThread())
                throw new IllegalMonitorStateException();
            if (c == 0) {
                setExclusiveOwnerThread(null);
            }
           // System.out.println(Thread.currentThread().getName() + " 释放锁state: " + c);
            setState(c);
        }

    }

    @Override
    public void lock() {
        sync.lock();
    }

    @Override
    public void unlock() {
        sync.unlock();
    }
}

上述实现的锁功能还比较简单,比如暂时还不支持响应中断,或者超时挂起等,但实现起来并不难,这里就不在赘述。

下一节我们探讨线程并发工具的基石AQS

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值