是否需要一个支持过期的Map?
需要是吧?好的那就看下面的代码实现
代码实现:
package utils
import (
"fmt"
"hash/fnv"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
)
var (
DefaultCleanupTime = 1 * time.Minute
DefaultExpiryValue = 5 * time.Minute
)
// ExpiringValue 存储值和过期时间
type ExpiringValue struct {
Value interface{}
ExpiryTime time.Time
}
type ExpiringMapOption func(expiringMap *ExpiringMap)
// ExpiringMap 存储带过期时间的数据结构
type ExpiringMap struct {
data []sync.Map // 采用多个 sync.Map 来减小锁竞争
activeCount int64 // 活跃项计数
shards int // 分段数量
cleanupTime time.Duration // 清理时间间隔
expiryTime time.Duration //设置过期时间
}
func WithExpiryTime(expiryTime time.Duration) ExpiringMapOption {
return func(em *ExpiringMap) {
em.expiryTime = expiryTime
}
}
func WithCleanupTime(duration time.Duration) ExpiringMapOption {
return func(em *ExpiringMap) {
em.cleanupTime = duration
}
}
// NewExpiringMap 创建新的 ExpiringMap
func NewExpiringMap(shardCount int, options ...ExpiringMapOption) *ExpiringMap {
em := &ExpiringMap{
data: make([]sync.Map, shardCount),
shards: shardCount,
cleanupTime: DefaultCleanupTime,
expiryTime: DefaultExpiryValue,
}
for _, option := range options {
option(em)
}
go em.cleanup() // 启动清理 goroutine
return em
}
// 计算键的哈希值
func (em *ExpiringMap) hash(key string) int {
h := fnv.New32a()
h.Write([]byte(key))
return int(h.Sum32()) % em.shards
}
// Set 设置值和过期时间
func (em *ExpiringMap) Set(key string, value interface{}, expiry ...time.Duration) {
expiryDuration := em.expiryTime
if len(expiry) > 0 {
expiryDuration = expiry[0]
}
shard := em.hash(key)
if _, ok := em.data[shard].Load(key); !ok {
atomic.AddInt64(&em.activeCount, 1)
}
em.data[shard].Store(key, ExpiringValue{
Value: value,
ExpiryTime: time.Now().Add(expiryDuration),
})
}
// Get 获取值,如果过期则返回 nil
func (em *ExpiringMap) Get(key string) (interface{}, bool) {
shard := em.hash(key)
if val, found := em.data[shard].Load(key); found {
it := val.(ExpiringValue)
if time.Now().Before(it.ExpiryTime) {
return it.Value, true
}
em.data[shard].Delete(key)
atomic.AddInt64(&em.activeCount, -1)
}
return nil, false
}
// Delete 删除指定的键
func (em *ExpiringMap) Delete(key string) {
shard := em.hash(key)
if _, found := em.data[shard].Load(key); found {
em.data[shard].Delete(key)
atomic.AddInt64(&em.activeCount, -1)
}
}
// cleanup 定期清理过期的元素
func (em *ExpiringMap) cleanup() {
ticker := time.NewTicker(em.cleanupTime)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
for i := 0; i < em.shards; i++ {
em.data[i].Range(func(key, value interface{}) bool {
it := value.(ExpiringValue)
if now.After(it.ExpiryTime) {
em.data[i].Delete(key)
atomic.AddInt64(&em.activeCount, -1)
}
return true
})
}
}
}
// Stats 获取当前活跃项的数量
func (em *ExpiringMap) Stats() int64 {
return atomic.LoadInt64(&em.activeCount)
}
func mapToString[K comparable, V comparable](m map[K]V) string {
keys := make([]K, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
return fmt.Sprintf("%v", keys[i]) < fmt.Sprintf("%v", keys[j])
})
var sb strings.Builder
for _, k := range keys {
v := m[k]
sb.WriteString(fmt.Sprintf("%v:%v,", k, v))
}
return sb.String()
}
// 计算字符串的哈希值
func hashString(s string) uint32 {
h := fnv.New32a()
h.Write([]byte(s))
return h.Sum32()
}
// MapsEqual 比较两个 map 是否相等
func MapsEqual[K comparable, V comparable](m1, m2 map[K]V) bool {
return hashString(mapToString(m1)) == hashString(mapToString(m2))
}
Test简单验证
func TestExpiringMap(t *testing.T) {
now := time.Now()
defer fmt.Println(fmt.Sprintf("总耗时:%v", time.Since(now).Seconds()))
//expiringMap := NewExpiringMap(8, WithCleanupTime(2*time.Second)) // 分段数量为8,清理间隔为2秒
expiringMap := NewExpiringMap(8, WithCleanupTime(500*time.Millisecond), WithExpiryTime(15*time.Second)) // 分段数量为8,清理间隔为2秒
// 读取并观察结果
value, found := expiringMap.Get("key1")
if found {
fmt.Println("Found:", value)
} else {
fmt.Println("Not found or expired.")
}
time.Sleep(2)
value, found = expiringMap.Get("key1")
if found {
fmt.Println("Found:", value)
} else {
fmt.Println("Not found or expired.")
}
for i := 0; i < 1000000; i++ {
expiringMap.Set("key"+fmt.Sprintf("%d", i), "value"+fmt.Sprintf("%d", i),
time.Duration(RandInt(0, 60))*time.Second)
}
//随机删除 100000个元素
for i := 0; i < 10000; i++ {
//生成随机int
a := RandInt(0, 10000-1)
expiringMap.Delete("key" + fmt.Sprintf("%d", a))
}
fmt.Println(fmt.Sprintf("总耗时:%v", time.Since(now).Seconds()))
//ticker := time.NewTicker(1 * time.Second)
//defer ticker.Stop()
//for range ticker.C {
// fmt.Println("Current active item count:", expiringMap.Stats())
//}
}
func RandInt(min, max int) int {
return rand.Intn(max-min+1) + min
}