二分+暴力
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <map>
#include <vector>
using namespace std;
typedef long long int LL;
int n;
LL a[500500];
LL sum[500500];
vector<int> v1,v2;
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
cin>>a[i];
sum[i]=sum[i-1]+a[i];
}
if(sum[n]%3LL!=0)
{
puts("0");
return 0;
}
LL X=sum[n]/3;
for(int i=1;i<=n;i++)
{
if(sum[i]==X)
{
if(i!=n&&i!=n-1)
v1.push_back(i);
}
if(sum[i]==2*X)
{
if(i!=n) v2.push_back(i);
}
}
LL ans=0;
int sz1=v1.size(),sz2=v2.size();
if(X!=0)
{
for(int i=0;i<sz1;i++)
{
int p1=v1[i];
int p2=lower_bound(v2.begin(),v2.end(),p1)-v2.begin();
if(p1==v2[p2]) p2++;
if(p2>=n) continue;
ans+=sz2-p2;
}
}
else
{
int LAST=0;
for(int i=0;i<sz1;i++)
{
int id=-1;
for(int j=LAST;j<sz2;j++)
{
if(v2[j]>v1[i])
{
LAST=j;
id=j; break;
}
}
ans+=sz2-id;
}
}
cout<<ans<<endl;
return 0;
}