学习了下跳表的逻辑代码。固定层级跳表,对动态链表长度不灵活,考虑根据链表长度来扩充层级。
skiplist.go:
package skiplist
import (
"fmt"
"golang.org/x/exp/constraints"
"math/rand"
"time"
)
type CompareKey interface {
constraints.Ordered
}
type node[KeyType CompareKey] struct {
key KeyType // 泛型主键,支持不同主键类型的跳表
value any
level int
nextLevelNodes []*node[KeyType]
}
// SkipList 跳表
type SkipList[KeyType CompareKey] struct {
head *node[KeyType]
rander *rand.Rand
maxLevel int // 跳表最大层级
}
func newNode[KeyType CompareKey](key KeyType, value any, level int) *node[KeyType] {
no := new(node[KeyType])
no.key = key
no.value = value
no.level = level
no.nextLevelNodes = make([]*node[KeyType], level)
return no
}
func NewSkipList[KeyType CompareKey](maxLevel int) *SkipList[KeyType] {
sl := new(SkipList[KeyType])
sl.rander = rand.New(rand.NewSource(time.Now().UnixNano()))
sl.maxLevel = maxLevel
return sl
}
func (sl *SkipList[CompareKey]) Find(key CompareKey) (value any, find bool) {
_, no, find := sl.find(key)
if find {
return no.value, find
}
return nil, false
}
func (sl *SkipList[CompareKey]) Insert(key CompareKey, value any) (replace bool) {
if sl.head == nil {
// 空表,借用key泛型创建一个满层的头节点
headNode := newNode(key, nil, sl.maxLevel)
sl.head = headNode
}
// 计算当前节点所属层,fixme:要从底往上每一层抛硬币(二分)决定概率来计算最终层级,这里随便写下
insertMaxLevel := sl.rander.Intn(sl.maxLevel) + 1
// 创建插入节点
insertNode := newNode(key, value, insertMaxLevel)
// 层层索引
indexLevel := sl.maxLevel - 1
p := sl.head
NEXT:
for {
if indexLevel < 0 {
break
}
if p.nextLevelNodes[indexLevel] == nil {
// 插入p后面
} else if p.nextLevelNodes[indexLevel].key == key {
// 重复添加,修改值
p.nextLevelNodes[indexLevel].value = value
return true
} else if p.nextLevelNodes[indexLevel].key > key {
// 插入p后面
} else {
// 继续当前层下个节点比较
p = p.nextLevelNodes[indexLevel]
continue NEXT
}
if indexLevel < insertMaxLevel {
// 新节点触发跳表分层,插入p后面
insertNode.nextLevelNodes[indexLevel] = p.nextLevelNodes[indexLevel]
p.nextLevelNodes[indexLevel] = insertNode
}
// 往下一层查找
indexLevel--
}
return
}
func (sl *SkipList[CompareKey]) Delete(key CompareKey) (find bool) {
preNo, deleteNo, find := sl.find(key)
if find {
for i := deleteNo.level - 1; i >= 0; i-- {
// 摘除链接
preNo.nextLevelNodes[i] = deleteNo.nextLevelNodes[i]
}
}
return find
}
func (sl *SkipList[CompareKey]) find(key CompareKey) (*node[CompareKey], *node[CompareKey], bool) {
p := sl.head
indexLevel := sl.maxLevel - 1
for p != nil {
if indexLevel < 0 {
break
}
if p.nextLevelNodes[indexLevel] == nil {
indexLevel--
} else if p.nextLevelNodes[indexLevel].key > key {
indexLevel--
} else if p.nextLevelNodes[indexLevel].key == key {
return p, p.nextLevelNodes[indexLevel], true
} else {
p = p.nextLevelNodes[indexLevel]
}
}
return nil, nil, false
}
func (sl *SkipList[CompareKey]) Print() string {
str := ""
for i := 0; i < sl.maxLevel; i++ {
p := sl.head.nextLevelNodes[i]
subStr := ""
for p != nil {
subStr += fmt.Sprintf("%v ", p.key)
p = p.nextLevelNodes[i]
}
str = subStr + "\n" + str
}
return str
}
skiplist_test.go
package skiplist
import (
"fmt"
"testing"
"time"
)
func TestSkipList(t *testing.T) {
t1 := time.Now()
maxLevel := 50
sl := NewSkipList[int](maxLevel)
testNumberCount := 10000000
for i := testNumberCount; i > 0; i-- {
sl.Insert(i, i)
}
p := sl.head.nextLevelNodes[maxLevel-1]
topCount := 0
for p != nil {
topCount++
p = p.nextLevelNodes[maxLevel-1]
}
t2 := time.Since(t1)
fmt.Printf("build list with %v elements ok, cost %v microseconds, max level:%v, top level key length:%v\n",
testNumberCount, t2.Microseconds(), maxLevel, topCount)
assertFind(sl, 4990000, 4990000)
assertFind(sl, 4, 4)
assertNotFind(sl, 1000000000)
assertFind(sl, 2000000, 2000000)
assertAdd(sl, 100000000, 100000000, false)
assertDelete(sl, 4000000)
assertAdd(sl, 4000000, 4000000, false)
assertAdd(sl, 4000000, 4000000, true)
}
func assertFind[k CompareKey](sl *SkipList[k], key k, expectedValue any) {
value, find := sl.Find(key)
if !find {
panic(fmt.Sprintf("assertFind not found key:%v", key))
}
if value != expectedValue {
panic(fmt.Sprintf("assertFind value not equal:%v/%v", value, expectedValue))
}
}
func assertNotFind[k CompareKey](sl *SkipList[k], key k) {
value, find := sl.Find(key)
if find {
panic(fmt.Sprintf("assertNotFind found key:%v value:%v", key, value))
}
}
func assertAdd[k CompareKey](sl *SkipList[k], key k, value any, expectedReplace bool) {
replace := sl.Insert(key, value)
if replace != expectedReplace {
panic(fmt.Sprintf("assertAdd expected replace not equal:%v/%v", replace, expectedReplace))
}
assertFind(sl, key, value)
}
func assertDelete[k CompareKey](sl *SkipList[k], key k) {
if !sl.Delete(key) {
panic(fmt.Sprintf("assertDelete not found key:%v", key))
}
assertNotFind(sl, key)
}