查询树链第K大 。
每个版本的线段树维护的是 从这个节点到 根的 树链的版本, 由于树链第K大,在统计比X 小的数个数时 是可以 进行加减法运算的,所以 就可以用可持久化数据结构。
维护个数时 , sum = f(a) + f(b) - f(c) -f(d) : c 为 a,b 的最近公共祖先, d 为 c 的父亲节点。这样就是 四个版本运算。
同时:二分可以直接在树上跑,判断 左半区域的和 是否大于K,大于K 说明第K大的值 还在 左区间, 相反在右区间里查 第K -sum 大的数。
复杂度 O(nlgn) 如果直接二分区间 复杂度是O(nlgnlgn)。
倍增 LCA 算法:
const int K = 18;
int d[maxn];
int p[maxn][K];
void dfs(int rt,int f){
d[rt]=d[f]+1;
p[rt][0]=f;
int pos = mp[num[rt]];
root[rt] = update(pos,1,n,1,root[f]);
for(int i=1;i<K;i++) p[rt][i] = p[p[rt][i-1]][i-1];
for(int i=head[rt];i!=-1;i= edge[i].next){
int son = edge[i].v;
if(son==f)continue;
dfs(son,rt);
}
}
int lca(int a,int b){
if(d[a]>d[b]) swap(a,b);
if(d[a]<d[b]){
int del = d[b]-d[a];
for(int i=0;i<K;i++) if(del &(1<<i)) b= p[b][i];
}
if(a!=b){
for(int i= K-1;i>=0;i--){
if(p[a][i]!= p[b][i]){
a = p[a][i],b = p[b][i];
}
}
a= p[a][0],b = p[b][0];
}
return a;
}
代码:
#include <vector>
#include <list>
#include <map>
#include <set>
#include <deque>
#include <stack>
#include <cstring>
#include <bitset>
#include <algorithm>
#include <functional>
#include <numeric>
#include <utility>
#include <sstream>
#include <iostream>
#include <iomanip>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <assert.h>
#include <queue>
#define REP(i,n) for(int i=0;i<n;i++)
#define TR(i,x) for(typeof(x.begin()) i=x.begin();i!=x.end();i++)
#define ALLL(x) x.begin(),x.end()
#define SORT(x) sort(ALLL(x))
#define CLEAR(x) memset(x,0,sizeof(x))
#define FILLL(x,c) memset(x,c,sizeof(x))
using namespace std;
const double eps = 1e-9;
#define LL long long
#define pb push_back
const int maxn = 101000;
const int K = 18;
int n ,m ;
int num[maxn];
int d[maxn];
int p[maxn][K];
map<int,int>mp;
map<int,int>::iterator it;
int idx[maxn];
int head[maxn];
struct Edge{
int v;
int next;
}edge[2*maxn];
int tot;
void init(){
memset(head,-1,sizeof(head));
CLEAR(d);
CLEAR(p);
tot = 0;
}
void add(int u,int v){
tot ++;
edge[tot].v= v;
edge[tot].next = head[u];
head[u] = tot;
}
struct Node{
Node *l,*r;
int sum;
}nodes[maxn*40];
Node *root[maxn];
Node *null;
int C;
void inits(){
C= 0;
null = &nodes[C++];
root[0] = null;
null->l = null->r = null;
null->sum = 0;
}
Node *update(int pos,int left ,int right,int val,Node *root){
Node *rt = &nodes[C++];
rt->l = root->l;
rt->r = root->r;
rt->sum = root->sum;
if(left ==right){
rt->sum +=val;
return rt;
}
int mid =(left +right)/2;
if(pos<=mid){
rt ->l =update(pos,left,mid,val,root->l);
}else{
rt ->r = update(pos,mid+1,right,val,root->r);
}
rt->sum = rt->l->sum + rt->r->sum;
return rt;
}
int query(int k,int left ,int right,Node *rt,Node *rt2,Node *rt3,Node *rt4){
// cout << left << " lr "<<right<<endl;
if(left ==right){
return left;
}
int mid = (left +right)/2;
// cout <<rt->sum<<" "<< rt2->sum <<" "<<rt3->sum<<" "<<rt4->sum<<endl;
int sum = rt->l->sum + rt2->l->sum - rt3->l->sum - rt4->l->sum;
// cout << sum <<" sum k " << k << " "<<mid << endl;
if(sum>=k){
return query(k,left,mid,rt->l,rt2->l,rt3->l,rt4->l);
}else{
return query(k-sum,mid+1,right,rt->r,rt2->r,rt3->r,rt4->r);
}
}
int get(int a,int b,int c,int d,int k){
return query(k,1,n,root[a],root[b],root[c],root[d]);
}
void dfs(int rt,int f){
d[rt]=d[f]+1;
p[rt][0]=f;
int pos = mp[num[rt]];
root[rt] = update(pos,1,n,1,root[f]);
for(int i=1;i<K;i++) p[rt][i] = p[p[rt][i-1]][i-1];
for(int i=head[rt];i!=-1;i= edge[i].next){
int son = edge[i].v;
if(son==f)continue;
dfs(son,rt);
}
}
int lca(int a,int b){
if(d[a]>d[b]) swap(a,b);
if(d[a]<d[b]){
int del = d[b]-d[a];
for(int i=0;i<K;i++) if(del &(1<<i)) b= p[b][i];
}
if(a!=b){
for(int i= K-1;i>=0;i--){
if(p[a][i]!= p[b][i]){
a = p[a][i],b = p[b][i];
}
}
a= p[a][0],b = p[b][0];
}
return a;
}
void solve(){
init();
inits();
for(int i =1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
dfs(1,0);
for(int i=1;i<=m;i++){
int a,b,k;
scanf("%d%d%d",&a,&b,&k);
int t1 = lca(a,b);
int t2 = p[t1][0];
int ans = get(a,b,t1,t2,k);
printf("%d\n",idx[ans]);
}
}
int main(){
while(~scanf("%d%d",&n,&m)){
mp.clear();
for(int i=1;i<=n;i++){
scanf("%d",&num[i]);
mp[num[i]] = 1;
}
int tot2 = 0;
for(it = mp.begin();it!=mp.end();it++){
tot2 ++ ;
it->second = tot2;
//cout << tot2 << " "<<it->first<<endl;
idx[tot2] = it->first;
}
solve();
}
return 0;
}