题解:
定义二元组
(x,d)
(
x
,
d
)
为一个选择
x
x
为关键点, 为最大距离的染色方案。
先考虑全为黑色点的做法:
我们设
f(x,d)
f
(
x
,
d
)
表示与
x
x
距离不超过 的点的集合。不妨假设
f(x,d)
f
(
x
,
d
)
不是全集,这样答案就是所有本质不同的集合 f(x,d) 个数再加上 1(因为全集总是一种合法的染色方案)。
如果对于每个
x
x
直接求出有多少 的取值并累加显然是不对的,因为这样会有本质相同的方案被重复计算。
我们考虑对于每个点对
(i,j)
(
i
,
j
)
,染色方案
(i,d1)
(
i
,
d
1
)
和
(j,d2)
(
j
,
d
2
)
何时会重复:设
S=f(i,d1)=f(j,d2)
S
=
f
(
i
,
d
1
)
=
f
(
j
,
d
2
)
,那么对于任意
i
i
到 路径上的节点
k
k
均满足 ,其中
dist(i,k)
d
i
s
t
(
i
,
k
)
表示
i
i
在树上与 的距离。故我们找到路径上最接近
i
i
的节点(即与 相邻)的节点
k
k
,只需要关心这些节点即可。
我们试图计算有多少 满足
f(i,d)
f
(
i
,
d
)
不与任何
f(j,d2)
f
(
j
,
d
2
)
重复,据题意得:
1、
f(i,d)
f
(
i
,
d
)
不为全集。
2、对于任意与
i
i
相邻的 ,
f(i,d)≠f(k,d−1)
f
(
i
,
d
)
≠
f
(
k
,
d
−
1
)
。
第1种条件的解决方法很简单,我们对于每个节点
i
i
求出 到树上最远节点的距离
disi
d
i
s
i
,那么
d∈[0,disi−1]
d
∈
[
0
,
d
i
s
i
−
1
]
。
第2种条件可以等价于方案
(i,d)
(
i
,
d
)
中以
i
i
为根的树除了 所在子树中的节点之外均被染黑而方案
(k,d−1)
(
k
,
d
−
1
)
中这些节点也同样被染黑,那么我们对于每条边
(i,k)
(
i
,
k
)
求出从
i
i
出发不经过 到树上最远节点的距离
dis2(i,k)
d
i
s
2
(
i
,
k
)
,则
d∈[0,dis2i,k+1]。
d
∈
[
0
,
d
i
s
2
i
,
k
+
1
]
。
disi
d
i
s
i
和
dis2(i,k)
d
i
s
2
(
i
,
k
)
的计算均可以通过简单的树形DP在
O(N)
O
(
N
)
的时间内解决。
现在有些节点不能被直接被选定,我们需要重新考虑能成为合法染色方案的
f(i,d)
f
(
i
,
d
)
需要满足什么条件:
1、若
i
i
是一个特殊节点,则 可以成为任意自然数。
2、若
i
i
不是一个特殊节点,考虑将 作为整棵树的根,那么若存在一个特殊节点
j
j
满足方案
(i,d)
(
i
,
d
)
中
j
j
所在子树内所有节点均被染成黑色, 就是一个合法的染色方案。故我们只需要求出从
i
i
出发的至少经过 个特殊节点到达
j
j
子树中的最远节点的距离的最小值,就是可行的
d
d
的最小值。这也可以通过简单的树形DP在
O(N)
O
(
N
)
的时间内解决。
对于每个节点
i
i
求出
d
d
的下界后,问题就转化成了上一个问题。
#include <bits/stdc++.h>
using namespace std;
inline int rd() {
char ch=getchar(); int i=0,f=1;
while(!isdigit(ch)) {if(ch=='-')f=-1; ch=getchar();}
while(isdigit(ch)) {i=(i<<1)+(i<<3)+ch-'0'; ch=getchar();}
return i*f;
}
const int N=2e5+50,INF=0x3f3f3f3f;
int n,col[N],fa[N],sze[N];
int d1[N]; //子树最远点
int d2[N]; //父亲最远点
int d3[N]; //经过存在黑点的子树最近点
int d4[N]; //经过存在黑点的父亲最近点
int d5[N]; //不经过某子树最远点
int d6[N]; //下界
vector <int> edge[N];
inline void dfs1(int x,int f) {
fa[x]=f; sze[x]=col[x];
if((x!=1 && edge[x].size()==1) || (x==1 && !edge[x].size())) {
d1[x]=0;
d3[x]=(col[x]?0:INF);
return;
}
int mx=0,mx2=INF;
for(auto v:edge[x]) {
if(v==f) continue;
dfs1(v,x); sze[x]+=sze[v];
mx=max(mx,d1[v]);
if(d3[v]<INF) mx2=min(mx2,d1[v]);
}
d1[x]=mx+1; d3[x]=mx2+1;
if(col[x]) d3[x]=min(d3[x],d1[x]);
}
inline void dfs2(int x,int f) {
d2[x]=max(d2[x],(f!=0)+d2[f]);
if(f) d5[x]=max(d5[x],d2[x]-1);
int mx,mx2;
mx=-1;
for(int e=0;e<edge[x].size();++e) {
int v=edge[x][e]; if(v==f) continue;
d2[v]=max(d2[v],mx+2);
mx=max(mx,d1[v]);
}
mx=-1;
for(int e=edge[x].size()-1;e>=0;e--) {
int v=edge[x][e]; if(v==f) continue;
d2[v]=max(d2[v],mx+2);
mx=max(mx,d1[v]);
}
for(auto v:edge[x]) if(v!=f) dfs2(v,x);
}
int main() {
n=rd();
memset(d3,0x3f,sizeof(d3));
memset(d4,0x3f,sizeof(d4));
for(int i=1;i<n;i++) {
int x=rd(), y=rd();
edge[x].push_back(y);
edge[y].push_back(x);
}
string s; cin>>s; int bz=0;
for(int i=0;i<n;i++) col[i+1]=s[i]-'0', bz|=col[i+1];
if(!bz) {puts("0"); return 0;}
dfs1(1,0);
dfs2(1,0);
long long ans=0;
for(int i=1;i<=n;i++) {
if(col[i]) d6[i]=0;
else d6[i]=min(d3[i],(sze[i]==sze[1])?INF:d2[i]);
int u=max(d1[i],d2[i])-1;
for(auto v:edge[i]) {
if(v==fa[i]) u=min(u,d1[i]+1);
else u=min(u,d5[v]+1);
}
if(u>=d6[i]) ans=ans+u-d6[i]+1;
}
cout<<ans+1;
}