SplayTree区间操作
-----------区间修改
题目:http://poj.org/problem?id=3468
题目大意:给出一组数字,区间整体增加一个值,区间查询和
思路:很经典的区间操作的题目,因此思路也不用自己想, 都是各路以例题的形式给出
之前用线段数写的,用了1938MS,今天用splay写,耗时2875.
说一下splay操作的几个要点:(其实核心还是编程珠玑的主题---抓住问题的本质)
1:操作的区间包括1或n的时候怎么办,如果操作的是[a,b],我们要进行splay(a-1, nul),
Splay(b + 1, root),这个时候就要考虑a-1和b+1越界的问题,再求最值的时候我是将区间扩建成[0,n + 1],但是这个求和的由于用了size,所以多出来节点是不允许的,不过也可以加一些限制,就是变麻烦了。其实分情况讨论一下就可以了,当a=1或b=n时,只旋转a-1
或b+1就可以了,这样也能那个得到我们想要的一颗子树(这个子树里只包括a到b的节点)
在这里卡了一段时间,还是对问题的本质不理解啊,只怪一开始没能理解旋转a-1和b+1的目的是为了构造一颗子树(这个子树里只包括a到b的节点)。
2:一个节点记录的信息到底代表什么,节点的sum域代表的是以这个节点为根的所有子树的总和,那么当我们修改的一个节点不是根节点的时候,我们就必须将这个节点的所有祖先节点全部更新。
不过最后时间是2800+,看到网上很多5000的,也算给自己安慰了
AC code:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long I64;
#define maxn 1000000
I64 num[maxn], n;
struct splayTreeNode{
I64 key, sum, size, renew, id;
splayTreeNode* son[2],* father;
void init(I64 _key, I64 _sum, I64 _size, I64 _renew, I64 _id, splayTreeNode* l, splayTreeNode* r, splayTreeNode* parent){
key = _key, sum = _sum, size = _size, renew = _renew, id = _id, son[0] = l, son[1] = r, father = parent;
}
};
struct splayTree{
#define root nul->son[0]
splayTreeNode* nul, *link;
I64 ad;
splayTree(){
link = new splayTreeNode[maxn];
ad = 0;
nul = &link[ad ++];
nul->init(0, 0, 0, 0, 0, NULL, NULL, NULL);
}
~splayTree(){ delete[] link; }
void copy(splayTreeNode* &x, splayTreeNode* &y, int co){
x->size += co * y->size;
x->sum += co * y->sum;
}
void rotate(splayTreeNode* &rt, int son1, int son2){
splayTreeNode* temp = rt->son[son1];
if(rt->renew) down(rt);
if(temp->renew) down(temp);
rt->son[son1] = temp->son[son2];
copy(rt, temp, -1);
if(temp->son[son2] != NULL){
copy(rt, temp->son[son2], 1);
copy(temp, temp->son[son2], -1);
temp->son[son2]->father = rt;
temp->son[son2]->id = son1;
}
temp->son[son2] = rt;
temp->father = rt->father;
copy(temp, rt, 1);
temp->id = rt->id;
rt->father = temp;
rt->id = son2;
rt = temp;
}
void splay(splayTreeNode* x, splayTreeNode* rt){
if(x == rt || x->father == rt) return;
splayTreeNode* y = x->father, * z = x->father->father;
if(z == rt){
if(y->son[0] == x) rotate(z->son[y->id], 0, 1);
else rotate(z->son[y->id], 1, 0);
}
else{
if(y->id == x->id){
rotate(z->father->son[z->id], x->id, (x->id)^1);
rotate(y->father->son[y->id], x->id, (x->id)^1);
}
else{
rotate(y->father->son[y->id], x->id, (x->id)^1);
rotate(z->father->son[z->id], y->id, (y->id)^1);
}
}
splay(x, rt);
}
void built(int l, int r, int sid, splayTreeNode* &rt, splayTreeNode* fa){
if(l > r) return;
int mid = (l + r) >> 1;
rt = &link[ad ++];
rt->init(mid, num[mid], 1, 0, sid, NULL, NULL, fa);
built(l, mid - 1, 0, rt->son[0], rt);
built(mid + 1, r, 1, rt->son[1], rt);
for(int i = 0; i < 2; ++ i){
if(rt->son[i] != NULL) copy(rt, rt->son[i], 1);
}
}
void replace(splayTreeNode* &rt, I64 data){
rt->sum += rt->size * data;
rt->renew += data;
}
void down(splayTreeNode* &rt){
for(int i = 0; i < 2; ++ i){
if(rt->son[i] != NULL) replace(rt->son[i], rt->renew);
}
rt->renew = 0;
}
splayTreeNode* find(I64 x, splayTreeNode* rt){
if(rt == NULL || rt->key == x) return rt;
if(rt->renew) down(rt);
if(x > rt->key) return find(x, rt->son[1]);
return find(x, rt->son[0]);
}
void updata(I64 a, I64 b, I64 c){
splayTreeNode *temp;
if(a == 1){
if(b == n){
replace(this->root, c);
}
else{
temp = this->find(b + 1, this->root);
splay(temp, this->nul);
replace(this->root->son[0], c);
this->root->sum += c * this->root->son[0]->size;
}
}
else{
if(b == n){
temp = this->find(a - 1, this->root);
splay(temp, this->nul);
replace(this->root->son[1], c);
this->root->sum += c * this->root->son[1]->size;
}
else{
temp = this->find(a - 1, this->root);
splay(temp, this->nul);
temp = this->find(b + 1, this->root);
splay(temp, this->root);
replace(this->root->son[1]->son[0], c);
this->root->son[1]->sum += c * this->root->son[1]->son[0]->size;
this->root->sum += c * this->root->son[1]->son[0]->size;
}
}
}
I64 answer(I64 a, I64 b){
splayTreeNode *temp;
if(a == 1){
if(b == n){
return this->root->sum;
}
else{
temp = this->find(b + 1, this->root);
splay(temp, this->nul);
return(this->root->son[0]->sum);
}
}
else{
if(b == n){
temp = this->find(a - 1, this->root);
splay(temp, this->nul);
return(this->root->son[1]->sum);
}
else{
temp = this->find(a - 1, this->root);
splay(temp, this->nul);
temp = this->find(b + 1, this->root);
splay(temp, this->root);
return(this->root->son[1]->son[0]->sum);
}
}
}
};
int main(){
char ord[4];
I64 a, b, c, Q;
while(~scanf("%d %d", &n, &Q)){
splayTree spl;
for(int i = 1; i <= n; ++ i) scanf("%lld", &num[i]);
spl.built(1, n, 0, spl.root, spl.nul);
while(Q --){
scanf("%1s", ord);
switch(ord[0]){
case 'C':
scanf("%lld %lld %lld", &a, &b, &c);
spl.updata(a, b, c);
break;
case 'Q':
scanf("%lld %lld", &a, &b);
printf("%lld\n", spl.answer(a, b));
break;
}
}
}
return 0;
}