题意:
链接:https://ac.nowcoder.com/acm/contest/5278/G
来源:牛客网
Compute 有一棵 n 个点,编号分别为
1
∼
n
1\sim n
1∼n 的树,其中 s 号点为根。
Compute 在树上养了很多松鼠,在第 i 个点上住了
a
i
a_i
ai
个松鼠。
因为某些缘故,它们开始同时向根节点移动,但它们相当不安分,如果在同一个节点上,它们就会打起来,简单地来说以下事件会依序发生:
如果一个节点上有 2 只或 2 只以上的松鼠,他们会打架,然后这个节点上松鼠的数量会减少 1;
根节点的所有松鼠移动到地面,位于地面上的松鼠不会再打架;
所有松鼠同时朝它们的父节点移动。
所有事件各自都在一瞬间完成,直至树上没有松鼠。
现在 Compute 想知道最终有多少只松鼠到达了地面。
题解:
我是没想这道题目,感觉我可能也做不出来。
由于会聚集的只有相同深度的松鼠,那么一层一层做,然后枚举所有会将两拨松鼠聚集的点,这些点是按照dfs序排序之后相邻两个点的lca(这是个结论)。就两个点两个点并查集合并即可,同时减掉路上的节点数。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=4e5+5;
int fa[N][21],dep[N],dfn[N],tim;
ll a[N],sum[N];
struct node{
int to,next;
}e[N*2];
int cnt,head[N];
void add(int x,int y){
e[cnt].to=y;
e[cnt].next=head[x];
head[x]=cnt++;
}
void dfs(int x,int f){
dep[x]=dep[f]+1;
dfn[x]=tim++;
fa[x][0]=f;
for(int i=head[x];~i;i=e[i].next){
int ne=e[i].to;
if(ne==f)continue;
dfs(ne,x);
}
}
int lca(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=20;i>=0;i--)
if(dep[fa[x][i]]>=dep[y])
x=fa[x][i];
if(x==y)return x;
for(int i=20;i>=0;i--)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
int p[N];
int cmp(int x,int y){
if(dep[x]==dep[y])
return dfn[x]<dfn[y];
return dep[x]>dep[y];
}
struct Point{
int x,y,l;
bool operator< (const Point& pp)const {
if(dep[l]==dep[pp.l])
return dfn[l]<dfn[pp.l];
return dep[l]>dep[pp.l];
}
}tmp[N];
int f[N];
int finds(int x){return x==f[x]?f[x]:f[x]=finds(f[x]);}
int main()
{
memset(head,-1,sizeof(head));
int n,s,x,y;
scanf("%d%d",&n,&s);
for(int i=1;i<=n;i++)
scanf("%lld",&a[i]);
for(int i=1;i<n;i++)
scanf("%d%d",&x,&y),add(x,y),add(y,x);
dfs(s,0);
for(int j=1;(1<<j)<=n;j++)
for(int i=1;i<=n;i++)
fa[i][j]=fa[fa[i][j-1]][j-1];
int tot=0;
for(int i=1;i<=n;i++)
if(a[i])
p[++tot]=i;
sort(p+1,p+1+tot,cmp);
int i,j=1;
ll ans=0;
for(i=1;i<=tot;i=j){
while(dep[p[i]]==dep[p[j]])j++;
if(j==i+1){
ans+=max(1ll,a[p[i]]-dep[p[i]]);
continue;
}
int all=0;
for(int k=i;k<j-1;k++){
f[p[k]]=p[k],f[p[k+1]]=p[k+1];
sum[p[k]]=a[p[k]],sum[p[k+1]]=a[p[k+1]];
tmp[++all].x=p[k],tmp[all].y=p[k+1];
tmp[all].l=lca(p[k],p[k+1]);
f[tmp[all].l]=tmp[all].l;
sum[tmp[all].l]=0;
}
sort(tmp+1,tmp+1+all);
for(int k=1;k<=all;k++){
int fax=finds(tmp[k].x),fay=finds(tmp[k].y);
if(fax!=tmp[k].l)
sum[tmp[k].l]+=max(1ll,sum[fax]-(dep[fax]-dep[tmp[k].l])),f[fax]=tmp[k].l;
if(fay!=tmp[k].l)
sum[tmp[k].l]+=max(1ll,sum[fay]-(dep[fay]-dep[tmp[k].l])),f[fay]=tmp[k].l;
}
ans+=max(1ll,sum[tmp[all].l]-dep[tmp[all].l]);
}
printf("%lld\n",ans);
return 0;
}