Update Context for store
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
// Filename: internal/store/memory_store.go (经同行审查后最终修复版)
|
||||
// Filename: internal/store/memory_store.go
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ensure memoryStore implements Store interface
|
||||
var _ Store = (*memoryStore)(nil)
|
||||
|
||||
type memoryStoreItem struct {
|
||||
@@ -32,7 +34,6 @@ type memoryStore struct {
|
||||
items map[string]*memoryStoreItem
|
||||
pubsub map[string][]chan *Message
|
||||
mu sync.RWMutex
|
||||
// [USER SUGGESTION APPLIED] 使用带锁的随机数源以保证并发安全
|
||||
rng *rand.Rand
|
||||
rngMu sync.Mutex
|
||||
logger *logrus.Entry
|
||||
@@ -42,7 +43,6 @@ func NewMemoryStore(logger *logrus.Logger) Store {
|
||||
store := &memoryStore{
|
||||
items: make(map[string]*memoryStoreItem),
|
||||
pubsub: make(map[string][]chan *Message),
|
||||
// 使用当前时间作为种子,创建一个新的随机数源
|
||||
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
logger: logger.WithField("component", "store.memory 🗱"),
|
||||
}
|
||||
@@ -50,13 +50,12 @@ func NewMemoryStore(logger *logrus.Logger) Store {
|
||||
return store
|
||||
}
|
||||
|
||||
// [USER SUGGESTION INCORPORATED] Fix #1: 使用 now := time.Now() 进行原子性检查
|
||||
func (s *memoryStore) startGCollector() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
s.mu.Lock()
|
||||
now := time.Now() // 避免在循环中重复调用
|
||||
now := time.Now()
|
||||
for key, item := range s.items {
|
||||
if !item.expireAt.IsZero() && now.After(item.expireAt) {
|
||||
delete(s.items, key)
|
||||
@@ -66,92 +65,10 @@ func (s *memoryStore) startGCollector() {
|
||||
}
|
||||
}
|
||||
|
||||
// [USER SUGGESTION INCORPORATED] Fix #2 & #3: 修复了致命的nil检查和类型断言问题
|
||||
func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
// --- [架构修正] 所有方法签名都增加了 context.Context 参数以匹配接口 ---
|
||||
// --- 内存实现可以忽略该参数,用 _ 接收 ---
|
||||
|
||||
mainItem, mainOk := s.items[mainKey]
|
||||
var mainSet map[string]struct{}
|
||||
|
||||
if mainOk && !mainItem.isExpired() {
|
||||
// 安全地进行类型断言
|
||||
mainSet, mainOk = mainItem.value.(map[string]struct{})
|
||||
// 确保断言成功且集合不为空
|
||||
mainOk = mainOk && len(mainSet) > 0
|
||||
} else {
|
||||
mainOk = false
|
||||
}
|
||||
|
||||
if !mainOk {
|
||||
cooldownItem, cooldownOk := s.items[cooldownKey]
|
||||
if !cooldownOk || cooldownItem.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
// 安全地进行类型断言
|
||||
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
|
||||
if !cooldownSetOk || len(cooldownSet) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
s.items[mainKey] = cooldownItem
|
||||
delete(s.items, cooldownKey)
|
||||
mainSet = cooldownSet
|
||||
}
|
||||
|
||||
var popped string
|
||||
for k := range mainSet {
|
||||
popped = k
|
||||
break
|
||||
}
|
||||
delete(mainSet, popped)
|
||||
|
||||
cooldownItem, cooldownOk := s.items[cooldownKey]
|
||||
if !cooldownOk || cooldownItem.isExpired() {
|
||||
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
|
||||
s.items[cooldownKey] = cooldownItem
|
||||
}
|
||||
// 安全地处理冷却池
|
||||
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
|
||||
if !ok {
|
||||
cooldownSet = make(map[string]struct{})
|
||||
cooldownItem.value = cooldownSet
|
||||
}
|
||||
cooldownSet[popped] = struct{}{}
|
||||
|
||||
return popped, nil
|
||||
}
|
||||
|
||||
// SRandMember [并发修复版] 使用带锁的rng
|
||||
func (s *memoryStore) SRandMember(key string) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok || len(set) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
members := make([]string, 0, len(set))
|
||||
for member := range set {
|
||||
members = append(members, member)
|
||||
}
|
||||
if len(members) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
s.rngMu.Lock()
|
||||
n := s.rng.Intn(len(members))
|
||||
s.rngMu.Unlock()
|
||||
|
||||
return members[n], nil
|
||||
}
|
||||
|
||||
// --- 以下是其余函数的最终版本,它们都遵循了安全、原子的锁策略 ---
|
||||
|
||||
func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
|
||||
func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var expireAt time.Time
|
||||
@@ -162,7 +79,7 @@ func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Get(key string) ([]byte, error) {
|
||||
func (s *memoryStore) Get(_ context.Context, key string) ([]byte, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -175,7 +92,7 @@ func (s *memoryStore) Get(key string) ([]byte, error) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
func (s *memoryStore) Del(keys ...string) error {
|
||||
func (s *memoryStore) Del(_ context.Context, keys ...string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, key := range keys {
|
||||
@@ -184,14 +101,14 @@ func (s *memoryStore) Del(keys ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Exists(key string) (bool, error) {
|
||||
func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
return ok && !item.isExpired(), nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -208,7 +125,7 @@ func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool,
|
||||
|
||||
func (s *memoryStore) Close() error { return nil }
|
||||
|
||||
func (s *memoryStore) HDel(key string, fields ...string) error {
|
||||
func (s *memoryStore) HDel(_ context.Context, key string, fields ...string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -223,7 +140,7 @@ func (s *memoryStore) HDel(key string, fields ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) HSet(key string, values map[string]any) error {
|
||||
func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -242,7 +159,7 @@ func (s *memoryStore) HSet(key string, values map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
|
||||
func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -259,7 +176,7 @@ func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
|
||||
return make(map[string]string), nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
|
||||
func (s *memoryStore) HIncrBy(_ context.Context, key, field string, incr int64) (int64, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -281,7 +198,7 @@ func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
|
||||
return newVal, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) LPush(key string, values ...any) error {
|
||||
func (s *memoryStore) LPush(_ context.Context, key string, values ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -301,7 +218,7 @@ func (s *memoryStore) LPush(key string, values ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) LRem(key string, count int64, value any) error {
|
||||
func (s *memoryStore) LRem(_ context.Context, key string, count int64, value any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -326,7 +243,7 @@ func (s *memoryStore) LRem(key string, count int64, value any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SAdd(key string, members ...any) error {
|
||||
func (s *memoryStore) SAdd(_ context.Context, key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -345,7 +262,7 @@ func (s *memoryStore) SAdd(key string, members ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
|
||||
func (s *memoryStore) SPopN(_ context.Context, key string, count int64) ([]string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -375,7 +292,7 @@ func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
|
||||
return popped, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SMembers(key string) ([]string, error) {
|
||||
func (s *memoryStore) SMembers(_ context.Context, key string) ([]string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -393,7 +310,7 @@ func (s *memoryStore) SMembers(key string) ([]string, error) {
|
||||
return members, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SRem(key string, members ...any) error {
|
||||
func (s *memoryStore) SRem(_ context.Context, key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -410,7 +327,31 @@ func (s *memoryStore) SRem(key string, members ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Rotate(key string) (string, error) {
|
||||
func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok || len(set) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
members := make([]string, 0, len(set))
|
||||
for member := range set {
|
||||
members = append(members, member)
|
||||
}
|
||||
if len(members) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
s.rngMu.Lock()
|
||||
n := s.rng.Intn(len(members))
|
||||
s.rngMu.Unlock()
|
||||
return members[n], nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -426,7 +367,7 @@ func (s *memoryStore) Rotate(key string) (string, error) {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) LIndex(key string, index int64) (string, error) {
|
||||
func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -447,8 +388,7 @@ func (s *memoryStore) LIndex(key string, index int64) (string, error) {
|
||||
return list[index], nil
|
||||
}
|
||||
|
||||
// Zset methods... (ZAdd, ZRange, ZRem)
|
||||
func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
|
||||
func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -471,8 +411,6 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
|
||||
for val, score := range membersMap {
|
||||
newZSet = append(newZSet, zsetMember{Value: val, Score: score})
|
||||
}
|
||||
// NOTE: This ZSet implementation is simple but not performant for large sets.
|
||||
// A production implementation would use a skip list or a balanced tree.
|
||||
sort.Slice(newZSet, func(i, j int) bool {
|
||||
if newZSet[i].Score == newZSet[j].Score {
|
||||
return newZSet[i].Value < newZSet[j].Value
|
||||
@@ -482,7 +420,7 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
|
||||
item.value = newZSet
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
|
||||
func (s *memoryStore) ZRange(_ context.Context, key string, start, stop int64) ([]string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -515,7 +453,7 @@ func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (s *memoryStore) ZRem(key string, members ...any) error {
|
||||
func (s *memoryStore) ZRem(_ context.Context, key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -540,13 +478,56 @@ func (s *memoryStore) ZRem(key string, members ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pipeline implementation
|
||||
func (s *memoryStore) PopAndCycleSetMember(_ context.Context, mainKey, cooldownKey string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
mainItem, mainOk := s.items[mainKey]
|
||||
var mainSet map[string]struct{}
|
||||
if mainOk && !mainItem.isExpired() {
|
||||
mainSet, mainOk = mainItem.value.(map[string]struct{})
|
||||
mainOk = mainOk && len(mainSet) > 0
|
||||
} else {
|
||||
mainOk = false
|
||||
}
|
||||
if !mainOk {
|
||||
cooldownItem, cooldownOk := s.items[cooldownKey]
|
||||
if !cooldownOk || cooldownItem.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
|
||||
if !cooldownSetOk || len(cooldownSet) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
s.items[mainKey] = cooldownItem
|
||||
delete(s.items, cooldownKey)
|
||||
mainSet = cooldownSet
|
||||
}
|
||||
var popped string
|
||||
for k := range mainSet {
|
||||
popped = k
|
||||
break
|
||||
}
|
||||
delete(mainSet, popped)
|
||||
cooldownItem, cooldownOk := s.items[cooldownKey]
|
||||
if !cooldownOk || cooldownItem.isExpired() {
|
||||
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
|
||||
s.items[cooldownKey] = cooldownItem
|
||||
}
|
||||
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
|
||||
if !ok {
|
||||
cooldownSet = make(map[string]struct{})
|
||||
cooldownItem.value = cooldownSet
|
||||
}
|
||||
cooldownSet[popped] = struct{}{}
|
||||
return popped, nil
|
||||
}
|
||||
|
||||
type memoryPipeliner struct {
|
||||
store *memoryStore
|
||||
ops []func()
|
||||
}
|
||||
|
||||
func (s *memoryStore) Pipeline() Pipeliner {
|
||||
func (s *memoryStore) Pipeline(_ context.Context) Pipeliner {
|
||||
return &memoryPipeliner{store: s}
|
||||
}
|
||||
func (p *memoryPipeliner) Exec() error {
|
||||
@@ -559,7 +540,6 @@ func (p *memoryPipeliner) Exec() error {
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) Expire(key string, expiration time.Duration) {
|
||||
// [USER SUGGESTION APPLIED] Fix #4: Capture value, not reference
|
||||
capturedKey := key
|
||||
p.ops = append(p.ops, func() {
|
||||
if item, ok := p.store.items[capturedKey]; ok {
|
||||
@@ -596,7 +576,6 @@ func (p *memoryPipeliner) SAdd(key string, members ...any) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) SRem(key string, members ...any) {
|
||||
capturedKey := key
|
||||
capturedMembers := make([]any, len(members))
|
||||
@@ -615,7 +594,6 @@ func (p *memoryPipeliner) SRem(key string, members ...any) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) LPush(key string, values ...any) {
|
||||
capturedKey := key
|
||||
capturedValues := make([]any, len(values))
|
||||
@@ -637,11 +615,12 @@ func (p *memoryPipeliner) LPush(key string, values ...any) {
|
||||
item.value = append(stringValues, list...)
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
|
||||
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
|
||||
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
|
||||
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
|
||||
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
|
||||
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
|
||||
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {}
|
||||
func (p *memoryPipeliner) ZRem(key string, members ...any) {}
|
||||
|
||||
// --- Pub/Sub implementation (remains unchanged) ---
|
||||
type memorySubscription struct {
|
||||
store *memoryStore
|
||||
channelName string
|
||||
@@ -649,10 +628,11 @@ type memorySubscription struct {
|
||||
}
|
||||
|
||||
func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan }
|
||||
func (ms *memorySubscription) ChannelName() string { return ms.channelName }
|
||||
func (ms *memorySubscription) Close() error {
|
||||
return ms.store.removeSubscriber(ms.channelName, ms.msgChan)
|
||||
}
|
||||
func (s *memoryStore) Publish(channel string, message []byte) error {
|
||||
func (s *memoryStore) Publish(_ context.Context, channel string, message []byte) error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
subscribers, ok := s.pubsub[channel]
|
||||
@@ -669,7 +649,7 @@ func (s *memoryStore) Publish(channel string, message []byte) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) Subscribe(channel string) (Subscription, error) {
|
||||
func (s *memoryStore) Subscribe(_ context.Context, channel string) (Subscription, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
msgChan := make(chan *Message, 10)
|
||||
|
||||
Reference in New Issue
Block a user