BZOJ 3531: [Sdoi2014]旅行 权值线段树_树链剖分
Code:
#include <bits/stdc++.h>
#define setIO(s) freopen(s".in","r",stdin)
#define ll long long
#define inf 100000000000
#define maxn 500000
#define N 200003
using namespace std;
namespace Seg
{
#define lson t[x].l
#define rson t[x].r
int n, tot;
struct Node
{
int l, r;
ll sumv, maxv;
}t[maxn << 2];
void pushup(int x)
{
t[x].sumv = t[lson].sumv + t[rson].sumv;
t[x].maxv = max(t[lson].maxv, t[rson].maxv);
}
void ins(int &x, int l, int r, int p, ll v)
{
if(!x) x = ++ tot;
if(l == r)
{
t[x].sumv = t[x].maxv = v;
return ;
}
int mid = (l + r) >> 1;
if(p <= mid) ins(lson, l, mid, p, v);
else ins(rson, mid + 1, r, p, v);
pushup(x);
}
void del(int x, int l, int r, int p)
{
if(l == r)
{
t[x].sumv = t[x].maxv = 0;
return ;
}
int mid = (l + r) >> 1;
if(p <= mid) del(lson, l, mid, p);
else del(rson, mid + 1, r, p);
pushup(x);
}
ll query_sum(int l, int r, int x, int L, int R)
{
if(!x) return 0;
if(l >= L && r <= R) return t[x].sumv;
ll tmp = 0;
int mid = (l + r) >> 1;
if(L <= mid) tmp += query_sum(l, mid, lson, L, R);
if(R > mid) tmp += query_sum(mid + 1, r, rson, L, R);
return tmp;
}
ll query_max(int l, int r, int x, int L, int R)
{
if(!x) return -inf;
if(l >= L && r <= R) return t[x].maxv;
ll tmp = -inf;
int mid = (l + r) >> 1;
if(L <= mid) tmp = max(tmp, query_max(l, mid, lson, L, R));
if(R > mid) tmp = max(tmp, query_max(mid + 1, r, rson, L, R));
return tmp;
}
#undef lson
#undef rson
};
char str[10];
int n, Q, edges, tim;
int hd[maxn], to[maxn << 1], nex[maxn << 1], W[maxn], C[maxn], fa[maxn], dep[maxn];
int ln[maxn], dfn[maxn], top[maxn], bot[maxn], siz[maxn], hson[maxn], rt[maxn];
void add(int u, int v)
{
nex[++edges] = hd[u], hd[u] = edges, to[edges] = v;
}
void dfs1(int u, int ff)
{
siz[u] = 1, fa[u] = ff, dep[u] = dep[ff] + 1;
for(int i = hd[u]; i ; i = nex[i])
{
int v = to[i];
if(v == ff) continue;
dfs1(v, u);
siz[u] += siz[v];
if(siz[hson[u]] < siz[v]) hson[u] = v;
}
}
void dfs2(int u, int tp)
{
top[u] = tp, ln[++tim] = u, dfn[u] = tim;
Seg :: ins(rt[C[u]], 1, N, tim, 1ll*W[u]);
if(hson[u])
dfs2(hson[u], tp), bot[u] = bot[hson[u]];
else
bot[u] = u;
for(int i = hd[u]; i ; i = nex[i])
{
int v = to[i];
if(v == fa[u] || v == hson[u]) continue;
dfs2(v, v);
}
}
ll _query_sum(int x, int y)
{
int ty = C[y];
ll tmp = 0;
// y is the deeper one
while(top[x] ^ top[y])
{
if(dep[top[x]] > dep[top[y]]) swap(x, y);
tmp += Seg :: query_sum(1, N, rt[ty], dfn[top[y]], dfn[y]);
y = fa[top[y]];
}
if(dep[x] > dep[y]) swap(x, y);
tmp += Seg :: query_sum(1, N, rt[ty], dfn[x], dfn[y]);
return tmp;
}
ll _query_max(int x, int y)
{
int ty = C[y];
ll tmp = 0;
while(top[x] ^ top[y])
{
if(dep[top[x]] > dep[top[y]]) swap(x, y);
tmp = max(tmp, Seg :: query_max(1, N, rt[ty], dfn[top[y]], dfn[y]));
y = fa[top[y]];
}
if(dep[x] > dep[y]) swap(x, y);
tmp = max(tmp, Seg :: query_max(1, N, rt[ty], dfn[x], dfn[y]));
return tmp;
}
int main()
{
// setIO("input");
scanf("%d%d",&n,&Q);
for(int i = 1;i <= n; ++i) scanf("%d%d",&W[i],&C[i]);
for(int i = 1, u, v; i < n; ++i)
{
scanf("%d%d",&u,&v), add(u, v), add(v, u);
}
Seg :: t[0].maxv = -inf;
dfs1(1, 0), dfs2(1, 1);
while(Q--)
{
scanf("%s",str);
int x, w, c, y;
if(str[1] == 'C')
{
scanf("%d%d",&x,&c);
Seg :: del(rt[C[x]], 1, N, dfn[x]);
C[x] = c;
Seg :: ins(rt[C[x]], 1, N, dfn[x], W[x]);
}
if(str[1] == 'W')
{
scanf("%d%d",&x,&w);
W[x] = w;
Seg :: ins(rt[C[x]], 1, N, dfn[x], 1ll*W[x]);
}
if(str[1] == 'S')
{
scanf("%d%d",&x,&y), printf("%lld\n",_query_sum(x, y));
}
if(str[1] == 'M')
{
scanf("%d%d",&x,&y), printf("%lld\n",_query_max(x, y));
}
}
return 0;
}