Update Context for store
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
// Filename: internal/store/redis_store.go
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
@@ -8,22 +10,20 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ensure RedisStore implements Store interface
|
||||
var _ Store = (*RedisStore)(nil)
|
||||
|
||||
// RedisStore is a Redis-backed key-value store.
|
||||
type RedisStore struct {
|
||||
client *redis.Client
|
||||
popAndCycleScript *redis.Script
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
// NewRedisStore creates a new RedisStore instance.
|
||||
func NewRedisStore(client *redis.Client) Store {
|
||||
// Lua script for atomic pop-and-cycle operation.
|
||||
// KEYS[1]: main set key
|
||||
// KEYS[2]: cooldown set key
|
||||
func NewRedisStore(client *redis.Client, logger *logrus.Logger) Store {
|
||||
const script = `
|
||||
if redis.call('SCARD', KEYS[1]) == 0 then
|
||||
if redis.call('SCARD', KEYS[2]) == 0 then
|
||||
@@ -36,15 +36,16 @@ func NewRedisStore(client *redis.Client) Store {
|
||||
return &RedisStore{
|
||||
client: client,
|
||||
popAndCycleScript: redis.NewScript(script),
|
||||
logger: logger.WithField("component", "store.redis 🗄️"),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RedisStore) Set(key string, value []byte, ttl time.Duration) error {
|
||||
return s.client.Set(context.Background(), key, value, ttl).Err()
|
||||
func (s *RedisStore) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
return s.client.Set(ctx, key, value, ttl).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Get(key string) ([]byte, error) {
|
||||
val, err := s.client.Get(context.Background(), key).Bytes()
|
||||
func (s *RedisStore) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
val, err := s.client.Get(ctx, key).Bytes()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrNotFound
|
||||
@@ -54,53 +55,53 @@ func (s *RedisStore) Get(key string) ([]byte, error) {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) Del(keys ...string) error {
|
||||
func (s *RedisStore) Del(ctx context.Context, keys ...string) error {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.Del(context.Background(), keys...).Err()
|
||||
return s.client.Del(ctx, keys...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Exists(key string) (bool, error) {
|
||||
val, err := s.client.Exists(context.Background(), key).Result()
|
||||
func (s *RedisStore) Exists(ctx context.Context, key string) (bool, error) {
|
||||
val, err := s.client.Exists(ctx, key).Result()
|
||||
return val > 0, err
|
||||
}
|
||||
|
||||
func (s *RedisStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
return s.client.SetNX(context.Background(), key, value, ttl).Result()
|
||||
func (s *RedisStore) SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
return s.client.SetNX(ctx, key, value, ttl).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Close() error {
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HSet(key string, values map[string]any) error {
|
||||
return s.client.HSet(context.Background(), key, values).Err()
|
||||
func (s *RedisStore) HSet(ctx context.Context, key string, values map[string]any) error {
|
||||
return s.client.HSet(ctx, key, values).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HGetAll(key string) (map[string]string, error) {
|
||||
return s.client.HGetAll(context.Background(), key).Result()
|
||||
func (s *RedisStore) HGetAll(ctx context.Context, key string) (map[string]string, error) {
|
||||
return s.client.HGetAll(ctx, key).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HIncrBy(key, field string, incr int64) (int64, error) {
|
||||
return s.client.HIncrBy(context.Background(), key, field, incr).Result()
|
||||
func (s *RedisStore) HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error) {
|
||||
return s.client.HIncrBy(ctx, key, field, incr).Result()
|
||||
}
|
||||
func (s *RedisStore) HDel(key string, fields ...string) error {
|
||||
func (s *RedisStore) HDel(ctx context.Context, key string, fields ...string) error {
|
||||
if len(fields) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.HDel(context.Background(), key, fields...).Err()
|
||||
return s.client.HDel(ctx, key, fields...).Err()
|
||||
}
|
||||
func (s *RedisStore) LPush(key string, values ...any) error {
|
||||
return s.client.LPush(context.Background(), key, values...).Err()
|
||||
func (s *RedisStore) LPush(ctx context.Context, key string, values ...any) error {
|
||||
return s.client.LPush(ctx, key, values...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) LRem(key string, count int64, value any) error {
|
||||
return s.client.LRem(context.Background(), key, count, value).Err()
|
||||
func (s *RedisStore) LRem(ctx context.Context, key string, count int64, value any) error {
|
||||
return s.client.LRem(ctx, key, count, value).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Rotate(key string) (string, error) {
|
||||
val, err := s.client.RPopLPush(context.Background(), key, key).Result()
|
||||
func (s *RedisStore) Rotate(ctx context.Context, key string) (string, error) {
|
||||
val, err := s.client.RPopLPush(ctx, key, key).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
@@ -110,29 +111,28 @@ func (s *RedisStore) Rotate(key string) (string, error) {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) SAdd(key string, members ...any) error {
|
||||
return s.client.SAdd(context.Background(), key, members...).Err()
|
||||
func (s *RedisStore) SAdd(ctx context.Context, key string, members ...any) error {
|
||||
return s.client.SAdd(ctx, key, members...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SPopN(key string, count int64) ([]string, error) {
|
||||
return s.client.SPopN(context.Background(), key, count).Result()
|
||||
func (s *RedisStore) SPopN(ctx context.Context, key string, count int64) ([]string, error) {
|
||||
return s.client.SPopN(ctx, key, count).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SMembers(key string) ([]string, error) {
|
||||
return s.client.SMembers(context.Background(), key).Result()
|
||||
func (s *RedisStore) SMembers(ctx context.Context, key string) ([]string, error) {
|
||||
return s.client.SMembers(ctx, key).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SRem(key string, members ...any) error {
|
||||
func (s *RedisStore) SRem(ctx context.Context, key string, members ...any) error {
|
||||
if len(members) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.SRem(context.Background(), key, members...).Err()
|
||||
return s.client.SRem(ctx, key, members...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SRandMember(key string) (string, error) {
|
||||
member, err := s.client.SRandMember(context.Background(), key).Result()
|
||||
func (s *RedisStore) SRandMember(ctx context.Context, key string) (string, error) {
|
||||
member, err := s.client.SRandMember(ctx, key).Result()
|
||||
if err != nil {
|
||||
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
@@ -141,81 +141,43 @@ func (s *RedisStore) SRandMember(key string) (string, error) {
|
||||
return member, nil
|
||||
}
|
||||
|
||||
// === 新增方法实现 ===
|
||||
|
||||
func (s *RedisStore) ZAdd(key string, members map[string]float64) error {
|
||||
func (s *RedisStore) ZAdd(ctx context.Context, key string, members map[string]float64) error {
|
||||
if len(members) == 0 {
|
||||
return nil
|
||||
}
|
||||
redisMembers := make([]redis.Z, 0, len(members))
|
||||
redisMembers := make([]redis.Z, len(members))
|
||||
i := 0
|
||||
for member, score := range members {
|
||||
redisMembers = append(redisMembers, redis.Z{Score: score, Member: member})
|
||||
redisMembers[i] = redis.Z{Score: score, Member: member}
|
||||
i++
|
||||
}
|
||||
return s.client.ZAdd(context.Background(), key, redisMembers...).Err()
|
||||
return s.client.ZAdd(ctx, key, redisMembers...).Err()
|
||||
}
|
||||
func (s *RedisStore) ZRange(key string, start, stop int64) ([]string, error) {
|
||||
return s.client.ZRange(context.Background(), key, start, stop).Result()
|
||||
func (s *RedisStore) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) {
|
||||
return s.client.ZRange(ctx, key, start, stop).Result()
|
||||
}
|
||||
func (s *RedisStore) ZRem(key string, members ...any) error {
|
||||
func (s *RedisStore) ZRem(ctx context.Context, key string, members ...any) error {
|
||||
if len(members) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.ZRem(context.Background(), key, members...).Err()
|
||||
return s.client.ZRem(ctx, key, members...).Err()
|
||||
}
|
||||
func (s *RedisStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
|
||||
val, err := s.popAndCycleScript.Run(context.Background(), s.client, []string{mainKey, cooldownKey}).Result()
|
||||
func (s *RedisStore) PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error) {
|
||||
val, err := s.popAndCycleScript.Run(ctx, s.client, []string{mainKey, cooldownKey}).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
// Lua script returns a string, so we need to type assert
|
||||
if str, ok := val.(string); ok {
|
||||
return str, nil
|
||||
}
|
||||
return "", ErrNotFound // This happens if both sets were empty and the script returned nil
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
type redisPipeliner struct{ pipe redis.Pipeliner }
|
||||
|
||||
func (p *redisPipeliner) HSet(key string, values map[string]any) {
|
||||
p.pipe.HSet(context.Background(), key, values)
|
||||
}
|
||||
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
|
||||
p.pipe.HIncrBy(context.Background(), key, field, incr)
|
||||
}
|
||||
func (p *redisPipeliner) Exec() error {
|
||||
_, err := p.pipe.Exec(context.Background())
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Del(keys ...string) {
|
||||
if len(keys) > 0 {
|
||||
p.pipe.Del(context.Background(), keys...)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) SAdd(key string, members ...any) {
|
||||
p.pipe.SAdd(context.Background(), key, members...)
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) SRem(key string, members ...any) {
|
||||
if len(members) > 0 {
|
||||
p.pipe.SRem(context.Background(), key, members...)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) LPush(key string, values ...any) {
|
||||
p.pipe.LPush(context.Background(), key, values...)
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) LRem(key string, count int64, value any) {
|
||||
p.pipe.LRem(context.Background(), key, count, value)
|
||||
}
|
||||
|
||||
func (s *RedisStore) LIndex(key string, index int64) (string, error) {
|
||||
val, err := s.client.LIndex(context.Background(), key, index).Result()
|
||||
func (s *RedisStore) LIndex(ctx context.Context, key string, index int64) (string, error) {
|
||||
val, err := s.client.LIndex(ctx, key, index).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
@@ -225,47 +187,120 @@ func (s *RedisStore) LIndex(key string, index int64) (string, error) {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Expire(key string, expiration time.Duration) {
|
||||
p.pipe.Expire(context.Background(), key, expiration)
|
||||
type redisPipeliner struct {
|
||||
pipe redis.Pipeliner
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s *RedisStore) Pipeline() Pipeliner {
|
||||
return &redisPipeliner{pipe: s.client.Pipeline()}
|
||||
func (s *RedisStore) Pipeline(ctx context.Context) Pipeliner {
|
||||
return &redisPipeliner{
|
||||
pipe: s.client.Pipeline(),
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Exec() error {
|
||||
_, err := p.pipe.Exec(p.ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Del(keys ...string) { p.pipe.Del(p.ctx, keys...) }
|
||||
func (p *redisPipeliner) Expire(key string, expiration time.Duration) {
|
||||
p.pipe.Expire(p.ctx, key, expiration)
|
||||
}
|
||||
func (p *redisPipeliner) HSet(key string, values map[string]any) { p.pipe.HSet(p.ctx, key, values) }
|
||||
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
|
||||
p.pipe.HIncrBy(p.ctx, key, field, incr)
|
||||
}
|
||||
func (p *redisPipeliner) LPush(key string, values ...any) { p.pipe.LPush(p.ctx, key, values...) }
|
||||
func (p *redisPipeliner) LRem(key string, count int64, value any) {
|
||||
p.pipe.LRem(p.ctx, key, count, value)
|
||||
}
|
||||
func (p *redisPipeliner) SAdd(key string, members ...any) { p.pipe.SAdd(p.ctx, key, members...) }
|
||||
func (p *redisPipeliner) SRem(key string, members ...any) { p.pipe.SRem(p.ctx, key, members...) }
|
||||
func (p *redisPipeliner) ZAdd(key string, members map[string]float64) {
|
||||
if len(members) == 0 {
|
||||
return
|
||||
}
|
||||
redisMembers := make([]redis.Z, len(members))
|
||||
i := 0
|
||||
for member, score := range members {
|
||||
redisMembers[i] = redis.Z{Score: score, Member: member}
|
||||
i++
|
||||
}
|
||||
p.pipe.ZAdd(p.ctx, key, redisMembers...)
|
||||
}
|
||||
func (p *redisPipeliner) ZRem(key string, members ...any) { p.pipe.ZRem(p.ctx, key, members...) }
|
||||
|
||||
type redisSubscription struct {
|
||||
pubsub *redis.PubSub
|
||||
msgChan chan *Message
|
||||
once sync.Once
|
||||
pubsub *redis.PubSub
|
||||
msgChan chan *Message
|
||||
logger *logrus.Entry
|
||||
wg sync.WaitGroup
|
||||
close context.CancelFunc
|
||||
channelName string
|
||||
}
|
||||
|
||||
func (s *RedisStore) Subscribe(ctx context.Context, channel string) (Subscription, error) {
|
||||
pubsub := s.client.Subscribe(ctx, channel)
|
||||
_, err := pubsub.Receive(ctx)
|
||||
if err != nil {
|
||||
_ = pubsub.Close()
|
||||
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err)
|
||||
}
|
||||
subCtx, cancel := context.WithCancel(context.Background())
|
||||
sub := &redisSubscription{
|
||||
pubsub: pubsub,
|
||||
msgChan: make(chan *Message, 10),
|
||||
logger: s.logger,
|
||||
close: cancel,
|
||||
channelName: channel,
|
||||
}
|
||||
sub.wg.Add(1)
|
||||
go sub.bridge(subCtx)
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
func (rs *redisSubscription) bridge(ctx context.Context) {
|
||||
defer rs.wg.Done()
|
||||
defer close(rs.msgChan)
|
||||
redisCh := rs.pubsub.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case redisMsg, ok := <-redisCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
msg := &Message{
|
||||
Channel: redisMsg.Channel,
|
||||
Payload: []byte(redisMsg.Payload),
|
||||
}
|
||||
select {
|
||||
case rs.msgChan <- msg:
|
||||
default:
|
||||
rs.logger.Warnf("Message dropped for channel '%s' due to slow consumer.", rs.channelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *redisSubscription) Channel() <-chan *Message {
|
||||
rs.once.Do(func() {
|
||||
rs.msgChan = make(chan *Message)
|
||||
go func() {
|
||||
defer close(rs.msgChan)
|
||||
for redisMsg := range rs.pubsub.Channel() {
|
||||
rs.msgChan <- &Message{
|
||||
Channel: redisMsg.Channel,
|
||||
Payload: []byte(redisMsg.Payload),
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
return rs.msgChan
|
||||
}
|
||||
|
||||
func (rs *redisSubscription) Close() error { return rs.pubsub.Close() }
|
||||
|
||||
func (s *RedisStore) Publish(channel string, message []byte) error {
|
||||
return s.client.Publish(context.Background(), channel, message).Err()
|
||||
func (rs *redisSubscription) ChannelName() string {
|
||||
return rs.channelName
|
||||
}
|
||||
|
||||
func (s *RedisStore) Subscribe(channel string) (Subscription, error) {
|
||||
pubsub := s.client.Subscribe(context.Background(), channel)
|
||||
_, err := pubsub.Receive(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err)
|
||||
}
|
||||
return &redisSubscription{pubsub: pubsub}, nil
|
||||
func (rs *redisSubscription) Close() error {
|
||||
rs.close()
|
||||
err := rs.pubsub.Close()
|
||||
rs.wg.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *RedisStore) Publish(ctx context.Context, channel string, message []byte) error {
|
||||
return s.client.Publish(ctx, channel, message).Err()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user