bzoj4712: 洪水 动态Dp 树链剖分+线段树 或 LCT维护矩阵乘法

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/lvzelong2014/article/details/81156949

bzoj4712: 洪水

Description

小A走到一个山脚下,准备给自己造一个小屋。这时候,小A的朋友(op,又叫管理员)打开了创造模式,然后飞到
山顶放了格水。于是小A面前出现了一个瀑布。作为平民的小A只好老实巴交地爬山堵水。那么问题来了:我们把这
个瀑布看成是一个n个节点的树,每个节点有权值(爬上去的代价)。小A要选择一些节点,以其权值和作为代价将
这些点删除(堵上),使得根节点与所有叶子结点不连通。问最小代价。不过到这还没结束。小A的朋友觉得这样
子太便宜小A了,于是他还会不断地修改地形,使得某个节点的权值发生变化。不过到这还没结束。小A觉得朋友做
得太绝了,于是放弃了分离所有叶子节点的方案。取而代之的是,每次他只要在某个子树中(和子树之外的点完全
无关)。于是他找到你。

Input

输入文件第一行包含一个数n,表示树的大小。
接下来一行包含n个数,表示第i个点的权值。
接下来n-1行每行包含两个数fr,to。表示书中有一条边(fr,to)。
接下来一行一个整数,表示操作的个数。
接下来m行每行表示一个操作,若该行第一个数为Q,则表示询问操作,后面跟一个参数x,表示对应子树的根;若
为C,则表示修改操作,后面接两个参数x,to,表示将点x的权值加上to。
n<=200000,保证任意to都为非负数

Output

对于每次询问操作,输出对应的答案,答案之间用换行隔开。

Sample Input

4
4 3 2 1
1 2
1 3
4 2
4
Q 1
Q 2
C 4 10
Q 1

Sample Output

3
1
4

分析

动态Dp大概的概念就是Dp中用来决策的变量会变化,在第一遍Dp的基础之上考录若干个变量的变化对Dp值的影响,并用一些奇技淫巧维护(比如数据结构和数据结构和数据结构)
这道题很典型。
首先方程很好写吧。
f[u]=min{v[u],f[sonu]}f[u]=min \{v[u], \sum f[son_u]\}
为了方便,定义h[u]=f[sonu]h[u]=\sum f[son_u]
不难发现任何修改不会减小f的值。
考虑一次修改的影响。
注意修改uu节点的值只会影响到uu的祖先。
考虑一个点的ff值的增量DiD_i
如果说Du=0D_u=0,显然不会对答案造成任何影响。
否则的话,对于uu的若干个连续祖先xx,如果他们满足h[x]+Du&lt;=v[x]h[x]+D_u&lt;=v[x],那么有h[x]h[x]+Du,f[x]Du+f[x]h[x]\to h[x]+D_u, f[x] \to D_u+f[x]
考虑这些祖先深度最低的节点zz,显然h[z]h[z]+Duh[z] \to h[z]+D_u
然而它的ff值就不一样了,f[z]v[z]f[z] \to v[z]
这个时候我们发现我们面对的是一个完全一样的子问题。
考虑这样的子问题会有多少个。
注意到一旦一个节点uu满足f[u]=v[u]f[u]=v[u],那么除非这个节点被修改,不然这个节点的ff值不会再增加。
所以说这样的子问题个数是O(n+m)O(n+m)的。
于是我们考虑如何用数据结构快速解决这个子问题。
首先找到uu的若干个满足h[x]+Du&lt;=v[x]h[x]+D_u&lt;=v[x]连续祖先xx中的深度最低的祖先zz
变换一下v[x]h[x]&gt;=Duv[x]-h[x]&gt;=D_u,用树链剖分+线段树维护min{vh}min\{v-h\},在线段树上二分即可。
修改打标记,查询Dp值的时候可以直接用叶子节点的标记来查询。
一次复杂度O(log2n)O(log^2n),总复杂度O((n+m)log2n)O((n+m)log^2n)

UPD2019.1.28: 用了一个LCT维护矩阵乘法的做法,可以适用于修改不递增的情况,而且复杂度为O((n+m)logn)O((n+m)logn)
具体的做法就是,LCT维护虚子树的ff之和,考虑把SplaySplay中的一条实链拉成序列来写方程,可以得到这样的方程:
fu=min{fu1+g[u],val[u]}f_u=\min \{f_{u-1}+g[u],val[u]\}
其中uuu1u-1对应实链上的一对父子。
考虑把这个式子写成矩阵乘法的形式。
[fu10][gu0val[u]0][fu0] \left[ \begin{matrix} f_{u-1} &amp; 0 \end{matrix} \right] \cdot \left[ \begin{matrix} g_u &amp; 0 \\ val[u] &amp; 0 \end{matrix} \right] \to \left[ \begin{matrix} f_u &amp; 0 \end{matrix} \right]
当然这个乘法是用min重载过后的。
这样的话再Access的时候维护虚实边的切换,Splay里维护链的矩阵乘法结果即可。
还有一种常数吊打LCT的全局平衡二叉树做法,出门转洛谷模板区,懒得敲了,LCT挺好写的。

代码

链剖
真滴难写。

#include<cstdio>
#include<algorithm>
#define ls p << 1
#define rs p << 1 | 1
typedef long long LL;
int ri() {
    char c = getchar(); int x = 0; for(;c < '0' || c > '9'; c = getchar()) ;
    for(;c >= '0' && c <= '9'; c = getchar()) x = (x << 1) + (x << 3) - '0' + c; return x;
}
const int N = 2e5 + 10;
int f[N], de[N], ds[N], p[N], in[N], d[N], s[N], pr[N], nx[N << 1], to[N << 1], tp, tot, n;
long long h[N], g[N], v[N], tg[N << 2], t[N << 2];
void add(int u, int v) {to[++tp] = v; nx[tp] = pr[u]; pr[u] = tp;}
void adds(int u, int v) {add(u, v); add(v, u);}
void Dfs1(int u, int f) {
    ::f[u] = f; de[u] = de[f] + 1; s[u] = 1;
    for(int i = pr[u]; i; i = nx[i]) 
    if(to[i] != f) {
        Dfs1(to[i], u); s[u] += s[to[i]];
        if(s[to[i]] > s[ds[u]]) ds[u] = to[i];
        h[u] += g[to[i]];
    }
    if(s[u] == 1) h[u] = 1e18;
    g[u] = std::min(v[u], h[u]);
}
void Dfs2(int u, int c) {
    d[u] = c; p[in[u] = ++tot] = u; if(!ds[u]) return; Dfs2(ds[u], c);
    for(int i = pr[u]; i; i = nx[i]) 
    if(to[i] != f[u] && to[i] != ds[u])
        Dfs2(to[i], to[i]);
}
void T(int p, int v) {tg[p] += v; t[p] -= v;}
void Push(int p) {if(tg[p]) T(ls, tg[p]), T(rs, tg[p]), tg[p] = 0;}
void Up(int p) {t[p] = std::min(t[ls], t[rs]);}
void Build(int p, int L, int R) {
    if(L == R) {tg[p] = h[::p[L]]; t[p] = v[::p[L]] - h[::p[L]]; return ;}
    int m = L + R >> 1; Build(ls, L, m); Build(rs, m + 1, R); Up(p);
}
LL Que(int p, int L, int R, int v) {
    if(L == R) return tg[p];
    int m = L + R >> 1; Push(p); 
    return v <= m ? Que(ls, L, m, v) : Que(rs, m + 1, R, v);
}
void Modv(int p, int L, int R, int v) {
    if(L == R) return void(t[p] = ::v[::p[p]] - h[::p[p]]);
    int m = L + R >> 1; Push(p);
    v <= p ? Modv(ls, L, m, v) : Modv(rs, m + 1, R, v);
    Up(p);
}
void Modh(int p, int L, int R, int st, int ed, LL w) {
    if(L == st && ed == R) return T(p, w);
    int m = L + R >> 1; Push(p);
    if(ed > m) Modh(rs, m + 1, R, std::max(st, m + 1), ed, w);
    if(st <= m) Modh(ls, L, m, st, std::min(ed, m), w);
    Up(p);
}
LL Dp(int u) {return std::min(v[u], Que(1, 1, n, in[u]));}
void Chain(int u, int v, LL w) {
    for(;d[u] != d[v]; u = f[d[u]]) Modh(1, 1, n, in[d[u]], in[u], w);
    Modh(1, 1, n, in[v], in[u], w);
}
int Get(int p, int L, int R, int st, int ed, LL w) {
    if(L == st && ed == R) {
        if(t[p] >= w) return L;
        if(L == R) return 0;
    }
    Push(p); int m = L + R >> 1, t = -1;
    if(ed > m) t = Get(rs, m + 1, R, std::max(st, m + 1), ed, w);
    if(~t && t != m + 1) return t;
    if(st <= m) t = Get(ls, L, m, st, std::min(ed, m), w);
    return t ? t : m + 1;
}
int Get(int u, LL w) {
    int r = u; u = f[u];
    for(;u; r = d[u], u = f[r]) {
        int s = Get(1, 1, n, in[d[u]], in[u], w);
        if(!s) break; if(s > in[d[u]]) return p[s];
    }
    return r;
}
void Mod(int u, LL w) {
    LL t = Dp(u); v[u] += w; Modv(1, 1, n, in[u]); 
    LL d = Dp(u) - t; if(!d) return ;
    for(;f[u];)  {
        int v = Get(u, d);
        if(u != v) Chain(f[u], v, d);
        u = f[v]; if(!u) return ;
        t = Dp(u); Modh(1, 1, n, in[u], in[u], d);
        d = Dp(u) - t; if(!d) return;
    } 
}
int main() {
    n = ri(); for(int i = 1; i <= n; ++i) v[i] = ri();
    for(int i = 1;i < n; ++i) adds(ri(), ri());
    Dfs1(1, 0); Dfs2(1, 1); Build(1, 1, n);
    for(int m = ri();m--;) {
        char op = getchar(); for(;op != 'C' && op != 'Q'; op = getchar()) ;
        int x = ri(); 
        if(op == 'Q') printf("%lld\n", Dp(x));
        else Mod(x, ri());
    }
    return 0;
}


LCT

#include<bits/stdc++.h>
#define ls ch[p][0]
#define rs ch[p][1]
const int N = 2e5 + 10;
typedef long long LL;
int ri() {
    char c = getchar(); int x = 0, f = 1; for(;c < '0' || c > '9'; c = getchar()) if(c == '-') f = -1;
    for(;c >= '0' && c <= '9'; c = getchar()) x = (x << 1) + (x << 3) - '0' + c; return x * f;
}
struct Maxtir {
    LL m[2][2];
    LL * operator [] (int x) {return m[x];}
    void Init(LL g, LL v) {
        m[0][0] = g; m[1][0] = v; 
        m[0][1] = m[1][1] = 0;
    }
    Maxtir operator * (Maxtir b) {
        Maxtir c;
        c[0][0] = std::min(m[0][0] + b[0][0], m[0][1] + b[1][0]);
        c[0][1] = std::min(m[0][0] + b[0][1], m[0][1] + b[1][1]);
        c[1][0] = std::min(m[1][0] + b[0][0], m[1][1] + b[1][0]);
        c[1][1] = std::min(m[1][0] + b[0][1], m[1][1] + b[1][1]);
        return c;
    }
}f[N];
int ch[N][2], fa[N], pr[N], to[N << 1], nx[N << 1], tp, n, m;
LL g[N], val[N];
void add(int u, int v) {to[++tp] = v; nx[tp] = pr[u]; pr[u] = tp;}
void adds(int u, int v) {add(u, v); add(v, u);}
bool wh(int p) {return ch[fa[p]][1] == p;}
bool Ir(int p) {return ch[fa[p]][0] != p && ch[fa[p]][1] != p;}
void Up(int p) {
    f[p].Init(g[p], val[p]);
    if(ls) f[p] = f[p] * f[ls];
    if(rs) f[p] = f[rs] * f[p];
}
void Rotate(int p) {
    int f = fa[p], g = fa[f], c = wh(p);
    if(!Ir(f)) ch[g][wh(f)] = p; fa[p] = g;
    ch[f][c] = ch[p][c ^ 1]; if(ch[f][c]) fa[ch[f][c]] = f;
    ch[p][c ^ 1] = f; fa[f] = p; Up(f);
}
void Splay(int p) {
    for(;!Ir(p); Rotate(p))
        if(!Ir(fa[p]))
            Rotate(wh(p) == wh(fa[p]) ? fa[p] : p);
    Up(p);
}
LL Ans(int p) {return std::min(f[p][0][0], f[p][1][0]);}
void Access(int u) {
    for(int p = u, pr = 0; p; pr = p, p = fa[p]) {
        Splay(p);
        if(rs)
        	g[p] += Ans(rs);
        if(pr)
            g[p] -= Ans(pr);
        rs = pr;
        Up(p);
    }
    Splay(u);
}
void Dp(int u, int fa) {
    ::fa[u] = fa; bool lf = true;
    for(int i = pr[u]; i; i = nx[i])
        if(to[i] != fa)
            Dp(to[i], u),
            g[u] += Ans(to[i]), 
			lf = false;
    if(lf) g[u] = 1e18;
    f[u].Init(g[u], val[u]);
}
int main() {
    n = ri();
    for(int i = 1;i <= n; ++i)
        val[i] = ri();
    for(int i = 1;i < n; ++i)
        adds(ri(), ri());
    Dp(1, 0);
    for(int m = ri();m--;) {
    	char op = getchar(); for(;op != 'Q' && op != 'C'; op = getchar()) ;
    	int u = ri(); Access(u);
    	if(op == 'Q') 
			printf("%lld\n", std::min(g[u], val[u]));
    	else
	        val[u] += ri(), Up(u);
    }
    return 0;
}
展开阅读全文

没有更多推荐了,返回首页