Description
给出一棵 n n n个节点的树,初始每个点为白色,对于一个点集 S S S,定义 f ( s ) f(s) f(s)为:将 S S S中点染成黑色,如果存在任何白点在两个黑点路径之间则 f ( S ) = 0 f(S)=0 f(S)=0,否则选取一个路径集合,使得该路径集合中每条路径斗不包含黑点,之后把路径上的点染成红色,此时 f ( S ) f(S) f(S)为使得所有黑点的邻接点为黑点或红点的路径集合数量,要求对于每个非空子集 S S S计算 f ( S ) f(S) f(S)之和
Input
第一行一整数 n n n表示点数,之后 n − 1 n-1 n−1行每行输入两个整数表示一条树边
( 1 ≤ n ≤ 2 ⋅ 1 0 5 ) (1\le n\le 2\cdot 10^5) (1≤n≤2⋅105)
Output
输出 f ( S ) f(S) f(S)之和,结果模 998244353 998244353 998244353
Sample Input
2
1 2
Sample Output
3
Solution
第二步所选边集会将这棵树分成若干没有红点的连通分支,每个连通分支均可选来做第一步所选点集,连通分支数=未被染成红色点的数量-两个端点均未被染成红色的边的数量,那么单独考虑每个点和每条边对答案的贡献
1.一个点
u
u
u对答案的贡献即为选出不经过该点的边集数量,令
n
u
m
[
x
]
=
2
x
(
x
+
1
)
2
num[x]=2^{\frac{x(x+1)}{2}}
num[x]=22x(x+1)为有
x
x
x个点的连通分支中选出一个边集的方案数,那么点
u
u
u对答案的贡献即为
f
(
u
)
=
n
u
m
[
n
−
S
i
z
e
u
]
∏
v
∈
s
o
n
(
u
)
n
u
m
[
S
i
z
e
v
]
f(u)=num[n-Size_u]\prod\limits_{v\in son(u)}num[Size_v]
f(u)=num[n−Sizeu]v∈son(u)∏num[Sizev]
2.一条边
u
→
v
u\rightarrow v
u→v对答案的贡献为(假设
u
u
u是
v
v
v的父亲,将这条边的贡献记录在
v
v
v点)
g
(
v
)
=
n
u
m
[
n
−
S
i
z
e
u
]
∏
s
∈
s
o
n
(
u
)
−
{
v
}
n
u
m
[
S
i
z
e
s
]
∏
t
∈
s
o
n
(
v
)
n
u
m
[
S
i
z
e
t
]
g(v)=num[n-Size_u]\prod\limits_{s\in son(u)-\{v\}}num[Size_s]\prod\limits_{t\in son(v)}num[Size_t]
g(v)=num[n−Sizeu]s∈son(u)−{v}∏num[Sizes]t∈son(v)∏num[Sizet]
也即
g
(
v
)
=
f
(
u
)
n
u
m
[
S
i
z
e
v
]
f
(
v
)
n
u
m
[
n
−
S
i
z
e
v
]
g(v)=\frac{f(u)}{num[Size_v]}\frac{f(v)}{num[n-Size_v]}
g(v)=num[Sizev]f(u)num[n−Sizev]f(v)
树形
D
P
DP
DP一遍,答案即为
∑
i
=
1
n
(
f
(
i
)
−
g
(
i
)
)
\sum\limits_{i=1}^n(f(i)-g(i))
i=1∑n(f(i)−g(i)),线性预处理
n
u
m
[
x
]
num[x]
num[x]以及
n
u
m
−
1
[
x
]
num^{-1}[x]
num−1[x],时间复杂度
O
(
n
)
O(n)
O(n)
Code
#include<cstdio>
#include<vector>
using namespace std;
typedef long long ll;
#define maxn 200005
#define mod 998244353
#define inv2 499122177
int mul(int x,int y)
{
ll z=1ll*x*y;
return z-z/mod*mod;
}
int add(int x,int y)
{
x+=y;
if(x>=mod)x-=mod;
return x;
}
int b[maxn],c[maxn],ib[maxn],ic[maxn];
void init(int n=2e5)
{
b[0]=1;
for(int i=1;i<=n;i++)b[i]=mul(2,b[i-1]);
c[0]=1;
for(int i=1;i<=n;i++)c[i]=mul(b[i],c[i-1]);
ib[0]=1;
for(int i=1;i<=n;i++)ib[i]=mul(inv2,ib[i-1]);
ic[0]=1;
for(int i=1;i<=n;i++)ic[i]=mul(ib[i],ic[i-1]);
}
int n,Size[maxn],f[maxn],g[maxn];
vector<int>e[maxn];
void dfs(int u,int fa)
{
Size[u]=1;
f[u]=1;
for(int i=0;i<e[u].size();i++)
{
int v=e[u][i];
if(v==fa)continue;
dfs(v,u);
f[u]=mul(f[u],c[Size[v]]);
Size[u]+=Size[v];
}
f[u]=mul(f[u],c[n-Size[u]]);
for(int i=0;i<e[u].size();i++)
{
int v=e[u][i];
if(v==fa)continue;
g[v]=1;
g[v]=mul(f[v],ic[n-Size[v]]);//son of v
g[v]=mul(g[v],mul(f[u],ic[Size[v]]));//brother of v&&out of u
}
}
int main()
{
init();
scanf("%d",&n);
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
e[u].push_back(v),e[v].push_back(u);
}
dfs(1,0);
int ans=0;
for(int i=1;i<=n;i++)ans=add(ans,add(f[i],mod-g[i]));
printf("%d\n",ans);
return 0;
}