LRU缓存机制扩展
上一篇文章已经给出了Lru缓存机制的编码实现,现在将对LRU进行扩展
- key和value支持任意类型
- 支持超时时间
- 保证线程安全
package lru
import (
"sync"
"time"
)
// cache接口
type Cache interface {
Put(key, value interface{})
PutWithTTL(key, value interface{}, expires time.Duration)
Get(key interface{}) (interface{}, bool)
GetWithTTL(key interface{}) (interface{}, time.Duration, bool)
Remove(key interface{})
Len() int
}
type LRUCache struct {
size, capacity int // 缓存的长度和容量
cache map[interface{}]*LinkedNode // 缓存
head, tail *LinkedNode // 伪头部和伪尾部节点
m sync.Mutex // 锁
}
// 双向链表
type LinkedNode struct {
key, value interface{}
prev, next *LinkedNode
expireTime time.Time // 过期时间
}
// 构造器
func New(capacity int) *LRUCache {
cache := &LRUCache{
capacity: capacity,
cache: make(map[interface{}]*LinkedNode),
head: &LinkedNode{},
tail: &LinkedNode{},
}
cache.head.next = cache.tail
cache.tail.prev = cache.head
return cache
}
func (lru *LRUCache) Remove(key interface{}) {
lru.m.Lock()
defer lru.m.Unlock()
if node, ok := lru.cache[key]; ok {
lru.removeNode(node)
delete(lru.cache, node.key)
}
}
func (lru *LRUCache) Len() int {
lru.m.Lock()
defer lru.m.Unlock()
return lru.size
}
func (lru *LRUCache) Get(key interface{}) (interface{}, bool) {
value, _, flag := lru.GetWithTTL(key)
return value, flag
}
func (lru *LRUCache) GetWithTTL(key interface{}) (interface{}, time.Duration, bool) {
lru.m.Lock()
defer lru.m.Unlock()
if node, ok := lru.cache[key]; ok {
// 如果 key 存在,通过哈希表定位,再移到头部
lru.moveToHead(node)
return node.value, time.Until(node.expireTime), true
}
return nil, 0, false
}
func (lru *LRUCache) Put(key interface{}, value interface{}) {
lru.PutWithTTL(key, value, -1)
}
func (lru *LRUCache) PutWithTTL(key interface{}, value interface{}, expires time.Duration) {
lru.m.Lock()
defer lru.m.Unlock()
if node, ok := lru.cache[key]; ok {
// 如果 key 存在,先通过哈希表定位,再修改 value,并移到头部
node.value = value
lru.moveToHead(node)
} else {
newNode := &LinkedNode{
key: key,
value: value,
expireTime: time.Now().Add(expires),
}
// 添加进哈希表
lru.cache[key] = newNode
// 添加至双向链表的头部
lru.addToHead(newNode)
lru.size++
if lru.size > lru.capacity {
//如果超出容量,删除双向链表的尾部节点
oldTail := lru.removeTail()
// 删除哈希表中对应的项
delete(lru.cache, oldTail.key)
lru.size--
}
// 开启一个监控协程使缓存过期
if expires > 0 {
go func() {
ticker := time.NewTicker(expires)
select {
case <-ticker.C:
lru.removeNode(newNode)
delete(lru.cache, newNode.key)
}
}()
}
}
}
// 删除节点
func (lru *LRUCache) removeNode(node *LinkedNode) {
node.prev.next = node.next
node.next.prev = node.prev
}
// 添加节点到头部
func (lru *LRUCache) addToHead(node *LinkedNode) {
node.prev = lru.head
node.next = lru.head.next
lru.head.next.prev = node
lru.head.next = node
}
// 删除尾节点
func (lru *LRUCache) removeTail() *LinkedNode {
node := lru.tail.prev
lru.removeNode(node)
return node
}
// 将节点移动到头节点
func (lru *LRUCache) moveToHead(node *LinkedNode) {
lru.removeNode(node)
lru.addToHead(node)
}
测试
package main
import (
"fmt"
"lru"
"time"
)
func main() {
cache := lru.New(2)
cache.Put(1, "a")
cache.Put(2, "b")
cache.Put(3, "c")
exitPrint(cache, 1)
cache.PutWithTTL(4, "d", 10*time.Second)
time.Sleep(1 * time.Second)
exitPrint(cache, 4)
_, duration, ok := cache.GetWithTTL(4)
if ok {
fmt.Printf("%v \n", duration)
}
time.Sleep(10 * time.Second)
exitPrint(cache, 4)
}
func exitPrint(cache *lru.LRUCache, key interface{}) {
get, ok := cache.Get(key)
if ok {
fmt.Println(key, ":", get)
} else {
fmt.Println(key, "not exit!")
}
}
结果:
1 not exit!
4 : d
8.9995341s
4 not exit!