传送门
题意
给出一棵 n n n 个点组成的树,每个点权的取值范围是 [ l i , r i ] [ l_i , r_{i}] [li,ri]每条边权代表的是两点的异或值,现在问这棵树有多少种有效赋值
分析
我们如果设置一个跟节点
r
o
o
t
root
root,就可以求出跟节点到每一个节点的异或和
w
i
w_{i}
wi,这样处理完,我们只需要判断跟节点的值
x
x
x和每一个节点
w
i
w_{i}
wi,是否满足
l
i
≤
w
i
⊕
x
≤
r
i
l_{i} \leq w_{i} \oplus x \leq r_{i}
li≤wi⊕x≤ri,也就是求每一个区间合法解的并集
但问题是,一个区间异或上一个值之后,并不是一个区间,大家可以用
1
∼
10
⊕
3
1\sim10 \oplus 3
1∼10⊕3试一下,但是有一个很神奇的性质
我们可以利用
[
0
,
2
30
−
1
]
[0,2^{30-1}]
[0,230−1]的线段树, 把$ [L[i] , R[i]]$ 分成
O
(
l
o
g
W
)
O(logW)
O(logW) 个连续的区间,
每个区间的形式是 :
k
.
.
.
30
k...30
k...30 位相同, $ 0…k-1$ 位是
0
0
0 到
2
k
−
1
2^k-1
2k−1,;
这样的区间异或上
w
[
i
]
w[i]
w[i] 后仍然还是一个区间
这样处理完之后用差分操作求并集即可
代码
#pragma GCC optimize(3)
#include <bits/stdc++.h>
#define debug(x) cout<<#x<<":"<<x<<endl;
#define dl(x) printf("%lld\n",x);
#define di(x) printf("%d\n",x);
#define _CRT_SECURE_NO_WARNINGS
#define pb push_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
typedef vector<int> VI;
const int INF = 0x3f3f3f3f;
const int N = 2e5 + 10;
const ll mod = 1000000007;
const double eps = 1e-9;
const double PI = acos(-1);
template<typename T>inline void read(T &a) {
char c = getchar(); T x = 0, f = 1; while (!isdigit(c)) {if (c == '-')f = -1; c = getchar();}
while (isdigit(c)) {x = (x << 1) + (x << 3) + c - '0'; c = getchar();} a = f * x;
}
int gcd(int a, int b) {return (b > 0) ? gcd(b, a % b) : a;}
int h[N],e[N],w[N],ne[N],idx;
int n;
int l[N],r[N];
vector<PII> v;
void add(int x,int y,int z){
ne[idx] = h[x],e[idx] = y,w[idx] = z,h[x] = idx++;
}
void build(int L,int R,int u,int d,int res){
if(L >= l[u] && R <= r[u]){
int vl = (L ^ res) & (((1 << 30) - 1) ^ ((1 << d) - 1));
int vr = vl + (1 << d) - 1;
v.pb({vl,1}),v.pb({vr + 1,-1});
return;
}
int mid = L + R >> 1;
if(l[u] <= mid) build(L,mid,u,d - 1,res);
if(r[u] > mid) build(mid + 1,R,u,d - 1,res);
}
void dfs(int u,int fa,int sum){
if(u != 1) build(0,(1 << 30) - 1,u,30,sum);
for(int i = h[u];~i;i = ne[i]){
int j = e[i];
if(j == fa) continue;
dfs(j,u,sum ^ w[i]);
}
}
int main() {
read(n);
memset(h,-1,sizeof h);
for(int i = 1;i <= n;i++) read(l[i]),read(r[i]);
for(int i = 1;i < n;i++){
int a,b,c;
read(a),read(b),read(c);
add(a,b,c),add(b,a,c);
}
dfs(1,0,0);
v.pb({l[1],1});v.pb({r[1] + 1,-1});
sort(all(v));
int res = 0,sum = 0;
for(int i = 0;i < v.size();i++){
sum += v[i].se;
if(sum == n){
res += v[i + 1].fi - v[i].fi;
}
}
di(res);
return 0;
}