题目描述
传送门
题目大意:给出一颗树,对于每个节点可以花费val[i]的代价把他堵住,使水无法向下流。每个节点的代价可能会在某时刻增加。对于每个询问,求出某个点的子树中的叶子节点(没有子节点)都没有水的最小代价。
题解
先考虑如果没有点权的修改该怎么做?应该是比较裸的树形DP.
dp[i]=min(val[i],∑dp[son[i]])
,我们设
h[i]=∑dp[son[i]]
然后考虑修改一个节点的点权,对其他节点的影响。影响的必然是该点到根路径上的一段连续区间。
假设对于某个节点来说
val[i]>h[i]
,现在这个节点的某个儿子的
dp[son[i]]
增加了
t
,那么如果
我们考虑对于每个节点用线段树维护两个值
tr[i]=val[i]−h[i],g[i]=h[i]
。假设现在产生增量的点为x,那么从x向上到第一个tr[i]小于等于增量的点的儿子节点(在x->i的路径上),这段区间的g[i]+=t,tr[i]-=t,我们可以用线段树的区间修改进行解决。这些点不会对他们父节点的值产生影响。对于i说,他的dp[i]发生了变化,那么他会产生新的增量,我们需要继续上述的过程,直到某个点的增量为0,或者到达根节点。
这样我们每次查询一个点的dp值,其实就是对val[i]和线段树中维护的g[i]取min.
向上找点的过程比较的恶心,处理起来比较的麻烦。
代码
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define N 400003
#define LL long long
using namespace std;
const LL inf=10000000000000LL;
LL tr[N*4],delta[N*4],g[N*4],h[N],val[N],dp[N];
int n,tot,sz,q[N],pos[N],belong[N],son[N],deep[N],size[N],fa[N];
int point[N],nxt[N],v[N],m;
void add(int x,int y)
{
tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x;
}
void dfs(int x,int f)
{
deep[x]=deep[f]+1; size[x]=1; dp[x]=val[x];
bool mark=false;
for (int i=point[x];i;i=nxt[i]){
if (v[i]==f) continue;
fa[v[i]]=x; mark=true;
dfs(v[i],x);
size[x]+=size[v[i]];
if (size[son[x]]<size[v[i]]) son[x]=v[i];
h[x]+=dp[v[i]];
}
if (mark) dp[x]=min(h[x],dp[x]);
else h[x]=inf;
}
void dfs1(int x,int chain)
{
belong[x]=chain; pos[x]=++sz; q[sz]=x;
if (!son[x]) return;
dfs1(son[x],chain);
for (int i=point[x];i;i=nxt[i])
if (v[i]!=fa[x]&&v[i]!=son[x]) dfs1(v[i],v[i]);
}
void update(int now)
{
tr[now]=min(tr[now<<1|1],tr[now<<1]);
}
void build(int now,int l,int r)
{
if (l==r) {
int t=q[l];
tr[now]=val[t]-h[t];
g[now]=h[t];
return;
}
int mid=(l+r)/2;
build(now<<1,l,mid); build(now<<1|1,mid+1,r);
update(now);
}
void change(int now,LL t)
{
delta[now]+=t; g[now]+=t; tr[now]-=t;
}
void pushdown(int now)
{
if (delta[now]) {
change(now<<1,delta[now]); change(now<<1|1,delta[now]);
delta[now]=0;
}
}
void pointchange(int now,int l,int r,int x,LL t)
{
if (l==r) {
tr[now]+=t;
return;
}
int mid=(l+r)/2;
pushdown(now);
if (x<=mid) pointchange(now<<1,l,mid,x,t);
else pointchange(now<<1|1,mid+1,r,x,t);
update(now);
}
LL find(int now,int l,int r,int x)
{
if (l==r) return g[now];
int mid=(l+r)/2;
pushdown(now);
if (x<=mid) return find(now<<1,l,mid,x);
else return find(now<<1|1,mid+1,r,x);
}
void qjchange(int now,int l,int r,int ll,int rr,LL v)
{
if (ll<=l&&r<=rr) {
change(now,v);
return;
}
int mid=(l+r)/2;
pushdown(now);
if (ll<=mid) qjchange(now<<1,l,mid,ll,rr,v);
if (rr>mid) qjchange(now<<1|1,mid+1,r,ll,rr,v);
update(now);
}
LL qjmin(int now,int l,int r,int ll,int rr)
{
if (ll<=l&&r<=rr) return tr[now];
int mid=(l+r)/2;
pushdown(now); LL ans=inf;
if (ll<=mid) ans=min(ans,qjmin(now<<1,l,mid,ll,rr));
if (rr>mid) ans=min(ans,qjmin(now<<1|1,mid+1,r,ll,rr));
return ans;
}
LL get_ans(int x)
{
return min(val[x],find(1,1,n,pos[x]));
}
void modify(int x,int y,LL t)
{
while (belong[x]!=belong[y]) {
if (deep[x]<deep[y]) swap(x,y);
qjchange(1,1,n,pos[belong[x]],pos[x],t);
x=fa[belong[x]];
}
if (deep[x]>deep[y]) swap(x,y);
qjchange(1,1,n,pos[x],pos[y],t);
}
int check(int now,int l,int r,int ll,int rr,LL t)
{
if (l==r) return l;
int mid=(l+r)/2;
pushdown(now);
if (ll>mid) return check(now<<1|1,mid+1,r,ll,rr,t);
if (rr<=mid) return check(now<<1,l,mid,ll,rr,t);
if (qjmin(now<<1|1,mid+1,r,ll,rr)<t)
return check(now<<1|1,mid+1,r,ll,rr,t);
return check(now<<1,l,mid,ll,rr,t);
}
int divide(int x,LL t)
{
int ans=0; int now=x; int down=x; x=fa[x];
while (true) {
LL mn=qjmin(1,1,n,pos[belong[x]],pos[x]);
if (mn<=t) {
ans=x;
break;
}
down=belong[x]; x=fa[belong[x]];
if (x==0) break;
}
if (x==0) return now;
ans=check(1,1,n,pos[belong[x]],pos[x],t);
if (ans<pos[x]) return q[++ans];
else return down;
}
void solve(int x,LL y)
{
LL last=get_ans(x); val[x]+=y;
pointchange(1,1,n,pos[x],y);
LL now=get_ans(x);
if (now==last) return;
while (x&&x!=1) {
int t=divide(x,now-last);
if (t==0) return;
if (x!=t) modify(fa[x],t,now-last);
x=fa[t];
LL add=now-last; last=get_ans(x);
qjchange(1,1,n,pos[x],pos[x],add);
now=get_ans(x);
if (now==last) return;
}
}
int main()
{
freopen("a.in","r",stdin);
freopen("my.out","w",stdout);
scanf("%d",&n);
for (int i=1;i<=n;i++) scanf("%lld",&val[i]);
for (int i=1;i<n;i++) {
int x,y; scanf("%d%d",&x,&y);
add(x,y);
}
dfs(1,0); dfs1(1,1);
// for (int i=1;i<=n;i++) printf("%d ",dp[i]); cout<<endl;
build(1,1,n);
//for (int i=1;i<=n;i++) cout<<q[i]<<" "; cout<<endl;
//cout<<find(1,1,n,2)<<endl;
scanf("%d",&m);
for (int i=1;i<=m;i++) {
char s[10]; int x; LL y;
scanf("%s",s+1);
if (s[1]=='Q') {
scanf("%d",&x);
printf("%lld\n",get_ans(x));
}
else {
scanf("%d%lld",&x,&y);
solve(x,y);
}
}
}