题意
有两棵树
T
1
T_1
T1 和
T
2
T_2
T2,大小分别为
n
1
,
n
2
n_1,n_2
n1,n2。构造一个新图,其中的每个节点有二元组
(
u
,
v
)
(
1
≤
u
≤
n
1
,
1
≤
v
≤
n
2
)
(u,v)(1\le u\le n_1,1\le v\le n_2)
(u,v)(1≤u≤n1,1≤v≤n2) 表示。
(
u
,
v
1
)
,
(
u
,
v
2
)
(u,v_1),(u,v_2)
(u,v1),(u,v2) 相邻当且仅当在
T
2
T_2
T2 中
v
1
,
v
2
v_1,v_2
v1,v2 相邻。
(
u
1
,
v
)
,
(
u
2
,
v
)
(u_1,v),(u_2,v)
(u1,v),(u2,v) 相邻当且仅当在
T
1
T_1
T1 中
u
1
,
u
2
u_1,u_2
u1,u2 相邻。问新图中有多少个不同的长度为
k
k
k 的环。
n
1
,
n
2
≤
4000
,
k
≤
75
n_1,n_2\le 4000,k\le 75
n1,n2≤4000,k≤75。
分析
实际上就是在每棵树上分别走。如果我们能对每棵树求出
s
t
e
p
i
step_i
stepi表示长度为
i
i
i的环的数量的话,就可以很容易求出答案,问题在于
s
t
e
p
i
step_i
stepi怎么求。
考虑点分治,然后求所有经过分治中心
c
c
c的环。
设
f
i
,
x
f_{i,x}
fi,x表示从
c
c
c开始走了
i
i
i步走到
x
x
x,且除了一开始以外不经过
c
c
c的方案数,
g
i
,
x
g_{i,x}
gi,x表示从
c
c
c开始走了
i
i
i步走到
x
x
x的方案数。
转移比较显然,那么对于某个点
x
x
x,从
x
x
x开始走,经过
c
c
c的大小为
i
i
i的环的数量就是
∑
j
=
0
i
f
j
,
x
∗
g
i
−
j
,
x
\sum_{j=0}^if_{j,x}*g_{i-j,x}
j=0∑ifj,x∗gi−j,x
这里可以看成是枚举第一次到达
c
c
c是在第几步,然后后面随便走。需要特判
x
x
x就是分治中心的情况。
这样的话总的时间复杂度就是
O
(
n
k
2
log
n
)
O(nk^2\log n)
O(nk2logn),如果用FFT来优化卷积的话可以做到
O
(
n
k
log
n
log
k
)
O(nk\log n\log k)
O(nklognlogk)。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
typedef long long LL;
const int N=4005;
const int M=80;
const int MOD=998244353;
int m,jc[M],ny[M];
struct Tree
{
int n,cnt,last[N],f[M][N],g[M][N],size[N],ans[M],w[N],tot,a[N],sum,root;
bool vis[N];
struct edge{int to,next;}e[N*2];
void addedge(int u,int v)
{
e[++cnt].to=v;e[cnt].next=last[u];last[u]=cnt;
e[++cnt].to=u;e[cnt].next=last[v];last[v]=cnt;
}
void get_root(int x,int fa)
{
size[x]=1;w[x]=0;
for (int i=last[x];i;i=e[i].next)
{
if (e[i].to==fa||vis[e[i].to]) continue;
get_root(e[i].to,x);
size[x]+=size[e[i].to];
w[x]=std::max(w[x],size[e[i].to]);
}
w[x]=std::max(w[x],sum-size[x]);
if (!root||w[x]<w[root]) root=x;
}
void get(int x,int fa)
{
a[++tot]=x;size[x]=1;
for (int i=last[x];i;i=e[i].next)
if (e[i].to!=fa&&!vis[e[i].to]) get(e[i].to,x),size[x]+=size[e[i].to];
}
void calc(int x)
{
tot=0;get(x,0);
for (int i=1;i<=tot;i++) f[0][a[i]]=0;
f[0][x]=g[0][x]=1;
for (int i=1;i<=m;i++)
for (int j=1;j<=tot;j++)
{
int y=a[j];
f[i][y]=g[i][y]=0;
for (int k=last[y];k;k=e[k].next)
{
if (vis[e[k].to]) continue;
(g[i][y]+=g[i-1][e[k].to])%=MOD;
if (y!=x) (f[i][y]+=f[i-1][e[k].to])%=MOD;
}
}
for (int i=1;i<=tot;i++)
{
int y=a[i];
if (y==x)
{
for (int j=0;j<=m;j++) (ans[j]+=g[j][x])%=MOD;
continue;
}
for (int j=0;j<=m;j++)
for (int k=0;j+k<=m;k++)
(ans[j+k]+=(LL)f[j][y]*g[k][y]%MOD)%=MOD;
}
vis[x]=1;
for (int i=last[x];i;i=e[i].next)
{
if (vis[e[i].to]) continue;
root=0;sum=size[e[i].to];
get_root(e[i].to,x);
calc(root);
}
}
void solve()
{
sum=n;root=0;
get_root(1,0);
calc(root);
}
}t1,t2;
int C(int n,int m)
{
return (LL)jc[n]*ny[m]%MOD*ny[n-m]%MOD;
}
int main()
{
scanf("%d%d%d",&t1.n,&t2.n,&m);
jc[0]=jc[1]=ny[0]=ny[1]=1;
for (int i=2;i<=m;i++) jc[i]=(LL)jc[i-1]*i%MOD,ny[i]=(LL)(MOD-MOD/i)*ny[MOD%i]%MOD;
for (int i=2;i<=m;i++) ny[i]=(LL)ny[i-1]*ny[i]%MOD;
for (int i=1;i<t1.n;i++)
{
int x,y;scanf("%d%d",&x,&y);
t1.addedge(x,y);
}
for (int i=1;i<t2.n;i++)
{
int x,y;scanf("%d%d",&x,&y);
t2.addedge(x,y);
}
t1.solve();t2.solve();
int s=0;
for (int i=0;i<=m;i++)
(s+=(LL)t1.ans[i]*t2.ans[m-i]%MOD*C(m,i)%MOD)%=MOD;
printf("%d",s);
return 0;
}