通常在统计区间贡献的问题时,我们会用分治思想来降低统计的次数。
我们通过分治把一些有相同属性的区间集中在一次统计贡献中。
题目链接
题目大意
一个序列 A A A , 问这个序列有多少个子区间满足以下条件:
- 对于区间
l
,
r
l,r
l,r中最大值
m
a
x
N
maxN
maxN 与最小值
m
i
n
N
minN
minN 的二进制表示中,
1
的个数相同。
解题思路
对于一个长度为
N
N
N 的序列,子区间会有
N
×
N
÷
2
N \times N \div 2
N×N÷2 个,如果暴力统计满足条件的区间个数,显然不能满足题目的效率要求。
此时,我们就需要转变以下思路,我们需要计算对每一个最大值,最小值的贡献即可。
考虑分治。
对于分治区间
(
l
,
r
)
(l,r)
(l,r),我们统计所有跨越区间中点
m
i
d
=
(
l
+
r
)
/
2
mid =( l+r)/2
mid=(l+r)/2 的子区间的贡献。
需要分四种情况讨论:
- 子区间最大值出现在 l , m i d l,mid l,mid区间内,子区间最小值也出现在 l , m i d l,mid l,mid 区间内
- 子区间最大值出现在 m i d + 1 , r mid+1,r mid+1,r区间内,子区间最小值也出现在 m i d + 1 , r mid+1,r mid+1,r 区间内
- 子区间最大值出现在 l , m i d l,mid l,mid区间内,子区间最小值出现在 m i d + 1 , r mid+1,r mid+1,r 区间内
- 子区间最大值出现在 m i d + 1 , r mid+1,r mid+1,r区间内,子区间最小值出现在 l , m i d l,mid l,mid 区间内
我们使用双指针(
p
1
,
p
2
p1,p2
p1,p2)分别从区间中点
m
i
d
mid
mid开始,访问对应的左右区间
(
l
,
m
i
d
)
,
(
m
i
d
+
1
,
r
)
(l,mid),(mid+1,r)
(l,mid),(mid+1,r)。
一个显然的结论,随着双指针从中间
m
i
d
mid
mid往两边访问的过程中,访问的区间(
p
1
,
m
i
d
p1,mid
p1,mid),(
m
i
d
+
1
,
p
2
mid+1,p2
mid+1,p2)的最大值递增,最小值递减。
对于第一种情况,我们需要保证
m
a
x
(
p
1
,
m
i
d
)
≥
m
a
x
(
m
i
d
+
1
,
p
2
)
并且
m
i
n
(
p
1
,
m
i
d
)
≤
m
i
n
(
m
i
d
+
1
,
r
)
max(p1,mid) \ge max(mid+1,p2)并且min(p1,mid) \le min(mid+1,r)
max(p1,mid)≥max(mid+1,p2)并且min(p1,mid)≤min(mid+1,r),对于每一个满足这个条件的
p
1
,
p
2
p1,p2
p1,p2,都有
(
p
1
,
m
i
d
+
1
→
p
2
)
(p1,mid+1\to p2)
(p1,mid+1→p2) 的区间的最值满足最大值是
m
a
x
(
p
1
,
m
i
d
)
max(p1,mid)
max(p1,mid) 最小值是
m
i
n
(
p
1
,
m
i
d
)
min(p1,mid)
min(p1,mid)
第二种情况和第一种情况相同。
对于第三种情况,我们需要保证
m
a
x
(
p
1
,
m
i
d
)
≥
m
a
x
(
m
i
d
+
1
,
p
2
)
并且
m
i
n
(
p
1
,
m
i
d
)
>
m
i
n
(
m
i
d
+
1
,
r
)
max(p1,mid) \ge max(mid+1,p2)并且min(p1,mid) > min(mid+1,r)
max(p1,mid)≥max(mid+1,p2)并且min(p1,mid)>min(mid+1,r),对于这种情况,需要用两个指针来
p
2
,
p
2
′
p2,p2'
p2,p2′ 来控制右区间的指针,
p
1
p_1
p1 来控制左区间。
第四种情况和第三种情况相同。
对于三 、四两种情况,我们需要记录
p
2
,
p
2
′
p_2,p_2'
p2,p2′ 这个区间内合法的贡献数。这题只需要先预处理序列
A
A
A 每个值的二进制1
的位数即可
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 1e6 + 10;
ll a[N];
int bit[N];
ll maxl[N];
ll minl[N];
int bit_cnt[64];
int n;
ll ans;
int count_1(ll x) {
int cnt = 0;
while (x) {
cnt ++;
x = x & (x - 1);
}
return cnt;
}
void divde(int l, int r) {
if (l == r) {
ans ++;
return;
}
int mid = (l + r) >> 1;
ll minn = mid + 1;
ll maxn = mid + 1;
for (int i = mid + 1; i <= r; i ++) {
if (a[minn] > a[i]) minn = i;
if (a[maxn] < a[i]) maxn = i;
maxl[i] = maxn;
minl[i] = minn;
}
minn = mid;
maxn = mid;
for (int i = mid; i >= l; i --) {
if (a[minn] > a[i]) minn = i;
if (a[maxn] < a[i]) maxn = i;
maxl[i] = maxn;
minl[i] = minn;
}
int p1, p2;
// max_number appear in left interval
// min_numver appear in left interval
// p1 is the right inerval ptr.
p1 = mid + 1;
for (int i = mid; i >= l; i --) {
while (a[maxl[p1]] <= a[maxl[i]] && a[minl[p1]] >= a[minl[i]] && p1 <= r ) p1 ++;
if (bit[maxl[i]] == bit[minl[i]]) ans = (ans + p1 - mid - 1);
}
// max_number appear in right interval
// min_number appear in right interval
// p1 is the left interval ptr.
p1 = mid;
for (int i = mid + 1; i <= r; i ++) {
while (a[maxl[p1]] <= a[maxl[i]] && a[minl[p1]] >= a[minl[i]] && p1 >= l && (a[maxl[p1]] != a[maxl[i]] || a[minl[p1]] != a[minl[i]]) ) p1 --;
if (bit[maxl[i]] == bit[minl[i]]) ans = (ans + mid - p1);
}
// max_number appear in left interval
// min_numver appear in right interval
// p1 is the max bound, and p2 is the min bound.
for (int i = 0; i <= 60; i ++) bit_cnt[i] = 0;
p1 = mid + 1;
p2 = mid + 1;
for (int i = mid; i >= l; i --) {
while (a[maxl[i]] > a[maxl[p1]] && p1 <= r) {bit_cnt[bit[minl[p1]]] ++ ;p1 ++;}
while (a[minl[i]] <= a[minl[p2]] && p2 < p1 ) {bit_cnt[bit[minl[p2]]] -- ;p2 ++;}
ans = (ans + bit_cnt[bit[maxl[i]]]);
}
// cout << ans << endl;
// max_number appear in right interval
// min_number appear in left interval
// ....
for (int i = 0; i <= 60; i ++) bit_cnt[i] = 0;
p1 = mid + 1;
p2 = mid + 1;
for (int i = mid; i >= l; i --) {
while (a[minl[i]] < a[minl[p2]] && p2 <= r) { bit_cnt[bit[maxl[p2]]] ++ ;p2 ++;}
while (a[maxl[i]] >= a[maxl[p1]] && p1 < p2) {bit_cnt[bit[maxl[p1]]] -- ;p1 ++;}
ans = (ans + bit_cnt[bit[minl[i]]]);
}
// cout << l << " " << r << " " << ans - now_ans << endl;
divde(l, mid);
divde(mid + 1, r);
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin >> n;
for (int i = 1; i <= n; i ++) {
cin >> a[i];
bit[i] = count_1(a[i]);
}
divde(1, n);
cout << ans;
}