// Filename: internal/store/memory_store.go package store import ( "context" "fmt" "math/rand" "sort" "sync" "time" "github.com/sirupsen/logrus" ) // ensure memoryStore implements Store interface var _ Store = (*memoryStore)(nil) type memoryStoreItem struct { value interface{} expireAt time.Time } func (item *memoryStoreItem) isExpired() bool { return !item.expireAt.IsZero() && time.Now().After(item.expireAt) } type zsetMember struct { Value string Score float64 } type memoryStore struct { items map[string]*memoryStoreItem pubsub map[string][]chan *Message mu sync.RWMutex rng *rand.Rand rngMu sync.Mutex logger *logrus.Entry } func NewMemoryStore(logger *logrus.Logger) Store { store := &memoryStore{ items: make(map[string]*memoryStoreItem), pubsub: make(map[string][]chan *Message), rng: rand.New(rand.NewSource(time.Now().UnixNano())), logger: logger.WithField("component", "store.memory 🗱"), } go store.startGCollector() return store } func (s *memoryStore) startGCollector() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { s.mu.Lock() now := time.Now() for key, item := range s.items { if !item.expireAt.IsZero() && now.After(item.expireAt) { delete(s.items, key) } } s.mu.Unlock() } } // --- [架构修正] 所有方法签名都增加了 context.Context 参数以匹配接口 --- // --- 内存实现可以忽略该参数,用 _ 接收 --- func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error { s.mu.Lock() defer s.mu.Unlock() var expireAt time.Time if ttl > 0 { expireAt = time.Now().Add(ttl) } s.items[key] = &memoryStoreItem{value: value, expireAt: expireAt} return nil } func (s *memoryStore) Get(_ context.Context, key string) ([]byte, error) { s.mu.RLock() defer s.mu.RUnlock() item, ok := s.items[key] if !ok || item.isExpired() { return nil, ErrNotFound } if value, ok := item.value.([]byte); ok { return value, nil } return nil, ErrNotFound } func (s *memoryStore) Del(_ context.Context, keys ...string) error { s.mu.Lock() defer s.mu.Unlock() for _, key := range keys { delete(s.items, key) } return nil } func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) { s.mu.RLock() defer s.mu.RUnlock() item, ok := s.items[key] return ok && !item.isExpired(), nil } func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] if ok && !item.isExpired() { return false, nil } var expireAt time.Time if ttl > 0 { expireAt = time.Now().Add(ttl) } s.items[key] = &memoryStoreItem{value: value, expireAt: expireAt} return true, nil } func (s *memoryStore) Close() error { return nil } func (s *memoryStore) HDel(_ context.Context, key string, fields ...string) error { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] if !ok || item.isExpired() { return nil } if hash, ok := item.value.(map[string]string); ok { for _, field := range fields { delete(hash, field) } } return nil } func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any) error { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] if !ok || item.isExpired() { item = &memoryStoreItem{value: make(map[string]string)} s.items[key] = item } hash, ok := item.value.(map[string]string) if !ok { hash = make(map[string]string) item.value = hash } for field, value := range values { hash[field] = fmt.Sprintf("%v", value) } return nil } func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) { s.mu.RLock() defer s.mu.RUnlock() item, ok := s.items[key] if !ok || item.isExpired() { return make(map[string]string), nil } if hash, ok := item.value.(map[string]string); ok { result := make(map[string]string, len(hash)) for k, v := range hash { result[k] = v } return result, nil } return make(map[string]string), nil } func (s *memoryStore) HIncrBy(_ context.Context, key, field string, incr int64) (int64, error) { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] if !ok || item.isExpired() { item = &memoryStoreItem{value: make(map[string]string)} s.items[key] = item } hash, ok := item.value.(map[string]string) if !ok { hash = make(map[string]string) item.value = hash } var currentVal int64 if valStr, ok := hash[field]; ok { fmt.Sscanf(valStr, "%d", ¤tVal) } newVal := currentVal + incr hash[field] = fmt.Sprintf("%d", newVal) return newVal, nil } func (s *memoryStore) LPush(_ context.Context, key string, values ...any) error { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] if !ok || item.isExpired() { item = &memoryStoreItem{value: make([]string, 0)} s.items[key] = item } list, ok := item.value.([]string) if !ok { list = make([]string, 0) } stringValues := make([]string, len(values)) for i, v := range values { stringValues[i] = fmt.Sprintf("%v", v) } item.value = append(stringValues, list...) return nil } func (s *memoryStore) LRem(_ context.Context, key string, count int64, value any) error { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] if !ok || item.isExpired() { return nil } list, ok := item.value.([]string) if !ok { return nil } valToRemove := fmt.Sprintf("%v", value) newList := make([]string, 0, len(list)) removedCount := int64(0) for _, v := range list { if v == valToRemove && (count == 0 || removedCount < count) { removedCount++ } else { newList = append(newList, v) } } item.value = newList return nil } func (s *memoryStore) SAdd(_ context.Context, key string, members ...any) error { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] if !ok || item.isExpired() { item = &memoryStoreItem{value: make(map[string]struct{})} s.items[key] = item } set, ok := item.value.(map[string]struct{}) if !ok { set = make(map[string]struct{}) item.value = set } for _, member := range members { set[fmt.Sprintf("%v", member)] = struct{}{} } return nil } func (s *memoryStore) SPopN(_ context.Context, key string, count int64) ([]string, error) { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] if !ok || item.isExpired() { return []string{}, nil } set, ok := item.value.(map[string]struct{}) if !ok || len(set) == 0 { return []string{}, nil } if int64(len(set)) < count { count = int64(len(set)) } popped := make([]string, 0, count) keys := make([]string, 0, len(set)) for k := range set { keys = append(keys, k) } s.rngMu.Lock() s.rng.Shuffle(len(keys), func(i, j int) { keys[i], keys[j] = keys[j], keys[i] }) s.rngMu.Unlock() for i := int64(0); i < count; i++ { poppedKey := keys[i] popped = append(popped, poppedKey) delete(set, poppedKey) } return popped, nil } func (s *memoryStore) SMembers(_ context.Context, key string) ([]string, error) { s.mu.RLock() defer s.mu.RUnlock() item, ok := s.items[key] if !ok || item.isExpired() { return []string{}, nil } set, ok := item.value.(map[string]struct{}) if !ok { return []string{}, nil } members := make([]string, 0, len(set)) for member := range set { members = append(members, member) } return members, nil } func (s *memoryStore) SRem(_ context.Context, key string, members ...any) error { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] if !ok || item.isExpired() { return nil } set, ok := item.value.(map[string]struct{}) if !ok { return nil } for _, member := range members { delete(set, fmt.Sprintf("%v", member)) } return nil } func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error) { s.mu.RLock() defer s.mu.RUnlock() item, ok := s.items[key] if !ok || item.isExpired() { return "", ErrNotFound } set, ok := item.value.(map[string]struct{}) if !ok || len(set) == 0 { return "", ErrNotFound } members := make([]string, 0, len(set)) for member := range set { members = append(members, member) } if len(members) == 0 { return "", ErrNotFound } s.rngMu.Lock() n := s.rng.Intn(len(members)) s.rngMu.Unlock() return members[n], nil } func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] if !ok || item.isExpired() { return "", ErrNotFound } list, ok := item.value.([]string) if !ok || len(list) == 0 { return "", ErrNotFound } val := list[len(list)-1] item.value = append([]string{val}, list[:len(list)-1]...) return val, nil } func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string, error) { s.mu.RLock() defer s.mu.RUnlock() item, ok := s.items[key] if !ok || item.isExpired() { return "", ErrNotFound } list, ok := item.value.([]string) if !ok { return "", ErrNotFound } l := int64(len(list)) if index < 0 { index += l } if index < 0 || index >= l { return "", ErrNotFound } return list[index], nil } func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] if !ok || item.isExpired() { item = &memoryStoreItem{value: make([]zsetMember, 0)} s.items[key] = item } zset, ok := item.value.([]zsetMember) if !ok { zset = make([]zsetMember, 0) } membersMap := make(map[string]float64, len(zset)) for _, z := range zset { membersMap[z.Value] = z.Score } for memberVal, score := range members { membersMap[memberVal] = score } newZSet := make([]zsetMember, 0, len(membersMap)) for val, score := range membersMap { newZSet = append(newZSet, zsetMember{Value: val, Score: score}) } sort.Slice(newZSet, func(i, j int) bool { if newZSet[i].Score == newZSet[j].Score { return newZSet[i].Value < newZSet[j].Value } return newZSet[i].Score < newZSet[j].Score }) item.value = newZSet return nil } func (s *memoryStore) ZRange(_ context.Context, key string, start, stop int64) ([]string, error) { s.mu.RLock() defer s.mu.RUnlock() item, ok := s.items[key] if !ok || item.isExpired() { return []string{}, nil } zset, ok := item.value.([]zsetMember) if !ok { return []string{}, nil } l := int64(len(zset)) if start < 0 { start += l } if stop < 0 { stop += l } if start < 0 { start = 0 } if start > stop || start >= l { return []string{}, nil } if stop >= l { stop = l - 1 } result := make([]string, 0, stop-start+1) for i := start; i <= stop; i++ { result = append(result, zset[i].Value) } return result, nil } func (s *memoryStore) ZRem(_ context.Context, key string, members ...any) error { s.mu.Lock() defer s.mu.Unlock() item, ok := s.items[key] if !ok || item.isExpired() { return nil } zset, ok := item.value.([]zsetMember) if !ok { return nil } membersToRemove := make(map[string]struct{}, len(members)) for _, m := range members { membersToRemove[fmt.Sprintf("%v", m)] = struct{}{} } newZSet := make([]zsetMember, 0, len(zset)) for _, z := range zset { if _, exists := membersToRemove[z.Value]; !exists { newZSet = append(newZSet, z) } } item.value = newZSet return nil } func (s *memoryStore) PopAndCycleSetMember(_ context.Context, mainKey, cooldownKey string) (string, error) { s.mu.Lock() defer s.mu.Unlock() mainItem, mainOk := s.items[mainKey] var mainSet map[string]struct{} if mainOk && !mainItem.isExpired() { mainSet, mainOk = mainItem.value.(map[string]struct{}) mainOk = mainOk && len(mainSet) > 0 } else { mainOk = false } if !mainOk { cooldownItem, cooldownOk := s.items[cooldownKey] if !cooldownOk || cooldownItem.isExpired() { return "", ErrNotFound } cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{}) if !cooldownSetOk || len(cooldownSet) == 0 { return "", ErrNotFound } s.items[mainKey] = cooldownItem delete(s.items, cooldownKey) mainSet = cooldownSet } var popped string for k := range mainSet { popped = k break } delete(mainSet, popped) cooldownItem, cooldownOk := s.items[cooldownKey] if !cooldownOk || cooldownItem.isExpired() { cooldownItem = &memoryStoreItem{value: make(map[string]struct{})} s.items[cooldownKey] = cooldownItem } cooldownSet, ok := cooldownItem.value.(map[string]struct{}) if !ok { cooldownSet = make(map[string]struct{}) cooldownItem.value = cooldownSet } cooldownSet[popped] = struct{}{} return popped, nil } type memoryPipeliner struct { store *memoryStore ops []func() } func (s *memoryStore) Pipeline(_ context.Context) Pipeliner { return &memoryPipeliner{store: s} } func (p *memoryPipeliner) Exec() error { p.store.mu.Lock() defer p.store.mu.Unlock() for _, op := range p.ops { op() } return nil } func (p *memoryPipeliner) Expire(key string, expiration time.Duration) { capturedKey := key p.ops = append(p.ops, func() { if item, ok := p.store.items[capturedKey]; ok { item.expireAt = time.Now().Add(expiration) } }) } func (p *memoryPipeliner) Del(keys ...string) { capturedKeys := make([]string, len(keys)) copy(capturedKeys, keys) p.ops = append(p.ops, func() { for _, key := range capturedKeys { delete(p.store.items, key) } }) } func (p *memoryPipeliner) SAdd(key string, members ...any) { capturedKey := key capturedMembers := make([]any, len(members)) copy(capturedMembers, members) p.ops = append(p.ops, func() { item, ok := p.store.items[capturedKey] if !ok || item.isExpired() { item = &memoryStoreItem{value: make(map[string]struct{})} p.store.items[capturedKey] = item } set, ok := item.value.(map[string]struct{}) if !ok { set = make(map[string]struct{}) item.value = set } for _, member := range capturedMembers { set[fmt.Sprintf("%v", member)] = struct{}{} } }) } func (p *memoryPipeliner) SRem(key string, members ...any) { capturedKey := key capturedMembers := make([]any, len(members)) copy(capturedMembers, members) p.ops = append(p.ops, func() { item, ok := p.store.items[capturedKey] if !ok || item.isExpired() { return } set, ok := item.value.(map[string]struct{}) if !ok { return } for _, member := range capturedMembers { delete(set, fmt.Sprintf("%v", member)) } }) } func (p *memoryPipeliner) LPush(key string, values ...any) { capturedKey := key capturedValues := make([]any, len(values)) copy(capturedValues, values) p.ops = append(p.ops, func() { item, ok := p.store.items[capturedKey] if !ok || item.isExpired() { item = &memoryStoreItem{value: make([]string, 0)} p.store.items[capturedKey] = item } list, ok := item.value.([]string) if !ok { list = make([]string, 0) } stringValues := make([]string, len(capturedValues)) for i, v := range capturedValues { stringValues[i] = fmt.Sprintf("%v", v) } item.value = append(stringValues, list...) }) } func (p *memoryPipeliner) LRem(key string, count int64, value any) {} func (p *memoryPipeliner) HSet(key string, values map[string]any) {} func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {} func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {} func (p *memoryPipeliner) ZRem(key string, members ...any) {} type memorySubscription struct { store *memoryStore channelName string msgChan chan *Message } func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan } func (ms *memorySubscription) ChannelName() string { return ms.channelName } func (ms *memorySubscription) Close() error { return ms.store.removeSubscriber(ms.channelName, ms.msgChan) } func (s *memoryStore) Publish(_ context.Context, channel string, message []byte) error { s.mu.RLock() defer s.mu.RUnlock() subscribers, ok := s.pubsub[channel] if !ok { return nil } msg := &Message{Channel: channel, Payload: message} for _, ch := range subscribers { select { case ch <- msg: case <-time.After(100 * time.Millisecond): s.logger.Warnf("Could not publish to a subscriber on channel '%s' within 100ms", channel) } } return nil } func (s *memoryStore) Subscribe(_ context.Context, channel string) (Subscription, error) { s.mu.Lock() defer s.mu.Unlock() msgChan := make(chan *Message, 10) sub := &memorySubscription{store: s, channelName: channel, msgChan: msgChan} s.pubsub[channel] = append(s.pubsub[channel], msgChan) return sub, nil } func (s *memoryStore) removeSubscriber(channelName string, msgChan chan *Message) error { s.mu.Lock() defer s.mu.Unlock() subscribers, ok := s.pubsub[channelName] if !ok { return nil } newSubscribers := make([]chan *Message, 0) for _, ch := range subscribers { if ch != msgChan { newSubscribers = append(newSubscribers, ch) } } if len(newSubscribers) == 0 { delete(s.pubsub, channelName) } else { s.pubsub[channelName] = newSubscribers } close(msgChan) return nil }