第一次码splay,参考了多方代码。
由于每次插入的数都旋转到根,求前驱、后继的函数就统统改为求根的前驱-后继了
其他splay的经典操作还没涉及到,慢慢练吧:P
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
const int maxn = 32768 + 100;
#define ll long long
int read(){
int s = 0, f = 1; char c = getchar();
while (c > '9' || c < '0') {if (c == '-') f = -1; c = getchar();}
while (c >= '0' && c <= '9') {s = s * 10 + c - '0'; c = getchar();}
return s * f;
}
struct _{
_ *son[2], *fa;
int val;
}*root, pool[maxn], *null = NULL;
int n, ans, tmp;
_ *newNode(int v, _ *fa) {
static int cnt = 1;
pool[cnt].fa = fa; pool[cnt].val = v;
return &pool[cnt++];
}
void rotate(_ *cur, int f){ //0左1右
_ *y = cur -> fa;
y -> son[f ^ 1] = cur -> son[f];
if (cur -> son[f]) cur -> son[f] -> fa = y;
cur -> son[f] = y;
if (y -> fa != NULL){
if (y -> fa -> son[0] == y) {
y -> fa -> son[0] = cur;
} else y -> fa -> son[1] = cur;
}
cur -> fa = y -> fa; y -> fa = cur;
}
void splay(_ *cur, _ *r) {//旋转至树r底下 如果是NULL旋转到根
while (cur -> fa != r) {
_ *p = cur -> fa;
if (p -> fa == r) { //父节点为目标
if (cur == p -> son[0]) rotate(cur, 1);
else rotate(cur, 0);
break;
}
if (cur == p -> son[0]){
if (p == p -> fa -> son[0]){ //zig - zig
rotate(p, 1); rotate(cur, 1);
} else {
rotate(cur, 1); rotate(cur, 0); //zig - zag
}
} else {
if (p == p -> fa -> son[1]) {
rotate(p, 0); rotate(cur, 0); //zag - zag
} else {
rotate(cur, 0); rotate(cur, 1); //zag - zig
}
}
}
if (r == NULL) root = cur;
}
bool insert(int x){
_ *p = root, *fa;
while (p != NULL) {
if (x == p -> val) {splay(p, NULL); return true;} //重复结点
if (x < p -> val) fa = p, p = p -> son[0];
else fa = p, p = p -> son[1];
}
p = newNode(x, fa);
if (x < fa -> val) fa -> son[0] = p; else fa -> son[1] = p;
splay(p, NULL);//旋转到根
return false;
}
int findpre(_ *cur){
_ *p = cur -> son[0];
if (p == NULL) return 0x3f3f3f3f;
while (p -> son[1]) p = p -> son[1];
return p -> val;
}
int findnxt(_ *cur){
_ *p = cur -> son[1];
if (p == NULL) return 0x3f3f3f3f;
while (p -> son[0]) p = p -> son[0];
return p -> val;
}
int main(){
n = read();
root = newNode(ans += read(), NULL);
for (int i = 2; i <= n; i++) {
tmp = read();
if (insert(tmp)) continue;
tmp = min(abs(root->val - findpre(root)), abs(root->val - findnxt(root)));
ans += tmp;
}
printf("%d\n", ans);
return 0;
}