获取锁原理:
1、客户端通过 create() 接口创建一个已存在的持久节点的临时顺序子节点,然后返回一个带序号节点名称;
2、客户端调用 getChildren() 接口获取这个持久节点的子节点列表;
3、通过子节点列表判断自己在已存在子节点中的顺序,如果是第一个子节点(序号最小),那么这个客户端获取到了锁;否在判断自己获取节点的类型,如果是读节点,那么判断是否有比自己小的写节点,若没有,则获取到了读锁,若有则通过 exist() 接口向比自己小的最后一个写节点注册一个Watcher监听;如果是写节点,那么通过 exist() 接口向比自己小的最有一个节点注册一个Watcher监听;
4,等待Watcher通知,如果客户端创建的是读节点,则获取到了读锁;否在继续进行第二步。
释放锁原理:
1、因为客户端创建的节点是一个临时节点,所以当获取到锁的客户端发生了岩机,zookeeper上的这个临时节点会被移除;
2、获取到锁的客户端,正常执行完业务逻辑后,客户端会主动将自己创建的临时节点删除。
简单的代码实现:代码中只实现了正常的获取锁和释放锁的逻辑。
// 定义一个锁接口
public interface Lock {
boolean lock();
void unLock();
}
/**
* 分布式共享锁:
* 支持可重入.读线程在获取了读锁后还可以获取读锁;写线程在获取了写锁之后既可以再次获取写锁又可以获取读锁
*
*/
public class SharedLock {
private static final String SHARE_LOCK_ROOT = "/share-lock-root";
private final String fatherPath;
private final ReadLock readLock;
private final WriteLock writeLock;
private final ConcurrentMap<Thread, LockData> threadData = Maps.newConcurrentMap();
@SuppressWarnings("unused")
private static class LockData {
final Thread owningThread;
final String lockSign;
final String node;
final AtomicInteger lockCount = new AtomicInteger(0);
private LockData(Thread owningThread, String lockSign, String node)
{
this.owningThread = owningThread;
this.lockSign = lockSign;
this.node = node;
}
}
public SharedLock(String lockName) {
this.fatherPath = PathUtils.validatePath(SHARE_LOCK_ROOT + "-" + lockName);
initFatherPath();
readLock = new ReadLock(this);
writeLock = new WriteLock(this);
}
public ReadLock readLock() {
return readLock;
}
public WriteLock writeLock() {
return writeLock;
}
private void initFatherPath() {
ZooKeeper zooKeeper = LockUtils.newZookeeper();
try {
zooKeeper.create(this.fatherPath, this.fatherPath.getBytes(), Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);
} catch (KeeperException.NodeExistsException e) {
} catch (Exception e) {
e.printStackTrace();
}
}
/**
*
* @description
* @param sign
* @return 0-当前线程不持任何锁; 1-当前线程持有的锁支持重入; 2-当前线程持有读锁,不支持重入写锁
* @author tangjingjing
* @date 2019年6月20日
*/
protected final int isRetryLock(String sign) {
Thread currentThread = Thread.currentThread();
LockData lockData = threadData.get(currentThread);
if (null == lockData) {
return 0;
}
// 读线程在获取了读锁后还可以获取读锁;写线程在获取了写锁之后既可以再次获取写锁又可以获取读锁
if (sign.equals(lockData.lockSign) || "READ".equals(sign)) {
// 锁标志一致
lockData.lockCount.incrementAndGet();
return 1;
} else {
return 2;
}
}
protected final boolean saveLockData(String sign, String node) {
Thread currentThread = Thread.currentThread();
LockData lockData = new LockData(currentThread, sign, node);
threadData.put(currentThread, lockData);
return true;
}
protected final boolean decrementRetry() {
Thread currentThread = Thread.currentThread();
LockData lockData = threadData.get(currentThread);
if (null == lockData) {
return false;
}
int count = lockData.lockCount.decrementAndGet();
if (count > 0) {
return false;
}
threadData.remove(currentThread);
deletedChildrenNode("/"+lockData.node);
return true;
}
protected final String createChildrenNode(String path) {
String realPath = null;
ZooKeeper zooKeeper = LockUtils.newZookeeper();
try {
realPath = zooKeeper.create(this.fatherPath + path, "".getBytes(), Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL_SEQUENTIAL);
realPath = realPath.substring(realPath.lastIndexOf("/")+1);
} catch (KeeperException | InterruptedException e) {
e.printStackTrace();
}
return realPath;
}
protected final void deletedChildrenNode(String path) {
// 允许重试5次
ZooKeeper zooKeeper = LockUtils.newZookeeper();
int retry = 5;
while (retry-- > 0) {
try {
zooKeeper.delete(this.fatherPath + path, -1);
retry = 0;
} catch (InterruptedException | KeeperException e) {
e.printStackTrace();
}
}
}
protected final List<String> getChildrens() {
List<String> children = new ArrayList<>();
ZooKeeper zooKeeper = LockUtils.newZookeeper();
try {
children = zooKeeper.getChildren(this.fatherPath, false);
} catch (KeeperException | InterruptedException e) {
e.printStackTrace();
}
return children;
}
protected final void watcherChildrenNode(CountDownLatch watcher, String node) {
ZooKeeper zooKeeper = LockUtils.newZookeeper();
try {
Stat stat = zooKeeper.exists(this.fatherPath + "/" + node, new Watcher() {
@Override
public void process(WatchedEvent event) {
if (EventType.NodeDeleted == event.getType()) {
watcher.countDown();
}
}
});
if (null == stat) {
watcher.countDown();
}
} catch (KeeperException | InterruptedException e) {
e.printStackTrace();
}
}
}
public abstract class ReadWriteLock implements Lock {
protected SharedLock sharedLock;
ReadWriteLock(SharedLock sharedLock) {
this.sharedLock = sharedLock;
}
@Override
public final boolean lock() {
int result = sharedLock.isRetryLock(getLockSign());
if (1 == result) {
return true;
} else if (2 == result) {
return false;
}
String node = sharedLock.createChildrenNode(PathUtils.validatePath("/"+getLockSign()));
if (StringUtils.isEmpty(node)) {
return false;
}
List<String> childrenNodes = sharedLock.getChildrens();
if (null == childrenNodes || childrenNodes.isEmpty()) {
return false;
}
if (node.equals(childrenNodes.get(0)) || selefLock(childrenNodes, node)) {
// 自己是第一个节点 或者 具体实现返回true
return sharedLock.saveLockData(getLockSign(), node);
}
return false;
}
protected abstract String getLockSign();
protected abstract boolean selefLock(List<String> childrenNodes, String node);
@Override
public final void unLock() {
sharedLock.decrementRetry();
}
}
public class WriteLock extends ReadWriteLock {
private static final String SIGN = "write";
WriteLock(SharedLock sharedLock) {
super(sharedLock);
}
@Override
protected String getLockSign() {
return SIGN;
}
@Override
protected boolean selefLock(List<String> childrenNodes, String node) {
boolean result = false;
while (!result) {
int lastLessSelfNode = -1;
// 寻找比自己小的最后一个节点
for (int i=1; i<childrenNodes.size(); i++) {
if (node.equals(childrenNodes.get(i))) {
lastLessSelfNode = i-1;
break;
}
}
CountDownLatch watcherNode = new CountDownLatch(1);
// 监控比自己小的最后一个节点
sharedLock.watcherChildrenNode(watcherNode, childrenNodes.get(lastLessSelfNode));
try {
watcherNode.await();
} catch (InterruptedException e) {
e.printStackTrace();
break;
}
childrenNodes = sharedLock.getChildrens();
if (null == childrenNodes || childrenNodes.isEmpty()) {
break;
}
if (node.equals(childrenNodes.get(0))) {
result = true;
}
}
return result;
}
}
public class ReadLock extends ReadWriteLock {
private static final String SIGN = "read";
ReadLock(SharedLock sharedLock) {
super(sharedLock);
}
@Override
protected String getLockSign() {
return SIGN;
}
@Override
protected boolean selefLock(List<String> childrenNodes, String node) {
int lastLessSelfWrite = -1;
// 寻找比自己小的最后一个写节点
for (int i= 1; i<childrenNodes.size(); i++) {
if (node.equals(childrenNodes.get(i))) {
break;
} else if (!childrenNodes.get(i).startsWith(SIGN)) {
lastLessSelfWrite = i;
}
}
if (-1 == lastLessSelfWrite) {
// 没有比自己小的写锁
return true;
} else {
// 有写锁
CountDownLatch watcherWriter = new CountDownLatch(1);
// 监控比自己小的最后一个写锁
sharedLock.watcherChildrenNode(watcherWriter, childrenNodes.get(lastLessSelfWrite));
try {
watcherWriter.await();
return true;
} catch (InterruptedException e) {
e.printStackTrace();
}
}
return false;
}
}
public class LockUtils {
private static ZooKeeper zooKeeper;
public static ZooKeeper newZookeeper() {
if (null == zooKeeper) {
init();
}
return zooKeeper;
}
private static void init() {
if (null == zooKeeper) {
synchronized (LockUtils.class) {
if (null == zooKeeper) {
try {
CountDownLatch countDownLatch = new CountDownLatch(1);
zooKeeper = new ZooKeeper("127.0.0.1:2181", 5000, new Watcher() {
@Override
public void process(WatchedEvent event) {
if (KeeperState.SyncConnected == event.getState()) {
countDownLatch.countDown();
}
}
});
countDownLatch.await();
System.out.println("=======init success========");
} catch (Exception e) {
e.printStackTrace();
}
}
}
}
}
}
测试结果:
class TestObject {
private Object obj = null;
SharedLock sharedLock = new SharedLock("test");
Lock readLock = sharedLock.readLock();
Lock writeLock = sharedLock.writeLock();
public void write(Object obj) {
if (writeLock.lock()) {
try {
this.obj = obj;
System.out.println(Thread.currentThread().getName() + "\t" + obj);
} catch (Exception e) {
e.printStackTrace();
} finally {
writeLock.unLock();
}
}
}
public void read() {
if (readLock.lock()) {
try {
System.out.println(Thread.currentThread().getName() + "\t" + obj);
} catch (Exception e) {
e.printStackTrace();
} finally {
readLock.unLock();
}
}
}
}
}
public class TestSharedLock {
public static void main(String[] args) {
LoggerContext loggerContext = (LoggerContext) LoggerFactory.getILoggerFactory();
loggerContext.getLogger("org.apache.zookeeper").setLevel(Level.valueOf("info"));
new TestSharedLock().test4();
}
public void test4() {
TestObject q = new TestObject();
CountDownLatch countDownLatch = new CountDownLatch(1);
new Thread(new Runnable() {
@Override
public void run() {
try {
countDownLatch.await();
} catch (Exception e) {
e.printStackTrace();
}
q.write("之后只能读到 : helloworld!");
}
}, "thread-write-1").start();
for (int i = 1; i <= 50; i++) {
new Thread(new Runnable() {
@Override
public void run() {
try {
countDownLatch.await();
} catch (Exception e) {
e.printStackTrace();
}
q.read();
}
}).start();
}
new Thread(new Runnable() {
@Override
public void run() {
try {
countDownLatch.await();
} catch (Exception e) {
e.printStackTrace();
}
q.write("之后只能读到 : nihao!");
}
}, "thread-write-2").start();
countDownLatch.countDown();
}
}
Thread-11 null
Thread-5 null
Thread-6 null
Thread-15 null
Thread-17 null
Thread-19 null
Thread-21 null
Thread-23 null
Thread-25 null
Thread-27 null
Thread-31 null
Thread-14 null
Thread-16 null
Thread-18 null
Thread-20 null
Thread-22 null
Thread-24 null
Thread-26 null
Thread-28 null
Thread-32 null
Thread-30 null
Thread-37 null
Thread-39 null
Thread-35 null
Thread-41 null
Thread-33 null
Thread-29 null
Thread-36 null
Thread-38 null
Thread-40 null
Thread-34 null
Thread-43 null
Thread-45 null
Thread-47 null
Thread-49 null
Thread-42 null
Thread-46 null
Thread-44 null
Thread-48 null
thread-write-2 之后只能读到 : nihao!
thread-write-1 之后只能读到 : helloworld!
Thread-3 之后只能读到 : helloworld!
Thread-1 之后只能读到 : helloworld!
Thread-13 之后只能读到 : helloworld!
Thread-9 之后只能读到 : helloworld!
Thread-7 之后只能读到 : helloworld!
Thread-2 之后只能读到 : helloworld!
Thread-8 之后只能读到 : helloworld!
Thread-0 之后只能读到 : helloworld!
Thread-4 之后只能读到 : helloworld!
Thread-12 之后只能读到 : helloworld!
Thread-10 之后只能读到 : helloworld!