【bzoj4712】洪水【树链剖分】【动态dp】

版权声明:本文为蒟蒻的博主极其水的原创文章,各位大佬转载记得标注id哦。 https://blog.csdn.net/ez_2016gdgzoi471/article/details/81914548

我们首先想一个dp方程:f[u]=min{v,vson[u]f[v],val[u]}f[u]=min\{\sum_{v,v\in son[u]}{f[v]},val[u]\}
这个方程可以通过矩阵的形式来表示。
先把树轻重链剖分了。
我们设g[u]g[u]uu的所有儿子的ff总和,vvuu的重儿子。
f[u]=min{f[v]+g[u],val[u]}f[u]=min\{f[v]+g[u],val[u]\}
我们可以把这个写成一个最短路矩阵相乘的形式。
[0f[u]]=[0f[v]]×[0val[u]g[u]]\left[ \begin{matrix}0 & f[u] \end{matrix} \right] = \left[ \begin{matrix} 0 & f[v] \end{matrix} \right] \times \left[ \begin{matrix} 0 & val[u] \\ \infty & g[u] \end{matrix} \right]

这个乘法是最短路矩阵相乘,就是取max。可以证明这是满足结合律的,组合证明其实很好想。
所以我们树剖维护矩阵乘法,修改时爬链修改就好了。注意要维护从右往左的,因为我们是从dfsdfs序大的往小的转移的。

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=200005,inf=0x3f3f3f3f3f3f3f3f;
int n,q,cnt,u,v,val[N],head[N],to[N*2],nxt[N*2],f[N],g[N];
int idx,fa[N],dep[N],siz[N],son[N],dfn[N],pos[N],top[N],bot[N];
char op[5];
struct matrix{
	int a[2][2];
	matrix(){
		memset(a,63,sizeof(a));
	}
	matrix operator * (const matrix &b){
		matrix c;
		for(int i=0;i<2;i++){
			for(int j=0;j<2;j++){
				if(a[i][j]!=inf){
					for(int k=0;k<2;k++){
						c.a[i][k]=min(c.a[i][k],a[i][j]+b.a[j][k]);
					}
				}
			}
		}
		return c;
	}
}a[N],minn[N*4];
void adde(int u,int v){
	to[++cnt]=v;
	nxt[cnt]=head[u];
	head[u]=cnt;
}
void dfs(int u){
	siz[u]=1;
	int v;
	for(int i=head[u];i;i=nxt[i]){
		v=to[i];
		if(v!=fa[u]){
			fa[v]=u;
			dep[v]=dep[u]+1;
			dfs(v);
			siz[u]+=siz[v];
			if(!son[u]||siz[son[u]]<siz[v]){
				son[u]=v;
			}
		}
	}
}
void dfs(int u,int tp){
	dfn[u]=++idx;
	pos[idx]=u;
	top[u]=tp;
	bot[tp]=u;
	if(!head[u]) {
		return;
	}
	f[u]=g[u]=inf;
	if(son[u]){
		dfs(son[u],tp);
		f[u]=f[son[u]];
		g[u]=0;
	}
	int v;
	for(int i=head[u];i;i=nxt[i]){
		v=to[i];
		if(v!=fa[u]&&v!=son[u]){
			dfs(v,v);
			g[u]+=f[v];
		}
	}
	f[u]=min(f[u]+g[u],val[u]);
}
void build(int o,int l,int r){
	if(l==r){
		minn[o]=a[pos[l]];
		return;
	}
	int mid=(l+r)/2;
	build(o*2,l,mid);
	build(o*2+1,mid+1,r);
	minn[o]=minn[o*2+1]*minn[o*2];
}
void upd(int o,int l,int r,int k){
	if(l==r){
		minn[o]=a[pos[l]];
		return;
	}
	int mid=(l+r)/2;
	if(k<=mid){
		upd(o*2,l,mid,k);
	}else{
		upd(o*2+1,mid+1,r,k);
	}
	minn[o]=minn[o*2+1]*minn[o*2];
}
matrix qry(int o,int l,int r,int L,int R){
	if(L==l&&R==r){
		return minn[o];
	}
	int mid=(l+r)/2;
	if(R<=mid){
		return qry(o*2,l,mid,L,R);
	}else if(L>mid){
		return qry(o*2+1,mid+1,r,L,R);
	}else{
		return qry(o*2+1,mid+1,r,mid+1,R)*qry(o*2,l,mid,L,mid);
	}
}
int query(int u){
	return qry(1,1,n,dfn[u],dfn[bot[top[u]]]).a[0][1];
}
void update(int u){
	while(u){
		a[u].a[0][1]=val[u];
		a[u].a[1][1]=g[u];
		upd(1,1,n,dfn[u]);
		if(u==1){
			break;
		}
		u=top[u];
		g[fa[u]]-=f[u];
		f[u]=qry(1,1,n,dfn[u],dfn[bot[top[u]]]).a[0][1];
		g[fa[u]]+=f[u];
		u=fa[u];
	}
}
signed main(){
	scanf("%lld",&n);
	for(int i=1;i<=n;i++){
		scanf("%lld",&val[i]);
	}
	for(int i=1;i<n;i++){
		scanf("%lld%lld",&u,&v);
		adde(u,v);
		adde(v,u);
	}
	dfs(1);
	dfs(1,1);
	for(int i=1;i<=n;i++){
		a[i].a[0][0]=0;
		a[i].a[0][1]=val[i];
		a[i].a[1][1]=g[i];
	}
	build(1,1,n);
	scanf("%lld",&q);
	while(q--){
		scanf("%s",op);
		if(op[0]=='Q'){
			scanf("%lld",&u);
			printf("%lld\n",query(u));
		}else{
			scanf("%lld%lld",&u,&v);
			val[u]+=v;
			update(u);
		}
	}
	return 0;
}
展开阅读全文

没有更多推荐了,返回首页