题意:
数据范围:
Analysis:
首先膜拜这题出的实在是过于神仙,蒟蒻瑟瑟发抖,完全不会做。
此题分为三个部分,我们分别解决。
先浅显的分析一些性质:
我们发现我们只需要两棵树边的交集,然后把它们拿出来。
会形成森林,那么对于这片森林的每一个联通块,数字必须一样。
设
E
1
E1
E1为第一棵树的边集,
E
2
E2
E2为第二颗树的边集。
由于
n
n
n个点的森林若一共有
k
k
k条边,那么一共会有
n
−
k
n-k
n−k个联通块。
所以答案即为:
y
n
−
∣
E
1
⋂
E
2
∣
y^{n-|E1\bigcap E2|}
yn−∣E1⋂E2∣。
对于
T
a
s
k
1
:
Task 1:
Task1:
两棵树都已确定,只要确定交集有多少条边即可,用
m
a
p
map
map随意维护一下。
对于
T
a
s
k
2
:
Task 2:
Task2:
我们发现第一棵树没有确定,我们需要枚举
E
1
E1
E1。
这是一个集合枚举,方案数极难计算,不妨考虑容斥。
套上一个集合容斥,我们设:
F
S
=
∑
T
⊆
S
∑
T
1
⊆
T
(
−
1
)
∣
T
∣
−
∣
T
1
∣
F
T
1
F_S=\sum_{T \subseteq S}\sum_{T1 \subseteq T}(-1)^{|T|-|T1|}F_{T1}
FS=T⊆S∑T1⊆T∑(−1)∣T∣−∣T1∣FT1
这是子集反演的套路,要记住。
让
F
S
=
y
n
−
∣
S
∣
F_S=y^{n-|S|}
FS=yn−∣S∣,表示边集交集为
S
S
S的答案,那么我们可以得到:
a
n
s
=
∑
E
1
∑
S
⊆
E
1
⋂
E
2
∑
T
⊆
S
(
−
1
)
∣
S
∣
−
∣
T
∣
∗
y
n
−
∣
T
∣
ans=\sum_{E1}\sum_{S \subseteq E1 \bigcap E2}\sum_{T \subseteq S} (-1)^{|S|-|T|}*y^{n-|T|}
ans=E1∑S⊆E1⋂E2∑T⊆S∑(−1)∣S∣−∣T∣∗yn−∣T∣
=
∑
S
⊆
E
2
y
n
−
∣
S
∣
∗
(
∑
T
⊆
S
(
−
1
)
∣
S
∣
−
∣
T
∣
∗
y
∣
S
∣
−
∣
T
∣
)
∗
G
S
=\sum_{S \subseteq E2}y^{n-|S|}*(\sum_{T \subseteq S}(-1)^{|S|-|T|}*y^{|S|-|T|})*G_S
=S⊆E2∑yn−∣S∣∗(T⊆S∑(−1)∣S∣−∣T∣∗y∣S∣−∣T∣)∗GS
其中
G
S
G_S
GS表示包含边集
S
S
S的边集
E
1
E1
E1有多少个,对于括号里面用二项式定理,我们能推得:
a
n
s
=
∑
S
⊆
E
2
y
n
−
∣
S
∣
∗
(
1
−
y
)
∣
S
∣
∗
G
S
ans=\sum_{S \subseteq E2}y^{n-|S|}*(1-y)^{|S|}*G_S
ans=S⊆E2∑yn−∣S∣∗(1−y)∣S∣∗GS
我们先考虑
G
S
G_S
GS如何算,首先我们把边集
S
S
S里的点全部连上,然后现在是一片森林,我们需要继续连边,既然若干连通块独立,我们不妨将一个连通块看成一个点,然后求树的形态个数,每次连边都是在连通块里随意选一个点去连。
因此设在边集
S
S
S中共有
k
k
k个连通块,且每个连通块点数为
a
i
a_i
ai,随意化一下公式就可以得到:
G
S
=
n
k
−
2
∗
∏
i
=
1
k
a
i
G_S=n^{k-2}*\prod_{i=1}^ka_i
GS=nk−2∗∏i=1kai
我们接着化原式:
a
n
s
=
∑
S
⊆
E
2
y
k
∗
(
1
−
y
)
n
−
k
∗
n
k
−
2
∗
∏
i
=
1
k
a
i
ans=\sum_{S \subseteq E2}y^{k}*(1-y)^{n-k}*n^{k-2}*\prod_{i=1}^ka_i
ans=S⊆E2∑yk∗(1−y)n−k∗nk−2∗i=1∏kai
=
(
1
−
y
)
n
n
2
∑
S
⊆
E
2
∏
i
=
1
k
n
y
1
−
y
a
i
=\frac{(1-y)^n}{n^2}\sum_{S \subseteq E2}\prod_{i=1}^k\frac{ny}{1-y}a_i
=n2(1−y)nS⊆E2∑i=1∏k1−ynyai
这样一个式子就简单很多了,既然子集数目很多,那我们考虑用
D
P
DP
DP去解决这个问题。
我们设
k
=
n
y
1
−
y
k=\frac{ny}{1-y}
k=1−yny,我们可以设
f
i
,
j
f_{i,j}
fi,j表示包含
i
i
i的连通块大小为
j
j
j,且
i
i
i所在连通块的贡献还没计算,子树内的所有选择方案,的连通块的贡献之和,即式子中的最右边那部分,那么答案就是
k
∑
j
j
∗
f
1
,
j
k\sum_jj*f_{1,j}
k∑jj∗f1,j。
但这样还不够,我们需要继续优化。
我们考虑设这个
D
P
DP
DP的生成函数为
F
(
x
)
F(x)
F(x),即
F
(
x
)
=
∑
i
>
0
f
i
x
i
F(x)=\sum_{i>0} f_ix^i
F(x)=∑i>0fixi。
我们再设
Z
i
=
k
∑
j
∗
f
i
,
j
Z_i=k\sum j*f_{i,j}
Zi=k∑j∗fi,j。
那么
F
i
(
x
)
=
x
∏
y
∈
s
o
n
(
Z
y
+
F
y
(
x
)
)
F_i(x)=x\prod_{y \in son}(Z_y+F_y(x))
Fi(x)=x∏y∈son(Zy+Fy(x))。
然后我们观察
Z
i
Z_i
Zi中的形式,会发现类似于求导后的系数,于是我们可以得到。
Z
x
=
k
F
x
′
(
1
)
=
k
(
x
∏
y
∈
s
o
n
(
Z
y
+
F
y
(
1
)
)
)
′
=
k
∏
y
∈
s
o
n
(
Z
y
+
F
y
(
1
)
)
+
(
∏
y
∈
s
o
n
(
Z
y
+
F
y
(
1
)
)
)
∑
y
∈
s
o
n
k
F
y
′
(
1
)
Z
y
+
F
y
(
1
)
Z_x=kF_x'(1)=k(x\prod_{y \in son}(Z_y+F_y(1)))'=k\prod_{y \in son}(Z_y+F_y(1))+(\prod_{y \in son}(Z_y+F_y(1)))\sum_{y \in son} \frac{kF_y'(1)}{Z_y+F_y(1)}
Zx=kFx′(1)=k(xy∈son∏(Zy+Fy(1)))′=ky∈son∏(Zy+Fy(1))+(y∈son∏(Zy+Fy(1)))y∈son∑Zy+Fy(1)kFy′(1)。
后面那一部分是通过如下式子得到的:
(
∏
i
=
1
n
F
(
x
)
)
′
=
∑
i
=
1
n
F
′
(
i
)
∏
j
≠
i
F
(
j
)
=
∏
i
=
1
n
F
(
i
)
∑
j
=
1
n
F
′
(
j
)
F
(
j
)
(\prod_{i=1}^nF(x))'=\sum_{i=1}^nF'(i)\prod_{j \neq i}F(j)=\prod_{i=1}^nF(i)\sum_{j=1}^n \frac{F'(j)}{F(j)}
(i=1∏nF(x))′=i=1∑nF′(i)j̸=i∏F(j)=i=1∏nF(i)j=1∑nF(j)F′(j)
再观察一下这个式子的形式,我们就可以发现可以
O
(
n
)
O(n)
O(n)递推了。
我们额外设
t
i
=
F
i
(
1
)
t_i=F_i(1)
ti=Fi(1),那么可以得到:
Z
i
=
t
i
(
k
+
∑
y
Z
y
Z
y
+
t
y
)
Z_i=t_i(k+\sum_y \frac{Z_y}{Z_y+t_y})
Zi=ti(k+y∑Zy+tyZy)
t
i
=
∏
y
(
Z
y
+
t
y
)
t_i=\prod_y(Z_y+t_y)
ti=y∏(Zy+ty)
Task3:
此时我们按照上一个
T
a
s
k
Task
Task里面的方法推导,最终的式子大约如下:
a
n
s
=
∑
S
y
n
−
∣
S
∣
∗
(
1
−
y
)
∣
S
∣
∗
G
S
2
ans=\sum_Sy^{n-|S|}*(1-y)^{|S|}*G^2_S
ans=S∑yn−∣S∣∗(1−y)∣S∣∗GS2
=
(
1
−
y
)
n
n
4
∑
S
∏
i
=
1
k
n
2
y
1
−
y
a
i
2
=\frac{(1-y)^n}{n^4}\sum_S\prod_{i=1}^k\frac{n^2y}{1-y}a_i^2
=n4(1−y)nS∑i=1∏k1−yn2yai2
一棵点数为
a
i
−
2
a_i-2
ai−2的树,形态有
a
i
a
i
−
2
种
a_i^{a_i-2}种
aiai−2种,再分配标号方案为
n
!
∏
i
a
i
!
\frac{n!}{\prod_ia_i!}
∏iai!n!,然后再去重,除掉连通块个数的阶乘。
我们考虑生成函数
F
(
x
)
=
∑
i
>
0
n
2
y
1
−
y
∗
i
i
i
!
F(x)=\sum_{i>0}\frac{n^2y}{1-y}*\frac{i^i}{i!}
F(x)=∑i>01−yn2y∗i!ii。
a
n
s
=
(
1
−
y
)
n
n
!
n
4
∑
i
=
1
n
[
x
n
]
f
(
x
)
i
ans=\frac{(1-y)^nn!}{n^4}\sum_{i=1}^n[x^n]f(x)^i
ans=n4(1−y)nn!i=1∑n[xn]f(x)i。
会发现后面是一个多项式
e
x
p
exp
exp的形式,因此只需要求一次
e
x
p
exp
exp即可。
复杂度
O
(
n
log
n
)
O(n \log n)
O(nlogn)。
Code:
# include<cstdio>
# include<cstring>
# include<algorithm>
# include<map>
using namespace std;
const int N = 1e5 + 5;
const int mo = 998244353;
const int invg = (mo + 1) / 3;
typedef long long ll;
map <ll,int> Q;
int z[N << 3],ans[N << 3],rev[N << 3],F[N << 3],E[N << 3];
int A[N << 3],B[N << 3],C[N << 3],D[N << 3],inv[N << 3];
int f[N],g[N],fac[N],st[N],to[N << 1],nx[N << 1];
int n,Y,op,L,len,tot,X;
inline void add(int u,int v)
{
to[++tot] = v,nx[tot] = st[u],st[u] = tot;
to[++tot] = u,nx[tot] = st[v],st[v] = tot;
}
inline int pow(int x,int p)
{
int ret = 1;
for (; p ; p >>= 1,x = (ll)x * x % mo)
if (p & 1) ret = (ll)ret * x % mo;
return ret;
}
inline int inc(int x,int y) { return x + y >= mo ? x + y - mo : x + y; }
inline int dec(int x,int y) { return x - y < 0 ? x - y + mo : x - y; }
inline void dft(int *f,int n,int op)
{
for (len = 1,L = 0 ; len <= n ; len <<= 1,++L);
for (int i = 0 ; i < len ; ++i)
{
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1));
if (i < rev[i]) swap(f[i],f[rev[i]]);
}
for (int i = 1 ; i < len ; i <<= 1)
{
int wn = pow(~op ? 3 : invg,(mo - 1) / (i << 1));
for (int j = 0 ; j < len ; j += (i << 1))
{
int w = 1;
for (int k = 0 ; k < i ; ++k,w = (ll)w * wn % mo)
{
int x = f[j + k],y = (ll)w * f[i + j + k] % mo;
f[j + k] = inc(x,y),f[i + j + k] = dec(x,y);
}
}
}
if (op == -1)
{
int x = pow(len,mo - 2);
for (int i = 0 ; i < len ; ++i) f[i] = (ll)f[i] * x % mo;
}
}
inline void Ginv(int *a,int *b,int n)
{
if (n == 1) { b[0] = pow(a[0],mo - 2); return; }
Ginv(a,b,n >> 1);
for (int i = 0 ; i < n ; ++i) A[i] = a[i],B[i] = b[i];
dft(A,n,1),dft(B,n,1);
for (int i = 0 ; i < len ; ++i) A[i] = (ll)A[i] * B[i] % mo * B[i] % mo;
dft(A,n,-1);
for (int i = 0 ; i < n ; ++i) b[i] = dec(inc(b[i],b[i]),A[i]);
for (int i = 0 ; i < len ; ++i) A[i] = B[i] = 0;
}
inline void Gln(int *a,int *b,int n)
{
Ginv(a,C,n);
for (int i = 0 ; i < n - 1 ; ++i) D[i] = (ll)(i + 1) * a[i + 1] % mo;
dft(C,n,1),dft(D,n,1);
for (int i = 0 ; i < len ; ++i) C[i] = (ll)C[i] * D[i] % mo;
dft(C,n,-1),b[0] = 0;
for (int i = 1 ; i < n ; ++i) b[i] = (ll)inv[i] * C[i - 1] % mo;
for (int i = 0 ; i < len ; ++i) C[i] = D[i] = 0;
}
inline void Gexp(int *a,int *b,int n)
{
if (n == 1) { b[0] = 1; return; }
Gexp(a,b,n >> 1);
for (int i = 0 ; i < n ; ++i) E[i] = b[i];
Gln(b,F,n);
for (int i = 0 ; i < n ; ++i) F[i] = dec(a[i],F[i]);
F[0] = inc(F[0],1);
dft(E,n,1),dft(F,n,1);
for (int i = 0 ; i < len ; ++i) E[i] = (ll)E[i] * F[i] % mo;
dft(E,n,-1);
for (int i = 0 ; i < n ; ++i) b[i] = E[i];
for (int i = 0 ; i < len ; ++i) E[i] = F[i] = 0;
}
inline void dfs(int x,int F)
{
g[x] = 1,f[x] = X;
for (int i = st[x] ; i ; i = nx[i])
if (to[i] != F)
{
dfs(to[i],x),g[x] = (ll)g[x] * inc(f[to[i]],g[to[i]]) % mo;
f[x] = (f[x] + (ll)f[to[i]] * pow(inc(f[to[i]],g[to[i]]),mo - 2) % mo) % mo;
} f[x] = (ll)f[x] * g[x] % mo;
}
inline int calc()
{
for (int i = 1 ; i < n ; ++i)
{
int u,v; scanf("%d%d",&u,&v);
Q[(ll)u * (n - 1) + v] = Q[(ll)v * (n - 1) + u] = 1;
} int cnt = 0;
for (int i = 1 ; i < n ; ++i)
{
int u,v; scanf("%d%d",&u,&v);
if (Q.count((ll)u * (n - 1) + v) || Q.count((ll)v * (n - 1) + u)) ++cnt;
} return pow(Y,n - cnt);
}
inline int calc1()
{
if (Y == 1) return pow(n,n - 2);
for (int i = 1 ; i < n ; ++i)
{
int u,v; scanf("%d%d",&u,&v);
add(u,v);
} X = (ll)n * Y % mo * pow(dec(1,Y),mo - 2) % mo,dfs(1,0);
X = pow(n,mo - 2);
return (ll)pow(dec(1,Y),n) * X % mo * X % mo * f[1] % mo;
}
inline int calc2()
{
if (Y == 1) return pow(n,2 * n - 4);
X = (ll)n * n % mo * Y % mo * pow(dec(1,Y),mo - 2) % mo;
int zs = 1; for (; zs <= n ; zs <<= 1); inv[1] = 1;
for (int i = 2 ; i < zs ; ++i) inv[i] = (ll)(mo - mo / i) * inv[mo % i] % mo;
int c = 1;
for (int i = 1 ; i < zs ; ++i,c = (ll)c * i % mo)
z[i] = (ll)pow(i,i) * pow(c,mo - 2) % mo * X % mo;
Gexp(z,ans,zs); X = (ll)pow(dec(1,Y),n) * pow(inv[n],4) % mo;
c = 1; for (int i = 2 ; i <= n ; ++i) c = (ll)c * i % mo;
return (ll)ans[n] * X % mo * c % mo;
}
int main()
{
scanf("%d%d%d",&n,&Y,&op);
if (!op) printf("%d\n",calc());
else if (op == 1) printf("%d\n",calc1());
else printf("%d\n",calc2());
return 0;
}