Description
给你一个区间[l,r],这个区间由二进制表示,
现在询问你在这个区间内的数在二进制的表示下有c00个00,c01个01,c10个10,c11个11的有多少个。
Sample Input
10
1001
0
0
1
1
Sample Output
1
我们思考其实一个01,一个10就相当于一个分割线,将连续一段的0或1分割开来。
那么我们对于一个连续的k个0,那么其实就会有k - 1个00,也就是说我们现在如果有a段,要求要g个00,那么0的个数为:a+g
那么我们即可根据c00,c01,c10,c11求出这个序列中有多少个0,1,并且被分成了几段。
那么问题就转化成了这样:给你一个序列要求有a个0,b个1,有k0段连续的0,k1段连续的1,有多少种满足的方案。
那我们可以这么想一段连续的0用k0-1个隔开,那么也就是在n-1个位置防止k0-1个数的方案数,即为:C(n-1,k0-1)。
那么对于0是如此,根据乘法原理1的情况与0的情况相乘即可得出。
然后在考虑[l,r]范围的限制,这个就比较套路了,然后上面那个是O(1)算的,于是时间复杂度变为O(n)。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
const LL mod = 1e9 + 7;
int _min(int x, int y) {return x < y ? x : y;}
int _max(int x, int y) {return x > y ? x : y;}
int read() {
int s = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
while(ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
return s * f;
}
int len[2], s[2][110000];
LL jc[110000], inv[110000];
char ss[2][110000];
LL pow_mod(LL a, LL k) {
LL ans = 1;
while(k) {
if(k & 1) (ans *= a) %= mod;
(a *= a) %= mod; k /= 2;
} return ans;
}
void pre() {
inv[0] = jc[0] = 1;
for(int i = 1; i <= 100000; i++) jc[i] = (LL)jc[i - 1] * i % mod;
inv[100000] = pow_mod(jc[100000], mod - 2);
for(int i = 100000 - 1; i >= 1; i--) inv[i] = (LL)inv[i + 1] * (i + 1) % mod;
}
LL C(LL n, LL m) {if(n < m) return 0; return jc[n] * inv[m] % mod * inv[n - m] % mod;}
LL get(LL a, LL b) {if(a == 0) return 1; return C(a + b - 1, b - 1);}
LL solve0(LL c00, LL c01, LL c10, LL c11) {
if(c00 < 0 || c01 < 0 || c10 < 0 || c11 < 0) return 0;
if(c10 == c01) return get(c11, c10) * get(c00, c01 + 1) % mod;
else if(c10 + 1 == c01) return get(c11, c10 + 1) * get(c00, c01) % mod;
else return 0;
}
LL solve1(LL c00, LL c01, LL c10, LL c11) {
if(c00 < 0 || c01 < 0 || c10 < 0 || c11 < 0) return 0;
if(c10 == c01) return get(c11, c10 + 1) * get(c00, c01) % mod;
else if(c10 == c01 + 1) return get(c11, c10) * get(c00, c01 + 1) % mod;
else return 0;
}
LL calc(LL c00, LL c01, LL c10, LL c11, int opt) {
LL ans = 0;
for(int i = 2; i <= len[opt]; i++) {
if(s[opt][i] == 1) {
if(s[opt][i - 1] == 0) (ans += solve0(c00 - 1, c01, c10, c11)) %= mod;
else (ans += solve0(c00, c01, c10 - 1, c11)) %= mod;
}
if(s[opt][i - 1] == 0 && s[opt][i] == 0) c00--;
else if(s[opt][i - 1] == 0 && s[opt][i] == 1) c01--;
else if(s[opt][i - 1] == 1 && s[opt][i] == 0) c10--;
else if(s[opt][i - 1] == 1 && s[opt][i] == 1) c11--;
} if(opt && !c00 && !c01 && !c10 && !c11) (ans += 1) %= mod;
return ans;
}
int main() {
pre();
LL c00, c01, c10, c11;
scanf("%s%s", ss[0] + 1, ss[1] + 1);
len[0] = strlen(ss[0] + 1), len[1] = strlen(ss[1] + 1);
for(int i = 1; i <= len[0]; i++) s[0][i] = ss[0][i] - '0';
for(int i = 1; i <= len[1]; i++) s[1][i] = ss[1][i] - '0';
c00 = read(), c01 = read(), c10 = read(), c11 = read();
LL sum = c00 + c01 + c10 + c11 + 1;
if(sum > len[1] || sum < len[0]) {printf("0\n"); return 0;}
LL ans = 0;
if(sum < len[1]) (ans += solve1(c00, c01, c10, c11)) %= mod;
else (ans += calc(c00, c01, c10, c11, 1)) %= mod;
if(sum == len[0]) (ans -= calc(c00, c01, c10, c11, 0)) %= mod;
printf("%lld\n", (ans + mod) % mod);
return 0;
}