前置知识:势能线段树(学了 懒标记 再来看)
线段树 能够通过 打懒标记实现区间修改的条件有两个:
- 能够快速处理 懒标记对区间询问结果 的影响
- 能够快速实现 懒标记的合并
但,有的区间修改不满足上面两个条件(如 区间整除 / 开方 / 取模 等,接下来要讲的例题就是 区间每个元素都加上自身的 lowbit
值)。
但某些修改存在一些 奇妙的性质,使得,序列每个元素被修改的次数 “有一个上限”。
具体做法:
-
可以 在线段树每个节点上记录一个
bool
值(要区别于 懒标记),记为ok
,表示:对应区间内是否 每个元素都达到 修改次数上限。 -
区间修改 时 暴力递归到叶子节点,如果 途中 遇到一个 节点(即 某一段区间),这个节点的 对应区间 内 每个元素 都 达到修改次数上限 则 在这个节点
return
掉。
题意:
给你一段 含 n
个数字的序列,对于这段序列可以 有 m
次操作。
操作有 两种类型:
- 1、
(1,L,R)
表示将(L,R)
区间的每个数加自身的lowbit
值(若 一个数为x
,则 其lowbit
值为x & -x
). - 2、
(2,L,R)
询问区间(L,R)
数字之和
思路:
本题肯定是要用线段树维护区间信息,但对于 操作 1
对区间每个数 都需要加上其 lowbit
值,直接用 单点修改 来做会 T
得很惨。
但是 加的值很特殊,是 该数的 lowbit
的值,仔细想想我们会发现一个事实:
- 一个数 最多加
logn
次其lowbit
值后 继续加上lowbit
值就变成了乘2
(因为 此时 该数二进制形式上只有最高位为1
)。
那么,我们 对于每个数的单点修改操作 最多 也只需要进行 nlogn
次,然后题目就变成了一个 普通线段树区间修改和区间求和 问题。
所以我们只需要在建造线段树时,每个节点加上一个 bool
值 判断是乘 2
还是加上 lowbit
的值 即可。
对于 懒标记 mul
,我们只针对 “区间整体乘 2
” 这个操作进行设置,懒标记 mul
标记的 就是 2
的次幂,如果 整个区间都是 1000
这样的二进制数,就打上 标记。
注意:
- 因为题目中存在 取模操作,对于 加上
lowbit
的情况 贸然进行 取模 可能会导致 该数二进制位情况的变化,从而 使得结果错误,所以 我们需要在当该数还未到达直接乘2
的情况时不对其进行取模,直到 该数符合乘2
的情况我们对该数修改时才进行取模操作。
接下来分析一下 代码的具体细节:(着重分析 check、pushup、pushdown、modify
四个函数)
modify
时,暴力递归到叶子节点,进行单点修改,如果后序 叶子结点的值 + 其lowbit
值 =2
* 叶子结点值(check
函数 返回为真),那么 将节点中的bool
值ok
置为true
(ok
用于判断是乘2
还是加上lowbit
的值)
if (t[u].l == t[u].r) //单点更新
{
t[u].sum = (t[u].sum + lowbit(t[u].sum));
if (check(t[u].sum)) t[u].ok = true;
return;
}
- 如果递归的 途中 遇到一个 节点(即 某一段区间),这个节点的 对应区间 内 每个元素 都 达到修改次数上限,即 整个区间都是
1000
这样的二进制数,打上标记、进行区间信息更新后,在这个节点return
掉。
if (t[u].l >= l && t[u].r <= r && t[u].ok) //区间更新
{
t[u].mul = (t[u].mul * 2) % mod;
t[u].sum = (t[u].sum * 2) % mod;
return;
}
- 树中节点的
ok
值、sum
值 的更新,我们就用pushup
函数 自子向父更新即可:
void pushup(int u) {
auto& rt = t[u], & le = t[u << 1], & ri = t[u << 1 | 1];
if (le.ok && ri.ok) rt.ok = true;
else rt.ok = false;
t[u].sum = (t[u << 1].sum + t[u << 1 | 1].sum) % mod;
}
- 对于 懒标记的下传,使用
pushdown
函数,这是个套路做法:
void pushdown(int u)
{
auto& rt = t[u], & le = t[u << 1], & ri = t[u << 1 | 1];
if (t[u].mul != 1)
{
le.mul = le.mul * rt.mul % mod;
le.sum = (le.sum * rt.mul) % mod;
ri.mul = ri.mul * rt.mul % mod;
ri.sum = (ri.sum * rt.mul) % mod;
rt.mul = 1;
}
}
对于 modify
函数,我们包含了 区间修改 和 单点修改,出于规范,我们最好将 区间修改 放在 单点修改 之前进行判断。对于 sum
的更新,在 pushup、pushdown、modify
中均有出现
总代码:
/*
一个数加一次lowbit就会使得最后一个1加1,往前进1位
没加几次就会变成类似10000这样的二进制数,此时加上lowbit,
就相当于乘以2,所以我们懒标记标记的就是2的次次幂,
如果整个区间都是1000这样的二进制数,就打上标记
*/
#define _CRT_SECURE_NO_WARNINGS 1
#include <bits/stdc++.h>
using namespace std;
#define int long long
//#define map unordered_map
const int N = 2e5 + 10, mod = 998244353;
struct node
{
int l, r;
int sum;
int mul = 1;
bool ok;
} t[N << 2];
int n, m;
int a[N];
inline int lowbit(int x) {
return x & (-x);
}
inline void Clear()
{
for (int i = 1; i <= (n << 2); ++i)
{
t[i].l = t[i].r = t[i].sum = 0;
t[i].mul = 1;
}
}
bool check(int sum)
{
return (sum + lowbit(sum) == 2 * sum) ? true : false;
}
void pushup(int u) {
auto& rt = t[u], & le = t[u << 1], & ri = t[u << 1 | 1];
if (le.ok && ri.ok) rt.ok = true;
else rt.ok = false;
t[u].sum = (t[u << 1].sum + t[u << 1 | 1].sum) % mod;
}
void pushdown(int u)
{
auto& rt = t[u], & le = t[u << 1], & ri = t[u << 1 | 1];
if (t[u].mul != 1)
{
le.mul = le.mul * rt.mul % mod;
le.sum = (le.sum * rt.mul) % mod;
ri.mul = ri.mul * rt.mul % mod;
ri.sum = (ri.sum * rt.mul) % mod;
rt.mul = 1;
}
}
void build(int u, int l, int r)
{
t[u] = { l, r };
if (l == r)
{
t[u].sum = a[l];
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify(int u, int l, int r)
{
if (t[u].l >= l && t[u].r <= r && t[u].ok) //区间更新
{
t[u].mul = (t[u].mul * 2) % mod;
t[u].sum = (t[u].sum * 2) % mod;
return;
}
if (t[u].l == t[u].r) //单点更新
{
t[u].sum = (t[u].sum + lowbit(t[u].sum));
if (check(t[u].sum)) t[u].ok = true;
return;
}
pushdown(u);
int mid = t[u].l + t[u].r >> 1;
if (l <= mid) modify(u << 1, l, r);
if (r > mid) modify(u << 1 | 1, l, r);
pushup(u);
}
int ask(int u, int l, int r)
{
if (l <= t[u].l && r >= t[u].r)
{
return t[u].sum;
}
pushdown(u);
int mid = t[u].l + t[u].r >> 1;
int ans = 0;
if (l <= mid) ans = (ans + ask(u << 1, l, r)) % mod;
if (r > mid) ans = (ans + ask(u << 1 | 1, l, r)) % mod;
return ans;
}
signed main()
{
int T; cin >> T;
while (T--)
{
Clear();
cin >> n;
for (int i = 1; i <= n; ++i)
{
scanf("%lld", &a[i]);
}
build(1, 1, n);
cin >> m;
while (m--)
{
int op, l, r;
scanf("%lld%lld%lld", &op, &l, &r);
if (op == 1)
{
modify(1, l, r);
}
else
{
int res = ask(1, l, r);
printf("%lld\n", res);
}
}
}
return 0;
}