package store import ( "context" "errors" "fmt" "sync" "time" "github.com/redis/go-redis/v9" ) // ensure RedisStore implements Store interface var _ Store = (*RedisStore)(nil) // RedisStore is a Redis-backed key-value store. type RedisStore struct { client *redis.Client popAndCycleScript *redis.Script } // NewRedisStore creates a new RedisStore instance. func NewRedisStore(client *redis.Client) Store { // Lua script for atomic pop-and-cycle operation. // KEYS[1]: main set key // KEYS[2]: cooldown set key const script = ` if redis.call('SCARD', KEYS[1]) == 0 then if redis.call('SCARD', KEYS[2]) == 0 then return nil end redis.call('RENAME', KEYS[2], KEYS[1]) end return redis.call('SPOP', KEYS[1]) ` return &RedisStore{ client: client, popAndCycleScript: redis.NewScript(script), } } func (s *RedisStore) Set(key string, value []byte, ttl time.Duration) error { return s.client.Set(context.Background(), key, value, ttl).Err() } func (s *RedisStore) Get(key string) ([]byte, error) { val, err := s.client.Get(context.Background(), key).Bytes() if err != nil { if errors.Is(err, redis.Nil) { return nil, ErrNotFound } return nil, err } return val, nil } func (s *RedisStore) Del(keys ...string) error { if len(keys) == 0 { return nil } return s.client.Del(context.Background(), keys...).Err() } func (s *RedisStore) Exists(key string) (bool, error) { val, err := s.client.Exists(context.Background(), key).Result() return val > 0, err } func (s *RedisStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) { return s.client.SetNX(context.Background(), key, value, ttl).Result() } func (s *RedisStore) Close() error { return s.client.Close() } func (s *RedisStore) HSet(key string, values map[string]any) error { return s.client.HSet(context.Background(), key, values).Err() } func (s *RedisStore) HGetAll(key string) (map[string]string, error) { return s.client.HGetAll(context.Background(), key).Result() } func (s *RedisStore) HIncrBy(key, field string, incr int64) (int64, error) { return s.client.HIncrBy(context.Background(), key, field, incr).Result() } func (s *RedisStore) HDel(key string, fields ...string) error { if len(fields) == 0 { return nil } return s.client.HDel(context.Background(), key, fields...).Err() } func (s *RedisStore) LPush(key string, values ...any) error { return s.client.LPush(context.Background(), key, values...).Err() } func (s *RedisStore) LRem(key string, count int64, value any) error { return s.client.LRem(context.Background(), key, count, value).Err() } func (s *RedisStore) Rotate(key string) (string, error) { val, err := s.client.RPopLPush(context.Background(), key, key).Result() if err != nil { if errors.Is(err, redis.Nil) { return "", ErrNotFound } return "", err } return val, nil } func (s *RedisStore) SAdd(key string, members ...any) error { return s.client.SAdd(context.Background(), key, members...).Err() } func (s *RedisStore) SPopN(key string, count int64) ([]string, error) { return s.client.SPopN(context.Background(), key, count).Result() } func (s *RedisStore) SMembers(key string) ([]string, error) { return s.client.SMembers(context.Background(), key).Result() } func (s *RedisStore) SRem(key string, members ...any) error { if len(members) == 0 { return nil } return s.client.SRem(context.Background(), key, members...).Err() } func (s *RedisStore) SRandMember(key string) (string, error) { member, err := s.client.SRandMember(context.Background(), key).Result() if err != nil { if errors.Is(err, redis.Nil) { return "", ErrNotFound } return "", err } return member, nil } // === 新增方法实现 === func (s *RedisStore) ZAdd(key string, members map[string]float64) error { if len(members) == 0 { return nil } redisMembers := make([]redis.Z, 0, len(members)) for member, score := range members { redisMembers = append(redisMembers, redis.Z{Score: score, Member: member}) } return s.client.ZAdd(context.Background(), key, redisMembers...).Err() } func (s *RedisStore) ZRange(key string, start, stop int64) ([]string, error) { return s.client.ZRange(context.Background(), key, start, stop).Result() } func (s *RedisStore) ZRem(key string, members ...any) error { if len(members) == 0 { return nil } return s.client.ZRem(context.Background(), key, members...).Err() } func (s *RedisStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) { val, err := s.popAndCycleScript.Run(context.Background(), s.client, []string{mainKey, cooldownKey}).Result() if err != nil { if errors.Is(err, redis.Nil) { return "", ErrNotFound } return "", err } // Lua script returns a string, so we need to type assert if str, ok := val.(string); ok { return str, nil } return "", ErrNotFound // This happens if both sets were empty and the script returned nil } type redisPipeliner struct{ pipe redis.Pipeliner } func (p *redisPipeliner) HSet(key string, values map[string]any) { p.pipe.HSet(context.Background(), key, values) } func (p *redisPipeliner) HIncrBy(key, field string, incr int64) { p.pipe.HIncrBy(context.Background(), key, field, incr) } func (p *redisPipeliner) Exec() error { _, err := p.pipe.Exec(context.Background()) return err } func (p *redisPipeliner) Del(keys ...string) { if len(keys) > 0 { p.pipe.Del(context.Background(), keys...) } } func (p *redisPipeliner) SAdd(key string, members ...any) { p.pipe.SAdd(context.Background(), key, members...) } func (p *redisPipeliner) SRem(key string, members ...any) { if len(members) > 0 { p.pipe.SRem(context.Background(), key, members...) } } func (p *redisPipeliner) LPush(key string, values ...any) { p.pipe.LPush(context.Background(), key, values...) } func (p *redisPipeliner) LRem(key string, count int64, value any) { p.pipe.LRem(context.Background(), key, count, value) } func (s *RedisStore) LIndex(key string, index int64) (string, error) { val, err := s.client.LIndex(context.Background(), key, index).Result() if err != nil { if errors.Is(err, redis.Nil) { return "", ErrNotFound } return "", err } return val, nil } func (p *redisPipeliner) Expire(key string, expiration time.Duration) { p.pipe.Expire(context.Background(), key, expiration) } func (s *RedisStore) Pipeline() Pipeliner { return &redisPipeliner{pipe: s.client.Pipeline()} } type redisSubscription struct { pubsub *redis.PubSub msgChan chan *Message once sync.Once } func (rs *redisSubscription) Channel() <-chan *Message { rs.once.Do(func() { rs.msgChan = make(chan *Message) go func() { defer close(rs.msgChan) for redisMsg := range rs.pubsub.Channel() { rs.msgChan <- &Message{ Channel: redisMsg.Channel, Payload: []byte(redisMsg.Payload), } } }() }) return rs.msgChan } func (rs *redisSubscription) Close() error { return rs.pubsub.Close() } func (s *RedisStore) Publish(channel string, message []byte) error { return s.client.Publish(context.Background(), channel, message).Err() } func (s *RedisStore) Subscribe(channel string) (Subscription, error) { pubsub := s.client.Subscribe(context.Background(), channel) _, err := pubsub.Receive(context.Background()) if err != nil { return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err) } return &redisSubscription{pubsub: pubsub}, nil }