传送门:Hdu 5589 Tree
题意:
一棵树有N个节点,编号为1到N,每条边都有边权。定义f(u,v)为从u到v路径上所有边权的异或和。给定一个数M,有Q次查询,每次给定一个区间[l,r],询问有多少对(u,v)满足f(u,v)>M (l≤ u< v≤r).
思路:
f[u]表示u到根的边的异或
树上两点之间的异或值为f[u]^f[v],
然后将查询用莫队算法分块,每个点插入到字典树中,利用字典树维护两点异或值大于等于M复杂度O(N^(3/2)*logM)
#include<bits/stdc++.h>
using namespace std;
const int maxn=50010;
typedef pair<int,int> PI;
vector<PI>G[maxn];
int a[maxn],unit,q;
bitset <20> cnt;
int vis[maxn];
long long ans[maxn];
struct node{
int l,r,id;
}Q[maxn];
bool cmp(node u,node v){
if(u.l/unit!=v.l/unit)
return u.l/unit<v.l/unit;
return u.r<v.r;
}
void bfs(int u){
queue<int>Q;
Q.push(1);
memset(vis,0,sizeof(vis));
vis[1]=1;
while(!Q.empty()){
int u=Q.front();
Q.pop();
for(int i=0;i<G[u].size();i++){
int v=G[u][i].first;
if(vis[v])
continue;
a[v]=a[u]^G[u][i].second;
Q.push(v);
vis[v]=1;
}
}
}
struct Trie{
int val[maxn*20],next[maxn*20][2];
int sz;
void init(){
sz=1;
memset(next[0],0,sizeof(next[0]));
}
void insert(int num,int x){
int u=0,c;
for(int i=17;i>=0;i--){
if((1<<i)&num)
c=1;
else
c=0;
if(!next[u][c])
memset(next[sz],0,sizeof(next[sz])),val[sz]=0,next[u][c]=sz++;
u=next[u][c];
val[u]+=x;
}
}
int query(int num){
int ans=0,c,u=0;
for(int i=17;i>=0;i--){
if((1<<i)&num)
c=1;
else
c=0;
if(cnt[i]==0){
if(next[u][c^1])
ans+=val[next[u][c^1]];
if(!next[u][c]||val[next[u][c]]==0)
return ans;
u=next[u][c];
}
else{
if(!next[u][c^1]||val[next[u][c^1]]==0)
return ans;
u=next[u][c^1];
}
}
return ans;
}
};
Trie trie;
void solve(){
int L=1,R=0;
trie.init();
long long tmp=0;
for(int i=1;i<=q;i++){
while(R<Q[i].r){
R++;
tmp+=trie.query(a[R]);
trie.insert(a[R],1);
}
while(R>Q[i].r){
trie.insert(a[R],-1);
tmp-=trie.query(a[R]);
R--;
}
while(L<Q[i].l){
trie.insert(a[L],-1);
tmp-=trie.query(a[L]);
L++;
}
while(L>Q[i].l){
L--;
tmp+=trie.query(a[L]);
trie.insert(a[L],1);
}
ans[Q[i].id]=tmp;
}
}
int main(){
int n,m;
while(scanf("%d%d%d",&n,&m,&q)!=EOF){
int u,v,w;
for(int i=1;i<=n;i++)
G[i].clear();
for(int i=1;i<n;i++){
scanf("%d%d%d",&u,&v,&w);
G[u].push_back(PI{v,w});
G[v].push_back(PI{u,w});
}
cnt=m;
a[1]=0,unit=sqrt(n);
bfs(1);
for(int i=1;i<=q;i++){
scanf("%d%d",&Q[i].l,&Q[i].r);
Q[i].id=i;
}
sort(Q+1,Q+q+1,cmp);
solve();
for(int i=1;i<=q;i++)
printf("%lld\n",ans[i]);
}
return 0;
}