醉之见面会(莫队+树状数组)
题目描述:
有很多个人需要发言,第i个人的发言时间为ai分钟
由于时间有限,需要挑选第l个到第r个小朋友来发言,如何安排他们的发言顺序才能使得他们发言+等待的时间总和最小。
比如,如果你安排2,3,5的小朋友一次发言,发言得总时间为2+(3+2)+(5+3+2)=17分钟
现在有m个询问,每个询问为[l,r],你要对每个询问进行回答
思路:
如果需要区间[l,r]的总时间最小,那么我们必须从小到大安排这个区间,如果每次对一个询问区间排序,时间复杂度过大,如何加速?
我们可以用莫队算法,莫队并不是一个模板,而是一种思路,意思是对询问进行排序,依次回答,每次暴力地移动区间,那么就会有两种情况:新增一个数,答案会如何变化,减少一个数,答案会如何变化。
本题中我们可以注意到,假设当前区间从小到大排序后是B[1], B[2], ... B[N],当新插入一个X时,假设排序后是B[1], B[2], ... B[K-1], X, B[K],B[K+1], ... B[N]。
那么,对于x前面的数,x都要等待他们发言,此时答案相当于加上了一个比x小的数的前缀和;
对于x后面的数,每个数都要等待x发言,假设有num个,答案相当于加上num*x。
所以,我们需要维护一个已经出现的比x小的数的前缀和,和已经出现的比x大的数的个数。
所以,我们想到树状数组,树状数组维护逆序对相信都不陌生,如果一个数出现标记为1,同时前缀和加上这个数,这样就可以同时用树状数组维护了。
莫队的增加和减少也很好写,增加就是上面增加的思路,减少就用树状数组重新标记为0,答案减去上面增加的数就好了。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <cmath>
#include <queue>
#include <vector>
using namespace std;
typedef long long ll;
const int maxn = 20000 + 10;
ll a[maxn];
int n, m;
//询问
struct Query {
int l, r;
int id;
Query(int _l, int _r, int _id) :l(_l), r(_r), id(_id) {}
Query() {}
}Q[maxn];
int block;//块
ll ANS[maxn];
ll ans;//记录答案
int pos[maxn];//每个询问所在的块
int cmp(const Query a, const Query b) {
if (pos[a.l] == pos[b.l]) {
return pos[a.r] < pos[b.r];
}
return pos[a.l] < pos[b.l];
}
int lowbit(int x) {
return x & -x;
}
int s[maxn];//每个数在当前询问出现了多少次
ll Sum[maxn];
//树状数组
int sum(int x,ll &presum) {
int res = 0;
presum = 0;
while (x > 0) {
res += s[x];
presum += Sum[x];
x -= lowbit(x);
}
return res;
}
void add(int x, int d) {
int k = x;
while (x < maxn) {
s[x] += d;
Sum[x] += 1ll * k * d;
x += lowbit(x);
}
}
//莫队
void ADD(int x,int L,int R) {
add(a[x], 1);
ll presum;//在他之前出现过的数的前缀和
int now = sum(a[x], presum)-1;
//printf("now是多少 %d\n", now);
int num = (R - L - now);
//printf("debug presum num %lld %d\n", presum,num);
ans += presum;
ans += 1ll * a[x]* num;
}
void DEL(int x,int L,int R) {
add(a[x], -1);
ll presum;
int now = sum(a[x], presum) - 1;
int num = (R - L - now);
ans -= presum;
ans -= 1ll * a[x] * num;
}
signed main() {
int T;
scanf("%d", &T);
while (T--) {
scanf("%d%d", &n, &m);
int sz = sqrt(n);
for (int i = 1; i <= n; i++) {
scanf("%lld", &a[i]);
pos[i] = i / sz;
}
ans = 0;
memset(s, 0, sizeof(s));
memset(Sum, 0, sizeof(Sum));
for (int i = 0; i < maxn; i++)Sum[i] = 0;
for (int i = 1; i <= m; i++) {
int l, r;
scanf("%d%d", &l, &r);
Q[i].l = l;
Q[i].r = r;
Q[i].id = i;
}
sort(Q + 1, Q + m + 1, cmp);
int L = 1, R = 0;
for (int i = 1; i <= m; i++) {
while (R < Q[i].r) {
R++;
ADD(R,L,R);
//printf("ans %lld\n", ans);
}
while (L > Q[i].l) {
L--;
ADD(L, L, R);
}
while (L < Q[i].l) {
DEL(L,L,R);
L++;
}
while (R > Q[i].r) {
DEL(R,L,R);
R--;
}
ANS[Q[i].id] = ans;
}
for (int i = 1; i <= m; i++) {
printf("%lld\n", ANS[i]);
}
}
return 0;
}