LOJ#2473. 「九省联考 2018」秘密袭击(线段树合并+拉格朗日插值)

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/Timsei/article/details/84574913

一个非常强的题。
也许比较套路但是都比较生疏。
主要使用两个思想。
首先是把求第k大的权转化成枚举i 从1 - W 计算 最终的第k大 大于等于 i 的和。
然后就可以 转化成一个DP。
f[i][j][k] represents the subtree of the node i and we are considering the value of the kth node is not less than j. This block already has k nodes which have value not less than j.
如果我们把最后一维表示为生成函数。则我们可以列出这个DP的转移方程。
由于我们需要知道所有subtree的信息和。
不难想到另外开一个g[i][j]表示i的subtree 内的f[i][j] 的和。
一般这一类题目都可以用求点值然后拉格朗日插值回去的做法来做。
注意到我们只想要知道所有DP状态的和。 所以我们只需要记录所有多项式之和就可以了。
具体怎样维护可以在树上自己推式子解决。
子区间操作然后点一一对应相作运算的题一般来讲都是可以线段树合并的。 对于多个操作涉及到很多不同的操作的时候一个普遍的优化方法是用矩阵。 当矩阵不够快的时候想一下是否能用几个一次函数来表示标记以简化。
比如说这个题。 标记(a, b, c, d) 表示(S, F) -> (aS+b, cS+d+F)
最后插值的时候用拉格朗日退背包就好。

最重要的转化:
第一步 看到值域很小, 所以自然的想到差分然后转化为各项之和。
第二步 是转化为点值这一步比较套路。
第三步 用线段树维护, 这一部分是利用矩阵转移然后优化状态。
值得注意的是线段树合并的时候是 有一方两个儿子均为空时合并。
这一类题目 先要找到可以方便多项式转移的DP 形式。 这样子就成功了一半。
然后剩下的部分想着如何用数据结构维护一般都是可行的。

#include <bits/stdc++.h>
using namespace std;

typedef unsigned int ui;

const int N = 2e3 + 5;
const int M = N * 2;
const ui mod = 64123;
const int MAX = 1e5 + 5;

struct NODE {
  ui a, b, c, d;
  friend NODE operator * (NODE er, NODE la) {
    return
      (NODE) {er.a * la.a % mod, (la.a * er.b % mod + la.b) % mod
      , (la.c * er.a % mod + er.c) % mod,
	(er.b * la.c % mod + la.d + er.d) % mod};
  }
  void init() {
    a = 1; b = c = d = 0;
  }
}T[MAX];

int n, k, W, d[N], fir[N], ne[M], to[M], cnt, rt[N], sz, ch[MAX][2];

#define lc (ch[x][0])
#define rc (ch[x][1])

void add(int x, int y) {
  ne[++ cnt] = fir[x];
  fir[x] = cnt;
  to[cnt] = y;
}

void link(int x, int y) {
  add(x, y);
  add(y, x);
}

#define Foreachson(i, x) for(int i = fir[x]; i; i = ne[i])

void readin() {
  int x, y;
  scanf("%d%d%d", &n, &k, &W);
  for(int i = 1; i <= n; ++ i) scanf("%d", &d[i]);
  for(int i = 1; i < n; ++ i) {
    scanf("%d%d", &x, &y);
    link(x, y);
  }
}

ui poly[N], F[MAX], S[MAX];

int newnode() {
  ++ sz;
  F[sz] = S[sz] = 0;
  T[sz].init();
  ch[sz][0] = ch[sz][1] = 0;
  return sz;
}

void pt(int &x, NODE who) {
  if(!x) x = newnode();
  T[x] = T[x] * who;
  return;
}

void pd(int x) {
  pt(lc, T[x]);
  pt(rc, T[x]);
  T[x].init();
}

void chg(int &x, int l, int r, int L, int R, NODE who) {
  if(!x) x = newnode();
  if(l == L && r == R) {
    T[x] = T[x] * who;
    return;
  }
  pd(x);
  int mid = (l + r) >> 1;
  if(L > mid) chg(rc, mid + 1, r, L, R, who);
  else if(R <= mid) chg(lc, l, mid, L, R, who);
  else chg(lc, l, mid, L, mid, who), chg(rc, mid + 1, r, mid + 1, R, who);
}

int merge(int x, int y) {
  if(!x || !y) return x + y;
  if(!ch[x][0] && !ch[x][1]) swap(x, y);
  if(!ch[y][0] && !ch[y][1]) {
    // y's is ((a + b), d)
    T[x].a = T[x].a * T[y].b % mod;
    T[x].b = T[x].b * T[y].b % mod;
    T[x].d = (T[x].d + T[y].d) % mod;
    return x;
  }
  pd(x); pd(y);
  ch[x][0] = merge(ch[x][0], ch[y][0]);
  ch[x][1] = merge(ch[x][1], ch[y][1]);
  return x;
}

void dfs(int x, int f, ui magic) {
  rt[x] = newnode();
  pt(rt[x], (NODE){0, 1, 0, 0});
  Foreachson(i, x) {
    int V = to[i];
    if(V == f) continue;
    dfs(V, x, magic);
    rt[x] = merge(rt[x], rt[V]);
  }
  if(d[x]) chg(rt[x], 1, W, 1, d[x], (NODE){magic, 0, 0, 0});
  pt(rt[x], (NODE){1, 0, 1, 0});
  pt(rt[x], (NODE){1, 1, 0, 0});
}

void query(int &x, int l, int r, ui &ans) {
  if(l == r) {
    ans = ans + (T[x].d);
    ans %= mod;
    return;
  }
  int mid = (l + r) >> 1;
  pd(x);
  query(lc, l, mid, ans), query(rc, mid + 1, r, ans);
}

ui f[N], g[N];

int Pow(ui x, int y) {
  ui res = 1;
  for(; y; y >>= 1, x = 1LL * x * x % mod) {
    if(y & 1) {
      res = 1LL * res * x % mod;
    }
  }
  return res;
}

void dec(ui *f, ui *g, int x) {
  for(int i = 0; i <= n + 1; ++ i) g[i] = f[i];
  ui Inv = Pow(x, mod - 2);
  for(int i = 0; i <= n; ++ i) {
    g[i] = 1LL * g[i] * (mod - Inv) % mod;
    g[i + 1] = (g[i + 1] - g[i] + mod) % mod;
  }
  assert(!g[n + 1]);
}

void Lagrange(void) {
  memset(f, 0, sizeof(f));
  memset(g, 0, sizeof(g));
  f[0] = 1;
  for(int i = 1; i <= n + 1; ++ i) 
    for(int j = n + 1; j >= 0; -- j)
      f[j + 1] = (f[j + 1] + f[j]) % mod,
	f[j] = 1LL * f[j] * (mod - i) % mod;
  ui ans = 0;
  for(int i = 1; i <= n + 1; ++ i) {
    dec(f, g, i);
    ui now = 0;
    for(int j = k; j <= n; ++ j) now = (now + g[j]) % mod;
    for(int j = 1; j <= n + 1; ++ j) {
      if(i != j)
	now = now * Pow((i - j + mod) % mod, mod - 2) % mod;
    }
    now = now * poly[i] % mod;
    ans = (ans + now) % mod;
  }
  cout << ans << endl;
}

int main() {
  readin();
  for(int i = 1; i <= n + 1; ++ i) {
    sz = 0;
    dfs(1, 0, i);
    query(rt[1], 1, W, poly[i]);
    //poly[i] = (poly[i] + query(rt[1], 1, W, j)) % mo;
    cerr << i <<" " << poly[i] << endl;
  }
  Lagrange();
}

展开阅读全文

没有更多推荐了,返回首页