[东师培训D3T2] Number
题目
题目描述
一个排列,求出了 a 数组,其中 ai 表示第 i 个数左边有多少个数比它小。计算出原来的排列。输入格式
第一行输入 n 接下来 n - 1 个整数 ai,下标从 2 开始。输出格式
输出 n 个整数表示原排列。Sample Input
5
1 2 1 0Sample Output
24
5
3
1
数据范围
对于 20% 的数据满足:1 ≤ n ≤ 10对于 50% 的数据满足:1 ≤ n ≤ 1000
对于 100% 的数据满足,1 ≤ n ≤ 100000
保证解存在
想法
本来的想法是O(n)逆序递推。在纸上推导的时候就是把1~n的所有正整数写下来,记为t[i]。然后对于原数组a[i]从n~1递推。最右端的数肯定在右边没有比它小的数(因为右边根本没有数)。所以原排列ans[i]就等于t[]中第a[i] + 1大的元素。然后把这个数从t[]中划掉,重复递推至1。
麻烦在于维护可以随时删除元素的t[]的第k大值。显然t[]不能是普通数组。作为一个打不好splay的苣蒻,我果断选择了treap…
另外正解是线段树或树状数组,挺麻烦的,就不记了。
代码
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cstdlib>
#include <ctime>
using namespace std;
const int MAXN = 100005;
int n, a[MAXN], ans[MAXN];
struct Node
{
int key, deg, size, cnt;
Node *ch[2];
Node() {}
Node(int x)
{
key = x;
deg = rand();
size = cnt = 1;
ch[0] = ch[1] = NULL;
}
};
Node *rt = NULL;
void rotate(Node *&n, bool d)
{
Node *t = n->ch[d ^ 1];
n->ch[d ^ 1] = t->ch[d];
t->ch[d] = n;
n->size = n->cnt;
if (n->ch[0] != NULL) n->size += n->ch[0]->size;
if (n->ch[1] != NULL) n->size += n->ch[1]->size;
t->size = t->cnt;
if (t->ch[0] != NULL) t->size += t->ch[0]->size;
if (t->ch[1] != NULL) t->size += t->ch[1]->size;
n = t;
}
void insert(Node *&n, int x)
{
if (n == NULL)
{
n = new Node(x);
return;
}
if (n->key == x)
{
n->size++;
n->cnt++;
}
if (n->key > x)
{
insert(n->ch[0], x);
if (n->ch[0]->deg < n->deg) rotate(n, 1);
else n->size++;
}
else
{
insert(n->ch[1], x);
if (n->ch[1]->deg < n->deg) rotate(n, 0);
else n->size++;
}
}
void remove(Node *&n, int x)
{
if (n == NULL) return;
if (x == n->key)
{
if (n->cnt > 1)
{
n->size--;
n->cnt--;
return;
}
else
{
if (n->ch[0] == NULL)
{
Node *t = n;
n = n->ch[1];
delete t;
return;
}
else if (n->ch[1] == NULL)
{
Node *t = n;
n = n->ch[0];
delete t;
return;
}
else
{
if (n->ch[0]->deg < n->ch[1]->deg)
{
rotate(n, 1);
remove(n->ch[1], x);
}
else
{
rotate(n, 0);
remove(n->ch[0], x);
}
n->size--;
}
}
}
else
{
if (n->key > x) remove(n->ch[0], x);
else remove(n->ch[1], x);
n->size--;
}
}
int kth(Node *n, int x)
{
int s = 0;
if (n->ch[0] != NULL) s = n->ch[0]->size;
if (x <= s) return kth(n->ch[0], x);
else if (x <= s + n->cnt) return n->key;
else return kth(n->ch[1], x - s - n->cnt);
}
int main()
{
srand(time(NULL));
scanf("%d", &n);
insert(rt, 1);
for (int i = 2; i <= n; i++)
{
scanf("%d", &a[i]);
insert(rt, i);
}
for (int i = n; i >= 1; i--)
{
ans[i] = kth(rt, a[i] + 1);
remove(rt, ans[i]);
}
for (int i = 1; i <= n; i++)
printf("%d\n", ans[i]);
return 0;
}