参考于
LCA-倍增法(在线)O(nlogn)-O(logn) - ousuo
/*题为codevs1036
假设有N个城镇,首都编号为1,商人从首都出发,其他各城镇之间都有道路连接,任意两个城镇之间如果有直连道路,在他们之间行驶需要花费单位时间。该国公路网络发达,从首都出发能到达任意一个城镇,并且公路网络不会存在环。
你的任务是帮助该商人计算一下他的最短旅行时间。
e.g.
Input
5
1 2 1 5 3 5 4 5
4 1 3 2 5
Output
7
*/
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
int n,m;
int head[30001],next[60001],to[60001],sum;
int deep[30001];
int f[30001][20];
int S,T;
int ans_lca,ans;
void dfs(int u)//遍历所有节点;
{
for(int i=head[u];i!=0;i=next[i])
{
if(deep[to[i]]==0)
{
deep[to[i]]=deep[u]+1;
f[to[i]][0]=u;
dfs(to[i]);
}
}
}
void init()
{
for(int i=1;(1<<i)<=n;i++)
for(int j=1;j<=n;j++)
if(f[j][i-1]) f[j][i]=f[f[j][i-1]][i-1];
}
int lca(int a,int b)
{
int s=a,e=b;
if(deep[s]<deep[e]) {int temp=s;s=e;e=temp;}
int t=0;
while((1<<t)<deep[s]) t++;
for(int i=t;i>=0;i--) if(deep[s]-(1<<i)>=deep[e]) s=f[s][i];
if(s==e) return s;
for(int i=t;i>=0;i--)
{
if(f[s][i]&&f[s][i]!=f[e][i])
{
s=f[s][i];
e=f[e][i];
}
}
return f[s][0];
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++) {f[i][0]=i;deep[i]=0;}
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
to[++sum]=v;
next[sum]=head[u];
head[u]=sum;
to[++sum]=u;
next[sum]=head[v];
head[v]=sum;//***双向建边;
}
dfs(1);
init();
scanf("%d",&m);
S=1;
for(int i=1;i<=m;i++)
{
if(i>1) S=T;
scanf("%d",&T);
ans_lca=lca(S,T);
ans+=deep[S]+deep[T]-2*deep[ans_lca];
}
printf("%d",ans);
}
/*进阶codevs2370*/
#include<cstdio>
#include<cmath>
#include<algorithm>
using namespace std;
int n,m,ans_lca;
int head[50005],next[100005],to[100005],w[100005],sum;
int deep[50005];
int f[50005][20],d[50005][20];
int S,T;
void dfs(int u)
{
for(int i=head[u];i!=0;i=next[i])
if(deep[to[i]]==0)
{
f[to[i]][0]=u;
deep[to[i]]=deep[u]+1;
d[to[i]][0]=w[i];
dfs(to[i]);
}
}
void init()
{
for(int i=1;(1<<i)<n;i++)
for(int j=0;j<n;j++)
{
if(deep[j]>(1<<i))
{
f[j][i]=f[f[j][i-1]][i-1];
d[j][i]=d[j][i-1]+d[f[j][i-1]][i-1];
}
}
}
int lca(int a,int b)
{
if(deep[a]<deep[b]) {int x=a;a=b;b=x;}
int t=0;
while((1<<t)<deep[a]) t++;
for(int i=t;i>=0;i--) if(deep[a]-(1<<i)>=deep[b]) a=f[a][i];
if(a==b) return a;
for(int i=t;i>=0;i--)
if(f[a][i]&&f[a][i]!=f[b][i])
{
a=f[a][i];
b=f[b][i];
}
return f[a][0];
}
int tot(int a,int b)
{
int tot=0;
int t=deep[a]-deep[b];
for(int i=0;i<=17;i++)
{
if(t&(1<<i))
{
tot+=d[a][i];
a=f[a][i];
}
}
return tot;
}
int main()
{
scanf("%d",&n);
for(int i=0;i<=n;i++) f[i][0]=i;
for(int i=1;i<n;i++) {
int u,v,x;
scanf("%d%d%d",&u,&v,&x);
to[++sum]=v;
w[sum]=x;
next[sum]=head[u];
head[u]=sum;
to[++sum]=u;
next[sum]=head[v];
w[sum]=x;
head[v]=sum;
}
dfs(0);
init();
scanf("%d",&m);
for(int i=1;i<=m;i++)
{
scanf("%d%d",&S,&T);
int a=S,b=T;
ans_lca=lca(a,b);
printf("%d\n",tot(S,ans_lca)+tot(T,ans_lca));
}
}
/*进阶(水题…?)*/
#include<cstdio>
#include<iostream>
#include<vector>
#include<cstring>
using namespace std;
const int N=200001,pow=20,maxx=999999;
vector<int> g[N];
int d[N],f[N][pow],maxv[N][pow]={0},minv[N][pow],diff[N][pow]={0},dife[N][pow]={0},q[N];
// diff:顺区间最大差
// dife:逆区间最大差
void dfs(int x,int fa)
{
int i,t;
maxv[x][0]=max(q[x],q[fa]);
minv[x][0]=min(q[x],q[fa]);
diff[x][0]=q[x]-q[fa];
dife[x][0]=q[fa]-q[x];
d[x]=d[fa]+1;
f[x][0]=fa;
for(i=1;i<pow;i++)
{
f[x][i]=f[f[x][i-1]][i-1];
maxv[x][i]=max(maxv[x][i-1],maxv[f[x][i-1]][i-1]);
minv[x][i]=min(minv[x][i-1],minv[f[x][i-1]][i-1]);
diff[x][i]=max(maxv[x][i-1]-minv[f[x][i-1]][i-1],max(diff[x][i-1],diff[f[x][i-1]][i-1]));
dife[x][i]=max(maxv[f[x][i-1]][i-1]-minv[x][i-1],max(dife[x][i-1],dife[f[x][i-1]][i-1]));
}
for(i=0;i<g[x].size();i++)
if(g[x][i]!=fa)
dfs(g[x][i],x);
}
int lca(int a,int b)
{
int i,t,total=0,flag=0,maxn=0,minn=maxx;
if(d[a]>d[b]){
flag=1;
a^=b,b^=a,a^=b;
}
if(d[a]<d[b]){
t=d[b]-d[a];
for(i=0;i<pow;i++)
if(t&(1<<i)){
if(!flag){
total=max(total,maxn-minv[b][i]);
maxn=max(maxn,maxv[b][i]);
total=max(total,diff[b][i]);
}
else{
total=max(total,maxv[b][i]-minn);
minn=min(minn,minv[b][i]);
total=max(total,dife[b][i]);
}
b=f[b][i];
}
if(!flag) minn=q[a];
else maxn=q[a];
}
else{
if(!flag) minn=q[a],maxn=q[b];
else maxn=q[a],minn=q[b];
}
if(a!=b){
for(i=pow-1;i>=0;i--)
if(f[a][i]!=f[b][i]){
if(!flag){
total=max(total,maxv[a][i]-minn);
total=max(total,maxn-minv[b][i]);
total=max(total,dife[a][i]);
total=max(total,diff[b][i]);
maxn=max(maxn,maxv[b][i]);
minn=min(minn,minv[a][i]);
}
else{
total=max(total,maxv[b][i]-minn);
total=max(total,maxn-minv[a][i]);
total=max(total,diff[a][i]);
total=max(total,dife[b][i]);
maxn=max(maxn,maxv[a][i]);
minn=min(minn,minv[b][i]);
}
a=f[a][i],b=f[b][i];
}
total=max(total,maxn-minn);
if(!flag){
total=max(total,maxn-q[f[a][0]]);
total=max(total,q[f[b][0]]-minn);
}
else{
total=max(total,maxn-q[f[b][0]]);
total=max(total,q[f[a][0]]-minn);
}
}
return total;
}
int main()
{
memset(minv,27,sizeof(minv));
int n,m,i,j,x,y;
scanf("%d",&n);
for(i=1;i<=n;i++)
scanf("%d",&q[i]);
for(i=1;i<n;i++){
scanf("%d%d",&x,&y);
g[x].push_back(y);
g[y].push_back(x);
}
for(i=0;i<g[1].size();i++)
dfs(g[1][i],1);
scanf("%d",&m);
for(i=1;i<=m;i++){
scanf("%d%d",&x,&y);
printf("%d\n",lca(x,y));
}
return 0;
}