龙龙是“饱了呀”外卖软件的注册骑手,负责送帕特小区的外卖。帕特小区的构造非常特别,都是双向道路且没有构成环 —— 你可以简单地认为小区的路构成了一棵树,根结点是外卖站,树上的结点就是要送餐的地址。
每到中午 12 点,帕特小区就进入了点餐高峰。一开始,只有一两个地方点外卖,龙龙简单就送好了;但随着大数据的分析,龙龙被派了更多的单子,也就送得越来越累……
看着一大堆订单,龙龙想知道,从外卖站出发,访问所有点了外卖的地方至少一次(这样才能把外卖送到)所需的最短路程的距离到底是多少?每次新增一个点外卖的地址,他就想估算一遍整体工作量,这样他就可以搞明白新增一个地址给他带来了多少负担。
输入格式:
输入第一行是两个数 N 和 M (2≤N≤1e5, 1≤M≤1e5),分别对应树上节点的个数(包括外卖站),以及新增的送餐地址的个数。
接下来首先是一行 N 个数,第 i 个数表示第 i 个点的双亲节点的编号。节点编号从 1 到 N,外卖站的双亲编号定义为 −1。
接下来有 M 行,每行给出一个新增的送餐地点的编号 Xi。保证送餐地点中不会有外卖站,但地点有可能会重复。
为了方便计算,我们可以假设龙龙一开始一个地址的外卖都不用送,两个相邻的地点之间的路径长度统一设为 1,且从外卖站出发可以访问到所有地点。
注意:所有送餐地址可以按任意顺序访问,且完成送餐后无需返回外卖站。
输出格式:
对于每个新增的地点,在一行内输出题目需要求的最短路程的距离。
输入样例:
7 4
-1 1 1 1 2 2 3
5
6
2
4
输出样例:
2
4
4
6
解题思路:这道题本身并不难,只是题目描述的很难理解
换句话说这道题要我们实现的就是两个事情,第一个事情就是统计对于每一个询问我们在树经过的边数,第二个事情就是计算给定的询问的最大深度,因此对于每一询问那么计算公式就是所有走过的边数 * 2 减去 点餐地到原点最大距离,就是每一组询问的最短距离。
对于这两个不同事件一开始并不能发现有什么联系,那首先一定先想到的就是两个dfs,分别去求每一个事件的值。
显然可以写出下面的代码
void dfs(int u , int root , int aim)
{
if(u == root || st[u])
{
dist += aim;
return ;
}
st[u] = true;
dfs(fa[u] , root , aim + 2);
}
int get_max(int u , int sum , int last_aim)
{
if(u == last_aim) return sum;
return get_max(fa[u] , sum + 1 , last_aim);
}
毋庸置疑,肯定是超时了
那么,下来一定就是优化、优化再优化。
我们可以发现这两个dfs都是从给定询问,去寻找根,那么这两个dfs可以到一起,我们发现,当进行dfs的时候,每一次都向上走显然经过了一条边,因此可以直接统计经过的所有边数,然后返回这个订餐地点离根的最长的距离。因此优化可以为下面的代码
int dfs(int u) //点餐地到原点最大距离
{
// u已经到达根节点 或者是 距离已经被更新过了(已经经过了)
if(fa[u] == -1 || di[u]) return di[u];
dist ++;
di[u] = dfs(fa[u]) + 1; // 意思就是点餐地离原点更近了一步
return di[u];
}
这样得到边数和最长距离就不会超时了,然后带入公式,最后顺利的得到答案。
以下是完整代码
#include<iostream>
#include<vector>
#include<unordered_map>
#include<cstring>
using namespace std;
const int N = 1e5 + 10;
int n , m;
unordered_map<int , int>fa;
bool st[N];
int max_dist = 0;
int dist = 0;
//每个点到驿站的边数(边的权重为1)之和
//假设走完所有点回到原点,那么走过的距离就是所有走过的边数 * 2
//因为我们可以不用回到原点,所以res = dist * 2 - max_dist(点餐地到原点最大距离)
int di[N] = {0};
int dfs(int u) //点餐地到原点最大距离
{
// u已经到达根节点 或者是 距离已经被更新过了(已经经过了)
if(fa[u] == -1 || di[u]) return di[u];
dist ++;
di[u] = dfs(fa[u]) + 1; // 意思就是点餐地离原点更近了一步
return di[u];
}
/*
将get_max归到dfs
int get_max(int u , int sum , int last_aim)
{
if(u == last_aim) return sum;
return get_max(fa[u] , sum + 1 , last_aim);
}
*/
int main()
{
scanf("%d %d" ,&n ,&m);
int root = 0;
for(int i = 1;i <= n;i ++) scanf("%d" ,&fa[i]);
while(m --)
{
int num;
scanf("%d" ,&num);
max_dist = max(max_dist , dfs(num));
//int t = get_max(num , 0 , root);
//max_dist = max(max_dist , t);
printf("%d\n" , dist * 2 - max_dist);
}
return 0;
}