// Filename: internal/store/memory_store.go (统一存储重构版) package store import ( "fmt" "math/rand" "sort" "sync" "time" "github.com/sirupsen/logrus" ) // ensure memoryStore implements Store interface var _ Store = (*memoryStore)(nil) // [核心重构] memoryStoreItem 现在是通用容器,可以存储任何类型的值,并自带过期时间 type memoryStoreItem struct { value interface{} // 可以是 []byte, []string, map[string]string, map[string]struct{}, []zsetMember expireAt time.Time } // isExpired 检查一个条目是否已过期 func (item *memoryStoreItem) isExpired() bool { return !item.expireAt.IsZero() && time.Now().After(item.expireAt) } // zsetMember 保持不变 type zsetMember struct { Value string Score float64 } // [核心重构] memoryStore 现在使用一个统一的 map 来存储所有数据 type memoryStore struct { items map[string]*memoryStoreItem // 指向 item 的指针,以便原地修改 pubsub map[string][]chan *Message mu sync.RWMutex logger *logrus.Entry } // NewMemoryStore [核心重構] 構造函數也被簡化了 func NewMemoryStore(logger *logrus.Logger) Store { return &memoryStore{ items: make(map[string]*memoryStoreItem), pubsub: make(map[string][]chan *Message), logger: logger.WithField("component", "store.memory 🗱"), } } // [核心重构] getItem 是一个新的内部辅助函数,它封装了获取、检查过期和删除的通用逻辑 func (s *memoryStore) getItem(key string, lockForWrite bool) *memoryStoreItem { if !lockForWrite { // 如果是读操作,先用读锁检查 s.mu.RLock() item, ok := s.items[key] if !ok || item.isExpired() { s.mu.RUnlock() // 如果不存在或已过期,需要尝试获取写锁来删除它 if ok { // 只有在确定 item 存在但已过期时才需要删除 s.mu.Lock() // 再次检查,防止在获取写锁期间状态已改变 if item, ok := s.items[key]; ok && item.isExpired() { delete(s.items, key) } s.mu.Unlock() } return nil // 无论如何都返回 nil } s.mu.RUnlock() return item } // 对于写操作,直接使用写锁 item, ok := s.items[key] if ok && item.isExpired() { delete(s.items, key) return nil } return item } // --- 所有接口方法现在都基于新的统一结构重写 --- func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error { s.mu.Lock() defer s.mu.Unlock() var expireAt time.Time if ttl > 0 { expireAt = time.Now().Add(ttl) } s.items[key] = &memoryStoreItem{value: value, expireAt: expireAt} return nil } func (s *memoryStore) Get(key string) ([]byte, error) { item := s.getItem(key, false) if item == nil { return nil, ErrNotFound } if value, ok := item.value.([]byte); ok { return value, nil } return nil, ErrNotFound // Type mismatch, treat as not found } func (s *memoryStore) Del(keys ...string) error { s.mu.Lock() defer s.mu.Unlock() for _, key := range keys { delete(s.items, key) } return nil } func (s *memoryStore) delNoLock(keys ...string) { for _, key := range keys { delete(s.items, key) } } func (s *memoryStore) Exists(key string) (bool, error) { return s.getItem(key, false) != nil, nil } func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) { s.mu.Lock() defer s.mu.Unlock() if item := s.getItem(key, true); item != nil { return false, nil } var expireAt time.Time if ttl > 0 { expireAt = time.Now().Add(ttl) } s.items[key] = &memoryStoreItem{value: value, expireAt: expireAt} return true, nil } func (s *memoryStore) Close() error { return nil } func (s *memoryStore) HDel(key string, fields ...string) error { s.mu.Lock() defer s.mu.Unlock() item := s.getItem(key, true) if item == nil { return nil } if hash, ok := item.value.(map[string]string); ok { for _, field := range fields { delete(hash, field) } } return nil } func (s *memoryStore) HSet(key string, values map[string]any) error { s.mu.Lock() defer s.mu.Unlock() s.hSetNoLock(key, values) return nil } func (s *memoryStore) hSetNoLock(key string, values map[string]any) { item := s.getItem(key, true) if item == nil { item = &memoryStoreItem{value: make(map[string]string)} s.items[key] = item } hash, ok := item.value.(map[string]string) if !ok { // If key exists but is not a hash, create a new hash hash = make(map[string]string) item.value = hash } for field, value := range values { hash[field] = fmt.Sprintf("%v", value) } } func (s *memoryStore) HGetAll(key string) (map[string]string, error) { item := s.getItem(key, false) if item == nil { return make(map[string]string), nil } if hash, ok := item.value.(map[string]string); ok { result := make(map[string]string, len(hash)) for k, v := range hash { result[k] = v } return result, nil } return make(map[string]string), nil } func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) { s.mu.Lock() defer s.mu.Unlock() return s.hIncrByNoLock(key, field, incr) } func (s *memoryStore) hIncrByNoLock(key, field string, incr int64) (int64, error) { item := s.getItem(key, true) if item == nil { item = &memoryStoreItem{value: make(map[string]string)} s.items[key] = item } hash, ok := item.value.(map[string]string) if !ok { hash = make(map[string]string) item.value = hash } var currentVal int64 if valStr, ok := hash[field]; ok { fmt.Sscanf(valStr, "%d", ¤tVal) } newVal := currentVal + incr hash[field] = fmt.Sprintf("%d", newVal) return newVal, nil } func (s *memoryStore) LPush(key string, values ...any) error { s.mu.Lock() defer s.mu.Unlock() s.lPushNoLock(key, values...) return nil } func (s *memoryStore) lPushNoLock(key string, values ...any) { item := s.getItem(key, true) if item == nil { item = &memoryStoreItem{value: make([]string, 0)} s.items[key] = item } list, ok := item.value.([]string) if !ok { list = make([]string, 0) } stringValues := make([]string, len(values)) for i, v := range values { stringValues[i] = fmt.Sprintf("%v", v) } item.value = append(stringValues, list...) } func (s *memoryStore) LRem(key string, count int64, value any) error { s.mu.Lock() defer s.mu.Unlock() s.lRemNoLock(key, count, value) return nil } func (s *memoryStore) lRemNoLock(key string, count int64, value any) { item := s.getItem(key, true) if item == nil { return } list, ok := item.value.([]string) if !ok { return } valToRemove := fmt.Sprintf("%v", value) newList := make([]string, 0, len(list)) removedCount := int64(0) for _, v := range list { if v == valToRemove && (count == 0 || removedCount < count) { removedCount++ } else { newList = append(newList, v) } } item.value = newList } func (s *memoryStore) SAdd(key string, members ...any) error { s.mu.Lock() defer s.mu.Unlock() s.sAddNoLock(key, members...) return nil } func (s *memoryStore) sAddNoLock(key string, members ...any) { item := s.getItem(key, true) if item == nil { item = &memoryStoreItem{value: make(map[string]struct{})} s.items[key] = item } set, ok := item.value.(map[string]struct{}) if !ok { set = make(map[string]struct{}) item.value = set } for _, member := range members { set[fmt.Sprintf("%v", member)] = struct{}{} } } func (s *memoryStore) SPopN(key string, count int64) ([]string, error) { s.mu.Lock() defer s.mu.Unlock() item := s.getItem(key, true) if item == nil { return []string{}, nil } set, ok := item.value.(map[string]struct{}) if !ok || len(set) == 0 { return []string{}, nil } if int64(len(set)) < count { count = int64(len(set)) } popped := make([]string, 0, count) keys := make([]string, 0, len(set)) for k := range set { keys = append(keys, k) } rand.Shuffle(len(keys), func(i, j int) { keys[i], keys[j] = keys[j], keys[i] }) for i := int64(0); i < count; i++ { poppedKey := keys[i] popped = append(popped, poppedKey) delete(set, poppedKey) } return popped, nil } func (s *memoryStore) SMembers(key string) ([]string, error) { item := s.getItem(key, false) if item == nil { return []string{}, nil } set, ok := item.value.(map[string]struct{}) if !ok { return []string{}, nil } s.mu.RLock() // Lock needed for iterating map defer s.mu.RUnlock() members := make([]string, 0, len(set)) for member := range set { members = append(members, member) } return members, nil } func (s *memoryStore) SRem(key string, members ...any) error { s.mu.Lock() defer s.mu.Unlock() s.sRemNoLock(key, members...) return nil } func (s *memoryStore) sRemNoLock(key string, members ...any) { item := s.getItem(key, true) if item == nil { return } set, ok := item.value.(map[string]struct{}) if !ok { return } for _, member := range members { delete(set, fmt.Sprintf("%v", member)) } } func (s *memoryStore) SRandMember(key string) (string, error) { item := s.getItem(key, false) if item == nil { return "", ErrNotFound } set, ok := item.value.(map[string]struct{}) if !ok || len(set) == 0 { return "", ErrNotFound } s.mu.RLock() defer s.mu.RUnlock() members := make([]string, 0, len(set)) for member := range set { members = append(members, member) } if len(members) == 0 { return "", ErrNotFound } return members[rand.Intn(len(members))], nil } func (s *memoryStore) Rotate(key string) (string, error) { s.mu.Lock() defer s.mu.Unlock() item := s.getItem(key, true) if item == nil { return "", ErrNotFound } list, ok := item.value.([]string) if !ok || len(list) == 0 { return "", ErrNotFound } val := list[len(list)-1] item.value = append([]string{val}, list[:len(list)-1]...) return val, nil } func (s *memoryStore) LIndex(key string, index int64) (string, error) { item := s.getItem(key, false) if item == nil { return "", ErrNotFound } list, ok := item.value.([]string) if !ok { return "", ErrNotFound } s.mu.RLock() defer s.mu.RUnlock() l := int64(len(list)) if index < 0 { index += l } if index < 0 || index >= l { return "", ErrNotFound } return list[index], nil } // Zset methods... (ZAdd, ZRange, ZRem) func (s *memoryStore) ZAdd(key string, members map[string]float64) error { s.mu.Lock() defer s.mu.Unlock() item := s.getItem(key, true) if item == nil { item = &memoryStoreItem{value: make([]zsetMember, 0)} s.items[key] = item } zset, ok := item.value.([]zsetMember) if !ok { zset = make([]zsetMember, 0) } for memberVal, score := range members { found := false for i := range zset { if zset[i].Value == memberVal { zset[i].Score = score found = true break } } if !found { zset = append(zset, zsetMember{Value: memberVal, Score: score}) } } sort.Slice(zset, func(i, j int) bool { if zset[i].Score == zset[j].Score { return zset[i].Value < zset[j].Value } return zset[i].Score < zset[j].Score }) item.value = zset return nil } func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) { item := s.getItem(key, false) if item == nil { return []string{}, nil } zset, ok := item.value.([]zsetMember) if !ok { return []string{}, nil } s.mu.RLock() defer s.mu.RUnlock() l := int64(len(zset)) if start < 0 { start += l } if stop < 0 { stop += l } if start < 0 { start = 0 } if start > stop || start >= l { return []string{}, nil } if stop >= l { stop = l - 1 } result := make([]string, 0, stop-start+1) for i := start; i <= stop; i++ { result = append(result, zset[i].Value) } return result, nil } func (s *memoryStore) ZRem(key string, members ...any) error { s.mu.Lock() defer s.mu.Unlock() item := s.getItem(key, true) if item == nil { return nil } zset, ok := item.value.([]zsetMember) if !ok { return nil } membersToRemove := make(map[string]struct{}, len(members)) for _, m := range members { membersToRemove[fmt.Sprintf("%v", m)] = struct{}{} } newZSet := make([]zsetMember, 0, len(zset)) for _, z := range zset { if _, exists := membersToRemove[z.Value]; !exists { newZSet = append(newZSet, z) } } item.value = newZSet return nil } func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) { s.mu.Lock() defer s.mu.Unlock() mainItem := s.getItem(mainKey, true) if mainItem == nil || len(mainItem.value.(map[string]struct{})) == 0 { cooldownItem := s.getItem(cooldownKey, true) if cooldownItem == nil { return "", ErrNotFound } // "Rename" by moving value and deleting old key s.items[mainKey] = cooldownItem delete(s.items, cooldownKey) mainItem = cooldownItem } mainSet, ok := mainItem.value.(map[string]struct{}) if !ok || len(mainSet) == 0 { return "", ErrNotFound // Should not happen after cycle logic } var popped string for k := range mainSet { popped = k break } delete(mainSet, popped) return popped, nil } // Pipeline implementation type memoryPipeliner struct { store *memoryStore ops []func() } func (s *memoryStore) Pipeline() Pipeliner { return &memoryPipeliner{store: s} } func (p *memoryPipeliner) Exec() error { p.store.mu.Lock() defer p.store.mu.Unlock() for _, op := range p.ops { op() } return nil } // [核心修正] Expire 现在可以正确地为任何 key 设置过期时间 func (p *memoryPipeliner) Expire(key string, expiration time.Duration) { p.ops = append(p.ops, func() { // This must be called within Exec's lock item := p.store.getItem(key, true) if item != nil { item.expireAt = time.Now().Add(expiration) } }) } // All other pipeliner methods... func (p *memoryPipeliner) HSet(key string, values map[string]any) { p.ops = append(p.ops, func() { p.store.hSetNoLock(key, values) }) } func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) { p.ops = append(p.ops, func() { p.store.hIncrByNoLock(key, field, incr) }) } func (p *memoryPipeliner) Del(keys ...string) { p.ops = append(p.ops, func() { p.store.delNoLock(keys...) }) } func (p *memoryPipeliner) SAdd(key string, members ...any) { p.ops = append(p.ops, func() { p.store.sAddNoLock(key, members...) }) } func (p *memoryPipeliner) SRem(key string, members ...any) { p.ops = append(p.ops, func() { p.store.sRemNoLock(key, members...) }) } func (p *memoryPipeliner) LPush(key string, values ...any) { p.ops = append(p.ops, func() { p.store.lPushNoLock(key, values...) }) } func (p *memoryPipeliner) LRem(key string, count int64, value any) { p.ops = append(p.ops, func() { p.store.lRemNoLock(key, count, value) }) } // Pub/Sub implementation (remains unchanged as it's a separate system) type memorySubscription struct { store *memoryStore channelName string msgChan chan *Message } func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan } func (ms *memorySubscription) Close() error { return ms.store.removeSubscriber(ms.channelName, ms.msgChan) } func (s *memoryStore) Publish(channel string, message []byte) error { s.mu.RLock() defer s.mu.RUnlock() subscribers, ok := s.pubsub[channel] if !ok { return nil } msg := &Message{Channel: channel, Payload: message} for _, ch := range subscribers { select { case ch <- msg: case <-time.After(100 * time.Millisecond): s.logger.Warnf("Could not publish to a subscriber on channel '%s' within 100ms", channel) } } return nil } func (s *memoryStore) Subscribe(channel string) (Subscription, error) { s.mu.Lock() defer s.mu.Unlock() msgChan := make(chan *Message, 10) sub := &memorySubscription{store: s, channelName: channel, msgChan: msgChan} s.pubsub[channel] = append(s.pubsub[channel], msgChan) return sub, nil } func (s *memoryStore) removeSubscriber(channelName string, msgChan chan *Message) error { s.mu.Lock() defer s.mu.Unlock() subscribers, ok := s.pubsub[channelName] if !ok { return nil } newSubscribers := make([]chan *Message, 0) for _, ch := range subscribers { if ch != msgChan { newSubscribers = append(newSubscribers, ch) } } if len(newSubscribers) == 0 { delete(s.pubsub, channelName) } else { s.pubsub[channelName] = newSubscribers } close(msgChan) return nil }