题意:一棵有根树,人从根结点出发,每个结点只能有一个人,每个人尽量向叶子走(以膜拜吉丽QAQ),如果有多个结点,则选择编号最小的。实现两种操作:1. 把x个人放到根结点,问第x个人停留在哪里。2. 取出编号为x的结点里的人,问除他以外有多少个人会移动。保证操作合法,树的结点数n和操作数m不超过100000。
手动模拟后发现,对子结点排序,人按照后序遍历的顺序依次考虑所有空结点。问题等价于,维护一个长度为n的序列,支持两种操作:1. 标记前x个未标记的结点,并返回此次最后一个被标记的结点。2. 返回x的祖先中最浅的一个被标记的结点,并把该结点置为未标记(置x为未标记,祖先向下挪)。
可不可以用set之类的东西呢?有两个困难:1. 点只能一个一个插,一个一个删。2. 难以高效地找到前x个未被标记的结点是哪些。于是我写了个线段树……
其他人的代码怎么这么短?我意识到自己的愚蠢:1. 操作一插x个点,操作二删1个点,插入和删除都不会超过(n+m)次。2. 维护未被标记的结点,而不是被标记的结点,这样就能快速找到前x个了。
#include <cstdio>
#include <vector>
#include <algorithm>
#include <cassert>
#define ALL 1, 1, n
using namespace std;
const int MAX_N = 1e5, MAX_D = 17;
vector<int> E[MAX_N+1];
int n, dfs_clock, max_d, post[MAX_N+1], rk[MAX_N+1], anc[MAX_N+1][MAX_D+1];
struct Segment_Tree {
int sum[MAX_N*4];
void maintain(int o)
{
sum[o] = sum[o*2] + sum[o*2+1];
}
void pushdown(int o)
{
if (!sum[o])
sum[o*2] = sum[o*2+1] = 0;
}
void build(int o, int l, int r)
{
if (l == r) {
sum[o] = 1;
return;
}
int m = (l+r)/2;
build(o*2, l, m);
build(o*2+1, m+1, r);
maintain(o);
}
// 查询pos
int query(int pos, int o, int l, int r)
{
if (sum[o] == 0 || l == r)
return sum[o];
int m = (l+r)/2;
return pos > m ? query(pos, o*2+1, m+1, r) : query(pos, o*2, l, m);
}
// 置pos为1
void set(int pos, int o, int l, int r)
{
if (l == r) {
sum[o] = 1;
return;
}
pushdown(o);
int m = (l+r)/2;
if (pos > m)
set(pos, o*2+1, m+1, r);
else
set(pos, o*2, l, m);
maintain(o);
}
// 返回k的位置,置前k个1为0
int clear(int k, int o, int l, int r)
{
if (l == r)
return sum[o] = 0, l;
int m = (l+r)/2, ret;
pushdown(o);
if (sum[o*2] >= k)
ret = clear(k, o*2, l, m);
else {
ret = clear(k-sum[o*2], o*2+1, m+1, r);
sum[o*2] = 0;
}
maintain(o);
return ret;
}
} T;
void dfs_1(int u, int fa)
{
for (int i = 0, v; i < E[u].size(); ++i)
if ((v=E[u][i]) != fa)
dfs_1(v, u);
post[u] = ++dfs_clock;
rk[dfs_clock] = u;
}
void dfs_2(int u, int fa)
{
int id = post[u];
for (int i = 1; i <= max_d; ++i)
anc[id][i] = anc[anc[id][i-1]][i-1];
for (int i = 0, v; i < E[u].size(); ++i)
if ((v=E[u][i]) != fa) {
anc[post[v]][0] = id;
dfs_2(v, u);
}
}
int query(int x)
{
int d = 0;
for (int i = max_d; i >= 0; --i)
if (anc[x][i] && !T.query(anc[x][i], ALL)) {
d += (1<<i);
x = anc[x][i];
}
T.set(x, ALL);
return d;
}
int main()
{
int t;
scanf("%d %d", &n, &t);
while ((1<<max_d) < n)
++max_d;
for (int i = 1; i < n; ++i) {
int u, v;
scanf("%d %d", &u, &v);
E[u].push_back(v);
E[v].push_back(u);
}
for (int i = 1; i <= n; ++i)
sort(E[i].begin(), E[i].end());
dfs_1(1, 0); // 重新编号
dfs_2(1, 0); // 倍增预处理
T.build(ALL);
while (t--) {
int op, x;
scanf("%d %d", &op, &x);
if (op == 1)
printf("%d\n", rk[T.clear(x, ALL)]);
else
printf("%d\n", query(post[x]));
}
return 0;
}