题意
给定一个长度为 n 的数组 A,用 A[1...n] 表示, A[1...n] 是 1−n 的数一种排列组合。存在一个函数 f(l,r,k) 表示 A[l...r] 中第k大数的值,同时 f(l,r,k)=0 当 r−l+1<k 。给定 k 求解 ∑nl=1∑nr=lf(l,r,k) 。
分析
考虑每个数对答案的贡献,显然
A[1...n]
中第
1
到第
新的问题是如何快速确定大于 ai 的每个数的位置,或者说快速逐个搜索 ai 左侧和右侧比 ai 大的数的位置。由于不需要利用到比 ai 小的数,不妨从大到小枚举 1−n 中的每一个数,处理完后将对应数的位置放入某个集合中,这样每次查询这个集合时,必然都是比 ai 大的数。然而集合中快速搜索最近位置的点的复杂度为 O(log(n)) ,加上枚举的复杂度,总复杂度达到了 O(nklog(n)) 。赛时尝试了一发,不出意料的T掉了。然后考虑到优化搜索过程,不难发现搜索时都是从 ai 的位置向左或者向右逐个查询,于是想到用链表的方式保存左边和右边最近点位置。通过链表的指针就能在枚举的过程中快速搜索下一个点的位置了。插入新位置时采用二分的方式找到最近的点的位置,再利用链表关系更新即可。最后的总复杂度为 O(nk+nlog(n)) 。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<map>
using namespace std;
#define LL long long
#define MAXN 500500
const int mod=1e9+7;
struct Node{
int l,r;
}nxt[MAXN];
int bin[MAXN];
int pos[MAXN];
int lef[MAXN];
int rig[MAXN];
int n;
int lowbit(int x){
return x&-x;
}
void add(int x){
while(x<=n){
bin[x]++;
x+=lowbit(x);
}
}
int sum(int x){
int ret=0;
while(x){
ret+=bin[x];
x-=lowbit(x);
}
return ret;
}
int query(int l,int r){
return sum(r)-sum(l-1);
}
void updata(int x){
int l=1,r=x-1,mid;
while(l<=r){
mid=(l+r)>>1;
if(query(mid,x-1)<1)
r=mid-1;
else
l=mid+1;
}
nxt[x].l=r;
if(r>0)
nxt[r].r=x;
l=x+1,r=n;
while(l<=r){
mid=(l+r)>>1;
if(query(x+1,mid)<1)
l=mid+1;
else
r=mid-1;
}
nxt[x].r=l;
if(l<=n)
nxt[l].l=x;
}
int main(){
int T,k,a;
cin>>T;
while(T--){
scanf("%d %d",&n,&k);
memset(bin,0,sizeof(bin));
nxt[0].l=nxt[n+1].l=0;
nxt[0].r=nxt[n+1].r=n+1;
for(int i=1;i<=n;++i){
scanf("%d",&a);
pos[a]=i;
nxt[i].l=0;
nxt[i].r=n+1;
}
for(int i=n;i>n-k+1;i--){
updata(pos[i]);
add(pos[i]);
}
LL ans=0;
for(int i=n-k+1;i;i--){
updata(pos[i]);
add(pos[i]);
for(int j=0,curl=pos[i],curr=pos[i];j<=k;++j){
lef[j]=curl-nxt[curl].l;
rig[j]=nxt[curr].r-curr;
curl=nxt[curl].l;
curr=nxt[curr].r;
}
for(int j=0;j<k;++j)
ans+=i*1ll*lef[j]*rig[k-j-1];
}
printf("%I64d\n",ans);
}
}