Update Context for store

This commit is contained in:
XOF
2025-11-22 14:20:05 +08:00
parent ac0e0a8275
commit 2b0b9b67dc
31 changed files with 817 additions and 1016 deletions

View File

@@ -1,8 +1,9 @@
// Filename: internal/store/memory_store.go (经同行审查后最终修复版)
// Filename: internal/store/memory_store.go
package store
import (
"context"
"fmt"
"math/rand"
"sort"
@@ -12,6 +13,7 @@ import (
"github.com/sirupsen/logrus"
)
// ensure memoryStore implements Store interface
var _ Store = (*memoryStore)(nil)
type memoryStoreItem struct {
@@ -32,7 +34,6 @@ type memoryStore struct {
items map[string]*memoryStoreItem
pubsub map[string][]chan *Message
mu sync.RWMutex
// [USER SUGGESTION APPLIED] 使用带锁的随机数源以保证并发安全
rng *rand.Rand
rngMu sync.Mutex
logger *logrus.Entry
@@ -42,7 +43,6 @@ func NewMemoryStore(logger *logrus.Logger) Store {
store := &memoryStore{
items: make(map[string]*memoryStoreItem),
pubsub: make(map[string][]chan *Message),
// 使用当前时间作为种子,创建一个新的随机数源
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
logger: logger.WithField("component", "store.memory 🗱"),
}
@@ -50,13 +50,12 @@ func NewMemoryStore(logger *logrus.Logger) Store {
return store
}
// [USER SUGGESTION INCORPORATED] Fix #1: 使用 now := time.Now() 进行原子性检查
func (s *memoryStore) startGCollector() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
s.mu.Lock()
now := time.Now() // 避免在循环中重复调用
now := time.Now()
for key, item := range s.items {
if !item.expireAt.IsZero() && now.After(item.expireAt) {
delete(s.items, key)
@@ -66,92 +65,10 @@ func (s *memoryStore) startGCollector() {
}
}
// [USER SUGGESTION INCORPORATED] Fix #2 & #3: 修复了致命的nil检查和类型断言问题
func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
// --- [架构修正] 所有方法签名都增加了 context.Context 参数以匹配接口 ---
// --- 内存实现可以忽略该参数,用 _ 接收 ---
mainItem, mainOk := s.items[mainKey]
var mainSet map[string]struct{}
if mainOk && !mainItem.isExpired() {
// 安全地进行类型断言
mainSet, mainOk = mainItem.value.(map[string]struct{})
// 确保断言成功且集合不为空
mainOk = mainOk && len(mainSet) > 0
} else {
mainOk = false
}
if !mainOk {
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
return "", ErrNotFound
}
// 安全地进行类型断言
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
if !cooldownSetOk || len(cooldownSet) == 0 {
return "", ErrNotFound
}
s.items[mainKey] = cooldownItem
delete(s.items, cooldownKey)
mainSet = cooldownSet
}
var popped string
for k := range mainSet {
popped = k
break
}
delete(mainSet, popped)
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
s.items[cooldownKey] = cooldownItem
}
// 安全地处理冷却池
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
if !ok {
cooldownSet = make(map[string]struct{})
cooldownItem.value = cooldownSet
}
cooldownSet[popped] = struct{}{}
return popped, nil
}
// SRandMember [并发修复版] 使用带锁的rng
func (s *memoryStore) SRandMember(key string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return "", ErrNotFound
}
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
if len(members) == 0 {
return "", ErrNotFound
}
s.rngMu.Lock()
n := s.rng.Intn(len(members))
s.rngMu.Unlock()
return members[n], nil
}
// --- 以下是其余函数的最终版本,它们都遵循了安全、原子的锁策略 ---
func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
var expireAt time.Time
@@ -162,7 +79,7 @@ func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
return nil
}
func (s *memoryStore) Get(key string) ([]byte, error) {
func (s *memoryStore) Get(_ context.Context, key string) ([]byte, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -175,7 +92,7 @@ func (s *memoryStore) Get(key string) ([]byte, error) {
return nil, ErrNotFound
}
func (s *memoryStore) Del(keys ...string) error {
func (s *memoryStore) Del(_ context.Context, keys ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
for _, key := range keys {
@@ -184,14 +101,14 @@ func (s *memoryStore) Del(keys ...string) error {
return nil
}
func (s *memoryStore) Exists(key string) (bool, error) {
func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
return ok && !item.isExpired(), nil
}
func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -208,7 +125,7 @@ func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool,
func (s *memoryStore) Close() error { return nil }
func (s *memoryStore) HDel(key string, fields ...string) error {
func (s *memoryStore) HDel(_ context.Context, key string, fields ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -223,7 +140,7 @@ func (s *memoryStore) HDel(key string, fields ...string) error {
return nil
}
func (s *memoryStore) HSet(key string, values map[string]any) error {
func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -242,7 +159,7 @@ func (s *memoryStore) HSet(key string, values map[string]any) error {
return nil
}
func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -259,7 +176,7 @@ func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
return make(map[string]string), nil
}
func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
func (s *memoryStore) HIncrBy(_ context.Context, key, field string, incr int64) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -281,7 +198,7 @@ func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
return newVal, nil
}
func (s *memoryStore) LPush(key string, values ...any) error {
func (s *memoryStore) LPush(_ context.Context, key string, values ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -301,7 +218,7 @@ func (s *memoryStore) LPush(key string, values ...any) error {
return nil
}
func (s *memoryStore) LRem(key string, count int64, value any) error {
func (s *memoryStore) LRem(_ context.Context, key string, count int64, value any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -326,7 +243,7 @@ func (s *memoryStore) LRem(key string, count int64, value any) error {
return nil
}
func (s *memoryStore) SAdd(key string, members ...any) error {
func (s *memoryStore) SAdd(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -345,7 +262,7 @@ func (s *memoryStore) SAdd(key string, members ...any) error {
return nil
}
func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
func (s *memoryStore) SPopN(_ context.Context, key string, count int64) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -375,7 +292,7 @@ func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
return popped, nil
}
func (s *memoryStore) SMembers(key string) ([]string, error) {
func (s *memoryStore) SMembers(_ context.Context, key string) ([]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -393,7 +310,7 @@ func (s *memoryStore) SMembers(key string) ([]string, error) {
return members, nil
}
func (s *memoryStore) SRem(key string, members ...any) error {
func (s *memoryStore) SRem(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -410,7 +327,31 @@ func (s *memoryStore) SRem(key string, members ...any) error {
return nil
}
func (s *memoryStore) Rotate(key string) (string, error) {
func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return "", ErrNotFound
}
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
if len(members) == 0 {
return "", ErrNotFound
}
s.rngMu.Lock()
n := s.rng.Intn(len(members))
s.rngMu.Unlock()
return members[n], nil
}
func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -426,7 +367,7 @@ func (s *memoryStore) Rotate(key string) (string, error) {
return val, nil
}
func (s *memoryStore) LIndex(key string, index int64) (string, error) {
func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -447,8 +388,7 @@ func (s *memoryStore) LIndex(key string, index int64) (string, error) {
return list[index], nil
}
// Zset methods... (ZAdd, ZRange, ZRem)
func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -471,8 +411,6 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
for val, score := range membersMap {
newZSet = append(newZSet, zsetMember{Value: val, Score: score})
}
// NOTE: This ZSet implementation is simple but not performant for large sets.
// A production implementation would use a skip list or a balanced tree.
sort.Slice(newZSet, func(i, j int) bool {
if newZSet[i].Score == newZSet[j].Score {
return newZSet[i].Value < newZSet[j].Value
@@ -482,7 +420,7 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
item.value = newZSet
return nil
}
func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
func (s *memoryStore) ZRange(_ context.Context, key string, start, stop int64) ([]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -515,7 +453,7 @@ func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
}
return result, nil
}
func (s *memoryStore) ZRem(key string, members ...any) error {
func (s *memoryStore) ZRem(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -540,13 +478,56 @@ func (s *memoryStore) ZRem(key string, members ...any) error {
return nil
}
// Pipeline implementation
func (s *memoryStore) PopAndCycleSetMember(_ context.Context, mainKey, cooldownKey string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
mainItem, mainOk := s.items[mainKey]
var mainSet map[string]struct{}
if mainOk && !mainItem.isExpired() {
mainSet, mainOk = mainItem.value.(map[string]struct{})
mainOk = mainOk && len(mainSet) > 0
} else {
mainOk = false
}
if !mainOk {
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
return "", ErrNotFound
}
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
if !cooldownSetOk || len(cooldownSet) == 0 {
return "", ErrNotFound
}
s.items[mainKey] = cooldownItem
delete(s.items, cooldownKey)
mainSet = cooldownSet
}
var popped string
for k := range mainSet {
popped = k
break
}
delete(mainSet, popped)
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
s.items[cooldownKey] = cooldownItem
}
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
if !ok {
cooldownSet = make(map[string]struct{})
cooldownItem.value = cooldownSet
}
cooldownSet[popped] = struct{}{}
return popped, nil
}
type memoryPipeliner struct {
store *memoryStore
ops []func()
}
func (s *memoryStore) Pipeline() Pipeliner {
func (s *memoryStore) Pipeline(_ context.Context) Pipeliner {
return &memoryPipeliner{store: s}
}
func (p *memoryPipeliner) Exec() error {
@@ -559,7 +540,6 @@ func (p *memoryPipeliner) Exec() error {
}
func (p *memoryPipeliner) Expire(key string, expiration time.Duration) {
// [USER SUGGESTION APPLIED] Fix #4: Capture value, not reference
capturedKey := key
p.ops = append(p.ops, func() {
if item, ok := p.store.items[capturedKey]; ok {
@@ -596,7 +576,6 @@ func (p *memoryPipeliner) SAdd(key string, members ...any) {
}
})
}
func (p *memoryPipeliner) SRem(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
@@ -615,7 +594,6 @@ func (p *memoryPipeliner) SRem(key string, members ...any) {
}
})
}
func (p *memoryPipeliner) LPush(key string, values ...any) {
capturedKey := key
capturedValues := make([]any, len(values))
@@ -637,11 +615,12 @@ func (p *memoryPipeliner) LPush(key string, values ...any) {
item.value = append(stringValues, list...)
})
}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {}
func (p *memoryPipeliner) ZRem(key string, members ...any) {}
// --- Pub/Sub implementation (remains unchanged) ---
type memorySubscription struct {
store *memoryStore
channelName string
@@ -649,10 +628,11 @@ type memorySubscription struct {
}
func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan }
func (ms *memorySubscription) ChannelName() string { return ms.channelName }
func (ms *memorySubscription) Close() error {
return ms.store.removeSubscriber(ms.channelName, ms.msgChan)
}
func (s *memoryStore) Publish(channel string, message []byte) error {
func (s *memoryStore) Publish(_ context.Context, channel string, message []byte) error {
s.mu.RLock()
defer s.mu.RUnlock()
subscribers, ok := s.pubsub[channel]
@@ -669,7 +649,7 @@ func (s *memoryStore) Publish(channel string, message []byte) error {
}
return nil
}
func (s *memoryStore) Subscribe(channel string) (Subscription, error) {
func (s *memoryStore) Subscribe(_ context.Context, channel string) (Subscription, error) {
s.mu.Lock()
defer s.mu.Unlock()
msgChan := make(chan *Message, 10)