Redis实现分布式读写锁
前言
使用Jedis构建redis连接池,使用lua脚本命令保证redis的事务,以实现分布式的读写锁。项目中需要用到分布式的读写锁,开始使用Redisson的读写锁实现,压测的时候时不时会抛异常获取锁超时,初步判断是Redisson中redis连接池设置的太小。由于项目中还自己另外维护着一个redis的连接池JedisPool,故决定自己来实现分布式的可重入读写锁。
设计思路
目标:读锁可以被多个线程获取,同一个线程可以重入读锁,如果获取读锁的时候存在写锁,则需要等待写锁被释放才能获取;写锁的获取的时候,若存在读锁,则需要等待所有的读锁释放之后才能被获取,如果有线程正在获取写锁,其他获取读锁的线程将等待写锁被获取并释放之后才能获取读锁,写锁只能被一个线程持有,可以重入。
方案:
假设我们要操作“abc”这个读写锁
- 一个读锁在redis中的存在形式为一个hash结构,key为read_lock_abc,值为键值对组成的hash,键值对的键为thread_id(这个id由连接池id+获取锁的线程id,来区分分布式中的不同线程),值为该线程重入读锁的次数。
- 一个写锁分为两部分,一个是key为write_lock_abc,值为thread_id(同上);另外一个key为reentrant_write_lock_abc,值为被重入的次数。
实现
RedisReadWriteLock.java 该类用于维护读写锁的单例
public class RedisReadWriteLock {
//读锁
private static volatile RedisReadLock redisReadLock;
//写锁
private static volatile RedisWriteLock redisWriteLock;
//双重检查锁实现单例
public static RedisReadLock readLock(){
if(redisReadLock == null){
synchronized (RedisReadLock.class){
if (redisReadLock == null){
redisReadLock = new RedisReadLock();
}
}
}
return redisReadLock;
}
public static RedisWriteLock writeLock(){
if(redisWriteLock == null){
synchronized (RedisWriteLock.class){
if (redisWriteLock == null){
redisWriteLock = new RedisWriteLock();
}
}
}
return redisWriteLock;
}
// 构建锁的key
public static String getReadLockKey(String name){
return RedisLockConf.READ_LOCK_PREFIX + name;
}
public static String getWriteLockKey(String name){
return RedisLockConf.WRITE_LOCK_PREFIX + name;
}
public static String getReentrantWriteLockKey(String name){
return RedisLockConf.REENTRANT_WRITE_LOCK_PREFIX + name;
}
//由连接池id+获取锁的线程id,来区分分布式中的不同线程
public static String getThreadUid(){
return JedisConnectPoll.JEDIS_CONNECT_POLL_UUID.toString() + ":" + Thread.currentThread().getId();
}
}
public class RedisLockConf {
public static final String READ_LOCK_PREFIX = "read_lock_";
public static final String WRITE_LOCK_PREFIX = "write_lock_";
public static final String REENTRANT_WRITE_LOCK_PREFIX = "reentrant_write_lock_";
}
RedisReadLock.java 读锁的实现
@Slf4j
public class RedisReadLock {
public void lock(String name){
tryLock(name, Long.MAX_VALUE, 30, TimeUnit.SECONDS);
}
public void lock(String name, long leaseTime, TimeUnit unit){
tryLock(name, Long.MAX_VALUE, leaseTime, unit);
}
public boolean tryLock(String name, long waitTime, long leaseTime, TimeUnit unit){
Long waitUntilTime = unit.toMillis(waitTime) + System.currentTimeMillis();
if(waitUntilTime < 0){
waitUntilTime = Long.MAX_VALUE;
}
Long leastTimeLong = unit.toMillis(leaseTime);
StringBuilder sctipt = new StringBuilder();
// write-lock read-lock uuid leaseTime,后面会专门说这段脚本
sctipt.append("if not redis.call('GET',KEYS[1]) then ")
//redis.call('GET',KEYS[1])之类的命令,若没有值返回的布尔类型的false,不是nil
.append("local count = redis.call('HGET',KEYS[2],KEYS[3]);")
.append("if count then ")
.append("count = tonumber(count) + 1;")
.append("redis.call('HSET',KEYS[2],KEYS[3],count);")
.append("else ")
.append("redis.call('HSET',KEYS[2],KEYS[3],1);")
.append("end;")
.append("local t = redis.call('PTTL', KEYS[2]);")
.append("redis.call('PEXPIRE', KEYS[2], math.max(t, ARGV[1]));")
.append("return 1;")
.append("else ")
.append("return 0;")
.append("end;");
for(;;){
if(System.currentTimeMillis() > waitUntilTime){
return false;
}
Long res = (Long) JedisTemplate.operate().eval(sctipt.toString(), 3, RedisReadWriteLock.getWriteLockKey(name), RedisReadWriteLock.getReadLockKey(name), RedisReadWriteLock.getThreadUid(), leastTimeLong.toString());
if(res.equals(1L)){
//successGetReadLock
log.debug("success get read lock, readLock={}", RedisReadWriteLock.getReadLockKey(name));
break;
}else {
//need to wait write lock to be released
log.debug("wait write lock release, writeLock={}", RedisReadWriteLock.getWriteLockKey(name));
try {
TimeUnit.MILLISECONDS.sleep(50);
} catch (InterruptedException e) {
log.error("wait write lock release exception", e);
}
}
}
return true;
}
public void unlock(String name){
StringBuilder sctipt = new StringBuilder();
sctipt.append("local count = redis.call('HGET',KEYS[1],KEYS[2]);")
.append("if count then ")
.append("if (tonumber(count) > 1) then ")
.append("count = tonumber(count) - 1;")
.append("redis.call('HSET',KEYS[1],KEYS[2],count);")
.append("else ")
.append("redis.call('HDEL',KEYS[1],KEYS[2]);")
.append("end;")
.append("end;")
.append("return;");
JedisTemplate.operate().eval(sctipt.toString(), 2, RedisReadWriteLock.getReadLockKey(name), RedisReadWriteLock.getThreadUid());
log.debug("success unlock read lock, readLock={}", RedisReadWriteLock.getReadLockKey(name));
}
}
redis执行lua脚本的命令格式为:EVAL script numkeys key [key …] arg [arg …]
redis 127.0.0.1:6379> EVAL "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}" 2 key1 key2 first second
1) "key1"
2) "key2"
3) "first"
4) "second"
-- 读锁获取的lua脚本
-- 判断不存在写锁
if not redis.call('GET',KEYS[1]) then
local count = redis.call('HGET',KEYS[2],KEYS[3])
-- 如果该线程已经获取了该读锁,就重入,重入次数加1
if count then
count = tonumber(count) + 1
redis.call('HSET',KEYS[2],KEYS[3],count)
else
redis.call('HSET',KEYS[2],KEYS[3],1)
end
-- 检查之前读锁的过期时间,和当前加的读锁的过期时间做对比,更新过期时间
local t = redis.call('PTTL', KEYS[2])
redis.call('PEXPIRE', KEYS[2], math.max(t, ARGV[1]))
return 1
else
-- 若存在写锁,返回获取失败,外层的代码做轮询尝试加锁
return 0
end;
-- 读锁释放的lua脚本
-- 获取锁被当前线程重入的次数
local count = redis.call('HGET',KEYS[1],KEYS[2])
if count then
if (tonumber(count) > 1) then
count = tonumber(count) - 1
redis.call('HSET',KEYS[1],KEYS[2],count)
else
redis.call('HDEL',KEYS[1],KEYS[2])
end
end
return
RedisWriteLock.java 写锁的实现
@Slf4j
public class RedisWriteLock {
public void lock(String name){
tryLock(name, Long.MAX_VALUE, 30, TimeUnit.SECONDS);
}
public void lock(String name, long leaseTime, TimeUnit unit){
tryLock(name, Long.MAX_VALUE, leaseTime, unit);
}
public boolean tryLock(String name, long waitTime, long leaseTime, TimeUnit unit){
Long waitUntilTime = unit.toMillis(waitTime) + System.currentTimeMillis();
if(waitUntilTime < 0){
waitUntilTime = Long.MAX_VALUE;
}
Long leastTimeLong = unit.toMillis(leaseTime);
StringBuilder sctipt = new StringBuilder();
// write-lock reentrant-write-lock uuid leaseTime
sctipt.append("if redis.call('SET',KEYS[1],ARGV[1],'NX','PX',ARGV[2]) then ")
.append("redis.call('SET',KEYS[2],1,'PX',ARGV[2]);")
.append("return 1;")
.append("else ")
.append("if (redis.call('GET',KEYS[1])== ARGV[1]) then ")
.append("local count = redis.call('GET',KEYS[2]);")
.append("if not count then ")
.append("redis.call('SET',KEYS[2],1,'PX',ARGV[2]);")
.append("return 1;")
.append("else ")
.append("count = tonumber(count) + 1;")
.append("redis.call('SET',KEYS[2],count,'PX',ARGV[2]);")
.append("return count;")
.append("end;")
.append("else ")
.append("return 0;")
.append("end;")
.append("end;");
for(;;){
if(System.currentTimeMillis() > waitUntilTime){
return false;
}
Long res = (Long) JedisTemplate.operate().eval(sctipt.toString(), 2, RedisReadWriteLock.getWriteLockKey(name), RedisReadWriteLock.getReentrantWriteLockKey(name), RedisReadWriteLock.getThreadUid(), leastTimeLong.toString());
if(res.equals(1L)){
//successGetWriteLock
log.debug("success get write lock, writeLock = {}", RedisReadWriteLock.getWriteLockKey(name));
for(;;){
if(JedisTemplate.operate().exists(RedisReadWriteLock.getReadLockKey(name))){
log.debug("wait read lock release, readLock = {}", RedisReadWriteLock.getReadLockKey(name));
try {
TimeUnit.MILLISECONDS.sleep(100);
} catch (InterruptedException e) {
log.error("wait read lock release exception", e);
}
}else{
break;
}
}
break;
}else if(res.equals(0L)){
//need to wait write lock to be released
log.debug("wait write lock release, writeLock = {}", RedisReadWriteLock.getWriteLockKey(name));
try {
TimeUnit.MILLISECONDS.sleep(100);
} catch (InterruptedException e) {
log.error("wait write lock release exception", e);
}
}else{
log.debug("success in reentrant write lock, reentrantWriteLock = {}, count now = {}", RedisReadWriteLock.getReentrantWriteLockKey(name), res);
break;
}
}
return true;
}
public void unlock(String name){
StringBuilder sctipt = new StringBuilder();
//write-lock reentrant-write-lock uuid
sctipt.append("if (redis.call('GET',KEYS[1])== ARGV[1]) then ")
.append("local count = redis.call('GET',KEYS[2]);")
.append("if count then ")
.append("if (tonumber(count) > 1) then ")
.append("count = tonumber(count) - 1;")
.append("local live = redis.call('PTTL',KEYS[2]);")
.append("redis.call('SET',KEYS[2],count,'PX',live);")
//success unlock reentrant-write-lock
.append("return count;")
.append("else ")
.append("redis.call('DEL',KEYS[2]);")
.append("redis.call('DEL',KEYS[1]);")
//success unlock
.append("return 0;")
.append("end;")
.append("else ")
.append("redis.call('DEL',KEYS[1]);")
.append("return 0;")
.append("end;")
.append("else ")
//fail unlock, thread not get the lock
.append("return -1;")
.append("end;");
Long res = (Long) JedisTemplate.operate().eval(sctipt.toString(), 2, RedisReadWriteLock.getWriteLockKey(name), RedisReadWriteLock.getReentrantWriteLockKey(name), RedisReadWriteLock.getThreadUid());
if(res.equals(0L)){
log.debug("success unlock write lock, writeLock = {}", RedisReadWriteLock.getWriteLockKey(name));
}else if(res.equals(-1L)){
log.debug("fail unlock, thread not get the lock, writeLock = {}, thread = {}", RedisReadWriteLock.getReentrantWriteLockKey(name), RedisReadWriteLock.getThreadUid());
}else {
log.debug("success unlock reentrant write lock, reentrantWriteLock = {}, count left = {}", RedisReadWriteLock.getReentrantWriteLockKey(name), res);
}
}
}
-- 写锁获取的lua脚本
-- 没有写锁,则成功设置
if redis.call('SET',KEYS[1],ARGV[1],'NX','PX',ARGV[2]) then
-- 重入数记为1
redis.call('SET',KEYS[2],1,'PX',ARGV[2])
return 1
else
-- 写锁已经被获取,判断已经获取写锁的线程是不是当前线程
if (redis.call('GET',KEYS[1])== ARGV[1]) then
-- 若是当前线程,则重入数+1
local count = redis.call('GET',KEYS[2])
if not count then
redis.call('SET',KEYS[2],1,'PX',ARGV[2])
return 1;
else
count = tonumber(count) + 1;
redis.call('SET',KEYS[2],count,'PX',ARGV[2])
return count
end
else
return 0
end
end
-- 写锁释放的lua脚本
-- 判断是否是当前线程获取的写锁
if (redis.call('GET',KEYS[1])== ARGV[1]) then
-- 是当前线程获取的写锁,判断锁是否被重入,重入数减到1后释放锁
local count = redis.call('GET',KEYS[2])
if count then
if (tonumber(count) > 1) then
count = tonumber(count) - 1
local live = redis.call('PTTL',KEYS[2])
redis.call('SET',KEYS[2],count,'PX',live)
-- 返回该线程对该锁剩余的重入次数
return count
else
redis.call('DEL',KEYS[2])
redis.call('DEL',KEYS[1])
return 0
end
else
redis.call('DEL',KEYS[1])
return 0
end
else
-- 其他线程获取的该写锁,该线程解锁失败
return -1
end
测试如下:
public static void main(String[] args) {
final int[] num = {0};
for (int i = 0; i < 10 ; i++) {
Thread thread = new Thread(() -> {
RedisReadWriteLock.writeLock().tryLock("ccc", 30,300, TimeUnit.SECONDS);
num[0]++;
System.out.println("【写】:" + num[0]);
RedisReadWriteLock.writeLock().unlock("ccc");
});
thread.start();
}
for (int i = 0; i < 100 ; i++) {
Thread thread = new Thread(() -> {
RedisReadWriteLock.readLock().tryLock("ccc", 30,300, TimeUnit.SECONDS);
System.out.println("读:" + num[0]);
RedisReadWriteLock.readLock().unlock("ccc");
});
thread.start();
if(i % 3 == 0){
try {
TimeUnit.MILLISECONDS.sleep(50);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}
运行结果
【写】:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
【写】:2
【写】:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
【写】:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
【写】:5
读:5
读:5
读:5
读:5
读:5
读:5
【写】:6
读:6
读:6
读:6
读:6
读:6
读:6
【写】:7
读:7
读:7
读:7
读:7
读:7
读:7
读:7
读:7
读:7
【写】:8
读:8
读:8
读:8
读:8
读:8
读:8
读:8
读:8
读:8
读:8
读:8
读:8
【写】:9
读:9
读:9
读:9
读:9
读:9
读:9
【写】:10
读:10
读:10
读:10
JedisTemplate其实是对从JedisPool获取的Jedis连接做的一个cglib的代理,用于使用完之后自动调用Jedis实例的close方法将连接归还到JedisPool。代理肯定会存在一点性能损耗,但是也简化了编码,以及避免了忘记归还连接,导致池的连接被耗空。
需要引入cglib的依赖
<dependency>
<groupId>cglib</groupId>
<artifactId>cglib</artifactId>
<version>3.2.10</version>
</dependency>
public class JedisTemplate {
public static Jedis operate(){
Enhancer enhancer = new Enhancer();
enhancer.setSuperclass(Jedis.class);
enhancer.setCallback(new JedisCglibProxyIntercepter());
return (Jedis) enhancer.create();
}
}
public class JedisCglibProxyIntercepter implements MethodInterceptor {
@Override
public Object intercept(Object o, Method method, Object[] objects, MethodProxy methodProxy) throws Throwable {
//try后会自动调用jedis的close方法释放资源
try(Jedis jedis = JedisConnectPoll.getJedis()){
return method.invoke(jedis, objects);
}
}
}
JedisPool连接池的代码也贴一下:
需要添加下面的Jedis依赖
<dependency>
<groupId>redis.clients</groupId>
<artifactId>jedis</artifactId>
<version>3.0.1</version>
</dependency>
@Slf4j的注解是需要添加lombok依赖,然后还要对应的日志依赖,这个按个人需要选择吧
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.6</version>
<scope>provided</scope>
</dependency>
<!--统一使用slf4j的logback打印日志-->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.16</version>
</dependency>
<!--桥接java common log Apache Commons Logging-->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>jcl-over-slf4j</artifactId>
<version>1.7.16</version>
</dependency>
<!--桥接log4j-->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>log4j-over-slf4j</artifactId>
<version>1.7.16</version>
</dependency>
<!--logback核心库-->
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-core</artifactId>
<version>1.1.6</version>
</dependency>
<!--logback重写log4j-->
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>1.1.6</version>
<exclusions>
<exclusion>
<artifactId>slf4j-api</artifactId>
<groupId>org.slf4j</groupId>
</exclusion>
</exclusions>
</dependency>
@Slf4j
public class JedisConnectPoll {
public static final UUID JEDIS_CONNECT_POLL_UUID = UUID.randomUUID();
//连接redis实例的ip
private static final String REDIS_ADDRESS = "127.0.0.1";
//连接redis实例的端口
private static final int PORT = "6379";
//密码
private static final String PASSWORD = "";
//多线程环境中,连接实例的最大数,如果设为-1则无上线,建议设置,否则有可能导致资源耗尽
private static final int MAX_ACTIVE = 160;
//在多线程环境中,连接池中最大空闲连接数,单线程环境没有实际意义
private static final int MAX_OLDE = 128;
//在多线程环境中,连接池中最小空闲连接数
private static final int MIN_OLDE = 8;
//多长时间将空闲线程进行回收,单位毫秒
private static final int METM = 2000;
//对象空闲多久后逐出, 当空闲时间>该值 且 空闲连接>最大空闲数 时直接逐出,不再根据MinEvictableIdleTimeMillis判断 (默认逐出策略)
private static final int SMETM = 2000;
//逐出扫描的时间间隔(毫秒) 如果为负数,则不运行逐出线程, 默认-1,只有运行了此线程,MIN_OLDE METM/SMETM才会起作用
private static final int TBERM = 1000;
//当连接池中连接不够用时,等待可用连接的最大时间,单位毫秒,默认值为-1,表示永不超时。如果超过等待时间,则直接抛出JedisConnectionException;
private static final int MAX_WAIT = 1000;
//超时时间,单位毫秒
private static final int TIME_OUT = 10000;
//在借用一个jedis连接实例时,是否提前进行有效性确认操作;如果为true,则得到的jedis实例均是可用的;
private static final boolean TEST_ON_BORROW = false;
//连接池实例
private static JedisPool jedisPool = null;
static {
initPoll();
}
private static void initPoll() {
try {
JedisPoolConfig config = new JedisPoolConfig();
config.setMaxTotal(MAX_ACTIVE);
config.setMaxIdle(MAX_OLDE);
config.setMaxWaitMillis(MAX_WAIT);
config.setTestOnBorrow(TEST_ON_BORROW);
config.setMinIdle(MIN_OLDE);
config.setMinEvictableIdleTimeMillis(METM);
config.setSoftMinEvictableIdleTimeMillis(SMETM);
config.setTimeBetweenEvictionRunsMillis(TBERM);
if(!"".equals(PASSWORD)){
jedisPool = new JedisPool(config, REDIS_ADDRESS, PORT, TIME_OUT, PASSWORD);
}else {
jedisPool = new JedisPool(config, REDIS_ADDRESS, PORT, TIME_OUT);
}
} catch (Exception e) {
log.error("initial JedisPoll fail: {}",e);
}
}
public static Jedis getJedis(){
return jedisPool.getResource();
}
}
如果本文对你有帮助,欢迎大家点赞关注收藏。作为一名具有极客精神的程序员,我会持续地给大家分享一些个人的开发经验和技术文章,若能帮助到你,那将是我莫大的荣幸。大家也可以关注我的同名微信公众号“三易程序员”,文章会在公众号和头条号同步更新,也欢迎大家私信探讨技术问题。