解题思路
因为没有交集的路径求很麻烦,所以我们求 (所有的路径组 − 相交的路径组)
(a,b)与(c,d)相交的路径可以分为下面两种,两条的l c a 为同一个,或是 c从(a,b)的 l c a l c a lca向上走
所以我们设 f p i fp_i fpi为以 i i i为根且 l c a l c a lca的子树中,路径为p 的个数; g p i gp_i gpi 为以i 为根的子树,路径长为q 的路径,一截在子树内一截在子树外的个数。 f q 、 g q fq、gq fq、gq同理.
那么所有的路径组就为 ( ∑ f p i ) ∗ ( ∑ f q i ) (\sum fp_i)*(\sum fq_i) (∑fpi)∗(∑fqi)枚举所有路径的 l c a l c a lca
然后相交的路径组就为 ∑ ( f p i ∗ f q i + f p i ∗ g q i + f q i ∗ g p i ) \sum(fp_i*fq_i+fp_i*gq_i+fq_i*gp_i) ∑(fpi∗fqi+fpi∗gqi+fqi∗gpi)
具体如何求
f
p
f p
fp和
g
p
g p
gp,我们可以先求
f
i
,
j
f_{i,j}
fi,j 和
g
i
,
j
g_{i,j}
gi,j ,
f
i
,
j
f_{i,j}
fi,j表示以i为根的子树中,与i 距离为j 的路径个数;
g
i
,
j
g_{i,j}
gi,j 表示以i 为根的子树,在子树外与i 距离为j 的路径个数
然后我们DP、类似换根的操作就行了
代码
#include<cstdio>
#include<iostream>
#include<queue>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#define ll long long
using namespace std;
int n,p,q,u,v,k;
int head[10010],father[3100];
ll fq[3100],fp[3100],gq[3100],gp[3100],f[3100][3100],g[3100][3100];
struct c {
int x,next;
} a[10010];
void add(int x,int y) {
a[++k].x=y;
a[k].next=head[x];
head[x]=k;
}
void dfs_f(int x,int fa){
father[x]=fa;
f[x][0]=1;
for(int i=head[x];i;i=a[i].next)
{
int y=a[i].x;
if(y==fa)continue;
dfs_f(y,x);
for(int j=1;j<=max(p,q);j++)
f[x][j]+=f[y][j-1];
}
}
void dfs_g(int x,int fa){
for(int j=1;j<=max(p,q);j++)
g[x][j]+=g[fa][j-1];
g[x][0]=1;
for(int i=head[fa];i;i=a[i].next)
{
int y=a[i].x;
if(y==x||y==father[fa])continue;
for(int j=2;j<=max(p,q);j++)
g[x][j]+=f[y][j-2];
}
for(int i=head[x];i;i=a[i].next)
{
int y=a[i].x;
if(y!=fa)
dfs_g(y,x);
}
}
void dfs(int x,int fa){
for(int i=head[x];i;i=a[i].next)
{
int y=a[i].x;
if(y==fa)continue;
for(int j=1;j<p;j++)
fp[x]+=f[y][j-1]*(f[x][p-j]-f[y][p-j-1]);
for(int j=1;j<q;j++)
fq[x]+=f[y][j-1]*(f[x][q-j]-f[y][q-j-1]);
}
fp[x]/=2,fq[x]/=2;
for(int i=head[x];i;i=a[i].next)
{
int y=a[i].x;
if(y==fa)continue;
fp[x]+=f[y][p-1],fq[x]+=f[y][q-1];
}
for(int j=1;j<=p;j++)
gp[x]+=g[x][j]*f[x][p-j];
for(int j=1;j<=q;j++)
gq[x]+=g[x][j]*f[x][q-j];
for(int i=head[x];i;i=a[i].next)
{
int y=a[i].x;
if(y==fa)continue;
dfs(y,x);
}
}
int main(){
freopen("intersection.in","r",stdin);
freopen("intersection.out","w",stdout);
scanf("%d%d%d",&n,&p,&q);
for(int i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
dfs_f(1,0);
dfs_g(1,0);
dfs(1,0);
ll sum1=0,sum2=0,ans=0;
for(int i=1;i<=n;i++)
sum1+=fp[i],sum2+=fq[i];
ans=sum1*sum2;
for(int i=1;i<=n;i++)
ans-=fp[i]*fq[i]+fp[i]*gq[i]+fq[i]*gp[i];
printf("%lld",ans*4);
}