题意:
求[L,R]区间内有多少个整数y满足 y = k1x1+b1 且 y = k2x2+b2. (x1 x2 >= 0)
题解:
解一下不定方程 a1k + b1 = a2l + b2,设 k mod lcm(a1, a2) / a1 的值是 t,设 lcm(a1, a2) / a1 = A,那么 k 可以写成 q·A + t 这个样子,那么显然 A 是有上下界的,我们二分到这个上下界,做个差就是答案了。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <stack>
#include <vector>
#include <queue>
#include <cstring>
#include <string>
#include <map>
#include <set>
using namespace std;
#define MAXN 1000+10
#define oo 4000000000ll
#define LL long long
#define LD long double
LL a1, b1, a2, b2, L, R;
LL gcd(LL a, LL b, LL& x, LL& y) {
if(b == 0){ x = 1, y = 0; return a; }
LL d = gcd(b, a % b, y, x); y -= (a / b) * x;
return d;
}
LL gcd(LL a, LL b) { return b == 0 ? a : gcd(b, a % b); }
int main() {
cin >> a1 >> b1 >> a2 >> b2 >> L >> R;
LL k, t;
LL d = gcd(a1, a2, k, t);
if((b2 - b1) % d != 0) return puts("0"), 0;
k *= (b2 - b1) / d; t *= (b2 - b1) / d;
LL A2 = a2 / gcd(a1, a2);
LL mod = (k % A2 + A2) % A2, al, ar;
// printf("%lld %lld\n", k, mod);
LL l, r; l = -oo - 1; r = oo + 1;
// printf("%lld %lld\n", l, r);
while(l < r) {
LL mid = l + (r - l) / 2;
LL lsid = (L - b1) % a1 != 0 ? (L - b1) / a1 + (L - b1 > 0 ? 1 : 0) : (L - b1) / a1,
rsid = (R - b1) % a1 != 0 ? (R - b1) / a1 + (R - b1 > 0 ? 0 : -1) : (R - b1) / a1,
x = mid * A2 + mod;
LL l2 = (L - b2) % a2 != 0 ? (L - b2) / a2 + (L - b2 > 0 ? 1 : 0) : (L - b2) / a2,
r2 = (R - b2) % a2 != 0 ? (R - b2) / a2 + (R - b2 > 0 ? 0 : -1) : (R - b2) / a2,
y = ((LD)a1 * x - b2 + b1) / a2;
// printf("%lld %lld %lld %lld %lld [%lld, %lld]\n", lsid, l2, mid, y, x, l, r);
if(lsid <= x && l2 <= y && x >= 0 && y >= 0) r = mid;
else l = mid + 1;
}
al = l;
l = -oo - 1; r = oo + 1;
// printf("%lld %lld\n", l, r);1 -2000000000 2 2000000000 -2000000000 2000000000
while(l < r - 1) {
LL mid = l + (r - l) / 2;
LL lsid = (L - b1) % a1 != 0 ? (L - b1) / a1 + (L - b1 > 0 ? 1 : 0) : (L - b1) / a1,
rsid = (R - b1) % a1 != 0 ? (R - b1) / a1 + (R - b1 > 0 ? 0 : -1) : (R - b1) / a1,
x = mid * A2 + mod;
LL l2 = (L - b2) % a2 != 0 ? (L - b2) / a2 + (L - b2 > 0 ? 1 : 0) : (L - b2) / a2,
r2 = (R - b2) % a2 != 0 ? (R - b2) / a2 + (R - b2 > 0 ? 0 : -1) : (R - b2) / a2,
y = ((LD)a1 * x - b2 + b1) / a2;
// printf("%lld %lld %lld %lld %lld [%lld, %lld]\n", mid, x, y, rsid, r2, l, r);
if(x <= rsid && y <= r2) l = mid;
else r = mid;
}
ar = l;
// printf("%lld %lld\n", al, ar);
LL mid = l + (r - l) / 2;
LL lsid = (L - b1) % a1 != 0 ? (L - b1) / a1 + (L - b1 > 0 ? 1 : 0) : (L - b1) / a1,
rsid = (R - b1) % a1 != 0 ? (R - b1) / a1 + (R - b1 > 0 ? 0 : -1) : (R - b1) / a1,
x = mid * A2 + mod;
LL l2 = (L - b2) % a2 != 0 ? (L - b2) / a2 + (L - b2 > 0 ? 1 : 0) : (L - b2) / a2,
r2 = (R - b2) % a2 != 0 ? (R - b2) / a2 + (R - b2 > 0 ? 0 : -1) : (R - b2) / a2,
y = ((LD)a1 * x - b2 + b1) / a2;
if(lsid <= x && x <= rsid && l2 <= y && y <= r2 && x >= 0 && y >= 0 && al <= ar)
cout << ar - al + 1 << endl;
else puts("0");
return 0;
}