(这篇文章是模板向…了解具体思想还是看网上其他详细讲解吧QAQ)
LCA,即最近公共祖先,是在有根树中两个点最近的公共祖先,在树上问题中非常有用QAQ
常用LCA求法:
一、树链剖分LCA
树链剖分LCA,顾名思义,就是用树链剖分求LCA,两个节点从各自的重链往上跳,跳到一条重链上LCA就为上面的那个点
复杂度
O(n−qlog(n))
模板代码:(*题为Codevs 2370)
#include<algorithm>
#include<ctype.h>
#include<cstdio>
#define N 80050
using namespace std;
inline int read(){
int x=0,f=1;char c;
do {c=getchar();if(c=='-') f=-1;} while(!isdigit(c));
do x=(x<<3)+(x<<1)+c-'0',c=getchar(); while(isdigit(c));
return x*f;
}
int n,m,x,y,k,Top;
int fir[N],top[N],son[N],size[N],dep[N],dis[N],fa[N];
struct Edge{
int to,nex,k;
Edge(int _=0,int __=0,int ___=0):to(_),nex(__),k(___){}
}nex[N<<1];
inline void add(int x,int y,int k){
nex[++Top]=Edge(y,fir[x],k);
fir[x]=Top;
}
void dfs1(int x,int Dep,int Fa,int Dis){//剖1
dep[x]=Dep;fa[x]=Fa;dis[x]=Dis;
size[x]=1;
for(int i=fir[x];i;i=nex[i].nex){
if(nex[i].to==Fa) continue;
dfs1(nex[i].to,Dep+1,x,Dis+nex[i].k);
size[x]=size[x]+size[nex[i].to];
if(size[son[x]]<size[nex[i].to]) son[x]=nex[i].to;
}
}
void dfs2(int x,int Top){//剖2
top[x]=Top;
if(!son[x]) return;
dfs2(son[x],Top);
for(int i=fir[x];i;i=nex[i].nex){
if(nex[i].to==fa[x] || nex[i].to==son[x]) continue;
dfs2(nex[i].to,nex[i].to);
}
}
inline int Query(int x,int y){//往上跳
int sum=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
sum=sum+dis[x]-dis[fa[top[x]]];
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
sum=sum+dis[y]-dis[x];
return sum;
}
int main(){
n=read();
for(int i=1;i<n;i++){
x=read()+1;y=read()+1;k=read();
add(x,y,k);add(y,x,k);
}
dfs1(1,1,0,0);dfs2(1,1);
m=read();
for(int i=1;i<=m;i++){
x=read()+1;y=read()+1;
printf("%d\n",Query(x,y));
}
return 0;
}
二、倍增LCA
考虑一个比较暴力的做法,即先让深的点慢慢爬到和另一个点高度相同,然后一起往上爬值到爬到一起,显然这样太暴力了,倍增LCA就是对其的一种优化
在每次爬的过程中,这次爬
2k
,下次爬
2k−1
,直到爬到一起,显然这种优化效果是非常可观的
复杂度
O(n−q∗log(n))
模板代码如下:
#include<algorithm>
#include<ctype.h>
#include<cstdio>
#define N 80050
using namespace std;
inline int read(){
int x=0,f=1;char c;
do {c=getchar();if(c=='-') f=-1;} while(!isdigit(c));
do x=(x<<3)+(x<<1)+c-'0',c=getchar(); while(isdigit(c));
return x*f;
}
int n,m,x,y,k,Top;
int fir[N],f[25][N],s[25][N],dep[N];
struct Edge{
int to,nex,k;
Edge(int _=0,int __=0,int ___=0):to(_),nex(__),k(___){}
}nex[N<<1];
inline void add(int x,int y,int k){
nex[++Top]=Edge(y,fir[x],k);
fir[x]=Top;
}
void dfs(int x,int fa,int dis,int Dep){
f[0][x]=fa;s[0][x]=dis;dep[x]=Dep;
for(int i=fir[x];i;i=nex[i].nex){
if(nex[i].to==fa) continue;
dfs(nex[i].to,x,nex[i].k,Dep+1);
}
}
inline int Query(int x,int y){
int sum=0;
if(dep[x]<dep[y]) swap(x,y);
for(int i=18;i>=0;i--)
if(dep[f[i][x]]>=dep[y]){
sum=sum+s[i][x];
x=f[i][x];
}
if(x==y) return sum;//跳到一起就不跳了
for(int i=18;i>=0;i--)
if(f[i][x]!=f[i][y]){
sum=sum+s[i][x]+s[i][y];
x=f[i][x];y=f[i][y];
}
return sum+s[0][x]+s[0][y];
}
int main(){
n=read();
for(int i=1;i<n;i++){
x=read()+1;y=read()+1;k=read();
add(x,y,k);add(y,x,k);
}
dfs(1,0,0,1);
for(int j=1;j<=17;j++)
for(int i=1;i<=n;i++){
f[j][i]=f[j-1][f[j-1][i]];
s[j][i]=s[j-1][f[j-1][i]]+s[j-1][i];
}
m=read();
for(int i=1;i<=m;i++){
x=read()+1;y=read()+1;
printf("%d\n",Query(x,y));
}
return 0;
}
三、ST表
ST表是基于dp上的一种rmp算法..大概就是要求x和y的LCA为欧拉序中x与y中深度最小的那个点,具体证明还是看网上神犇们的题解吧orz…
复杂度
O(n∗log(n)−q)
模板代码如下:
#include<algorithm>
#include<ctype.h>
#include<cstdio>
#include<cmath>
#define N 800050
using namespace std;
inline int read(){
int x=0,f=1;char c;
do {c=getchar();if(c=='-') f=-1;} while(!isdigit(c));
do x=(x<<3)+(x<<1)+c-'0',c=getchar(); while(isdigit(c));
return x*f;
}
int n,m,x,y,k,top;
int fir[N],s[N<<1],dep[N],fa[N],loc[N],dis[N];
int f[25][N];
struct Edge{
int to,nex,k;
Edge(int _=0,int __=0,int ___=0):to(_),nex(__),k(___){}
}nex[N<<1];
inline void add(int x,int y,int k){
nex[++top]=Edge(y,fir[x],k);
fir[x]=top;
}
void dfs(int x,int Dep,int Fa,int Dis){
fa[x]=Fa;dis[x]=Dis;
s[++top]=x;dep[top]=Dep;
loc[x]=top;
for(int i=fir[x];i;i=nex[i].nex){
if(nex[i].to==Fa) continue;
dfs(nex[i].to,Dep+1,x,Dis+nex[i].k);
s[++top]=x;dep[top]=Dep;
}
}
inline void init(){
for(int i=1;i<=top;i++) f[0][i]=i;
for(int j=1;j<=18;j++)
for(int i=1;i<=top;i++){
if(dep[f[j-1][i]]>dep[f[j-1][i+(1<<(j-1))]])
f[j][i]=f[j-1][i+(1<<(j-1))];
else f[j][i]=f[j-1][i];
}
}
inline int rmq(int l,int r){
int tmp=0;
while((1<<(tmp+1))<=r-l+1)///手动log2
tmp++;
if(dep[f[tmp][l]]>dep[f[tmp][r-(1<<tmp)+1]]) return f[tmp][r-(1<<tmp)+1];
return f[tmp][l];//这个预处理的过程神烦
}
inline int Query(int x,int y){
if(loc[x]>loc[y]) swap(x,y);
return s[rmq(loc[x],loc[y])];
}
int main(){
n=read();
for(int i=1;i<n;i++){
x=read()+1;y=read()+1;k=read();
add(x,y,k);add(y,x,k);
}
top=0;
dfs(1,1,0,0);
init();
m=read();
for(int i=1;i<=m;i++){
x=read()+1;y=read()+1;
printf("%d\n",dis[x]+dis[y]-dis[Query(x,y)]*2);
}
return 0;
}
四、Tarjan求LCA
又是Tarjan…
Tarjan求LCA是这几个算法中唯一一个离线算法,复杂度极优
算法思想大致是随便从一个点走遍历所有子树后把所有子树的fa都改为这个点,如果遍历完x时y已经走过了,则他俩的LCA就是y的祖先(find(y))
复杂度
O(n+q)
,顺便问候Tarjan及他工作室的小伙伴
模板代码如下:
#include<ctype.h>
#include<cstdio>
#define N 80050
using namespace std;
inline int read(){
int x=0,f=1;char c;
do {c=getchar();if(c=='-') f=-1;} while(!isdigit(c));
do x=(x<<3)+(x<<1)+c-'0',c=getchar(); while(isdigit(c));
return x*f;
}
int n,m,x,y,k,top,Top;
int fa[N],dis[N],fir[N],Fir[N],f[N],ans[N],fx[N],fy[N];
bool b[N];
struct Edge{
int to,nex,k;
Edge(int _=0,int __=0,int ___=0):to(_),nex(__),k(___){}
}nex[N*2],Nex[N*2];
void dfs(int x,int Dis,int Fa){
fa[x]=Fa;dis[x]=Dis;
for(int i=fir[x];i;i=nex[i].nex){
if(nex[i].to==Fa) continue;
dfs(nex[i].to,nex[i].k+Dis,x);
}
}
inline void add(int x,int y,int k,int &top,int fir[],Edge nex[]){
nex[++top]=Edge(y,fir[x],k);
fir[x]=top;
}
int find(int x){return f[x]==x?f[x]:f[x]=find(f[x]);}
inline void link(int x,int y){
int fx=find(x),fy=find(y);
f[fx]=fy;
}
void Tarjan(int x){
b[x]=true;
for(int i=fir[x];i;i=nex[i].nex){
if(nex[i].to==fa[x]) continue;
Tarjan(nex[i].to);
link(nex[i].to,x);
}
for(int i=Fir[x];i;i=Nex[i].nex)
if(b[Nex[i].to])
ans[Nex[i].k]=find(Nex[i].to);
}
int main(){
n=read();
for(int i=1;i<n;i++){
x=read()+1;y=read()+1;k=read();
add(x,y,k,top,fir,nex);add(y,x,k,top,fir,nex);
f[i]=i;
}
f[n]=n;
dfs(1,0,0);
m=read();
for(int i=1;i<=m;i++){
x=read()+1;y=read()+1;
fx[i]=x;fy[i]=y;
add(x,y,i,Top,Fir,Nex);add(y,x,i,Top,Fir,Nex);
}
Tarjan(1);
for(int i=1;i<=m;i++){
printf("%d\n",dis[fx[i]]+dis[fy[i]]-2*dis[ans[i]]);
}
return 0;
}