题意:
有一颗有n个结点的树,树上存在一个污染源(位置不确定),它可以污染与它距离不超过d的节点,现给出m个被污染的节点(污染源本身也可能是被污染的节点),求污染源可能的位置数。
分析:
解法一:
树的最长链问题,还是老套路,不过最长链求得是最远的节点,这个题换成了最远的被污染源,状态方程一样,
表示以第 i 个节点为根节点的子树距离 i 的最远污染源
表示以第 i 个节点为根节点的子树距离 i 的次远污染源
表示 i 经过 i 的父节点到最远污染源的距离
还有就是注意细节。
解法二:
计算两个相距最远的被污染体,可以容易想到如果污染源能将这两个污染,那么其他的肯定也能被污染。所以类似于计算树的直接直径的方法计算两个最远的污染点到每个点的距离即可。
代码实现:
解法一:
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
int head[maxn];
int tot;
int n,m,d;
bool vis[maxn];
int dp[maxn][3];
struct Edge{
int to;
int next;
}edge[maxn<<1];
void addedge(int u,int v){
edge[tot].to=v;
edge[tot].next=head[u];
head[u]=tot++;
}
void init(){
tot=0;
memset(head,-1,sizeof head);
memset(vis,0,sizeof vis);
}
void dfs1(int u,int fa){
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].to;
if(v==fa) continue;
if(vis[v]) dp[v][0]=0;
dfs1(v,u);
if(dp[v][0]>=0){
if(dp[v][0]+1>=dp[u][0]){
dp[u][1]=dp[u][0];
dp[u][0]=dp[v][0]+1;
}else if(dp[v][0]+1>dp[u][1]){
dp[u][1]=dp[v][0]+1;
}
}
}
}
void dfs2(int u,int fa){
if(vis[u]) dp[u][2]=max(0,dp[u][2]);
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].to;
if(v==fa) continue;
int tmp;
if(dp[u][0]==dp[v][0]+1){
tmp=max(dp[u][1],dp[u][2]);
}else{
tmp=max(dp[u][0],dp[u][2]);
}
//答案合法。
if(tmp>=0) dp[v][2]=tmp+1;
dfs2(v,u);
}
}
int main(){
init();
scanf("%d%d%d",&n,&m,&d);
for(int i=0;i<m;i++){
int a;scanf("%d",&a);
vis[a]=1;
}
for(int i=0;i<n-1;i++){
int u,v;
scanf("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
for(int i=1;i<=n;i++){
dp[i][0]=dp[i][1]=dp[i][2]=-111;
}
dfs1(1,1);
dfs2(1,1);
int ans=0;
for(int i=1;i<=n;i++){
if(max(dp[i][0],dp[i][2])<=d) ans++;
}
printf("%d\n",ans);
return 0;
}
解法二:
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
int n,m,d;
int dis1[maxn],dis2[maxn];
int head[maxn],tot;
int p[maxn];
struct Edge{
int to,next;
}edge[maxn<<1];
void init(){
memset(head,-1,sizeof head);
tot=0;
}
void addedge(int u,int v){
edge[tot].to=v;
edge[tot].next=head[u];
head[u]=tot++;
}
void dfs(int u,int fa,int dep,int dis[]){
dis[u]=dep;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].to;
if(v==fa) continue;
dfs(v,u,dep+1,dis);
}
}
int main(){
init();
scanf("%d%d%d",&n,&m,&d);
for(int i=0;i<m;i++){
scanf("%d",&p[i]);
}
for(int i=0;i<n-1;i++){
int u,v;
scanf("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
dfs(1,1,0,dis1);
int a=p[0];
for(int i=1;i<m;i++){
if(dis1[p[i]]>dis1[a]) a=p[i];
}
dfs(a,a,0,dis1);
int b=p[0];
for(int i=1;i<m;i++){
if(dis1[p[i]]>dis1[b]) b=p[i];
}
dfs(b,b,0,dis2);
int ans=0;
for(int i=1;i<=n;i++){
if(dis1[i]<=d&&dis2[i]<=d) ans++;
}
printf("%d\n",ans);
return 0;
}