1036: [ZJOI2008]树的统计Count
Time Limit: 10 Sec Memory Limit: 162 MBSubmit: 19818 Solved: 8066
[ Submit][ Status][ Discuss]
Description
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
Input
输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
Output
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
Sample Input
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
Sample Output
4
1
2
2
10
6
5
6
5
16
1
2
2
10
6
5
6
5
16
HINT
Source
【思路】
朴素算法对每个操作都执行一遍深搜,不可取,需要一种能够记住点对点路径或者部分路径的方式。所以对整棵树进行轻重链剖分,把同一条链的节点映射到一个连续区间,再对其使用线段树维护。思想是:每个节点都属于某一条链,每条链都有一个唯一顶端,那么如果两个点具有同一个顶端,则位于同一条链上,那么其间的路径信息便可较快获取。
【代码】
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int MAXN = 30005, INF = 0x3f3f3f3f;
struct edge {
int to, next;
};
struct segment {
int left, right, mid;
int sum, mx;
};
int n, q, cnt, tot;
int head[MAXN], p[MAXN], in[MAXN], id[MAXN],fa[MAXN], top[MAXN], sz[MAXN], max_son[MAXN], deep[MAXN];
edge e[MAXN << 1];
segment tree[MAXN << 2];
void addedge(int from, int to)
{
++cnt;
e[cnt].to = to;
e[cnt].next = head[from];
head[from] = cnt;
}
void dfs_1(int u, int father, int depth)
{
deep[u] = depth;
fa[u] = father; max_son[u] = 0;
sz[u] = 1;
for (int i = head[u]; i != 0; i = e[i].next) {
int v = e[i].to;
if (v == fa[u]) continue;
dfs_1(v, u, depth + 1);
sz[u] += sz[v];
if (sz[max_son[u]] < sz[v]) max_son[u] = v;
}
}
void dfs_2(int u, int tp)
{
in[u] = ++tot;
id[tot] = u;
top[u] = tp;
if (max_son[u] != 0) dfs_2(max_son[u], tp);
for (int i = head[u]; i != 0; i = e[i].next) {
int v = e[i].to;
if (v == fa[u] || v == max_son[u]) continue;
dfs_2(v, v);
}
}
void build(int left, int right, int root)
{
tree[root].left = left;
tree[root].right = right;
tree[root].mid = (left + right) >> 1;
if (left == right) {
tree[root].mx = tree[root].sum = p[id[left]];
return;
}
build(left, tree[root].mid, root << 1);
build(tree[root].mid + 1, right, root << 1 | 1);
tree[root].sum = tree[root << 1].sum + tree[root << 1 | 1].sum;
tree[root].mx = max(tree[root << 1].mx, tree[root << 1 | 1].mx);
}
void modify(int index, int num, int root)
{
if (tree[root].left == tree[root].right) {
tree[root].mx = tree[root].sum = num;
return;
}
if (index <= tree[root].mid) modify(index, num, root << 1);
if (index >= tree[root].mid + 1) modify(index, num, root << 1 | 1);
tree[root].mx = max(tree[root << 1].mx, tree[root << 1 | 1].mx);
tree[root].sum = tree[root << 1].sum + tree[root << 1 | 1].sum;
}
int sum_query(int l, int r, int root)
{
if (l <= tree[root].left && tree[root].right <= r) return tree[root].sum;
int ans = 0;
if (l <= tree[root].mid) ans += sum_query(l, r, root << 1);
if (r >= tree[root].mid + 1) ans += sum_query(l, r, root << 1 | 1);
return ans;
}
int max_query(int l, int r, int root)
{
if (l <= tree[root].left && tree[root].right <= r) return tree[root].mx;
int ans = -INF;
if (l <= tree[root].mid) ans = max(ans, max_query(l, r, root << 1));
if (r >= tree[root].mid + 1) ans = max(ans, max_query(l, r, root << 1 | 1));
return ans;
}
int main()
{
cnt = 0;
memset(head, 0, sizeof(head));
scanf("%d", &n);
for (int i = 1; i <= n - 1; i++) {
int a, b; scanf("%d %d", &a, &b);
addedge(a, b);
addedge(b, a);
}
for (int i = 1; i <= n; i++) scanf("%d", &p[i]);
dfs_1(1, 0, 1);
tot = 0;
dfs_2(1, 1);
build(1, n, 1);
scanf("%d", &q);
while (q--) {
char mes[7];
int u, v; scanf("%s %d %d", mes, &u, &v);
if (mes[0] == 'C') modify(in[u], v, 1);
if (mes[1] == 'M') {
int ans = -INF;
while (top[u] != top[v]) {
if (deep[top[u]] < deep[top[v]]) swap(u, v);
ans = max(ans, max_query(in[top[u]], in[u], 1));
u = fa[top[u]];
}
if (deep[u] > deep[v]) swap(u, v);
ans = max(ans, max_query(in[u], in[v], 1));
printf("%d\n", ans);
}
if (mes[1] == 'S') {
int ans = 0;
while (top[u] != top[v]) {
if (deep[top[u]] < deep[top[v]]) swap(u, v);
ans += sum_query(in[top[u]], in[u], 1);
u = fa[top[u]];
}
if (deep[u] > deep[v]) swap(u, v);
ans += sum_query(in[u], in[v], 1);
printf("%d\n", ans);
}
}
return 0;
}