题意
一棵树,每个点有一个点权。
没有修改,只有一种询问: ( x , y ) (x,y) (x,y)路径上的点不能选,在其他点之中选出一些,使得选出来的点点权异或和为 k k k,问有多少种选法。
思路
这题有好多好像很厉害的套路。。。
- 线性基方案数
首先是线性基,假如总共有 n n n个元素,线性基的大小为 k k k,而且线性基能够异或得到 x x x,那么在这 n n n个数中选出一个集合使得异或和为 x x x的方案数就是 2 n − k 2^{n-k} 2n−k。
实际上好像很显然的样子,枚举剩下的选还是不选就好了。
- dfs序
然后是dfs序。如何搞出除了某个点 u u u到根的路径之外的所有点的线性基?做两遍dfs,得到两个dfs序。假如第一遍dfs搜索儿子的顺序是 s o n 1 , s o n 2 , . . . , s o n n son_1,son_2,...,son_n son1,son2,...,sonn,那么第二遍就是反过来 s o n n , s o n n − 1 , . . . , s o n 1 son_n,son_{n-1},...,son_1 sonn,sonn−1,...,son1。然后假设 u u u在两的dfs序中的位置分别是 i d 1 id_1 id1和 i d 2 id_2 id2,那么我们所要求的线性基就是第一个dfs序中 i d 1 + 1 id_1+1 id1+1到 n n n的所有点的线性基和第2个dfs序中 i d 2 + 1 id_2+1 id2+1到 n n n的所有点的线性基的合并。
好像也挺显然的。但是这是一种挺有趣的思路。对于树上维护信息来说,在常用的树链剖分点分树之外,又是一种可以考虑的做法。
- 正式开始
然后开始试图处理询问,处理的方法实质上与上面dfs序的方法类似,也是在正反两个dfs序上做。
首先规定t1.dfn(i)
表示在
t
1
t1
t1这个dfs序中,第
i
i
i个是哪个点。t1.rfn(i)
表示在
t
1
t1
t1这个dfs序中,编号为
i
i
i的点在序列中的哪个位置。
我们设两个dfs序分别为
t
1
t1
t1和
t
2
t2
t2,并假设在
t
1
t1
t1中dfn(x)>dfn(y)
。
那么除了 ( x , y ) (x,y) (x,y)之外的所有点可以由下面这些点合并而来:
t2.dfn(x)+1~t2.n
t1.dfn(y)+1~t1.n
这两个是在
t
1
t1
t1或者
t
2
t2
t2中[dfn(x),dfn(y)]
区间之外的点。
下面是[dfn(x),dfn(y)]
区间之间的点:
假设 z z z是 y y y沿着父亲向上爬爬到深度恰好比lca大1的点。
t1.dfn(x),t1+1,dfn(z)-1
表示
t
1
t1
t1中x
之后,lca
的包含y
的子树之前。
显然还有一部分就是在z
的子树内部,y
之前,且不包含路径(z,y)
:
t2.dfn(y)+1, t2.end(z)
最后加上lca
到根的那段,大功告成。
- 猫树
像上面这样做,询问 O ( log n log 2 k ) O(\log n \log^2k) O(lognlog2k),显然不行。
因为没有修改,可以使用猫树,用预处理的时间复杂度来换快速的查询,最终复杂度是 O ( n log n log k + q log 2 k ) O(n\log n \log k + q \log^2k) O(nlognlogk+qlog2k),分别是预处理和询问。
代码
因为一直调不出来,又因为机房令人燥热难安,所以代码写的实在是有点难看。
但是跑得挺快的。我常数小我骄傲。
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int mod = 998244353;
const int N = 2e5 + 10, K = 32768, M = N << 1;
const int EN = 20, EK = 15;
namespace G
{
int h[N], ecnt, nxt[M], v[M];
void clear(){
ecnt = 1;
}
void add_dir(int _u, int _v){
v[++ecnt] = _v;
nxt[ecnt] = h[_u]; h[_u] = ecnt;
}
void add_undir(int _u, int _v){
add_dir(_u, _v);
add_dir(_v, _u);
}
}
using namespace G;
struct Bas{
short d[EK], siz;
};
Bas tmpbas;
Bas empty(){
tmpbas.siz = 0;
memset(tmpbas.d, 0, sizeof(tmpbas.d));
return tmpbas;
}
Bas insert(Bas v, int x){
Bas u = v;
for (int i = EK-1; i >= 0; -- i)
if ((x >> i) & 1){
if (u.d[i]){
x ^= u.d[i];
if (x == 0) break;
}
else{
u.d[i] = x;
u.siz++;
break;
}
}
return u;
}
bool check(Bas u, int x){
for (int i = EK-1; i >= 0; -- i)
if ((x >> i) & 1){
if (u.d[i]){
x ^= u.d[i];
if (x == 0) break;
}
else return false;
}
return true;
}
Bas merge(Bas u, Bas w){
Bas v = w;
for (int i = 0; i < EK; ++ i)
if (u.d[i]){
int x = u.d[i];
for (int j = EK-1; j >= 0; -- j)
if ((x >> j) & 1){
if (v.d[j]){
x ^= v.d[j];
if (x == 0) break;
}
else{
v.d[j] = x;
v.siz++;
break;
}
}
}
return v;
}
int Log[N << 2];
int pos[N];
struct Cat_tr{
int dfn[N], rfn[N], end[N], idx;
Bas f[EN][N];
void build(int d, int u, int l, int r, int *a){
if (l == r){
f[d][l] = insert(empty(), a[dfn[l]]);
pos[l] = u;
return;
}
int mid = l + r >> 1;
f[d][mid] = insert(empty(), a[dfn[mid]]);
for (int i = mid-1; i >= l; -- i)
f[d][i] = insert(f[d][i+1], a[dfn[i]]);
f[d][mid+1] = insert(empty(), a[dfn[mid+1]]);
for (int i = mid+2; i <= r; ++ i)
f[d][i] = insert(f[d][i-1], a[dfn[i]]);
build(d+1, u<<1, l, mid, a);
build(d+1, u<<1^1, mid+1, r, a);
}
int lcp(int x, int y){
int nx = pos[x];
int ny = pos[y];
while (Log[nx] > Log[ny]) ny <<= 1;
while (Log[nx] < Log[ny]) nx <<= 1;
return Log[nx] - Log[nx ^ ny] - 1; // 这里的-1不能漏,在写的时候一定要三思
}
Bas query(int x, int y){
if (x > y) return empty();
int d = lcp(x, y);
return merge(f[d][x], f[d][y]);
}
};
int n, m, a[N], f[N][EN], dpt[N];
Bas too[N], sum[N], mus[N];
Cat_tr t, r;
template<class T> void read(T &x){
x = 0; bool fl = 0; char c = getchar();
while (!isdigit(c)){if (c == '-') fl = 1; c = getchar();}
while (isdigit(c)){x = (x<<3)+(x<<1)+c-'0'; c = getchar();}
if (fl) x = -x;
}
int fpow(int x, int y, int p){
int r = 1;
while (y){
if (y&1) r = 1LL*r*x%p;
x = 1LL*x*x%p;
y >>= 1;
}
return r;
}
void dfs1(int u, int fa)
{
too[u] = insert(too[fa], a[u]);
t.dfn[++t.idx] = u;
t.rfn[u] = t.idx;
dpt[u] = dpt[fa] + 1;
f[u][0] = fa;
for (int i = 1; i < EN; ++ i)
f[u][i] = f[f[u][i-1]][i-1];
for (int i = h[u]; i; i = nxt[i]){
if (v[i] == fa) continue;
dfs1(v[i], u);
}
}
vector<int> tmp[N];
void dfs2(int u, int fa)
{
r.dfn[++r.idx] = u;
r.rfn[u] = r.end[u] = r.idx;
for (int i = h[u]; i; i = nxt[i]){
if (v[i] == fa) continue;
tmp[u].push_back(v[i]);
}
int n3 = tmp[u].size();
for (int i = n3-1; i >= 0; -- i){
dfs2(tmp[u][i], u);
r.end[u] = max(r.end[u], r.end[tmp[u][i]]);
}
}
void print(Bas u)
{
for (int i = 0; i < EK; ++ i)
cout << u.d[i] << " ";
cout << endl;
}
int Lca(int x, int y)
{
if (dpt[x] < dpt[y]) swap(x, y);
for (int i = EN-1; i >= 0; -- i)
if (dpt[f[x][i]] >= dpt[y])
x = f[x][i];
if (x == y) return x;
for (int i = EN-1; i >= 0; -- i)
if (f[x][i] != f[y][i])
x = f[x][i], y = f[y][i];
return f[x][0];
}
int Lca_son(int x, int y)
{
if (dpt[x] < dpt[y]) swap(x, y);
for (int i = EN-1; i >= 0; -- i)
if (dpt[f[x][i]] > dpt[y])
x = f[x][i];
return x;
}
int main()
{
read(n); read(m);
for (int i = 1; i < n; ++ i){
int x, y;
read(x); read(y);
add_undir(x, y);
}
for (int i = 1; i <= n; ++ i)
read(a[i]);
Log[0] = Log[1] = 0;
for (int i = 2; i <= n*4; ++ i)
Log[i] = Log[i>>1] + 1;
t.idx = r.idx = 0;
t.dfn[0] = r.dfn[0] = 0; a[0] = 0;
too[0] = empty();
dfs1(1, 0);
dfs2(1, 0);
t.build(0, 1, 1, n, a); r.build(0, 1, 1, n, a);
for (; m--; ){
int x, y, w, z, k, cnt;
read(x); read(y); read(k);
if (t.rfn[x] > t.rfn[y]) swap(x, y);
w = Lca(x, y);
z = Lca_son(y, w);
cnt = dpt[x] + dpt[y] - 2 * dpt[w] + 1;
Bas ans;
if (w == x){ // 这种情况不能不特判,很显然不特判会导致(x,y)路径整条被加到线性基里面
ans = merge(too[f[x][0]], merge(t.query(t.rfn[y]+1, n), r.query(r.rfn[y]+1, n)));
}
else{
Bas ans1 = r.query(r.rfn[x]+1, n);
Bas ans2 = t.query(t.rfn[y]+1, n);
Bas ans3 = t.query(t.rfn[x]+1, t.rfn[z]-1);
Bas ans4 = r.query(r.rfn[y]+1, r.end[z]);
ans = merge(merge(ans1, ans2), merge(ans3, merge(ans4, too[f[w][0]])));
}
if (check(ans, k)) printf("%d\n", fpow(2, n - cnt - ans.siz, mod));
else puts("0");
}
return 0;
}