线段树对加法乘法的优化
例题 洛谷P3373
题目描述
已知一个数列,你需要进行下面三种操作:
将某区间每一个数乘上 xx
将某区间每一个数加上 xx
求出某区间每一个数的和
输入格式
第一行包含三个整数 n,m,pn,m,p,分别表示该数列数字的个数、操作的总个数和模数。
第二行包含 nn 个用空格分隔的整数,其中第 ii 个数字表示数列第 ii 项的初始值。
接下来 mm 行每行包含若干个整数,表示一个操作,具体如下:
操作 1: 格式:1 x y k 含义:将区间 [x,y][x,y] 内每个数乘上 kk
操作 2: 格式:2 x y k 含义:将区间 [x,y][x,y] 内每个数加上 kk
操作 3: 格式:3 x y 含义:输出区间 [x,y][x,y] 内每个数的和对 pp 取模所得的结果
输出格式
输出包含若干行整数,即为所有操作 33 的结果
输入样例
5 5 38
1 5 4 2 3
2 1 4 1
3 2 5
1 2 4 2
2 3 5 5
3 1 4
输出样例
17
2
题目解析
这是一道线段树的模板题,主要考察对 laze_tag 的应用,因为这里涉及到乘法和加法,所以需要两个懒标记。
struct {
int l, r;
ll sum;
ll laze1, laze2;//laze1代表乘,laze2代表加
}tree[N << 2];
所以显然易见的,我们可以写出下面的 pushup 和 pushdown 代码来。注意:这里需要区分加法和乘法的运算顺序。 因为建树代码几乎所有线段树题都是一样的,我这里暂时不给出,唯一需要注意的一点就是在建树的时候我们需要把乘法懒标记初始化为 1 。
void update(int node, int laze1, int laze2)
{
tree[node].sum = (tree[node].sum * laze1 + laze2 * (tree[node].r - tree[node].l + 1)) % p;//更新当前节点维护的值
tree[node].laze1 = (tree[node].laze1 * laze1) % p;//更新乘法懒标记
tree[node].laze2 = (laze2 + tree[node].laze2 * laze1) % p;//更新加法懒标记
return;
}
void pushdown(int node)
{
update(node << 1, tree[node].laze1, tree[node].laze2);//更新左儿子
update(node << 1 | 1, tree[node].laze1, tree[node].laze2);//更新右儿子
tree[node].laze1 = 1;//清空懒标记
tree[node].laze2 = 0;
}
void pushup(int node)
{
tree[node].sum = tree[node << 1].sum + tree[node << 1 | 1].sum;
}
对 change 函数的优化
常规思路,我们需要写两个 change 函数来分别针对加法和乘法进行修改,太懒,不想敲这个代码 ,那么我们能不能把加法和乘法合并为一个表达式呢,其实是可以的。例:对一个数 x ,我们有这样一个式子 x * a + b,如果让 x 加 k,则使 a = 1, b = k 即可,如果让 x 乘以 k,则让 a = k, b = 0 即可。这样操作的话,我们的 change 函数也就只需要写一个了。代码如下:
void change(int node, int l, int r, int k, int type)//type为1代表乘,为2代表加
{
if (tree[node].l == l && tree[node].r == r)
{
if (type == 1)
update(node, k, 0);//乘法
else
update(node, 1, k);//加法
return;
}
if(tree[node].laze1 != 1 || tree[node].laze2 != 0)
pushdown(node);
int mid = (tree[node].l + tree[node].r) >> 1;
if (r <= mid)
change(node << 1, l, r, k, type);
else if (l > mid)
change(node << 1 | 1, l, r, k, type);
else
{
change(node << 1, l, mid, k, type);
change(node << 1 | 1, mid + 1, r, k, type);
}
pushup(node);
}
最后,附上这一题的完整代码
#include <iostream>
#include <cstdio>
using namespace std;
typedef long long ll;
const int N = 1e5 + 5;
ll arr[N];
int p;
struct {
int l, r;
ll sum;
ll laze1, laze2;//laze1代表乘,laze2代表加
}tree[N << 2];
//建树
void build(int node, int l, int r)
{
tree[node].l = l, tree[node].r = r;
tree[node].laze1 = 1;
if (l == r)
{
tree[node].sum = arr[l];
return;
}
else
{
int mid = (l + r) >> 1;
build(node << 1, l, mid);
build(node << 1 | 1, mid + 1, r);
tree[node].sum = tree[node << 1].sum + tree[node << 1 | 1].sum;
}
}
//更新当前节点信息
void update(int node, int laze1, int laze2)
{
tree[node].sum = (tree[node].sum * laze1 + laze2 * (tree[node].r - tree[node].l + 1)) % p;
tree[node].laze1 = (tree[node].laze1 * laze1) % p;
tree[node].laze2 = (laze2 + tree[node].laze2 * laze1) % p;
return;
}
//下放懒标记,即更新左子节点和右子节点
void pushdown(int node)
{
update(node << 1, tree[node].laze1, tree[node].laze2);
update(node << 1 | 1, tree[node].laze1, tree[node].laze2);
tree[node].laze1 = 1;
tree[node].laze2 = 0;
}
//更新当前节点
void pushup(int node)
{
tree[node].sum = tree[node << 1].sum + tree[node << 1 | 1].sum;
}
//对区间进行修改
void change(int node, int l, int r, int k, int type)//type为1代表乘,为2代表加
{
if (tree[node].l == l && tree[node].r == r)
{
if (type == 1)
update(node, k, 0);
else
update(node, 1, k);
return;
}
if(tree[node].laze1 != 1 || tree[node].laze2 != 0)
pushdown(node);
int mid = (tree[node].l + tree[node].r) >> 1;
if (r <= mid)
change(node << 1, l, r, k, type);
else if (l > mid)
change(node << 1 | 1, l, r, k, type);
else
{
change(node << 1, l, mid, k, type);
change(node << 1 | 1, mid + 1, r, k, type);
}
pushup(node);
}
//查询区间 l 到 r
ll query(int node, int l, int r)
{
if (tree[node].l == l && tree[node].r == r)
return tree[node].sum;
if (tree[node].laze1 != 1 || tree[node].laze2 != 0)
pushdown(node);
int mid = (tree[node].l + tree[node].r) >> 1;
if (r <= mid)
return query(node << 1, l, r);
else if (l > mid)
return query(node << 1 | 1, l, r);
else
return (query(node << 1, l, mid) + query(node << 1 | 1, mid + 1, r)) % p;
}
int main()
{
int n, m;
cin >> n >> m >> p;
for (int i = 1; i <= n; i++)
cin >> arr[i];
build(1, 1, n);
while (m--)
{
int judge;
cin >> judge;
if (judge == 3)
{
int x, y;
cin >> x >> y;
cout << query(1, x, y) << endl;
}
else
{
int x, y, k;
cin >> x >> y >> k;
change(1, x, y, k, judge);
}
}
return 0;
}