题意 :
给定一个数组 ,范围为 [0,65536),有以下两种操作:
- 给出 x , y 把 [x , y] 内的每个数 + 1 同时对 65536 取模。
- 给出 x,y,L , 查询区间 [x , x + L - 1] 和区间 [y , y + L - 1]是否完全相同。
思路 :
- 思路就是 线段树维护 hash ,有区间修改和查询 判断两段 hash值是否相同就可以了。
- 首先考虑一下区间合并(也就是pushup),线段树的每个节点表示这一段的 hash 值,在区间合并的操作时 大区间的 hash 值就是 左区间的 hash值 * base ^ len (len表示右区间的长度) + 右区间的 hash值 。Hash[rt] = (Hash[rt << 1] * poww[r - mid] + Hash[rt << 1 | 1])
- 然后是区间更新 ,把这个区间的值全部 + 1, hash 的变化 就是 base的前缀和 ,例如 某一个区间的hash值为
(n 为区间长度 - 1),那如果现在把每个 a[i] 都 + 1 , 那hash值的变化就是
这里用个前缀和记录一下 ,就可以很好的用 lazy维护。
4.查询操作和普通的查询不一样 , 因为在合并两个区间时 ,合并后的 hash 值 不是两个 hash的 简单相加(参考上面的pushup),也就是左区间的 hash值要先乘上 base ^ 右区间长度,再加右 区间的hash值。
5.最后就要考虑一下溢出的问题了,如果在更新过程 某个数 >= 65536 , 就要对 65536 取模了,直接在更新操作里判断=肯定不好写 ,所以我们在每次更新后都找一下有没有数 大于65536 ,这里怎么找呢 ,肯定不能暴力扫一遍 。 可以利用线段树进行一个类似二分的过程 ,维护一下每个区间的最大值 ,如果左区间最大值大于 65536 ,继续更新左区间,这样一直 下去找到那个值为止 ,复杂度 log(n) ,不用担心超时。
代码:
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N = 5e5 + 10, mod = 1e9 + 7, MOD = 65536, base = 131;
struct node {
ll sum, mx;
int l, r, lazy;
}tr[4 * N];
int a[N];
ll p[N], pre[N];
int n, q;
void pushup(int rt) {
tr[rt].sum = (tr[rt << 1].sum + tr[rt << 1 | 1].sum) % mod;
tr[rt].mx = max(tr[rt << 1].mx, tr[rt << 1 | 1].mx);
}
void pushdown(int rt) {
auto& root = tr[rt], & left = tr[rt << 1], & right = tr[rt << 1 | 1];
if (tr[rt].lazy) {
left.lazy += root.lazy;
right.lazy += root.lazy;
left.mx += root.lazy;
right.mx += root.lazy;
left.sum += (ll)root.lazy * (pre[left.r] - pre[left.l - 1]) % mod;
right.sum += (ll)root.lazy * (pre[right.r] - pre[right.l - 1]) % mod;
left.sum %= mod;
right.sum %= mod;
root.lazy = 0;
}
}
void build(int rt, int l, int r) {
tr[rt].l = l;
tr[rt].r = r;
if (l == r) {
tr[rt].mx = a[tr[rt].l];
tr[rt].sum = a[tr[rt].l] * p[tr[rt].l] % mod;
return;
}
int m = l + r >> 1;
build(rt << 1, l, m);
build(rt << 1 | 1, m + 1, r);
pushup(rt);
}
ll query_sum(int rt, int l, int r) {
if (l <= tr[rt].l && tr[rt].r <= r) {
return tr[rt].sum;
}
pushdown(rt);
int m = tr[rt].l + tr[rt].r >> 1;
ll res = 0;
if (l <= m) res = (res + query_sum(rt << 1, l, r)) % mod;
if (r > m) res = (res + query_sum(rt << 1 | 1, l, r)) % mod;
return res;
}
ll query_mx(int rt, int l, int r) {
if (l <= tr[rt].l && tr[rt].r <= r) {
return tr[rt].mx;
}
pushdown(rt);
int m = tr[rt].l + tr[rt].r >> 1;
ll res = 0;
if (l <= m) res = max(res, query_mx(rt << 1, l, r));
if (r > m)res = max(res, query_mx(rt << 1 | 1, l, r));
return res;
}
void modify(int rt, int l, int r, int d) {
if (l <= tr[rt].l && tr[rt].r <= r) {
tr[rt].lazy += d;
tr[rt].sum = (ll)((tr[rt].sum + pre[tr[rt].r] - pre[tr[rt].l - 1]) % mod + mod) % mod;
tr[rt].mx += d;
return;
}
pushdown(rt);
int m = tr[rt].l + tr[rt].r >> 1;
if (l <= m)modify(rt << 1, l, r, d);
if (r > m)modify(rt << 1 | 1, l, r, d);
pushup(rt);
}
void modify_all(int rt, int l, int r) {
if (l == tr[rt].l && tr[rt].r == r) {
tr[rt].mx %= MOD;
tr[rt].sum = (ll)tr[rt].mx * p[tr[rt].l] % mod;
return;
}
pushdown(rt);
int m = tr[rt].l + tr[rt].r >> 1;
if (l <= m) modify_all(rt << 1, l, r);
if (r > m)modify_all(rt << 1 | 1, l, r);
pushup(rt);
}
int main() {
scanf("%d %d", &n, &q);
p[0] = 1;
for (int i = 1; i <= n; i++) {
p[i] = p[i - 1] * base % mod;
pre[i] = pre[i - 1] + p[i];
}
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
build(1, 1, n);
int op;
while (q--) {
scanf("%d", &op);
if (op == 1) {
int l, r;
scanf("%d %d", &l, &r);
modify(1, l, r, 1);
ll mx = query_mx(1, l, r);
if (mx >= MOD) {
modify_all(1, l, r);
}
}
else {
int x, y, len;
scanf("%d %d %d", &x, &y, &len);
int l1 = x, r1 = len + x - 1;
int l2 = y, r2 = len + y - 1;
ll num1 = query_sum(1, l1, r1);
ll num2 = query_sum(1, l2, r2);
num1 = num1 * p[y - x] % mod;
if (num1 == num2) {
printf("yes\n");
}
else {
printf("no\n");
}
}
}
return 0;
}