第一次写这个题是好长时间以前了,然后没调出来.
本来以为是思路错了,结果今天看题解发现思路没错,但是好多代码细节需要注意.
code:
#include <cstdio>
#include <vector>
#include <map>
#include <cstring>
#include <algorithm>
#define N 600008
#define ll long long
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
int total;
int rt[N],arr[N],rk[N],bu[N],id[N];
vector<int>G[N];
namespace seg
{
int tot;
int newnode() { return ++tot; }
struct data
{
int ls,rs,sum1;
ll sum2;
data() { ls=rs=sum1=sum2=0; }
data operator+(const data &b) const
{
data c;
c.sum1=sum1+b.sum1;
c.sum2=sum2+b.sum2;
return c;
}
}s[N*20],bl;
void update(int &x,int l,int r,int p,int v)
{
if(!x) x=newnode();
++s[x].sum1;
s[x].sum2+=v;
if(l==r) return;
int mid=(l+r)>>1;
if(p<=mid) update(s[x].ls,l,mid,p,v);
else update(s[x].rs,mid+1,r,p,v);
}
data query(int x,int l,int r,int L,int R)
{
if(!x||r<L||l>R||L>R) return bl;
if(l>=L&&r<=R) return s[x];
int mid=(l+r)>>1;
if(L<=mid&&R>mid) return query(s[x].ls,l,mid,L,R)+query(s[x].rs,mid+1,r,L,R);
else if(L<=mid) return query(s[x].ls,l,mid,L,R);
else return query(s[x].rs,mid+1,r,L,R);
}
int merge(int x,int y)
{
if(!x||!y) return x+y;
int now=newnode();
s[now].sum1=s[x].sum1+s[y].sum1;
s[now].sum2=s[x].sum2+s[y].sum2;
s[now].ls=merge(s[x].ls,s[y].ls);
s[now].rs=merge(s[x].rs,s[y].rs);
return now;
}
};
namespace sam
{
int tot,last;
map<int,int>ch[N];
int len[N],pre[N];
void init() { tot=last=1; }
void extend(int c)
{
int np=++tot,p=last;
len[np]=len[p]+1,last=np;
for(;p&&!ch[p][c];p=pre[p]) ch[p][c]=np;
if(!p) pre[np]=1;
else
{
int q=ch[p][c];
if(len[q]==len[p]+1) pre[np]=q;
else
{
int nq=++tot;
len[nq]=len[p]+1;
pre[nq]=pre[q],pre[q]=pre[np]=nq,ch[nq]=ch[q];
for(;p&&ch[p][c]==q;p=pre[p]) ch[p][c]=nq;
}
}
G[np].push_back(len[np]);
seg::update(rt[np],1,total,len[np],len[np]);
}
void get_rank()
{
int i,j;
for(i=1;i<=tot;++i) ++bu[len[i]];
for(i=1;i<=tot;++i) bu[i]+=bu[i-1];
for(i=1;i<=tot;++i) rk[bu[len[i]]--]=i;
}
};
int main()
{
// setIO("input");
int i,j,n;
sam::init();
scanf("%d",&n);
for(i=1;i<=n;++i) scanf("%d",&arr[i]);
for(i=1;i<n;++i) arr[i]=arr[i+1]-arr[i];
total=n-1;
for(i=1;i<=total;++i) sam::extend(arr[i]);
sam::get_rank();
ll ans=0ll;
ans=1ll*n*(n-1)/2;
for(i=1;i<=sam::tot;++i) id[i]=i;
for(i=sam::tot;i>=2;--i)
{
int u=rk[i];
int a=rk[i];
int b=sam::pre[a];
if(G[id[a]].size()>G[id[b]].size()) swap(a,b);
// id[a] < id[b]
for(j=0;j<G[id[a]].size();++j)
{
int x=G[id[a]][j];
ans+=(ll)sam::len[sam::pre[u]]*seg::query(rt[b],1,total,1,x-sam::len[sam::pre[u]]-1).sum1;
ans+=(ll)sam::len[sam::pre[u]]*seg::query(rt[b],1,total,x+sam::len[sam::pre[u]]+1,total).sum1;
seg::data tmp1=seg::query(rt[b],1,total,x-sam::len[sam::pre[u]],x-2);
seg::data tmp2=seg::query(rt[b],1,total,x+2,x+sam::len[sam::pre[u]]);
ans+=(ll)(x-1)*tmp1.sum1-tmp1.sum2;
ans+=tmp2.sum2-(ll)(x+1)*tmp2.sum1;
G[id[b]].push_back(x);
}
id[sam::pre[u]]=id[b];
rt[sam::pre[u]]=seg::merge(rt[a],rt[b]);
}
printf("%lld\n",ans);
return 0;
}