// Filename: internal/store/redis_store.go package store import ( "context" "errors" "fmt" "sync" "time" "github.com/redis/go-redis/v9" "github.com/sirupsen/logrus" ) // ensure RedisStore implements Store interface var _ Store = (*RedisStore)(nil) type RedisStore struct { client *redis.Client popAndCycleScript *redis.Script logger *logrus.Entry } // NewRedisStore creates a new RedisStore instance. func NewRedisStore(client *redis.Client, logger *logrus.Logger) Store { const script = ` if redis.call('SCARD', KEYS[1]) == 0 then if redis.call('SCARD', KEYS[2]) == 0 then return nil end redis.call('RENAME', KEYS[2], KEYS[1]) end return redis.call('SPOP', KEYS[1]) ` return &RedisStore{ client: client, popAndCycleScript: redis.NewScript(script), logger: logger.WithField("component", "store.redis πŸ—„οΈ"), } } func (s *RedisStore) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error { return s.client.Set(ctx, key, value, ttl).Err() } func (s *RedisStore) Get(ctx context.Context, key string) ([]byte, error) { val, err := s.client.Get(ctx, key).Bytes() if err != nil { if errors.Is(err, redis.Nil) { return nil, ErrNotFound } return nil, err } return val, nil } func (s *RedisStore) Del(ctx context.Context, keys ...string) error { if len(keys) == 0 { return nil } return s.client.Del(ctx, keys...).Err() } func (s *RedisStore) Exists(ctx context.Context, key string) (bool, error) { val, err := s.client.Exists(ctx, key).Result() return val > 0, err } func (s *RedisStore) SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error) { return s.client.SetNX(ctx, key, value, ttl).Result() } func (s *RedisStore) Close() error { return s.client.Close() } func (s *RedisStore) Expire(ctx context.Context, key string, expiration time.Duration) error { return s.client.Expire(ctx, key, expiration).Err() } func (s *RedisStore) HSet(ctx context.Context, key string, values map[string]any) error { return s.client.HSet(ctx, key, values).Err() } func (s *RedisStore) HGet(ctx context.Context, key, field string) (string, error) { val, err := s.client.HGet(ctx, key, field).Result() if err != nil { if errors.Is(err, redis.Nil) { return "", ErrNotFound } return "", err } return val, nil } func (s *RedisStore) HGetAll(ctx context.Context, key string) (map[string]string, error) { return s.client.HGetAll(ctx, key).Result() } func (s *RedisStore) HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error) { return s.client.HIncrBy(ctx, key, field, incr).Result() } func (s *RedisStore) HDel(ctx context.Context, key string, fields ...string) error { if len(fields) == 0 { return nil } return s.client.HDel(ctx, key, fields...).Err() } func (s *RedisStore) LPush(ctx context.Context, key string, values ...any) error { return s.client.LPush(ctx, key, values...).Err() } func (s *RedisStore) LRem(ctx context.Context, key string, count int64, value any) error { return s.client.LRem(ctx, key, count, value).Err() } func (s *RedisStore) Rotate(ctx context.Context, key string) (string, error) { val, err := s.client.RPopLPush(ctx, key, key).Result() if err != nil { if errors.Is(err, redis.Nil) { return "", ErrNotFound } return "", err } return val, nil } func (s *RedisStore) MSet(ctx context.Context, values map[string]any) error { if len(values) == 0 { return nil } // Redis MSet ε‘½δ»€ιœ€θ¦ [key1, value1, key2, value2, ...] ζ ΌεΌηš„εˆ‡η‰‡ pairs := make([]interface{}, 0, len(values)*2) for k, v := range values { pairs = append(pairs, k, v) } return s.client.MSet(ctx, pairs...).Err() } func (s *RedisStore) SAdd(ctx context.Context, key string, members ...any) error { return s.client.SAdd(ctx, key, members...).Err() } func (s *RedisStore) SPopN(ctx context.Context, key string, count int64) ([]string, error) { return s.client.SPopN(ctx, key, count).Result() } func (s *RedisStore) SMembers(ctx context.Context, key string) ([]string, error) { return s.client.SMembers(ctx, key).Result() } func (s *RedisStore) SRem(ctx context.Context, key string, members ...any) error { if len(members) == 0 { return nil } return s.client.SRem(ctx, key, members...).Err() } func (s *RedisStore) SRandMember(ctx context.Context, key string) (string, error) { member, err := s.client.SRandMember(ctx, key).Result() if err != nil { if errors.Is(err, redis.Nil) { return "", ErrNotFound } return "", err } return member, nil } func (s *RedisStore) SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error) { if len(keys) == 0 { return 0, nil } return s.client.SUnionStore(ctx, destination, keys...).Result() } func (s *RedisStore) ZAdd(ctx context.Context, key string, members map[string]float64) error { if len(members) == 0 { return nil } redisMembers := make([]redis.Z, len(members)) i := 0 for member, score := range members { redisMembers[i] = redis.Z{Score: score, Member: member} i++ } return s.client.ZAdd(ctx, key, redisMembers...).Err() } func (s *RedisStore) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) { return s.client.ZRange(ctx, key, start, stop).Result() } func (s *RedisStore) ZRem(ctx context.Context, key string, members ...any) error { if len(members) == 0 { return nil } return s.client.ZRem(ctx, key, members...).Err() } func (s *RedisStore) PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error) { val, err := s.popAndCycleScript.Run(ctx, s.client, []string{mainKey, cooldownKey}).Result() if err != nil { if errors.Is(err, redis.Nil) { return "", ErrNotFound } return "", err } if str, ok := val.(string); ok { return str, nil } return "", ErrNotFound } func (s *RedisStore) LIndex(ctx context.Context, key string, index int64) (string, error) { val, err := s.client.LIndex(ctx, key, index).Result() if err != nil { if errors.Is(err, redis.Nil) { return "", ErrNotFound } return "", err } return val, nil } type redisPipeliner struct { pipe redis.Pipeliner ctx context.Context } func (s *RedisStore) Pipeline(ctx context.Context) Pipeliner { return &redisPipeliner{ pipe: s.client.Pipeline(), ctx: ctx, } } func (p *redisPipeliner) Exec() error { _, err := p.pipe.Exec(p.ctx) return err } func (p *redisPipeliner) Del(keys ...string) { p.pipe.Del(p.ctx, keys...) } func (p *redisPipeliner) Expire(key string, expiration time.Duration) { p.pipe.Expire(p.ctx, key, expiration) } func (p *redisPipeliner) HSet(key string, values map[string]any) { p.pipe.HSet(p.ctx, key, values) } func (p *redisPipeliner) HIncrBy(key, field string, incr int64) { p.pipe.HIncrBy(p.ctx, key, field, incr) } func (p *redisPipeliner) LPush(key string, values ...any) { p.pipe.LPush(p.ctx, key, values...) } func (p *redisPipeliner) LRem(key string, count int64, value any) { p.pipe.LRem(p.ctx, key, count, value) } func (p *redisPipeliner) Set(key string, value []byte, expiration time.Duration) { p.pipe.Set(p.ctx, key, value, expiration) } func (p *redisPipeliner) MSet(values map[string]any) { if len(values) == 0 { return } p.pipe.MSet(p.ctx, values) } func (p *redisPipeliner) SAdd(key string, members ...any) { p.pipe.SAdd(p.ctx, key, members...) } func (p *redisPipeliner) SRem(key string, members ...any) { p.pipe.SRem(p.ctx, key, members...) } func (p *redisPipeliner) ZAdd(key string, members map[string]float64) { if len(members) == 0 { return } redisMembers := make([]redis.Z, len(members)) i := 0 for member, score := range members { redisMembers[i] = redis.Z{Score: score, Member: member} i++ } p.pipe.ZAdd(p.ctx, key, redisMembers...) } func (p *redisPipeliner) ZRem(key string, members ...any) { p.pipe.ZRem(p.ctx, key, members...) } type redisSubscription struct { pubsub *redis.PubSub msgChan chan *Message logger *logrus.Entry wg sync.WaitGroup close context.CancelFunc channelName string } func (s *RedisStore) Subscribe(ctx context.Context, channel string) (Subscription, error) { pubsub := s.client.Subscribe(ctx, channel) _, err := pubsub.Receive(ctx) if err != nil { _ = pubsub.Close() return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err) } subCtx, cancel := context.WithCancel(context.Background()) sub := &redisSubscription{ pubsub: pubsub, msgChan: make(chan *Message, 10), logger: s.logger, close: cancel, channelName: channel, } sub.wg.Add(1) go sub.bridge(subCtx) return sub, nil } func (rs *redisSubscription) bridge(ctx context.Context) { defer rs.wg.Done() defer close(rs.msgChan) redisCh := rs.pubsub.Channel() for { select { case <-ctx.Done(): return case redisMsg, ok := <-redisCh: if !ok { return } msg := &Message{ Channel: redisMsg.Channel, Payload: []byte(redisMsg.Payload), } select { case rs.msgChan <- msg: default: rs.logger.Warnf("Message dropped for channel '%s' due to slow consumer.", rs.channelName) } } } } func (rs *redisSubscription) Channel() <-chan *Message { return rs.msgChan } func (rs *redisSubscription) ChannelName() string { return rs.channelName } func (rs *redisSubscription) Close() error { rs.close() err := rs.pubsub.Close() rs.wg.Wait() return err } func (s *RedisStore) Publish(ctx context.Context, channel string, message []byte) error { return s.client.Publish(ctx, channel, message).Err() }