这一题实际上就是线段上的最大独立集
因为带了修改,所以还有了点动态DP的意思??
比较暴力的写法是对于每次改值操作,全部重新DP
那肯定过不去
想想原因,因为每次重新DP重复计算了太多了计算过的值
而且,对于连续一段来说,他并没有后效性,就是不影响后面的选择
同时,后面的选择也不影响前面的
那么单点更新,可不可以每次调用以前算过的解呢?
当然可以,于是我们引出了线段树
因为要求独立集而且无前后效性,所以对于一个区间
他对外界的影响只有端点处,我们分四种情况讨论一下
struct node{
int l,r;
#define lch(x) x<<1
#define rch(x) x<<1|1
ll f1,f2,f3,f4;// f1 表示两端都选
// f2 表示左选右不选
// f3 表示左不选右选
// f4 表示左右都不选
}tr[maxn*4];
更新就很简单了,最基础的DP啊
void pushup(int k){
tr[k].f1=max(tr[lch(k)].f1+tr[rch(k)].f3,tr[lch(k)].f2+max(tr[rch(k)].f1,tr[rch(k)].f3));
tr[k].f2=max(tr[lch(k)].f1+tr[rch(k)].f4,tr[lch(k)].f2+max(tr[rch(k)].f2,tr[rch(k)].f4));
tr[k].f3=max(tr[lch(k)].f3+tr[rch(k)].f3,tr[lch(k)].f4+max(tr[rch(k)].f1,tr[rch(k)].f3));
tr[k].f4=max(tr[lch(k)].f3+tr[rch(k)].f4,tr[lch(k)].f4+max(tr[rch(k)].f2,tr[rch(k)].f4));
}
这样,我们只要在建树的时候计算一次全部的DP值即可,为
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)
再接下来的每一次修改,我们只需要
O
(
树
高
)
=
O
(
l
o
g
n
)
O(树高)=O(logn)
O(树高)=O(logn)就可以得到答案了
#include<iostream>
#include<cstdio>
#include<cmath>
#include<queue>
#include<algorithm>
using namespace std;
const int maxn = 50007;
typedef long long ll;
int n,d;
ll val[maxn],ans;
struct node{
int l,r;
#define lch(x) x<<1
#define rch(x) x<<1|1
ll f1,f2,f3,f4;// f1 表示两端都选
// f2 表示左选右不选
// f3 表示左不选右选
// f4 表示左右都不选
}tr[maxn*4];
void pushup(int k){
tr[k].f1=max(tr[lch(k)].f1+tr[rch(k)].f3,tr[lch(k)].f2+max(tr[rch(k)].f1,tr[rch(k)].f3));
tr[k].f2=max(tr[lch(k)].f1+tr[rch(k)].f4,tr[lch(k)].f2+max(tr[rch(k)].f2,tr[rch(k)].f4));
tr[k].f3=max(tr[lch(k)].f3+tr[rch(k)].f3,tr[lch(k)].f4+max(tr[rch(k)].f1,tr[rch(k)].f3));
tr[k].f4=max(tr[lch(k)].f3+tr[rch(k)].f4,tr[lch(k)].f4+max(tr[rch(k)].f2,tr[rch(k)].f4));
}
void build(int k,int l,int r){
tr[k].l=l,tr[k].r=r;
if(l==r){
tr[k].f1=val[l];
return;
}int mid=l+r>>1;
build(lch(k),l,mid);
build(rch(k),mid+1,r);
pushup(k);
}
void change(int k,int x,ll num){
int l=tr[k].l,r=tr[k].r;
if(l==r&&l==x){
tr[k].f1=num;
tr[k].f2=tr[k].f3=tr[k].f4=0;
return ;
}
int mid=l+r>>1;
if(x<=mid) change(lch(k),x,num);
else change(rch(k),x,num);
pushup(k);
}
int main(){
scanf("%d%d",&n,&d);
for(int i=1;i<=n;i++)scanf("%lld",&val[i]);
build(1,1,n);
for(int i=1;i<=d;i++){
int x;ll y;scanf("%d%lld",&x,&y);
change(1,x,y);
ans+=max(max(tr[1].f1,tr[1].f2),max(tr[1].f3,tr[1].f4));
}
printf("%lld",ans);
return 0;
}