D. Petya and Array
time limit per test
2 seconds
memory limit per test
256 megabytes
input
standard input
output
standard output
Petya has an array aa consisting of nn integers. He has learned partial sums recently, and now he can calculate the sum of elements on any segment of the array really fast. The segment is a non-empty sequence of elements standing one next to another in the array.
Now he wonders what is the number of segments in his array with the sum less than tt. Help Petya to calculate this number.
More formally, you are required to calculate the number of pairs l,rl,r (l≤rl≤r) such that al+al+1+⋯+ar−1+ar<tal+al+1+⋯+ar−1+ar<t.
Input
The first line contains two integers nn and tt (1≤n≤200000,|t|≤2⋅10141≤n≤200000,|t|≤2⋅1014).
The second line contains a sequence of integers a1,a2,…,ana1,a2,…,an (|ai|≤109|ai|≤109) — the description of Petya's array. Note that there might be negative, zero and positive elements.
Output
Print the number of segments in Petya's array with the sum of elements less than tt.
Examples
input
Copy
5 4
5 -1 3 4 -1
output
Copy
5
input
Copy
3 0
-1 2 -3
output
Copy
4
input
Copy
4 -1
-2 1 -2 3
output
Copy
3
题目大意:给定数组a和一个数字t,问有多少对i,j使得子序列(i,j)的和小于t。
预处理前缀和sum[i],对于位置i,即求有多少个j使得sum[i] - sum[j] < t,即sum[i] - t < sum[j],所以维护下sum数组中有多少个大于sum[i] - t即可,之前都是用线段树写的,这次用splay写了下。
#include <cstdio>
#include <vector>
#include <algorithm>
#include <cstring>
#include <iostream>
using namespace std;
#define ll long long
#define lson tr[x][0]
#define rson tr[x][1]
const int maxn = 2e5 + 10;
const int INF = 1e9;
int root, tot, n;
ll t, ans;
int fa[maxn], tr[maxn][2], sz[maxn];
ll sum[maxn], key[maxn];
int num[maxn], a[maxn];
int judge(int x) {
return tr[fa[x]][1] == x;
}
void pushup(int x) {
if (x) {
sz[x] = sz[tr[x][0]] + sz[tr[x][1]] + num[x];
}
}
void rotate(int x) {
int y = fa[x], d = judge(x);
if (tr[y][d] = tr[x][d ^ 1]) fa[tr[y][d]] = y;
if (fa[x] = fa[y]) tr[fa[y]][judge(y)] = x;
tr[fa[y] = x][d ^ 1] = y;
pushup(y);
}
void splay(int x, int k) {
for (int y; (y = fa[x]) != k; rotate(x))
if (fa[y] != k) rotate(judge(x) == judge(y) ? y : x);
pushup(x);
if (!k) root = x;
}
void insert(ll val) {
int node = root, f = 0;
while (node && key[node] != val) {
f = node;
node = tr[node][key[node] < val];
}
if (node) num[node] += 1;
else {
node = ++tot;
if (f) tr[f][key[f] < val] = node;
fa[node] = f;
key[node] = val;
sz[node] = num[node] = 1;
}
splay(node, 0);
}
void find(ll x) {
int r = root;
while (key[r] != x && tr[r][key[r] < x])
r = tr[r][key[r] < x];
splay(r, 0);
}
int main() {
scanf("%d%I64d", &n, &t);
tot = root = ans = sum[0] = 0;
insert(0);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
sum[i] = sum[i - 1] + a[i];
}
for (int i = 1; i <= n; i++) {
find(sum[i] - t);
if (key[root] > (sum[i] - t)) ans += (sz[tr[root][1]] + num[root]);
else ans += sz[tr[root][1]];
insert(sum[i]);
}
printf("%I64d\n", ans);
return 0;
}