题目链接: 智乃酱的平方数列
2021.11.6更新了第二种做法
大致题意
给定一个长度为 n n n的序列, 有 m m m次如下操作:
1 l r
表示对区间
[
l
,
r
]
[l, r]
[l,r]加上一个如
{
1
,
4
,
9
,
16
,
.
.
.
,
l
e
n
2
}
\{ 1, 4, 9, 16, ..., len^2 \}
{1,4,9,16,...,len2}的平方数列
2 l r
表示询问区间
[
l
,
r
]
[l, r]
[l,r]的和.
解题思路
思路一: 线段树维护高阶前缀和
我们只需要用线段树维护三阶前缀和即可.
思路二: 线段树维护二次函数
我们考虑给 [ l , r ] [l, r] [l,r]添加平方数列, 对于位置 x ∈ [ l , r ] x \in [l, r] x∈[l,r], 增加的值应当是 ( x − ( l − 1 ) ) 2 (x - (l - 1))^2 (x−(l−1))2.
我们将上式展开得到: x 2 − 2 ( l − 1 ) x + ( l − 1 ) 2 x^2 - 2(l - 1)x + (l - 1)^2 x2−2(l−1)x+(l−1)2. 这样我们得到了一个关于 x x x的二次函数, 我们维护二次函数的系数即可.
代码中用了三个懒标记, 分别维护二次项 x 2 x^2 x2的系数和, 一次项 x x x的系数和,以及常数项的系数和.
AC代码
/* 思路一: 线段树维护高阶前缀和 */
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 1; i <= (n); ++i)
using namespace std;
typedef long long ll;
const int N = 5E5 + 10, mod = 1E9 + 7, INV2 = mod - mod / 2;
struct node {
int l, r;
int sum1, sum2, sum3;
int lazy;
int base2, base3;
}t[N << 2];
void pushdown(node& op, int lazy) {
op.sum1 = (op.sum1 + 1ll * (op.r - op.l + 1) * lazy % mod) % mod;
op.sum2 = (op.sum2 + 1ll * op.base2 * lazy % mod) % mod;
op.sum3 = (op.sum3 + 1ll * op.base3 * lazy % mod) % mod;
op.lazy = (op.lazy + lazy) % mod;
}
void pushdown(int x) {
if (!t[x].lazy) return;
pushdown(t[x << 1], t[x].lazy), pushdown(t[x << 1 | 1], t[x].lazy);
t[x].lazy = 0;
}
void pushup(int x) {
node& op = t[x], &l = t[x << 1], &r = t[x << 1 | 1];
op.sum1 = (l.sum1 + r.sum1) % mod;
op.sum2 = (l.sum2 + r.sum2) % mod;
op.sum3 = (l.sum3 + r.sum3) % mod;
}
void build(int l, int r, int x = 1) {
t[x] = { l, r, 0, 0, 0, 0 };
if (l == r) {
t[x].base2 = l;
t[x].base3 = 1ll * l * l % mod;
return;
}
int mid = l + r >> 1;
build(l, mid, x << 1), build(mid + 1, r, x << 1 | 1);
t[x].base2 = (t[x << 1].base2 + t[x << 1 | 1].base2) % mod;
t[x].base3 = (t[x << 1].base3 + t[x << 1 | 1].base3) % mod;
}
void modify(int l, int r, int c, int x = 1) {
if (l <= t[x].l and r >= t[x].r) {
pushdown(t[x], c);
return;
}
pushdown(x);
int mid = t[x].l + t[x].r >> 1;
if (l <= mid) modify(l, r, c, x << 1);
if (r > mid) modify(l, r, c, x << 1 | 1);
pushup(x);
}
int ask(int l, int r, int x = 1) {
if (l <= t[x].l and r >= t[x].r) {
int res = (1ll * t[x].sum1 * (r + 1) % mod * (r + 2) % mod - 1ll * t[x].sum2 * (2 * r + 3) % mod + t[x].sum3) % mod;
return 1ll * ((res + mod) % mod) * INV2 % mod;
}
pushdown(x);
int mid = t[x].l + t[x].r >> 1;
int res = 0;
if (l <= mid) res = ask(l, r, x << 1);
if (r > mid) res = (res + ask(l, r, x << 1 | 1)) % mod;
return res;
}
int main()
{
int n, m; cin >> n >> m;
build(1, n);
rep(i, m) {
int tp, l, r; scanf("%d %d %d", &tp, &l, &r);
if (tp == 1) {
int len = r - l + 1;
modify(l, l, 1);
if (l != r) modify(l + 1, r, 2);
if (r + 1 <= n) {
int val = (1ll * len * len % mod + 2 * len - 1) % mod;
val = (-val + mod) % mod;
modify(r + 1, r + 1, val);
}
if (r + 2 <= n) modify(r + 2, r + 2, 1ll * len * len % mod);
}
else {
int res = ask(1, r);
if (l - 1 >= 1) res = ((res - ask(1, l - 1)) % mod + mod) % mod;
printf("%d\n", res);
}
}
return 0;
}
/* 思路二: 线段树维护二次函数 */
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 1; i <= (n); ++i)
using namespace std;
typedef long long ll;
const int N = 5E5 + 10, mod = 1E9 + 7;
struct node {
int l, r;
int sum;
int lazy1, lazy2, lazy3;
int base1, base2;
}t[N << 2];
void pushdown(node& op, int lazy1, int lazy2, int lazy3) {
const int len = op.r - op.l + 1;
op.sum = (op.sum + 1ll * op.base1 * lazy1) % mod;
op.sum = ((op.sum + -2ll * op.base2 * lazy2) % mod + mod) % mod;
op.sum = (op.sum + 1ll * len * lazy3) % mod;
op.lazy1 = (op.lazy1 + lazy1) % mod;
op.lazy2 = (op.lazy2 + lazy2) % mod;
op.lazy3 = (op.lazy3 + lazy3) % mod;
}
void pushdown(int x) {
if (!t[x].lazy1 and !t[x].lazy2 and !t[x].lazy3) return;
pushdown(t[x << 1], t[x].lazy1, t[x].lazy2, t[x].lazy3), pushdown(t[x << 1 | 1], t[x].lazy1, t[x].lazy2, t[x].lazy3);
t[x].lazy1 = t[x].lazy2 = t[x].lazy3 = 0;
}
void pushup(int x) { t[x].sum = (t[x << 1].sum + t[x << 1 | 1].sum) % mod; }
void build(int l, int r, int x = 1) {
t[x] = { l, r, 0, 0, 0, 0 };
if (l == r) {
t[x].base1 = 1ll * l * l % mod;
t[x].base2 = l;
return;
}
int mid = l + r >> 1;
build(l, mid, x << 1), build(mid + 1, r, x << 1 | 1);
t[x].base1 = (t[x << 1].base1 + t[x << 1 | 1].base1) % mod;
t[x].base2 = (t[x << 1].base2 + t[x << 1 | 1].base2) % mod;
}
void modify(int l, int r, int c, int x = 1) {
if (l <= t[x].l and r >= t[x].r) {
pushdown(t[x], 1, c, 1ll * c * c % mod);
return;
}
pushdown(x);
int mid = t[x].l + t[x].r >> 1;
if (l <= mid) modify(l, r, c, x << 1);
if (r > mid) modify(l, r, c, x << 1 | 1);
pushup(x);
}
int ask(int l, int r, int x = 1) {
if (l <= t[x].l and r >= t[x].r) return t[x].sum;
pushdown(x);
int mid = t[x].l + t[x].r >> 1;
int res = 0;
if (l <= mid) res = ask(l, r, x << 1);
if (r > mid) res = (res + ask(l, r, x << 1 | 1)) % mod;
return res;
}
int main()
{
int n, m; cin >> n >> m;
build(1, n);
rep(i, m) {
int tp, l, r; scanf("%d %d %d", &tp, &l, &r);
if (tp == 1) modify(l, r, l - 1);
else printf("%d\n", ask(l, r));
}
return 0;
}