[python刷题模板] 字符串哈希
一、 算法&数据结构
1. 描述
字符串哈希可以用O(n)的时间预处理,用O(1)的时间计算某一段的哈希值。
这样可以用O(1)时间比较两段是否相同。
字符串哈希把每个字符看做P进制数,然后用前缀和的思想减法求区间哈希。
2. 复杂度分析
- 预处理O(log2n)
- 查询,O(l1)
3. 常见应用
- 比较字符串区间段是否相同。
- 字符串段计数。
4. 常用优化
- 由于python数字溢出会自动转大数,而大数计算很慢,因此必须取模。但这也导致了容易哈希冲突。
- 注意板子里的切片效率不如调函数,勿用。
- 在LC上试了两题,MOD取1e9+7都wa了,取到1e13+7可以过
二、 模板代码
1. 比较同一个字符串内两块子段是否相同
例题: 6195. 对字母串可执行的最大删除数
313周赛T4,这题可以LCP,python的话可以切片比较不需要LCP,但DP思路类似。
DP的复杂度已经是O(n^2),数据范围4000,因此需要O(1)比较两段是否相同。
LCP的话只要共同前缀长度超过段长即可;
这里记一个字符串哈希的做法,顺便当板子。
注意切片会TLE,不如调用函数。
class StringHash:
# 字符串哈希,用O(n)时间预处理,用O(1)时间获取段的哈希值
def __init__(self, s):
n = len(s)
self.BASE = BASE = 131 # 进制 131,131313
self.MOD = MOD = 10 ** 13 + 7 # 10**9+7,998244353,10**13+7
self.h = h = [0] * (n + 1)
self.p = p = [1] * (n + 1)
for i in range(1, n + 1):
p[i] = (p[i - 1] * BASE) % MOD
h[i] = (h[i - 1] * BASE % MOD + ord(s[i - 1])) % MOD
# 用O(1)时间获取开区间[l,r)(即s[l:r])的哈希值
def get_hash(self, l, r):
return (self.h[r] - self.h[l] * self.p[r - l] % self.MOD) % self.MOD
# # 用O(1)时间获取开区间[l,r)(即s[l:r])的哈希值
# def __getitem__(self, index):
# if isinstance(index, slice):
# l, r, step = index.indices(len(self.h)-1)
# if step != 1:
# raise Exception('StringHash slice 步数仅限1'+str(index))
# return (self.h[r] - self.h[l] * self.p[r - l] % self.MOD) % self.MOD
# else:
# return (self.h[index+1] - self.h[index] * self.p[index+1 - index] % self.MOD) % self.MOD
class Solution:
def deleteString(self, s: str) -> int:
n = len(s)
sh = StringHash(s)
f = [1] * n
for i in range(n - 1, -1, -1):
for j in range(i + 1, (i + n) // 2 + 1):
if sh[i:j] == sh[j: j + j - i]:
f[i] = max(f[i], f[j] + 1)
return f[0]
2. 计数同一个字符串内定长子段数量
例题: 187. 重复的DNA序列
这题目标长度是10,因此直接切片计数也能过。
而且由于只有4类字符,可以分别用0123进行状态压缩,每个数字栈两位,一共也才20位。
但如果目标长度很大,那就不好做了,可以用字符串哈希来做。
– 注意,这里wa了很多次,因为出题人加了case卡字符串哈希,换了很多MOD,最后用了10**13+7才过。
class StringHash:
# 字符串哈希,用O(n)时间预处理,用O(1)时间获取段的哈希值
def __init__(self, s):
n = len(s)
self.BASE = BASE = 131 # 进制 131,131313
self.MOD = MOD = 10**13+7 # 10**9+7,10**13+7,998244353
self.h = h = [0] * (n + 1)
self.p = p = [1] * (n + 1)
for i in range(1, n + 1):
p[i] = (p[i - 1] * BASE) % MOD
h[i] = (h[i - 1] * BASE % MOD + ord(s[i - 1])*2) % MOD
# 用O(1)时间获取开区间[l,r)(即s[l:r])的哈希值,比切片要快
def get_hash(self, l, r):
return (self.h[r] - self.h[l] * self.p[r - l] % self.MOD) % self.MOD
# 用O(1)时间获取开区间[l,r)(即s[l:r])的哈希值;这个实测会TLE,不如用self.get_hash
def __getitem__(self, index):
if isinstance(index, slice):
l, r, step = index.indices(len(self.h)-1)
if step != 1:
raise Exception('StringHash slice 步数仅限1'+str(index))
return (self.h[r] - self.h[l] * self.p[r - l] % self.MOD) % self.MOD
else:
return (self.h[index+1] - self.h[index] * self.p[index+1 - index] % self.MOD) % self.MOD
class Solution:
def findRepeatedDnaSequences(self, s: str) -> List[str]:
sh = StringHash(s)
n = len(s)
vis = Counter()
ans = []
for i in range(n-9):
# h = sh[i:i+10]
h = sh.get_hash(i,i+10)
vis[h] += 1
if vis[h] == 2:
ans.append(s[i:i+10])
return ans
3. 计数同一个字符串内定长子段数量+二分答案
例题: 1044. 最长重复子串
这题是上一个例子187. 重复的DNA序列的升级版。
题意暴力无脑,求s中最长的出现重复的子串,len(s)<3e4。
显然至少要nlgn才能过。
于是可以套用上一题的思路在O(n)时间处理出一个长度为x的重复子串。
发现是否有重复子串和x呈单调性,毕竟如果存在长度为x的子串,那么它的前缀x-1长度的子串也是重复的。
因此可以二分。
class StringHash:
# 字符串哈希,用O(n)时间预处理,用O(1)时间获取段的哈希值
def __init__(self, s):
n = len(s)
self.BASE = BASE = 131 # 进制 131,131313
self.MOD = MOD = 10**13+7 # 10**13+7,10**13+7,998244353
self.h = h = [0] * (n + 1)
self.p = p = [1] * (n + 1)
for i in range(1, n + 1):
p[i] = (p[i - 1] * BASE) % MOD
h[i] = (h[i - 1] * BASE % MOD + ord(s[i - 1])*2) % MOD
# 用O(1)时间获取开区间[l,r)(即s[l:r])的哈希值,比切片要快
def get_hash(self, l, r):
return (self.h[r] - self.h[l] * self.p[r - l] % self.MOD) % self.MOD
# 用O(1)时间获取开区间[l,r)(即s[l:r])的哈希值;这个实测会TLE,不如用self.get_hash
def __getitem__(self, index):
if isinstance(index, slice):
l, r, step = index.indices(len(self.h)-1)
if step != 1:
raise Exception('StringHash slice 步数仅限1'+str(index))
return (self.h[r] - self.h[l] * self.p[r - l] % self.MOD) % self.MOD
else:
return (self.h[index+1] - self.h[index] * self.p[index+1 - index] % self.MOD) % self.MOD
class Solution:
def longestDupSubstring(self, s: str) -> str:
sh = StringHash(s)
n = len(s)
def calc(x): # 计算是否不存在长度为x的重复串
vis = set()
for i in range(n-x+1):
h = sh[i:i+x]
if h in vis:
return 0
vis.add(h)
return 1
p = bisect_right(range(n+1),0,lo=1,key =calc)
if p == 1:
return ''
x = p -1
vis = set()
for i in range(n-x+1):
h = sh[i:i+x]
if h in vis:
return s[i:i+x]
vis.add(h)
三、其他
- 由于是哈希,不可避免会遇到冲突,这时可以尝试换MOD和BASE。