SPOJ COT Count on a tree
主席树,倍增LCA
题意
求树上A,B两点路径上第K小的数
思路
同样是可持久化线段树,只是这一次我们用它来维护树上的信息。
我们之前已经知道,可持久化线段树实际上是维护的一个前缀和,而前缀和不一定要出现在一个线性表上。
比如说我们从一棵树的根节点进行DFS,得到根节点到各节点的距离dist[x]——这是一个根-x路径上点与根节点距离的前缀和。
利用这个前缀和,我们可以解决一些树上任意路径的问题,比如在线询问[a,b]点对的距离——答案自然是dist[a]+dist[b]-2*dist[lca(a,b)]。
同理,我们可以利用可持久化线段树来解决树上任意路径的问题。
DFS遍历整棵树,然后在每个节点上建立一棵线段树,某一棵线段树的“前一版本”是位于该节点父亲节点fa的线段树。
利用与之前类似的方法插入点权(排序离散)。那么对于询问[a,b],答案就是root[a]+root[b]-root[lca(a,b)]-root[fa[lca(a,b)]]上的第k大。
代码
#include<bits/stdc++.h>
#define M(a,b) memset(a,b,sizeof(a))
typedef long long LL;
const int MAXN=100007;
using namespace std;
int cnt,root[MAXN],a[MAXN];
struct Node{int l, r, sum;}T[MAXN*40];
void update(int l,int r,int &x,int y,int pos,int c)
{
x=++cnt;
T[x]=T[y];T[x].sum+=c;
if(l==r) return;
int mid=(l+r)>>1;
if(pos<=mid) update(l,mid,T[x].l,T[y].l,pos,c);
else update(mid+1,r,T[x].r,T[y].r,pos,c);
}
int find_Kth(int l,int r,int rx,int ry,int rlca,int rflca,int k)
{
if(l>=r) return l;
int mid=(l+r)>>1;
int sum=T[T[rx].l].sum+T[T[ry].l].sum-T[T[rlca].l].sum-T[T[rflca].l].sum;
if(sum>=k) return find_Kth(l,mid,T[rx].l,T[ry].l,T[rlca].l,T[rflca].l,k);
else return find_Kth(mid+1,r,T[rx].r,T[ry].r,T[rlca].r,T[rflca].r,k-sum);
}
int val[MAXN];
int s[MAXN];
int hs[MAXN];
struct Edge
{
int to,ne;
}e[MAXN<<1];
int head[MAXN],ecnt;
void addedge(int from,int to)
{
ecnt++;e[ecnt].to=to,e[ecnt].ne=head[from];head[from]=ecnt;
ecnt++;e[ecnt].to=from,e[ecnt].ne=head[to];head[to]=ecnt;
}
int parent[MAXN][50] ,depth[MAXN];
void dfs(int n,int u,int la,int d)
{
parent[u][0]=la;
depth[u]=d;
update(1,n,root[u],root[la],hs[u],1);
for(int i=head[u];~i;i=e[i].ne)
{
if(e[i].to!=la)
{
dfs(n,e[i].to,u,d+1);
}
}
}
void init_lca(int n,int sz)
{
memset(depth, -1, sizeof depth);
for(int i=1; i<=n; i++)
if(depth[i]<0)
dfs(sz,i, 1, 0);
for(int k=0; k+1<30; k++)
{
for(int i=1; i<=n; i++)
{
if(parent[i][k]<0) parent[i][k+1] = -1;
else
{
parent[i][k+1] = parent[parent[i][k]][k];
}
}
}
}
LL lca(int u, int v)
{
if(depth[v] > depth[u]) swap(u, v);
int dis=depth[u]-depth[v],k=0;
while(dis)
{
if(dis&1)
u=parent[u][k];
dis>>=1;
k++;
}
k=0;
while (u!=v)
{
if ( parent[u][k]!= parent[v][k] || (parent[u][k]== parent[v][k] && k ==0) )
{
u=parent[u][k];
v=parent[v][k];
k++;
}
else k--;
}
return u;
}
int main()
{
int n,m;scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&val[i]),s[i]=val[i];
sort(s+1,s+n+1);
int sz=unique(s+1,s+n+1)-s-1;
for(int i=1;i<=n;i++)
hs[i]=lower_bound(s+1,s+sz+1,val[i])-s;
M(head,-1);ecnt=0;
for(int i=1;i<n;i++)
{
int a,b;scanf("%d%d",&a,&b);
addedge(a,b);
}
init_lca(n,sz);
parent[1][0]=0;
while(m--)
{
int a,b,k;scanf("%d%d%d",&a,&b,&k);
int l=lca(a,b);
//printf("%d %d %d %d\n",a,b,l,parent[l][0]);
//printf("%d %d\n",depth[a],depth[b]);
int res=find_Kth(1,sz,root[a],root[b],root[l],root[parent[l][0]],k);
printf("%d\n",s[res]);
}
}