链接:http://hihocoder.com/problemset/problem/1387
题意: 给一棵n个节点的树,每个节点上有个姓氏,Q个询问,询问两种姓氏,问从这两种姓氏中各取一个节点,这两点路径最大的节点是多少。
分析:我们如果知道了某种姓氏(a)的直径(在树中最远距离)的两个端点(au,av),那么询问它与另外一个姓氏(b)的答案就是max{dis(au,bu),dis(au,bv),dis(av,bu),dis(av,bv)}
。
求某个姓氏的直径可以用O(nlongn)的复杂度求:
在dfs中添加某种姓氏的节点(x),原直径(u,v),那么新的直径是(u,v),(x,u),(x,v)最大一个。
代码:
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<string>
#include<vector>
#include<queue>
#include<cmath>
#include<stack>
#include<set>
#include<map>
#define INF 0x3f3f3f3f
#define Mn 100010
#define Mm 2000005
#define mod 1000000007
#define CLR(a,b) memset((a),(b),sizeof((a)))
#define CLRS(a,b,Size) memset((a),(b),sizeof((a[0]))*(Size+1))
#define CPY(a,b) memcpy ((a), (b), sizeof((a)))
#pragma comment(linker, "/STACK:102400000,102400000")
#define ul u<<1
#define ur (u<<1)|1
using namespace std;
typedef long long ll;
struct edge {
int v,next;
}e[Mm];
int head[Mn],tot;
void addedge(int u,int v) {
e[tot].v=v;
e[tot].next=head[u];
head[u]=tot++;
}
int fa[Mn][20];
int deep[Mn];
void dfs(int u,int f,int de) {
fa[u][0]=f;
deep[u]=de;
for(int i=head[u];~i;i=e[i].next) {
int v=e[i].v;
if(v==f) continue;
dfs(v,u,de+1);
}
}
void build(int n) {
for(int j=1;j<19;j++) {
for(int i=1;i<=n;i++) {
fa[i][j]=fa[fa[i][j-1]][j-1];
}
}
}
int lca(int x,int y) {
int len=deep[x]-deep[y];
for(int i=0;len;i++) {
if(len&1) x=fa[x][i];
len>>=1;
}
for(int i=18;i>=0;i--) {
if(fa[x][i]!=fa[y][i]) {
x=fa[x][i];
y=fa[y][i];
}
}
if(x==y) return x;
return fa[x][0];
}
int getdis(int x,int y) {
if(deep[x]<deep[y]) swap(x,y);
int r=lca(x,y);
return deep[x]+deep[y]-2*deep[r]+1;
}
int Lpoint[Mn][2];
void getL(int cl,int u) {
if(Lpoint[cl][0]==-1){
Lpoint[cl][0]=u;
Lpoint[cl][1]=u;
return ;
}
int L12=getdis(Lpoint[cl][0],Lpoint[cl][1]);
int L1u=getdis(u,Lpoint[cl][0]);
int L2u=getdis(u,Lpoint[cl][1]);
//cout<<cl<<" "<<L12<<" "<<L1u<<" "<<L2u<<endl;
if(L1u>=L2u) {
if(L1u>L12)
Lpoint[cl][1]=u;
} else if(L2u>L12){
Lpoint[cl][0]=u;
}
}
int a[Mn];
void dfs1(int u,int f) {
getL(a[u],u);
for(int i=head[u];~i;i=e[i].next) {
int v=e[i].v;
if(v==f) continue;
dfs1(v,u);
}
}
void init() {
CLR(head,-1);
CLR(fa,-1);
CLR(Lpoint,-1);
tot=0;
}
map<string,int>mp;
int main() {
int n,m,u,v;
char sm[10];
while(~scanf("%d%d",&n,&m)) {
init();
getchar();
int cnt=1;
mp.clear();
for(int i=1;i<=n;i++) {
scanf("%s",sm);
if(mp[string(sm)]==0) {
mp[string(sm)]=cnt;
a[i]=cnt++;
} else a[i]=mp[string(sm)];
}
for(int i=1;i<n;i++) {
scanf("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
dfs(1,0,0);
build(n);
dfs1(1,0);
int st,ed;
getchar();
for(int i=0;i<m;i++) {
scanf("%s ",sm);st=mp[string(sm)];
scanf("%s",sm);ed=mp[string(sm)];
if(st==0||ed==0) {
printf("-1\n");
continue;
}
int a=Lpoint[st][0],b=Lpoint[st][1];
int c=Lpoint[ed][0],d=Lpoint[ed][1];
int ans=max(getdis(a,c),getdis(a,d));
ans=max(ans,max(getdis(b,c),getdis(b,d)));
printf("%d\n",ans);
}
}
return 0;
}