链接:https://ac.nowcoder.com/acm/contest/9753/C
来源:牛客网
题目描述
牛牛有一棵n个点的无权无根树,他想知道有多少个点在树的直径上,你可以帮帮他吗?
注意:树的直径可能不止一条。
示例1
输入
3,[1,2],[2,3]
返回值
3
说明
直径为1−2−3,三个节点均在直径上,故答案为3。直径为1-2-3,三个节点均在直径上,故答案为3。直径为1−2−3,三个节点均在直径上,故答案为3。
备注:
2≤n≤1e5
思路:
【错误想法及代码】这道题让我卡了整整一天,我之前的思路是通过两次dfs找到树的直径长度以及直径的某一个端点【不懂树的直径的求法可以参考我这篇博客】。然后从某一个端点出发进行dfs并记录路径长度以及之后最多能走多深,如果这两个值的和等于树的直径,则这个点就一定是直径上的点,具体代码如下:
class Solution {
public:
/**
* 代码中的类名、方法名、参数名已经指定,请勿修改,直接返回方法规定的值即可
*
* @param n int整型 节点个数
* @param u int整型vector
* @param v int整型vector
* @return int整型
*/
int ans,mxR;
int depth[100005];
vector<int> edges[100005];
int dfs1(int u,int f,int now){
int res=0;
int size=edges[u].size();
for(int i=0;i<size;i++){
int v=edges[u][i];
if(v==f) continue;
dfs1(v,u,now+1);
res=max(res,depth[v]);
}
if(res+now==mxR)
ans++;
return depth[u]=res+1;
}
void dfs(int u,int f){
depth[u]=depth[f]+1;
int size=edges[u].size();
for(int i=0;i<size;i++){
int v=edges[u][i];
if(v==f) continue;
dfs(v,u);
}
}
int PointsOnDiameter(int n, vector<int>& u, vector<int>& v) {
// write code here
ans=mxR=0;
int len=u.size();
for(int i=0;i<len;i++){
edges[u[i]].push_back(v[i]);
edges[v[i]].push_back(u[i]);
}
memset(depth, 0, sizeof(depth));
dfs(1,0);
int start=1;
for(int i=1;i<=n;i++)
if(depth[i]>depth[start])
start=i;
dfs(start,0);
for(int i=1;i<=n;i++)
mxR=max(mxR,depth[i]);
memset(depth, 0, sizeof(depth));
dfs1(start,0,1);
return ans;
}
};
然而只能过91%的样例,卡了一晚上不知道错在什么地方【最坑的是发现比赛中许多人竟然用这种方法过了,可能后期增加数据了,但比赛没有重判!】
睡了个午觉突然发现了一个卡掉我方法的样例,其示意图如下:
按道理说所有点的都在直径上的,然而我的方法一定会遗漏一个点!【这点留给读者想】
【AC】进一步改善代码,我们从直径的一个端点走到直径的中间,然后从中间为起点进行dfs,并找到所有能走到深度为直径的二分之一的点,并统计数量即可。
class Solution {
public:
/**
* 代码中的类名、方法名、参数名已经指定,请勿修改,直接返回方法规定的值即可
*
* @param n int整型 节点个数
* @param u int整型vector
* @param v int整型vector
* @return int整型
*/
int depth[100005];
int f[100005],g[100005];
vector<int> edges[100005];
void dfs(int u,int fa){
depth[u]=depth[fa]+1;
f[u]=fa;
g[u]=depth[u];
int size=edges[u].size();
for(int i=0;i<size;i++){
int v=edges[u][i];
if(v==fa) continue;
dfs(v,u);
g[u]=max(g[u], g[v]);
}
}
int PointsOnDiameter(int n, vector<int>& u, vector<int>& v) {
// write code here
int ans=0;
for(int i=0;i<n-1;i++){
edges[u[i]].push_back(v[i]);
edges[v[i]].push_back(u[i]);
}
dfs(1,0);
int start=1;
for(int i=1;i<=n;i++)
if(depth[i]>depth[start])
start=i;
dfs(start,0);
for(int i=1;i<=n;i++)
if(depth[i]>depth[start])
start=i;
int d=depth[start];
if(d%2==0){
for(int i=1;i<d/2;i++)
start=f[start];
depth[f[start]]=0;
dfs(start,f[start]);
depth[start]=0;
dfs(f[start],start);
for(int i=1;i<=n;i++)
if(g[i]==d/2)
ans++;
}
else{
for(int i=1;i<=d/2;i++)
start=f[start];
dfs(start,0);
for(int i=1;i<=n;i++)
if(g[i]==d/2+1)
ans++;
}
return ans;
}
};