以下代码使用的是倍增算法求lca
最近公共祖先
poj1330
#include<bits/stdc++.h>
using namespace std;
const int N=10010;
const int M=20010;
int h[N], ne[M], to[M];
int cnt, root;
int fa[N][17];
void add(int a, int b)
{
to[cnt] = b,ne[cnt] = h[a],h[a] = cnt++;
}
queue<int> q;
int depth[N];
void bfs()
{
memset(depth,-1,sizeof(depth));
depth[root] = 1;
int now = root;
q.push(now);
while (!q.empty())
{
now = q.front();
q.pop();
for (int i = h[now]; i != -1; i = ne[i])
{
int j = to[i];
if(depth[j]!=-1) continue;
depth[j] = depth[now] + 1;
q.push(j);
fa[j][0] = now;
for (int x = 1; x <= 16; x++)
{
fa[j][x] = fa[fa[j][x - 1]][x - 1];
}
}
}
}
int lca(int x, int y)
{
if (depth[x] < depth[y])
swap(x, y);
for (int i = 16; i >= 0; i--)
{
if (depth[fa[x][i]] >= depth[y])
{
x=fa[x][i];
}
}
if(x==y) return y;
for( int i=16;i>=0;i--){
if(fa[x][i]!=fa[y][i]){
x=fa[x][i],y=fa[y][i];
}
}
return fa[x][0];
}
void solve( ){
memset(h,-1,sizeof h);
memset(fa,0,sizeof(fa));//初始化为0
int n,x,y;
cin>>n;
for( int i=1;i<n;i++){
scanf("%d%d",&x,&y);
add(x,y),add(y,x);
fa[y][0]=x;
}
root=1,cnt=0;
while(fa[root][0]!=-1){
root=fa[root][0];
}
bfs();
scanf("%d%d",&x,&y);
cout<<lca(x,y)<<'\n';
}
int main(){
int t;
cin>>t;
while(t--){
solve();
}
return 0;
}
树上距离
hdu2586
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=40010;
const int M=N+N;
int h[N], ne[M], to[M],val[M];
int cnt, root;
int fa[N][17];
ll dis[N][17];
void add(int a, int b,int c)
{
val[cnt]=c,to[cnt] = b,ne[cnt] = h[a],h[a] = cnt++;
}
queue<int> q;
int depth[N];
void bfs()
{
memset(depth,-1,sizeof(depth));
depth[root] = 1;
int now = root;
q.push(now);
while (!q.empty())
{
now = q.front();
q.pop();
for (int i = h[now]; i != -1; i = ne[i])
{
int j = to[i];
if(depth[j]!=-1) continue;
depth[j] = depth[now] + 1;
q.push(j);
fa[j][0] = now;
dis[j][0]=val[i];
for (int x = 1; x <= 16; x++)
{
fa[j][x] = fa[fa[j][x - 1]][x - 1];
dis[j][x]=dis[fa[j][x-1]][x-1]+dis[j][x-1];
}
}
}
}
int lca(int x, int y)
{
long long res=0;
if (depth[x] < depth[y])
swap(x, y);
for (int i = 16; i >= 0; i--)
{
if (depth[fa[x][i]] >= depth[y])
{
res+=dis[x][i];
x=fa[x][i];
}
}
for( int i=16;i>=0;i--){
if(fa[x][i]!=fa[y][i]){
res+=dis[x][i]+dis[y][i];
x=fa[x][i],y=fa[y][i];
}
}
if(x!=y){
res+=dis[x][0]+dis[y][0];
}
return res;
}
void solve( ){
int n,m,x,y,z;
memset(h,-1,sizeof h);
memset(fa,0,sizeof(fa));//初始化为0
cin>>n>>m;
for( int i=1;i<n;i++){
scanf("%d%d%d",&x,&y,&z);
add(x,y,z),add(y,x,z);
}
root=1,cnt=0;
bfs();
for( int i=0;i<m;i++){
scanf("%d%d",&x,&y);
cout<<lca(x,y)<<'\n';
}
}
int main(){
int t;
cin>>t;
while(t--){
solve();
}
return 0;
}
距离查询
#include <algorithm> //STL通用算法
#include <cmath>//定义数学函数
#include <complex> //复数类
#include <cstdio>//定义输入/输出函数
#include <cstdlib>//定义杂项函数及内存分配函数
#include <cstring>//字符串处理
#include <deque> //STL双端队列容器
#include <map> //STL 映射容器
#include <iostream>//数据流输入/输出
#include <queue> //STL队列容器
#include <set> //STL 集合容器
#include <sstream>//基于字符串的流
#include <stack> //STL堆栈容器
#include <string>//字符串类
#include <vector>//STL动态数组容器
using namespace std;
typedef long long ll;
const int N=40010;
const int M=N+N;
int h[N], ne[M], to[M],val[M];
int cnt, root;
int fa[N][17];
ll dis[N][17];
void add(int a, int b,int c)
{
val[cnt]=c,to[cnt] = b,ne[cnt] = h[a],h[a] = cnt++;
}
queue<int> q;
int depth[N];
void bfs()
{
memset(depth,-1,sizeof(depth));
depth[root] = 1;
int now = root;
q.push(now);
while (!q.empty())
{
now = q.front();
q.pop();
for (int i = h[now]; i != -1; i = ne[i])
{
int j = to[i];
if(depth[j]!=-1) continue;
depth[j] = depth[now] + 1;
q.push(j);
fa[j][0] = now;
dis[j][0]=val[i];
for (int x = 1; x <= 16; x++)
{
fa[j][x] = fa[fa[j][x - 1]][x - 1];
dis[j][x]=dis[fa[j][x-1]][x-1]+dis[j][x-1];
}
}
}
}
int lca(int x, int y)
{
int res=0;
if (depth[x] < depth[y])
swap(x, y);
for (int i = 16; i >= 0; i--)
{
if (depth[fa[x][i]] >= depth[y])
{
res+=dis[x][i];
x=fa[x][i];
}
}
// cout<<x<<y<<endl;
// if(x==1&&y==1) cout<<dis[35][3]<<"????"<<endl<<endl;
for( int i=16;i>=0;i--){
if(fa[x][i]!=fa[y][i]){
res+=dis[x][i]+dis[y][i];
x=fa[x][i],y=fa[y][i];
}
}
if(x!=y){
res+=dis[x][0]+dis[y][0];
}
return res;
}
void solve( ){
//FILE *fp=fopen("test1.out","w");
int n,m,x,y,z;
memset(h,-1,sizeof(h));
memset(fa,0,sizeof(fa));
cin>>n;
for( int i=0;i<n-1;i++){
scanf("%d%d%d",&x,&y,&z);
add(x,y,z),add(y,x,z);
}
root=1,cnt=0;
bfs();
cin>>m;
for( int i=0;i<m;i++){
scanf("%d%d",&x,&y);
cout<<lca(x,y)<<'\n';
// fprintf(fp,"%d\n",lca(x,y));
}
}
int main(){
int t=1;
while(t--){
solve();
}
return 0;
}