1. 问题描述:
给定一个长度为 N 的数列 A,以及 M 条指令,每条指令可能是以下两种之一:
C l r d,表示把 A[l],A[l+1],…,A[r] 都加上 d。
Q l r,表示询问数列中第 l∼r 个数的和。
对于每个询问,输出一个整数表示答案。
输入格式
第一行两个整数 N,M。
第二行 N 个整数 A[i]。
接下来 M 行表示 M 条指令,每条指令的格式如题目描述所示。
输出格式
对于每个询问,输出一个整数表示答案。
每个答案占一行。
数据范围
1 ≤ N,M ≤ 10 ^ 5,
|d| ≤ 10000,
|A[i]| ≤ 10 ^ 9
输入样例:
10 5
1 2 3 4 5 6 7 8 9 10
Q 4 4
Q 1 10
Q 2 4
C 3 6 3
Q 2 4
输出样例:
4
55
9
15
来源:https://www.acwing.com/problem/content/244/
2. 思路分析:
分析题目可以知道这是一道经典的区间修改与区间查询的线段树题目,对于区间修改的大部分题目我们都需要使用带有懒标记的线段树进行求解(对于区间修改的一些题目我们是可以转化为差分的思路从而转为不带有懒标记的线段树进行求解),带有懒标记的线段树比没有懒标记的线段树多了一个pushdown操作,pushdown操作可以将父节点的信息下传到两个子节点,其实这个pushdown操作的灵感来自于区间查询操作,线段树执行查询操作的时候若当前节点的区间包含于查询区间的时候那么直接返回当前节点或者是当前节点有关的信息,对于区间修改操作我们也是可以借助于类似的想法,对于区间修改的操作我们需要在线段树节点中保存一个懒标记add(一般是区间加操作)用于pushdown操作将父节点的信息下传到子节点,区间修改也是一个递归的过程,当我们发现当前节点的区间包含在了待修改的区间之内那么我们就需要执行pushown操作,将父节点的懒标记add下传到两个子节点,这样有了pushdown操作可以避免修改单点操作那样需要修改区间中的所有点了。我们一般是在这两个操作需要使用pushdown操作:第一个是区间修改的操作,在递归调用之前需要调用pushdown方法,这样可以将当前节点的懒标记下传到子区间,第二个是区间查询操作的时候也需要调用pushdown方法。
3. 代码如下:
java:
import java.util.Scanner;
public class Main {
static Tree []tree;
static long []w;
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int m = sc.nextInt();
tree = new Tree[4 * n];
w = new long[n + 1];
for (int i = 1; i <= n; ++i){
w[i] = sc.nextInt();
}
// 初始化线段树节点
for (int i = 0; i < 4 * n; ++i){
tree[i] = new Tree();
}
// 注意需要接受回车符
sc.nextLine();
build(1, 1, n);
for (int i = 0; i < m; ++i){
String []s = sc.nextLine().split(" ");
String t = s[0];
// Q为查询操作
if (t.equals("Q")){
System.out.println(query(1, Integer.parseInt(s[1]), Integer.parseInt(s[2])));
}else {
int l = Integer.parseInt(s[1]), r = Integer.parseInt(s[2]), v = Integer.parseInt(s[3]);
modify(1, l, r, v);
}
}
}
// 将子节点的信息传递到父节点
public static void pushup(int u){
tree[u].sum = tree[u << 1].sum + tree[u << 1 | 1].sum;
}
// 将父节点懒标记下传
public static void pushdown(int u){
Tree root = tree[u], left = tree[u << 1], right = tree[u << 1 | 1];
// 当前根节点有懒标记那么就往下传(add不等于0)
if (root.add != 0){
left.add += root.add;
right.add += root.add;
left.sum += (left.r - left.l + 1) * root.add;
right.sum += (right.r - right.l + 1) * root.add;
// 注意需要清空根节点的标记
root.add = 0;
}
}
public static void build(int u, int l, int r){
tree[u].l = l;
tree[u].r = r;
if (l == r){
tree[u].sum = w[l];
// 懒标记为0
tree[u].add = 0;
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
// 修改的思路其实是借助了查询的思路
public static void modify(int u, int l, int r, int v){
if (tree[u].l >= l && tree[u].r <= r){
tree[u].sum += (tree[u].r - tree[u].l + 1) * v;
tree[u].add += v;
return;
}
pushdown(u);
int mid = tree[u].l + tree[u].r >> 1;
// 因为是区间修改所以需要判断两个方向
if (l <= mid) modify(u << 1, l, r, v);
if (r > mid) modify(u << 1 | 1, l, r, v);
pushup(u);
}
public static long query(int u, int l, int r){
if (tree[u].l >= l && tree[u].r <= r) return tree[u].sum;
pushdown(u);
int mid = tree[u].l + tree[u].r >> 1;
long v = 0;
if (l <= mid) v = query(u << 1, l, r);
if (r > mid) v += query(u << 1 | 1, l, r);
return v;
}
public static class Tree{
private int l, r;
// add为懒标记
private long sum, add;
}
}
c++:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 100010;
int n, m;
int w[N];
struct Node
{
int l, r;
LL sum, add;
}tr[N * 4];
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int u)
{
auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
if (root.add)
{
left.add += root.add, left.sum += (LL)(left.r - left.l + 1) * root.add;
right.add += root.add, right.sum += (LL)(right.r - right.l + 1) * root.add;
root.add = 0;
}
}
void build(int u, int l, int r)
{
if (l == r) tr[u] = {l, r, w[r], 0};
else
{
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int l, int r, int d)
{
if (tr[u].l >= l && tr[u].r <= r)
{
tr[u].sum += (LL)(tr[u].r - tr[u].l + 1) * d;
tr[u].add += d;
}
else // 一定要分裂
{
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r, d);
if (r > mid) modify(u << 1 | 1, l, r, d);
pushup(u);
}
}
LL query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
LL sum = 0;
if (l <= mid) sum = query(u << 1, l, r);
if (r > mid) sum += query(u << 1 | 1, l, r);
return sum;
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i ++ ) scanf("%d", &w[i]);
build(1, 1, n);
char op[2];
int l, r, d;
while (m -- )
{
scanf("%s%d%d", op, &l, &r);
if (*op == 'C')
{
scanf("%d", &d);
modify(1, l, r, d);
}
else printf("%lld\n", query(1, l, r));
}
return 0;
}