题目意思
给出一颗树,每个结点有一个值,求两节点之间第k小的值是谁
题目思路
一开始看到要两个结点之间的问题,准备用树链剖分,然后dfs序建主席树,但是想了一会发现这样建树之后,树链剖分的方法不好找到一个连续的区间来求第k小,多个区间也不晓得咋转移啥的,于是这个思路搁置。
后来发现思维被框在dfs序和普通树里面了。
因为主席树跟普通树区别是很大的,对于主席树我们每次建一棵树是可以任意选择连接哪一颗树的,并不是只能连接前一个树。
所以我们可以在遍历储存数据的树时,边建主席树。
这点想明白了后面的思路就全出来了。
对于x到y的区间,其实就是(x对应的主席树位置+y对应的主席树位置-x,y的lca的父亲节点对应的主席树位置-x,y的lca对应的主席树位置)。
那么我们求个x,y的lca,再根据主席树的前缀性质就能求出答案了
ac代码
const int maxn = 2e5+10;
const int inf = 1e9+10;
const ll llinf =1e18+10;
const ll mod = 1e9+7;
const double pi = acos(-1);
int a[maxn],first[maxn];
int dep[maxn],dp[maxn][20];
int n,m;
vector<int>vec;
struct node
{
int l,r,sum;
}hjt[maxn*40];
int cot,cnt,root[maxn];
int getid(int x)
{
return lower_bound(vec.begin(),vec.end(),x)-vec.begin()+1;
}
struct edge
{
int to,next;
}e[maxn];
void add(int u,int v)
{
e[++cnt].to=v;
e[cnt].next=first[u];
first[u]=cnt;
}
void inserts(int l,int r,int pre,int &now,int p)
{
hjt[++cot]=hjt[pre];
now=cot;
hjt[now].sum++;
if(l==r)return;
int mid=(l+r)>>1;
if(p<=mid)
inserts(l,mid,hjt[pre].l,hjt[now].l,p);
else
inserts(mid+1,r,hjt[pre].r,hjt[now].r,p);
}
int query(int l,int r,int L,int R,int lca,int falca,int k)
{
if(l==r)return l;
int mid=(l+r)>>1;
int tem=hjt[hjt[L].l].sum+hjt[hjt[R].l].sum-hjt[hjt[lca].l].sum-hjt[hjt[falca].l].sum;
if(tem>=k)
return query(l,mid,hjt[L].l,hjt[R].l,hjt[lca].l,hjt[falca].l,k);
else
return query(mid+1,r,hjt[L].r,hjt[R].r,hjt[lca].r,hjt[falca].r,k-tem);
}
void dfs(int x,int fa,int d)
{
dp[x][0]=fa,dep[x]=d;
inserts(1,n,root[fa],root[x],getid(a[x]));
for(int i=first[x];i;i=e[i].next)
{
int to=e[i].to;
if(to==fa)continue;
dfs(to,x,d+1);
}
}
void init()
{
for(int j = 0;(1 << (j+1)) < n;j++)
{
for(int i = 1;i <= n;i++)
{
if(dp[i][j] < 0) dp[i][j+1] = -1;
else dp[i][j+1] = dp[dp[i][j]][j];
}
}
}
int LCA(int u,int v)
{
if(dep[u] > dep[v]) swap(u,v);
int temp = dep[v] - dep[u];
for(int i = 0;(1 << i) <= temp;i++)
{
if((1<<i) & temp) v = dp[v][i];
}
if(v == u) return u;
for(int i = log(n*1.0) / log(2.0);i >= 0;i--)
{
if(dp[u][i] != dp[v][i])
u = dp[u][i],v = dp[v][i];
}
return dp[u][0];
}
signed main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)scanf("%d",&a[i]),vec.push_back(a[i]);
sort(vec.begin(),vec.end());
vec.erase(unique(vec.begin(),vec.end()),vec.end());
for(int i=1,u,v;i<n;i++)
{
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
dfs(1,0,0);
init();
while(m--)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
int lca=LCA(x,y);
printf("%d\n",vec[query(1,n,root[x],root[y],root[lca],root[dp[lca][0]],z)-1]);
}
}