【题目链接】
【思路要点】
- 用 exgcd e x g c d 合并每一步的结果。
- 时间复杂度 O(NLogN) O ( N L o g N ) 。
【代码】
#include<bits/stdc++.h>
using namespace std;
const int MAXN = 1e5 + 5;
template <typename T> void chkmax(T &x, T y) {x = max(x, y); }
template <typename T> void read(T &x) {
x = 0; int f = 1;
char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -f;
for (; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0';
x *= f;
}
struct info {long long k, b; };
multiset <long long> st;
info ans; int n, m;
long long a[MAXN], p[MAXN], r[MAXN];
long long least;
long long find(long long x) {
multiset <long long> :: iterator tmp = st.upper_bound(x);
long long ans = 0;
if (tmp == st.begin()) ans = *tmp, st.erase(tmp);
else tmp--, ans = *tmp, st.erase(tmp);
return ans;
}
long long gcd(long long x, long long y) {
if (y == 0) return x;
else return gcd(y, x % y);
}
void exgcd(long long a, long long b, long long &x, long long &y) {
if (b == 0) {
x = 1;
y = 0;
return;
}
long long q = a / b, r = a % b;
exgcd(b, r, y, x);
x *= -1, y *= -1;
y += q * x;
}
long long times(long long a, long long b, long long P) {
if (b == 0) return 0;
long long tmp = times(a, b / 2, P);
if (b % 2 == 0) return (tmp + tmp) % P;
else return (tmp + tmp + a) % P;
}
info equation(long long a, long long b, long long c) {
long long g = gcd(a, b);
if (c % g != 0) return (info) {-1, -1};
long long x = 0, y = 0;
a /= g, b /= g, c /= g;
exgcd(a, b, x, y);
if (c % a == 0) chkmax(least, c / a);
else chkmax(least, c / a + 1);
x = (x % b + b) % b;
x = times(x, c, b);
return (info) {b, x};
}
info merge(info vx, info vy) {
if (vx.b > vy.b) swap(vx, vy);
info ans = equation(vx.k, vy.k, vy.b - vx.b);
if (ans.k == -1) return ans;
ans.k *= vx.k, ans.b = ans.b * vx.k + vx.b;
return ans;
}
void solve() {
least = 0; ans.k = 1, ans.b = 0;
for (int i = 1; i <= n; i++) {
long long tmp = find(a[i]);
info now = equation(tmp, p[i], a[i]);
if (now.k == -1) {
printf("-1\n");
return;
}
ans = merge(ans, now);
if (ans.k == -1) {
printf("-1\n");
return;
}
st.insert(r[i]);
}
if (ans.b >= least) printf("%lld\n", ans.b);
else {
least -= ans.b;
if (least % ans.k != 0) least = (least / ans.k + 1) * ans.k;
printf("%lld\n", least + ans.b);
}
}
int main() {
int T; read(T);
while (T--) {
read(n), read(m);
for (int i = 1; i <= n; i++)
read(a[i]);
for (int i = 1; i <= n; i++)
read(p[i]);
for (int i = 1; i <= n; i++)
read(r[i]);
st.clear();
for (int i = 1; i <= m; i++) {
long long x; read(x);
st.insert(x);
}
solve();
}
return 0;
}