c
一看这么多修改询问就可以猜到是线段树,但是我们没有办法直接搞线段树,于是我们对树做一次dfs生成一个2*N的序列,L[u],R[u]分别表示刚进入子树u时的dfs序列位置和刚处理完子树u时的dfs序列位置。则l[u]到r[u]这一段便是子树u的所有节点。我们对每个节点记录两个值cnt,k,每个节点v的实际值w = cnt *dep[v] + k,则原操作add u x可以看成是对每个子树中的cnt值+1,k值+ x - dep[u],这样便可以用线段树对DFS序列进行区间操作了。
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <cmath>
#include <algorithm>
using namespace std;
typedef long long ll;
const int MaxN = 50005;
const int MaxE = MaxN * 2;
const int MaxS = MaxN * 4;
struct Edge
{
int to;
Edge* next;
} memo[MaxE], *cur, *g[MaxN];
int dep[MaxN], l[MaxN], r[MaxN], N, K;
int ax[MaxN];
void init()
{
for (int i = 1; i <= N; i++)
g[i] = NULL;
cur = memo;
K = 0;
}
void add_edge(int u, int v)
{
cur->to = v;
cur->next = g[u];
g[u] = cur++;
}
void DFS(int u, int f)
{
int v;
l[u] = ++K;
ax[K] = u;
dep[u] = dep[f] + 1;
for (Edge* it = g[u]; it; it = it->next)
{
v = it->to;
if (v != f)
DFS(v, u);
}
r[u] = K;
}
struct Seg
{
int l, r;
ll sumK, delK;
ll base, sumD, delCnt;
} seg[MaxS];
void insK(int k, ll v)
{
seg[k].sumK += v * (seg[k].r - seg[k].l + 1);
seg[k].delK += v;
}
void insCnt(int k, ll v)
{
seg[k].delCnt += v;
seg[k].sumD += seg[k].base * v;
}
void update(int k)
{
seg[k].sumD = seg[k << 1].sumD + seg[k << 1 | 1].sumD;
seg[k].sumK = seg[k << 1].sumK + seg[k << 1 | 1].sumK;
}
void pushdown(int k)
{
if(seg[k].delK)
{
insK(k << 1, seg[k].delK);
insK(k << 1 | 1, seg[k].delK);
seg[k].delK = 0;
}
if(seg[k].delCnt)
{
insCnt(k << 1, seg[k].delCnt);
insCnt(k << 1 | 1, seg[k].delCnt);
seg[k].delCnt = 0;
}
}
void init(int k, int l, int r)
{
seg[k].l = l;
seg[k].r = r;
seg[k].sumK = seg[k].delK = 0;
seg[k].sumD = seg[k].delCnt = 0;
if(l == r)
{
seg[k].base = dep[ax[l]];
return;
}
int mid = (l + r) >> 1;
init(k << 1, l, mid);
init(k << 1 | 1, mid + 1, r);
seg[k].base = seg[k << 1].base + seg[k << 1 | 1].base;
}
void addK(int k, int l, int r, ll v)
{
if(seg[k].l > r || seg[k].r < l) return;
if(seg[k].l >= l && seg[k].r <= r)
{
insK(k, v);
return;
}
pushdown(k);
addK(k << 1, l, r, v);
addK(k << 1 | 1, l, r, v);
update(k);
}
void addCnt(int k, int l, int r, ll v)
{
if(seg[k].l > r || seg[k].r < l) return;
if(seg[k].l >= l && seg[k].r <= r)
{
insCnt(k, v);
return;
}
pushdown(k);
addCnt(k << 1, l, r, v);
addCnt(k << 1 | 1, l, r, v);
update(k);
}
ll read(int k, int l, int r)
{
if(seg[k].l > r || seg[k].r < l)
return 0;
if(seg[k].l >= l && seg[k].r <= r)
return seg[k].sumK + seg[k].sumD;
pushdown(k);
return read(k << 1, l, r) + read(k << 1 | 1, l, r);
}
int main()
{
int T, x;
int Q, u, cas = 1;
char op[5];
freopen("c.in", "r", stdin);
freopen("c.out", "w", stdout);
scanf("%d%d",&N,&Q);
init();
for (int i = 2; i <= N; i++)
{
scanf("%d",&x);
add_edge(x, i);
add_edge(i, x);
}
dep[0] = -1;
DFS(1, 0);
init(1, 1, N);
while (Q--)
{
scanf("%s",op);
if (op[0] == 'Q')
{
scanf("%d",&u);
printf("%I64d\n", read(1, l[u], r[u]));
}
else
{
scanf("%d%d",&u,&x);
addK(1, l[u], r[u], x - dep[u]);
addCnt(1, l[u], r[u], 1);
}
}
return 0;
}