要求的是所有的 l,r 使得:
r
−
l
=
k
∗
(
s
u
m
[
r
]
−
s
u
m
[
l
]
)
r - l = k * (sum[r] - sum[l])
r−l=k∗(sum[r]−sum[l])
设置一个块大小 block,枚举1的个数 x,按block分块
若 x ≤ b l o c k x \leq block x≤block,O(n) 枚举左端点统计所有右端点,复杂度为 O ( b l o c k ∗ n ) O(block * n) O(block∗n)
若 x > b l o c k x > block x>block,移项得到: r − k ∗ s u m [ r ] = l − k ∗ s u m [ l ] r - k*sum[r] = l - k * sum[l] r−k∗sum[r]=l−k∗sum[l],枚举 k,维护 p − k ∗ s u m [ p ] p - k * sum[p] p−k∗sum[p],按 k 分块,再按权值和下标分块,每个块内枚举每个右端点,对每个右端点统计有多少个左端点。
受空间限制,block的大小设为 ⌊ n 200 ⌋ \lfloor\frac{n}{200}\rfloor ⌊200n⌋左右,太小维护 x > b l o c k x > block x>block 的情况时会MLE
代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 10;
const int mod = 201326611;
#define pii pair<int,int>
typedef long long ll;
int n,sum[maxn];
char s[maxn];
vector<int> g;
vector<pii> h[450];
int main() {
scanf("%s",s + 1);
n = strlen(s + 1);
for(int i = 1; i <= n; i++) {
if(s[i] == '1') g.push_back(i);
if(s[i] == '1') sum[i] = 1;
else sum[i] = 0;
}
g.push_back(n+1);
int sqr = min(n,250);
int sqr2 = n / sqr;
for(int i = 1; i <= n; i++)
sum[i] += sum[i - 1];
for(int j = 1; j <= sqr; j++) {
for(int i = 0; i <= n; i++) {
int v = (1ll * i - 1ll * sum[i] * j % mod) % mod;
h[j].push_back(pii(v,i));
}
sort(h[j].begin(),h[j].end());
}
ll ans = 0,cnt = 0;
for(int i = 1; i <= sqr2; i++) { //1的个数
int p = 0;
for(int j = 1; j <= n; j++) {
while(p < g.size() && g[p] < j) p++;
if(p >= g.size()) continue;
if(p + i < g.size()) {
int L = g[p + i - 1] - j + 1,R = g[p + i] - j;
ans += R / i - (L - 1) / i;
}
}
}
for(int i = 1; i <= sqr; i++) {
int cnt = 0;
for(int l = 0,r; l < h[i].size(); l = r + 1) {
r = l;
while(r + 1 < h[i].size() && h[i][r].first == h[i][r + 1].first) r++;
int ll = l;
for(int rr = l; rr <= r; rr++) {
while(sum[h[i][rr].second] - sum[h[i][ll].second] > sqr2 && ll < rr) ll++;
ans += ll - l;
}
}
}
printf("%lld\n",ans);
return 0;
}