题意:
给出一个数组a,叫你每次询问如下等式的值。
f(l,r)=∑ri=l∑rj=igcd(ai,ai+1....aj)
解析:
思考了很久终于理解了学长的思路
给你一个序列,这个序列的子序列gcd的个数不会超过logN个(N为每个数字,最大能取到的范围)
因为求gcd是递减的,每次至少除以2,所以gcd的个数只会有logN个。
然后让我们来看看题目要求的是什么。
所有子区间的gcd的和。
比如[1, 5]这个区间可以分解成如下子区间。
[1, 1] [1, 2] [1, 3] [1, 4] [1, 5]
[2, 2] [2, 3] [2, 4] [2, 5]
[3, 3] [3, 4] [3, 5]
[4, 4] [4, 5]
[5, 5]现在我们可以开一个数组来保存[1->i, i] 子区间的gcd的值,以及满足该gcd的最大范围。
如果我们知道这个最大范围的话,以及这个范围的gcd值的话,那么就能很快的求出这个范围的值。
设最大范围为n,公共gcd为d
这个最大范围的内的子序列的和,就可以满足,一个等差数列。比如我现在有一个区间[1, 6]
[3, 6] [4, 6] [5, 6] [6, 6]的gcd值一样,而前面[1, 6],[2, 6]包含了这4个区间
假设前4个区间的gcd值是v,查询的范围越大,所加的值就越多。[6, 6] +v
[5, 6] +2v
[4, 6] +3v
[3, 6] +4v
[2, 6] +4v
[1, 6] +4v然后就可以先固定下右边界 r ,利用线段树单点查询左边界
l 总共累加了多少次。
my code
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ls (o<<1)
#define rs (o<<1|1)
#define lson ls, L, M
#define rson rs, M+1, R
using namespace std;
typedef long long ll;
const int N = (int)1e4 + 5;
int gcd(int a, int b) {
return b == 0 ? a : gcd(b, a % b);
}
struct Query {
int l, r, num;
ll length() { return r - l + 1; }
Query() {}
Query(int l, int r, int num) : l(l), r(r), num(num) {}
} Q[N], save[35], line[35];
int sn, ln;
ll a[N], ans[N];
struct Node {
int l, r;
ll sum, a1, d;
ll length() { return r - l + 1; }
void add(ll a2, ll d2) {
ll n = length();
sum += n*a2 + n*(n-1)/2 * d2;
a1 += a2, d += d2;
}
} node[N << 2];
inline void pushUp(int o) {
node[o].sum = node[ls].sum + node[rs].sum;
}
inline void pushDown(int o) {
ll a1 = node[o].a1, d = node[o].d;
ll len = node[ls].length();
node[ls].add(a1, d);
node[rs].add(a1 + d * len, d);
node[o].a1 = node[o].d = 0;
}
void build(int o, int L, int R) {
node[o].l = L, node[o].r = R;
node[o].sum = node[o].a1 = node[o].d = 0;
if(L == R) return ;
int M = (L + R) >> 1;
build(lson);
build(rson);
}
void modify(int o, int ql, int qr, ll a1, ll d) {
if(qr < ql) return ;
if(ql <= node[o].l && node[o].r <= qr) {
ll a = a1 + d * (node[o].l - ql);
node[o].add(a, d);
return ;
}
int M = (node[o].l + node[o].r) >> 1;
pushDown(o);
if(ql <= M) modify(ls, ql, qr, a1, d);
if(qr > M) modify(rs, ql, qr, a1, d);
pushUp(o);
}
ll query(int o, int pos) {
if(node[o].l == node[o].r)
return node[o].sum;
int M = (node[o].l + node[o].r) >> 1;
pushDown(o);
if(pos <= M) return query(ls, pos);
else return query(rs, pos);
}
bool cmp(Query a, Query b) {
return a.r < b.r;
}
void getSegment(int last, int pos) {
for(int i = 0; i < sn; i++)
save[i].num = gcd(save[i].num, last);
save[sn++] = Query(pos, pos, last);
ln = 0;
line[ln++] = save[0];
for(int i = 1; i < sn; i++) {
if(line[ln-1].num == save[i].num) {
line[ln-1].r = save[i].r;
}else line[ln++] = save[i];
}
sn = ln;
memcpy(save, line, ln*sizeof(Query));
}
int n, m;
int main() {
int T;
scanf("%d", &T);
while(T--) {
scanf("%d", &n);
for(int i = 1; i <= n; i++)
scanf("%lld", &a[i]);
int ql, qr;
scanf("%d", &m);
for(int i = 0; i < m; i++) {
scanf("%d%d", &ql, &qr);
Q[i] = Query(ql, qr, i);
}
sort(Q, Q+m, cmp);
build(1, 1, n);
sn = ln = 0;
int l = 0;
for(int i = 1; i <= n; i++) {
getSegment(a[i], i);
for(int j = 0; j < ln; j++) {
ll a1 = line[j].num * line[j].length();
ll d = -line[j].num;
modify(1, line[j].l, line[j].r, a1, d);
modify(1, 1, line[j].l-1, a1, 0);
}
while(l < m && Q[l].r <= i) {
ans[Q[l].num] = query(1, Q[l].l);
l++;
}
}
for(int i = 0; i < m; i++) {
printf("%lld\n", ans[i]);
}
}
return 0;
}