ACM数据结构模板

数据结构

0x01 表达式求值(栈的应用)

//(如若第一个数为负数,可以在其之前加0,注意防止爆int)
stack<int> num;
stack<char> op;
void eval()
{
    auto b = num.top(); num.pop();
    auto a = num.top(); num.pop();
    auto c = op.top(); op.pop();
    int x;
    if (c == '+') x = a + b;
    else if (c == '-') x = a - b;
    else if (c == '*') x = a * b;
    else x = a / b;
    num.push(x);
}
int get_ans(string str) // 举例:输入:(2+2)*(1+1) 输出:8
{
    unordered_map<char, int> pr = {{'+', 1}, {'-', 1}, {'*', 2}, {'/', 2}};
    for (int i = 0; i < str.size() - 1; i ++ )
    {
        auto c = str[i];
        if (isdigit(c))
        {
            int x = 0, j = i;
            while (j < str.size() && isdigit(str[j]))
                x = x * 10 + str[j ++ ] - '0';
            i = j - 1;
            num.push(x);
        }
        else if (c == '(') op.push(c);
        else if (c == ')') 
        {
            while (op.top() != '(') eval();
            op.pop();
        }
        else 
        {
            while (op.size() && op.top() != '(' && pr[op.top()] >= pr[c]) 
                eval();
            op.push(c);
        }
    }
    while (op.size()) eval();
    return num.top();
}

0x02 单调栈

int stk[N], top;
// 保持栈内元素单调,即可O(1)查询第一个比当前的数大(或小)的数
void push(int x) {
    while (top && stk[top] >= x) top -- ;
    // cout << (top ? stk[top] : -1) << ' ';
    stk[ ++ top] = x;
}

0x03 单调队列(滑动窗口求最值)

// min
int hh = 0, tt = -1;
for(int i = 0; i < n; i ++ ) {
    if(hh <= tt && i - k + 1 > q[hh]) hh ++ ;

    while(hh <= tt && a[q[tt]] >= a[i]) tt -- ;
    q[ ++ tt] = i;

    if(i >= k - 1) printf("%d ", a[q[hh]]);
}
// max
hh = 0, tt = -1;
for (int i = 0; i < n; i ++ )
{
    if (hh <= tt && i - k + 1 > q[hh]) hh ++ ;

    while (hh <= tt && a[q[tt]] <= a[i]) tt -- ;
    q[ ++ tt] = i;

    if (i >= k - 1) printf("%d ", a[q[hh]]);
}

0x04 并查集

struct DSU { // 创建:DSU g(n)
    vector<int> p, siz;
    DSU(int n) : p(n), siz(n, 1) { iota(p.begin(), p.end(), 0); }
    int find(int x) {
        if (p[x] != x) p[x] = find(p[x]);
        return p[x];
    }
    bool same(int x, int y) { return find(x) == find(y); }
    bool merge(int x, int y) {
        x = find(x);
        y = find(y);
        if (x == y) return false;
        siz[x] += siz[y];
        p[y] = x;
        return true;
    }
    int size(int x) { return siz[find(x)]; }
};

0x05 树状数组(修改add和query即可进行区间最值查询)

template <typename T> 
struct Fenwick {
    const int n;
    vector<T> tr;
    Fenwick(int n) : n(n), tr(n) {}
    int lowbit(int x) {
        return x & -x;
    }
    void add(int x, T v) {
        for (int i = x; i < n; i += lowbit(i)) tr[i] += v;
    }
    T query(int x) { 
        T res = 0;
        for (int i = x; i; i -= lowbit(i)) res += tr[i];
        return res;
    }
    T rangeSum(int l, int r) { 
        return query(r) - query(l - 1);
    }
};

0x06 线段树

1. 单点修改区间查询
// 需要建立节点Info,合并操作为重载‘+’号
// 单点修改
template<class Info, 
    class Merge = plus<Info>>
class SegmentTree {
    const int n;
    const Merge merge;
    vector<Info> tr;
    void pushup(int u) {
        tr[u] = merge(tr[u << 1], tr[u << 1 | 1]);
    }
    void modify(int u, int l, int r, int x, const Info &v) {
        if (l == r) {
            tr[u] = v;
            return;
        }
        int mid = l + r >> 1;
        if (x <= mid) modify(u << 1, l, mid, x, v);
        else modify(u << 1 | 1, mid + 1, r, x, v);
        pushup(u);
    }
    Info rangeQuery(int u, int l, int r, int L, int R) {
        if (r < L || l > R) return Info();
        if (l >= L && r <= R) return tr[u];
        int mid = l + r >> 1;
        return merge(rangeQuery(u << 1, l, mid, L, R), rangeQuery(u << 1 | 1, mid + 1, r, L, R));
    }
public:
    SegmentTree(int _n) : n(_n), merge(Merge()), tr(_n * 4) {}
    SegmentTree(vector<Info> &init) : SegmentTree(init.size()) {
        function<void(int, int, int)> build = [&](int u, int l, int r) {
            if (l == r) {
                tr[u] = init[r];
                return;
            }
            int mid = l + r >> 1;
            build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
            pushup(u);
        };
        build(1, 0, n - 1);
    }
    void modify(int x, const Info &v) {
        modify(1, 0, n - 1, x, v);
    }
    Info rangeQuery(int l, int r) {
        return rangeQuery(1, 0, n - 1, l, r);
    }
    Info query(int x) {
        return rangeQuery(x, x);
    }
};
2. 区间修改区间查询
// 除了单点修改的操作之外,需要新建Tag为懒标记,以及apply(Info &a, Tag b)和apply(Tag &a, Tag b)两个函数
template<class Info, class Tag,
    class Merge = plus<Info>>
class SegmentTree {
    const int n;
    const Merge merge;
    vector<Info> tr;
    vector<Tag> tag;
    void pushup(int u) {
        tr[u] = merge(tr[u << 1], tr[u << 1 | 1]);
    }
    // ===========================================
    void apply(Info &a, const Tag &b) {
    }
    void apply(Tag &a, const Tag &b) {
    }
    // ===========================================
    void apply(int u, const Tag &v) {
        apply(tr[u], v);
        apply(tag[u], v);
    }
    void pushdown(int u) {
        apply(u << 1, tag[u]);
        apply(u << 1 | 1, tag[u]);
        tag[u] = Tag();
    }
    void modify(int u, int l, int r, int x, const Info &v) {
        if (l == r) {
            tr[u] = v;
            return;
        }
        pushdown(u);
        int mid = l + r >> 1;
        if (x <= mid) modify(u << 1, l, mid, x, v);
        else modify(u << 1 | 1, mid + 1, r, x, v);
        pushup(u);
    }
    void rangeApply(int u, int l, int r, int L, int R, const Tag &v) {
        if (l > R || r < L) return;
        if (l >= L && r <= R) {
            apply(u, v);
            return;
        }
        pushdown(u);
        int mid = l + r >> 1;
        rangeApply(u << 1, l, mid, L, R, v);
        rangeApply(u << 1 | 1, mid + 1, r, L, R, v);
        pushup(u);
    }
    Info rangeQuery(int u, int l, int r, int L, int R) {
        if (l > R || r < L) return Info();
        if (l >= L && r <= R) return tr[u];
        pushdown(u);
        int mid = l + r >> 1;
        auto res = merge(rangeQuery(u << 1, l, mid, L, R), rangeQuery(u << 1 | 1, mid + 1, r, L, R));
        pushup(u);
        return res;
    }
public:
    SegmentTree(int _n) : n(_n), merge(Merge()), tr(_n * 4), tag(_n * 4) {}
    SegmentTree(vector<Info> &init) : SegmentTree(init.size()) {
        function<void(int, int, int)> build = [&](int u, int l, int r) {
            if (l == r) {
                tr[u] = init[r];
                return;
            }
            int mid = l + r >> 1;
            build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
            pushup(u);
        };
        build(1, 0, n - 1);
    }
    void modify(int x, const Info &v) {
        modify(1, 0, n - 1, x, v);
    }
    void rangeApply(int l, int r, const Tag &v) {
        rangeApply(1, 0, n - 1, l, r, v);
    }
    Info query(int x) {
        return rangeQuery(1, 0, n - 1, x, x);
    }
    Info rangeQuery(int l, int r) {
        return rangeQuery(1, 0, n - 1, l, r);
    }
};
3. 可持久化(区间和、区间mex)

大多情况下可持久化线段树为了保持不同版本间树的结构相同一般都采用权值线段树的方式存储数据

int n, m;
struct Node
{
    int l, r;
    int cnt = 0, id;
}tr[4 * N + 20 * N];
int root[N], idx;

// 大多情况不用写,少一个空间常数
int build(int l, int r)
{
    int p = ++ idx;
    if (l == r) return p;
    int mid = l + r >> 1;
    tr[p].l = build(l, mid), tr[p].r = build(mid + 1, r);
    return p;
}

// l、r为节点代表的区间,x为插入的位置
int insert(int p, int l, int r, int id, int x)
{
    int q = ++ idx;
    tr[q] = tr[p];
    if (l == r) {
        tr[q].cnt ++ ;
        tr[q].id = id;
        return q;
    }
    int mid = l + r >> 1;
    if (x <= mid) tr[q].l = insert(tr[p].l, l, mid, id, x);
    else tr[q].r = insert(tr[p].r, mid + 1, r, id, x);
    tr[q].cnt = tr[tr[q].l].cnt + tr[tr[q].r].cnt;
    tr[q].id = min(tr[tr[q].l].id, tr[tr[q].r].id);
    return q;
}

// l、r为节点代表的区间,L、R为查询的区间
int querymex(int q, int l, int r, int L)
{
    if (l == r) return r;
    int mid = l + r >> 1;
    if (tr[tr[q].l].id < L) return querymex(tr[q].l, l, mid, L);
    else return querymex(tr[q].r, mid + 1, r, L);
}
 
int querysum(int q, int p, int l, int r, int L, int R)
{
    if (r < L || l > R) return 0;
    if (l >= L && r <= R) {
        return tr[q].cnt - tr[p].cnt;
    }
    int mid = l + r >> 1;
    int sum = 0;
    if (L <= mid) sum += querysum(tr[q].l, tr[p].l, l, mid, L, R);
    if (R > mid) sum += querysum(tr[q].r, tr[p].r, mid + 1, r, L, R);
    return sum;
}

root[0] = build(1, n + 1);
 
for (int i = 1; i <= n; i ++ ) {
    int x;
    cin >> x;
    if (x > n) x = n + 1; // 区间mex大于n的数显然没有用,随便给个大数就行了
    root[i] = insert(root[i - 1], 1, n + 1, i, x);
}
4. 李超线段树(动态插入直线,单点查询最值)

常用作斜率优化dp

// max
namespace LichaoTree1 {
    const int N = 1000010;
    const double eps = 1e-12;
    struct Interval {
        int l, r;
        int k, b;
        bool flag;
        Interval() { k = 0, b = -1e18, flag = false; }
        Interval(int _l, int _r, int _k, int _b) {
            this->l = _l, this->r = _r;
            this->k = _k, this->b = _b;
            this->flag = true;
        }
        int calc(int x) { return k * x + b; }
        double cross(const Interval &rhs) {
            return (double)(b - rhs.b) / (rhs.k - k);
        }
    }tr[N * 4];
    void update(int u, int l, int r, Interval k) {
        if (l >= k.l && r <= k.r) {
            if (!tr[u].flag) {
                tr[u] = k;
            } else if (k.calc(l) > tr[u].calc(l) && k.calc(r) > tr[u].calc(r)) {
                tr[u] = k;
            } else if (k.calc(l) > tr[u].calc(l) || k.calc(r) > tr[u].calc(r)) {
                int mid = l + r >> 1;
                if (k.calc(mid) > tr[u].calc(mid)) {
                    swap(k, tr[u]);
                }
                if (tr[u].cross(k) - mid < eps) update(u << 1, l, mid, k);
                else update(u << 1 | 1, mid + 1, r, k);
            }
        } else {
            int mid = l + r >> 1;
            if (l <= r) update(u << 1, l, mid, k);
            else update(u << 1 | 1, mid + 1, r, k);
        }
    }
    int query(int u, int l, int r, int x) {
        int res = tr[u].calc(x);
        if (l == r) return res;
        int mid = l + r >> 1;
        if (x <= mid) res = max(res, query(u << 1, l, mid, x));
        else res = max(res, query(u << 1 | 1, mid + 1, r, x));
        return res;
    }
    void clear(int u, int l, int r) {
        if (tr[u].flag) {
            tr[u] = Interval();
        }
        if (l == r) return;
        int mid = l + r >> 1;
        if (tr[u << 1].flag) clear(u << 1, l, mid);
        if (tr[u << 1 | 1].flag) clear(u << 1 | 1, mid + 1, r);
    }
}
// min
namespace LichaoTree2 {
    const int N = 1000010;
    const double eps = 1e-12;
    struct Interval {
        int l, r;
        int k, b;
        bool flag;
        Interval() { k = 0, b = 1e18, flag = false; }
        Interval(int _l, int _r, int _k, int _b) {
            this->l = _l, this->r = _r;
            this->k = _k, this->b = _b;
            this->flag = true;
        }
        int calc(int x) { return k * x + b; }
        double cross(const Interval &rhs) {
            return (double)(b - rhs.b) / (rhs.k - k);
        }
    }tr[N * 4];
    void update(int u, int l, int r, Interval k) {
        if (l >= k.l && r <= k.r) {
            if (!tr[u].flag) {
                tr[u] = k;
            } else if (k.calc(l) < tr[u].calc(l) && k.calc(r) < tr[u].calc(r)) {
                tr[u] = k;
            } else if (k.calc(l) < tr[u].calc(l) || k.calc(r) < tr[u].calc(r)) {
                int mid = l + r >> 1;
                if (k.calc(mid) < tr[u].calc(mid)) {
                    swap(k, tr[u]);
                }
                if (tr[u].cross(k) - mid < eps) update(u << 1, l, mid, k);
                else update(u << 1 | 1, mid + 1, r, k);
            }
        } else {
            int mid = l + r >> 1;
            if (l <= r) update(u << 1, l, mid, k);
            else update(u << 1 | 1, mid + 1, r, k);
        }
    }
    int query(int u, int l, int r, int x) {
        int res = tr[u].calc(x);
        if (l == r) return res;
        int mid = l + r >> 1;
        if (x <= mid) res = min(res, query(u << 1, l, mid, x));
        else res = min(res, query(u << 1 | 1, mid + 1, r, x));
        return res;
    }
    void clear(int u, int l, int r) {
        if (tr[u].flag) {
            tr[u] = Interval();
        }
        if (l == r) return;
        int mid = l + r >> 1;
        if (tr[u << 1].flag) clear(u << 1, l, mid);
        if (tr[u << 1 | 1].flag) clear(u << 1 | 1, mid + 1, r);
    }
}
5. 势能线段树

若势能线段树节点数为 N N N,操作数为 M M M,则时间复杂度为 O ( M × ∣ 0 势能时线段树操作时间复杂度 ∣ + N × ∣ 节点势能上限降低至 0 势能时间复杂度 ∣ + M × ∣ 线段树单次操作影响到的节点数目 ∣ × ∣ 操作额外提供的势能 ∣ ) O(M×∣0势能时线段树操作时间复杂度∣+N×∣节点势能上限降低至0势能时间复杂度∣+M×∣线段树单次操作影响到的节点数目∣×∣操作额外提供的势能∣) O(M×0势能时线段树操作时间复杂度+N×节点势能上限降低至0势能时间复杂度+M×线段树单次操作影响到的节点数目×操作额外提供的势能)

势能举例: lowbit,开根,除等

例题1

共有 m m m 只蚊子,每一只蚊子在 [ 1 , n ] [1,n] [1,n] 内的一点,并且第i只蚊子具有它的体形 a i a_i ai
你会拍 k k k 次蚊子,第i次在区间 [ L i , R i ] [L_i, R_i] [Li,Ri] 内拍死体形大于等于 p i p_i pi 的蚊子,请按顺序输出每次拍死蚊子体形大小的总和。

#include <bits/stdc++.h>
 
#define int long long
 
using namespace std;
 
using i64 = long long;
 
const int N = 100010;
 
int n, m, k;
vector<int> w[N];
struct Node
{
    int l, r;
    int minv, maxv;
    int sum;
    priority_queue<int> q;
    bool flag;
}tr[N * 4];
 
void pushup(int u)
{
    tr[u].minv = min(tr[u << 1].minv, tr[u << 1 | 1].minv);
    tr[u].maxv = max(tr[u << 1].maxv, tr[u << 1 | 1].maxv);
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
 
void pushdown(int u)
{
    if (tr[u].flag) {
        tr[u << 1].flag = tr[u << 1 | 1].flag = true;
        tr[u << 1].sum = tr[u << 1 | 1].sum = 0;
        tr[u << 1].minv = tr[u << 1 | 1].minv = 1e9;
        tr[u << 1].maxv = tr[u << 1 | 1].maxv = 0;
        tr[u].flag = false;
    }
}
 
void build(int u, int l, int r)
{
    tr[u] = {l, r, (int)1e9, 0, 0};
    for (int i = l; i <= r; i ++ ) {
        for (int j = 0; j < w[i].size(); j ++ ) {
            tr[u].minv = min(tr[u].minv, w[i][j]);
            tr[u].maxv = max(tr[u].maxv, w[i][j]);
            tr[u].sum += w[i][j];
            tr[u].q.push(w[i][j]);
        }
    }
    if (l >= r) return;
    int mid = l + r >> 1;
    build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
}
 
int query(int u, int l, int r, int x)
{
    if (tr[u].l >= l && tr[u].r <= r) {
        if (tr[u].maxv < x) {
            return 0;
        } else if (tr[u].minv >= x) {
            int res = tr[u].sum;
            tr[u].sum = 0;
            tr[u].maxv = 0;
            tr[u].minv = 1e9;
            tr[u].flag = true;
            return res;
        } else if (tr[u].l == tr[u].r) {
            auto &q = tr[u].q;
            int res = 0;
            while (q.size() && q.top() >= x) {
                tr[u].sum -= q.top();
                res += q.top();
                q.pop();
            }
            if (q.size()) {
                tr[u].maxv = q.top();
            } else {
                tr[u].maxv = 0;
                tr[u].minv = 1e9;
            }
            return res;
        }
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    int res = 0;
    if (l <= mid) res += query(u << 1, l, r, x);
    if (r > mid) res += query(u << 1 | 1, l, r, x);
    pushup(u);
    return res;
}
 
signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
 
    cin >> n >> m >> k;
    while (m -- ) {
        int id, x;
        cin >> id >> x;
        w[id].push_back(x);
    }
    build(1, 1, n);
 
    while (k -- ) {
        int l, r, x;
        cin >> l >> r >> x;
        cout << query(1, l, r, x) << "\n";
    }
 
    return 0;
}

例题2

  1. 给定区间 [ l , r ] [l,r] [l,r]对区间中所有数字开根号向下取整,即 a i = ⌊ a i ⌋ a_i= \lfloor \sqrt{a_i}⌋ ai=ai ( l ≤ i ≤ r ) (l≤i≤r) (lir)
  2. 给定区间 [ l , r ] [l,r] [l,r],对区间中每个数字加上一个正整数 x x x
  3. 查询给定区间 [ l , r ] [l,r] [l,r] 的元素和,即求 ∑ i = l r a i \sum_{i=l}^{r}a_{i} i=lrai
struct Node
{
    int l, r;
    int minv, maxv;
    int sum;
    int add;
    int len() { return r - l + 1; }
}tr[N * 4];
 
void pushup(int u)
{
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
    tr[u].minv = min(tr[u << 1].minv, tr[u << 1 | 1].minv);
    tr[u].maxv = max(tr[u << 1].maxv, tr[u << 1 | 1].maxv);
}
 
void split(int u, int add)
{
    tr[u].sum += tr[u].len() * add;
    tr[u].minv += add;
    tr[u].maxv += add;
    tr[u].add += add;
}
 
void pushdown(int u)
{
    if (tr[u].add != 0) {
        split(u << 1, tr[u].add);
        split(u << 1 | 1, tr[u].add);
        tr[u].add = 0;
    }
}
 
void build(int u, int l, int r)
{
    tr[u] = {l, r, w[r], w[r], w[r], 0};
    if (l >= r) return;
    int mid = l + r >> 1;
    build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
    pushup(u);
}
 
void modifysqrt(int u, int l, int r)
{
    if (tr[u].l >= l && tr[u].r <= r && tr[u].maxv == tr[u].minv) {
        int delta = tr[u].maxv - (int)sqrt(tr[u].maxv);
        split(u, -delta);
        return;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if (l <= mid) modifysqrt(u << 1, l, r);
    if (r > mid) modifysqrt(u << 1 | 1, l, r);
    pushup(u);
}
 
void modifyadd(int u, int l, int r, int x)
{
    if (tr[u].l >= l && tr[u].r <= r) {
        split(u, x);
        return;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if (l <= mid) modifyadd(u << 1, l, r, x);
    if (r > mid) modifyadd(u << 1 | 1, l, r, x);
    pushup(u);
}
 
int query(int u, int l, int r)
{
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    int res = 0;
    if (l <= mid) res += query(u << 1, l, r);
    if (r > mid) res += query(u << 1 | 1, l, r);
    pushup(u);
    return res;
}

0x07 平衡树

平衡树有两种模式,一种是维护键值有序,一种是维护下标有序

1. FHQ

FHQ有两种分裂模式,一种是按键值分裂,主要用于维护值有序(类比set),一种是按排名分裂,可以支持区间反转等操作。

class FHQ {
private:
    struct Node {
        int l, r;
        int key, val;
        int sz;
    };
    vector<Node> tr;
    stack<int> stk;
    int root, idx;
    int x, y, z;
    int get_node(int key) {
        int u = stk.top();
        stk.pop();
        tr[u].l = tr[u].r = 0;
        tr[u].key = key;
        tr[u].val = rand();
        tr[u].sz = 1;
        return u;
    }
    void insert(int key) {
        split(root, key, x, y);
        root = merge(merge(x, get_node(key)), y);
    }
    void pushup(int u) {
        tr[u].sz = tr[tr[u].l].sz + tr[tr[u].r].sz + 1;
    }
    // split by key
    void split(int u, int key, int &x, int &y) {
        if (!u) {
            x = y = 0;
        } else {
            if (tr[u].key <= key) {
                x = u;
                split(tr[u].r, key, tr[u].r, y);
            } else {
                y = u;
                split(tr[u].l, key, x, tr[u].l);
            }
            pushup(u);
        }
    }
    int merge(int x, int y) {
        if (!x || !y) return x + y;
        if (tr[x].val > tr[y].val) {
            tr[x].r = merge(tr[x].r, y);
            pushup(x);
            return x;
        } else {
            tr[y].l = merge(x, tr[y].l);
            pushup(y);
            return y;
        }
    }
    void build(int w[], int l, int r) {
        for (int i = l; i <= r; i ++ ) {
            insert(w[i]);
        }
    }
    int get_rank(int key) {
        split(root, key - 1, x, y);
        int rank = tr[x].sz + 1;
        root = merge(x, y);
        return rank;
    }
    void change(int key, int nkey) {
        split(root, key, x, z);
        split(x, key - 1, x, y);
        stk.push(y);
        y = merge(tr[y].l, tr[y].r);
        root = merge(merge(x, y), z);
    }
    int get_pre(int key) {
        split(root, key - 1, x, y);
        int u = x;
        while (tr[u].r) u = tr[u].r;
        int pre = tr[u].key;
        if (pre == 0) pre = 0;
        root = merge(x, y);
        return pre;
    }
    int get_nxt(int key) {
        split(root, key, x, y);
        int u = y;
        while (tr[u].l) u = tr[u].l;
        int nxt = tr[u].key;
        if (nxt == 0) nxt = 1e9;
        root = merge(x, y);
        return nxt;
    }
    void output(int u) {
        if (!u) return;
        output(tr[u].l);
        cout << tr[u].key << " ";
        output(tr[u].r);
    }
public:
    FHQ() {}
    FHQ(int w[], int l, int r) {
        root = 0, idx = 0;
        tr.resize(r - l + 1 + 10);
        for (int i = 1; i < tr.size(); i ++ ) stk.push(i);
        build(w, l, r);
    }
    int get_k(int key) {
        return get_rank(key);
    }
    void change(int w[], int id, int x) {
        change(w[id], x);
        insert(x);
    }
    int get_prev(int key) {
        return get_pre(key);
    }
    int get_next(int key) {
        return get_nxt(key);
    }
    void output() {
        output(root);
        cout << "\n";
    }
};
struct Node {
    int l, r;
    int key, val;
    int sz, rev;
}tr[N];
int root, idx;
int x, y, z;

int get_node(int key)
{
    int u = ++ idx;
    tr[u].key = key;
    tr[u].val = rand();
    tr[u].sz = 1;
    tr[u].rev = 0;
    return u;
}
 
void pushup(int u)
{
    tr[u].sz = tr[tr[u].l].sz + tr[tr[u].r].sz + 1;
}
 
void pushdown(int u)
{
    if (tr[u].rev) {
        swap(tr[u].l, tr[u].r);
        tr[tr[u].l].rev ^= 1;
        tr[tr[u].r].rev ^= 1;
        tr[u].rev = 0;
    }
}
// split by rank
void split(int u, int sz, int &x, int &y)
{
    if (!u) {
        x = y = 0;
    } else {
        pushdown(u);
        if (tr[tr[u].l].sz < sz) {
            x = u;
            split(tr[u].r, sz - tr[tr[u].l].sz - 1, tr[u].r, y);
        } else {
            y = u;
            split(tr[u].l, sz, x, tr[u].l);
        }
        pushup(u);
    }
}
 
int merge(int x, int y)
{
    if (!x || !y) return x + y;
    if (tr[x].val > tr[y].val) {
        pushdown(x);
        tr[x].r = merge(tr[x].r, y);
        pushup(x);
        return x;
    } else {
        pushdown(y);
        tr[y].l = merge(x, tr[y].l);
        pushup(y);
        return y;
    }
}
// 区间反转
void reverse(int l, int r)
{
    split(root, l - 1, x, y);
    split(y, r - l + 1, y, z);
    tr[y].rev ^= 1;
    root = merge(merge(x, y), z);
}
 
void output(int u)
{
    if (!u) return;
    pushdown(u);
    output(tr[u].l);
    cout << tr[u].key << " ";
    output(tr[u].r);
}

FHQ 可以可持久化(待补全)

2. splay
class Splay {
private:
    struct Node {
        int p, s[2], v;
        int sz;
        void init(int _v, int _p) {
            v = _v, p = _p;
            sz = 1;
        }
    };
    vector<Node> tr;
    int root = 0, idx = 0;
        
    void pushup(int u) {
        tr[u].sz = tr[tr[u].s[0]].sz + tr[tr[u].s[1]].sz + 1;
    }
    void rotate(int x) {
        int y = tr[x].p, z = tr[y].p;
        int k = tr[y].s[1] == x;
        tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
        tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
        tr[x].s[k ^ 1] = y, tr[y].p = x;
        pushup(y), pushup(x);
    }
    void splay(int x, int k) {
        while (tr[x].p != k) {
            int y = tr[x].p, z = tr[y].p;
            if (z != k) {
                if ((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
                else rotate(y);
            }
            rotate(x);
        }
        if (!k) root = x;
    }
    void insert(int id, int v) {
        int u = root, p = 0;
        while (u) p = u, u = tr[u].s[v > tr[u].v];
        u = (id == 0 ? ++ idx : id);
        if (p) tr[p].s[v > tr[p].v] = u;
        tr[u].init(v, p);
        splay(u, 0);
    }
    int erase(int v) {
        int u = root;
        while (u) {
            if (tr[u].v == v) break;
            else if (tr[u].v < v) u = tr[u].s[1];
            else u = tr[u].s[0];
        }
        splay(u, 0);
        int l = tr[u].s[0], r = tr[u].s[1];
        while (tr[l].s[1]) l = tr[l].s[1];
        while (tr[r].s[0]) r = tr[r].s[0];
        splay(l, 0), splay(r, l);
        u = tr[r].s[0];
        tr[r].s[0] = 0;
        pushup(r), pushup(l);
        return u;
    }
    int get_rank(int v) {
        int u = root, res = 0;
        while (u) {
            if (tr[u].v < v) {
                res += tr[tr[u].s[0]].sz + 1;
                u = tr[u].s[1];
            } else {
                u = tr[u].s[0];
            }
        }
        return res;
    }
    int get_pre(int v) {
        int u = root, res = -inf;
        while (u) {
            if (tr[u].v < v) {
                res = max(res, tr[u].v);
                u = tr[u].s[1];
            } else {
                u = tr[u].s[0];
            }
        }
        return res;
    }
    int get_nxt(int v) {
        int u = root, res = inf;
        while (u) {
            if (tr[u].v > v) {
                res = min(res, tr[u].v);
                u = tr[u].s[0];
            } else {
                u = tr[u].s[1];
            }
        }
        return res;
    }
    void output(int u) {
        if (tr[u].s[0]) output(tr[u].s[0]);
        if (tr[u].v > -inf && tr[u].v < inf) cout << tr[u].v << " ";
        if (tr[u].s[1]) output(tr[u].s[1]);
    }
public:
    Splay() {}
    Splay(int w[], int l, int r) {
        tr.resize(r - l + 10);
        insert(0, -inf);
        for (int i = l; i <= r; i ++ ) {
            insert(0, w[i]);
        }
        insert(0, inf);
    }
    int get_k(int key) {
        return get_rank(key) - 1;
    }
    void change(int key, int nkey) {
        insert(erase(key), nkey);
    }
    int pre(int key) {
        return get_pre(key);
    }
    int nxt(int key) {
        return get_nxt(key);
    }
    void output() {
        output(root);
        cout << "\n";
    }
};

0x08 莫队

#include <bits/stdc++.h>

using namespace std;

using i64 = long long;

const int N = 50010, M = 200010, S = 10000010;

int n, m, len;
int w[N], ans[M];
struct Query
{
    int id, l, r;
}q[M];
int cnt[S];

int get(int x)
{
    return x / len;
}

bool cmp(const Query &a, const Query &b)
{
    int i = get(a.l), j = get(b.l);
    if (i != j) return i < j;
    return a.r < b.r;
}

void add(int x, int &res)
{
    if (!cnt[x]) res ++ ;
    cnt[x] ++ ;
}

void del(int x, int &res)
{
    cnt[x] -- ;
    if (!cnt[x]) res -- ;
}

signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    cin >> n;
    for (int i = 1; i <= n; i ++ ) cin >> w[i];
    cin >> m;
    len = max(1, (int)sqrt((double)n * n / m));

    for (int i = 0; i < m; i ++ ) {
        int l, r;
        cin >> l >> r;
        q[i] = {i, l, r};
    }
    sort(q, q + m, cmp);

    for (int k = 0, i = 0, j = 1, res = 0; k < m; k ++ ) {
        int id = q[k].id, l = q[k].l, r = q[k].r;
        while (i < r) add(w[ ++ i], res);
        while (i > r) del(w[i -- ], res);
        while (j < l) del(w[j ++ ], res);
        while (j > l) add(w[ -- j], res);
        ans[id] = res;
    }

    for (int i = 0; i < m; i ++ ) cout << ans[i] << "\n";

    return 0;
}

0x09 CDQ分治

CDQ分治的作用实际上是为了解决偏序问题中的一维而出现的,一般用来配合数据结构解决多维偏序问题,当然其本身套娃也可以解决多维偏序。

  1. 二维偏序:第一维排序,第二维归并,每一层递归完 ( l , m i d ) , ( m i d + 1 , r ) (l,mid),(mid+1,r) (l,mid),(mid+1,r)后处理左对右的贡献即可(即归并排序求逆序对)
  2. 三维偏序:第一维排序,第二维归并,某一层递归完 ( l , m i d ) , ( m i d + 1 , r ) (l,mid),(mid+1,r) (l,mid),(mid+1,r)后,左右两边第一维 相对大小 是确定的,将其变为 ( 0 / 1 , b i , c i ) (0/1,b_i,c_i) (0/1,bi,ci)(在左边 a i a_i ai为0右边 a i a_i ai为1),于是 b i b_i bi有序, a i a_i ai可以O(1)判断,又变成了二维偏序。若当前在 ( l , m i d ) (l,mid) (l,mid) 中且 a i = 0 a_i=0 ai=0 c n t + 1 cnt+1 cnt+1,若在 ( m i d + 1 , r ) (mid+1,r) (mid+1,r) a i = 1 a_i=1 ai=1,当前答案加 c n t cnt cnt
  3. 当小于32时可以用状态压缩的方式记录之前维度的大小,过大可以用bitset,每多一个维度复杂度多一个 l o g log log
// 三维偏序(套Fenwick)
#include <bits/stdc++.h>

using namespace std;

using i64 = long long;

const int N = 100010, M = 200010;

int n, m;
struct Data
{
    int a, b, c;
    int s, res;
    bool operator< (const Data &t) const {
        if (a != t.a) return a < t.a;
        if (b != t.b) return b < t.b;
        return c < t.c;
    }
    bool operator== (const Data &t) const {
        return a == t.a && b == t.b && c == t.c;
    }
}q[N], tmp[N];
int tr[M], ans[N];

int lowbit(int x)
{
    return x & -x;
}

void add(int x, int v)
{
    for (int i = x; i < M; i += lowbit(i)) tr[i] += v;
}

int query(int x)
{
    int res = 0;
    for (int i = x; i; i -= lowbit(i)) res += tr[i];
    return res;
}

void merge_sort(int l, int r)
{
    if (l >= r) return;
    int mid = l + r >> 1;
    merge_sort(l, mid), merge_sort(mid + 1, r);
    int i = l, j = mid + 1, k = 0;
    while (i <= mid && j <= r) {
        if (q[i].b <= q[j].b) add(q[i].c, q[i].s), tmp[k ++ ] = q[i ++ ];
        else q[j].res += query(q[j].c), tmp[k ++ ] = q[j ++ ];
    }
    while (i <= mid) add(q[i].c, q[i].s), tmp[k ++ ] = q[i ++ ];
    while (j <= r) q[j].res += query(q[j].c), tmp[k ++ ] = q[j ++ ];
    for (i = l; i <= mid; i ++ ) add(q[i].c, -q[i].s);
    for (i = l, j = 0; j < k; i ++ , j ++ ) q[i] = tmp[j];
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    cin >> n >> m;
    for (int i = 0; i < n; i ++ ) {
        int a, b, c;
        cin >> a >> b >> c;
        q[i] = {a, b, c, 1, 0};
    }
    sort(q, q + n);
    int k = 1;
    for (int i = 1; i < n; i ++ ) {
        if (q[i] == q[k - 1]) q[k - 1].s ++ ;
        else q[k ++ ] = q[i];
    }
    merge_sort(0, k - 1);
    for (int i = 0; i < k; i ++ ) {
        ans[q[i].res + q[i].s - 1] += q[i].s;
    }

    for (int i = 0; i < n; i ++ ) cout << ans[i] << "\n";

    return 0;
}

CDQ分治同时也可以处理dp问题,不过需要注意转移顺序

// rpg
// dp[i][j] = max(dp[x][y] + buff[x][y](i - x + j - y) + val[x][y])
// x<=i, y<=j, 此处dp映射成了一维
const int N = 100010;
 
int n, m;
int len;
struct Data
{
    int x, y, id;
    int buf, val;
}a[N], tmp[N][30];
int dp[N];
 
int get_id(int x, int y)
{
    return (x - 1) * m + y;
}
// 归并排序先求出整个状态,因为dp的转移是中序遍历而归并是后序遍历
void merge_sort(int l, int r, int d)
{
    if (l == r) {
        tmp[r][d] = a[r];
        return;
    }
    int mid = l + r >> 1;
    merge_sort(l, mid, d + 1), merge_sort(mid + 1, r, d + 1);
    int i = l, j = mid + 1, k = l;
    while (i <= mid && j <= r) {
        if (tmp[i][d + 1].y <= tmp[j][d + 1].y) {
            tmp[k ++ ][d] = tmp[i ++ ][d + 1];
        } else {
            tmp[k ++ ][d] = tmp[j ++ ][d + 1];
        }
    }
    while (i <= mid) tmp[k ++ ][d] = tmp[i ++ ][d + 1];
    while (j <= r) tmp[k ++ ][d] = tmp[j ++ ][d + 1];
}
 
void cdqDivAlgorithm(int l, int r, int d)
{
    if (l == r) return;
    int mid = l + r >> 1;
    cdqDivAlgorithm(l, mid, d + 1);
 
    LichaoTree::clear(1, 2, n + m);
    int i = l, j = mid + 1;
    while (i <= mid && j <= r) {
        auto &a = tmp[i][d + 1], &b = tmp[j][d + 1];
        if (a.y <= b.y) {
            LichaoTree::update(1, 2, n + m, LichaoTree::Interval(2, n + m, a.buf, dp[a.id] - a.buf * (a.x + a.y)));
            i ++ ;
        } else {
            dp[b.id] = max(dp[b.id], LichaoTree::query(1, 2, n + m, b.x + b.y) + b.val);
            j ++ ;
        }
    }
    while (j <= r) {
        auto &b = tmp[j][d + 1];
        dp[b.id] = max(dp[b.id], LichaoTree::query(1, 2, n + m, b.x + b.y) + b.val);
        j ++ ;
    }
 
    cdqDivAlgorithm(mid + 1, r, d + 1);
}
 
signed main()
{
    cin >> n >> m;
    for (int i = 1; i <= n; i ++ ) {
        for (int j = 1; j <= m; j ++ ) {
            len ++ ;
            a[len] = {i, j, len};
            // a[ ++ len] = {i, j, len};
        }
    }
    for (int i = 1; i <= n; i ++ ) {
        for (int j = 1; j <= m; j ++ ) {
            cin >> a[get_id(i, j)].buf;
        }
    }
    for (int i = 1; i <= n; i ++ ) {
        for (int j = 1; j <= m; j ++ ) {
            cin >> a[get_id(i, j)].val;
            dp[get_id(i, j)] = a[get_id(i, j)].val;
        }
    }
 
    merge_sort(1, len, 0);
    cdqDivAlgorithm(1, len, 0);
    cout << dp[n * m] << "\n";
 
    return 0;
}

CDQ分治也可以优化dp,如大范围的最长上升子序列问题 ( i < j , a i < a j ) (i <j , a_i < a_j) (i<j,ai<aj) ,只需找出其中存在的偏序关系同时注意转移顺序即可。

0x0A 树链剖分

树链剖分本质就是对dfs序的处理,同时可以求lca,预处理 O ( n ) O(n) O(n)查询 O ( l o g n ) O(logn) O(logn)
但树链剖分全面但不优秀,总结dfs序技巧如下:

  1. 点修改,点查询,就是传说中的数组吗doge

  2. 点修改,子树查询 -> 点修改,区间查询

  3. 子树修改,点查询 -> 区间修改,点查询

  4. 子树修改,子树查询 -> 区间修改,区间查询

  5. 点修改,链查询:设链端点为 a , b a, b a,b l c a ( a , b ) = p lca(a,b)=p lca(a,b)=p f a [ p ] = f p fa[p]=fp fa[p]=fp,我们从上至下在每条链上做前缀和,即转化为 子树修改,点查询(cnt[a]+cnt[b]-cnt[p]-cnt[fp])

  6. 链修改,点查询:同上,从下至上做差分, u 的值为 u 的原值 − u 所有儿子的原值和 u的值为u的原值-u所有儿子的原值和 u的值为u的原值u所有儿子的原值和,(即树上差分),即转化为 点修改(cnt[a]-x,cnt[b]-x,cnt[p]+x,cnt[fp]+x),子树查询

  7. 链修改,子树查询 -> 点修改,区间查询。对于 u u u的子树中的一个点 v v v,把 v v v到根加一个值 w w w,对 u u u的贡献为 w ( d e p v − d e p u + 1 ) = w ( d e p v + 1 ) − w ⋅ d e p u w(depv−depu+1)=w(depv+1)−w⋅depu w(depvdepu+1)=w(depv+1)wdepu,于是分别维护 w ( d e p v + 1 ) w(depv+1) w(depv+1) w w w的值,点修改,区间查询即可。

  8. 子树修改,链查询 -> 区间修改,点查询。对于 v v v的子树中的一个点 u u u,把 v v v的子树加一个值 w w w,对 u u u的贡献为 w ( d e p u − d e p v + 1 ) = w ⋅ d e p u − w ( d e p v − 1 ) w(depu−depv+1)=w⋅depu−w(depv−1) w(depudepv+1)=wdepuw(depv1),于是分别维护 w ( d e p v − 1 ) w(depv−1) w(depv1) w w w的值,区间修改,点查询即可。

  9. 链修改,链查询 -> 树链剖分

综上所述,除了第九种情况之外,其余情况均可 O ( l o g n ) O(logn) O(logn) 进行操作,是比树链剖分 O ( l o g 2 n ) O(log^2n) O(log2n) 优秀一点。

void dfs1(int u, int father, int depth)
{
    dep[u] = depth, fa[u] = father, sz[u] = 1;
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == father) continue;
        dfs1(j, u, depth + 1);
        sz[u] += sz[j];
        if (sz[j] > sz[son[u]]) son[u] = j;
    }
}
void dfs2(int u, int t)
{
    id[u] = ++ cnt, nw[cnt] = w[u], top[u] = t;
    if (!son[u]) return;
    dfs2(son[u], t);
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == fa[u] || j == son[u]) continue;
        dfs2(j, j);
    }
}
void update_path(int u, int v, int k)
{
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        update(1, id[top[u]], id[u], k);
        u = fa[top[u]];
    }
    if (dep[u] < dep[v]) swap(u, v);
    update(1, id[v], id[u], k);
}
i64 query_path(int u, int v)
{
    i64 res = 0;
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        res += query(1, id[top[u]], id[u]);
        u = fa[top[u]];
    }
    if (dep[u] < dep[v]) swap(u, v);
    res += query(1, id[v], id[u]);
    return res;
}
void update_tree(int u, int k)
{
    update(1, id[u], id[u] + sz[u] - 1, k);
}
i64 query_tree(int u)
{
    return query(1, id[u], id[u] + sz[u] - 1);
}

0x0B LCT(维护子树信息版)

LCT是一种维护的森林信息的数据结构(一棵树也是森林)

例题:

给定一棵 n n n 个节点的树。每条边有一种颜色。

f ( u , v ) f(u,v) f(u,v) 表示从 u u u v v v 的路径上,出现且只出现一次的颜色的数量。

∑ u = 2 n ∑ v = 1 u − 1 f ( u , v ) \sum_{u=2}^{n}\sum_{v=1}^{u-1}f(u,v) u=2nv=1u1f(u,v)

对每个颜色算贡献。可以得到如下算法。

枚举每个颜色 c c c,将所有颜色不为 c c c 的边建出来。此时树被划分为若干连通块。

对于每条颜色为 c c c 的边 ( u , v ) (u,v) (u,v),当一条路径一端在 u u u 所在连通块内,另一端在 v v v 所在连通块内时,该边对该路径有贡献。记节点 x x x 所在连通块大小为 s x s_x sx,则该边贡献为 s u × s v s_u×s_v su×sv

考虑优化。

上述过程可以用动态树维护。初始时将所有边加进去。对于每个颜色 c c c,将所有颜色为 c c c 的边断开,计算该颜色贡献,再将断开的边重新连上。

#include <bits/stdc++.h>

using namespace std;

using i64 = long long;

const int N = 5e5 + 10;

int n, m;
struct Node
{
    int s[2], p;
    int v, sv; // v记录虚儿子的节点数,sv记录当前子树的节点数
    int rev; 
}tr[N];

void pushrev(int x)
{
    swap(tr[x].s[0], tr[x].s[1]);
    tr[x].rev ^= 1;
}

void pushup(int x)
{
    tr[x].sv = tr[tr[x].s[0]].sv + tr[x].v + tr[tr[x].s[1]].sv;
}

void pushdown(int x)
{
    if (tr[x].rev) {
        pushrev(tr[x].s[0]), pushrev(tr[x].s[1]);
        tr[x].rev = 0;
    }
}

bool isroot(int x) // 判断x是不是当前splay的根
{
    return tr[tr[x].p].s[0] != x && tr[tr[x].p].s[1] != x;
}

void rotate(int x)
{
    int y = tr[x].p, z = tr[y].p;
    int k = tr[y].s[1] == x;
    if (!isroot(y)) tr[z].s[tr[z].s[1] == y] = x;
    tr[x].p = z;
    tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
    tr[x].s[k ^ 1] = y, tr[y].p = x;
    pushup(y), pushup(x);
}

void splay(int x)
{
    static int stk[N];
    int top = 0, r = x;
    stk[ ++ top] = r;
    while (!isroot(r)) stk[ ++ top] = r = tr[r].p;
    while (top) pushdown(stk[top -- ]);
    while (!isroot(x)) {
        int y = tr[x].p, z = tr[y].p;
        if (!isroot(y)) {
            if ((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
            else rotate(y);
        }
        rotate(x);
    }
}

void access(int x) // 建立一条从根到x的实边,同时将x变为splay的根节点
{
    int z = x;
    for (int y = 0; x; y = x, x = tr[x].p) {
        splay(x);
        tr[x].v -= tr[y].sv; // 因为要将y加入x的splay中,防止重复计算需要减去
        tr[x].v += tr[tr[x].s[1]].sv; // 删除了的要加上
        tr[x].s[1] = y;
        pushup(x);
    }
    splay(z);
}

void makeroot(int x) // 将x变为原树的根节点
{
    access(x);
    pushrev(x);
}

int findroot(int x) // 找到x所在原树的根节点,再将原树的根节点转到splay的根节点
{
    access(x);
    while (tr[x].s[0]) pushdown(x), x = tr[x].s[0];
    splay(x);
    return x;
}

void split(int x, int y) // 给x和y之间的路径建立一棵splay,根节点是y
{
    makeroot(x);
    access(y);
}

void link(int x, int y) // 若x和y不连通,则加入一条x和y之间的边
{
    makeroot(x);
    if (findroot(y) != x) {
        tr[x].p = y;
        // access(y);
        splay(y);
        tr[y].v += tr[x].sv;
    }
}

void cut(int x, int y) // 若x和y之间存在边,则删除
{
    makeroot(x);
    if (findroot(y) == x && tr[y].p == x && !tr[y].s[0]) {
        // tr[x].sv -= tr[y].sv; // 会在pushup时顺便改掉,所以不需要加
        tr[x].s[1] = tr[y].p = 0;
        pushup(x);
    }
}

struct Edge {
    int u, v;
};
vector<Edge> e[N];

int size(int x)
{
    makeroot(x);
    return tr[x].sv;
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    cin >> n;
    for (int i = 1; i <= n; i ++ ) tr[i].v = tr[i].sv = 1;
    int m = 0;
    for (int i = 0; i < n - 1; i ++ ) {
        int u, v, w;
        cin >> u >> v >> w;
        e[w].push_back({u, v});
        m = max(m, w);
        link(u, v);
    }

    i64 ans = 0;
    for (int i = 1; i <= n; i ++ ) {
        for (auto [u, v] : e[i]) cut(u, v);
        for (auto [u, v] : e[i]) {
            ans += 1ll * size(u) * size(v);
        }

        for (auto [u, v] : e[i]) {
            link(u, v);
        }
    }

    cout << ans << "\n";

    return 0;
}

0x0C dsu on tree 与 长链剖分

1. dsu on tree(求子树内不同颜色的个数)
void dfs1(int u, int depth, int father)
{
    dep[u] = depth, fa[u] = father, sz[u] = 1;
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == father) continue;
        dfs1(j, depth + 1, u);
        if (sz[son[u]] < sz[j]) son[u] = j;
        sz[u] += sz[j];
    }
}
 
int sson;
 
void add(int u)
{
    if (cnt[w[u]] == 0) sum ++ ;
    cnt[w[u]] ++ ;
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == fa[u] || j == sson) continue;
        add(j);
    }
}
 
void del(int u)
{
    cnt[w[u]] -- ;
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == fa[u] || j == sson) continue;
        del(j);
    }
}
 
void dsu(int u, bool is_del)
{
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == son[u] || j == fa[u]) continue;
        dsu(j, true);
    }
    if (son[u]) {
        dsu(son[u], false);
        sson = son[u];
    }
    add(u);
    ans[u] = sum;
    if (is_del) sson = 0, del(u), sum = 0;
}

模板

dsu (x) {
    1、处理轻儿子
    2、处理重儿子
    3、把轻儿子的信息往重儿子上边并
    4、计算答案
    5、需要清空时清空统计用的数据结构
}
2. 长链剖分(求满足条件的点的数量)

如果满足 d i s ( i , j ) = k dis(i, j) = k dis(i,j)=k,我们就说 i i i j j j 旗鼓相当

u u u 的子树中的所有节点,如果 x , y x, y x,y 是旗鼓相当的,并且 x , y x,y x,y 的最近公共祖先是 u u u 且满足 u ≠ x , u ≠ y u≠x,u≠y u=x,u=y 那么 u u u r a t i n g rating rating 就会增加 a x + a y a_x + a_y ax+ay

int fa[N], dep[N], son[N], len[N];

void dfs_son(int u, int depth, int father)
{
    fa[u] = father, dep[u] = depth;
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == father) continue;
        dfs_son(j, depth + 1, u);
        if (len[son[u]] < len[j]) son[u] = j;
    }
    len[u] = len[son[u]] + 1;
}

int L[N], R[N], cc;
int uid[N];

void dfs_id(int u)
{ 
    L[u] = ++ cc, R[u] = L[u] + len[u] - 1;
    uid[cc] = u;
    if (son[u]) dfs_id(son[u]);
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == fa[u] || j == son[u]) continue;
        dfs_id(j);
    }
}

int n, K;
int w[N];
int cnt[N];
i64 sum[N];
i64 ans[N];

void dfs(int u)
{
    if (son[u]) dfs(son[u]);
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == fa[u] || j == son[u]) continue;
        dfs(j);
        for (int x = L[j], k = 1; x <= R[j]; x ++ , k ++ ) {
            int kk = K - k;
            if (kk > 0 && kk < len[u]) {
                kk = kk + L[u];
                ans[u] += sum[kk] * cnt[x] + sum[x] * cnt[kk];
            }
        }
        for (int x = L[j], k = 1; x <= R[j]; x ++ , k ++ ) {
            cnt[L[u] + k] += cnt[x];
            sum[L[u] + k] += sum[x];
        }
    }

    cnt[L[u]] ++ ;
    sum[L[u]] += w[u];
}

模板

dsu (x) {
    1、处理重儿子
    2、处理轻儿子同时把轻儿子的信息往重儿子上边并
    3、计算答案
}

0x0D 点分治和点分树

1. 点分治(求到点 x 距离为 d 的点的个数)
#include <bits/stdc++.h>

using namespace std;

using i64 = long long;

const int N = 100010, M = 2 * N;

int h[N], e[M], ne[M], idx;

void add(int a, int b)
{
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}

int n, d;
bool st[N];

int get_size(int u, int fa)
{
    if (st[u]) return 0;
    int sz = 1;
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == fa) continue;
        sz += get_size(j, u);
    }
    return sz;
}

int get_wc(int u, int fa, int tot, int &wc)
{
    if (st[u]) return 0;
    int sz = 1, mx = 0;
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == fa) continue;
        int t = get_wc(j, u, tot, wc);
        mx = max(mx, t);
        sz += t;
    }
    mx = max(mx, tot - sz);
    if (mx <= tot / 2) wc = u;
    return sz;
}

pair<int, int> p[N], q[N];

void get_dist(int u, int fa, int dist, int &qt)
{
    if (st[u]) return;
    q[qt ++ ] = {dist, u};
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == fa) continue;
        get_dist(j, u, dist + 1, qt);
    }
}

int ans[N];

void calc(pair<int, int> a[], int k, int tt)
{
    sort(a, a + k);
    for (int i = 0; i < k; i ++ ) {
        int t = d - a[i].first;
        int x = upper_bound(a, a + k, make_pair(t, n + 1)) - a;
        x -- ;
        x = max(x, 0);
        ans[a[i].second] += tt * x;
    }
}

void div(int u)
{
    if (st[u]) return;
    get_wc(u, -1, get_size(u, -1), u);
    st[u] = true;

    // ---------------------------------

    int pt = 0;
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i], qt = 0;
        get_dist(j, -1, 1, qt);
        calc(q, qt, -1);
        for (int k = 0; k < qt; k ++ ) {
            if (q[k].first <= d) {
                ans[u] ++ ;
                ans[q[k].second] ++ ;
            }
            p[pt ++ ] = q[k];
        }
    }
    
    ans[u] ++ ;
    calc(p, pt, 1);
    // ---------------------------------

    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        div(j);
    }
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    cin >> n >> d;
    memset(h, -1, sizeof h);
    for (int i = 0; i < n - 1; i ++ ) {
        int a, b;
        cin >> a >> b;
        add(a, b), add(b, a);
    }

    div(1);

    for (int i = 1; i <= n; i ++ ) cout << ans[i] << " ";

    return 0;
}
2. 树分治(给出树的结构,询问到点 u 距离为 d 的点的个数)
#include <bits/stdc++.h>

using namespace std;

using i64 = long long;

const int N = 100010, M = 2 * N;

int n, q;

int h[N], e[M], ne[M], idx;

void add_edge(int a, int b)
{
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}

struct Father
{
    int wc, son;
    int dist;
};

// 记录当前结点往上的重心、其是该重心的哪个儿子以及到该重心的距离
vector<Father> fa[N]; 

struct Child
{
    vector<vector<pair<int, int>>> son;
    vector<pair<int, int>> all;
};

// son记录该节点某个儿子的信息
// all记录全部儿子的信息
Child son[N];

bool st[N];

int get_size(int u, int father)
{
    if (st[u]) return 0;
    int res = 1;
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == father) continue;
        res += get_size(j, u);
    }
    return res;
}

int get_wc(int u, int father, int tot, int &wc)
{
    if (st[u]) return 0;
    int sz = 1, mx = 0;
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == father) continue;
        int t = get_wc(j, u, tot, wc);
        mx = max(mx, t);
        sz += t;
    }
    mx = max(mx, tot - sz);
    if (mx <= tot / 2) wc = u;
    return sz;
}

void get_dist(int u, int father, int dist, int top, int kson)
{
    if (st[u]) return;
    fa[u].push_back({top, kson, dist});
    son[top].son[kson].emplace_back(dist, u);
    son[top].all.emplace_back(dist, u);
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == father) continue;
        get_dist(j, u, dist + 1, top, kson);
    }
}

void treediv(int u)
{
    if (st[u]) return;
    get_wc(u, 0, get_size(u, 0), u);
    st[u] = true;

    // -------------------------------

    for (int i = h[u], kson = 0; ~i; i = ne[i]) {
        int j = e[i];
        if (st[j]) continue;
        son[u].son.emplace_back(0);
        get_dist(j, u, 1, u, kson);
        kson ++ ;
    }

    // -------------------------------

    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        treediv(j);
    }
}

// 给定一个有序序列,求出其中距离小于等于k的数的数目
int calc(const vector<pair<int, int>> &a, int k)
{
    if (a.empty()) return 0;
    int p = upper_bound(a.begin(), a.end(), make_pair(k, n + 1)) - a.begin();
    p -- ;
    return p + 1;
}

int query(int u, int d)
{
    int res = 0;
    // 往上找重心,到u的距离为k的数目为该重心其他儿子结点到u距离为k的数目(就是点分治)
    for (const auto &_ : fa[u]) { 
        int dist = _.dist;
        int sn = _.son;
        int wc = _.wc;
        if (dist > d) continue;
        res ++ ; // 该重心到u距离小于k,加上
        // 容斥求
        res += calc(son[wc].all, d - dist); 
        res -= calc(son[wc].son[sn], d - dist);
    }
    // 统计以u为根的子树的信息
    res ++ ;
    res += calc(son[u].all, d);
    return res;
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    cin >> n >> q;
    memset(h, -1, sizeof h);
    for (int i = 0; i < n - 1; i ++ ) {
        int u, v;
        cin >> u >> v;
        add_edge(u, v), add_edge(v, u);
    }

    treediv(1);

    for (int i = 1; i <= n; i ++ ) {
        sort(son[i].all.begin(), son[i].all.end());
        for (auto &x : son[i].son) {
            sort(x.begin(), x.end());
        }
    }

    int ans = 0;
    while (q -- ) {
        int u, d;
        cin >> u >> d;
        u = (u + ans) % n + 1;
        ans = query(u, d);
        cout << ans << "\n";
    }

    return 0;
}
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值