【题目链接】
【思路要点】
- 显然有\(O(N^2)\)的DP,用线段树维护每一个点的DP数组,用线段树合并支持转移。
- 注意DP数组是单调的,可以将区间取最大值操作看做区间赋值操作。
- 时间复杂度\(O(NLogN)\)。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 2e5 + 5; const int MAXP = 1e7 + 5; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } struct SegmentTree { struct Node { int lc, rc; int tag, add, Max; } a[MAXP]; int n, size; void init(int x) { n = x; size = 0; } void update(int root) { a[root].Max = 0; chkmax(a[root].Max, a[a[root].lc].Max); chkmax(a[root].Max, a[a[root].rc].Max); } void pushdown(int root) { if (a[root].tag == 0 && a[root].add == 0) return; if (a[root].tag == 0) { if (a[root].lc == 0) { a[root].lc = ++size; a[a[root].lc].tag = a[root].add; a[a[root].lc].Max = a[root].add; } else if (a[a[root].lc].tag) { a[a[root].lc].tag += a[root].add; a[a[root].lc].Max = a[a[root].lc].tag; } else { a[a[root].lc].add += a[root].add; a[a[root].lc].Max += a[root].add; } if (a[root].rc == 0) { a[root].rc = ++size; a[a[root].rc].tag = a[root].add; a[a[root].rc].Max = a[root].add; } else if (a[a[root].rc].tag) { a[a[root].rc].tag += a[root].add; a[a[root].rc].Max = a[a[root].rc].tag; } else { a[a[root].rc].add += a[root].add; a[a[root].rc].Max += a[root].add; } a[root].add = 0; return; } if (a[root].lc == 0) a[root].lc = ++size; a[a[root].lc].tag = a[a[root].lc].Max = a[root].tag; a[a[root].lc].add = 0; if (a[root].rc == 0) a[root].rc = ++size; a[a[root].rc].tag = a[a[root].rc].Max = a[root].tag; a[a[root].rc].add = 0; a[root].tag = 0; } int query(int root, int l, int r, int ql, int qr) { if (root == 0) return 0; if (l == ql && r == qr) return a[root].Max; pushdown(root); int mid = (l + r) / 2, ans = a[root].tag; if (mid >= ql) chkmax(ans, query(a[root].lc, l, mid, ql, min(mid, qr))); if (mid + 1 <= qr) chkmax(ans, query(a[root].rc, mid + 1, r, max(mid + 1, ql), qr)); return ans; } int query(int root, int l, int r) { if (l > r) return 0; else return query(root, 1, n, l, r); } int getr(int &root, int l, int r, int pos, int val) { if (root == 0 || a[root].Max < val) return r + 1; if (l == r) return l; pushdown(root); int mid = (l + r) / 2; if (a[a[root].lc].Max >= val && mid >= pos) return getr(a[root].lc, l, mid, pos, val); else return getr(a[root].rc, mid + 1, r, pos, val); } void modify(int &root, int l, int r, int ql, int qr, int val) { if (root == 0) root = ++size; if (l == ql && r == qr) { a[root].Max = val; a[root].tag = val; a[root].add = 0; return; } pushdown(root); int mid = (l + r) / 2; if (mid >= ql) modify(a[root].lc, l, mid, ql, min(mid, qr), val); if (mid + 1 <= qr) modify(a[root].rc, mid + 1, r, max(mid + 1, ql), qr, val); update(root); } void modify(int &root, int pos, int val) { int tmp = getr(root, 1, n, pos, val); if (tmp == pos) return; modify(root, 1, n, pos, tmp - 1, val); } int merge(int x, int y, int l, int r) { if (x == 0 || y == 0) return x + y; if (a[x].tag && a[y].tag) { a[x].tag += a[y].tag; a[x].Max += a[y].Max; return x; } if (a[x].tag) { a[y].add += a[x].tag; a[y].Max += a[x].tag; return y; } if (a[y].tag) { a[x].add += a[y].tag; a[x].Max += a[y].tag; return x; } pushdown(x); pushdown(y); int mid = (l + r) / 2; a[x].lc = merge(a[x].lc, a[y].lc, l, mid); a[x].rc = merge(a[x].rc, a[y].rc, mid + 1, r); update(x); return x; } void join(int &x, int y) { if (x == 0 || y == 0) { x += y; return; } else x = merge(x, y, 1, n); } int getans(int root) {return a[root].Max; } } ST; int n, val[MAXN], f[MAXN]; int tot, tmp[MAXN], root[MAXN]; vector <int> a[MAXN]; void work(int pos) { for (unsigned i = 0; i < a[pos].size(); i++) { work(a[pos][i]); ST.join(root[pos], root[a[pos][i]]); } int tmp = ST.query(root[pos], 1, val[pos] - 1); ST.modify(root[pos], val[pos], tmp + 1); } int main() { read(n); for (int i = 1; i <= n; i++) { read(val[i]), read(f[i]); tmp[++tot] = val[i]; } sort(tmp + 1, tmp + tot + 1); tot = unique(tmp + 1, tmp + tot + 1) - tmp - 1; for (int i = 1; i <= n; i++) { val[i] = lower_bound(tmp + 1, tmp + tot + 1, val[i]) - tmp; if (i != 1) a[f[i]].push_back(i); } ST.init(tot); work(1); writeln(ST.getans(root[1])); return 0; }