Update Context for store

This commit is contained in:
XOF
2025-11-22 14:20:05 +08:00
parent ac0e0a8275
commit 2b0b9b67dc
31 changed files with 817 additions and 1016 deletions

View File

@@ -1,3 +1,4 @@
// Filename: internal/store/factory.go
package store
import (
@@ -11,7 +12,6 @@ import (
// NewStore creates a new store based on the application configuration.
func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) {
// 检查是否有Redis配置
if cfg.Redis.DSN != "" {
opts, err := redis.ParseURL(cfg.Redis.DSN)
if err != nil {
@@ -20,10 +20,10 @@ func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) {
client := redis.NewClient(opts)
if err := client.Ping(context.Background()).Err(); err != nil {
logger.WithError(err).Warnf("WARN: Failed to connect to Redis (%s), falling back to in-memory store. Error: %v", cfg.Redis.DSN, err)
return NewMemoryStore(logger), nil // 连接失败,也回退到内存模式,但不返回错误
return NewMemoryStore(logger), nil
}
logger.Info("Successfully connected to Redis. Using Redis as store.")
return NewRedisStore(client), nil
return NewRedisStore(client, logger), nil
}
logger.Info("INFO: Redis DSN not configured, falling back to in-memory store.")
return NewMemoryStore(logger), nil

View File

@@ -1,8 +1,9 @@
// Filename: internal/store/memory_store.go (经同行审查后最终修复版)
// Filename: internal/store/memory_store.go
package store
import (
"context"
"fmt"
"math/rand"
"sort"
@@ -12,6 +13,7 @@ import (
"github.com/sirupsen/logrus"
)
// ensure memoryStore implements Store interface
var _ Store = (*memoryStore)(nil)
type memoryStoreItem struct {
@@ -32,7 +34,6 @@ type memoryStore struct {
items map[string]*memoryStoreItem
pubsub map[string][]chan *Message
mu sync.RWMutex
// [USER SUGGESTION APPLIED] 使用带锁的随机数源以保证并发安全
rng *rand.Rand
rngMu sync.Mutex
logger *logrus.Entry
@@ -42,7 +43,6 @@ func NewMemoryStore(logger *logrus.Logger) Store {
store := &memoryStore{
items: make(map[string]*memoryStoreItem),
pubsub: make(map[string][]chan *Message),
// 使用当前时间作为种子,创建一个新的随机数源
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
logger: logger.WithField("component", "store.memory 🗱"),
}
@@ -50,13 +50,12 @@ func NewMemoryStore(logger *logrus.Logger) Store {
return store
}
// [USER SUGGESTION INCORPORATED] Fix #1: 使用 now := time.Now() 进行原子性检查
func (s *memoryStore) startGCollector() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
s.mu.Lock()
now := time.Now() // 避免在循环中重复调用
now := time.Now()
for key, item := range s.items {
if !item.expireAt.IsZero() && now.After(item.expireAt) {
delete(s.items, key)
@@ -66,92 +65,10 @@ func (s *memoryStore) startGCollector() {
}
}
// [USER SUGGESTION INCORPORATED] Fix #2 & #3: 修复了致命的nil检查和类型断言问题
func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
// --- [架构修正] 所有方法签名都增加了 context.Context 参数以匹配接口 ---
// --- 内存实现可以忽略该参数,用 _ 接收 ---
mainItem, mainOk := s.items[mainKey]
var mainSet map[string]struct{}
if mainOk && !mainItem.isExpired() {
// 安全地进行类型断言
mainSet, mainOk = mainItem.value.(map[string]struct{})
// 确保断言成功且集合不为空
mainOk = mainOk && len(mainSet) > 0
} else {
mainOk = false
}
if !mainOk {
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
return "", ErrNotFound
}
// 安全地进行类型断言
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
if !cooldownSetOk || len(cooldownSet) == 0 {
return "", ErrNotFound
}
s.items[mainKey] = cooldownItem
delete(s.items, cooldownKey)
mainSet = cooldownSet
}
var popped string
for k := range mainSet {
popped = k
break
}
delete(mainSet, popped)
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
s.items[cooldownKey] = cooldownItem
}
// 安全地处理冷却池
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
if !ok {
cooldownSet = make(map[string]struct{})
cooldownItem.value = cooldownSet
}
cooldownSet[popped] = struct{}{}
return popped, nil
}
// SRandMember [并发修复版] 使用带锁的rng
func (s *memoryStore) SRandMember(key string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return "", ErrNotFound
}
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
if len(members) == 0 {
return "", ErrNotFound
}
s.rngMu.Lock()
n := s.rng.Intn(len(members))
s.rngMu.Unlock()
return members[n], nil
}
// --- 以下是其余函数的最终版本,它们都遵循了安全、原子的锁策略 ---
func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
var expireAt time.Time
@@ -162,7 +79,7 @@ func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
return nil
}
func (s *memoryStore) Get(key string) ([]byte, error) {
func (s *memoryStore) Get(_ context.Context, key string) ([]byte, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -175,7 +92,7 @@ func (s *memoryStore) Get(key string) ([]byte, error) {
return nil, ErrNotFound
}
func (s *memoryStore) Del(keys ...string) error {
func (s *memoryStore) Del(_ context.Context, keys ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
for _, key := range keys {
@@ -184,14 +101,14 @@ func (s *memoryStore) Del(keys ...string) error {
return nil
}
func (s *memoryStore) Exists(key string) (bool, error) {
func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
return ok && !item.isExpired(), nil
}
func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -208,7 +125,7 @@ func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool,
func (s *memoryStore) Close() error { return nil }
func (s *memoryStore) HDel(key string, fields ...string) error {
func (s *memoryStore) HDel(_ context.Context, key string, fields ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -223,7 +140,7 @@ func (s *memoryStore) HDel(key string, fields ...string) error {
return nil
}
func (s *memoryStore) HSet(key string, values map[string]any) error {
func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -242,7 +159,7 @@ func (s *memoryStore) HSet(key string, values map[string]any) error {
return nil
}
func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -259,7 +176,7 @@ func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
return make(map[string]string), nil
}
func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
func (s *memoryStore) HIncrBy(_ context.Context, key, field string, incr int64) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -281,7 +198,7 @@ func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
return newVal, nil
}
func (s *memoryStore) LPush(key string, values ...any) error {
func (s *memoryStore) LPush(_ context.Context, key string, values ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -301,7 +218,7 @@ func (s *memoryStore) LPush(key string, values ...any) error {
return nil
}
func (s *memoryStore) LRem(key string, count int64, value any) error {
func (s *memoryStore) LRem(_ context.Context, key string, count int64, value any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -326,7 +243,7 @@ func (s *memoryStore) LRem(key string, count int64, value any) error {
return nil
}
func (s *memoryStore) SAdd(key string, members ...any) error {
func (s *memoryStore) SAdd(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -345,7 +262,7 @@ func (s *memoryStore) SAdd(key string, members ...any) error {
return nil
}
func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
func (s *memoryStore) SPopN(_ context.Context, key string, count int64) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -375,7 +292,7 @@ func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
return popped, nil
}
func (s *memoryStore) SMembers(key string) ([]string, error) {
func (s *memoryStore) SMembers(_ context.Context, key string) ([]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -393,7 +310,7 @@ func (s *memoryStore) SMembers(key string) ([]string, error) {
return members, nil
}
func (s *memoryStore) SRem(key string, members ...any) error {
func (s *memoryStore) SRem(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -410,7 +327,31 @@ func (s *memoryStore) SRem(key string, members ...any) error {
return nil
}
func (s *memoryStore) Rotate(key string) (string, error) {
func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return "", ErrNotFound
}
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
if len(members) == 0 {
return "", ErrNotFound
}
s.rngMu.Lock()
n := s.rng.Intn(len(members))
s.rngMu.Unlock()
return members[n], nil
}
func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -426,7 +367,7 @@ func (s *memoryStore) Rotate(key string) (string, error) {
return val, nil
}
func (s *memoryStore) LIndex(key string, index int64) (string, error) {
func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -447,8 +388,7 @@ func (s *memoryStore) LIndex(key string, index int64) (string, error) {
return list[index], nil
}
// Zset methods... (ZAdd, ZRange, ZRem)
func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -471,8 +411,6 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
for val, score := range membersMap {
newZSet = append(newZSet, zsetMember{Value: val, Score: score})
}
// NOTE: This ZSet implementation is simple but not performant for large sets.
// A production implementation would use a skip list or a balanced tree.
sort.Slice(newZSet, func(i, j int) bool {
if newZSet[i].Score == newZSet[j].Score {
return newZSet[i].Value < newZSet[j].Value
@@ -482,7 +420,7 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
item.value = newZSet
return nil
}
func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
func (s *memoryStore) ZRange(_ context.Context, key string, start, stop int64) ([]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -515,7 +453,7 @@ func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
}
return result, nil
}
func (s *memoryStore) ZRem(key string, members ...any) error {
func (s *memoryStore) ZRem(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -540,13 +478,56 @@ func (s *memoryStore) ZRem(key string, members ...any) error {
return nil
}
// Pipeline implementation
func (s *memoryStore) PopAndCycleSetMember(_ context.Context, mainKey, cooldownKey string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
mainItem, mainOk := s.items[mainKey]
var mainSet map[string]struct{}
if mainOk && !mainItem.isExpired() {
mainSet, mainOk = mainItem.value.(map[string]struct{})
mainOk = mainOk && len(mainSet) > 0
} else {
mainOk = false
}
if !mainOk {
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
return "", ErrNotFound
}
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
if !cooldownSetOk || len(cooldownSet) == 0 {
return "", ErrNotFound
}
s.items[mainKey] = cooldownItem
delete(s.items, cooldownKey)
mainSet = cooldownSet
}
var popped string
for k := range mainSet {
popped = k
break
}
delete(mainSet, popped)
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
s.items[cooldownKey] = cooldownItem
}
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
if !ok {
cooldownSet = make(map[string]struct{})
cooldownItem.value = cooldownSet
}
cooldownSet[popped] = struct{}{}
return popped, nil
}
type memoryPipeliner struct {
store *memoryStore
ops []func()
}
func (s *memoryStore) Pipeline() Pipeliner {
func (s *memoryStore) Pipeline(_ context.Context) Pipeliner {
return &memoryPipeliner{store: s}
}
func (p *memoryPipeliner) Exec() error {
@@ -559,7 +540,6 @@ func (p *memoryPipeliner) Exec() error {
}
func (p *memoryPipeliner) Expire(key string, expiration time.Duration) {
// [USER SUGGESTION APPLIED] Fix #4: Capture value, not reference
capturedKey := key
p.ops = append(p.ops, func() {
if item, ok := p.store.items[capturedKey]; ok {
@@ -596,7 +576,6 @@ func (p *memoryPipeliner) SAdd(key string, members ...any) {
}
})
}
func (p *memoryPipeliner) SRem(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
@@ -615,7 +594,6 @@ func (p *memoryPipeliner) SRem(key string, members ...any) {
}
})
}
func (p *memoryPipeliner) LPush(key string, values ...any) {
capturedKey := key
capturedValues := make([]any, len(values))
@@ -637,11 +615,12 @@ func (p *memoryPipeliner) LPush(key string, values ...any) {
item.value = append(stringValues, list...)
})
}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {}
func (p *memoryPipeliner) ZRem(key string, members ...any) {}
// --- Pub/Sub implementation (remains unchanged) ---
type memorySubscription struct {
store *memoryStore
channelName string
@@ -649,10 +628,11 @@ type memorySubscription struct {
}
func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan }
func (ms *memorySubscription) ChannelName() string { return ms.channelName }
func (ms *memorySubscription) Close() error {
return ms.store.removeSubscriber(ms.channelName, ms.msgChan)
}
func (s *memoryStore) Publish(channel string, message []byte) error {
func (s *memoryStore) Publish(_ context.Context, channel string, message []byte) error {
s.mu.RLock()
defer s.mu.RUnlock()
subscribers, ok := s.pubsub[channel]
@@ -669,7 +649,7 @@ func (s *memoryStore) Publish(channel string, message []byte) error {
}
return nil
}
func (s *memoryStore) Subscribe(channel string) (Subscription, error) {
func (s *memoryStore) Subscribe(_ context.Context, channel string) (Subscription, error) {
s.mu.Lock()
defer s.mu.Unlock()
msgChan := make(chan *Message, 10)

View File

@@ -1,3 +1,5 @@
// Filename: internal/store/redis_store.go
package store
import (
@@ -8,22 +10,20 @@ import (
"time"
"github.com/redis/go-redis/v9"
"github.com/sirupsen/logrus"
)
// ensure RedisStore implements Store interface
var _ Store = (*RedisStore)(nil)
// RedisStore is a Redis-backed key-value store.
type RedisStore struct {
client *redis.Client
popAndCycleScript *redis.Script
logger *logrus.Entry
}
// NewRedisStore creates a new RedisStore instance.
func NewRedisStore(client *redis.Client) Store {
// Lua script for atomic pop-and-cycle operation.
// KEYS[1]: main set key
// KEYS[2]: cooldown set key
func NewRedisStore(client *redis.Client, logger *logrus.Logger) Store {
const script = `
if redis.call('SCARD', KEYS[1]) == 0 then
if redis.call('SCARD', KEYS[2]) == 0 then
@@ -36,15 +36,16 @@ func NewRedisStore(client *redis.Client) Store {
return &RedisStore{
client: client,
popAndCycleScript: redis.NewScript(script),
logger: logger.WithField("component", "store.redis 🗄️"),
}
}
func (s *RedisStore) Set(key string, value []byte, ttl time.Duration) error {
return s.client.Set(context.Background(), key, value, ttl).Err()
func (s *RedisStore) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
return s.client.Set(ctx, key, value, ttl).Err()
}
func (s *RedisStore) Get(key string) ([]byte, error) {
val, err := s.client.Get(context.Background(), key).Bytes()
func (s *RedisStore) Get(ctx context.Context, key string) ([]byte, error) {
val, err := s.client.Get(ctx, key).Bytes()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, ErrNotFound
@@ -54,53 +55,53 @@ func (s *RedisStore) Get(key string) ([]byte, error) {
return val, nil
}
func (s *RedisStore) Del(keys ...string) error {
func (s *RedisStore) Del(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}
return s.client.Del(context.Background(), keys...).Err()
return s.client.Del(ctx, keys...).Err()
}
func (s *RedisStore) Exists(key string) (bool, error) {
val, err := s.client.Exists(context.Background(), key).Result()
func (s *RedisStore) Exists(ctx context.Context, key string) (bool, error) {
val, err := s.client.Exists(ctx, key).Result()
return val > 0, err
}
func (s *RedisStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
return s.client.SetNX(context.Background(), key, value, ttl).Result()
func (s *RedisStore) SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
return s.client.SetNX(ctx, key, value, ttl).Result()
}
func (s *RedisStore) Close() error {
return s.client.Close()
}
func (s *RedisStore) HSet(key string, values map[string]any) error {
return s.client.HSet(context.Background(), key, values).Err()
func (s *RedisStore) HSet(ctx context.Context, key string, values map[string]any) error {
return s.client.HSet(ctx, key, values).Err()
}
func (s *RedisStore) HGetAll(key string) (map[string]string, error) {
return s.client.HGetAll(context.Background(), key).Result()
func (s *RedisStore) HGetAll(ctx context.Context, key string) (map[string]string, error) {
return s.client.HGetAll(ctx, key).Result()
}
func (s *RedisStore) HIncrBy(key, field string, incr int64) (int64, error) {
return s.client.HIncrBy(context.Background(), key, field, incr).Result()
func (s *RedisStore) HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error) {
return s.client.HIncrBy(ctx, key, field, incr).Result()
}
func (s *RedisStore) HDel(key string, fields ...string) error {
func (s *RedisStore) HDel(ctx context.Context, key string, fields ...string) error {
if len(fields) == 0 {
return nil
}
return s.client.HDel(context.Background(), key, fields...).Err()
return s.client.HDel(ctx, key, fields...).Err()
}
func (s *RedisStore) LPush(key string, values ...any) error {
return s.client.LPush(context.Background(), key, values...).Err()
func (s *RedisStore) LPush(ctx context.Context, key string, values ...any) error {
return s.client.LPush(ctx, key, values...).Err()
}
func (s *RedisStore) LRem(key string, count int64, value any) error {
return s.client.LRem(context.Background(), key, count, value).Err()
func (s *RedisStore) LRem(ctx context.Context, key string, count int64, value any) error {
return s.client.LRem(ctx, key, count, value).Err()
}
func (s *RedisStore) Rotate(key string) (string, error) {
val, err := s.client.RPopLPush(context.Background(), key, key).Result()
func (s *RedisStore) Rotate(ctx context.Context, key string) (string, error) {
val, err := s.client.RPopLPush(ctx, key, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
@@ -110,29 +111,28 @@ func (s *RedisStore) Rotate(key string) (string, error) {
return val, nil
}
func (s *RedisStore) SAdd(key string, members ...any) error {
return s.client.SAdd(context.Background(), key, members...).Err()
func (s *RedisStore) SAdd(ctx context.Context, key string, members ...any) error {
return s.client.SAdd(ctx, key, members...).Err()
}
func (s *RedisStore) SPopN(key string, count int64) ([]string, error) {
return s.client.SPopN(context.Background(), key, count).Result()
func (s *RedisStore) SPopN(ctx context.Context, key string, count int64) ([]string, error) {
return s.client.SPopN(ctx, key, count).Result()
}
func (s *RedisStore) SMembers(key string) ([]string, error) {
return s.client.SMembers(context.Background(), key).Result()
func (s *RedisStore) SMembers(ctx context.Context, key string) ([]string, error) {
return s.client.SMembers(ctx, key).Result()
}
func (s *RedisStore) SRem(key string, members ...any) error {
func (s *RedisStore) SRem(ctx context.Context, key string, members ...any) error {
if len(members) == 0 {
return nil
}
return s.client.SRem(context.Background(), key, members...).Err()
return s.client.SRem(ctx, key, members...).Err()
}
func (s *RedisStore) SRandMember(key string) (string, error) {
member, err := s.client.SRandMember(context.Background(), key).Result()
func (s *RedisStore) SRandMember(ctx context.Context, key string) (string, error) {
member, err := s.client.SRandMember(ctx, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
}
@@ -141,81 +141,43 @@ func (s *RedisStore) SRandMember(key string) (string, error) {
return member, nil
}
// === 新增方法实现 ===
func (s *RedisStore) ZAdd(key string, members map[string]float64) error {
func (s *RedisStore) ZAdd(ctx context.Context, key string, members map[string]float64) error {
if len(members) == 0 {
return nil
}
redisMembers := make([]redis.Z, 0, len(members))
redisMembers := make([]redis.Z, len(members))
i := 0
for member, score := range members {
redisMembers = append(redisMembers, redis.Z{Score: score, Member: member})
redisMembers[i] = redis.Z{Score: score, Member: member}
i++
}
return s.client.ZAdd(context.Background(), key, redisMembers...).Err()
return s.client.ZAdd(ctx, key, redisMembers...).Err()
}
func (s *RedisStore) ZRange(key string, start, stop int64) ([]string, error) {
return s.client.ZRange(context.Background(), key, start, stop).Result()
func (s *RedisStore) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) {
return s.client.ZRange(ctx, key, start, stop).Result()
}
func (s *RedisStore) ZRem(key string, members ...any) error {
func (s *RedisStore) ZRem(ctx context.Context, key string, members ...any) error {
if len(members) == 0 {
return nil
}
return s.client.ZRem(context.Background(), key, members...).Err()
return s.client.ZRem(ctx, key, members...).Err()
}
func (s *RedisStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
val, err := s.popAndCycleScript.Run(context.Background(), s.client, []string{mainKey, cooldownKey}).Result()
func (s *RedisStore) PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error) {
val, err := s.popAndCycleScript.Run(ctx, s.client, []string{mainKey, cooldownKey}).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
}
return "", err
}
// Lua script returns a string, so we need to type assert
if str, ok := val.(string); ok {
return str, nil
}
return "", ErrNotFound // This happens if both sets were empty and the script returned nil
return "", ErrNotFound
}
type redisPipeliner struct{ pipe redis.Pipeliner }
func (p *redisPipeliner) HSet(key string, values map[string]any) {
p.pipe.HSet(context.Background(), key, values)
}
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
p.pipe.HIncrBy(context.Background(), key, field, incr)
}
func (p *redisPipeliner) Exec() error {
_, err := p.pipe.Exec(context.Background())
return err
}
func (p *redisPipeliner) Del(keys ...string) {
if len(keys) > 0 {
p.pipe.Del(context.Background(), keys...)
}
}
func (p *redisPipeliner) SAdd(key string, members ...any) {
p.pipe.SAdd(context.Background(), key, members...)
}
func (p *redisPipeliner) SRem(key string, members ...any) {
if len(members) > 0 {
p.pipe.SRem(context.Background(), key, members...)
}
}
func (p *redisPipeliner) LPush(key string, values ...any) {
p.pipe.LPush(context.Background(), key, values...)
}
func (p *redisPipeliner) LRem(key string, count int64, value any) {
p.pipe.LRem(context.Background(), key, count, value)
}
func (s *RedisStore) LIndex(key string, index int64) (string, error) {
val, err := s.client.LIndex(context.Background(), key, index).Result()
func (s *RedisStore) LIndex(ctx context.Context, key string, index int64) (string, error) {
val, err := s.client.LIndex(ctx, key, index).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
@@ -225,47 +187,120 @@ func (s *RedisStore) LIndex(key string, index int64) (string, error) {
return val, nil
}
func (p *redisPipeliner) Expire(key string, expiration time.Duration) {
p.pipe.Expire(context.Background(), key, expiration)
type redisPipeliner struct {
pipe redis.Pipeliner
ctx context.Context
}
func (s *RedisStore) Pipeline() Pipeliner {
return &redisPipeliner{pipe: s.client.Pipeline()}
func (s *RedisStore) Pipeline(ctx context.Context) Pipeliner {
return &redisPipeliner{
pipe: s.client.Pipeline(),
ctx: ctx,
}
}
func (p *redisPipeliner) Exec() error {
_, err := p.pipe.Exec(p.ctx)
return err
}
func (p *redisPipeliner) Del(keys ...string) { p.pipe.Del(p.ctx, keys...) }
func (p *redisPipeliner) Expire(key string, expiration time.Duration) {
p.pipe.Expire(p.ctx, key, expiration)
}
func (p *redisPipeliner) HSet(key string, values map[string]any) { p.pipe.HSet(p.ctx, key, values) }
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
p.pipe.HIncrBy(p.ctx, key, field, incr)
}
func (p *redisPipeliner) LPush(key string, values ...any) { p.pipe.LPush(p.ctx, key, values...) }
func (p *redisPipeliner) LRem(key string, count int64, value any) {
p.pipe.LRem(p.ctx, key, count, value)
}
func (p *redisPipeliner) SAdd(key string, members ...any) { p.pipe.SAdd(p.ctx, key, members...) }
func (p *redisPipeliner) SRem(key string, members ...any) { p.pipe.SRem(p.ctx, key, members...) }
func (p *redisPipeliner) ZAdd(key string, members map[string]float64) {
if len(members) == 0 {
return
}
redisMembers := make([]redis.Z, len(members))
i := 0
for member, score := range members {
redisMembers[i] = redis.Z{Score: score, Member: member}
i++
}
p.pipe.ZAdd(p.ctx, key, redisMembers...)
}
func (p *redisPipeliner) ZRem(key string, members ...any) { p.pipe.ZRem(p.ctx, key, members...) }
type redisSubscription struct {
pubsub *redis.PubSub
msgChan chan *Message
once sync.Once
pubsub *redis.PubSub
msgChan chan *Message
logger *logrus.Entry
wg sync.WaitGroup
close context.CancelFunc
channelName string
}
func (s *RedisStore) Subscribe(ctx context.Context, channel string) (Subscription, error) {
pubsub := s.client.Subscribe(ctx, channel)
_, err := pubsub.Receive(ctx)
if err != nil {
_ = pubsub.Close()
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err)
}
subCtx, cancel := context.WithCancel(context.Background())
sub := &redisSubscription{
pubsub: pubsub,
msgChan: make(chan *Message, 10),
logger: s.logger,
close: cancel,
channelName: channel,
}
sub.wg.Add(1)
go sub.bridge(subCtx)
return sub, nil
}
func (rs *redisSubscription) bridge(ctx context.Context) {
defer rs.wg.Done()
defer close(rs.msgChan)
redisCh := rs.pubsub.Channel()
for {
select {
case <-ctx.Done():
return
case redisMsg, ok := <-redisCh:
if !ok {
return
}
msg := &Message{
Channel: redisMsg.Channel,
Payload: []byte(redisMsg.Payload),
}
select {
case rs.msgChan <- msg:
default:
rs.logger.Warnf("Message dropped for channel '%s' due to slow consumer.", rs.channelName)
}
}
}
}
func (rs *redisSubscription) Channel() <-chan *Message {
rs.once.Do(func() {
rs.msgChan = make(chan *Message)
go func() {
defer close(rs.msgChan)
for redisMsg := range rs.pubsub.Channel() {
rs.msgChan <- &Message{
Channel: redisMsg.Channel,
Payload: []byte(redisMsg.Payload),
}
}
}()
})
return rs.msgChan
}
func (rs *redisSubscription) Close() error { return rs.pubsub.Close() }
func (s *RedisStore) Publish(channel string, message []byte) error {
return s.client.Publish(context.Background(), channel, message).Err()
func (rs *redisSubscription) ChannelName() string {
return rs.channelName
}
func (s *RedisStore) Subscribe(channel string) (Subscription, error) {
pubsub := s.client.Subscribe(context.Background(), channel)
_, err := pubsub.Receive(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err)
}
return &redisSubscription{pubsub: pubsub}, nil
func (rs *redisSubscription) Close() error {
rs.close()
err := rs.pubsub.Close()
rs.wg.Wait()
return err
}
func (s *RedisStore) Publish(ctx context.Context, channel string, message []byte) error {
return s.client.Publish(ctx, channel, message).Err()
}

View File

@@ -1,6 +1,9 @@
// Filename: internal/store/store.go
package store
import (
"context"
"errors"
"time"
)
@@ -17,6 +20,7 @@ type Message struct {
// Subscription represents an active subscription to a pub/sub channel.
type Subscription interface {
Channel() <-chan *Message
ChannelName() string
Close() error
}
@@ -38,6 +42,10 @@ type Pipeliner interface {
LPush(key string, values ...any)
LRem(key string, count int64, value any)
// ZSET
ZAdd(key string, members map[string]float64)
ZRem(key string, members ...any)
// Execution
Exec() error
}
@@ -45,44 +53,44 @@ type Pipeliner interface {
// Store is the master interface for our cache service.
type Store interface {
// Basic K/V operations
Set(key string, value []byte, ttl time.Duration) error
Get(key string) ([]byte, error)
Del(keys ...string) error
Exists(key string) (bool, error)
SetNX(key string, value []byte, ttl time.Duration) (bool, error)
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
Get(ctx context.Context, key string) ([]byte, error)
Del(ctx context.Context, keys ...string) error
Exists(ctx context.Context, key string) (bool, error)
SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error)
// HASH operations
HSet(key string, values map[string]any) error
HGetAll(key string) (map[string]string, error)
HIncrBy(key, field string, incr int64) (int64, error)
HDel(key string, fields ...string) error // [新增]
HSet(ctx context.Context, key string, values map[string]any) error
HGetAll(ctx context.Context, key string) (map[string]string, error)
HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error)
HDel(ctx context.Context, key string, fields ...string) error
// LIST operations
LPush(key string, values ...any) error
LRem(key string, count int64, value any) error
Rotate(key string) (string, error)
LIndex(key string, index int64) (string, error)
LPush(ctx context.Context, key string, values ...any) error
LRem(ctx context.Context, key string, count int64, value any) error
Rotate(ctx context.Context, key string) (string, error)
LIndex(ctx context.Context, key string, index int64) (string, error)
// SET operations
SAdd(key string, members ...any) error
SPopN(key string, count int64) ([]string, error)
SMembers(key string) ([]string, error)
SRem(key string, members ...any) error
SRandMember(key string) (string, error)
SAdd(ctx context.Context, key string, members ...any) error
SPopN(ctx context.Context, key string, count int64) ([]string, error)
SMembers(ctx context.Context, key string) ([]string, error)
SRem(ctx context.Context, key string, members ...any) error
SRandMember(ctx context.Context, key string) (string, error)
// Pub/Sub operations
Publish(channel string, message []byte) error
Subscribe(channel string) (Subscription, error)
Publish(ctx context.Context, channel string, message []byte) error
Subscribe(ctx context.Context, channel string) (Subscription, error)
// Pipeline (optional) - 我们在redis实现它内存版暂时不实现
Pipeline() Pipeliner
// Pipeline
Pipeline(ctx context.Context) Pipeliner
// Close closes the store and releases any underlying resources.
Close() error
// === 新增方法,支持轮询策略 ===
ZAdd(key string, members map[string]float64) error
ZRange(key string, start, stop int64) ([]string, error)
ZRem(key string, members ...any) error
PopAndCycleSetMember(mainKey, cooldownKey string) (string, error)
// ZSET operations
ZAdd(ctx context.Context, key string, members map[string]float64) error
ZRange(ctx context.Context, key string, start, stop int64) ([]string, error)
ZRem(ctx context.Context, key string, members ...any) error
PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error)
}