一、概述
哈希表用于存储“键值对”(key-value),最常见莫过于 python 的 dict。
它的优点是查找速度快,可以在近似 O(1) 的时间内确定 key 的位置。使用“双数组”也可以存储“键值对”,但查找 key 的时间为 O(n)。
哈希表是如何实现的呢?
我们知道,数组是时间复杂度为 O(1) 的数据结构,因为知道某一数据的序号,立刻就可以通过序号访问。
如果能够为 key 转化为一个独一无二的序号,再用数组对应的编号来存储 value ,即可实现 O(1) 的快速访问。
二、原理
- key 一般来说是“字符串”(如果是数值,也可转化为str)。而每一个字符(英文),根据 ASCII 编码规则可以转化为一个 0~255的数值(8位,1字节)。同样地,中文等其他字符,也可以根据 Unicode 转换为至多 4 个字节(对应的数值)。
- 为减少编码冲突,例如:"abcd" 与 "bcda" 的编码相同,采用下列转换公式:
ℎ𝑎𝑠ℎ𝑐𝑜𝑑𝑒=𝑎𝑝0+𝑏𝑝1+𝑐𝑝2+𝑑𝑝3+... , 其中 p 是作为基底的质数,一般取 31 或 37。
- 为减少冲突,哈希表的大小b(容量)大于数据量n(存储量),一般来说 n/b < 0.7。
- 冲突是不可避免的,一般有两种方式处理:
- 封闭寻址(链表法):遇到冲突时,在同一地址用链表保存数据
- 开放寻址:遇到冲突时,用第二个哈希函数处理,送至新的地址 (也可以寻找相邻的空单元)
三、代码实现
1、根据 key 生成哈希编码
def get_hash_code(self, key):
key = str(key)
num_buckets = len(self.bucket_array) # 哈希表的容量
current_coefficient = 1
hash_code = 0
for character in key:
hash_code += ord(character) * current_coefficient
hash_code = hash_code % num_buckets # compress hash_code
current_coefficient *= self.p
current_coefficient = current_coefficient % num_buckets # compress coefficient
return hash_code % num_buckets # one last compression before returning
其中,compress hash_code 和 compress coefficient 两行代码可以隐藏,结果是一样的。
2、冲突处理:封闭寻址
def put(self, key, value):
new_node = LinkedListNode(key, value) # 把键值对生产一个结点(用于存储)
bucket_index = self.get_bucket_index(key) # 根据键值得到 bucket_index 即array的序号
#检查输入的key是否已经存在
head = self.bucket_array[bucket_index] #把对应的队列空间定义为head
while head is not None: # 如果head(队列空间)不为空,进入循环
if head.key == key: # 如果head(队列空间)的key 等于输入的key,表示需要更新这个key对应的value
head.value = value # 更新value
return # 完成更新,put函数运行结束
head = head.next # 继续遍历队列单元存储的链表
#输入的key不存在
head = self.bucket_array[bucket_index] # 前面head可能修改过了,需要重新指向
new_node.next = head # 新结点指向原有的链表头(旧结点), 如果原来为None,则新结点的next指向None
self.bucket_array[bucket_index] = new_node # bucket_array[bucket_index] 重新指向新的结点
self.num_entries += 1
def get(self, key):
bucket_index = self.get_bucket_index(key)
head = self.bucket_array[bucket_index]
# 在链表中查找key及对应的value
while head is not None:
if head.key == key:
return head.value
head = head.next
return None
3、哈希表扩容
def _rehash(self):
old_num = len(self.bucket_array)
old_bucket_array = self.bucket_array
num_buckets = old_num * 2
self.bucket_array = [None for _ in range(num_buckets)]
# 复制扩容前的数据
for head in old_bucket_array:
while head is not None:
key = head.key
value = head.value
self.put(key, value)
head = head.next
四、应用
通过哈希表实现caching,提高递归效率。
经典问题:如果每次可以爬1、2、3级,爬 n 级楼梯有多少种方式?
传统方法:
# 原始状态: L(4) = L(3)+L(2)+L(1) =7 被计算了多次
def staircase(n):
if n == 1:
return 1
if n == 2:
return 2
if n == 3:
return 4
value = staircase(n-1) + staircase(n-2) + staircase(n-3)
print(value) # 观察递归的重复执行次数
return value
staircase(7)
输出:可见 staircase(4) 被递归调用了4次
增加chaching:
# 加入caching: L(4) = L(3)+L(2)+L(1) =7 只计算1次
def countstair(n):
cache = {} # caching 字典
def staircase(n):
if n == 1:
return 1
if n == 2:
return 2
if n == 3:
return 4
if n in cache:
return cache.get(n) # 如果在 cache 中,直接返回
value = staircase(n-1) + staircase(n-2) + staircase(n-3)
cache[n] = value # 计算结果加入 cache
print(value) # 观察递归的重复执行次数
return value
return staircase(n)
countstair(7)
输出: 没有重复执行
另一实现方式:
# 利用functools 也有同样效果
import functools
@functools.lru_cache()
def staircase(n):
if n == 1:
return 1
if n == 2:
return 2
if n == 3:
return 4
value = staircase(n-1) + staircase(n-2) + staircase(n-3)
print(value)
return value
常见的哈希函数
一个哈希函数的好不好,取决于以下三点
- 哈希函数的定义域必须包括需要存储的全部关键码,而如果哈希表允许有m个地址时,其值域必须在0 到m-1之间
- 哈希函数计算出来的地址能均匀分布在整个空间中
- 哈希函数应该比较简单
除留余数法(最常用)
函数:Hash(key)=key MOD p (p<=m m为表长),求出来的hash(key)就为存储该key的下标
例如有一下数据{2, 4, 6, 8, 9}
表长为10,也就是数组容量为10
直接定制法(常用)
取关键字的某个线性函数为散列地址(A、B为常数):Hash(Key)= A*Key + B
优点:简单、均匀
缺点:需要事先知道关键字的分布情况
适用场景:适合查找较小数据范围且连续的情况
平方取中法(少)
如果关键字的每一位都有某些数字重复出现频率很高的现象,可以先求关键字的平方值,通过平方扩大差异,而后取中间数位作为最终存储地址。
使用举例
比如key=1234 1234^2=1522756 取227作hash地址
比如key=4321 4321^2=18671041 取671作hash地址
适用场景:事先不知道数据并且数据长度较小的情况
哈希冲突
即不同的key通过同一哈希函数产生了相同的哈希位置,H(key1)=H(key2),例如我们在除留余数法中的例子,如果此时插入一个12,其hash(12)为2,此时下标为2的位置已经有元素,此时就会产生哈希冲突
处理哈希冲突
解决哈希冲突主要有两个方案:闭散列 和 开散列
闭散列
闭散列:也叫开放定址法,当发生哈希冲突时,如果哈希表未被装满,说明在哈希表中必然还有空位置,那 么可以把key存放到冲突位置中的“下一个” 空位置中去
闭散列中主要处理方法有 线性探测 和 二次探测
线性探测
思想:从计算的哈希位置开始,往后找到第一个空闲的位置存放数据
插入:插入就是计算哈希地址,将数据存放在计算出来的哈希位置上,如果该位置有数据则往后查找第一个空闲位置插入。但是当我们的元素越多时,我们产生的哈希冲突的次数就会越多,
删除:当我们要删除一个元素时,不能物理上直接删除,例如我们把15删除了,此时下标为8的位置为空,当我们要查找25这个元素时,也是会从下标为5这个位置开始查找,当5这个位置不是25时,说明产生了哈希冲突,且该插入是使用的是线性探测,也就是第一个空位置插入。我们往后查找时,如果该数据存在,则在空位置之前一定存在该数。但是此时我们物理上把15删除了。查找会查找到下标为8的位置就结束查找,此时也就不会找到25这个数据了。
所以使用线性探测方法,删除并不是实际意义上的删除,而是一种伪删除,我们可以定义三种状态,分别是:EMPTY、EXIST、DELETE。EMPTY表示该位置从来没存放过数据,是一个空位置;EXIST表示该位置存在数据;DELETE表示该位置之前存放过数据,只是已经删除了而已
此时我们想要删除8这个位置上的数据时,就将该位置的状态置为DELETE,我们再次查找25这个数字时,遇到8位置就不会停止搜索,会继续往后搜索,直至遇到状态为EMPTY的位置为止。但是次方法会造成一个问题,就是有可能数据满了,如果此时还一直搜索,就不会找到空的位置,会一直搜索下去。而且如果数据比较极端且数据越来越多,产生的哈希冲突会越来越多。这就不符合我们的哈希要求的高效率的插入与查找。解决办法就是进行扩容
【文章福利】需要C/C++ Linux服务器架构师学习资料加群812855908(资料包括C/C++,Linux,golang技术,Nginx,ZeroMQ,MySQL,Redis,fastdfs,MongoDB,ZK,流媒体,CDN,P2P,K8S,Docker,TCP/IP,协程,DPDK,ffmpeg等)
扩容:扩容并不是一定要等到数据满了才扩容。我们知道当数据越来越多,产生哈希冲突的次数就越多,所以我们要设定一个阈值,也就是当数据达到一定的数量时,就有必要进行扩容。而这决定这个阈值的高低的是一个叫负载因子。负载因子 = 实际存放元素 / 数组容量,范围在0~1之间,我们通常将负载因子置为[0.6, 0.8]之间。例如我们数组大小有10个,负载因为为0.7,则当插入第8个元素的时候就需要进行扩容,因为8/10=0.8>.07,也就是大于我们的负载因子就需要进行扩容。扩容的时候要注意,我们需要将原来的数据移动到新的表中,但是如果是单纯的赋值获取,那哈希冲突并没有解决,而此时我们应该将旧表中的数据重新以插入的方式插入到新的表中,从而减少哈希冲突的次数
```python
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
from transformers.models.bert.configuration_bert import *
import torch
config = BertConfig.from_pretrained("bert-base-uncased")
bert_pooler = BertPooler(config=config)
print("input to bert pooler size: {}".format(config.hidden_size))
batch_size = 1
seq_len = 2
hidden_size = 768
x = torch.rand(batch_size, seq_len, hidden_size)
y = bert_pooler(x)
print(y.size())
```