Leetcode 3130. Find All Possible Stable Binary Arrays II

文章讲述了作者尝试解决LeetCode题目3130的过程,通过递归和动态规划优化算法,从最初的版本到最终的优化,探讨了如何处理超出限制的排列组合问题。最终,作者展示了两种高效算法,特别是通过预计算阶乘和逆元,显著降低了时间复杂度和内存使用。
摘要由CSDN通过智能技术生成

0. 序言

这道题和题目3129本质上就是一道题目,唯一的差别就是取值范围的差异,题目3129的范围在 [ 1 , 200 ] [1,200] [1,200],而这道题在 [ 1 , 1000 ] [1, 1000] [1,1000],因此复杂度会更高。

很不幸,我只搞定了题目3129,而这道题则是死活都没有搞定,最后是看了一下题目3129当中其他大佬们的高效率算法之后改写了一下通过了最后的测试,不过不幸的是,具体到思路上依然是没有看懂,真的是有点伤……

所以,这里的话我会现在前两部分讲一下我自己的算法思路,以及为了提升算法效率而做的优化,总体来说的话,在题目3129当中将算法效率提升了10倍左右,耗时从3727ms降至574ms,不过不幸的是依然无法通过题目3130的测试样例……

因此,如果有读者对这部分内容不感兴趣的话可以直接跳到第三部分看一下大佬们的高效算法即可。

1. 算法思路

这一题我整体的算法思路是当做一个数学上的排列组合问题来做的,显然,如果0和1的个数 n , m ≤ l i m i t n, m \leq limit n,mlimit的话,那么显然我们可以直接给出答案 C n + m n C_{n+m}^{n} Cn+mn

但问题就在于如果有 n , m > l i m i t n, m > limit n,m>limit的情况,此时就不能直接用数学方法解了,本来考虑如果将 l i m i t + 1 limit+1 limit+1(不妨简记 l i m i t + 1 = k limit+1=k limit+1=k)个元素进行绑定然后填充的方式,倒是可以直接计算组合数为: C n + m − k n ⋅ ( n + 1 ) C_{n+m-k}^{n} \cdot (n+1) Cn+mkn(n+1)

不过这种情况仅限于 n < k n < k n<k k < m < 2 k k < m < 2k k<m<2k的情况,即要确保不可能存在两个组内的元素均不少于 k k k个,否则就会出现重复计数的情况。

最后,关于其他的情况,我们就是能用动态规划进行暴力求解了,就很繁琐。

2. 代码实现

1. 第一版本

首先,我们给出我们的第一版本的代码实现如下:

MOD = 10**9+7
FACTORIALS = [1 for _ in range(401)]
for i in range(1, 401):
    FACTORIALS[i] = i * FACTORIALS[i-1] % MOD
    
def rev(x):
    return pow(x, -1, MOD)

def C(n, m):
    if m < 0:
        return 0
    return FACTORIALS[n] * rev(FACTORIALS[m]) * rev(FACTORIALS[n-m]) % MOD

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        
        @lru_cache(None)
        def dp(n,m,k,p):
            # n -> zero, m -> one, k -> pre count, p -> pre element
            if n + (1-p) * k > limit * (m+1) or m + p * k > limit * (n+1):
                return 0
            elif n + (1-p) * k <= limit and m + p * k <= limit:
                ans = C(n+m, n)
            elif p == 0 and n + k <= limit and m > limit and m <= 2 * limit:
                ans = C(n+m, n) - C(n+m-limit-1, n) * (n+1)
            elif p == 1 and n > limit and m + k <= limit and n <= 2 * limit:
                ans = C(n+m, n) - C(n+m-limit-1, m) * (m+1)
            else:
                ans = 0
                if k*(1-p)+1 <= limit:
                    ans = (ans + dp(n-1, m, k*(1-p)+1, 0)) % MOD
                if k*p+1 <= limit:
                    ans = (ans + dp(m-1, n, k*p+1, 0)) % MOD
            return ans % MOD
        
        ans = dp(zero, one, 0, 0)
        return ans

这个实现基本就是翻译了一下我们的上述实现,在题目3129上的评测结果如下:耗时3727ms,占用内存756.5MB。

2. 第二版本

然后,我们注意到这里的n, m事实上是完全等价的,因此,我们就可以取消掉p这个元素,也就是无需再记录前一个元素是什么,直接对换n,m的值即可,选用是默认从n当中进行元素选择作为开头,这样就可以进一步提升cache的利用率了。

给出第二版代码实现如下:

MOD = 10**9+7
FACTORIALS = [1 for _ in range(401)]
for i in range(1, 401):
    FACTORIALS[i] = i * FACTORIALS[i-1] % MOD
    
def rev(x):
    return pow(x, -1, MOD)

def C(n, m):
    if m < 0:
        return 0
    return FACTORIALS[n] * rev(FACTORIALS[m]) * rev(FACTORIALS[n-m]) % MOD

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        
        @lru_cache(None)
        def dp(n,m,k):
            if n + k > limit * (m+1) or m > limit * (n+1):
                return 0
            elif n + k <= limit and m <= limit:
                ans = C(n+m, n)
            elif n + k <= limit and m > limit and m <= 2 * limit:
                ans = C(n+m, n) - C(n+m-limit-1, n) * (n+1)
            else:
                ans = dp(m-1, n, 1)
                if k+1 <= limit:
                    ans = (ans + dp(n-1, m, k+1)) % MOD
            return ans % MOD
        
        ans = dp(zero, one, 0)
        return ans

提交代码评测得到:耗时3157ms,占用内存684.4MB。

3. 第三版本

然后,我们注意到这个k也很碍事,既然n,m的地位完全等价了,我们只需要默认每次都必须从n当中选择1到limit个元素即可,这样就可以去掉k这个参数了,可以进一步优化cache。

MOD = 10**9+7
FACTORIALS = [1 for _ in range(401)]
for i in range(1, 401):
    FACTORIALS[i] = i * FACTORIALS[i-1] % MOD
    
def rev(x):
    return pow(x, -1, MOD)

def C(n, m):
    if m < 0:
        return 0
    return FACTORIALS[n] * rev(FACTORIALS[m]) * rev(FACTORIALS[n-m]) % MOD

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        
        @lru_cache(None)
        def dp(n,m):
            if n == 0:
                return 0
            elif m == 0:
                return 1 if n <= limit else 0
            elif n > limit * (m+1) or m > limit * n:
                return 0
            elif n <= limit and m <= limit:
                ans = C(n+m-1, m)
            elif n <= limit and m > limit and m <= 2 * limit:
                ans = C(n+m-1, m) - C(n+m-limit-2, n-1) * n
            else:
                ans = 0
                for i in range(1, min(limit, n) + 1):
                    ans = (ans + dp(m, n-i))
            return ans % MOD
        
        ans = (dp(zero, one) + dp(one, zero)) % MOD
        return ans

提交代码评测得到:耗时707ms,占用内存32.1MB。

4. 第四版本

最后,我们还对上述 C n m C_{n}^{m} Cnm的实现进行了一下优化,具体来说的话就是每次都算pow(n, -1 MOD)太耗时了,因此我们也像阶乘一样预先算好保存在一个数组当中即可。

给出python代码实现如下:

MOD = 10**9+7
FACTORIALS = [1 for _ in range(401)]
for i in range(1, 401):
    FACTORIALS[i] = i * FACTORIALS[i-1] % MOD
    
Inv_FACTORIALS = [pow(x, -1, MOD) for x in FACTORIALS]

def C(n: int, m: int):
    if m < 0:
        return 0
    return (FACTORIALS[n] * Inv_FACTORIALS[m] * Inv_FACTORIALS[n-m]) % MOD

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        
        @lru_cache(None)
        def dp(n,m):
            if n == 0:
                return 0
            elif m == 0:
                return 1 if n <= limit else 0
            elif n > limit * (m+1) or m > limit * n:
                return 0
            elif n <= limit and m <= limit:
                ans = C(n+m-1, m)
            elif n <= limit and m > limit and m <= 2 * limit:
                ans = C(n+m-1, m) - C(n+m-limit-2, n-1) * n
            else:
                ans = 0
                for i in range(1, min(limit, n) + 1):
                    ans = (ans + dp(m, n-i))
            return ans % MOD
        
        ans = (dp(zero, one) + dp(one, zero)) % MOD
        return ans

提交代码评测得到:耗时574ms,占用内存32MB。

3. 算法优化

不过可惜的是,即便如此,上述的算法依然无法通过题目3130的测试样例,还是会出现超时的情况。因此在下面的小节里面,我们摘录了两个大佬的两个算法实现,分别来自题目3129和题目3130的解答当中耗时最优的方法,然后稍微用我自己感觉更好理解的方式稍微翻译了一下,虽然我自己依然没有完全看明白具体的数学含义就是了。

不过插句题外话,虽然代码实现两个算法不太一样,不过从具体的思路以及参数计算来看,我觉得这俩实现很可能来自同一个大佬……

只能说,大佬牛逼……

1. 算法实现一

MOD = 10**9+7
FACTORIALS = [1 for _ in range(1001)]
for i in range(1, 1001):
    FACTORIALS[i] = i * FACTORIALS[i-1] % MOD
    
Inv_FACTORIALS = [pow(x, -1, MOD) for x in FACTORIALS]

def C(n: int, m: int):
    if m < 0:
        return 0
    return (FACTORIALS[n] * Inv_FACTORIALS[m] * Inv_FACTORIALS[n-m]) % MOD

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        ans = 0
        N = zero + one
        min_zero_group = (zero - 1) // limit + 1
        min_one_group = (one - 1) // limit + 1
        
        @lru_cache(None)
        def count(n, g, k):
            ans = C(n+g-1, g-1)
            r, flag = 1, -1
            while n - r * (k+1) >= 0:
                ans = (ans + flag * C(n - r*(k+1) + g-1, g-1) * C(g, r)) % MOD
                r += 1
                flag *= -1
            return ans
        
        for n in range(min_zero_group, zero+1):
            for m in range(n-1, n+1+1):
                if m < min_one_group or m > one:
                    continue
                flag = 1 if n != m else 2
                ans = (ans + flag * count(zero-n, n, limit - 1) * count(one-m, m, limit - 1)) % MOD
        return ans % MOD

提交代码评测得到:耗时72ms,占用内存17.1MB。

2. 算法实现二

MOD = 10**9+7
FACTORIALS = [1 for _ in range(1001)]
for i in range(1, 1001):
    FACTORIALS[i] = i * FACTORIALS[i-1] % MOD
    
Inv_FACTORIALS = [pow(x, -1, MOD) for x in FACTORIALS]

def C(n: int, m: int):
    if m < 0:
        return 0
    return (FACTORIALS[n] * Inv_FACTORIALS[m] * Inv_FACTORIALS[n-m]) % MOD

class Solution:
    
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        ans = 0
        N = zero + one
        min_zero_group = (zero - 1) // limit + 1
        min_one_group = (one - 1) // limit + 1
        
        def count(n, g, k):
            ans = C(n+g-1, g-1)
            r, flag = 1, -1
            while n - r * (k+1) >= 0:
                ans = (ans + flag * C(n - r*(k+1) + g-1, g-1) * C(g, r)) % MOD
                r += 1
                flag *= -1
            return ans
        
        for n in range(min_zero_group, zero+1):
            for m in range(n-1, n+1+1):
                if m < min_one_group or m > one:
                    continue
                flag = 1 if n != m else 2
                ans = (ans + flag * count(zero-n, n, limit - 1) * count(one-m, m, limit - 1)) % MOD
        return ans % MOD

提交代码评测得到:耗时94ms,占用内存16.7MB。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值