2021牛客暑期多校训练营4 E.Tree Xor
原题地址
题目大意:
给出n个节点的树,以及相邻节点的异或值,现在要给n个节点填上对应的权值,每个节点的权值需要在
[
l
i
,
r
i
]
[l_i,r_i]
[li,ri]以内。
1
≤
n
≤
1
e
5
,
0
≤
l
i
≤
r
i
<
2
30
,
0
≤
w
u
⊕
w
v
<
2
30
1\leq n \leq 1e5,0 \leq l_i\leq r_i < 2^{30},0 \leq w_u \oplus w_v < 2^{30}
1≤n≤1e5,0≤li≤ri<230,0≤wu⊕wv<230
注:w表示树上节点的权值。
题解:
容易发现,只要树上一个节点的权值确定,所有节点的权值都是唯一确定的。
选取1号节点作为根节点。
令 w [ 1 ] = 0 w[1]=0 w[1]=0 时可以确定树上其他节点的权值为 w [ i ] w[i] w[i]( 2 ≤ i ≤ n 2\leq i \leq n 2≤i≤n)
这时再令 w [ 1 ] = a w[1]=a w[1]=a ,那么树上其他节点的权值就会变成 w [ i ] ⊕ a w[i] \oplus a w[i]⊕a
那么原问题就转换成了以下n个不等式的解的数量:
l
1
≤
a
≤
r
1
l_1 \leq a\leq r_1
l1≤a≤r1
l 2 ≤ a ⊕ w [ 2 ] ≤ r 2 l_2 \leq a\oplus w[2]\leq r_2 l2≤a⊕w[2]≤r2
l 3 ≤ a ⊕ w [ 3 ] ≤ r 3 l_3 \leq a\oplus w[3]\leq r_3 l3≤a⊕w[3]≤r3
. . . . . . ...... ......
l n ≤ a ⊕ w [ n ] ≤ r n l_n \leq a\oplus w[n]\leq r_n ln≤a⊕w[n]≤rn
子问题:有多少个数异或 x小于y?(或者有多少个数异或x大于y?)
该子问题在2021杭电多校第二场1004.I love counting_zrr12138的博客-CSDN博客这篇博客中有非常详细的求解办法。
回到原问题:
我们容易发现问题转换成了求 2 ∗ n 2*n 2∗n个区间的交集问题,麻烦的是这 2 ∗ n 2*n 2∗n个区间并不是连续的,每个区间最多是30个小区间的并集。
如果直接做区间运算非常容易超时,因为相邻两个区间的小区间要暴力求交,并且需要合并,合并时需要排序。这样一来复杂度很好,我试图各种玄学优化卡过去以失败告终。
简单证明这样搞的复杂度为:(并且常数还比较大)
O
(
2
∗
n
∗
30
∗
30
∗
l
o
g
30
)
O(2*n*30*30*log30)
O(2∗n∗30∗30∗log30)
我们和以上提到的博客中一样不难发现,每个区间有以下性质:
- 区间内的子区间两两互不相交
- 每个子区间在二进制下的某位之前全部相同,之后遍历全0到全1.例如101000-101111
考虑以上性质,我们将查询的子区间的前缀(二进制下前面相同的部分,上述例子中为101)插入到字典树上,并且打上标记,标记的含义为以该节点向下的所有叶子节点(第30层的,不一定在字典树中有记录,但是可以统计)都在区间内。
最后进行一次询问,在字典树上从上到下进行遍历,累加标记数量(同一个节点可以打多次标记),累加标记数量的含义为以该节点向下的所有叶子节点(第30层的,不一定在字典树中有记录,但是可以统计)被区间覆盖了多少次。
当标记数量达到 2 ∗ n 2*n 2∗n时(表示 2 ∗ n 2*n 2∗n区间都覆盖了,就是要求的交集)直接统计结果返回,否则就遍历到叶子节点(有记录的)。
时间复杂度为
O
(
2
∗
n
∗
30
∗
30
)
O(2*n*30*30)
O(2∗n∗30∗30)
AC代码:
#include <bits/stdc++.h>
#define debug(x) cout<<#x"=" <<x<<'\n';
//#define debug(x) ;
using ll=long long;
using namespace std;
ll gcd(ll a,ll b)
{
return b?gcd(b,a%b):a;
}
template<typename T> void read(T &x)
{
x = 0;
char ch = getchar();
ll f = 1;
while(!isdigit(ch))
{
if(ch == '-')f*=-1;
ch=getchar();
}
while(isdigit(ch))
{
x = x*10+ch-48;
ch=getchar();
}
x*=f;
}
ll quick_pow(ll a, ll b,ll mod)
{
ll ans = 1, base = a;
while(b)
{
if(b&1)
ans=ans*base%mod;
base = base*base%mod;
b >>= 1;
}
return ans;
}
const int M=1e5+10;
int l[M],r[M],w[M];
vector<pair<int,int> >vec[M];
void dfs(int x,int fa,int val)
{
w[x]=val;
for(auto it:vec[x])
{
if(it.first==fa) continue;
dfs(it.first,x,val^it.second);
}
}
int trie[M*100][2];
int ed[M*100];
int tot=0;
const int mbit=29;
void add(const int &x,const int &j)
{
int p=0;
for(int i=mbit; i>=j; i--)
{
int t=((x>>i)&1);
if(trie[p][t]==0) trie[p][t]=++tot;
p=trie[p][t];
}
ed[p]++;
}
int res=0;
int n;
void ask(int p,int val,int deep)
{
int l=trie[p][0];
int r=trie[p][1];
val+=ed[p];
if(val==2*n)
{
res+=(1<<(deep+1));
return ;
}
if(l) ask(l,val,deep-1);
if(r) ask(r,val,deep-1);
}
int main()
{
read(n);
for(int i=1; i<=n; i++)
{
read(l[i]),read(r[i]);
}
for(int i=1; i<n; i++)
{
int x,y,w;
read(x),read(y),read(w);
vec[x].push_back(make_pair(y,w));
vec[y].push_back(make_pair(x,w));
}
w[1]=0;
for(auto it:vec[1])
{
dfs(it.first,1,it.second);
}
for(int i=1; i<=n; i++)
{
int now=0;
for(int j=mbit; j>=0; j--) //考虑w[i] ^ x <= r[i]
{
if(r[i]>>j&1)
{
if(w[i]>>j&1)
{
add(now+(1LL<<j),j);
}
else
add(now,j);
}
if(((r[i]>>j&1)==0) ^ ((w[i]>>j &1)==0)) now|=(1<<j);
}
add(w[i]^r[i],0);
now=0;
for(int j=mbit; j>=0; j--) //考虑w[i] ^ x >= l[i]
{
if((l[i]>>j& 1)==0)
{
if(w[i]>>j &1)
{
add(now,j);
}
else
add(now+(1LL<<j),j);
}
if(((l[i]>>j&1)==0) ^ ((w[i]>>j &1)==0)) now|=(1<<j);
}
add(w[i]^l[i],0);
}
ask(0,0,mbit);
printf("%d\n",res);
return 0;
}
/*
2
0 3
2 2
1 2 0
*/