首先来看如何快速求出区间[l,r]的答案。预处理pre[i]表示和a[i]相同的数上一次出现的位置,如果a[i]第一次出现则per[i]=0。那么[l,r]的答案就相当于求[l,r]中所有pre[]<l的数的和。
因此可以构建主席树,第i棵树表示右坐标为i时,里面节点[u,v]的值就是左端点在[u,v],右端点为i时的最大值。那么可以发现第i棵树相对于第i-1棵树只有pre[i+1]~i的每一个都加上了a[i],这样就可以打标记然后可持久化了。
然后用一个优先队列维护答案,每次找头结点拓展,注意结点要额外记录两个值(l,r)表示左端点范围是(l,r)。
注意主席树打标记的时候要新建结点不然其他地方指过来的话会出错。另外可以标记永久化来加速。
AC代码如下:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<algorithm>
#define N 100005
#define M 7000005
#define ll long long
using namespace std;
int n,m,trtot,num[N],rt[N],pre[N],ls[M],rs[M]; ll icr[M];
struct node{ ll x; int l,r,y,id; }a[N];
struct trnd{ ll x; int y; }val[M];
bool operator <(node u,node v){ return u.x<v.x || u.x==v.x && u.id<v.id; }
priority_queue<node> q;
int read(){
int x=0,fu=1; char ch=getchar();
while (ch<'0' || ch>'9'){ if (ch=='-') fu=-1; ch=getchar(); }
while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); }
return x*fu;
}
void build(int &k,int l,int r){
k=++trtot; val[k].y=l; int mid=(l+r)>>1;
if (l==r) return; build(ls[k],l,mid); build(rs[k],mid+1,r);
}
void add(int &x,int w){
int y=++trtot; ls[y]=ls[x]; rs[y]=rs[x];
val[y]=val[x]; val[y].x+=w; icr[y]=icr[x]+w; x=y;
}
void ins(int l,int r,int x,int &y,int u,int v,int w){
y=++trtot; val[y]=val[x]; icr[y]=icr[x]; ls[y]=ls[x]; rs[y]=rs[x];
if (l==u && r==v){ add(y,w); return; }
int mid=(l+r)>>1;
if (v<=mid) ins(l,mid,ls[x],ls[y],u,v,w); else
if (u>mid) ins(mid+1,r,rs[x],rs[y],u,v,w); else{
ins(l,mid,ls[x],ls[y],u,mid,w); ins(mid+1,r,rs[x],rs[y],mid+1,v,w);
}
if (val[ls[y]].x>val[rs[y]].x) val[y]=val[ls[y]]; else val[y]=val[rs[y]];
val[y].x+=icr[y];
}
trnd qry(int l,int r,int k,int x,int y){
if (l==x && r==y) return val[k]; int mid=(l+r)>>1;
trnd t1;
if (y<=mid) t1=qry(l,mid,ls[k],x,y); else
if (x>mid) t1=qry(mid+1,r,rs[k],x,y); else{
t1=qry(l,mid,ls[k],x,mid); trnd t2=qry(mid+1,r,rs[k],mid+1,y);
if (t2.x>t1.x) t1=t2;
}
t1.x+=icr[k]; return t1;
}
int main(){
n=read(); m=read(); int i;
for (i=1; i<=n; i++){
num[i]=read();
a[i].x=(ll)num[i]; a[i].id=i;
}
sort(a+1,a+n+1);
for (i=1; i<=n; i++)
if (i>1 && a[i].x==a[i-1].x) pre[a[i].id]=a[i-1].id;
else pre[a[i].y]=0;
node u,v; trnd t; build(rt[0],1,n);
for (i=1; i<=n; i++){
ins(1,n,rt[i-1],rt[i],pre[i]+1,i,num[i]);
t=qry(1,n,rt[i],1,i);
u.l=1; u.r=u.id=i; u.x=t.x; u.y=t.y; q.push(u);
}
for (i=1; i<=m; i++){
u=q.top(); q.pop();
if (i==m) printf("%lld\n",u.x); else{
if (u.l<u.y){
t=qry(1,n,rt[u.id],u.l,u.y-1);
v.l=u.l; v.r=u.y-1; v.id=u.id; v.x=t.x; v.y=t.y;
q.push(v);
}
if (u.y<u.r){
t=qry(1,n,rt[u.id],u.y+1,u.r);
u.l=u.y+1; u.x=t.x; u.y=t.y; q.push(u);
}
}
}
return 0;
}
by lych
2016.4.8