题意:
给定长度为
n
n
n的正整数序列
a
a
a。
有
m
m
m组询问,每组询问包括一个正整数
x
x
x。
对每个询问计算共有多少个区间
[
l
,
r
]
[l,r]
[l,r],满足
a
l
,
a
l
+
1
,
…
,
a
r
a_l,a_{l+1},\dots,a_r
al,al+1,…,ar的最大公约数为
x
x
x,即
g
c
d
(
a
l
,
a
l
+
1
,
…
,
a
r
)
=
x
gcd(a_l,a_{l+1},\dots,a_r)=x
gcd(al,al+1,…,ar)=x。
数据范围:
题解:
1.首先,我们使用 S T ST ST表存储区间之间的最大公约数。
2.对每个询问 x x x,枚举左端点 L L L。
①对每个左端点 L L L,在最大公约数为 g g g的情况下,计算出满足 g c d ( [ L , R ] ) = g gcd([L,R])=g gcd([L,R])=g的最大的 R R R, g g g初始化为 s t [ L ] [ 0 ] st[L][0] st[L][0]。
②那么对于最大公约数为 g g g,且左端点为 L L L的情况下,满足条件的区间个数为 R − L + 1 R -L+1 R−L+1。接着我们将 g g g更新为 g c d ( [ L , R + 1 ] ) gcd([L,R+1]) gcd([L,R+1]),重复之前的操作,直到 R > n R>n R>n。
由于区间的最大公约数随着区间右端点的不断增加,具有不增加的性质,因此,我们可以对右端点 R R R进行二分查找。
此时, g c d ( a , b ) = m i n ( a , b ) 且 g c d ( a , b ) ≤ m i n ( a , b ) 2 gcd(a,b)=min(a,b)且gcd(a,b) \leq \frac{min(a,b)}2 gcd(a,b)=min(a,b)且gcd(a,b)≤2min(a,b),此时, g c d gcd gcd最多只会变化 l o g n logn logn次,也就是第 2 2 2步最多只会执行 l o g n logn logn次。
我们在 O ( n ) O(n) O(n)枚举左端点的同时,二分查找满足条件的最大右端点,因此时间复杂度为 O ( n l o g n ) O(nlogn) O(nlogn)。
实现细节见代码:
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
#define int long long
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int MAXN = 1e5 + 10;
const int mod = 1e9 + 7;
int st[MAXN][30], logn[MAXN], a[MAXN];
map<int, int> ans;
void init() {
logn[1] = 0;
for (int i = 2; i < MAXN; i++) {
logn[i] = logn[i / 2] + 1;
}
}
void get_st(int n) {
for (int j = 1; j <= 25; j++) {
for (int i = 1; i + (1 << j) - 1 <= n; i++) {
st[i][j] = __gcd(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
}
}
}
int query(int l, int r) {
int k = logn[r - l + 1];
return __gcd(st[l][k], st[r - (1 << k) + 1][k]);
}
int check(int g, int L, int l, int n) {
int r = n;
while (l < r) {
int mid = l + r >> 1;
if (query(L, mid) != g) {
r = mid;
}
else {
l = mid + 1;
}
}
if (query(L, l) != g) l--; // 注意特判当前l是否符合要求
return l;
}
void solve(int l, int n) {
int r = l, g = st[l][0];
while (r <= n) {
int pre = r;
r = check(g, l, r, n); // 二分满足区间公约数为g的最大右端点
ans[g] += r - pre + 1;
g = query(l, r + 1);
r++;
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int n;
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> a[i];
st[i][0] = a[i];
}
init();
get_st(n);
for (int i = 1; i <= n; i++) { // 枚举左端点
solve(i, n);
}
int q;
cin >> q;
while (q--) {
int x;
cin >> x;
cout << ans[x] << endl;
}
return 0;
}