题目大意
给出
N
N
个节点的一棵树,每个节点有开和关两种状态,且每个点有权为,给出
M
M
个操作,每次改变一个节点的状态(给出负数则关,正数则开),求每次操作过后,有多少个节点满足,自己处于开状态,且子树内处于关状态的节点数严格大于该节点的.
N,M<=100000
N
,
M
<=
100000
时限五秒
Analysis
考虑到每次修改,只会影响到其到根节点的节点对答案的贡献,每一次操作可以看成对一个节点的 ti t i 加减,求的是有多少个节点的 ti<0 t i < 0 。这个东西用数据结构及其难维护。再分析一下,每一次操作影响到的点仅是一部分点,很多点的贡献未变。那么是否可以一次计算多次操作的答案。我们可以把操作分块,每 M−−√ M 个操作放到一起做。具体做法是:考虑修改的节点提出来,建成虚树,那么虚树上每个儿子到父亲的边上的原树的点,对于一次影响它们的操作,他们的变化量是相同的。那么可以把它们拉出来排个序,去重,每条边维护一个指针,指向它们的零分界线,就可以算出答案(注意点的开关状态),修改点的答案要单独算。复杂度约为 O(NN−−√) O ( N N ) ,常数比较大,排序的时候用基数排序才能达到准确的复杂度,由于偷懒,用了快排。
代码
# include<cstdio>
# include<cstring>
# include<algorithm>
# include<cmath>
# include<vector>
using namespace std;
const int N = 1e5 + 5;
int sta[N],dfn[N],fa[N][25],dep[N],p[N][3];
int t[N],b[N],a[N],e[N],vis[N],col[N],out[N];
int st[N],to[N << 1],nx[N << 1],ans[N],q[N],f[N];
int n,m,tot,len,cnt,h;
vector <int> inl[N];
bool cmp(int x,int y) { return dfn[x] < dfn[y]; }
bool cmp1(int x,int y) { return t[x] < t[y]; }
void add(int u,int v)
{
to[++tot] = v,nx[tot] = st[u],st[u] = tot;
to[++tot] = u,nx[tot] = st[v],st[v] = tot;
}
void dfs(int u)
{
dfn[u] = ++tot,dep[u] = dep[fa[u][0]] + 1;
for (int i = 1 ; i <= 18 ; ++i) fa[u][i] = fa[fa[u][i - 1]][i - 1];
for (int i = st[u] ; i ; i = nx[i])
if (to[i] != fa[u][0]) dfs(to[i]);
out[u] = tot;
}
inline int getlca(int x,int y)
{
if (dep[x] < dep[y]) swap(x,y);
for (int i = 18 ; ~i ; --i)
if (dep[fa[x][i]] >= dep[y]) x = fa[x][i];
if (x == y) return x;
for (int i = 18 ; ~i ; --i)
if (fa[x][i] != fa[y][i]) x = fa[x][i],y = fa[y][i];
return fa[x][0];
}
inline void build(int l,int r)
{
memset(f,0,sizeof(f)),memset(col,0,sizeof(col));
int top = 0; b[h = 1] = 1;
for (int i = l ; i <= r ; ++i) b[++h] = q[i];
sort(b + 1,b + h + 1); h = unique(b + 1,b + h + 1) - b - 1;
sort(b + 1,b + h + 1,cmp);
int h1 = h;
for (int i = 1 ; i <= h; ++i)
{
if (!top) { sta[++top] = b[i]; continue; }
int x = b[i],lca = getlca(x,sta[top]);
while (dfn[lca] < dfn[sta[top]])
{
if (dfn[sta[top - 1]] <= dfn[lca])
{
f[sta[top--]] = lca;
if (sta[top] != lca) sta[++top] = lca,b[++h1] = lca;
break;
}else f[sta[top]] = sta[top - 1],--top;
}
sta[++top] = x;
}
while (top > 1) f[sta[top]] = sta[top - 1],--top;
h = h1;
sort(b + 1,b + h + 1,cmp);
for (int i = 1 ; i <= h ; ++i) col[b[i]] = 1;
}
inline void solve(int l,int r)
{
for (int i = 1 ; i <= n ; ++i) a[i] = i,e[i] = 0;
sort(a + 1,a + n + 1,cmp1);
for (int i = 1 ; i <= h ; ++i)
for (int j = fa[b[i]][0] ; j != f[b[i]] ; j = fa[j][0])
e[j] = b[i];
int ret = 0;
for (int i = 1 ; i <= n ; ++i)
if (e[a[i]]) inl[e[a[i]]].push_back(a[i]);
else ret += (!vis[a[i]] && !col[a[i]] && t[a[i]] < 0);
for (int i = l ; i <= r ; ++i) ans[i] = ret;
for (int i = 1 ; i <= h ; ++i)
{
int ed = 0,x = b[i];
for (int j = 0 ; j < inl[x].size() ; ++j)
if (ed && t[inl[x][j]] == t[inl[x][j - 1]]) p[ed][1] += (!vis[inl[x][j]]);
else p[++ed][0] = inl[x][j],p[ed][1] = (!vis[inl[x][j]]);
int p1 = 0,tag = 0,sum = 0;
while (p1 < ed && t[p[p1 + 1][0]] < 0) ++p1,sum += p[p1][1];
for (int j = l ; j <= r ; ++j)
{
vis[q[j]] ^= 1;
if (dfn[q[j]] >= dfn[x] && dfn[q[j]] <= out[x])
{
t[x] += (vis[q[j]]) ? -1 : 1;
if (vis[q[j]]) --tag;
else ++tag;
}
while (p1 && t[p[p1][0]] + tag >= 0) sum -= p[p1][1],--p1;
while (p1 < ed && t[p[p1 + 1][0]] + tag < 0) ++p1,sum += p[p1][1];
ans[j] += sum;
if (t[x] < 0 && !vis[x]) ++ans[j];
}
for (int j = l ; j <= r ; ++j) vis[q[j]] ^= 1;
for (int j = 0 ; j < inl[x].size() ; ++j) t[inl[x][j]] += tag;
}
for (int i = l ; i <= r ; ++i) vis[q[i]] ^= 1;
for (int i = 1 ; i <= h ; ++i) inl[b[i]].clear();
}
int main()
{
scanf("%d%d",&n,&m); len = (int)sqrt(m * 15);
for (int i = 2 ; i <= n ; ++i) scanf("%d",&fa[i][0]),add(i,fa[i][0]);
for (int i = 1 ; i <= n ; ++i) scanf("%d",&t[i]);
for (int i = 1 ; i <= m ; ++i) scanf("%d",&q[i]),q[i] = abs(q[i]);
tot = 0,dfs(1);
cnt = m / len + ((m % len) ? 1 : 0);
for (int i = 1 ; i <= cnt ; ++i)
{
int l = (i - 1) * len + 1,r = min(m,i * len);
build(l,r);
solve(l,r);
}
for (int i = 1 ; i <= m ; ++i) printf("%d ",ans[i]);
return 0;
}