手写分布式锁(基于Zookeeper)
关于分布式锁有很多种实现方式,但是说到底由于是分布式的所以需要加一层共享区域来实现锁机制(一层不够就加两层☺)
- 那Zookeeper是基于什么实现分布式锁的呢?
- 借助zk(有序临时节点+watcher机制)实现
- 那并发请求来zk创建节点是如何保证创建节点的唯一性?
- zk底层是靠ConcurrentHashMap的put()实现.
基于zk的分布式Demo
借助了ThreadLocal 和 CountDownLatch实现
package zookeeper;
import org.apache.zookeeper.*;
import org.apache.zookeeper.data.Stat;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.stream.Collectors;
import static java.lang.System.out;
/**
* <Description> <br>
*
* @author shi.yuwen<br>
* @version 1.0<br>
* @taskId: <br>
* @see zookeeper <br>
*/
public class DistributedLock implements Lock{
private ZooKeeper zooKeeper = null;
private String rootPath = "/lock";
private String lockName;
private ThreadLocal<String> nodeId = new ThreadLocal<>();
private ThreadLocal<String> preNodeId = new ThreadLocal<>();
private CountDownLatch signal = new CountDownLatch(1);
private static final Integer SESSION_TIMEOUT = 3000;
private static final byte[] DATA_VALUE = new byte[0];
public DistributedLock(String url, String lockName) {
this.lockName = lockName;
try {
zooKeeper = new ZooKeeper(url, SESSION_TIMEOUT, watchedEvent -> {
if(watchedEvent.getState() == Watcher.Event.KeeperState.SyncConnected) {
signal.countDown();
}
});
signal.await();
Stat stat = zooKeeper.exists(rootPath, false);
// 创建锁根节点
if(stat == null){
zooKeeper.create(rootPath, DATA_VALUE, ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);
}
} catch (Exception e) {
e.printStackTrace();
}
}
@Override
public void lock() {
try {
if(tryLock()){
return;
}
waitForLock();
} catch (Exception e) {
e.printStackTrace();
}
}
private void waitForLock(){
try {
//判断前一个节点是否存在
CountDownLatch latch = new CountDownLatch(1);
Stat stat = zooKeeper.exists(preNodeId.get(), new LockWatcher(latch));
//前一个节点是否有值,有值则监听等待是否删除
if (null != stat) {
long startTime = System.currentTimeMillis();
out.println("等待线程: " + Thread.currentThread().getName() + " release lock");
latch.await();
long endTime = System.currentTimeMillis();
out.println("当前线程: " + Thread.currentThread().getName() + " get lock" + "nodeId: " + nodeId.get() + "等待时间: " + (endTime-startTime));
}
} catch (Exception e) {
e.printStackTrace();
}
}
@Override
public void lockInterruptibly() throws InterruptedException {
}
@Override
public boolean tryLock() {
try {
// 创建临时节点
String currentNode = zooKeeper.create(rootPath + "/" + lockName, DATA_VALUE, ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL_SEQUENTIAL);
out.println("当前线程: " + Thread.currentThread().getName() + "创建节点: " + currentNode);
// 获取"/lock"下面所有的有序子节点,重新遍历排序
List<String> nodes = zooKeeper.getChildren(rootPath, false).stream().map(e -> {return rootPath + "/" + e;})
.sorted(String::compareTo)
.collect(Collectors.toList());
int index = nodes.indexOf(currentNode);
this.nodeId.set(currentNode);
if (index == 0) {
out.println("当前线程: " + Thread.currentThread().getName() + "get lock" + "nodeId: " + currentNode);
return true;
}
this.preNodeId.set(nodes.get(index - 1));
} catch (Exception e) {
e.printStackTrace();
}
return false;
}
@Override
public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
return false;
}
@Override
public void unlock() {
out.println(Thread.currentThread().getName() + "解锁......");
try {
if (null != nodeId.get()) {
zooKeeper.delete(nodeId.get(), -1);
}
out.println(Thread.currentThread().getName() + "解锁成功......");
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
nodeId.remove();
preNodeId.remove();
}
}
@Override
public Condition newCondition() {
return null;
}
/**
* Description <br>
* 添加监听类监听删除事件
**/
class LockWatcher implements Watcher {
private CountDownLatch latch;
public LockWatcher(CountDownLatch latch) {
this.latch = latch;
}
@Override
public void process(WatchedEvent watchedEvent) {
if(watchedEvent.getType() == Event.EventType.NodeDeleted) {
latch.countDown();
}
}
}
}
测试代码
public class TestThread {
private static volatile Integer store = 3;
private static Lock myLock = new DistributedLock("192.168.245.128:2181", "myLock");
public static void main(String[] args) throws Exception {
for (int i = 1; i <= 10; i++) {
new Thread(()->{
Resource resource = new Resource();
try {
resource.sale();
} catch (Exception e) {
e.printStackTrace();
}
}, String.valueOf(i)).start();
}
}
static class Resource {
public void sale(){
myLock.lock();
try {
if (store < 0) {
throw new RuntimeException("库存异常!!!!");
}
if (store > 0) {
store--;
System.out.println("线程: " + Thread.currentThread().getName() + " 库存: " + store);
}else {
System.out.println("库存售罄!!");
}
}catch (Exception e) {
e.printStackTrace();
}finally {
myLock.unlock();
}
}
}
}
简单说明下:
- 利用zk节点唯一性每一个线程进入创建临时有序节点时,都会遍历根/lock下所有的临时有序节点并且排序(这个排序是关键).
- 因为是临时有序节点,那按排序第一个便是得到锁的那个线程,其余线程则会等待(这边用的是CountDownLatch类)
- 什么时候等待结束? 这边使用了ThreadLocal,每个线程都会记录当前节点值和他的上一个节点值.使用ZK的监听Watcher,当节点被删除时则会执行CountDownLatch的countDown()方法,那么其后一个节点将会获得锁…一次类推,直到所有节点执行完.(没错放宽了讲这相当于是公平锁了).
代码写的比较随意,如有错误, 请大佬多多指点!!!