切树游戏
题目链接:luogu P3781 / LOJ 2269
题目大意
给你一棵树,会有单点修改,要你在其中求有多少棵子树的权值异或和是一个询问的 k。
思路
首先考虑没有修改的 DP。
先是最暴力的:
f
i
,
j
f_{i,j}
fi,j 为
i
i
i 的子树,当前选的子树有
i
i
i,异或和为
j
j
j。
询问答案
k
k
k 就是
∑
i
=
1
n
f
k
,
j
\sum\limits_{i=1}^nf_{k,j}
i=1∑nfk,j。
然后转移是枚举子树
s
o
n
son
son:
f
i
,
j
=
∑
x
⊕
y
=
j
f
i
,
x
f
s
o
n
,
y
+
f
i
,
j
f_{i,j}=\sum\limits_{x\oplus y=j}f_{i,x}f_{son,y}+f_{i,j}
fi,j=x⊕y=j∑fi,xfson,y+fi,j。
然后这个显然是可以用 FWT 优化的:
n
f
i
=
F
W
T
[
a
i
]
nf_{i}=FWT[a_i]
nfi=FWT[ai]
(其实卷积起来就是
f
i
,
k
=
n
f
i
,
k
∏
s
o
n
(
f
s
o
n
,
k
+
1
)
f_{i,k}=nf_{i,k}\prod\limits_{son}(f_{son,k}+1)
fi,k=nfi,kson∏(fson,k+1))
然后至于答案我们可以搞 h i , k = f i , k + ∑ s o n h s o n , k h_{i,k}=f_{i,k}+\sum\limits_{son}h_{son,k} hi,k=fi,k+son∑hson,k。
那我们这个就可以搞,然后搞出来最后 IFWT 转回去。
那一次的复杂度是
O
(
n
m
log
m
)
O(nm\log m)
O(nmlogm)。
考虑加上动态改点,发现上面的 DP 其实还算简单,我们考虑 DDP。
那还是轻重链剖分,然后重儿子是
h
s
o
n
x
hson_x
hsonx。
那我们设
l
f
x
,
k
=
n
f
x
,
k
∏
s
o
n
∧
s
o
n
≠
h
s
o
n
x
(
f
s
o
n
,
k
+
1
)
lf_{x,k}=nf_{x,k}\prod\limits_{son\wedge son\neq hson_{x}}(f_{son,k}+1)
lfx,k=nfx,kson∧son=hsonx∏(fson,k+1)
l
h
x
,
k
=
∑
s
o
n
∧
s
o
n
≠
h
s
o
n
x
l
h
s
o
n
,
k
lh_{x,k}=\sum\limits_{son\wedge son\neq hson_{x}}lh_{son,k}
lhx,k=son∧son=hsonx∑lhson,k
然后你会发现这个
k
k
k 一直只跟自己有关联,然后一直挂在这里很烦,那我们考虑把上面转移那些去掉
k
k
k 的一维,然后把它们当做数组转移。
(因为你都跟自己有关,所以加减乘除都是
O
(
m
)
O(m)
O(m) 的,就不用卷积)
然后看看怎么转移,准备上矩阵。
f
x
=
l
f
x
(
f
h
s
o
n
x
+
1
)
f_{x}=lf_{x}(f_{hson_x}+1)
fx=lfx(fhsonx+1)
h
x
=
f
x
+
l
h
x
+
h
h
s
o
n
x
=
l
f
x
(
f
h
s
o
n
x
+
1
)
+
l
h
x
+
h
h
s
o
n
x
h_x=f_x+lh_x+h_{hson_x}=lf_{x}(f_{hson_x}+1)+lh_x+h_{hson_x}
hx=fx+lhx+hhsonx=lfx(fhsonx+1)+lhx+hhsonx
然后你就可以尝试建矩阵啦!
先搞一个
∣
f
x
h
x
1
∣
\begin{vmatrix}f_x\\ h_x\\ 1\end{vmatrix}
∣∣∣∣∣∣fxhx1∣∣∣∣∣∣
然后要通过乘一个矩阵把
∣
f
h
s
o
n
x
h
h
s
o
n
x
1
∣
\begin{vmatrix}f_{hson_x}\\ h_{hson_x}\\ 1\end{vmatrix}
∣∣∣∣∣∣fhsonxhhsonx1∣∣∣∣∣∣ 变成
∣
f
x
h
x
1
∣
\begin{vmatrix}f_x\\ h_x\\ 1\end{vmatrix}
∣∣∣∣∣∣fxhx1∣∣∣∣∣∣。
然后可以构造出:
∣
l
f
x
0
l
f
x
l
f
x
1
l
f
x
+
l
h
x
0
0
1
∣
∗
∣
f
h
s
o
n
x
h
h
s
o
n
x
1
∣
=
∣
f
x
h
x
1
∣
\begin{vmatrix}lf_x&0& lf_x\\ lf_x&1&lf_x+lh_x\\ 0&0&1\end{vmatrix}*\begin{vmatrix}f_{hson_x}\\ h_{hson_x}\\ 1\end{vmatrix}=\begin{vmatrix}f_x\\ h_x\\ 1\end{vmatrix}
∣∣∣∣∣∣lfxlfx0010lfxlfx+lhx1∣∣∣∣∣∣∗∣∣∣∣∣∣fhsonxhhsonx1∣∣∣∣∣∣=∣∣∣∣∣∣fxhx1∣∣∣∣∣∣
然后就可以用这个矩阵搞,复杂度为 O ( q l o g ( 2 ) n m ∗ 27 ) O(qlog^{(2)}nm*27) O(qlog(2)nm∗27),好像有点大,就算用全局平衡二叉树也过不去。
然后我们看到这个矩阵的样子比较特别,考虑手玩一下:
∣
a
1
0
b
1
c
1
1
d
1
0
0
1
∣
∗
∣
a
2
0
b
2
c
2
1
d
2
0
0
1
∣
=
∣
a
1
a
2
0
a
1
b
2
+
b
1
a
2
c
1
+
c
2
1
c
1
b
2
+
d
2
+
d
1
0
0
1
∣
\begin{vmatrix}a_1&0& b_1\\ c_1&1&d_1\\ 0&0&1\end{vmatrix}*\begin{vmatrix}a_2&0& b_2\\ c_2&1&d_2\\ 0&0&1\end{vmatrix}=\begin{vmatrix}a_1a_2&0& a_1b_2+b_1\\ a_2c_1+c_2&1&c_1b_2+d_2+d_1\\ 0&0&1\end{vmatrix}
∣∣∣∣∣∣a1c10010b1d11∣∣∣∣∣∣∗∣∣∣∣∣∣a2c20010b2d21∣∣∣∣∣∣=∣∣∣∣∣∣a1a2a2c1+c20010a1b2+b1c1b2+d2+d11∣∣∣∣∣∣
然后你会发现你只用维护四个值,而且它们怎么维护是固定的,所以常数就变成了
4
4
4,用平衡二叉树就可以过啦!
(好像说 luogu 树链剖分被卡了的说)
然后至于实现的话。。。多用结构体。
具体一下就是你矩阵里面每个值是数组用结构体,轻边的维护不能直接转移,因为你修改的时候要除,但是里面可能是
0
0
0,所以你要再弄一个结构体专门来轻边的转移,就是一个正常的数组加上一个表示数组每一位乘了
0
0
0 的个数。
然后你乘
0
0
0 就不乘而是加
0
0
0 的个数,除就是减,然后弄一个函数把它转回乘整除的数组,就是
O
(
m
)
O(m)
O(m) 枚举一次把有
0
0
0 的变成
0
0
0。
然后建议全局平衡二叉树也可以封装一下。
(麻了代码是真的长)
代码
#include<cstdio>
#include<algorithm>
using namespace std;
const int mo = 1e4 + 7;
const int N = 3e4 + 10;
const int M = 128;
struct node {
int to, nxt;
}e[N << 1];
int n, m, a[N], inv[N], le[N], KK;
int sz[N], son[N], ans[M];
char c;
int jia(int x, int y) {return x + y >= mo ? x + y - mo : x + y;}
int jian(int x, int y) {return x < y ? x - y + mo : x - y;}
int cheng(int x, int y) {return x * y % mo;}
void FWT(int *f, int limit, int op) {//FWT
for (int mid = 1; mid < limit; mid <<= 1) {
for (int R = mid << 1, j = 0; j < limit; j += R)
for (int k = 0; k < mid; k++) {
int x = f[j | k], y = f[j | mid | k];
f[j | k] = jia(x, y); f[j | mid | k] = jian(x, y);
if (op == -1) f[j | k] = cheng(f[j | k], inv[2]), f[j | mid | k] = cheng(f[j | mid | k], inv[2]);
}
}
}
void add(int x, int y) {e[++KK] = (node){y, le[x]}; le[x] = KK;}
void dfs(int now, int father) {//重链剖分
sz[now] = 1;
for (int i = le[now]; i; i = e[i].nxt)
if (e[i].to != father) {
dfs(e[i].to, now); sz[now] += sz[e[i].to];
if (sz[e[i].to] > sz[son[now]]) son[now] = e[i].to;
}
}
struct poly {//记得你矩阵里面每个值都是一个数组,那加减我们得维护
int f[M];
int& operator [](int& x) {return f[x];}
poly operator +(poly y) {poly re; for (int i = 0; i < m; i++) re[i] = jia(f[i], y[i]); return re;}
poly operator -(poly y) {poly re; for (int i = 0; i < m; i++) re[i] = jian(f[i], y[i]); return re;}
poly operator *(poly y) {poly re; for (int i = 0; i < m; i++) re[i] = cheng(f[i], y[i]); return re;}
poly operator /(poly y) {poly re; for (int i = 0; i < m; i++) re[i] = cheng(f[i], inv[y[i]]); return re;}
void operator +=(poly y) {for (int i = 0; i < m; i++) f[i] = jia(f[i], y[i]);}
void operator -=(poly y) {for (int i = 0; i < m; i++) f[i] = jian(f[i], y[i]);}
}one, ee[N];
struct matrix {
poly a[2][2];
poly* operator [](const int& x) {return a[x];}
matrix operator *(matrix b) {//优化了的矩阵转移
matrix re;
re[0][0] = a[0][0] * b[0][0];
re[0][1] = a[0][0] * b[0][1] + a[0][1];
re[1][0] = b[0][0] * a[1][0] + b[1][0];
re[1][1] = a[1][0] * b[0][1] + b[1][1] + a[1][1];
return re;
}
};
struct Light {//对于轻边上的转移我们可以单独开一个结构体,因为要记录 0 个数
int num0[M], val[M];
void change(int pl, int va) {
if (!va) num0[pl] = val[pl] = 1;
else num0[pl] = 0, val[pl] = va;
}
void operator *=(poly f) {
for (int i = 0; i < m; i++)
if (!f.f[i]) num0[i]++;
else val[i] = cheng(val[i], f.f[i]);
}
void operator /=(poly f) {
for (int i = 0; i < m; i++)
if (!f.f[i]) num0[i]--;
else val[i] = cheng(val[i], inv[f.f[i]]);
}
};
poly get_poly(Light &b) {//把轻边的结构体转回给数组
poly x; for (int i = 0; i < m; i++) x[i] = b.num0[i] ? 0 : b.val[i]; return x;
}
struct BST {//全局平衡二叉树
Light lf[N]; poly lh[N];
int fa[N], ls[N], rs[N], root, sta[N], ssz[N];
matrix val[N], sum[N];
bool nrt(int x) {
return ls[fa[x]] == x || rs[fa[x]] == x;
}
void Make_val(int now, int to) {//求 lf,lh (轻边的值转移)
lf[now] *= (sum[to][1][0] + one);
lh[now] += sum[to][1][1];
}
void Clean_val(int now, int to) {
lf[now] /= (sum[to][1][0] + one);
lh[now] -= sum[to][1][1];
}
void Make_Val(int now) {//建矩阵
val[now][0][0] = val[now][0][1] = val[now][1][0] = val[now][1][1] = get_poly(lf[now]);
val[now][1][1] += lh[now];
}
void up(int now) {
sum[now] = sum[ls[now]] * val[now] * sum[rs[now]];
}
int buildT(int l, int r) {
if (l > r) return 0;
int tot = 0; for (int i = l; i <= r; i++) tot += ssz[sta[i]];
for (int i = l, now = ssz[sta[i]]; i <= r; i++, now += ssz[sta[i]])
if (now * 2 >= tot) {
ls[sta[i]] = buildT(l, i - 1); rs[sta[i]] = buildT(i + 1, r);
fa[ls[sta[i]]] = fa[rs[sta[i]]] = sta[i]; up(sta[i]); return sta[i];
}
}
int build(int now, int fr) {
for (int i = now; i; fr = i, i = son[i]) {
for (int j = le[i]; j; j = e[j].nxt)
if (e[j].to != fr && e[j].to != son[i]) {
int x = build(e[j].to, i); fa[x] = i;
Make_val(i, x);
}
Make_Val(i);
}
sta[0] = 0;
for (int i = now; i; i = son[i]) sta[++sta[0]] = i, ssz[i] = sz[i] - sz[son[i]];
reverse(sta + 1, sta + sta[0] + 1);//反转,因为也是从下到上DP的
return buildT(1, sta[0]);
}
void Init() {
for (int i = 1; i <= n; i++)
for (int j = 0; j < m; j++)
lf[i].change(j, ee[i][j]);
val[0][0][0] = sum[0][0][0] = one;
root = build(1, 0);
}
void change(int x, int y) {
lf[x] /= ee[x];
for (int i = 0; i < m; i++) ee[x][i] = 0; a[x] = y; ee[x][a[x]] = 1;
FWT(ee[x].f, m, 1); lf[x] *= ee[x]; Make_Val(x);
for (; x; x = fa[x]) {
if (nrt(x)) up(x);//重链上直接上传
else {//轻链要修改好父亲的值,上传
Clean_val(fa[x], x); up(x); Make_val(fa[x], x); Make_Val(fa[x]);
}
}
}
void update_ans() {
for (int i = 0; i < m; i++) ans[i] = sum[root][1][1][i];
FWT(ans, m, -1);
}
}T;
void Init() {
for (int i = 0; i < m; i++) one[i] = 1;
inv[0] = inv[1] = 1; for (int i = 2; i < mo; i++) inv[i] = cheng(inv[mo % i], mo - mo / i);
}
int main() {
scanf("%d %d", &n, &m); Init();
for (int i = 1; i <= n; i++) scanf("%d", &a[i]), ee[i][a[i]] = 1, FWT(ee[i].f, m, 1);
for (int i = 1; i < n; i++) {
int x, y; scanf("%d %d", &x, &y); add(x, y); add(y, x);
}
dfs(1, 0);
T.Init();
T.update_ans();
int q; scanf("%d", &q);
while (q--) {
c = getchar(); while (c != 'C' && c != 'Q') c = getchar();
if (c == 'C') {
while (c != ' ') c = getchar();
int x, y; scanf("%d %d", &x, &y);
T.change(x, y); T.update_ans();
}
else {
while (c != ' ') c = getchar();
int x; scanf("%d", &x);
printf("%d\n", ans[x]);
}
}
return 0;
}