原题链接:https://ac.nowcoder.com/acm/contest/11255/E
题意
有一棵树,每个节点都有一个权值 a [ i ] a[i] a[i]的取值范围 [ L [ i ] , R [ i ] ] [L[i],R[i]] [L[i],R[i]],每条边也有权值 w w w满足 a [ u ] ⨁ a [ v ] a[u] \bigoplus a[v] a[u]⨁a[v],询问方案数。
分析
我们知道如果确定一个点,那么所有点基本上就已经确定了。但枚举一个点的所有权值的话, O ( 2 30 ∗ n ) O(2^{30}*n) O(230∗n)复杂度肯定是不行的。
但如果我们先确定一个点为基准点,那么它与所有点的异或值都可以处理出来的,因此我们求的方案数其实就是所有点
[
L
[
i
]
,
R
[
i
]
]
⨁
a
′
[
i
]
]
[L[i],R[i]]\bigoplus a'[i]]
[L[i],R[i]]⨁a′[i]]集合取并集。还有一个问题,就是这个异或之后的区间并不是连续的,但不会超过log个区间(可以自己手算一下)。因此我们需要构造一些区间使得每次异或之后的答案是连续的,类似
[
x
x
x
0000
,
x
x
x
1111
]
[xxx0000, xxx1111]
[xxx0000,xxx1111]的区间(前缀相同)不论异或上任何值,最后的区间一定是连续的,例如
[
0
,
3
]
⨁
6
=
[
4
,
7
]
[0,3]\bigoplus6=[4,7]
[0,3]⨁6=[4,7]。很快我们发现这不就是权值线段树上节点的区间吗?
最后只要把所有区间扫一遍,利用扫描线的思想,当前值为n就统计区间长度。
这是题解中
n
l
o
g
n
2
nlogn^2
nlogn2的做法,有些偷懒了,当然还可以做到
n
l
o
g
n
nlogn
nlogn,只要在线段树上打一下懒惰标记,做一下区间加最后统计和为n的个数也是可以的。
Code
#include <bits/stdc++.h>
#define lowbit(i) i & -i
#define Debug(x) cout << x << endl
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef pair<ll, ll> PII;
const ll INF = 1e18;
const int N = 1e5 + 10;
const int M = 1e6 + 10;
const int MOD = 998244353;
int n, L[N], R[N];
struct Edge {
int to, next, w;
}e[N<<1];
int cnt, h[N], tot, rt, ls[N*40], rs[N*40], sum[N*40], tag[N*40];
vector<PII> ve;
void add(int u, int v, int w) {
e[cnt].to = v;
e[cnt].w = w;
e[cnt].next = h[u];
h[u] = cnt++;
}
void insert(int &now, int l, int r, int ql, int qr, int dep, int val) {
if (!now) now = ++tot;
if (ql <= l && qr >= r) {
int Left = (val ^ l) >> dep << dep;
int Right = Left + (1 << dep) - 1;
ve.push_back({Left, 1});
ve.push_back({Right+1, -1});
return;
}
int mid = (l + r) >> 1;
if (ql <= mid) insert(ls[now], l, mid, ql, qr, dep-1, val);
if (qr > mid) insert(rs[now], mid+1, r, ql, qr, dep-1, val);
}
void dfs(int x, int fa, int val) {
for (int i = h[x]; ~i; i = e[i].next) {
int v = e[i].to;
if (v == fa) continue;
insert(rt, 0, (1<<30)-1, L[v], R[v], 30, val ^ e[i].w);
dfs(v, x, val ^ e[i].w);
}
}
void solve() {
cin >> n;
memset(h, -1, sizeof h);
for (int i = 1; i <= n; i++) cin >> L[i] >> R[i];
for (int i = 1; i <= n-1; i++) {
int u, v, w; cin >> u >> v >> w;
add(u, v, w), add(v, u, w);
}
dfs(1, 0, 0);
ve.push_back({L[1], 1});
ve.push_back({R[1] + 1, -1});
sort(ve.begin(), ve.end());
int sum = 0;
int ans = 0;
for (int i = 0; i < ve.size(); i++) {
sum += ve[i].se;
if (sum == n && i + 1 < ve.size()) {
ans += ve[i+1].fi - ve[i].fi;
}
}
cout << ans << endl;
}
signed main() {
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
#ifdef ACM_LOCAL
freopen("input", "r", stdin);
freopen("output", "w", stdout);
signed test_index_for_debug = 1;
char acm_local_for_debug = 0;
do {
if (acm_local_for_debug == '$') exit(0);
if (test_index_for_debug > 20)
throw runtime_error("Check the stdin!!!");
auto start_clock_for_debug = clock();
solve();
auto end_clock_for_debug = clock();
cout << "Test " << test_index_for_debug << " successful" << endl;
cerr << "Test " << test_index_for_debug++ << " Run Time: "
<< double(end_clock_for_debug - start_clock_for_debug) / CLOCKS_PER_SEC << "s" << endl;
cout << "--------------------------------------------------" << endl;
} while (cin >> acm_local_for_debug && cin.putback(acm_local_for_debug));
#else
solve();
#endif
return 0;
}