基于Redis实现分布式锁

基于Redis实现分布式锁

1 自定义线程控制锁

package com.nikooh.manage.annotation;

import java.lang.annotation.*;
import java.util.concurrent.TimeUnit;

/**
 * @Description: 线程控制锁 参数key
 * @Author: nikooh
 * @Date: 2020/06/29 : 10:40
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface CacheThreadArg {

    /**
     * key的前缀
     */
    String prefix() default "";

    /**
     * 通过注解所在参数生成的唯一key
     */
    String[] argKey() default {};

    /**
     * 允许通过线程的数量
     */
    String number() default "";

    /**
     * 提示信息
     */
    String reminder() default "";

    /**
     * 超时时间
     */
    long timeout() default 1L;

    /**
     * 超时时间单位
     */
    TimeUnit timeUnit() default TimeUnit.MINUTES;
}

2 做线程锁的redisTemplate配置

package com.nikooh.manage.config;

import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cache.annotation.CachingConfigurerSupport;
import org.springframework.cache.annotation.EnableCaching;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.connection.RedisPassword;
import org.springframework.data.redis.connection.RedisStandaloneConfiguration;
import org.springframework.data.redis.connection.jedis.JedisClientConfiguration;
import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.RedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.util.ObjectUtils;
import redis.clients.jedis.JedisPoolConfig;

/**
 * @Description: 做线程锁的redisTemplate配置
 * @Author: nikooh
 * @Date: 2020/06/29 : 11:39
 */
@Configuration
@EnableCaching
public class CacheRedisConfig {

    @Value("${spring.redis.database}")
    private int database;

    @Value("${spring.redis.jedis.pool.max-active}")
    private int maxActive;

    @Value("${spring.redis.jedis.pool.max-wait}")
    private int maxWait;

    @Value("${spring.redis.jedis.pool.max-idle}")
    private int maxIdle;

    @Value("${spring.redis.jedis.pool.min-idle}")
    private int minIdle;

    @Value("${spring.redis.host}")
    private String hostName;

    @Value("${spring.redis.port}")
    private int port;

    @Value("${spring.redis.password}")
    private String password;

    /**
     * redis模板,存储关键字是字符串,值是Jdk序列化
     *
     * @param factory
     * @return
     * @Description:
     */
    @Bean("cacheRedisTemplate")
    public RedisTemplate<?, ?> redisTemplate(@Qualifier("jedisConnectionFactory") RedisConnectionFactory factory) {
        RedisTemplate<?, ?> redisTemplate = new RedisTemplate<>();
        redisTemplate.setConnectionFactory(factory);
        //key序列化方式;但是如果方法上有Long等非String类型的话,会报类型转换错误;
        //Long类型不可以会出现异常信息;
        RedisSerializer<String> redisSerializer = new StringRedisSerializer();
        redisTemplate.setKeySerializer(redisSerializer);
        redisTemplate.setHashKeySerializer(redisSerializer);

        //默认使用JdkSerializationRedisSerializer序列化方式;会出现乱码,改成StringRedisSerializer
        StringRedisSerializer stringSerializer = new StringRedisSerializer();
        redisTemplate.setKeySerializer(stringSerializer);
        redisTemplate.setValueSerializer(stringSerializer);
        redisTemplate.setHashKeySerializer(stringSerializer);
        redisTemplate.setHashValueSerializer(stringSerializer);
        return redisTemplate;
    }

    @Bean("jedisConnectionFactory")
    public RedisConnectionFactory getRedisConnectionFactory() {
        JedisPoolConfig poolConfig = new JedisPoolConfig();
        poolConfig.setMaxTotal(maxActive);
        poolConfig.setMaxIdle(maxIdle);
        poolConfig.setMinIdle(minIdle);
        poolConfig.setMaxWaitMillis(maxWait);

        RedisStandaloneConfiguration redisStandaloneConfiguration = new RedisStandaloneConfiguration();
        redisStandaloneConfiguration.setHostName(hostName);
        redisStandaloneConfiguration.setPort(port);
        if (!ObjectUtils.isEmpty(password)) {
            RedisPassword redisPassword = RedisPassword.of(password);
            redisStandaloneConfiguration.setPassword(redisPassword);
        }
        redisStandaloneConfiguration.setDatabase(database);

        JedisClientConfiguration.JedisClientConfigurationBuilder builder = JedisClientConfiguration.builder();
        JedisClientConfiguration jedisClientConfiguration = builder.usePooling().poolConfig(poolConfig).build();

        return new JedisConnectionFactory(redisStandaloneConfiguration, jedisClientConfiguration);

    }
}

3 缓存控制线程切面

package com.nikooh.manage.config.aspect;

import com.nikooh.manage.config.annotation.CacheThread;
import com.nikooh.manage.config.annotation.CacheThreadArg;
import com.nikooh.manage.exception.ServiceException;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.After;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.LocalVariableTableParameterNameDiscoverer;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ValueOperations;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.stereotype.Component;

import java.lang.reflect.Method;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantLock;

/**
 * @Description: 缓存控制线程切面
 * @Author: nikooh
 * @Date: 2020/06/29 : 10:42
 */
@Aspect
@Component
public class CacheThreadAspect {
    private static final Logger LOGGER = LoggerFactory.getLogger(CacheThreadAspect.class);

    private static final ParameterNameDiscoverer PARAMETER_NAME_DISCOVERER = new DefaultParameterNameDiscoverer();

    /**
     * 当前默认操作允许的最大线程数
     */
    private static final Integer MAX_THREAD_NUM = 10;

    private final ReentrantLock lock = new ReentrantLock();

    @Autowired
    private RedisTemplate cacheRedisTemplate;

    private RedisDistributeLock redisDistributeLock = new RedisDistributeLock();


    //定义前置通知
    @SuppressWarnings("all")
    @Before("@annotation(cacheThread)")
    public void openCacheThreadNum(JoinPoint joinPoint, CacheThread cacheThread) {

        //1.获取注解参数
        String uniqueKey = cacheThread.prefix() + cacheThread.key();
        String reminder = cacheThread.reminder();
        long timeout = cacheThread.timeout();
        TimeUnit timeUnit = cacheThread.timeUnit();
        LOGGER.info(uniqueKey + " openCacheThreadNum ···");

        //2.添加分布式缓存
        lockByRedis(uniqueKey, reminder, timeout, timeUnit);

    }

    /**
     * 获取锁
     *
     * @param uniqueKey
     * @param reminder
     * @param timeout
     * @param timeUnit
     */
    private void lockByRedis(String uniqueKey, String reminder, long timeout, TimeUnit timeUnit) {

        int expireMsecs = (int) timeUnit.toMillis(timeout);
        redisDistributeLock.setRedisTemplate(cacheRedisTemplate);
        redisDistributeLock.setExpireMsecs(expireMsecs);
        boolean isLock = redisDistributeLock.getLock(uniqueKey);
        if (!isLock) {
            LOGGER.warn("current thread does not has lock, thread id is " + Thread.currentThread().getId());
            if (reminder == null || "".equals(reminder)) {
                throw new ServiceException(1000, "当前相同操作进行中,请稍后再试!");
            } else {
                throw new ServiceException(1000, reminder);
            }
        } else {
            LOGGER.warn("current thread has lock, thread id is " + Thread.currentThread().getId());
        }
    }


    @SuppressWarnings("all")
    @After("@annotation(cacheThread)")
    public void closeCacheThreadNum(JoinPoint joinPoint, CacheThread cacheThread) {

        //1.获取操作key对应的set集合
        String uniqueKey = cacheThread.prefix() + cacheThread.key();
        LOGGER.info(uniqueKey + " closeCacheThreadNum ···");

        //2.释放redis分布式缓存锁
        unlockByRedis(uniqueKey);
    }

    /**
     * 释放锁
     *
     * @param uniqueKey
     */
    private void unlockByRedis(String uniqueKey) {

        redisDistributeLock.setRedisTemplate(cacheRedisTemplate);
        boolean releaseLock = redisDistributeLock.releaseLock(uniqueKey);
        if (releaseLock) {
            LOGGER.info("release lock success, thread id is " + Thread.currentThread().getId());
        } else {
            LOGGER.info("release lock error, thread id is " + Thread.currentThread().getId());
        }
    }


    @SuppressWarnings("all")
    @Before("@annotation(cacheThreadArg)")
    public void openCacheThreadByArgKey(JoinPoint joinPoint, CacheThreadArg cacheThreadArg) {

        //1.获取注解
        //获得切面当中方法签名
        MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
        //获得签名方法
        Method method = methodSignature.getMethod();

        //2.获取注解参数
        String argUniqueKey = cacheThreadArg.prefix() + parseKey(cacheThreadArg.argKey(), method, joinPoint.getArgs());
        String reminder = cacheThreadArg.reminder();
        long timeout = cacheThreadArg.timeout();
        TimeUnit timeUnit = cacheThreadArg.timeUnit();
        LOGGER.info(argUniqueKey + " openCacheThreadByArgKey ···");

        //3.加锁
        lockByRedis(argUniqueKey, reminder, timeout, timeUnit);
    }

    @SuppressWarnings("all")
    @After("@annotation(cacheThreadArg)")
    public void closeCacheThreadByArgKey(JoinPoint joinPoint, CacheThreadArg cacheThreadArg) {

        //1.获取注解
        //获得切面当中方法签名
        MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
        //获得签名方法
        Method method = methodSignature.getMethod();

        //2.获取注解参数
        String argUniqueKey = cacheThreadArg.prefix() + parseKey(cacheThreadArg.argKey(), method, joinPoint.getArgs());
        LOGGER.info(argUniqueKey + " closeCacheThreadByArgKey ···");

        //3.释放锁缓存
        unlockByRedis(argUniqueKey);
    }

    /**
     * @param key
     * @param number
     * @param reminder
     * @deprecated 添加缓存方法废弃
     * <p>
     * 添加缓存
     */
    @Deprecated
    private void setThreadCache(String key, String number, String reminder, TimeUnit timeUnit, long timeout) {
        lock.lock();
        try {

            //1.获取操作key对应的set集合
            ValueOperations valueOperations = cacheRedisTemplate.opsForValue();
            Set threadIdSet = (Set) valueOperations.get(key);
            if (threadIdSet == null) {
                threadIdSet = new HashSet();
            }

            //2.最大线程数检验
            if (number != null && !"".equals(number)) {
                //传入线程数
                if (threadIdSet.size() >= Long.valueOf(number)) {
                    LOGGER.warn("operateKey: " + key + "has " + threadIdSet.size() + " thread , max thread number is " + number);
                    if (reminder == null || "".equals(reminder)) {
                        throw new ServiceException(1000, "当前相同操作进行中,请稍后再试!");
                    } else {
                        throw new ServiceException(1000, reminder + ",请稍后再试!");
                    }
                }
            } else {
                //使用默认的最大线程数
                if (threadIdSet.size() >= MAX_THREAD_NUM) {
                    LOGGER.warn("operateKey: " + key + "has " + threadIdSet.size() + " thread , max thread number is " + MAX_THREAD_NUM);
                    throw new ServiceException(1000, "当前操作最大线程数为:" + MAX_THREAD_NUM + ",请稍后再试");
                }
            }

            //3.重复线程数检验
            long threadId = Thread.currentThread().getId();
            LOGGER.info("---->当前线程id:" + threadId);
            if (threadIdSet.contains(threadId)) {
                LOGGER.warn("operateKey: " + key + ", threadId: " + threadId + "has been already exist in thread set !");
                throw new ServiceException(1000, "当前线程已在" + key + "中,请稍后再试!");
            }

            //4.将新添加元素的set放如对应的操作key中
            threadIdSet.add(threadId);

            if (timeUnit != null && timeout > 0L) {
                valueOperations.set(key, threadIdSet, timeout, timeUnit);
            } else {
                valueOperations.set(key, threadIdSet);
            }
        } catch (Exception e) {
            throw new ServiceException(e);
        } finally {
            lock.unlock();
        }
    }

    /**
     * @param key
     * @deprecated 删除缓存方法废弃
     * <p>
     * 删除缓存
     */
    @Deprecated
    private void removeThreadCache(String key) {

        ValueOperations valueOperations = cacheRedisTemplate.opsForValue();
        Set threadIdSet = (Set) valueOperations.get(key);
        if (threadIdSet == null || threadIdSet.isEmpty()) {
            LOGGER.warn("operateKey: " + key + ",对应的set集合不存在");
            return;
        }

        //3.获取当前线程id
        long threadId = Thread.currentThread().getId();
        LOGGER.info("---->当前线程id:" + threadId);
        if (!threadIdSet.contains(threadId)) {
            LOGGER.warn("operateKey: " + key + ", threadId: " + threadId + "has not exist in thread set !");
            return;
        }

        //4.从对应的set集合中删除此线程id
        threadIdSet.remove(threadId);
        valueOperations.set(key, threadIdSet);
    }


    /**
     * @param keyArr
     * @param method
     * @param args
     * @return key 定义在注解上,支持SPEL表达式
     */
    private String parseKey(String[] keyArr, Method method, Object[] args) {

        if (keyArr == null || keyArr.length == 0) {
            LOGGER.error("keyArr is null or length is 0");
            throw new ServiceException(1000, "keyArr is null or length is 0");
        }

        //获取被拦截方法参数名列表(使用Spring支持类库)
        LocalVariableTableParameterNameDiscoverer u = new LocalVariableTableParameterNameDiscoverer();
        String[] paraNameArr = u.getParameterNames(method);

        //使用SPEL进行key的解析
        ExpressionParser parser = new SpelExpressionParser();
        //SPEL上下文
        StandardEvaluationContext context = new StandardEvaluationContext();
        //把方法参数放入SPEL上下文中
        for (int i = 0; i < paraNameArr.length; i++) {
            context.setVariable(paraNameArr[i], args[i]);
        }

        StringBuilder uniqueKey = new StringBuilder();
        for (String s : keyArr) {
            uniqueKey.append(parser.parseExpression(s).getValue(context, String.class));
        }

        return "unitKey:" + uniqueKey;
    }


    /**
     * 使用SpeL解析缓存key
     *
     * @param method
     * @param args
     * @param expressionString
     * @return
     */
    private String parseCacheKey(Method method, Object[] args, String expressionString) {
        String[] parameterNames = PARAMETER_NAME_DISCOVERER.getParameterNames(method);
        EvaluationContext context = new StandardEvaluationContext();
        if (parameterNames != null && args != null && args.length > 0
                && args.length == parameterNames.length) {
            for (int i = 0, length = parameterNames.length; i < length; i++) {
                context.setVariable(parameterNames[i], args[i]);
            }
        }
        ExpressionParser parser = new SpelExpressionParser();
        Expression expression = null;
        if (expressionString != null && !"".equals(expressionString)) {
            expression = parser.parseExpression(expressionString);
            return String.valueOf(expression.getValue(context));
        } else {
            return null;
        }
    }

}

4 redis分布式锁

package com.nikooh.manage.config.aspect;

import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.RedisTemplate;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisCluster;

import java.util.Collections;

/**
 * @Description: redis分布式锁
 * @Author: nikooh
 * @Date: 2020/06/29 : 10:43
 */
public class RedisDistributeLock {

    /**
     * 锁key
     */
    private String lockKey;

    /**
     * 锁超时时间
     */
    private int expireMsecs = 60 * 1000;

    private RedisTemplate cacheRedisTemplate;

    /**
     * nx
     */
    private static final String SET_IF_NOT_EXIST = "NX";

    /**
     * px
     */
    private static final String SET_WITH_EXPIRE_TIME = "PX";

    /**
     * 加锁返回结果 ok
     */
    private static final String LOCK_OK_RESULT = "OK";

    /**
     * 加锁返回结果 no
     */
    private static final String LOCK_NO_RESULT = "NO";

    /**
     * 释放锁返回结果 ok
     */
    private static final String RELEASE_LOCK_OK_RESULT = "1";

    /**
     * 释放锁返回结果 no
     */
    private static final String RELEASE_LOCK_NO_RESULT = "0";

    /**
     * 无参构造
     */
    public RedisDistributeLock() {}

    /**
     * 有参构造1
     * @param redisTemplate
     */
    public RedisDistributeLock(RedisTemplate redisTemplate) {
        this.cacheRedisTemplate = redisTemplate;
    }

    /**
     * 有参构造2
     * @param redisTemplate
     * @param expireMsecs
     */
    public RedisDistributeLock(RedisTemplate redisTemplate, int expireMsecs) {
        this(redisTemplate);
        this.expireMsecs = expireMsecs;
    }

    /**
     * 有参构造3
     * @param redisTemplate
     * @param expireMsecs
     * @param lockKey
     */
    public RedisDistributeLock(RedisTemplate redisTemplate, int expireMsecs, String lockKey) {
        this(redisTemplate, expireMsecs);
        this.lockKey = lockKey;
    }

    public void setRedisTemplate(RedisTemplate redisTemplate) {
        this.cacheRedisTemplate = redisTemplate;
    }

    public void setExpireMsecs(int expireMsecs) {
        this.expireMsecs = expireMsecs;
    }

    /**
     * 获取锁
     * @param lockKey
     * @return
     */
    public boolean getLock(String lockKey) {

        //1.加锁
        String result = (String) cacheRedisTemplate.execute((RedisCallback<String>) connection -> {
            Object nativeConnection = connection.getNativeConnection();
            if(nativeConnection instanceof JedisCluster) {
                //集群模式
                return ((JedisCluster) nativeConnection).set(lockKey, String.valueOf(Thread.currentThread().getId()), SET_IF_NOT_EXIST, SET_WITH_EXPIRE_TIME, expireMsecs);
            }
            if(nativeConnection instanceof Jedis) {
                //单机模式/哨兵模式
                return ((Jedis) nativeConnection).set(lockKey, String.valueOf(Thread.currentThread().getId()), SET_IF_NOT_EXIST, SET_WITH_EXPIRE_TIME, expireMsecs);
            }
            return LOCK_NO_RESULT;
        });

        //2.返回
        return LOCK_OK_RESULT.equals(result);
    }

    /**
     * 释放锁
     * @param lockKey
     * @return
     */
    public boolean releaseLock(String lockKey) {

        //1.lua脚本
        String script = "if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('del', KEYS[1]) else return 0 end";

        //2.释放锁
        Object result = cacheRedisTemplate.execute((RedisCallback<Object>) connection -> {
            Object nativeConnection = connection.getNativeConnection();
            if(nativeConnection instanceof JedisCluster) {
                //集群模式
                return ((JedisCluster) nativeConnection).eval(script, Collections.singletonList(lockKey), Collections.singletonList(String.valueOf(Thread.currentThread().getId())));
            }
            if(nativeConnection instanceof Jedis) {
                //单机模式/哨兵模式
                return ((Jedis) nativeConnection).eval(script, Collections.singletonList(lockKey), Collections.singletonList(String.valueOf(Thread.currentThread().getId())));
            }
            return RELEASE_LOCK_NO_RESULT;
        });

        //3.返回
        return RELEASE_LOCK_OK_RESULT.equals(result.toString());
    }
}

TIPS:由于不同版本的springboot兼容的jedis版本不同,而在Jedis3.1.0版本,需要作出修改

    /**
     * 获取锁
     *
     * @param lockKey
     * @return
     */
    public boolean getLock(String lockKey) {

        //1.加锁
        String result = (String) cacheRedisTemplate.execute((RedisCallback<String>) connection -> {
            Object nativeConnection = connection.getNativeConnection();
            if (nativeConnection instanceof JedisCluster) {
                //集群模式
                return ((JedisCluster) nativeConnection).set(lockKey, String.valueOf(Thread.currentThread().getId()), SetParams.setParams().nx().px(expireMsecs));
            }
            if (nativeConnection instanceof Jedis) {
                //单机模式/哨兵模式
                return ((Jedis) nativeConnection).set(lockKey, String.valueOf(Thread.currentThread().getId()), SetParams.setParams().nx().px(expireMsecs));
            }
            return LOCK_NO_RESULT;
        });

        //2.返回
        return LOCK_OK_RESULT.equals(result);
    }

用法:

@Transactional(rollbackFor = Exception.class)
    @CacheThreadArg(prefix = "pavilion:base:editCityOperUser:", argKey = {"#request.id"}, number = "1", reminder = "请勿重复提交")
    public boolean editCityOperUser(CreateCityOperUserRequest request) {
    }
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值