基于Redis实现分布式锁
1 自定义线程控制锁
package com.nikooh.manage.annotation;
import java.lang.annotation.*;
import java.util.concurrent.TimeUnit;
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface CacheThreadArg {
String prefix() default "";
String[] argKey() default {};
String number() default "";
String reminder() default "";
long timeout() default 1L;
TimeUnit timeUnit() default TimeUnit.MINUTES;
}
2 做线程锁的redisTemplate配置
package com.nikooh.manage.config;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cache.annotation.CachingConfigurerSupport;
import org.springframework.cache.annotation.EnableCaching;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.connection.RedisPassword;
import org.springframework.data.redis.connection.RedisStandaloneConfiguration;
import org.springframework.data.redis.connection.jedis.JedisClientConfiguration;
import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.RedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.util.ObjectUtils;
import redis.clients.jedis.JedisPoolConfig;
@Configuration
@EnableCaching
public class CacheRedisConfig {
@Value("${spring.redis.database}")
private int database;
@Value("${spring.redis.jedis.pool.max-active}")
private int maxActive;
@Value("${spring.redis.jedis.pool.max-wait}")
private int maxWait;
@Value("${spring.redis.jedis.pool.max-idle}")
private int maxIdle;
@Value("${spring.redis.jedis.pool.min-idle}")
private int minIdle;
@Value("${spring.redis.host}")
private String hostName;
@Value("${spring.redis.port}")
private int port;
@Value("${spring.redis.password}")
private String password;
@Bean("cacheRedisTemplate")
public RedisTemplate<?, ?> redisTemplate(@Qualifier("jedisConnectionFactory") RedisConnectionFactory factory) {
RedisTemplate<?, ?> redisTemplate = new RedisTemplate<>();
redisTemplate.setConnectionFactory(factory);
RedisSerializer<String> redisSerializer = new StringRedisSerializer();
redisTemplate.setKeySerializer(redisSerializer);
redisTemplate.setHashKeySerializer(redisSerializer);
StringRedisSerializer stringSerializer = new StringRedisSerializer();
redisTemplate.setKeySerializer(stringSerializer);
redisTemplate.setValueSerializer(stringSerializer);
redisTemplate.setHashKeySerializer(stringSerializer);
redisTemplate.setHashValueSerializer(stringSerializer);
return redisTemplate;
}
@Bean("jedisConnectionFactory")
public RedisConnectionFactory getRedisConnectionFactory() {
JedisPoolConfig poolConfig = new JedisPoolConfig();
poolConfig.setMaxTotal(maxActive);
poolConfig.setMaxIdle(maxIdle);
poolConfig.setMinIdle(minIdle);
poolConfig.setMaxWaitMillis(maxWait);
RedisStandaloneConfiguration redisStandaloneConfiguration = new RedisStandaloneConfiguration();
redisStandaloneConfiguration.setHostName(hostName);
redisStandaloneConfiguration.setPort(port);
if (!ObjectUtils.isEmpty(password)) {
RedisPassword redisPassword = RedisPassword.of(password);
redisStandaloneConfiguration.setPassword(redisPassword);
}
redisStandaloneConfiguration.setDatabase(database);
JedisClientConfiguration.JedisClientConfigurationBuilder builder = JedisClientConfiguration.builder();
JedisClientConfiguration jedisClientConfiguration = builder.usePooling().poolConfig(poolConfig).build();
return new JedisConnectionFactory(redisStandaloneConfiguration, jedisClientConfiguration);
}
}
3 缓存控制线程切面
package com.nikooh.manage.config.aspect;
import com.nikooh.manage.config.annotation.CacheThread;
import com.nikooh.manage.config.annotation.CacheThreadArg;
import com.nikooh.manage.exception.ServiceException;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.After;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.LocalVariableTableParameterNameDiscoverer;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ValueOperations;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.stereotype.Component;
import java.lang.reflect.Method;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantLock;
@Aspect
@Component
public class CacheThreadAspect {
private static final Logger LOGGER = LoggerFactory.getLogger(CacheThreadAspect.class);
private static final ParameterNameDiscoverer PARAMETER_NAME_DISCOVERER = new DefaultParameterNameDiscoverer();
private static final Integer MAX_THREAD_NUM = 10;
private final ReentrantLock lock = new ReentrantLock();
@Autowired
private RedisTemplate cacheRedisTemplate;
private RedisDistributeLock redisDistributeLock = new RedisDistributeLock();
@SuppressWarnings("all")
@Before("@annotation(cacheThread)")
public void openCacheThreadNum(JoinPoint joinPoint, CacheThread cacheThread) {
String uniqueKey = cacheThread.prefix() + cacheThread.key();
String reminder = cacheThread.reminder();
long timeout = cacheThread.timeout();
TimeUnit timeUnit = cacheThread.timeUnit();
LOGGER.info(uniqueKey + " openCacheThreadNum ···");
lockByRedis(uniqueKey, reminder, timeout, timeUnit);
}
private void lockByRedis(String uniqueKey, String reminder, long timeout, TimeUnit timeUnit) {
int expireMsecs = (int) timeUnit.toMillis(timeout);
redisDistributeLock.setRedisTemplate(cacheRedisTemplate);
redisDistributeLock.setExpireMsecs(expireMsecs);
boolean isLock = redisDistributeLock.getLock(uniqueKey);
if (!isLock) {
LOGGER.warn("current thread does not has lock, thread id is " + Thread.currentThread().getId());
if (reminder == null || "".equals(reminder)) {
throw new ServiceException(1000, "当前相同操作进行中,请稍后再试!");
} else {
throw new ServiceException(1000, reminder);
}
} else {
LOGGER.warn("current thread has lock, thread id is " + Thread.currentThread().getId());
}
}
@SuppressWarnings("all")
@After("@annotation(cacheThread)")
public void closeCacheThreadNum(JoinPoint joinPoint, CacheThread cacheThread) {
String uniqueKey = cacheThread.prefix() + cacheThread.key();
LOGGER.info(uniqueKey + " closeCacheThreadNum ···");
unlockByRedis(uniqueKey);
}
private void unlockByRedis(String uniqueKey) {
redisDistributeLock.setRedisTemplate(cacheRedisTemplate);
boolean releaseLock = redisDistributeLock.releaseLock(uniqueKey);
if (releaseLock) {
LOGGER.info("release lock success, thread id is " + Thread.currentThread().getId());
} else {
LOGGER.info("release lock error, thread id is " + Thread.currentThread().getId());
}
}
@SuppressWarnings("all")
@Before("@annotation(cacheThreadArg)")
public void openCacheThreadByArgKey(JoinPoint joinPoint, CacheThreadArg cacheThreadArg) {
MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
Method method = methodSignature.getMethod();
String argUniqueKey = cacheThreadArg.prefix() + parseKey(cacheThreadArg.argKey(), method, joinPoint.getArgs());
String reminder = cacheThreadArg.reminder();
long timeout = cacheThreadArg.timeout();
TimeUnit timeUnit = cacheThreadArg.timeUnit();
LOGGER.info(argUniqueKey + " openCacheThreadByArgKey ···");
lockByRedis(argUniqueKey, reminder, timeout, timeUnit);
}
@SuppressWarnings("all")
@After("@annotation(cacheThreadArg)")
public void closeCacheThreadByArgKey(JoinPoint joinPoint, CacheThreadArg cacheThreadArg) {
MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
Method method = methodSignature.getMethod();
String argUniqueKey = cacheThreadArg.prefix() + parseKey(cacheThreadArg.argKey(), method, joinPoint.getArgs());
LOGGER.info(argUniqueKey + " closeCacheThreadByArgKey ···");
unlockByRedis(argUniqueKey);
}
@Deprecated
private void setThreadCache(String key, String number, String reminder, TimeUnit timeUnit, long timeout) {
lock.lock();
try {
ValueOperations valueOperations = cacheRedisTemplate.opsForValue();
Set threadIdSet = (Set) valueOperations.get(key);
if (threadIdSet == null) {
threadIdSet = new HashSet();
}
if (number != null && !"".equals(number)) {
if (threadIdSet.size() >= Long.valueOf(number)) {
LOGGER.warn("operateKey: " + key + "has " + threadIdSet.size() + " thread , max thread number is " + number);
if (reminder == null || "".equals(reminder)) {
throw new ServiceException(1000, "当前相同操作进行中,请稍后再试!");
} else {
throw new ServiceException(1000, reminder + ",请稍后再试!");
}
}
} else {
if (threadIdSet.size() >= MAX_THREAD_NUM) {
LOGGER.warn("operateKey: " + key + "has " + threadIdSet.size() + " thread , max thread number is " + MAX_THREAD_NUM);
throw new ServiceException(1000, "当前操作最大线程数为:" + MAX_THREAD_NUM + ",请稍后再试");
}
}
long threadId = Thread.currentThread().getId();
LOGGER.info("---->当前线程id:" + threadId);
if (threadIdSet.contains(threadId)) {
LOGGER.warn("operateKey: " + key + ", threadId: " + threadId + "has been already exist in thread set !");
throw new ServiceException(1000, "当前线程已在" + key + "中,请稍后再试!");
}
threadIdSet.add(threadId);
if (timeUnit != null && timeout > 0L) {
valueOperations.set(key, threadIdSet, timeout, timeUnit);
} else {
valueOperations.set(key, threadIdSet);
}
} catch (Exception e) {
throw new ServiceException(e);
} finally {
lock.unlock();
}
}
@Deprecated
private void removeThreadCache(String key) {
ValueOperations valueOperations = cacheRedisTemplate.opsForValue();
Set threadIdSet = (Set) valueOperations.get(key);
if (threadIdSet == null || threadIdSet.isEmpty()) {
LOGGER.warn("operateKey: " + key + ",对应的set集合不存在");
return;
}
long threadId = Thread.currentThread().getId();
LOGGER.info("---->当前线程id:" + threadId);
if (!threadIdSet.contains(threadId)) {
LOGGER.warn("operateKey: " + key + ", threadId: " + threadId + "has not exist in thread set !");
return;
}
threadIdSet.remove(threadId);
valueOperations.set(key, threadIdSet);
}
private String parseKey(String[] keyArr, Method method, Object[] args) {
if (keyArr == null || keyArr.length == 0) {
LOGGER.error("keyArr is null or length is 0");
throw new ServiceException(1000, "keyArr is null or length is 0");
}
LocalVariableTableParameterNameDiscoverer u = new LocalVariableTableParameterNameDiscoverer();
String[] paraNameArr = u.getParameterNames(method);
ExpressionParser parser = new SpelExpressionParser();
StandardEvaluationContext context = new StandardEvaluationContext();
for (int i = 0; i < paraNameArr.length; i++) {
context.setVariable(paraNameArr[i], args[i]);
}
StringBuilder uniqueKey = new StringBuilder();
for (String s : keyArr) {
uniqueKey.append(parser.parseExpression(s).getValue(context, String.class));
}
return "unitKey:" + uniqueKey;
}
private String parseCacheKey(Method method, Object[] args, String expressionString) {
String[] parameterNames = PARAMETER_NAME_DISCOVERER.getParameterNames(method);
EvaluationContext context = new StandardEvaluationContext();
if (parameterNames != null && args != null && args.length > 0
&& args.length == parameterNames.length) {
for (int i = 0, length = parameterNames.length; i < length; i++) {
context.setVariable(parameterNames[i], args[i]);
}
}
ExpressionParser parser = new SpelExpressionParser();
Expression expression = null;
if (expressionString != null && !"".equals(expressionString)) {
expression = parser.parseExpression(expressionString);
return String.valueOf(expression.getValue(context));
} else {
return null;
}
}
}
4 redis分布式锁
package com.nikooh.manage.config.aspect;
import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.RedisTemplate;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisCluster;
import java.util.Collections;
public class RedisDistributeLock {
private String lockKey;
private int expireMsecs = 60 * 1000;
private RedisTemplate cacheRedisTemplate;
private static final String SET_IF_NOT_EXIST = "NX";
private static final String SET_WITH_EXPIRE_TIME = "PX";
private static final String LOCK_OK_RESULT = "OK";
private static final String LOCK_NO_RESULT = "NO";
private static final String RELEASE_LOCK_OK_RESULT = "1";
private static final String RELEASE_LOCK_NO_RESULT = "0";
public RedisDistributeLock() {}
public RedisDistributeLock(RedisTemplate redisTemplate) {
this.cacheRedisTemplate = redisTemplate;
}
public RedisDistributeLock(RedisTemplate redisTemplate, int expireMsecs) {
this(redisTemplate);
this.expireMsecs = expireMsecs;
}
public RedisDistributeLock(RedisTemplate redisTemplate, int expireMsecs, String lockKey) {
this(redisTemplate, expireMsecs);
this.lockKey = lockKey;
}
public void setRedisTemplate(RedisTemplate redisTemplate) {
this.cacheRedisTemplate = redisTemplate;
}
public void setExpireMsecs(int expireMsecs) {
this.expireMsecs = expireMsecs;
}
public boolean getLock(String lockKey) {
String result = (String) cacheRedisTemplate.execute((RedisCallback<String>) connection -> {
Object nativeConnection = connection.getNativeConnection();
if(nativeConnection instanceof JedisCluster) {
return ((JedisCluster) nativeConnection).set(lockKey, String.valueOf(Thread.currentThread().getId()), SET_IF_NOT_EXIST, SET_WITH_EXPIRE_TIME, expireMsecs);
}
if(nativeConnection instanceof Jedis) {
return ((Jedis) nativeConnection).set(lockKey, String.valueOf(Thread.currentThread().getId()), SET_IF_NOT_EXIST, SET_WITH_EXPIRE_TIME, expireMsecs);
}
return LOCK_NO_RESULT;
});
return LOCK_OK_RESULT.equals(result);
}
public boolean releaseLock(String lockKey) {
String script = "if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('del', KEYS[1]) else return 0 end";
Object result = cacheRedisTemplate.execute((RedisCallback<Object>) connection -> {
Object nativeConnection = connection.getNativeConnection();
if(nativeConnection instanceof JedisCluster) {
return ((JedisCluster) nativeConnection).eval(script, Collections.singletonList(lockKey), Collections.singletonList(String.valueOf(Thread.currentThread().getId())));
}
if(nativeConnection instanceof Jedis) {
return ((Jedis) nativeConnection).eval(script, Collections.singletonList(lockKey), Collections.singletonList(String.valueOf(Thread.currentThread().getId())));
}
return RELEASE_LOCK_NO_RESULT;
});
return RELEASE_LOCK_OK_RESULT.equals(result.toString());
}
}
TIPS:由于不同版本的springboot兼容的jedis版本不同,而在Jedis3.1.0版本,需要作出修改
public boolean getLock(String lockKey) {
String result = (String) cacheRedisTemplate.execute((RedisCallback<String>) connection -> {
Object nativeConnection = connection.getNativeConnection();
if (nativeConnection instanceof JedisCluster) {
return ((JedisCluster) nativeConnection).set(lockKey, String.valueOf(Thread.currentThread().getId()), SetParams.setParams().nx().px(expireMsecs));
}
if (nativeConnection instanceof Jedis) {
return ((Jedis) nativeConnection).set(lockKey, String.valueOf(Thread.currentThread().getId()), SetParams.setParams().nx().px(expireMsecs));
}
return LOCK_NO_RESULT;
});
return LOCK_OK_RESULT.equals(result);
}
用法:
@Transactional(rollbackFor = Exception.class)
@CacheThreadArg(prefix = "pavilion:base:editCityOperUser:", argKey = {"#request.id"}, number = "1", reminder = "请勿重复提交")
public boolean editCityOperUser(CreateCityOperUserRequest request) {
}