[LeetCode解题报告] 1712. 将数组分成三个子数组的方案数
一、 题目
1. 题目描述
我们称一个分割整数数组的方案是 好的 ,当它满足:
- 数组被分成三个 非空 连续子数组,从左至右分别命名为
left
,mid
,right
。 left
中元素和小于等于mid
中元素和,mid
中元素和小于等于right
中元素和。
给你一个 非负 整数数组 nums
,请你返回 好的 分割 nums
方案数目。由于答案可能会很大,请你将结果对 109 + 7
取余后返回。
示例 1:
输入:nums = [1,1,1]
输出:1
解释:唯一一种好的分割方案是将 nums 分成 [1] [1] [1] 。
示例 2:
输入:nums = [1,2,2,2,5,0]
输出:3
解释:nums 总共有 3 种好的分割方案:
[1] [2] [2,2,5,0]
[1] [2,2] [2,5,0]
[1,2] [2,2] [5,0]
示例 3:
输入:nums = [3,2,1] 输出:0 解释:没有好的分割方案。
提示:
3 <= nums.length <= 105
0 <= nums[i] <= 104
Related Topics
- 数组
- 双指针
- 二分查找
- 前缀和
- 👍 75
- 👎 0
2. 原题链接
二、 解题报告
1. 思路分析
- 先计算前缀和使区间和消耗降为O(1)。
- 然后枚举中间数组的左端点,范围是[1,n-1]。
- 由于数据非负,我们发现左端点固定时,右端点向左移动会使mid变小、right变大,向右移动会使mid变大,right变小。
- 也就是说,左端点固定时,右端点是有个取值范围的[j_min,j_max],而这个范围可以二分。
- 二分效率较低O(nlgn),我们发现i,j_min,j_max一直是同向(右)移动的,因此可以三指针,复杂度降为O(n)。
2. 复杂度分析
最坏时间复杂度O(nlog2n)
3. 代码实现
二分
class Solution:
def waysToSplit(self, nums: List[int]) -> int:
n = len(nums)
mod = 10**9+7
presum = list(accumulate(nums,initial=0))
def sum_interval(i,j):
return presum[j+1]-presum[i]
ans = 0
# 这段4796 ms 5.09%,O(nlgn)
# 遍历mid左端点
for i in range(1,n-1):
left = sum_interval(0,i-1)
# 左端点固定,右端点向右移动时,mid不断变大,right不断变小
# 那么右端点最小值可以二分:这个区域的和要大于等于left
j_min = bisect_left(range(n-1),left,lo=i,key=lambda j:sum_interval(i,j))
# 如果找不到,j_min==n-1,则这个不能当左端点;如果以这个位置分割,right<mid,也不可以。
if j_min >= n-1 or sum_interval(j_min+1,n-1) < sum_interval(i,j_min):
continue
# 右端点最大值必须满足:right-mid>=0,即mid-right<=0;随着j增加mid-right增加;
# 由于上边排除了,因此一定能找到j_max
j_max = bisect_right(range(n-1),0,lo=j_min,key=lambda j:sum_interval(i,j)-sum_interval(j+1,n-1))-1
ans += j_max-j_min+1
ans %= mod
return ans
三指针
class Solution:
def waysToSplit(self, nums: List[int]) -> int:
n = len(nums)
mod = 10**9+7
presum = list(accumulate(nums,initial=0))
def sum_interval(i,j):
return presum[j+1]-presum[i]
ans = 0
# 随着i右移,j_min是右移的,j_max右移的,所以可以三指针寻找。O(n)
j_min,j_max = 1,1
for i in range(1,n-1):
left = sum_interval(0,i-1)
j_min = max(j_min,i)
while j_min<n-1 and sum_interval(i,j_min) < left:
j_min += 1
# print(i,j_min,j_max,ans,left,sum_interval(i,j_min),sum_interval(j_min+1,n-1))
if j_min >= n-1 :
return ans
if sum_interval(j_min+1,n-1) < sum_interval(i,j_min):
continue
while j_max+1<n-1 and sum_interval(i,j_max+1) <= sum_interval(j_max+2,n-1):
j_max += 1
ans += j_max-j_min+1
ans %= mod
return ans
三、 本题小结
- 前缀和计算区间和的应用。