无脑暴力+O2=AC
题目要统计距离两两相等的三个点的组数,这三个点之间显然有一个点,并且这三个点到这个点的距离都相同.所以枚举中间这个点作为根,然后bfs整棵树,对于每一层,把以根的某个儿子的子树中在这一层点的数量统计出来,那么这样三元组的数量就是在这些点里面选3个点,并且分别来自不同子树的方案,\(f_{i,0/1/2/3}\)即可
详见代码
// luogu-judger-enable-o2
#include<bits/stdc++.h>
#define LL long long
#define il inline
#define re register
using namespace std;
const int N=5000+10;
il int rd()
{
int x=0,w=1;char ch=0;
while(ch<'0'||ch>'9') {if(ch=='-') w=-1;ch=getchar();}
while(ch>='0'&&ch<='9') {x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
return x*w;
}
LL ans,f[N][4];
int to[N<<1],nt[N<<1],hd[N],dg[N],tot=1,a[N];
bool v[N];
il void add(int x,int y)
{
++tot,to[tot]=y,nt[tot]=hd[x],hd[x]=tot,++dg[x];
++tot,to[tot]=x,nt[tot]=hd[y],hd[y]=tot,++dg[y];
}
int n,m;
int main()
{
n=rd();
for(int i=1;i<n;++i) add(rd(),rd());
for(int i=0;i<=n;++i) f[i][0]=1;
queue<int> id[2],q[2];
for(int i=1;i<=n;++i)
{
memset(v,0,sizeof(v));
v[i]=1;
while(!id[0].empty()) id[0].pop();
while(!id[1].empty()) id[1].pop();
while(!q[0].empty()) q[0].pop();
while(!q[1].empty()) q[1].pop();
m=dg[i];
int nw=1,la=0;
for(int j=hd[i],k=1;j;j=nt[j],++k) id[0].push(k),q[0].push(to[j]);
while(!q[la].empty())
{
memset(a,0,4*(m+1));
while(!q[la].empty())
{
int k=id[la].front(),x=q[la].front();
id[la].pop(),q[la].pop();
++a[k],v[x]=true;
for(int j=hd[x];j;j=nt[j])
{
int y=to[j];
if(!v[y]) id[nw].push(k),q[nw].push(y);
}
}
for(int j=1;j<=m;++j)
{
for(int k=1;k<=3;++k) f[j][k]=f[j-1][k]+f[j-1][k-1]*a[j];
}
ans+=f[m][3];
swap(nw,la);
}
}
printf("%lld\n",ans);
return 0;
}
正解是长链剖分
咕咕咕
其实上面那个dp比较沙雕,可以直接设\(f_{i,j}\)为点\(i\)子树内到\(i\)距离为\(j\)的点个数,\(g_{i,j}\)为点\(i\)子树内,到lca距离为\(d\),且这个lca到\(i\)距离为\(d-j\)的点对个数,然后转移就是
\[ans=\sum_{x}g_{x,0}+\sum_{y=son_x}\sum_{j}f_{x,j-1}*g_{y,j}\]\[f_{x,j}=\sum_{y=son_x} f_{y,j-1}\]\[g_{x,j}=\sum_{y=son_x,z=son_x,y<z} f_{y,j-1}*f_{z,j-1}+\sum_{y=son_x} g_{y,j+1}\]
对整棵树长链剖分之后,那么转移时就直接可以继承重儿子的信息,轻儿子直接暴力合并,因为每个点只会在链顶被暴力合并上去,所以复杂度是\(O(n)\)的
#include<bits/stdc++.h>
#define LL long long
#define il inline
#define re register
using namespace std;
const int N=100000+10;
il int rd()
{
int x=0,w=1;char ch=0;
while(ch<'0'||ch>'9') {if(ch=='-') w=-1;ch=getchar();}
while(ch>='0'&&ch<='9') {x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
return x*w;
}
int to[N<<1],nt[N<<1],hd[N],tot=1;
il void add(int x,int y)
{
++tot,to[tot]=y,nt[tot]=hd[x],hd[x]=tot;
++tot,to[tot]=x,nt[tot]=hd[y],hd[y]=tot;
}
int n,m;
int ff[N],de[N],dpt[N],son[N];
LL *f[N],*g[N],rbq[N<<3],*uc=rbq,ans;
void dfs1(int x)
{
dpt[x]=de[x];
for(int i=hd[x];i;i=nt[i])
{
int y=to[i];
if(y==ff[x]) continue;
ff[y]=x,de[y]=de[x]+1,dfs1(y),dpt[x]=max(dpt[x],dpt[y]);
if(dpt[son[x]]<dpt[y]) son[x]=y;
}
}
void dd(int x)
{
if(son[x]) f[son[x]]=f[x]+1,g[son[x]]=g[x]-1,dd(son[x]);
f[x][0]=1,ans+=g[x][0];
for(int i=hd[x];i;i=nt[i])
{
int y=to[i];
if(y==ff[x]||y==son[x]) continue;
f[y]=uc,uc+=(dpt[y]-de[x]+1)<<1,g[y]=uc,uc+=(dpt[y]-de[x]+1)<<1;
dd(y);
for(int j=0;j<dpt[y]-de[x];++j)
{
if(j) ans+=f[x][j-1]*g[y][j];
ans+=g[x][j+1]*f[y][j];
}
for(int j=0;j<dpt[y]-de[x];++j)
{
g[x][j+1]+=f[x][j+1]*f[y][j];
if(j) g[x][j-1]+=g[y][j];
f[x][j+1]+=f[y][j];
}
}
}
int main()
{
n=rd();
for(int i=1;i<n;++i) add(rd(),rd());
de[1]=1,dfs1(1);
f[1]=uc,uc+=(dpt[1]+1)<<1,g[1]=uc,uc+=(dpt[1]+1)<<1;
dd(1);
printf("%lld\n",ans);
return 0;
}