比较优秀的LCA写法(位运算)
int lca(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
int step=dep[x]-dep[y];
for(int i=0;i<=20;i++)
if((1<<i)&step)
x=dis[x][i];
if(x==y) return x;
for(int i=20;i>=0;i--)
if(dis[x][i]!=dis[y][i])
x=dis[x][i],y=dis[y][i];
return fa[x];
}
例:BZOJ1787
画图发现规律,三个点两两之间最近公共祖先点的集合最多只有两个元素,分类讨论即可。
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
typedef pair<int,int> pii;
const int maxn = 500010;
int first[maxn];
struct edg
{
int next;
int to;
}e[maxn<<1];
int fa[maxn],dep[maxn];
int dis[maxn][30];
int n,m;
int e_sum;
void add_edg(int x,int y)
{
e_sum++;
e[e_sum].next=first[x];
first[x]=e_sum;
e[e_sum].to=y;
}
void dfs(int x,int f,int d)
{
fa[x]=f;
dep[x]=d;
dis[x][0]=fa[x];
for(int i=1;i<=20;i++)
dis[x][i]=dis[dis[x][i-1]][i-1];
for(int i=first[x];i;i=e[i].next)
{
int w=e[i].to;
if(w==f) continue;
dfs(w,x,d+1);
}
}
int lca(int x,int y)
{
int now=x,i=0;
if(dep[x]!=dep[y])
{
if(dep[x]<dep[y]) swap(x,y);
now=x;
while(dep[fa[now]]!=dep[y])
{
i=0;
while(dep[dis[now][i]]>dep[y]) i++;
now=dis[now][--i];
}
x=fa[now];
}
if(x==y) return x;
while(fa[x]!=fa[y])
{
i=0;
while(dis[x][i]!=dis[y][i]) i++;
i--;
x=dis[x][i];y=dis[y][i];
}
return fa[x];
}
int solve(int x,int y,int z)
{
int xf=lca(x,y);
int yf=lca(y,z);
int zf=lca(x,z);
int ans=0;
if(xf==yf)
{
printf("%d ",zf);
ans=(dep[x]-dep[zf]-dep[zf]+dep[z]);
ans+=(dep[zf]-dep[xf]-dep[xf]+dep[y]);
}
else if(yf==zf)
{
printf("%d ",xf);
ans=(dep[x]-dep[xf]-dep[xf]+dep[y]);
ans+=(dep[xf]-dep[yf]-dep[yf]+dep[z]);
}
else if(xf==zf)
{
printf("%d ",yf);
ans=(dep[y]-dep[yf]+dep[z]-dep[yf]);
ans+=(dep[yf]-dep[zf]+dep[x]-dep[zf]);
}
return ans;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
add_edg(x,y);
add_edg(y,x);
}
dfs(1,1,1);
for(int i=1;i<=m;i++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
printf("%d\n",solve(x,y,z));
}
return 0;
}
BZOJ3732
类似货车运输,注意倍增写法。
#include<bits/stdc++.h>
#define debug(x) cout<<#x<<"="<<x<<endl
using namespace std;
const int maxn = 15010;
const int maxm = 30010;
struct bian
{
int from,to;
int val;
}ee[maxm];
int first[maxn];
int fa[maxn];
int p[maxn],dep[maxn],dis[maxn][25],mm[maxn][25];
struct edg
{
int next;
int to;
int val;
}e[maxm];
int n,m;
int k;
int e_sum;
bool cmp(bian x,bian y)
{
return x.val<y.val;
}
void add_edg(int x,int y,int z)
{
e_sum++;
e[e_sum].next=first[x];
first[x]=e_sum;
e[e_sum].to=y;
e[e_sum].val=z;
}
int find(int x)
{
if(p[x]==x) return x;
return p[x]=find(p[x]);
}
void merge(int x,int y)
{
int xf=find(x);
int yf=find(y);
p[xf]=yf;
}
void dfs(int x,int f,int d)
{
dep[x]=d;
dis[x][0]=fa[x];
for(int i=1;i<=20;i++)
{
dis[x][i]=dis[dis[x][i-1]][i-1];
mm[x][i]=max(mm[x][i-1],mm[dis[x][i-1]][i-1]);
}
for(int i=first[x];i;i=e[i].next)
{
int w=e[i].to;
if(w==f) continue;
fa[w]=x;
mm[w][0]=e[i].val;
dfs(w,x,d+1);
}
}
int lca(int x,int y)
{
int now=x,i=0;
int ans=-0x3f3f3f3f;
if(dep[x]!=dep[y])
{
if(dep[x]<dep[y]) swap(x,y);
now=x;
while(dep[fa[now]]!=dep[y])
{
i=0;
while(dep[dis[now][i]]>dep[y]) i++;
ans=max(ans,mm[now][--i]);
now=dis[now][i];
}
ans=max(ans,mm[now][0]);
x=fa[now];
}
if(x==y) return ans;
while(fa[x]!=fa[y])
{
i=0;
while(dis[x][i]!=dis[y][i]) i++;
i--;
ans=max(ans,mm[x][i]);
ans=max(ans,mm[y][i]);
x=dis[x][i];y=dis[y][i];
}
ans=max(max(mm[x][0],mm[y][0]),ans);
return ans;
}
int main()
{
scanf("%d%d%d",&n,&m,&k);
for(int i=1;i<=m;i++)
scanf("%d%d%d",&ee[i].to,&ee[i].from,&ee[i].val);
for(int i=1;i<=n;i++) p[i]=i;
memset(mm,128,sizeof mm);
sort(ee+1,ee+1+m,cmp);
for(int i=1;i<=m;i++)
{
int x=ee[i].to;
int y=ee[i].from;
if(find(x)!=find(y))
{
add_edg(x,y,ee[i].val);
add_edg(y,x,ee[i].val);
merge(x,y);
}
}
fa[1]=1;
dfs(1,-1,1);
for(int i=1;i<=k;i++)
{
int x,y;
scanf("%d%d",&x,&y);
int num=lca(x,y);
printf("%d\n",num);
}
return 0;
}