题目链接: Caesar Cipher
大致题意
给定一个非负整数序列, 有两种操作: 操作1: 将[l, r]区间所有数字+1, 操作2: 询问两个子序列区间是否相同.
特别的: 对于每一个数字, 应该为[0, 65535]之间的整数(%65536).
解题思路
区间修改+区间查询+询问子序列是否相同 ==> 线段树 + 字符串哈希
AC代码
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 1; i <= (n); ++i)
using namespace std;
typedef long long ll;
const int N = 5E5 + 10, mod = 1E9 + 7, MOD = 65536, B = 13331; //B为进制
int w[N], P[N];
struct node {
int l, r; int base; //base表示当前区间的基数, 仅在build时固定即可
int fmax, hash; //fmax存放当前区间数字的最大值, 根结点的fmax即为当前位的数值
int lazy;
}t[N << 2];
void pushdown(node& op, int lazy) {
op.lazy += lazy;
op.fmax += lazy;
op.hash = (op.hash + (ll)lazy * op.base) % mod;
}
void pushdown(int x) {
if (!t[x].lazy) return;
pushdown(t[x << 1], t[x].lazy), pushdown(t[x << 1 | 1], t[x].lazy);
t[x].lazy = 0;
}
void pushup(int x) {
t[x].fmax = max(t[x << 1].fmax, t[x << 1 | 1].fmax);
t[x].hash = (t[x << 1].hash + t[x << 1 | 1].hash) % mod;
}
void build(int l, int r, int x = 1) {
if (l == r) {
t[x] = { l, r, P[l], w[l], 0, 0 };
t[x].hash = (ll)P[l] * w[l] % mod;
return;
}
int mid = l + r >> 1;
build(l, mid, x << 1), build(mid + 1, r, x << 1 | 1);
t[x] = { l, r, (t[x << 1].base + t[x << 1 | 1].base) % mod, 0, 0, 0 };
pushup(x);
}
void modify(int l, int r, int c, int x = 1) {
if (l <= t[x].l && r >= t[x].r) { pushdown(t[x], c); return; }
pushdown(x);
int mid = t[x].l + t[x].r >> 1;
if (l <= mid) modify(l, r, c, x << 1);
if (r > mid) modify(l, r, c, x << 1 | 1);
pushup(x);
}
int ask(int l, int r, int x = 1) {
if (l <= t[x].l && r >= t[x].r) return t[x].hash;
pushdown(x);
int mid = t[x].l + t[x].r >> 1;
if (r <= mid) return ask(l, r, x << 1);
if (l > mid) return ask(l, r, x << 1| 1);
auto L = ask(l, r, x << 1), R = ask(l, r, x << 1 | 1);
return (L + R) % mod;
}
void outofmod(int x = 1) { //每次检测是否存在某些点的数值>MOD;
if (t[x].fmax < MOD) return;
if (t[x].l == t[x].r) {
t[x].fmax = 0, t[x].hash = 0;
return;
}
pushdown(x);
outofmod(x << 1), outofmod(x << 1 | 1);
pushup(x);
}
void init(int n) {
P[0] = 1;
rep(i, n) {
P[i] = (ll)P[i - 1] * B % mod;
scanf("%d", &w[i]);
}
}
int main()
{
int n, m; cin >> n >> m;
init(n);
build(1, n);
while (m--) {
int op; scanf("%d", &op);
if (op == 1) {
int l, r; scanf("%d %d", &l, &r);
modify(l, r, 1); //由于本题的操作是给[l, r]加1, 所以传参就是1
outofmod(); //检查溢出
}
else {
int x, y, l; scanf("%d %d %d", &x, &y, &l);
if (x > y) swap(x, y); //注意x应当小于y;
int res1 = ask(x, x + l - 1), res2 = ask(y, y + l - 1);
res1 = (ll)res1 * P[y - x] % mod; //对齐区间
printf("%s\n", res1 == res2 ? "yes" : "no");
}
}
return 0;
}