New
This commit is contained in:
30
internal/store/factory.go
Normal file
30
internal/store/factory.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/config"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// NewStore creates a new store based on the application configuration.
|
||||
func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) {
|
||||
// 检查是否有Redis配置
|
||||
if cfg.Redis.DSN != "" {
|
||||
opts, err := redis.ParseURL(cfg.Redis.DSN)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse redis DSN: %w", err)
|
||||
}
|
||||
client := redis.NewClient(opts)
|
||||
if err := client.Ping(context.Background()).Err(); err != nil {
|
||||
logger.WithError(err).Warnf("WARN: Failed to connect to Redis (%s), falling back to in-memory store. Error: %v", cfg.Redis.DSN, err)
|
||||
return NewMemoryStore(logger), nil // 连接失败,也回退到内存模式,但不返回错误
|
||||
}
|
||||
logger.Info("Successfully connected to Redis. Using Redis as store.")
|
||||
return NewRedisStore(client), nil
|
||||
}
|
||||
logger.Info("INFO: Redis DSN not configured, falling back to in-memory store.")
|
||||
return NewMemoryStore(logger), nil
|
||||
}
|
||||
642
internal/store/memory_store.go
Normal file
642
internal/store/memory_store.go
Normal file
@@ -0,0 +1,642 @@
|
||||
// Filename: internal/store/memory_store.go (统一存储重构版)
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ensure memoryStore implements Store interface
|
||||
var _ Store = (*memoryStore)(nil)
|
||||
|
||||
// [核心重构] memoryStoreItem 现在是通用容器,可以存储任何类型的值,并自带过期时间
|
||||
type memoryStoreItem struct {
|
||||
value interface{} // 可以是 []byte, []string, map[string]string, map[string]struct{}, []zsetMember
|
||||
expireAt time.Time
|
||||
}
|
||||
|
||||
// isExpired 检查一个条目是否已过期
|
||||
func (item *memoryStoreItem) isExpired() bool {
|
||||
return !item.expireAt.IsZero() && time.Now().After(item.expireAt)
|
||||
}
|
||||
|
||||
// zsetMember 保持不变
|
||||
type zsetMember struct {
|
||||
Value string
|
||||
Score float64
|
||||
}
|
||||
|
||||
// [核心重构] memoryStore 现在使用一个统一的 map 来存储所有数据
|
||||
type memoryStore struct {
|
||||
items map[string]*memoryStoreItem // 指向 item 的指针,以便原地修改
|
||||
pubsub map[string][]chan *Message
|
||||
mu sync.RWMutex
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
// NewMemoryStore [核心重構] 構造函數也被簡化了
|
||||
func NewMemoryStore(logger *logrus.Logger) Store {
|
||||
return &memoryStore{
|
||||
items: make(map[string]*memoryStoreItem),
|
||||
pubsub: make(map[string][]chan *Message),
|
||||
logger: logger.WithField("component", "store.memory 🗱"),
|
||||
}
|
||||
}
|
||||
|
||||
// [核心重构] getItem 是一个新的内部辅助函数,它封装了获取、检查过期和删除的通用逻辑
|
||||
func (s *memoryStore) getItem(key string, lockForWrite bool) *memoryStoreItem {
|
||||
if !lockForWrite {
|
||||
// 如果是读操作,先用读锁检查
|
||||
s.mu.RLock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
s.mu.RUnlock()
|
||||
// 如果不存在或已过期,需要尝试获取写锁来删除它
|
||||
if ok { // 只有在确定 item 存在但已过期时才需要删除
|
||||
s.mu.Lock()
|
||||
// 再次检查,防止在获取写锁期间状态已改变
|
||||
if item, ok := s.items[key]; ok && item.isExpired() {
|
||||
delete(s.items, key)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
return nil // 无论如何都返回 nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
return item
|
||||
}
|
||||
|
||||
// 对于写操作,直接使用写锁
|
||||
item, ok := s.items[key]
|
||||
if ok && item.isExpired() {
|
||||
delete(s.items, key)
|
||||
return nil
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
// --- 所有接口方法现在都基于新的统一结构重写 ---
|
||||
|
||||
func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var expireAt time.Time
|
||||
if ttl > 0 {
|
||||
expireAt = time.Now().Add(ttl)
|
||||
}
|
||||
s.items[key] = &memoryStoreItem{value: value, expireAt: expireAt}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Get(key string) ([]byte, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if value, ok := item.value.([]byte); ok {
|
||||
return value, nil
|
||||
}
|
||||
return nil, ErrNotFound // Type mismatch, treat as not found
|
||||
}
|
||||
|
||||
func (s *memoryStore) Del(keys ...string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, key := range keys {
|
||||
delete(s.items, key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) delNoLock(keys ...string) {
|
||||
for _, key := range keys {
|
||||
delete(s.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *memoryStore) Exists(key string) (bool, error) {
|
||||
return s.getItem(key, false) != nil, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if item := s.getItem(key, true); item != nil {
|
||||
return false, nil
|
||||
}
|
||||
var expireAt time.Time
|
||||
if ttl > 0 {
|
||||
expireAt = time.Now().Add(ttl)
|
||||
}
|
||||
s.items[key] = &memoryStoreItem{value: value, expireAt: expireAt}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Close() error { return nil }
|
||||
|
||||
func (s *memoryStore) HDel(key string, fields ...string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
return nil
|
||||
}
|
||||
if hash, ok := item.value.(map[string]string); ok {
|
||||
for _, field := range fields {
|
||||
delete(hash, field)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) HSet(key string, values map[string]any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.hSetNoLock(key, values)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) hSetNoLock(key string, values map[string]any) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item = &memoryStoreItem{value: make(map[string]string)}
|
||||
s.items[key] = item
|
||||
}
|
||||
hash, ok := item.value.(map[string]string)
|
||||
if !ok { // If key exists but is not a hash, create a new hash
|
||||
hash = make(map[string]string)
|
||||
item.value = hash
|
||||
}
|
||||
for field, value := range values {
|
||||
hash[field] = fmt.Sprintf("%v", value)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
return make(map[string]string), nil
|
||||
}
|
||||
if hash, ok := item.value.(map[string]string); ok {
|
||||
result := make(map[string]string, len(hash))
|
||||
for k, v := range hash {
|
||||
result[k] = v
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
return make(map[string]string), nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.hIncrByNoLock(key, field, incr)
|
||||
}
|
||||
|
||||
func (s *memoryStore) hIncrByNoLock(key, field string, incr int64) (int64, error) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item = &memoryStoreItem{value: make(map[string]string)}
|
||||
s.items[key] = item
|
||||
}
|
||||
hash, ok := item.value.(map[string]string)
|
||||
if !ok {
|
||||
hash = make(map[string]string)
|
||||
item.value = hash
|
||||
}
|
||||
var currentVal int64
|
||||
if valStr, ok := hash[field]; ok {
|
||||
fmt.Sscanf(valStr, "%d", ¤tVal)
|
||||
}
|
||||
newVal := currentVal + incr
|
||||
hash[field] = fmt.Sprintf("%d", newVal)
|
||||
return newVal, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) LPush(key string, values ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.lPushNoLock(key, values...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) lPushNoLock(key string, values ...any) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item = &memoryStoreItem{value: make([]string, 0)}
|
||||
s.items[key] = item
|
||||
}
|
||||
list, ok := item.value.([]string)
|
||||
if !ok {
|
||||
list = make([]string, 0)
|
||||
}
|
||||
stringValues := make([]string, len(values))
|
||||
for i, v := range values {
|
||||
stringValues[i] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
item.value = append(stringValues, list...)
|
||||
}
|
||||
|
||||
func (s *memoryStore) LRem(key string, count int64, value any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.lRemNoLock(key, count, value)
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) lRemNoLock(key string, count int64, value any) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
return
|
||||
}
|
||||
list, ok := item.value.([]string)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
valToRemove := fmt.Sprintf("%v", value)
|
||||
newList := make([]string, 0, len(list))
|
||||
removedCount := int64(0)
|
||||
for _, v := range list {
|
||||
if v == valToRemove && (count == 0 || removedCount < count) {
|
||||
removedCount++
|
||||
} else {
|
||||
newList = append(newList, v)
|
||||
}
|
||||
}
|
||||
item.value = newList
|
||||
}
|
||||
func (s *memoryStore) SAdd(key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sAddNoLock(key, members...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) sAddNoLock(key string, members ...any) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item = &memoryStoreItem{value: make(map[string]struct{})}
|
||||
s.items[key] = item
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok {
|
||||
set = make(map[string]struct{})
|
||||
item.value = set
|
||||
}
|
||||
for _, member := range members {
|
||||
set[fmt.Sprintf("%v", member)] = struct{}{}
|
||||
}
|
||||
}
|
||||
func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
return []string{}, nil
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok || len(set) == 0 {
|
||||
return []string{}, nil
|
||||
}
|
||||
if int64(len(set)) < count {
|
||||
count = int64(len(set))
|
||||
}
|
||||
popped := make([]string, 0, count)
|
||||
keys := make([]string, 0, len(set))
|
||||
for k := range set {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
rand.Shuffle(len(keys), func(i, j int) { keys[i], keys[j] = keys[j], keys[i] })
|
||||
for i := int64(0); i < count; i++ {
|
||||
poppedKey := keys[i]
|
||||
popped = append(popped, poppedKey)
|
||||
delete(set, poppedKey)
|
||||
}
|
||||
return popped, nil
|
||||
}
|
||||
func (s *memoryStore) SMembers(key string) ([]string, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
return []string{}, nil
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok {
|
||||
return []string{}, nil
|
||||
}
|
||||
s.mu.RLock() // Lock needed for iterating map
|
||||
defer s.mu.RUnlock()
|
||||
members := make([]string, 0, len(set))
|
||||
for member := range set {
|
||||
members = append(members, member)
|
||||
}
|
||||
return members, nil
|
||||
}
|
||||
func (s *memoryStore) SRem(key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sRemNoLock(key, members...)
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) sRemNoLock(key string, members ...any) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
return
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, member := range members {
|
||||
delete(set, fmt.Sprintf("%v", member))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *memoryStore) SRandMember(key string) (string, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok || len(set) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
members := make([]string, 0, len(set))
|
||||
for member := range set {
|
||||
members = append(members, member)
|
||||
}
|
||||
if len(members) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return members[rand.Intn(len(members))], nil
|
||||
}
|
||||
func (s *memoryStore) Rotate(key string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
list, ok := item.value.([]string)
|
||||
if !ok || len(list) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
val := list[len(list)-1]
|
||||
item.value = append([]string{val}, list[:len(list)-1]...)
|
||||
return val, nil
|
||||
}
|
||||
func (s *memoryStore) LIndex(key string, index int64) (string, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
list, ok := item.value.([]string)
|
||||
if !ok {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
l := int64(len(list))
|
||||
if index < 0 {
|
||||
index += l
|
||||
}
|
||||
if index < 0 || index >= l {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return list[index], nil
|
||||
}
|
||||
|
||||
// Zset methods... (ZAdd, ZRange, ZRem)
|
||||
func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item = &memoryStoreItem{value: make([]zsetMember, 0)}
|
||||
s.items[key] = item
|
||||
}
|
||||
zset, ok := item.value.([]zsetMember)
|
||||
if !ok {
|
||||
zset = make([]zsetMember, 0)
|
||||
}
|
||||
for memberVal, score := range members {
|
||||
found := false
|
||||
for i := range zset {
|
||||
if zset[i].Value == memberVal {
|
||||
zset[i].Score = score
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
zset = append(zset, zsetMember{Value: memberVal, Score: score})
|
||||
}
|
||||
}
|
||||
sort.Slice(zset, func(i, j int) bool {
|
||||
if zset[i].Score == zset[j].Score {
|
||||
return zset[i].Value < zset[j].Value
|
||||
}
|
||||
return zset[i].Score < zset[j].Score
|
||||
})
|
||||
item.value = zset
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
return []string{}, nil
|
||||
}
|
||||
zset, ok := item.value.([]zsetMember)
|
||||
if !ok {
|
||||
return []string{}, nil
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
l := int64(len(zset))
|
||||
if start < 0 {
|
||||
start += l
|
||||
}
|
||||
if stop < 0 {
|
||||
stop += l
|
||||
}
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
if start > stop || start >= l {
|
||||
return []string{}, nil
|
||||
}
|
||||
if stop >= l {
|
||||
stop = l - 1
|
||||
}
|
||||
result := make([]string, 0, stop-start+1)
|
||||
for i := start; i <= stop; i++ {
|
||||
result = append(result, zset[i].Value)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (s *memoryStore) ZRem(key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
return nil
|
||||
}
|
||||
zset, ok := item.value.([]zsetMember)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
membersToRemove := make(map[string]struct{}, len(members))
|
||||
for _, m := range members {
|
||||
membersToRemove[fmt.Sprintf("%v", m)] = struct{}{}
|
||||
}
|
||||
newZSet := make([]zsetMember, 0, len(zset))
|
||||
for _, z := range zset {
|
||||
if _, exists := membersToRemove[z.Value]; !exists {
|
||||
newZSet = append(newZSet, z)
|
||||
}
|
||||
}
|
||||
item.value = newZSet
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
mainItem := s.getItem(mainKey, true)
|
||||
if mainItem == nil || len(mainItem.value.(map[string]struct{})) == 0 {
|
||||
cooldownItem := s.getItem(cooldownKey, true)
|
||||
if cooldownItem == nil {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
// "Rename" by moving value and deleting old key
|
||||
s.items[mainKey] = cooldownItem
|
||||
delete(s.items, cooldownKey)
|
||||
mainItem = cooldownItem
|
||||
}
|
||||
mainSet, ok := mainItem.value.(map[string]struct{})
|
||||
if !ok || len(mainSet) == 0 {
|
||||
return "", ErrNotFound // Should not happen after cycle logic
|
||||
}
|
||||
var popped string
|
||||
for k := range mainSet {
|
||||
popped = k
|
||||
break
|
||||
}
|
||||
delete(mainSet, popped)
|
||||
return popped, nil
|
||||
}
|
||||
|
||||
// Pipeline implementation
|
||||
type memoryPipeliner struct {
|
||||
store *memoryStore
|
||||
ops []func()
|
||||
}
|
||||
|
||||
func (s *memoryStore) Pipeline() Pipeliner {
|
||||
return &memoryPipeliner{store: s}
|
||||
}
|
||||
func (p *memoryPipeliner) Exec() error {
|
||||
p.store.mu.Lock()
|
||||
defer p.store.mu.Unlock()
|
||||
for _, op := range p.ops {
|
||||
op()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// [核心修正] Expire 现在可以正确地为任何 key 设置过期时间
|
||||
func (p *memoryPipeliner) Expire(key string, expiration time.Duration) {
|
||||
p.ops = append(p.ops, func() {
|
||||
// This must be called within Exec's lock
|
||||
item := p.store.getItem(key, true)
|
||||
if item != nil {
|
||||
item.expireAt = time.Now().Add(expiration)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// All other pipeliner methods...
|
||||
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
|
||||
p.ops = append(p.ops, func() { p.store.hSetNoLock(key, values) })
|
||||
}
|
||||
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {
|
||||
p.ops = append(p.ops, func() { p.store.hIncrByNoLock(key, field, incr) })
|
||||
}
|
||||
func (p *memoryPipeliner) Del(keys ...string) {
|
||||
p.ops = append(p.ops, func() { p.store.delNoLock(keys...) })
|
||||
}
|
||||
func (p *memoryPipeliner) SAdd(key string, members ...any) {
|
||||
p.ops = append(p.ops, func() { p.store.sAddNoLock(key, members...) })
|
||||
}
|
||||
func (p *memoryPipeliner) SRem(key string, members ...any) {
|
||||
p.ops = append(p.ops, func() { p.store.sRemNoLock(key, members...) })
|
||||
}
|
||||
func (p *memoryPipeliner) LPush(key string, values ...any) {
|
||||
p.ops = append(p.ops, func() { p.store.lPushNoLock(key, values...) })
|
||||
}
|
||||
func (p *memoryPipeliner) LRem(key string, count int64, value any) {
|
||||
p.ops = append(p.ops, func() { p.store.lRemNoLock(key, count, value) })
|
||||
}
|
||||
|
||||
// Pub/Sub implementation (remains unchanged as it's a separate system)
|
||||
type memorySubscription struct {
|
||||
store *memoryStore
|
||||
channelName string
|
||||
msgChan chan *Message
|
||||
}
|
||||
|
||||
func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan }
|
||||
func (ms *memorySubscription) Close() error {
|
||||
return ms.store.removeSubscriber(ms.channelName, ms.msgChan)
|
||||
}
|
||||
func (s *memoryStore) Publish(channel string, message []byte) error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
subscribers, ok := s.pubsub[channel]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
msg := &Message{Channel: channel, Payload: message}
|
||||
for _, ch := range subscribers {
|
||||
select {
|
||||
case ch <- msg:
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
s.logger.Warnf("Could not publish to a subscriber on channel '%s' within 100ms", channel)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) Subscribe(channel string) (Subscription, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
msgChan := make(chan *Message, 10)
|
||||
sub := &memorySubscription{store: s, channelName: channel, msgChan: msgChan}
|
||||
s.pubsub[channel] = append(s.pubsub[channel], msgChan)
|
||||
return sub, nil
|
||||
}
|
||||
func (s *memoryStore) removeSubscriber(channelName string, msgChan chan *Message) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
subscribers, ok := s.pubsub[channelName]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
newSubscribers := make([]chan *Message, 0)
|
||||
for _, ch := range subscribers {
|
||||
if ch != msgChan {
|
||||
newSubscribers = append(newSubscribers, ch)
|
||||
}
|
||||
}
|
||||
if len(newSubscribers) == 0 {
|
||||
delete(s.pubsub, channelName)
|
||||
} else {
|
||||
s.pubsub[channelName] = newSubscribers
|
||||
}
|
||||
close(msgChan)
|
||||
return nil
|
||||
}
|
||||
271
internal/store/redis_store.go
Normal file
271
internal/store/redis_store.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ensure RedisStore implements Store interface
|
||||
var _ Store = (*RedisStore)(nil)
|
||||
|
||||
// RedisStore is a Redis-backed key-value store.
|
||||
type RedisStore struct {
|
||||
client *redis.Client
|
||||
popAndCycleScript *redis.Script
|
||||
}
|
||||
|
||||
// NewRedisStore creates a new RedisStore instance.
|
||||
func NewRedisStore(client *redis.Client) Store {
|
||||
// Lua script for atomic pop-and-cycle operation.
|
||||
// KEYS[1]: main set key
|
||||
// KEYS[2]: cooldown set key
|
||||
const script = `
|
||||
if redis.call('SCARD', KEYS[1]) == 0 then
|
||||
if redis.call('SCARD', KEYS[2]) == 0 then
|
||||
return nil
|
||||
end
|
||||
redis.call('RENAME', KEYS[2], KEYS[1])
|
||||
end
|
||||
return redis.call('SPOP', KEYS[1])
|
||||
`
|
||||
return &RedisStore{
|
||||
client: client,
|
||||
popAndCycleScript: redis.NewScript(script),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RedisStore) Set(key string, value []byte, ttl time.Duration) error {
|
||||
return s.client.Set(context.Background(), key, value, ttl).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Get(key string) ([]byte, error) {
|
||||
val, err := s.client.Get(context.Background(), key).Bytes()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) Del(keys ...string) error {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.Del(context.Background(), keys...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Exists(key string) (bool, error) {
|
||||
val, err := s.client.Exists(context.Background(), key).Result()
|
||||
return val > 0, err
|
||||
}
|
||||
|
||||
func (s *RedisStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
return s.client.SetNX(context.Background(), key, value, ttl).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Close() error {
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HSet(key string, values map[string]any) error {
|
||||
return s.client.HSet(context.Background(), key, values).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HGetAll(key string) (map[string]string, error) {
|
||||
return s.client.HGetAll(context.Background(), key).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HIncrBy(key, field string, incr int64) (int64, error) {
|
||||
return s.client.HIncrBy(context.Background(), key, field, incr).Result()
|
||||
}
|
||||
func (s *RedisStore) HDel(key string, fields ...string) error {
|
||||
if len(fields) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.HDel(context.Background(), key, fields...).Err()
|
||||
}
|
||||
func (s *RedisStore) LPush(key string, values ...any) error {
|
||||
return s.client.LPush(context.Background(), key, values...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) LRem(key string, count int64, value any) error {
|
||||
return s.client.LRem(context.Background(), key, count, value).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Rotate(key string) (string, error) {
|
||||
val, err := s.client.RPopLPush(context.Background(), key, key).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) SAdd(key string, members ...any) error {
|
||||
return s.client.SAdd(context.Background(), key, members...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SPopN(key string, count int64) ([]string, error) {
|
||||
return s.client.SPopN(context.Background(), key, count).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SMembers(key string) ([]string, error) {
|
||||
return s.client.SMembers(context.Background(), key).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SRem(key string, members ...any) error {
|
||||
if len(members) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.SRem(context.Background(), key, members...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SRandMember(key string) (string, error) {
|
||||
member, err := s.client.SRandMember(context.Background(), key).Result()
|
||||
if err != nil {
|
||||
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return member, nil
|
||||
}
|
||||
|
||||
// === 新增方法实现 ===
|
||||
|
||||
func (s *RedisStore) ZAdd(key string, members map[string]float64) error {
|
||||
if len(members) == 0 {
|
||||
return nil
|
||||
}
|
||||
redisMembers := make([]redis.Z, 0, len(members))
|
||||
for member, score := range members {
|
||||
redisMembers = append(redisMembers, redis.Z{Score: score, Member: member})
|
||||
}
|
||||
return s.client.ZAdd(context.Background(), key, redisMembers...).Err()
|
||||
}
|
||||
func (s *RedisStore) ZRange(key string, start, stop int64) ([]string, error) {
|
||||
return s.client.ZRange(context.Background(), key, start, stop).Result()
|
||||
}
|
||||
func (s *RedisStore) ZRem(key string, members ...any) error {
|
||||
if len(members) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.ZRem(context.Background(), key, members...).Err()
|
||||
}
|
||||
func (s *RedisStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
|
||||
val, err := s.popAndCycleScript.Run(context.Background(), s.client, []string{mainKey, cooldownKey}).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
// Lua script returns a string, so we need to type assert
|
||||
if str, ok := val.(string); ok {
|
||||
return str, nil
|
||||
}
|
||||
return "", ErrNotFound // This happens if both sets were empty and the script returned nil
|
||||
}
|
||||
|
||||
type redisPipeliner struct{ pipe redis.Pipeliner }
|
||||
|
||||
func (p *redisPipeliner) HSet(key string, values map[string]any) {
|
||||
p.pipe.HSet(context.Background(), key, values)
|
||||
}
|
||||
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
|
||||
p.pipe.HIncrBy(context.Background(), key, field, incr)
|
||||
}
|
||||
func (p *redisPipeliner) Exec() error {
|
||||
_, err := p.pipe.Exec(context.Background())
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Del(keys ...string) {
|
||||
if len(keys) > 0 {
|
||||
p.pipe.Del(context.Background(), keys...)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) SAdd(key string, members ...any) {
|
||||
p.pipe.SAdd(context.Background(), key, members...)
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) SRem(key string, members ...any) {
|
||||
if len(members) > 0 {
|
||||
p.pipe.SRem(context.Background(), key, members...)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) LPush(key string, values ...any) {
|
||||
p.pipe.LPush(context.Background(), key, values...)
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) LRem(key string, count int64, value any) {
|
||||
p.pipe.LRem(context.Background(), key, count, value)
|
||||
}
|
||||
|
||||
func (s *RedisStore) LIndex(key string, index int64) (string, error) {
|
||||
val, err := s.client.LIndex(context.Background(), key, index).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Expire(key string, expiration time.Duration) {
|
||||
p.pipe.Expire(context.Background(), key, expiration)
|
||||
}
|
||||
|
||||
func (s *RedisStore) Pipeline() Pipeliner {
|
||||
return &redisPipeliner{pipe: s.client.Pipeline()}
|
||||
}
|
||||
|
||||
type redisSubscription struct {
|
||||
pubsub *redis.PubSub
|
||||
msgChan chan *Message
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (rs *redisSubscription) Channel() <-chan *Message {
|
||||
rs.once.Do(func() {
|
||||
rs.msgChan = make(chan *Message)
|
||||
go func() {
|
||||
defer close(rs.msgChan)
|
||||
for redisMsg := range rs.pubsub.Channel() {
|
||||
rs.msgChan <- &Message{
|
||||
Channel: redisMsg.Channel,
|
||||
Payload: []byte(redisMsg.Payload),
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
return rs.msgChan
|
||||
}
|
||||
|
||||
func (rs *redisSubscription) Close() error { return rs.pubsub.Close() }
|
||||
|
||||
func (s *RedisStore) Publish(channel string, message []byte) error {
|
||||
return s.client.Publish(context.Background(), channel, message).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Subscribe(channel string) (Subscription, error) {
|
||||
pubsub := s.client.Subscribe(context.Background(), channel)
|
||||
_, err := pubsub.Receive(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err)
|
||||
}
|
||||
return &redisSubscription{pubsub: pubsub}, nil
|
||||
}
|
||||
88
internal/store/store.go
Normal file
88
internal/store/store.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrNotFound is returned when a key is not found in the store.
|
||||
var ErrNotFound = errors.New("key not found")
|
||||
|
||||
// Message is the struct for received pub/sub messages.
|
||||
type Message struct {
|
||||
Channel string
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
// Subscription represents an active subscription to a pub/sub channel.
|
||||
type Subscription interface {
|
||||
Channel() <-chan *Message
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Pipeliner defines an interface for executing a batch of commands.
|
||||
type Pipeliner interface {
|
||||
// General
|
||||
Del(keys ...string)
|
||||
Expire(key string, expiration time.Duration)
|
||||
|
||||
// HASH
|
||||
HSet(key string, values map[string]any)
|
||||
HIncrBy(key, field string, incr int64)
|
||||
|
||||
// SET
|
||||
SAdd(key string, members ...any)
|
||||
SRem(key string, members ...any)
|
||||
|
||||
// LIST
|
||||
LPush(key string, values ...any)
|
||||
LRem(key string, count int64, value any)
|
||||
|
||||
// Execution
|
||||
Exec() error
|
||||
}
|
||||
|
||||
// Store is the master interface for our cache service.
|
||||
type Store interface {
|
||||
// Basic K/V operations
|
||||
Set(key string, value []byte, ttl time.Duration) error
|
||||
Get(key string) ([]byte, error)
|
||||
Del(keys ...string) error
|
||||
Exists(key string) (bool, error)
|
||||
SetNX(key string, value []byte, ttl time.Duration) (bool, error)
|
||||
|
||||
// HASH operations
|
||||
HSet(key string, values map[string]any) error
|
||||
HGetAll(key string) (map[string]string, error)
|
||||
HIncrBy(key, field string, incr int64) (int64, error)
|
||||
HDel(key string, fields ...string) error // [新增]
|
||||
|
||||
// LIST operations
|
||||
LPush(key string, values ...any) error
|
||||
LRem(key string, count int64, value any) error
|
||||
Rotate(key string) (string, error)
|
||||
LIndex(key string, index int64) (string, error)
|
||||
|
||||
// SET operations
|
||||
SAdd(key string, members ...any) error
|
||||
SPopN(key string, count int64) ([]string, error)
|
||||
SMembers(key string) ([]string, error)
|
||||
SRem(key string, members ...any) error
|
||||
SRandMember(key string) (string, error)
|
||||
|
||||
// Pub/Sub operations
|
||||
Publish(channel string, message []byte) error
|
||||
Subscribe(channel string) (Subscription, error)
|
||||
|
||||
// Pipeline (optional) - 我们在redis实现它,内存版暂时不实现
|
||||
Pipeline() Pipeliner
|
||||
|
||||
// Close closes the store and releases any underlying resources.
|
||||
Close() error
|
||||
|
||||
// === 新增方法,支持轮询策略 ===
|
||||
ZAdd(key string, members map[string]float64) error
|
||||
ZRange(key string, start, stop int64) ([]string, error)
|
||||
ZRem(key string, members ...any) error
|
||||
PopAndCycleSetMember(mainKey, cooldownKey string) (string, error)
|
||||
}
|
||||
Reference in New Issue
Block a user