题意
有一棵每个点都具有权值 w i w_i wi的树,给你每个点的权值范围 [ l [ i ] , r [ i ] ] [l[i],r[i]] [l[i],r[i]], n − 1 n-1 n−1对异或关系,求满足上述条件的权值有多少种。
分析
通过
n
−
1
n-1
n−1对异或关系,可以得到任意两点之间的异或值。
可以先令
w
1
′
=
0
w'_1=0
w1′=0,通过
d
f
s
dfs
dfs计算出每个点的
w
i
′
w'_i
wi′,使用每个点的左右区间限制,以及1号点本身的左右区间,来求出满足条件的
w
1
w_1
w1的取值。
设
w
1
=
a
w_1=a
w1=a
w
i
=
w
i
′
⊕
a
w_i=w'_i\oplus a
wi=wi′⊕a
l
i
≤
w
i
≤
r
i
l_i\le w_i\le r_i
li≤wi≤ri
l
i
≤
w
i
′
⊕
a
≤
r
i
l_i \le w'_i\oplus a \le r_i
li≤wi′⊕a≤ri
S
e
g
i
=
[
l
[
i
]
,
r
[
i
]
]
⊕
w
i
′
Seg_i=[l[i],r[i]]\oplus w'_i
Segi=[l[i],r[i]]⊕wi′
当前区间的异或可能被拆成多个区间,可以使用线段树建树的方式,每次将当前区间分成左右两部分,一个子区间左端点就是当前位不变,后面位全填0,右端点就是当前位不变,后面位全填1,使用位运算实现。
之后需解决多个区间的交,排序后扫一遍就可得到答案。
代码
#include <iostream>
#include <string>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <queue>
#include <iomanip>
#include <map>
#include <cstdio>
#include <stack>
#include <set>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int ,int> pii;
#define endl '\n'
ll gcd(ll a, ll b){
return b == 0 ? a : gcd(b, a % b);
}
void input(){
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
}
inline int read(){
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {if(c=='-') f=-1;c=getchar();}
while (c>='0'&&c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
return x*f;
}
const int N = 1e6+10, M = N * 2, inf = 1e8;
int n, L[N], R[N], W[N];
vector<pii> G[N];
vector<pii> seg, v;
// l r 是当前点限制的左右取值范围
// vl vr 是当前这段区间所代表的取值范围
// dep 表示深度,v是在w1取0的条件下当前点的取值
void build(int l, int r, int vl, int vr, int dep, int v){
if(l <= vl && vr <= r){
int nowl = (vl ^ v) & (((1<<30)-1) ^ ((1<<dep) - 1));
int nowr = nowl + (1<<dep) - 1;
seg.push_back({nowl, nowr});
}else{
int mid = (vl + vr) >> 1;
if(l <= mid) build(l, r, vl, mid, dep-1, v);
if(r > mid) build(l, r, mid+1, vr, dep-1, v);
}
}
// 使用dfs求解每个点的值
void dfs(int x, int fa, int val){
if(x != 1) build(L[x], R[x], 0, (1<<30)-1, 30, val);
for(auto i : G[x]){
if(i.first == fa) continue;
dfs(i.first, x, i.second^val);
}
}
int main(){
ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
cin>>n;
for(int i = 1; i <= n; i++) cin>>L[i]>>R[i];
for(int i = 1; i < n; i++){
int u, v, w; cin>>u>>v>>w;
G[u].push_back({v, w});
G[v].push_back({u, w});
}
dfs(1, -1, 0);
seg.push_back({L[1], R[1]});
for(auto i : seg)
v.push_back({i.first, 1}), v.push_back({i.second + 1, -1});
sort(v.begin(), v.end());
int res = 0, sum = 0;
for(int i = 0; i < (int)v.size(); i++){
sum += v[i].second;
if(sum == n) res += v[i+1].first - v[i].first;
}
cout<<res<<endl;
return 0;
}