三分数列
给出一个有n 个整数的数组a[1],a[2],...,a[n], 有多少种方法把数组分成3 个连续的子序列,使得各子序列的元素之和相等。
也就是说,有多少个下标对i,j (2≤i≤j≤n-1), 满足:sum(a[1]..a[i-1]) = sum(a[i]..a[j]) = sum(a[j+1]..a[n])
详细题解参见代码
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
inline void _read(long long &x){
char ch=getchar(); bool mark=false;
for(;!isdigit(ch);ch=getchar())if(ch=='-')mark=true;
for(x=0;isdigit(ch);ch=getchar())x=x*10+ch-'0';
if(mark)x=-x;
}
long long a[500005],sum[500005];
long long pos1[500005];
long long pos2[500005];
long long cnt1=0;
long long cnt2=0;
int main(){
long long n,i,j,k,tot,tot2,cnt=0,pos;
_read(n);
for(i=1;i<=n;i++){
_read(a[i]);sum[i]=sum[i-1]+a[i];
}
if(n<3||sum[n]%3!=0){
cout<<"0";return 0;
}
tot=sum[n]/3;tot2=2*tot;
for(i=1;i<=n;i++){
if(sum[i]==tot){
pos1[++cnt1]=i;//pos1记录下sum[i]==sum/3的i;
}
if(sum[i]==tot2){
pos2[++cnt2]=i;//pos1记录下sum[i]==sum*2/3的i;
}
}
while(pos1[cnt1]>=n-1){
cnt1--;//除掉不可能的情况
}
if(pos2[cnt2]==n){
cnt2--;//除掉不可能的情况
}
for(i=1;i<=cnt1;i++){
pos=lower_bound(pos2+1,pos2+1+cnt2,pos1[i])-pos2;//使用二分查找
if(pos2[pos]<=pos1[i])pos++;//处理掉特殊情况
if(pos>cnt2||pos<1||pos2[pos]<=pos1[i])continue;
cnt=cnt+cnt2-pos+1;//加上个数
}
cout<<cnt;
}
Solution 2 by spark :
分析:
记数列的前缀和为sum[],显然,当sum[n]%3!=0时无解;
设第一段结尾的数的下标为x,第二段数结尾为y
那么一定会有 sum[x]= sum[n]/3 ;sum[y]= sum[n] * 2/3 ; (1<=x<n-2,2<=y<=n-1),注意这里的x,y,范围很重要,否则会wa;
将 前n个前缀和中等于sum[n]/3的个数记为cnt[n];
那么对于每一个sum[j]==sum[n] *2/3,与前面每一个sum[i]==sum[n]/3都是一种方案,就有,cnt[j-1]种方案。
注意使用 long long
代码如下:
#include<iostream>
#include<cstdio>
#include<algorithm>
#define LL long long
using namespace std;
const int maxn=500000+5;
LL n,s[maxn],sum[maxn],cnt[maxn];
inline void _read(LL &x){
char ch=getchar(); bool mark=false;
for(;!isdigit(ch);ch=getchar())if(ch=='-')mark=true;
for(x=0;isdigit(ch);ch=getchar())x=x*10+ch-'0';
if(mark)x=-x;
}
int main(){
LL i,j,ans=0,cnt1=0,cnt2=0;
_read(n);
for(i=1;i<=n;i++){
_read(s[i]);
sum[i]=sum[i-1]+s[i];
}
if(sum[n]%3!=0){ //无解
cout<<"0"; return 0;
}
for(i=1;i<=n;i++)
if(sum[i]==sum[n]/3)cnt[i]=cnt[i-1]+1;
else cnt[i]=cnt[i-1];
for(j=2;j<n;j++) //注意j从2到n-1;
if(sum[j]==sum[n]*2/3)ans+=cnt[j-1];
cout<<ans;
}