题目大意
维护一棵树,支持:
1.动态修改某个点权值;
2.查询有多少个联通子树异或值为
p
p
p。
题解
这题感觉比较套路,显然可以列出一个dp方程,发现这是FWT异或卷积的形式。具体的,记
f
[
i
]
f[i]
f[i]为
i
i
i的dp数组的FWT卷积,那么
f
[
i
]
=
b
[
v
a
l
[
i
]
]
∗
∏
v
∈
s
o
n
[
i
]
(
f
[
v
]
+
b
[
0
]
)
f[i]=b[val[i]]*\prod_{v\in son[i]} (f[v]+b[0])
f[i]=b[val[i]]∗∏v∈son[i](f[v]+b[0]),其中
b
[
i
]
b[i]
b[i]表示只有
i
i
i一个数字的FWT异或卷积。
于是显然的动态dp就出来了。每个重链维护一颗线段树,我们最终需要求重链上每个区间的FWT卷积之和,因此我们需要维护四个值:FWT卷积,前缀FWT卷积之和,后缀FWT卷积之和,区间FWT卷积之和。然后直接线段树就可以
O
(
m
l
o
g
n
)
O(mlogn)
O(mlogn)更新了。
#include <bits/stdc++.h>
namespace IOStream {
const int MAXR = 10000000;
char _READ_[MAXR], _PRINT_[MAXR];
int _READ_POS_, _PRINT_POS_, _READ_LEN_;
inline char readc() {
#ifndef ONLINE_JUDGE
return getchar();
#endif
if (!_READ_POS_) _READ_LEN_ = fread(_READ_, 1, MAXR, stdin);
char c = _READ_[_READ_POS_++];
if (_READ_POS_ == MAXR) _READ_POS_ = 0;
if (_READ_POS_ > _READ_LEN_) return 0;
return c;
}
template<typename T> inline void read(T &x) {
x = 0; register int flag = 1, c;
while (((c = readc()) < '0' || c > '9') && c != '-');
if (c == '-') flag = -1; else x = c - '0';
while ((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
x *= flag;
}
template<typename T1, typename ...T2> inline void read(T1 &a, T2 &...x) {
read(a), read(x...);
}
inline int reads(char *s) {
register int len = 0, c;
while (isspace(c = readc()) || !c);
s[len++] = c;
while (!isspace(c = readc()) && c) s[len++] = c;
s[len] = 0;
return len;
}
inline void ioflush() {
fwrite(_PRINT_, 1, _PRINT_POS_, stdout), _PRINT_POS_ = 0;
fflush(stdout);
}
inline void printc(char c) {
_PRINT_[_PRINT_POS_++] = c;
if (_PRINT_POS_ == MAXR) ioflush();
}
inline void prints(char *s) {
for (int i = 0; s[i]; i++) printc(s[i]);
}
template<typename T> inline void print(T x, char c = '\n') {
if (x < 0) printc('-'), x = -x;
if (x) {
static char sta[20];
register int tp = 0;
for (; x; x /= 10) sta[tp++] = x % 10 + '0';
while (tp > 0) printc(sta[--tp]);
} else printc('0');
printc(c);
}
template<typename T1, typename ...T2> inline void print(T1 x, T2... y) {
print(x, ' '), print(y...);
}
}
using namespace IOStream;
using namespace std;
typedef long long ll;
const int MAXT = 70000, MAXN = 30005, MAXM = 130, MOD = 10007;
struct Edge { int to, next; } edge[MAXT];
int rt[MAXN], ls[MAXT], rs[MAXT], val[MAXN], n, m, Q, tot;
int par[MAXN], top[MAXN], head[MAXN], sz[MAXN], wson[MAXN];
int base[MAXM][MAXM], inv[MOD], id[MAXN], ans[MAXM], temp[MAXN];
struct ModInt {
int num, cnt;
ModInt() { num = cnt = 1; }
ModInt& operator=(int x) {
if (x == 0) num = 1, cnt = 1;
else num = x, cnt = 0;
return *this;
}
ModInt& operator*=(int x) {
x %= MOD;
if (x == 0) ++cnt;
else (num *= x) %= MOD;
return *this;
}
ModInt& operator/=(int x) {
x %= MOD;
if (x == 0) --cnt;
else (num *= inv[x]) %= MOD;
return *this;
}
int get() { return cnt ? 0 : num; }
} f[MAXN][MAXM];
void addedge(int u, int v) {
edge[++tot] = (Edge) { v, head[u] };
head[u] = tot;
}
void fwt(int *a, int n) {
for (int h = 2; h <= n; h <<= 1) {
int hh = h >> 1;
for (int i = 0; i < n; i += h)
for (int j = i; j < i + hh; j++) {
int x = a[j], y = a[j + hh];
a[j] = x + y, a[j + hh] = x - y;
}
}
for (int i = 0; i < n; i++) a[i] = (a[i] % MOD + MOD) % MOD;
}
int ifwt(int *a, int n, int x) {
if (n == 0) return a[x];
if (x & n) return ifwt(a, n >> 1, x ^ n) - ifwt(a, n >> 1, x);
else return ifwt(a, n >> 1, x) + ifwt(a, n >> 1, x ^ n);
}
void dfs1(int u, int fa) {
++sz[u], par[u] = fa;
for (int i = head[u]; i; i = edge[i].next) {
int v = edge[i].to;
if (v == fa) continue;
dfs1(v, u), sz[u] += sz[v];
if (!wson[u] || sz[wson[u]] < sz[v]) wson[u] = v;
}
}
vector<int> lnk[MAXN];
short pre[MAXT][MAXM], suf[MAXT][MAXM], mul[MAXT][MAXM], sum[MAXT][MAXM];
void pushup(int x) {
int l = ls[x], r = rs[x];
for (int k = 0; k < m; k++) {
pre[x][k] = (pre[l][k] + (int)pre[r][k] * mul[l][k]) % MOD;
suf[x][k] = (suf[r][k] + (int)suf[l][k] * mul[r][k]) % MOD;
mul[x][k] = (int)mul[l][k] * mul[r][k] % MOD;
sum[x][k] = (sum[l][k] + sum[r][k] + (int)suf[l][k] * pre[r][k]) % MOD;
}
}
int build(const vector<int> &v, int l, int r) {
int p = ++tot;
if (l == r) {
for (int i = 0; i < m; i++)
pre[p][i] = suf[p][i] = mul[p][i] = sum[p][i] = f[v[l]][i].get();
return p;
}
int mid = (l + r) >> 1;
ls[p] = build(v, l, mid);
rs[p] = build(v, mid + 1, r);
pushup(p); return p;
}
void dfs2(int u, int fa, int t) {
for (int i = 0; i < m; i++) f[u][i] = base[val[u]][i];
id[u] = lnk[t].size(); lnk[t].push_back(u); top[u] = t;
if (wson[u]) dfs2(wson[u], u, t);
for (int i = head[u]; i; i = edge[i].next) {
int v = edge[i].to;
if (v != fa && v != wson[u]) {
dfs2(v, u, v);
for (int j = 0; j < m; j++)
f[u][j] *= pre[rt[v]][j] + base[0][j];
}
}
if (u == t) {
rt[u] = build(lnk[u], 0, lnk[u].size() - 1);
for (int i = 0; i < m; i++) ans[i] += sum[rt[u]][i];
}
}
void update(const vector<int> &v, int x, int l, int r, int p) {
if (l == r) {
for (int i = 0; i < m; i++)
pre[p][i] = suf[p][i] = mul[p][i] = sum[p][i] = f[v[l]][i].get();
return;
}
int mid = (l + r) >> 1;
if (x <= mid) update(v, x, l, mid, ls[p]);
else update(v, x, mid + 1, r, rs[p]);
pushup(p);
}
char opt[10];
int main() {
read(n, m);
inv[1] = 1;
for (int i = 2; i < MOD; i++)
inv[i] = MOD - MOD / i * inv[MOD % i] % MOD;
for (int i = 1; i <= n; i++) read(val[i]);
for (int i = 1; i < n; i++) {
int u, v; read(u, v);
addedge(u, v), addedge(v, u);
}
for (int i = 0; i < m; i++) {
base[i][i] = 1;
fwt(base[i], m);
}
dfs1(1, tot = 0);
dfs2(1, tot = 0, 1);
for (int j = 0; j < m; j++) ans[j] %= MOD;
read(Q);
while (Q--) {
reads(opt);
if (opt[0] == 'Q') {
int x; read(x); print((ifwt(ans, m >> 1, x) % MOD + MOD) * inv[m] % MOD);
} else {
int x, y; read(x, y);
for (int i = 0; i < m; i++) (f[x][i] /= base[val[x]][i]) *= base[y][i];
val[x] = y;
for (; x; x = par[x]) {
int t = par[top[x]], a = id[x];
x = top[x];
if (t > 0) for (int i = 0; i < m; i++)
f[t][i] /= pre[rt[x]][i] + base[0][i];
for (int i = 0; i < m; i++) ans[i] -= sum[rt[x]][i];
update(lnk[x], a, 0, lnk[x].size() - 1, rt[x]);
for (int i = 0; i < m; i++) ans[i] += sum[rt[x]][i];
if (t > 0) for (int i = 0; i < m; i++)
f[t][i] *= pre[rt[x]][i] + base[0][i];
}
for (int i = 0; i < m; i++) ans[i] = (ans[i] % MOD + MOD) % MOD;
}
}
ioflush();
return 0;
}