题意:
一颗有向边的树,根节点为 1,q次询问,询问以该点为根节点的子树的重心编号
暴力思路:
对于每一个结点进行一次DFS,深搜的时间复杂度是O(n)
于是改算法的时间复杂度是O(n^2)。对于结点树n是3十万,显然会超时但还是给出代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
const int INF = 0x3f3f3f3f;
const int maxn = 300010;
int fa[maxn],siz[maxn],mu[maxn],ans[maxn];
vector<int> way[maxn];
inline void init(int n)
{
for(int i=0;i<maxn;i++) way[i].clear();
fa[1] = 0;
siz[1] = n;
}
inline void add(int u,int v)
{
way[u].push_back(v);
way[v].push_back(u);
}
void dfsroot(int u,int fa,int allnode,int &root)
{
siz[u] = 1;
mu[u] = 0;
for(int i=0;i<way[u].size();i++)
{
int to = way[u][i];
if(to == fa) continue;
dfsroot(to,u,allnode,root);
siz[u] += siz[to];
mu[u] = max(mu[u],siz[to]);
}
mu[u] = max(mu[u],allnode - siz[u]);
if(mu[u] < mu[root]) root = u;
}
void getroot(int u,int fa)
{
int root;
mu[root = 0] = INF;
dfsroot(u,fa,siz[u],root);
ans[u] = root;
}
int main()
{
int n,q;
while(~scanf("%d%d",&n,&q))
{
init(n);
for(int i=2;i<=n;i++) {
scanf("%d",&fa[i]);
add(i,fa[i]);
}
for(int i=1;i<=n;i++) getroot(i,fa[i]);
for(int i=1,x;i<=q;i++){
scanf("%d",&x);
printf("%d\n",ans[x]);
}
}
return 0;
}
思路二:
利用树的重心定义和性质
定义1:树的重心定义为:找到一个点,其所有的子树中最大的子树节点数最少,那么这个点就是这棵树的重心,删去重心后,生成的多棵树尽可能平衡.
定义2:以这个点(重心)为根,那么所有的子树(不算整个树自身)的大小都不超过整个树大小的一半。(size[v]*2<=size[root],root是树的重心,v的root的子结点)
性质1:树中所有点到某个点的距离和中,到重心的距离和是最小的,如果有两个距离和,他们的距离和一样。
性质2:把两棵树通过某一点相连得到一颗新的树,新的树的重心必然在连接原来两棵树重心的路径上。
性质3:一棵树添加或者删除一个节点,树的重心最多只移动一条边的位置。
本题用到了定义2和性质2;对于结点u要求以u为根节点的重心ans[u],现在我们已知其所有子树的大小。由定义2可知,对于最大的子树v,如果size[v] * 2 <= size[u] ,则结点u就是以u为根节点的树的重心。ans[u] = ans[u]
若size[v] * 2 > size[u],由性质2可知重心一定在v子树的重心到结点u的路径上。
先令u的重心等于ans[v],即ans[u] = ans[v],通过回溯时,(size[u] - size[ans[u]]) * 2 <= size[u]来判断ans[u] 是否为u树重心。
(size[ans[u]] * 2 = size[ans[v]] * 2 <= size[v] < size[u], 则 size[ans[u]] * 2 < size[u],于是只需要判断剩下的子树是否满足。)
如果满足 (size[u] - size[ans[u]]) * 2 <= size[u] 则ans[u]成立。
否则ans[u] = fa[ans[u]] ,一直向上寻找其重心。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
const int maxn = 300100;
int fa[maxn],siz[maxn],ans[maxn];
vector<int> way[maxn];
void dfs(int u)
{
siz[u] = 1;
ans[u] = u;
int root = 0;
for(int i=0;i<way[u].size();i++){
int v = way[u][i];
dfs(v);
if(siz[v] > siz[root]) root = v;
siz[u] += siz[v];
}
if(siz[root]*2>siz[u]) ans[u] = ans[root];
while((siz[u] - siz[ans[u]])*2>siz[u]) ans[u] = fa[ans[u]];
}
int main()
{
int n,q;siz[0] = 0;
scanf("%d%d",&n,&q);
for(int i=2;i<=n;i++) {
scanf("%d",&fa[i]);
way[fa[i]].push_back(i);
}
dfs(1);
while(q--){
int x;scanf("%d",&x);
printf("%d\n",ans[x]);
}
return 0;
}