题意:有 n n n 个城市连成一棵树,每个城市有 a i a_i ai 个人。接下来 m m m 天每天会发生 k i k_i ki 次灾难,每个灾难会让一个给定城市的人全部死掉。每个人一天可以走一条边,也可以不动。求最多多少人能活过这 m m m 天。
n ≤ 1 0 6 , ∑ k i ≤ 2 × 1 0 6 n\leq 10^6,\sum k_i\leq 2\times 10^6 n≤106,∑ki≤2×106
倒过来考虑,维护在哪些位置的人可以活到最后,相当于要支持以下操作
- 把一个点染黑。
- 所有白点扩展 1 1 1 步。
对于 2 2 2 操作,一个黑点会变白当且仅当和一个原来的白点相邻。考虑分别维护这个白点是父亲还是儿子。
需要维护的东西:
h u h_u hu 表示真实情况下 u u u 的白儿子个数。
n n n 个 vector 记录每个点 可能 需要被自己更新为白点的所有黑儿子。
封装两个线性结构 q , s q q,sq q,sq ,记录可能会被儿子更新为白点的结点、可能会用来更新黑儿子的结点。
染黑的时候,自己进 q q q,父亲进 s q sq sq,更新父亲的 h h h 和 vector。
然后这一天的行动,把 q q q 和 s q sq sq 一个个取出来判断是否合法并更新自己或 vector 中的儿子,需要染白的记下来稍后处理。
染白的时候,自己进 s q sq sq,父亲进 q q q,更新父亲的 h h h。因为这里更新不了 vector,只能不删,所以不能用 vector 的大小来代替 h h h。
复杂度 O ( n + m + ∑ k i ) O(n+m+\sum k_i) O(n+m+∑ki)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
#define MAXN 1000005
using namespace std;
inline int read()
{
int ans=0;
char c=getchar();
while (!isdigit(c)) c=getchar();
while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
return ans;
}
typedef long long ll;
vector<int> e[MAXN],lis[MAXN],son[MAXN];
int a[MAXN],col[MAXN],fa[MAXN],h[MAXN];
struct que
{
int a[MAXN],vis[MAXN],cnt;
inline void push(int x){if (!vis[x]) vis[a[++cnt]=x]=1;}
inline int pop(){if (cnt) return vis[a[cnt]]=0,a[cnt--];return 0;}
}q,sq,ans;
void dfs(int u,int f){fa[u]=f;for (int i=0;i<(int)e[u].size();i++) if (e[u][i]!=f) ++h[u],dfs(e[u][i],u);}
inline void die(int u)
{
if (col[u]) return;
col[u]=1;
q.push(u);
if (fa[u])
{
son[fa[u]].push_back(u),--h[fa[u]];
sq.push(fa[u]);
}
}
inline void live(int u)
{
if (!col[u]) return;
col[u]=0,sq.push(u);
if (fa[u]) ++h[fa[u]],q.push(fa[u]);
}
int main()
{
int n,m;
n=read(),m=read();
for (int i=1;i<=n;i++) a[i]=read();
for (int i=1;i<n;i++)
{
int u,v;
u=read(),v=read();
e[u].push_back(v),e[v].push_back(u);
}
dfs(1,0);
for (int i=1;i<=m;i++)
{
lis[i].resize(read());
for (int j=0;j<(int)lis[i].size();j++) lis[i][j]=read();
}
for (int T=m;T>=1;T--)
{
for (int i=0;i<(int)lis[T].size();i++) die(lis[T][i]);
for (int u=q.pop();u;u=q.pop()) if (h[u]) ans.push(u);
for (int u=sq.pop();u;u=sq.pop())
if (!col[u])
{
for (int i=0;i<(int)son[u].size();i++)
ans.push(son[u][i]);
son[u].clear();
}
for (int u=ans.pop();u;u=ans.pop()) live(u);
}
ll ans=0;
for (int i=1;i<=n;i++) if (!col[i]) ans+=a[i];
cout<<ans;
return 0;
}