题目描述
给你一个长为n的序列a
m次查询
每次查询一个区间的所有子区间的gcd的和mod1e9+7的结果
输入描述:
第一行两个数n,m
之后一行n个数表示a
之后m行每行两个数l,r表示查询的区间
输出描述:
对于每个询问,输出一行一个数表示答案
示例1
输入
5 7 30 60 20 20 20 1 1 1 5 2 4 3 4 3 5 2 5 2 3
输出
30 330 160 60 120 240 100
说明
[1,1]的子区间只有[1,1],其gcd为30
[1,5]的子区间有:
[1,1]=30,[1,2]=30,[1,3]=10,[1,4]=10,[1,5]=10
[2,2]=60,[2,3]=20,[2,4]=20,[2,5]=20
[3,3]=20,[3,4]=20,[3,5]=20
[4,4]=20,[4,5]=20
[5,5]=20
总共330
[2,4]的子区间有:
[2,2]=60,[2,3]=20,[2,4]=20
[3,3]=20,[3,4]=20
[4,4]=20
总共160
[3,4]的子区间有:
[3,3]=20,[3,4]=20
[4,4]=20
总共60
[3,5]的子区间有:
[3,3]=20,[3,4]=20,[3,5]=20
[4,4]=20,[4,5]=20
[5,5]=20
总共120
[2,5]的子区间有:
[2,2]=60,[2,3]=20,[2,4]=20,[2,5]=20
[3,3]=20,[3,4]=20,[3,5]=20
[4,4]=20,[4,5]=20
[5,5]=20
总共240
[2,3]的子区间有:
[2,2]=60,[2,3]=20
[3,3]=20
总共100
备注:
对于100%的数据,有1 <= n , m , ai <= 100000
题解
倍增预处理、莫队算法。
类似的题目做过好几个了,有一个比较重要的性质:以$i$为起点的区间,区间$gcd$的值只有$log(n)$种。
这样莫队转移的时候,只要把那$log(n)$种都算一下就$ok$了。
有点卡常,优化了一点才过。
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 10;
const long long mod = 1e9 + 7;
int a[maxn];
int pos[maxn];
int n, m, L, R;
long long Ans;
struct X {
int l, r, id;
}s[maxn];
long long ans[maxn];
struct P {
int end1;
int end2;
long long sum;
int GCD;
int nx;
}p[maxn * 40];
int cnt;
int List[maxn][2];
/* st */
int dp[maxn][30];
int gcd(int a, int b) {
if(b == 0) return a;
return gcd(b, a % b);
}
void init() {
for(int i = 1; i <= n; i ++) {
dp[i][0] = a[i];
}
for(int i = 1; (1 << i) <= n; i ++) {
for(int j = 1; j + (1 << i) - 1 <= n; j ++) {
dp[j][i] = gcd(dp[j][i - 1], dp[j + (1 << (i - 1))][i - 1]);
}
}
}
int query(int l, int r) {
int k = (int)(log(double(r - l + 1)) / log((double)2));
return gcd(dp[l][k], dp[r - (1 << k) + 1][k]);
}
/* st */
bool cmp(const X& a, const X& b) {
if (pos[a.l] != pos[b.l]) return a.l < b.l;
if((pos[a.l]) & 1) return a.r > b.r;
return a.r < b.r;
}
void add(int x, int op) {
int it;
for(it = List[x][op]; it != -1; it = p[it].nx) {
int id = it;
if(p[id].end2 < L || p[id].end1 > R) continue;
if(p[id].end1 >= L && p[id].end2 <= R) {
Ans = Ans + p[id].sum;
} else {
int ll = max(p[id].end1, L);
int rr = min(p[id].end2, R);
Ans = Ans + 1LL * (rr - ll + 1) * p[id].GCD;
}
}
}
void del(int x, int op) {
int it;
for(it = List[x][op]; it != -1; it = p[it].nx) {
int id = it;
if(p[id].end2 < L || p[id].end1 > R) continue;
if(p[id].end1 >= L && p[id].end2 <= R) {
Ans = Ans - p[id].sum;
} else {
int ll = max(p[id].end1, L);
int rr = min(p[id].end2, R);
Ans = Ans - 1LL * (rr - ll + 1) * p[id].GCD;
}
}
}
int main() {
scanf("%d%d", &n, &m);
int sz = sqrt(n);
for(int i = 1; i <= n; i ++) {
scanf("%d", &a[i]);
pos[i] = i / sz;
List[i][0] = List[i][1] = -1;
}
init();
for(int i = 1; i <= n; i ++) {
int ll = i, rr = i;
while(ll <= n) {
int left = ll, right = n;
int g = query(i, ll);
while(left <= right) {
int mid = (left + right) / 2;
if(g == query(i, mid)) {
rr = mid, left = mid + 1;
} else {
right = mid - 1;
}
}
p[cnt].end1 = ll;
p[cnt].end2 = rr;
p[cnt].sum = 1LL * g * (rr - ll + 1);
p[cnt].GCD = g;
p[cnt].nx = List[i][0];
List[i][0] = cnt;
ll = rr + 1;
cnt ++;
}
}
for(int i = 1; i <= n; i ++) {
int ll = i, rr = i;
while(rr >= 1) {
int left = 1, right = rr;
int g = query(rr, i);
while(left <= right) {
int mid = (left + right) / 2;
if(g == query(mid, i)) {
ll = mid, right = mid - 1;
} else {
left = mid + 1;
}
}
p[cnt].end1 = ll;
p[cnt].end2 = rr;
p[cnt].sum = 1LL * g * (rr - ll + 1);
p[cnt].GCD = g;
p[cnt].nx = List[i][1];
List[i][1] = cnt;
rr = ll - 1;
cnt ++;
}
}
for(int i = 1; i <= m; i ++) {
scanf("%d%d", &s[i].l, &s[i].r);
s[i].id = i;
}
sort(s + 1, s + m + 1, cmp);
L = s[1].l;
R = s[1].l - 1;
Ans = 0;
for(int i = s[1].l; i <= s[1].r; i ++) {
R ++;
add(i, 1);
}
ans[s[1].id] = Ans;
for(int i = 2; i <= m; i ++) {
while (L > s[i].l) { L --, add(L, 0); }
while (R < s[i].r) { R ++, add(R, 1); }
while (L < s[i].l) { del(L, 0), L ++; }
while (R > s[i].r) { del(R, 1), R --; }
ans[s[i].id] = Ans;
}
for(int i = 1; i <= m; i ++) {
printf("%lld\n", ans[i] % mod);
}
return 0;
}