一树上有 n 个节点,编号分别为 1 到 n,每个节点都有一个权值 w。
我们将以下面的形式来要求你对这棵树完成一些操作:
CHANGE u t
:把节点 u权值改为 t;QMAX u v
:询问点 u 到点 v 路径上的节点的最大权值;QSUM u v
:询问点 u 到点 v 路径上的节点的权值和。注意:从点 u 到点 v路径上的节点包括 u 和 v本身。
输入格式
第一行为一个数 n,表示节点个数;
接下来 n−1 行,每行两个整数 a,b,表示节点 a 与节点 b 之间有一条边相连;
接下来 n 行,每行一个整数,第 i 行的整数 wi 表示节点 i 的权值;
接下来一行,为一个整数 q,表示操作总数;
接下来 q 行,每行一个操作,以
CHANGE u t
或QMAX u v
或QSUM u v
的形式给出。输出格式
对于每个 QMAX 或 QSUM 的操作,每行输出一个整数表示要求的结果。
数据范围
1≤n≤3×104,
0≤q≤2×105,
中途操作中保证每个节点的权值 w 在 −30000 至 30000 之间。输入样例:
4 1 2 2 3 4 1 4 2 1 3 12 QMAX 3 4 QMAX 3 3 QMAX 3 2 QMAX 2 3 QSUM 3 4 QSUM 2 1 CHANGE 1 5 QMAX 3 4 CHANGE 3 6 QMAX 3 4 QMAX 2 4 QSUM 3 4
输出样例:
4 1 2 2 10 6 5 6 5 16
难度:困难 时/空限制:1s / 64MB 总通过数:86 总尝试数:279 来源:《信息学奥赛一本通》 , ZJOI2008 算法标签
挑战模式
#include <iostream>
#include <cstring>
using namespace std;
typedef long long ll;
constexpr int N=30010,M=N*2;
int n;
int h[N],e[M],ne[M],idx;
int top[N],wn[N],w[N],sz[N],fa[N],id[N],cnt,dep[N],son[N];
struct node{
int l,r;
int sum,maxx;
}tr[N*4];
void add(int a,int b){
e[idx]=b;
ne[idx]=h[a];
h[a]=idx++;
}
void dfs1(int u,int father,int depth){
dep[u]=depth,fa[u]=father,sz[u]=1;
for(int i=h[u];i!=-1;i=ne[i]){
int j=e[i];
if(j!=father){
dfs1(j,u,depth+1);
sz[u]+=sz[j];
if(sz[son[u]]<sz[j]) son[u]=j;
}
}
}
void dfs(int u,int t){
id[u]=++cnt,wn[cnt]=w[u],top[u]=t;
if (!son[u]) return;
dfs(son[u], t);
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa[u] || j == son[u]) continue;
dfs(j, j);
}
}
void pushup(int u){
tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum;
tr[u].maxx= max(tr[u<<1].maxx,tr[u<<1|1].maxx);
}
void build(int u,int l,int r){
tr[u]={l,r};
if(l==r){
tr[u].sum=wn[r];
tr[u].maxx=wn[r];
return;
}
int mid=l+r>>1;
build(u<<1,l,mid);
build(u<<1|1,mid+1,r);
pushup(u);
}
void update(int u,int x,int v){
if(tr[u].l==x&&tr[u].r==x) tr[u].maxx=v,tr[u].sum=v;
else{
int mid=(tr[u].l+tr[u].r)>>1;
if(x<=mid) update(u<<1,x,v);
else update(u<<1|1,x,v);
pushup(u);
}
}
int query1(int u,int l,int r)
{
if(tr[u].l>=l&&tr[u].r<=r) return tr[u].maxx;
int mid=tr[u].l+tr[u].r>>1;
int v=-1e9;
if(l<=mid) v=max(v,query1(u<<1,l,r));
if(r>mid) v=max(v,query1(u<<1|1,l,r));
return v;
}
int query2(int u,int l,int r) {
if (l <= tr[u].l && r >= tr[u].r) return tr[u].sum;
int mid = tr[u].l + tr[u].r >> 1;
int res = 0;
if (l <= mid) res += query2(u << 1, l, r);
if (r > mid) res += query2(u << 1 | 1, l, r);
return res;
}
int query_nax(int u,int v){
int res = -1e9;
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]]) swap(u, v);
res = max(res, query1(1, id[top[u]], id[u]));
u = fa[top[u]];
}
if (dep[u] < dep[v]) swap(u, v);
res = max(res, query1(1, id[v], id[u]));
return res;
}
int query_path(int u,int v) {
int res = 0;
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]]) swap(u, v);
res += query2(1, id[top[u]], id[u]);
u = fa[top[u]];
}
if (dep[u] < dep[v]) swap(u, v);
res += query2(1, id[v], id[u]);
return res;
}
int main(){
scanf("%d",&n);
memset(h,-1,sizeof h);
for(int i=1;i<n;i++){
int a,b;
scanf("%d%d",&a,&b);
add(a,b);
add(b,a);
}
for(int i=1;i<=n;i++){
scanf("%d",&w[i]);
}
dfs1(1,-1,1);
dfs(1,1);
build(1,1,n);
int m;
scanf("%d",&m);
while(m--){
char op[10];
int u,v;
scanf("%s%d%d",op,&u,&v);
if(op[0] == 'C'){
update(1,id[u],v);
}
else if(op[1] == 'M'){
printf("%d\n",query_nax(u,v));
}
else {
printf("%d\n",query_path(u,v));
}
}
return 0;
}