Problem
题目简述
DX3906 星系,Melancholy 星上,我在勘测这里的地质情况。
我把这些天来已探测到的区域分为 N 组,并用二元组(D,V)对每一组进行标记:其中 D 为区域的相对距离,V 为内部地质元素的相对丰富程度。
在我的日程安排表上有 Q 项指派的计划。每项计划的形式是类似的,都是“对相对距离 D 在[L,R]之间的区域进行进一步的勘测,并在其中有次序地挑出 K 块区域的样本进行研究。”采集这 K 块的样品后,接下来在实验中,它们的研究价值即为这 K 块区域地质相对丰富程度 V 的乘积。
我对这 Q 项计划都进行了评估:一项计划的评估值 P 为所有可能选取情况的研究价值之和。但是由于仪器的原因,在一次勘测中,这其中 V 最小的区域永远不会被选取。
现在我只想知道这 Q 项计划的评估值对 2^32 取模后的值,特殊地,如果没有 K 块区域可供选择,评估值为 0。
输入格式
第一行给出两个整数,区域数 N 与计划数 Q。
第二行给出 N 个整数,代表每一块区域的相对距离 D。
第三行给出 N 个整数,代表每一块区域的内部地质元素的相对丰富程度 V。
接下来的 Q 行,每一行 3 个整数,代表相对距离的限制 L,R,以及选取的块数 K。
输出格式
输出包括 Q 行,每一行一个整数,代表这项计划的评估值对 2^32 取模后的值。
输入输出样例
Input
5 3
5 4 7 2 6
1 4 5 3 2
6 7 1
2 6 2
1 8 3
Output
5
52
924
数据范围
1<=N,Q<=10^5
1<=D,V<=10^9
1<=L<=R<=10^9
1<=K<=6
Solution
对线段树上的每个区间,我们记录min和sum[7],min表示区间最小值,sum[i]表示在[l,r]区间中选取i个数的乘积的总和,不妨记sum[0]=1
min可以简单的由区间合并得到,sum可以这样合并
t[x].sum[i]=∑i0t[lc].sum[k]∗t[rc].sum[i−k]
这里用到了乘法原理,举个栗子,abc*de=abcde
这样就可以统计长度为k的乘积和了
那怎么把最小值去掉呢?
想到使用容斥原理,有如下计算方式(另query得到的答案为Q)
Ans=Q.sum[k]−Q.sum[k−1]∗Q.min+Q.sum[k−2]∗(Q.min)2−…
这样就可以把所有包含Q.min的组合去除,具体原理不再赘述
另外这道题的d很大,需要先离散化。并且每次Query之前,要在排好序的数组中二分得到左右端点
Code
#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for(int i = (a); i <= (b); i++)
#define red(i, a, b) for(int i = (a); i >= (b); i--)
#define ui unsigned int
const int N = 200000;
const ui inf = 2000000000;
struct node{
ui mi;
ui sum[7];
}t[N * 6];
struct hbh{
ui d, v;
}a[N];
int n, T;
ui xp[10];
inline ui read() {
ui x = 0; char c = getchar();
while(!isdigit(c)) c = getchar();
while(isdigit(c)) { x = x * 10 + c - '0'; c = getchar(); }
return x;
}
bool cmp(hbh a, hbh b) { return a.d < b.d; }
void calc() { xp[1] = 1; rep(i, 2, 6) xp[i] = xp[i - 1] * i; }
node operator * (node a, node b) {
node c;
c.mi = min(a.mi, b.mi);
memset(c.sum, 0, sizeof(c.sum));
c.sum[0] = 1;
rep(i, 1, 6) rep(j, 0, i) c.sum[i] += a.sum[j] * b.sum[i - j];
return c;
}
void build(int x, int l, int r) {
int mid = (l + r) >> 1, lc = x << 1, rc = lc + 1;
t[x].sum[0] = 1;
t[x].mi = inf;
if (l == r) return;
build(lc, l, mid);
build(rc, mid + 1, r);
}
void update(int x, int l, int r, int k, int num) {
int mid = (l + r) >> 1, lc = x << 1, rc = lc + 1;
if (l == r) { t[x].sum[1] = t[x].mi = num; return; }
if (k <= mid) update(lc, l, mid, k, num);
else update(rc, mid + 1, r, k, num);
t[x] = t[lc] * t[rc];
}
node query(int x, int l, int r, int ql, int qr) {
int mid = (l + r) >> 1, lc = x << 1, rc = lc + 1;
if (ql <= l && r <= qr) return t[x];
node ans_left, ans_right, ans_x;
if (ql <= mid) ans_left = query(lc, l, mid, ql, qr);
else {
ans_left.mi = inf;
memset(ans_left.sum, 0, sizeof(ans_left.sum));
ans_left.sum[0] = 1;
}
if (qr > mid) ans_right = query(rc, mid + 1, r, ql, qr);
else {
ans_right.mi = inf;
memset(ans_right.sum, 0, sizeof(ans_right.sum));
ans_right.sum[0] = 1;
}
ans_x = ans_left * ans_right;
return ans_x;
}
ui find(int pos) {
int l = 0, r = n + 1, mid;
while(l + 1 < r) {
mid = (l + r) >> 1;
if (a[mid].d < pos) l = mid;
else r = mid;
}
return l + 1;
}
int main() {
freopen("melancholy.in", "r", stdin);
freopen("melancholy.out", "w", stdout);
scanf("%d%d", &n, &T);
memset(t, 0, sizeof(t));
calc();
rep(i, 1, n) a[i].d = read();
rep(i, 1, n) a[i].v = read();
sort(a + 1, a + n + 1, cmp);
build(1, 1, n);
rep(i, 1, n) update(1, 1, n, i, a[i].v);
while(T--) {
ui ql = read(), qr = read(), k = read(), fvck = qr;
ql = find(ql); qr = find(qr);
qr = a[qr].d == fvck ? qr : qr - 1;
if (ql > qr || qr - ql + 1 < k) {
printf("0\n");
continue;
}
node qq = query(1, 1, n, ql, qr);
ui ans = 0, base = 1;
int j = 1;
red(i, k, 0) {
ans += qq.sum[i] * base * j;
base = base * qq.mi;
j *= -1;
}
printf("%u\n", ans * xp[k]);
}
return 0;
}