题意:写着DP总不可能是dp吧,就是给定一个dp[i][j]的求解方式,如图中分段函数所示,然后给你一个l r k,然后问你dp[ r-l+1 ][ k ]的值为多少?
思路:既然题目那这DP,那应该就不是DP吧。手写几项后,反正我写不出来 ,发现我们需要求得就是(
∑
i
=
1
r
−
l
+
1
i
2
\sum_{i=1}^{r-l+1}i^2
∑i=1r−l+1i2) + [
b
l
b_l
bl,
b
l
+
1
,
.
.
.
.
.
.
b
r
b_{l+1},......b_r
bl+1,......br]中前k大的数和。
前面的式子很好维护,至于后面的区间前k大和,一看到区间,k大,那就是用主席树来维护这个东西了。具体实现看代码。
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 1e5 + 7;
int root[MAXN],n,q,size,num;//size 内存池的大小
ll a[MAXN],s[MAXN],b[MAXN];
struct HJT_tree{
int l,r,cnt;//cnt记录出现次数 sum记录和 val记录节点的值
ll sum,val;
}tree[MAXN*40];
void build(int &now,int l,int r){
now = ++size;
tree[now].cnt = tree[now].sum = tree[now].val = 0;
if(l == r) return ;
int mid = (l+r)>>1;
build(tree[now].l,l,mid);
build(tree[now].r,mid+1,r);
}
void modify(int &now,int pre,int l,int r,int p,int val){
tree[now = ++size] = tree[pre];
tree[now].cnt++,tree[now].sum += val;
if(l == r){
tree[now].val = val;
return ;
}
int mid = (l+r)>>1;
if(p <= mid) modify(tree[now].l,tree[pre].l,l,mid,p,val);
else modify(tree[now].r,tree[pre].r,mid+1,r,p,val);
}
//查前k大值,因为节点存的是值 所以找大的 应该先往右边找 看右边(也就是大的)的数量符合情况嘛
ll query(int now,int pre,int l,int r,int k){
if(l == r) return tree[now].val * k;//因为一个值可能出现多次 那么计算答案的时候要把他们都算上
int mid = (l+r)>>1;
int tmp = tree[tree[now].r].cnt - tree[tree[pre].r].cnt;
if(tmp >= k) return query(tree[now].r,tree[pre].r,mid+1,r,k);
else return query(tree[now].l,tree[pre].l,l,mid,k-tmp) +
tree[tree[now].r].sum - tree[tree[pre].r].sum;//注意左边的k要变成k-tmp的形式
}
int getid(ll x){ return lower_bound(b+1,b+1+num,x)-b; }
int main(){
int T;
scanf("%d",&T);
while(T--){
memset(root,0,sizeof(root));
size = 0;
scanf("%d",&n);
for(int i = 1;i <= n;i ++){
scanf("%lld",&a[i]);
b[i] = a[i];
s[i] = s[i-1] + (1ll*i*i);//这里i会爆 所以开ll
}
/*****离散化 + 初始化****/
sort(b+1,b+1+n);
num = unique(b+1,b+1+n)-b-1;
build(root[0],1,num);
/***********************/
for(int i = 1;i <= n;i ++){
int p = getid(a[i]);
modify(root[i],root[i-1],1,num,p,a[i]);
}
scanf("%d",&q);
int l,r,k;
while(q--){
scanf("%d%d%d",&l,&r,&k);
ll ans = query(root[r],root[l-1],1,num,k);
ans += s[r-l+1];
printf("%lld\n",ans);
}
}
return 0;
}