题目链接:https://leetcode.cn/problems/maximum-number-of-ways-to-partition-an-array/description/
题目大意:给出一个数组nums[]
和一个数字k
。可以不改变数组,也可以将数组中某一个元素改为k
。求改变/不改变后的数组中,能够找到的最大的下标pivot
的数量,pivot
满足从下标区间[0, pivot-1]
的和=下标区间[pivot, n]
的和。
思路:理解题意很重要,在统计满足条件的pivot
的时候,如果改变了,那么改变的元素num[i]
是确定的,也就是不能在pivot=1
时改变nums[0]
,让两边相等;而在pivot=2
时改变nums[4]
,让两边相等这样的操作。
因此也就是对n+1
个数组,求最大的满足条件的pivot
数量。如果两次循环,肯定是会超时的。因此需要想办法减少重复计算。前缀和其实我也想到了,但是这题的关键点还是要使用哈希表。
使用两个哈希表ml, mr
,表示在改变位置i
的元素nums[i]
时,位置i
左边和右边的【到某个位置j
为止,可能的各个前缀和所对应的出现次数】。
当不改变时,就相当于改变了i=-1
位置的元素,那么i
左边没有任何位置,所以所有的前缀和都在右边,也就是mr
里面。设所有元素和为ttl
,此时只需要ttl
是偶数(这样才能被某个pivot
分为和相等的两半),然后就可以得到pivot
的数量即为mr[ttl/2]
,这就是使得pivot
的左右两半和相等的pivot
的数量。
// no change, then ttl must be even to get some ans
if (ttl % 2 == 0)
ans = mr[ttl/2];
看题解时我其实绕了好久,这里理解的重点是,分清楚【所谓的“左右”是谁的左右】。大家做题时会普遍先入为主理解,认为“左右”就是某个pivot
的左右。然而题解中,是先针对【改变的位置i
】来定义两个哈希表的,这两个哈希表代表的“左右”,是【改变的位置i
】的左右两侧,和“求和使两边相等”没有任何关系。
ml, mr
中存的,是【使得某个前缀和(实际上就是从左侧0
求和求到某个位置j
)等于某个值的「位置」的数量】,也就是说,两个哈希表中存的都是“左半边的和”,因为只要【左半边的和=全部和/2】实际上就是【左半边和=右半边和】了。到这里,“某个pivot
的左右”已经无关紧要了,因为问题已经转换成了找满足【左半边的和=某个前缀和=全部和/2】的pivot
,而我们早就计算过了前缀和。
因此题解中的“左右”、两个哈希表所代表的“左右”,都是指改变位置i
的值为k
中,这个i
的左右。对i
循环的过程中,ml
中的元素会越来越多,mr
中多元素会越来越少。我们要在这个这个i
的左边和右边,分别去找满足条件的pivot
。
每次改变nums[i]
,对总和的改变是dif = k - nums[i]
,稍微推导一下,就可以知道,如果pivot
选在i
左边,那么要满足presum[pivot] = (ttl+dif)/2
;如果pivot
选在i
右边,那么要满足presum[pivot] = (ttl-dif)/2
。那么直接查哈希表,加起来就行了。
// change nums[i] to k
for (int i = 0; i < n; i++) {
int dif = k - nums[i];
if ((ttl + dif) % 2 == 0)
ans = max(ans, ml[(ttl+dif)/2] + mr[(ttl-dif)/2]);
ml[presum[i]]++;
mr[presum[i]]--;
}
这样就只有一层循环,复杂度为 O ( N ) O(N) O(N)
完整代码
class Solution {
public:
int waysToPartition(vector<int>& nums, int k) {
int n = nums.size();
vector<long> presum(n, 0);
unordered_map<long, int> mr, ml;
presum[0] = nums[0];
for (int i = 1; i < n; i++) {
presum[i] = presum[i-1] + nums[i];
mr[presum[i-1]]++;
}
long ttl = presum.back();
int ans = 0;
// no change, then ttl must be even to get some ans
if (ttl % 2 == 0)
ans = mr[ttl/2];
// change nums[i] to k
for (int i = 0; i < n; i++) {
int dif = k - nums[i];
if ((ttl + dif) % 2 == 0)
ans = max(ans, ml[(ttl+dif)/2] + mr[(ttl-dif)/2]);
ml[presum[i]]++;
mr[presum[i]]--;
}
return ans;
}
};