折腾了好久也没有想到快一点的方法…
用了第二种思路也只快了十多毫秒…
继续思考,争取做得更快再来写题解… T T
Ver. 1 359ms
struct Node {
long long cnt;
long long lower;
long long pivot;
long long upper;
Node *left;
Node *right;
Node(long long lower, long long upper, long long cnt)
: cnt(cnt), lower(lower), upper(upper), left(NULL), right(NULL) {}
};
void newInterval(Node* &root, long long lower, long long upper, long long cnt = 1) {
if (root) {
long long new_lower = lower <= root->lower ? lower : root->lower;
long long new_upper = upper >= root->upper ? upper : root->upper;
if (!root->left && !root->right) {
if (root->lower == lower && root->upper == upper) root->cnt += cnt;
else {
root->pivot = (new_lower + new_upper) / 2;
if (root->pivot <= 0 && new_lower + 1 == new_upper && new_lower != 0)
root->pivot--;
if (root->upper <= root->pivot) {
newInterval(root->left, root->lower, root->upper, root->cnt);
}
else if (root->lower <= root->pivot) {
newInterval(root->left, root->lower, root->pivot, root->cnt);
newInterval(root->right, root->pivot + 1, root->upper, root->cnt);
}
else {
newInterval(root->right, root->lower, root->upper, root->cnt);
}
if (upper <= root->pivot) {
newInterval(root->left, lower, upper);
}
else if (lower <= root->pivot) {
newInterval(root->left, lower, root->pivot);
newInterval(root->right, root->pivot + 1, upper);
}
else {
newInterval(root->right, lower, upper);
}
}
}
else {
if (upper <= root->pivot) {
newInterval(root->left, lower, upper);
}
else if (lower <= root->pivot) {
newInterval(root->left, lower, root->pivot);
newInterval(root->right, root->pivot + 1, upper);
}
else if (lower > root->pivot) {
newInterval(root->right, lower, upper);
}
}
root->lower = new_lower;
root->upper = new_upper;
}
else {
root = new Node(lower, upper, cnt);
}
}
long long getCnt(Node* root, long long target) {
if (root) {
if (target < root->lower || target > root->upper) return 0;
else if (!root->left && !root->right) {
return root->cnt;
}
else {
if (target <= root->pivot) return getCnt(root->left, target);
else return getCnt(root->right, target);
}
}
return 0;
}
class Solution {
public:
long long countRangeSum(vector<int>& nums, int lower, int upper) {
Node* root = NULL;
newInterval(root, lower, upper);
int cnt = 0;
long long offset = 0;
for (long long i = 0; i < nums.size(); i++) {
offset += nums[i];
cnt += getCnt(root, offset);
newInterval(root, lower + offset, upper + offset);
}
return cnt;
}
};
Ver. 2 313ms
#include <iostream>
#include <vector>
using namespace std;
struct Node {
long long cnt;
long long lower;
long long pivot;
long long upper;
Node *left;
Node *right;
Node(long long lower, long long upper, long long cnt)
: cnt(cnt), lower(lower), upper(upper), left(NULL), right(NULL) {}
};
void del(Node* root) {
if (root) {
del(root->left);
del(root->right);
delete root;
}
}
void newInterval(Node* &root, long long lower, long long upper, long long cnt = 1) {
if (root) {
long long new_lower = lower <= root->lower ? lower : root->lower;
long long new_upper = upper >= root->upper ? upper : root->upper;
if (!root->left && !root->right) {
if (root->lower == lower && root->upper == upper) root->cnt += cnt;
else {
root->pivot = (new_lower + new_upper) / 2;
if (root->pivot <= 0 && new_lower + 1 == new_upper && new_lower != 0)
root->pivot--;
if (root->upper <= root->pivot) {
newInterval(root->left, root->lower, root->upper, root->cnt);
}
else if (root->lower <= root->pivot) {
newInterval(root->left, root->lower, root->pivot, root->cnt);
newInterval(root->right, root->pivot + 1, root->upper, root->cnt);
}
else {
newInterval(root->right, root->lower, root->upper, root->cnt);
}
if (upper <= root->pivot) {
newInterval(root->left, lower, upper);
}
else if (lower <= root->pivot) {
newInterval(root->left, lower, root->pivot);
newInterval(root->right, root->pivot + 1, upper);
}
else {
newInterval(root->right, lower, upper);
}
}
}
else {
if (upper <= root->pivot) {
newInterval(root->left, lower, upper);
}
else if (lower <= root->pivot) {
newInterval(root->left, lower, root->pivot);
newInterval(root->right, root->pivot + 1, upper);
}
else if (lower > root->pivot) {
newInterval(root->right, lower, upper);
}
}
root->lower = new_lower;
root->upper = new_upper;
}
else {
root = new Node(lower, upper, cnt);
}
}
long long getCnt(Node* root, long long target) {
if (root) {
if (target < root->lower || target > root->upper) return 0;
else if (!root->left && !root->right) {
return root->cnt;
}
else {
if (target <= root->pivot) return getCnt(root->left, target);
else return getCnt(root->right, target);
}
}
return 0;
}
class Solution {
public:
int A(vector<int>& nums, long long lower, long long upper, int i, int j) {
if (i == j) {
if (lower <= nums[i] && nums[i] <= upper) {
return 1;
}
else return 0;
}
int mid = (i + j) / 2;
int cnt = 0;
cnt += A(nums, lower, upper, i, mid);
cnt += A(nums, lower, upper, mid + 1, j);
Node* root = NULL;
long long acc = 0;
for (int idx = mid + 1; idx <= j; idx++) {
acc += nums[idx];
newInterval(root, lower - acc, upper - acc);
}
acc = 0;
for (int idx = mid; idx >= i; idx--) {
acc += nums[idx];
cnt += getCnt(root, acc);
}
del(root);
return cnt;
}
int countRangeSum(vector<int>& nums, int lower, int upper) {
if (nums.empty()) return 0;
return A(nums, lower, upper, 0, nums.size() - 1);
}
};