一、题目
题目描述
给定一棵
n
n
n个点的树,每个点有一个姓氏(不多于
5
5
5个字符),有
m
m
m个问题,每次问两个姓氏距离的最大值,如果没有最大值,输出
−
1
-1
−1,否则输出最大值。
数据范围
1
≤
n
,
m
≤
1
0
5
1\leq n,m\leq10^5
1≤n,m≤105
二、解法
首先这个最大距离是跟直径有关系的,我们考虑求出每一个姓氏的直径,最大值一定存在于 A A A姓氏两个直径点的和 B B B姓氏两个直径点的四种组合,下面给出证明。
先考虑一个点到一个姓氏的情况,考虑使用反证法。
B
1
B
2
B_1B_2
B1B2是直径,
O
O
O是中转点,假设
A
O
+
O
B
3
>
A
O
+
O
B
1
AO+OB_3>AO+OB_1
AO+OB3>AO+OB1或者
A
O
+
O
B
3
>
A
O
+
O
B
2
AO+OB_3>AO+OB_2
AO+OB3>AO+OB2,那么
O
B
3
>
O
B
1
OB_3>OB_1
OB3>OB1或
O
B
3
>
O
B
2
OB_3>OB_2
OB3>OB2,显然上述不可能成立,因为这样的话
B
3
B_3
B3就一定直径的端点,故假设不成立,所以对于单点来说,取直径的端点是最优的。
对于姓氏到姓氏的情况,类比上述方法,读者自证不难 。
所以现在的问题在于快速求出每个姓氏的直径,可以用虚树的方法,然后回答问题就很容易了,这道题给我们一个启示,虚树的思想还可以运用在预处理上,不一定要固定关键点数的询问。
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <vector>
#include <stack>
#include <map>
#define inf 0x3f3f3f3f
using namespace std;
const int MAXN = 100005;
int read()
{
int x=0,flag=1;char c;
while((c=getchar())<'0' || c>'9') if(c=='-') flag=-1;
while(c>='0' && c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
return x*flag;
}
int n,m,k,k1,Index,tot,cnt,dfin[MAXN],dfou[MAXN],f[MAXN],dep[MAXN];
int a[2*MAXN],dp[MAXN][2],fr[MAXN][2],dm[MAXN][2],vis[MAXN],fa[MAXN][20];
map<string,int> mp;string str;
vector<int> T[MAXN];
stack<int> s;
struct edge
{
int v,next;
}e[MAXN*2];
bool cmp(int a,int b)
{
int t1=a>0?dfin[a]:dfou[-a];
int t2=b>0?dfin[b]:dfou[-b];
return t1<t2;
}
void dfs(int u,int p)
{
dfin[u]=++cnt;
dep[u]=dep[p]+1;
fa[u][0]=p;
for(int i=1;i<20;i++)
fa[u][i]=fa[fa[u][i-1]][i-1];
for(int i=f[u];i;i=e[i].next)
{
int v=e[i].v;
if(v==p) continue;
dfs(v,u);
}
dfou[u]=++cnt;
}
int get(int u,int v)
{
if(dep[u]<dep[v]) swap(u,v);
for(int i=19;i>=0;i--)
if(dep[fa[u][i]]>=dep[v])
u=fa[u][i];
if(u==v) return u;
for(int i=19;i>=0;i--)
if(fa[u][i]^fa[v][i])
u=fa[u][i],v=fa[v][i];
return fa[u][0];
}
int dis(int u,int v)
{
return dep[u]+dep[v]-2*dep[get(u,v)];
}
int main()
{
while(~scanf("%d %d",&n,&m))
{
Index=cnt=tot=0;
mp.clear();
for(int i=1;i<=n;i++)
{
T[i].clear();
f[i]=0;
}
for(int i=1;i<=n;i++)
{
cin>>str;
if(!mp[str]) mp[str]=++Index;
T[mp[str]].push_back(i);
}
for(int i=2;i<=n;i++)
{
int u=read(),v=read();
e[++tot]=edge{v,f[u]},f[u]=tot;
e[++tot]=edge{u,f[v]},f[v]=tot;
}
dfs(1,0);
for(int i=1;i<=n;i++)
dp[i][0]=dp[i][1]=-inf;
for(int l=1;l<=Index;l++)
{
if(T[l].size()==1)
{
dm[l][0]=dm[l][1]=T[l][0];
continue ;
}
k=k1=T[l].size();
for(int i=1;i<=k;i++)
{
a[i]=T[l][i-1];vis[a[i]]=1;
dp[a[i]][0]=0;fr[a[i]][0]=a[i];
}
sort(a+1,a+1+k,cmp);
for(int i=1;i<k1;i++)
{
int t=get(a[i],a[i+1]);
if(!vis[t])
a[++k]=t,vis[t]=1;
}
if(!vis[1]) a[++k]=1,vis[1]=1;
k1=k;
for(int i=1;i<=k1;i++)
a[++k]=-a[i];
sort(a+1,a+1+k,cmp);
int Max=0;
for(int i=1;i<=k;i++)
{
if(a[i]>0)
s.push(a[i]);
else
{
int t=s.top();s.pop();
if(t==1) break;
int p=s.top(),tmp=dp[t][0]+dep[t]-dep[p];
if(tmp>dp[p][0])
{
fr[p][1]=fr[p][0];
dp[p][1]=dp[p][0];
dp[p][0]=tmp;
fr[p][0]=fr[t][0];
}
else
{
dp[p][1]=tmp;
fr[p][1]=fr[t][0];
}
if(dp[p][0]+dp[p][1]>Max)
{
Max=dp[p][0]+dp[p][1];
dm[l][0]=fr[p][0];
dm[l][1]=fr[p][1];
}
vis[t]=0;dp[t][0]=dp[t][1]=-inf;
}
}
vis[1]=0;dp[1][0]=dp[1][0]=-inf;
}
for(int i=1;i<=m;i++)
{
cin>>str;
int t1=mp[str];
cin>>str;
int t2=mp[str];
if(!t1 || !t2) puts("-1");
else printf("%d\n",1+max(max(dis(dm[t1][0],dm[t2][1]),dis(dm[t1][1],dm[t2][1])),max(dis(dm[t1][0],dm[t2][0]),dis(dm[t1][1],dm[t2][0]))));
}
}
}