题意
有一棵树,初始所有点是黑色,给出一个概率 p p p ,有以下两种操作:
- 把点 x x x 和与他直接相连的点都染色,有 p p p 的概率染白, 1 − p 1-p 1−p 的概率染黑
- 询问当前状况下白点个数的方差
思路
首先要知道方差是什么。 方差等于平方的期望减去期望的平方。
那么考虑如何设计随机变量。先假设 a i = [ c o l o r i = w h i t e ] a_i=[color_i=white] ai=[colori=white] 。然后我们要求的方差就是:
V ( ∑ a i ) = E [ ( ∑ a i ) 2 ] − E ( ∑ a i ) 2 V(\sum a_i)=E[(\sum a_i)^2]-E(\sum a_i)^2 V(∑ai)=E[(∑ai)2]−E(∑ai)2
对于期望,因为每个点不管被重复染色几次,只要他曾经被染色,那么是白色的概率就肯定是 p p p ,所以
E ( ∑ a i ) = ∑ E ( a i ) = p ∗ c n t ( 染 过 色 的 节 点 ) E(\sum a_i)=\sum E(a_i)=p*cnt(染过色的节点) E(∑ai)=∑E(ai)=p∗cnt(染过色的节点)
对于平方的期望,我们考虑维护某个点的最后一次染色是被谁染的。当两个点不是被同一个点染色的,那么他们互相独立,否则他们显然同色,所以
E [ ( ∑ a i ) 2 ] = E ( ∑ ∑ a i ∗ a j ) = { p , b e l o n g i = b e l o n g j p 2 , b e l o n g i ≠ b e l o n g j E[(\sum a_i)^2]=E(\sum\sum a_i*a_j)= \left\{ \begin{aligned} &p&,belong_i=belong_j \\&p^2&,belong_i\neq belong_j \end{aligned} \right. E[(∑ai)2]=E(∑∑ai∗aj)={pp2,belongi=belongj,belongi=belongj
所以我们只要维护有多少个点最后一次是被同一个点染色的就行了。
至于如何维护,可以对一个点记录树上有哪些儿子不是被他染色的,存在一个 vector 里面。每次染色,暴力修改他和他的父亲,然后把 vector 里面的儿子全都修改并弹出。因为一个点只可能作为被染色的点或者被染色的点的父亲进入 vector 的,所以可以证明这样做的复杂度是线性的。
代码
#include<bits/stdc++.h>
using namespace std;
const int N = 5e5 + 10, M = N<<1, mod = 998244353;
namespace Graph
{
int h[N], ecnt, nxt[M], v[M];
void clear(){ecnt = 1;}
void add_dir(int _u, int _v){
v[++ecnt] = _v;
nxt[ecnt] = h[_u]; h[_u] = ecnt;
}
void add_undir(int _u, int _v){
add_dir(_u, _v);
add_dir(_v, _u);
}
}
using namespace Graph;
int n, q, p, p2, faz[N];
vector<int> son[N];
bool vis[N];
int bel[N], cnt[N], tot, ncnt;
template<class T>inline void read(T &x){
x = 0; bool fl = 0; char c = getchar();
while (!isdigit(c)){if (c == '-') fl = 1; c = getchar();}
while (isdigit(c)){x = (x<<3)+(x<<1)+c-'0'; c = getchar();}
if (fl) x = -x;
}
inline int add(int &x, int y){x += y; if (x >= mod) x -= mod;}
void dfs(int u, int fa){
faz[u] = fa;
for (int i = h[u]; i; i = nxt[i])
if (v[i] != fa){
son[u].push_back(v[i]);
dfs(v[i], u);
}
}
inline void check(int x){
if (vis[x] || x == 0) return;
vis[x] = 1; ++ncnt;
}
inline void inc(int x){
if (x == 0) return;
add(tot, 2*cnt[x]+1);
cnt[x]++;
}
inline void dec(int x){
if (x == 0) return;
cnt[x]--;
add(tot, mod-2*cnt[x]-1);
}
inline void update(int u){
if (faz[u] != 0 && bel[faz[u]] != u){
check(faz[u]);
if (bel[faz[u]] == faz[faz[u]]) son[faz[faz[u]]].push_back(faz[u]);
inc(u), dec(bel[faz[u]]), bel[faz[u]] = u;
}
if (bel[u] != u){
check(u);
if (bel[u] == faz[u]) son[faz[u]].push_back(u);
inc(u), dec(bel[u]), bel[u] = u;
}
for (int i = 0, qw = son[u].size(); i < qw; ++ i){
int v = son[u][i];
if (bel[v] != u){
check(v);
inc(u), dec(bel[v]), bel[v] = u;
}
}
son[u].clear();
}
inline int get_ans(){
return (1LL*tot*(p-p2)%mod+mod)%mod;
}
int main()
{
read(n); read(q); read(p);
p2 = 1LL*p*p%mod;
clear();
for (int i = 1; i < n; ++ i){
int x, y;
read(x); read(y);
add_undir(x, y);
}
dfs(1, 0);
for (int i = 1; i <= q; ++ i){
int opt, x;
read(opt);
if (opt == 1){
read(x);
update(x);
}
else printf("%d\n", get_ans());
}
return 0;
}