题意:有一颗有n个结点的树,树上存在一个污染源(位置不确定),它可以污染与它距离不超过d的节点,现给出m个被污染的节点(污染源本身也可能是被污染的节点),求污染源可能的位置数。
解题思路:很明显如果一个点能到最远的污染源那么其他都能到达。现在题目就转变成对每一个点求距离自己最远的污染点这就类似HDU2916[求所有点能到的最远距离]:不过这里要特判一下下面是否有污染点才能更新结果
#include <iostream>
#include <cstdio>
#include <stack>
#include <sstream>
#include <limits.h>
#include <vector>
#include <map>
#include <cstring>
#include <deque>
#include <cmath>
#include <iomanip>
#include <unordered_map>
#include <queue>
#include <algorithm>
#include <set>
#define mid ((l + r) >> 1)
#define Lson rt << 1, l , mid
#define Rson rt << 1|1, mid + 1, r
#define ms(a,al) memset(a,al,sizeof(a))
#define log2(a) log(a)/log(2)
#define _for(i,a,b) for( int i = (a); i < (b); ++i)
#define _rep(i,a,b) for( int i = (a); i <= (b); ++i)
#define for_(i,a,b) for( int i = (a); i >= (b); -- i)
#define rep_(i,a,b) for( int i = (a); i > (b); -- i)
#define lowbit(x) ((-x) & x)
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define INF 0x3f3f3f3f
#define LLF 0x3f3f3f3f3f3f3f3f
#define hash Hash
#define next Next
#define pb push_back
#define f first
#define s second
#define y1 Y
using namespace std;
const int N = 1e7 + 10, MOD = 1e9 + 7;
const int maxn = 4e5 + 10;
const long double eps = 1e-5;
const int EPS = 500 * 500;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PII;
typedef pair<ll,ll> PLL;
typedef pair<double,double> PDD;
template<typename T> void read(T &x)
{
x = 0;char ch = getchar();ll f = 1;
while(!isdigit(ch)){if(ch == '-')f*=-1;ch=getchar();}
while(isdigit(ch)){x = x*10+ch-48;ch=getchar();}x*=f;
}
template<typename T, typename... Args> void read(T &first, Args& ... args)
{
read(first);
read(args...);
}
vector<int>G[maxn << 1];
int down[maxn << 1][2], up[maxn << 1][2];
//0最长路,1是次长路
int n, m, k;
bool goal[maxn];
int depth[maxn], poi;
vector<int> ans[maxn];
inline void dfs1(int u, int fa)//第一次是所有节点向下搜索的遇到的最大的距离
{
if(goal[u])
down[u][0] = down[u][1] = 0;
for(auto it : G[u])
{
if(it == fa) continue;
dfs1(it,u);
if(down[it][0] == -1) continue;
if(down[it][0] + 1 >= down[u][0])
{
down[u][1] = down[u][0];
down[u][0] = down[it][0] + 1;
}
else if(down[it][0] + 1 > down[u][1])
down[u][1] = down[it][0] + 1;
}
}
inline void dfs2(int u, int fa)
{
if(down[fa][0] != -1)
{
if(down[u][0] + 1 != down[fa][0])
{
if(down[u][0] < down[fa][0] + 1)
{
down[u][1] = down[u][0];
down[u][0] = down[fa][0] + 1;
}
else if(down[fa][0] + 1 > down[u][1])
down[u][1] = down[fa][0] + 1;
}
else
{
if(down[fa][1] != -1)
{
if(down[fa][1] + 1 > down[u][0])
{
down[u][1] = down[u][0];
down[u][0] = down[fa][1] + 1;
}
else if(down[fa][1] + 1 > down[u][1])
down[u][1] = down[fa][1] + 1;
}
}
}
for(auto it : G[u])
{
if(it == fa) continue;
dfs2(it,u);
}
}
int main()
{
ms(down,-1);
read(n,m,k);
for(int i = 0; i < m; ++ i)
{
int x;
read(x);
goal[x] = true;
}
for(int i = 1; i < n; ++ i)
{
int l, r;
read(l,r);
G[l].pb(r);
G[r].pb(l);
}
dfs1(1,0);
// for(int i = 1; i <= n; ++ i)
// {
// cout << down[i][0] << " = max " << down[i][1] << endl;
// }
// cout << endl;
for(auto it : G[1])
dfs2(it,1);
int ans = 0;
for(int i = 1; i <= n; ++i)
if(down[i][0] <= k)
ans ++;
cout << ans << endl;
// for(int i = 1; i <= n; ++ i)
// {
// cout << down[i][0] << " = max\n";
// }
return 0;
}