python中redlock的使用、源码解析以及尝试改进

RedLock 基本使用

import redlock
import redis
from setting import REDIS_HOST as host, REDIS_PORT as port, REDIS_DB as db
from time import sleep
# 使用连接池
pool = redis.ConnectionPool(host=host, port=port, max_connections=100)  # 最多可以有100个线程连接
r = redis.StrictRedis(host=host, port=port, db=db, password="", socket_timeout=1, connection_pool=pool)
# 如果是多个redis服务器,直接添加就好
connection_details = [r, ]
lock_name = "test" # 加锁的key
retry_times = 5  # 加锁失败重试的次数
# 定义一个RedLock实例
lock = redlock.RedLock(lock_name, connection_details, retry_times, ttl=5000)
# 简单使用
lock.acquire()
sleep(3) # 中间可以加入自己的业务代码
lock.release()

测试RedLock能否多次获取锁

lock = redlock.RedLock(lock_name, connection_details, retry_times, ttl=5000)
lock.acquire()
# 再次请求锁
print(lock.acquire())  # False
lock.release()
print(r.get('test'))  # b'a681f0b23f474eb8a800aee2980e3604'

发现不但不能使用可重入锁,而且还导致了锁不能被删除,这是因为acquire方法会重置uuid标识,导致无法匹配之前的uuid标识

redlcok的可重入锁ReentrantRedLock

lock = redlock.ReentrantRedLock(lock_name, connection_details, retry_times, ttl=5000)
# 使用方法
lock.acquire()
if lock.acquire():
    print("再次加锁")  # 再次加锁
lock.release()
print(r.get('test'))  # b'875e119f35e34466a4bb2ac7fff4c99d
lock.release()
print(r.get('test'))  # None

源码解析(都在注释里了)

DEFAULT_RETRY_TIMES = 3
DEFAULT_RETRY_DELAY = 200
DEFAULT_TTL = 100000
CLOCK_DRIFT_FACTOR = 0.01

# 释放锁的lua脚本,用来保证get和del操作的原子性
RELEASE_LUA_SCRIPT = """
    if redis.call("get",KEYS[1]) == ARGV[1] then
        return redis.call("del",KEYS[1])
    else
        return 0
    end
"""

class RedLock(object):

	# resource为需要加锁的名字,retry_times为最多尝试加锁的次数,retry_delay就用默认的,ttl为加锁时长
    def __init__(self, resource, connection_details=None,
                 retry_times=DEFAULT_RETRY_TIMES,
                 retry_delay=DEFAULT_RETRY_DELAY,
                 ttl=DEFAULT_TTL,
                 created_by_factory=False):

        self.resource = resource
        self.retry_times = retry_times
        self.retry_delay = retry_delay
        self.ttl = ttl
		# 工厂构造方法
        if created_by_factory:
            self.factory = None
            return

        self.redis_nodes = []
        # If the connection_details parameter is not provided,
        # use redis://127.0.0.1:6379/0
        if connection_details is None:
            connection_details = [{
                'host': 'localhost',
                'port': 6379,
                'db': 0,
            }]
		# 这里逻辑是connection_details如果存的是redis.StrictRedis实例就直接用
        # 否则如果是字典,就看url关键字有没有,有就创建redis.StrictRedis实例并带上相关参数
        # 没有就用创建默认的redis.StrictRedis的实例,并带上相关参数
        for conn in connection_details:
            if isinstance(conn, redis.StrictRedis):
                node = conn
            elif 'url' in conn:
                url = conn.pop('url')
                node = redis.StrictRedis.from_url(url, **conn)
            else:
                node = redis.StrictRedis(**conn)
            node._release_script = node.register_script(RELEASE_LUA_SCRIPT)
            # 把所有redis服务器节点添加到self.redis_nodes
            self.redis_nodes.append(node)
        # 所有节点中至少需要获取到锁的个数
        self.quorum = len(self.redis_nodes) // 2 + 1

    def __enter__(self):
        if not self.acquire():
            raise RedLockError('failed to acquire lock')

    def __exit__(self, exc_type, exc_value, traceback):
        self.release()
	# 转换时间单位用的
    def _total_ms(self, delta):
        """
        Get the total number of milliseconds in a timedelta object with
        microsecond precision.
        """
        delta_seconds = delta.seconds + delta.days * 24 * 3600
        return (delta.microseconds + delta_seconds * 10**6) / 10**3
	# 单个节点请求锁
    def acquire_node(self, node):
        """
        acquire a single redis node
        """
        return node.set(self.resource, self.lock_key, nx=True, px=self.ttl)
	# 单个节点释放锁
    def release_node(self, node):
        """
        release a single redis node
        """
        # 执行lua脚本来释放锁,保证了get和del操作的原子性
        node._release_script(keys=[self.resource], args=[self.lock_key])
	# 所有节点请求锁
    def acquire(self):
        self.lock_key = uuid.uuid4().hex # 创建唯一id
		# 总体逻辑就是多次对多台redis服务器请求加锁,直到次数用完,或者一次性成功
        for retry in range(self.retry_times):
            # 初始化一个获取到锁的节点数
            acquired_node_count = 0
            # 拿到国际开始时间
            start_time = datetime.utcnow()

            # 尝试对每个节点加锁,并拿到成功加锁的数量
            for node in self.redis_nodes:
                if self.acquire_node(node):
                    acquired_node_count += 1
			# 拿到国际结束时间
            end_time = datetime.utcnow()
            # 所有请求完成的时间
            elapsed_milliseconds = self._total_ms(end_time - start_time)

            drift = (self.ttl * CLOCK_DRIFT_FACTOR) + 2 # 这个应该是计算误差
			# 有一半以上的锁请求成功并且锁的ttl大于请求锁和误差的时间就代表成功加锁
            if acquired_node_count >= self.quorum and \
               self.ttl > (elapsed_milliseconds + drift):
                return True
            else: # 否则所有节点都释放锁
                for node in self.redis_nodes:
                    self.release_node(node)
                # 随机延时,确保节点的锁都释放(可能是其他线程的业务没执行完导致的,所以延时)
                time.sleep(random.randint(0, self.retry_delay) / 1000)
        return False
    
	# 所有节点释放锁
    def release(self):  
        for node in self.redis_nodes:
            self.release_node(node)

# 可重入锁,这里的实现比较简单,就是发现已经加过锁了,就把self._acquired+1,释放锁就self._acquired-1
# 如果self._acquired为0就真正的释放锁
class ReentrantRedLock(RedLock):
    def __init__(self, *args, **kwargs):
        super(ReentrantRedLock, self).__init__(*args, **kwargs)
        self._acquired = 0

    def acquire(self):
        if self._acquired == 0:
            result = super(ReentrantRedLock, self).acquire()
            if result:
                self._acquired += 1
            return result
        else:
            self._acquired += 1
            return True

    def release(self):
        if self._acquired > 0:
            self._acquired -= 1
            if self._acquired == 0:
                return super(ReentrantRedLock, self).release()
            return True
        return False

目前解决了主从一致性的问题、可重试的问题以及可重入锁的问题,还没有解决超时释放的问题。

个人认为重入锁的时候更新锁的ttl好一点,或者self._acquired -1 !=0 时更新ttl也可以,当然如果有watchdog来解决超时释放,这点小问题就不用再专门解决了。

解决超时释放的问题参考文章(https://blog.csdn.net/weixin_42289273/article/details/120160978)

原文中更新锁ttl的部分

def lock_renewal(lock_key, lock_value, expires=30):
    """开启一个子线程对锁进行延期"""
    while True:
        rv = redis.get(lock_key)
        if rv is not None and rv.decode() == lock_value:
            print('执行锁延期...')
            set_with_ttl(lock_key, lock_value, expires)  # 将锁的过期时间重新设置为30秒
        else:
            break
        time.sleep(expires // 3)

这里存在原子性的问题,我们可以使用lua脚本来替换原来的代码

单节点解决超时释放的问题大概就是如下形式的伪代码

def lock_renewal(lock_key, lock_value, expires=30):
    """开启一个子线程对锁进行延期"""
    while True:
        time.sleep(expires // 3)
        if not exec(lua_script()): # 如果更新时间失败,代表业务代码执行完毕了,锁已经被释放了
            break

但是我们还需要考虑多个节点的问题。

下面有具体实现,已经解决了超时问题,自己测试没发现什么问题,欢迎使用和测试!

"""
Distributed locks with Redis
Redis doc: http://redis.io/topics/distlock
"""
from __future__ import division
from datetime import datetime
import random
import time
import uuid

import redis
import threading

DEFAULT_RETRY_TIMES = 3
DEFAULT_RETRY_DELAY = 200
DEFAULT_TTL = 100000
CLOCK_DRIFT_FACTOR = 0.01

# Reference:  http://redis.io/topics/distlock
# Section Correct implementation with a single instance
RELEASE_LUA_SCRIPT = """
    if redis.call("get",KEYS[1]) == ARGV[1] then
        return redis.call("del",KEYS[1])
    else
        return 0
    end
"""
delay_lua_script = """
local key = KEYS[1]
local threadId = ARGV[1]
local time = ARGV[2]
local id = redis.call('get', key)
-- 如果key存在并且id与threadId相等,就延时
if (id and id == threadId) then
    redis.call('expire', key, time)
    return 1
end
-- 否则说明没有延时成功
return 0
"""


class RedLockError(Exception):
    pass


class RedLockFactory(object):
    """
    A Factory class that helps reuse multiple Redis connections.
    """

    def __init__(self, connection_details):
        """

        """
        self.redis_nodes = []

        for conn in connection_details:
            if isinstance(conn, redis.StrictRedis):
                node = conn
            elif 'url' in conn:
                url = conn.pop('url')
                node = redis.StrictRedis.from_url(url, **conn)
            else:
                node = redis.StrictRedis(**conn)
            node._release_script = node.register_script(RELEASE_LUA_SCRIPT)
            self.redis_nodes.append(node)
            self.quorum = len(self.redis_nodes) // 2 + 1

    def create_lock(self, resource, **kwargs):
        """
        Create a new RedLock object and reuse stored Redis clients.
        All the kwargs it received would be passed to the RedLock's __init__
        function.
        """
        lock = RedLock(resource=resource, created_by_factory=True, **kwargs)
        lock.redis_nodes = self.redis_nodes
        lock.quorum = self.quorum
        lock.factory = self
        return lock


class RedLock(object):
    """
    A distributed lock implementation based on Redis.
    It shares a similar API with the `threading.Lock` class in the
    Python Standard Library.
    """

    def __init__(self, resource, connection_details=None,
                 retry_times=DEFAULT_RETRY_TIMES,
                 retry_delay=DEFAULT_RETRY_DELAY,
                 ttl=DEFAULT_TTL,
                 created_by_factory=False):

        self.resource = resource
        self.retry_times = retry_times
        self.retry_delay = retry_delay
        self.ttl = ttl

        if created_by_factory:
            self.factory = None
            return

        self.redis_nodes = []
        # If the connection_details parameter is not provided,
        # use redis://127.0.0.1:6379/0
        if connection_details is None:
            connection_details = [{
                'host': 'localhost',
                'port': 6379,
                'db': 0,
            }]

        for conn in connection_details:
            if isinstance(conn, redis.StrictRedis):
                node = conn
            elif 'url' in conn:
                url = conn.pop('url')
                node = redis.StrictRedis.from_url(url, **conn)
            else:
                node = redis.StrictRedis(**conn)
            node._release_script = node.register_script(RELEASE_LUA_SCRIPT)
            node._delay_script = node.register_script(delay_lua_script)
            self.redis_nodes.append(node)
        self.quorum = len(self.redis_nodes) // 2 + 1

    def __enter__(self):
        if not self.acquire():
            raise RedLockError('failed to acquire lock')

    def __exit__(self, exc_type, exc_value, traceback):
        self.release()

    def delay_lock(self):
        """需要开启一个子线程来对锁进行延期"""
        expire = self.ttl // 1000
        while True:
            count = 0
            time.sleep(expire // 3)
            for node in self.redis_nodes:
                # 计算lua脚本返回结果为0的数量
                count += 1 - node._delay_script(keys=[self.resource], args=[self.lock_key, expire])
            # 如果一半以上都没有执行成功,就代表延时失败,同时说明主线程已经执行完毕
            if count >= self.quorum:
                break

    def _total_ms(self, delta):
        """
        Get the total number of milliseconds in a timedelta object with
        microsecond precision.
        """
        delta_seconds = delta.seconds + delta.days * 24 * 3600
        return (delta.microseconds + delta_seconds * 10 ** 6) / 10 ** 3

    def acquire_node(self, node):
        """
        acquire a single redis node
        """
        return node.set(self.resource, self.lock_key, nx=True, px=self.ttl)

    def release_node(self, node):
        """
        release a single redis node
        """
        # use the lua script to release the lock in a safe way
        node._release_script(keys=[self.resource], args=[self.lock_key])

    def acquire(self):

        # lock_key should be random and unique
        self.lock_key = uuid.uuid4().hex

        for retry in range(self.retry_times):
            acquired_node_count = 0
            start_time = datetime.utcnow()

            # acquire the lock in all the redis instances sequentially
            for node in self.redis_nodes:
                if self.acquire_node(node):
                    acquired_node_count += 1

            end_time = datetime.utcnow()
            elapsed_milliseconds = self._total_ms(end_time - start_time)

            # Add 2 milliseconds to the drift to account for Redis expires
            # precision, which is 1 milliescond, plus 1 millisecond min drift
            # for small TTLs.
            drift = (self.ttl * CLOCK_DRIFT_FACTOR) + 2

            if acquired_node_count >= self.quorum and self.ttl > (elapsed_milliseconds + drift):
                # 加锁成功就开启定时任务watchdog重置锁时长
                threading.Thread(target=self.delay_lock).start()
                return True
            else:
                for node in self.redis_nodes:
                    self.release_node(node)
                time.sleep(random.randint(0, self.retry_delay) / 1000)
        return False

    def release(self):
        for node in self.redis_nodes:
            self.release_node(node)


class ReentrantRedLock(RedLock):
    def __init__(self, *args, **kwargs):
        super(ReentrantRedLock, self).__init__(*args, **kwargs)
        self._acquired = 0

    def acquire(self):
        if self._acquired == 0:
            result = super(ReentrantRedLock, self).acquire()
            if result:
                self._acquired += 1
            return result
        else:
            self._acquired += 1
            return True

    def release(self):
        if self._acquired > 0:
            self._acquired -= 1
            if self._acquired == 0:
                return super(ReentrantRedLock, self).release()
            return True
        return False
        
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值