题目描述
给定一个整数数组 A
,以及一个整数 target
作为目标值,返回满足 i < j < k
且 A[i] + A[j] + A[k] == target
的元组 i, j, k
的数量。
由于结果会非常大,请返回 结果除以 10^9 + 7 的余数
。
示例 1:
输入:A = [1,1,2,2,3,3,4,4,5,5], target = 8 输出:20 解释: 按值枚举(A[i],A[j],A[k]): (1, 2, 5) 出现 8 次; (1, 3, 4) 出现 8 次; (2, 2, 4) 出现 2 次; (2, 3, 3) 出现 2 次。
示例 2:
输入:A = [1,1,2,2,2,2], target = 5 输出:12 解释: A[i] = 1,A[j] = A[k] = 2 出现 12 次: 我们从 [1,1] 中选择一个 1,有 2 种情况, 从 [2,2,2,2] 中选出两个 2,有 6 种情况。
提示:
3 <= A.length <= 3000
0 <= A[i] <= 100
0 <= target <= 300
解题思路
本题最直观的解法是三层循环寻找A[i]+A[j]+A[k] = target (i<j<k)但是O(N3)一定会超时,然后考虑优化到O(N2logN),A[i]+A[j] = target - A[k],也就是在两层循环内部执行lower_bound()和upper_bound(),结果依旧会超时。比较好的解决办法是分析三个数的关系:(1)三个数两两都相等;(2)三个数中只有两个数相等;(3)三个数都不相等;
int threeSumMulti(vector<int>& A, int target) {
map<long long,long long> mp;
long long mod = 1e9+7;
map<long long,long long>::iterator iti,itj,itk;
int len = A.size();
long long ans = 0;
for(int i=0;i<len;i++) mp[A[i]]++;
for(iti = mp.begin();iti!=mp.end();iti++){
for(itj = mp.begin();itj!=mp.end();itj++){
int k = target - iti->first - itj->first;
itk = mp.find(k);
if(itk == mp.end()) continue;
if((iti->first == itj->first) && (itj->first == k)) ans += (iti->second * (iti->second - 1) * (iti->second - 2))/6;
else if((iti->first == itj->first) && (itj->first != k)) ans += (iti->second *(iti->second - 1)/2)*itk->second;
else if((iti->first < itj->first) && (itj->first < k)) ans += iti->second* itj->second * itk->second;
ans = ans % mod;
}
}
return ans;
}