链接 : G Caesar Cipher
题意 :
给定一个数组 ,范围为 [0,65536),有以下两种操作:
- 给出 x , y 把 [x , y] 内的每个数 + 1 同时对 65536 取模。
- 给出 x,y,L , 查询区间 [x , x + L - 1] 和区间 [y , y + L - 1]是否完全相同。
思路 :
- 思路就是 线段树维护 hash ,有区间修改和查询 判断两段 hash值是否相同就可以了。
- 首先考虑一下区间合并(也就是pushup),线段树的每个节点表示这一段的 hash 值,在区间合并的操作时 大区间的 hash 值就是 左区间的 hash值 * base ^ len (len表示右区间的长度) + 右区间的 hash值 。
Hash[rt] = (Hash[rt << 1] * poww[r - mid] + Hash[rt << 1 | 1])
- 然后是区间更新 ,把这个区间的值全部 + 1, hash 的变化 就是 base的前缀和 ,例如 某一个区间的hash值为 ∑ i = 0 n \sum_{i=0}^n ∑i=0na[i] * base ^ i (n 为区间长度 - 1),那如果现在把每个 a[i] 都 + 1 , 那hash值的变化就是 ∑ i = 0 n \sum_{i=0}^n ∑i=0n base ^ i , 这里用个前缀和记录一下 ,就可以很好的用 lazy维护。
- 查询操作和普通的查询不一样 , 因为在合并两个区间时 ,合并后的 hash 值 不是两个 hash的 简单相加(参考上面的pushup) , 也就是左区间的 hash值要先乘上 base ^ len(len为右区间长度) 再加右区间。
- 最后就要考虑一下溢出的问题了,如果在更新过程 某个数 >= 65536 , 就要对 65536 取模了 ,直接在更新操作里判断=肯定不好写 ,所以我们在每次更新后都找一下有没有数 大于 65536 ,这里怎么找呢 ,肯定不能暴力扫一遍 。 可以利用线段树进行一个类似二分的过程 ,维护一下每个区间的最大值 ,如果左区间最大值大于 65536 ,继续更新左区间,这样一直下去找到那个值为止 ,复杂度 log(n) ,不用担心超时。
代码:
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<stack>
#include<set>
#define iss ios::sync_with_stdio(false)
using namespace std;
typedef long long ll;
const int mod = 65536;
const int mod1 = 1e9 + 7;
const int mod2 = 998244353;
const int r = 31;
const int maxn = 2e6 + 7;
ll Hash[maxn],ma[maxn],la[maxn];
ll pre[maxn],poww[maxn]; // base 的幂次 前缀和
int a[maxn] ,n,q,x,y,op,L;
void pushup(int l , int r, int rt){
int mid = (l + r) / 2;
Hash[rt] = (Hash[rt << 1] * poww[r - mid] % mod1 + Hash[rt << 1 | 1]) % mod1;
ma[rt] = max(ma[rt << 1] , ma[rt << 1 | 1]);
}
void pushdown(int l,int r,int rt){
if(la[rt] == 0) return ;
int mid = (l + r) / 2;
Hash[rt << 1] = (Hash[rt << 1] + la[rt] * pre[mid - l] % mod1) % mod1; //加上前缀和的 幂次
Hash[rt << 1 | 1] = (Hash[rt << 1 | 1] + la[rt] * pre[r - mid - 1] % mod1) % mod1;
ma[rt << 1] += la[rt];
ma[rt << 1|1] += la[rt];
la[rt<<1] += la[rt];
la[rt<<1|1] += la[rt];
la[rt] = 0;
}
void update(int L,int R,int l,int r,int rt){
if(L <= l && R >= r){
Hash[rt] = (Hash[rt] + pre[r - l]) % mod1;
la[rt] ++;
ma[rt] ++;
return ;
}
pushdown(l , r, rt);
int mid = (l + r) / 2;
if(R > mid) update(L , R ,mid + 1 ,r ,rt << 1 | 1);
if(L <= mid) update(L ,R ,l , mid , rt << 1);
pushup(l ,r , rt);
}
void update_mod(int l,int r,int rt){ //考虑溢出
if(ma[rt] < mod){ //没有超过 mod的 直接退出
return ;
}
if(l == r){
ma[rt] -= mod;
Hash[rt] -= mod;
return ;
}
pushdown(l , r, rt);
int mid = (l + r) / 2;
if(ma[rt << 1] >= mod) update_mod( l , mid ,rt << 1);
if(ma[rt << 1 | 1] >= mod) update_mod(mid + 1 , r, rt << 1 | 1);
pushup(l , r, rt);
}
ll query(int L,int R,int l,int r,int rt){
ll s = 0;
if(L <= l && R >= r){
return Hash[rt];
}
pushdown(l , r, rt);
int mid = (l + r) / 2;
if(R > mid) s = (s + query(L,R,mid + 1,r,rt<<1|1) ) % mod1;
if(L <= mid) s = (s + poww[max(0,min(R , r)- mid)] * query(L,R,l,mid,rt<<1) % mod1) % mod1;
return s;
}
void build(int l,int r ,int rt){
if(l == r){
Hash[rt] = a[l];
ma[rt] = a[l];
return ;
}
int mid = (l + r) / 2;
build(l , mid ,rt << 1);
build(mid + 1 , r,rt << 1 | 1);
pushup(l,r, rt);
}
int main (){
poww[0] = pre[0] = 1;
for(int i = 1; i <= 5e5 ; i ++){
poww[i] = poww[i-1] * r % mod1;
}
for(int i = 1; i <= 5e5 ; i ++){
pre[i] = (pre[i-1] + poww[i]) % mod1;
}
scanf("%d%d",&n,&q);
for(int i = 1; i<= n; i ++){
scanf("%d",&a[i]);
}
build(1 , n,1);
while(q--){
scanf("%d",&op);
if(op == 1){
scanf("%d%d",&x,&y);
update(x,y,1,n,1);
update_mod(1 , n , 1);
}
if(op == 2){
scanf("%d%d%d",&x,&y,&L);
ll h1 = query(x , x + L - 1 ,1 , n, 1);
ll h2 = query(y , y + L - 1 ,1 , n, 1);
if(h1 == h2 ) printf ("yes\n");
else printf ("no\n");
}
}
}