Aizu 2450 Do use segment tree 树链剖分

题意:

给出一棵\(n(1 \leq n \leq 200000)\)个节点的树,每个节点有一个权值。
然后有\(2\)种操作:

  • \(1 \, a \, b \, c\):将路径\(a \to b\)上的所有点的权值都变为\(c\)
  • \(2 \, a \, b \, c\):查询路径\(a \to b\)的权值和最大的非空连续子序列
分析:

首先要树链剖分,将问题转为线性的问题:
给出一个序列,查询给定区间\([L,R]\)的最大非空连续子序列。
线段树最重要的一点就是可以由左右子区间的合并得到父亲节点的区间信息。
这里维护区间的四个信息:

  • \(sum\):就是区间的所有元素和
  • \(pre\):区间的最大前缀和
  • \(suf\):区间的最大后缀和
  • \(sub\):区间的最大子区间和,也正是题目所求的

区间合并可以这样合并:

  • \(sum_f=sum_l+sum_r\)
  • \(pre_f=max \{ pre_l, sum_l + pre_r \}\),最大前缀可能在左子区间,可能跨过区间中点
  • \(suf_f=max \{ suf_r, suf_l+sum_r \}\),最大后缀可能在右子区间,可能跨过区间中点
  • \(sub_f=max \{ sub_l, sub_r, suf_l+pre_r \}\),最大子区间可能在左子区间,可能在右子区间,也可能跨过区间中点,就是左子区间的最大后缀与右子区间的最大前缀拼接起来

然后再将问题转移到树上,就是简单的一段一段的区间合并就行了。
注意区间合并的方向,查询的时候是将两个顶点向着\(LCA\)往上跳,注意最后合并的时候将区间翻转一下。

最后如果默认将节点\(1\)作为根节点开始\(DFS\)的话会爆栈,所以我们将\(\left \lceil \frac{n}{2} \right \rceil\)作为根就可以了。
这么鸡贼的做法,我猜肯定是出过题的人想出来的 :)

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int maxn = 200000 + 10;
const int INF = 0x3f3f3f3f;

struct Edge
{
    int v, nxt;
    Edge(int v = 0, int nxt = 0): v(v), nxt(nxt) {}
};

int ecnt, head[maxn];
Edge edges[maxn * 2];

void AddEdge(int u, int v) {
    edges[ecnt] = Edge(v, head[u]);
    head[u] = ecnt++;
}

int n, q, w[maxn];

int sz[maxn], fa[maxn], son[maxn], dep[maxn];

void dfs(int u) {
    sz[u] = 1; son[u] = 0;
    for(int i = head[u]; ~i; i = edges[i].nxt) {
        int v = edges[i].v;
        if(v == fa[u]) continue;
        fa[v] = u;
        dep[v] = dep[u] + 1;
        dfs(v);
        sz[u] += sz[v];
        if(sz[v] > sz[son[u]]) son[u] = v;
    }
}

int top[maxn], id[maxn], pos[maxn], tot;

void dfs2(int u, int tp) {
    id[u] = ++tot;
    pos[tot] = u;
    top[u] = tp;
    if(!son[u]) return;
    dfs2(son[u], tp);
    for(int i = head[u]; ~i; i = edges[i].nxt) {
        int v = edges[i].v;
        if(v == fa[u] || v == son[u]) continue;
        dfs2(v, v);
    }
}

struct Node
{
    int sum, pre, suf, sub;
    Node() {}
    Node(int a, int b, int c, int d): sum(a), pre(b), suf(c), sub(d) {}
    Node operator + (const Node& t) const {
        Node ans;
        ans.sum = sum + t.sum;
        ans.pre = max(pre, sum + t.pre);
        ans.suf = max(t.suf, t.sum + suf);
        ans.sub = max(max(sub, t.sub), suf + t.pre);
        return ans;
    }
};

Node T[maxn << 2];
int set[maxn << 2];

void build(int o, int L, int R) {
    if(L == R) {
        T[o].sum = T[o].pre = T[o].suf = T[o].sub = w[pos[L]];
        return;
    }
    int M = (L + R) / 2;
    build(o<<1, L, M);
    build(o<<1|1, M+1, R);
    T[o] = T[o<<1] + T[o<<1|1];
}

void pushdown(int o, int L, int R) {
    if(set[o] != INF) {
        set[o<<1] = set[o<<1|1] = set[o];
        int lc = o<<1, rc = o<<1|1;
        int M = (L + R) / 2;
        T[lc].sum = set[o] * (M - L + 1);
        T[rc].sum = set[o] * (R - M);
        T[lc].pre = T[lc].suf = T[lc].sub = set[o] > 0 ? T[lc].sum : set[o];
        T[rc].pre = T[rc].suf = T[rc].sub = set[o] > 0 ? T[rc].sum : set[o];
        set[o] = INF;
    }
}

void update(int o, int L, int R, int qL, int qR, int v) {
    if(qL <= L && R <= qR) {
        set[o] = v;
        T[o].sum = (R - L + 1) * v;
        T[o].pre = T[o].suf = T[o].sub = v > 0 ? T[o].sum : v;
        return;
    }
    pushdown(o, L, R);
    int M = (L + R) / 2;
    if(qL <= M) update(o<<1, L, M, qL, qR, v);
    if(qR > M) update(o<<1|1, M+1, R, qL, qR, v);
    T[o] = T[o<<1] + T[o<<1|1];
}

void UPDATE(int u, int v, int val) {
    int t1 = top[u], t2 = top[v];
    while(t1 != t2) {
        if(dep[t1] < dep[t2]) { swap(u, v); swap(t1, t2); }
        update(1, 1, n, id[t1], id[u], val);
        u = fa[t1]; t1 = top[u];
    }
    if(dep[u] > dep[v]) swap(u, v);
    update(1, 1, n, id[u], id[v], val);
}

Node query(int o, int L, int R, int qL, int qR) {
    if(qL <= L && R <= qR) return T[o];
    pushdown(o, L, R);
    int M = (L + R) / 2;
    if(qR <= M) return query(o<<1, L, M, qL, qR);
    else if(qL > M) return query(o<<1|1, M+1, R, qL, qR);
    else return query(o<<1, L, M, qL, qR) + query(o<<1|1, M+1, R, qL, qR);
}

void updateans(Node& q, bool& flag, int L, int R) {
    if(!flag) { q = query(1, 1, n, L, R); flag = true; }
    else q = query(1, 1, n, L, R) + q;
}

int QUERY(int u, int v) {
    Node q1, q2;
    bool flag1 = false, flag2 = false;
    int t1 = top[u], t2 = top[v];
    while(t1 != t2) {
        if(dep[t1] > dep[t2]) {
            updateans(q1, flag1, id[t1], id[u]);
            u = fa[t1]; t1 = top[u];
        } else {
            updateans(q2, flag2, id[t2], id[v]);;
            v = fa[t2]; t2 = top[v];
        }
    }
    if(dep[u] > dep[v]) updateans(q1, flag1, id[v], id[u]);
    else updateans(q2, flag2, id[u], id[v]);
    if(!flag1) return q2.sub;
    if(!flag2) return q1.sub;
    swap(q1.pre, q1.suf);
    return (q1 + q2).sub;
}

int main()
{
    scanf("%d%d", &n, &q);
    for(int i = 1; i <= n; i++) scanf("%d", w + i);

    ecnt = 0;
    memset(head, -1, sizeof(head));
    for(int i = 1; i < n; i++) {
        int u, v; scanf("%d%d", &u, &v);
        AddEdge(u, v); AddEdge(v, u);
    }

    int root = (n + 1) / 2;

    dfs(root);
    tot = 0;
    dfs2(root, root);
    
    memset(set, 0x3f, sizeof(set));
    build(1, 1, n);

    while(q--) {
        int op, a, b, c; scanf("%d%d%d%d", &op, &a, &b, &c);
        if(op == 1) UPDATE(a, b, c);
        else printf("%d\n", QUERY(a, b));
    }

    return 0;
}

转载于:https://www.cnblogs.com/AOQNRMGYXLMV/p/5303355.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值