You've got array a[1], a[2], ..., a[n], consisting of n integers. Count the number of ways to split all the elements of the array into three contiguous parts so that the sum of elements in each part is the same.
More formally, you need to find the number of such pairs of indices i, j (2 ≤ i ≤ j ≤ n - 1), that .
The first line contains integer n (1 ≤ n ≤ 5·105), showing how many numbers are in the array. The second line contains n integers a[1], a[2], ..., a[n] (|a[i]| ≤ 109) — the elements of array a.
Print a single integer — the number of ways to split the array into three parts with the same sum.
5 1 2 3 0 3
2
4 0 1 -1 0
1
2 4 1
0
首先,我们用数组lsum[]记录前缀和。此时,正常思想是循环i,j两个位置变量,lsum[i]=sum/3,lsum[j]=sum*2/3,而这样的做法是O(n),显然不可行。另一个方法就是,一遍遍历找到lsum[i]=sum*2/3的位置,在找到这个位置之前,我们需要记录lsum[i]=sum/3的个数,我们用一个变量保存,这样在找到lsum[i]=sum*2/3的位置时,我们就可以计算出总的情况数了。这样的复杂度是O(n)。
#include<cstdio>
#include<cstdlib>
#include<iostream>
#include<stack>
#include<queue>
#include<algorithm>
#include<string>
#include<cstring>
#include<cmath>
#include<vector>
#include<map>
#include<set>
#define eps 1e-8
#define zero(x) (((x>0?(x):-(x))-eps)
#define mem(a,b) memset(a,b,sizeof(a))
#define memmax(a) memset(a,0x3f,sizeof(a))
#define pfn printf("\n")
#define ll __int64
#define ull unsigned long long
#define sf(a) scanf("%d",&a)
#define sf64(a) scanf("%I64d",&a)
#define sf264(a,b) scanf("%I64d%I64d",&a,&b)
#define sf364(a,b,c) scanf("%I64d%I64d%I64d",&a,&b,&c)
#define sf464(a,b,c,d) scanf("%I64d%I64d%I64d%I64d",&a,&b,&c,&d)
#define sf2(a,b) scanf("%d%d",&a,&b)
#define sf3(a,b,c) scanf("%d%d%d",&a,&b,&c)
#define sf4(a,b,c,d) scanf("%d%d%d%d",&a,&b,&c,&d)
#define sff(a) scanf("%f",&a)
#define sfs(a) scanf("%s",a)
#define sfs2(a,b) scanf("%s%s",a,b)
#define sfs3(a,b,c) scanf("%s%s%s",a,b,c)
#define sfd(a) scanf("%lf",&a)
#define sfd2(a,b) scanf("%lf%lf",&a,&b)
#define sfd3(a,b,c) scanf("%lf%lf%lf",&a,&b,&c)
#define sfd4(a,b,c,d) scanf("%lf%lf%lf%lf",&a,&b,&c,&d)
#define sfc(a) scanf("%c",&a)
#define ull unsigned long long
#define pp pair<int,int>
#define debug printf("***\n")
#define pi 3.1415927
#define mod 1000000007
#define rep(i,a,b) for(int i=a;i<b;i++)
const double PI = acos(-1.0);
const double e = exp(1.0);
const int INF = 0x7fffffff;;
template<class T> T gcd(T a, T b) { return b ? gcd(b, a % b) : a; }
template<class T> T lcm(T a, T b) { return a / gcd(a, b) * b; }
template<class T> inline T Min(T a, T b) { return a < b ? a : b; }
template<class T> inline T Max(T a, T b) { return a > b ? a : b; }
bool cmpbig(int a, int b){ return a>b; }
bool cmpsmall(int a, int b){ return a<b; }
using namespace std;
bool cmp(pp a,pp b)
{
if(a.first!=b.first)
return a.first<b.first;
return a.second<b.second;
}
ll a[500010];
ll lsum[500010];
int main()
{
//freopen("data.in","r",stdin);
//freopen("data.out" ,"w",stdout);
ll n;
while(cin>>n)
{
ll sum=0;
rep(i,0,n)
{
sf64(a[i]);
sum+=a[i];
lsum[i]=sum;
}
if(sum%3!=0||n<3)
{
cout<<"0"<<endl;
return 0;
}
ll cnt=0,pos=0;
rep(i,0,n)
{
if(i!=n-1&&lsum[i]==sum*2/3)
cnt+=pos;
if(lsum[i]==sum/3)
pos++;
}
cout<<cnt<<endl;
}
return 0;
}