题目链接:https://leetcode.cn/problems/partition-array-into-two-arrays-to-minimize-sum-difference/solution/
题目大意:给出一个长度为2*N
的数组nums[]
,求将其分为等长的两个数组,使得两个数组的和的差的绝对值最小。求这个绝对值。
思路:一开始以为可以这样做:先排序,两个两个遍历,其中一个分到arr1[]
中,另一个分到arr2[]
中,如何分根据目前两个数组谁的和大谁的和小来决定。后来发现这样子太天真了,并不一定是两个数都是分开的,很有可能连着的这两个数被分到同一个数组中。
真的有点难,让chatGPT做,它也做的是错的。于是看了好几个题解,理解了一下。
首先,设nums[]
所有元素之和为tot
,tgt = tot/2
,那么目的就是使两个数组的和尽可能接近tgt
。
将数组分为前后等长的两部分。假设某个数组在前部分取k
个,那么后部分就取N-k
个。枚举前半部分所有可能取的情况,一共有2^N
种情况,可以用一个N
位二进制数来表示。
a1[k]
中的元素表示【在前半部分取k
个元素的和】的所有可能值,其余类推。计算出所有前半部分和后半部分取若干个数的和的可能值。
很明显,k
的可能取值就是[0, N]
,因此a1, a2
长度都是N+1
。countBits()
函数计算当前的状态掩码i
中有多少个1
,即取了多少个数。
vector<vector<int>> a1(N+1, vector<int>());
vector<vector<int>> a2(N+1, vector<int>());
for (int i = 0; i < (1 << N); i++) {
int cnt1 = 0, cnt2 = 0;
for (int j = 0; j < N; j++) {
if (i & (1 << j)) {
cnt1 += nums[j];
cnt2 += nums[N+j];
}
}
int k = countBits(i);
a1[k].push_back(cnt1);
a2[k].push_back(cnt2);
}
为了方便后序双指针求和,将所有的结果排序。
for (int i = 0; i <= N; i++) {
sort(a1[i].begin(), a1[i].end());
sort(a2[i].begin(), a2[i].end());
}
用两个指针l, r
来分别指向a1[k], a2[N-k]
,此时情况表示从前部分取k
个,从后部分取N-k
个。计算差,更新结果。如果和比tot
大,说明太靠右了,要往左,让r--
;如果和比tot
小,说明太靠左了,要往右,让l++
;如果刚好和等于tot
,说明这种分法能让两数组和的差为0,直接返回0
即可。
int ret = INT_MAX;
for (int k = 0; k <= N; k++) {
auto x = a1[k];
auto y = a2[N-k];
int l = 0, r = y.size()-1;
while (l < x.size() && r >= 0) {
ret = min(ret, abs(tot - x[l] - y[r] - x[l] - y[r]));
if (x[l] + y[r] + x[l] + y[r] > tot)
r--;
else if (x[l] + y[r] + x[l] + y[r] < tot)
l++;
else
return 0;
}
}
完整代码
class Solution {
public:
int countBits(int num) {
int ret = 0;
while (num != 0) {
ret += num & 1;
num >>= 1;
}
return ret;
}
int minimumDifference(vector<int>& nums) {
int N = nums.size() / 2;
if (N == 1)
return abs(nums[1] - nums[0]);
int tot = 0;
for (auto x : nums)
tot += x;
int tgt = tot / 2;
vector<vector<int>> a1(N+1, vector<int>());
vector<vector<int>> a2(N+1, vector<int>());
for (int i = 0; i < (1 << N); i++) {
int cnt1 = 0, cnt2 = 0;
for (int j = 0; j < N; j++) {
if (i & (1 << j)) {
cnt1 += nums[j];
cnt2 += nums[N+j];
}
}
int k = countBits(i);
a1[k].push_back(cnt1);
a2[k].push_back(cnt2);
}
for (int i = 0; i <= N; i++) {
sort(a1[i].begin(), a1[i].end());
sort(a2[i].begin(), a2[i].end());
}
int ret = INT_MAX;
for (int k = 0; k <= N; k++) {
auto x = a1[k];
auto y = a2[N-k];
int l = 0, r = y.size()-1;
while (l < x.size() && r >= 0) {
ret = min(ret, abs(tot - x[l] - y[r] - x[l] - y[r]));
if (x[l] + y[r] + x[l] + y[r] > tot)
r--;
else if (x[l] + y[r] + x[l] + y[r] < tot)
l++;
else
return 0;
}
}
return ret;
}
};