Files
gemini-banlancer/internal/store/memory_store.go
2025-11-20 12:24:05 +08:00

643 lines
16 KiB
Go

// Filename: internal/store/memory_store.go (统一存储重构版)
package store
import (
"fmt"
"math/rand"
"sort"
"sync"
"time"
"github.com/sirupsen/logrus"
)
// ensure memoryStore implements Store interface
var _ Store = (*memoryStore)(nil)
// [核心重构] memoryStoreItem 现在是通用容器,可以存储任何类型的值,并自带过期时间
type memoryStoreItem struct {
value interface{} // 可以是 []byte, []string, map[string]string, map[string]struct{}, []zsetMember
expireAt time.Time
}
// isExpired 检查一个条目是否已过期
func (item *memoryStoreItem) isExpired() bool {
return !item.expireAt.IsZero() && time.Now().After(item.expireAt)
}
// zsetMember 保持不变
type zsetMember struct {
Value string
Score float64
}
// [核心重构] memoryStore 现在使用一个统一的 map 来存储所有数据
type memoryStore struct {
items map[string]*memoryStoreItem // 指向 item 的指针,以便原地修改
pubsub map[string][]chan *Message
mu sync.RWMutex
logger *logrus.Entry
}
// NewMemoryStore [核心重構] 構造函數也被簡化了
func NewMemoryStore(logger *logrus.Logger) Store {
return &memoryStore{
items: make(map[string]*memoryStoreItem),
pubsub: make(map[string][]chan *Message),
logger: logger.WithField("component", "store.memory 🗱"),
}
}
// [核心重构] getItem 是一个新的内部辅助函数,它封装了获取、检查过期和删除的通用逻辑
func (s *memoryStore) getItem(key string, lockForWrite bool) *memoryStoreItem {
if !lockForWrite {
// 如果是读操作,先用读锁检查
s.mu.RLock()
item, ok := s.items[key]
if !ok || item.isExpired() {
s.mu.RUnlock()
// 如果不存在或已过期,需要尝试获取写锁来删除它
if ok { // 只有在确定 item 存在但已过期时才需要删除
s.mu.Lock()
// 再次检查,防止在获取写锁期间状态已改变
if item, ok := s.items[key]; ok && item.isExpired() {
delete(s.items, key)
}
s.mu.Unlock()
}
return nil // 无论如何都返回 nil
}
s.mu.RUnlock()
return item
}
// 对于写操作,直接使用写锁
item, ok := s.items[key]
if ok && item.isExpired() {
delete(s.items, key)
return nil
}
return item
}
// --- 所有接口方法现在都基于新的统一结构重写 ---
func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
var expireAt time.Time
if ttl > 0 {
expireAt = time.Now().Add(ttl)
}
s.items[key] = &memoryStoreItem{value: value, expireAt: expireAt}
return nil
}
func (s *memoryStore) Get(key string) ([]byte, error) {
item := s.getItem(key, false)
if item == nil {
return nil, ErrNotFound
}
if value, ok := item.value.([]byte); ok {
return value, nil
}
return nil, ErrNotFound // Type mismatch, treat as not found
}
func (s *memoryStore) Del(keys ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
for _, key := range keys {
delete(s.items, key)
}
return nil
}
func (s *memoryStore) delNoLock(keys ...string) {
for _, key := range keys {
delete(s.items, key)
}
}
func (s *memoryStore) Exists(key string) (bool, error) {
return s.getItem(key, false) != nil, nil
}
func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
if item := s.getItem(key, true); item != nil {
return false, nil
}
var expireAt time.Time
if ttl > 0 {
expireAt = time.Now().Add(ttl)
}
s.items[key] = &memoryStoreItem{value: value, expireAt: expireAt}
return true, nil
}
func (s *memoryStore) Close() error { return nil }
func (s *memoryStore) HDel(key string, fields ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
item := s.getItem(key, true)
if item == nil {
return nil
}
if hash, ok := item.value.(map[string]string); ok {
for _, field := range fields {
delete(hash, field)
}
}
return nil
}
func (s *memoryStore) HSet(key string, values map[string]any) error {
s.mu.Lock()
defer s.mu.Unlock()
s.hSetNoLock(key, values)
return nil
}
func (s *memoryStore) hSetNoLock(key string, values map[string]any) {
item := s.getItem(key, true)
if item == nil {
item = &memoryStoreItem{value: make(map[string]string)}
s.items[key] = item
}
hash, ok := item.value.(map[string]string)
if !ok { // If key exists but is not a hash, create a new hash
hash = make(map[string]string)
item.value = hash
}
for field, value := range values {
hash[field] = fmt.Sprintf("%v", value)
}
}
func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
item := s.getItem(key, false)
if item == nil {
return make(map[string]string), nil
}
if hash, ok := item.value.(map[string]string); ok {
result := make(map[string]string, len(hash))
for k, v := range hash {
result[k] = v
}
return result, nil
}
return make(map[string]string), nil
}
func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.hIncrByNoLock(key, field, incr)
}
func (s *memoryStore) hIncrByNoLock(key, field string, incr int64) (int64, error) {
item := s.getItem(key, true)
if item == nil {
item = &memoryStoreItem{value: make(map[string]string)}
s.items[key] = item
}
hash, ok := item.value.(map[string]string)
if !ok {
hash = make(map[string]string)
item.value = hash
}
var currentVal int64
if valStr, ok := hash[field]; ok {
fmt.Sscanf(valStr, "%d", &currentVal)
}
newVal := currentVal + incr
hash[field] = fmt.Sprintf("%d", newVal)
return newVal, nil
}
func (s *memoryStore) LPush(key string, values ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
s.lPushNoLock(key, values...)
return nil
}
func (s *memoryStore) lPushNoLock(key string, values ...any) {
item := s.getItem(key, true)
if item == nil {
item = &memoryStoreItem{value: make([]string, 0)}
s.items[key] = item
}
list, ok := item.value.([]string)
if !ok {
list = make([]string, 0)
}
stringValues := make([]string, len(values))
for i, v := range values {
stringValues[i] = fmt.Sprintf("%v", v)
}
item.value = append(stringValues, list...)
}
func (s *memoryStore) LRem(key string, count int64, value any) error {
s.mu.Lock()
defer s.mu.Unlock()
s.lRemNoLock(key, count, value)
return nil
}
func (s *memoryStore) lRemNoLock(key string, count int64, value any) {
item := s.getItem(key, true)
if item == nil {
return
}
list, ok := item.value.([]string)
if !ok {
return
}
valToRemove := fmt.Sprintf("%v", value)
newList := make([]string, 0, len(list))
removedCount := int64(0)
for _, v := range list {
if v == valToRemove && (count == 0 || removedCount < count) {
removedCount++
} else {
newList = append(newList, v)
}
}
item.value = newList
}
func (s *memoryStore) SAdd(key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
s.sAddNoLock(key, members...)
return nil
}
func (s *memoryStore) sAddNoLock(key string, members ...any) {
item := s.getItem(key, true)
if item == nil {
item = &memoryStoreItem{value: make(map[string]struct{})}
s.items[key] = item
}
set, ok := item.value.(map[string]struct{})
if !ok {
set = make(map[string]struct{})
item.value = set
}
for _, member := range members {
set[fmt.Sprintf("%v", member)] = struct{}{}
}
}
func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item := s.getItem(key, true)
if item == nil {
return []string{}, nil
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return []string{}, nil
}
if int64(len(set)) < count {
count = int64(len(set))
}
popped := make([]string, 0, count)
keys := make([]string, 0, len(set))
for k := range set {
keys = append(keys, k)
}
rand.Shuffle(len(keys), func(i, j int) { keys[i], keys[j] = keys[j], keys[i] })
for i := int64(0); i < count; i++ {
poppedKey := keys[i]
popped = append(popped, poppedKey)
delete(set, poppedKey)
}
return popped, nil
}
func (s *memoryStore) SMembers(key string) ([]string, error) {
item := s.getItem(key, false)
if item == nil {
return []string{}, nil
}
set, ok := item.value.(map[string]struct{})
if !ok {
return []string{}, nil
}
s.mu.RLock() // Lock needed for iterating map
defer s.mu.RUnlock()
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
return members, nil
}
func (s *memoryStore) SRem(key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
s.sRemNoLock(key, members...)
return nil
}
func (s *memoryStore) sRemNoLock(key string, members ...any) {
item := s.getItem(key, true)
if item == nil {
return
}
set, ok := item.value.(map[string]struct{})
if !ok {
return
}
for _, member := range members {
delete(set, fmt.Sprintf("%v", member))
}
}
func (s *memoryStore) SRandMember(key string) (string, error) {
item := s.getItem(key, false)
if item == nil {
return "", ErrNotFound
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return "", ErrNotFound
}
s.mu.RLock()
defer s.mu.RUnlock()
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
if len(members) == 0 {
return "", ErrNotFound
}
return members[rand.Intn(len(members))], nil
}
func (s *memoryStore) Rotate(key string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item := s.getItem(key, true)
if item == nil {
return "", ErrNotFound
}
list, ok := item.value.([]string)
if !ok || len(list) == 0 {
return "", ErrNotFound
}
val := list[len(list)-1]
item.value = append([]string{val}, list[:len(list)-1]...)
return val, nil
}
func (s *memoryStore) LIndex(key string, index int64) (string, error) {
item := s.getItem(key, false)
if item == nil {
return "", ErrNotFound
}
list, ok := item.value.([]string)
if !ok {
return "", ErrNotFound
}
s.mu.RLock()
defer s.mu.RUnlock()
l := int64(len(list))
if index < 0 {
index += l
}
if index < 0 || index >= l {
return "", ErrNotFound
}
return list[index], nil
}
// Zset methods... (ZAdd, ZRange, ZRem)
func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
s.mu.Lock()
defer s.mu.Unlock()
item := s.getItem(key, true)
if item == nil {
item = &memoryStoreItem{value: make([]zsetMember, 0)}
s.items[key] = item
}
zset, ok := item.value.([]zsetMember)
if !ok {
zset = make([]zsetMember, 0)
}
for memberVal, score := range members {
found := false
for i := range zset {
if zset[i].Value == memberVal {
zset[i].Score = score
found = true
break
}
}
if !found {
zset = append(zset, zsetMember{Value: memberVal, Score: score})
}
}
sort.Slice(zset, func(i, j int) bool {
if zset[i].Score == zset[j].Score {
return zset[i].Value < zset[j].Value
}
return zset[i].Score < zset[j].Score
})
item.value = zset
return nil
}
func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
item := s.getItem(key, false)
if item == nil {
return []string{}, nil
}
zset, ok := item.value.([]zsetMember)
if !ok {
return []string{}, nil
}
s.mu.RLock()
defer s.mu.RUnlock()
l := int64(len(zset))
if start < 0 {
start += l
}
if stop < 0 {
stop += l
}
if start < 0 {
start = 0
}
if start > stop || start >= l {
return []string{}, nil
}
if stop >= l {
stop = l - 1
}
result := make([]string, 0, stop-start+1)
for i := start; i <= stop; i++ {
result = append(result, zset[i].Value)
}
return result, nil
}
func (s *memoryStore) ZRem(key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item := s.getItem(key, true)
if item == nil {
return nil
}
zset, ok := item.value.([]zsetMember)
if !ok {
return nil
}
membersToRemove := make(map[string]struct{}, len(members))
for _, m := range members {
membersToRemove[fmt.Sprintf("%v", m)] = struct{}{}
}
newZSet := make([]zsetMember, 0, len(zset))
for _, z := range zset {
if _, exists := membersToRemove[z.Value]; !exists {
newZSet = append(newZSet, z)
}
}
item.value = newZSet
return nil
}
func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
mainItem := s.getItem(mainKey, true)
if mainItem == nil || len(mainItem.value.(map[string]struct{})) == 0 {
cooldownItem := s.getItem(cooldownKey, true)
if cooldownItem == nil {
return "", ErrNotFound
}
// "Rename" by moving value and deleting old key
s.items[mainKey] = cooldownItem
delete(s.items, cooldownKey)
mainItem = cooldownItem
}
mainSet, ok := mainItem.value.(map[string]struct{})
if !ok || len(mainSet) == 0 {
return "", ErrNotFound // Should not happen after cycle logic
}
var popped string
for k := range mainSet {
popped = k
break
}
delete(mainSet, popped)
return popped, nil
}
// Pipeline implementation
type memoryPipeliner struct {
store *memoryStore
ops []func()
}
func (s *memoryStore) Pipeline() Pipeliner {
return &memoryPipeliner{store: s}
}
func (p *memoryPipeliner) Exec() error {
p.store.mu.Lock()
defer p.store.mu.Unlock()
for _, op := range p.ops {
op()
}
return nil
}
// [核心修正] Expire 现在可以正确地为任何 key 设置过期时间
func (p *memoryPipeliner) Expire(key string, expiration time.Duration) {
p.ops = append(p.ops, func() {
// This must be called within Exec's lock
item := p.store.getItem(key, true)
if item != nil {
item.expireAt = time.Now().Add(expiration)
}
})
}
// All other pipeliner methods...
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
p.ops = append(p.ops, func() { p.store.hSetNoLock(key, values) })
}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {
p.ops = append(p.ops, func() { p.store.hIncrByNoLock(key, field, incr) })
}
func (p *memoryPipeliner) Del(keys ...string) {
p.ops = append(p.ops, func() { p.store.delNoLock(keys...) })
}
func (p *memoryPipeliner) SAdd(key string, members ...any) {
p.ops = append(p.ops, func() { p.store.sAddNoLock(key, members...) })
}
func (p *memoryPipeliner) SRem(key string, members ...any) {
p.ops = append(p.ops, func() { p.store.sRemNoLock(key, members...) })
}
func (p *memoryPipeliner) LPush(key string, values ...any) {
p.ops = append(p.ops, func() { p.store.lPushNoLock(key, values...) })
}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {
p.ops = append(p.ops, func() { p.store.lRemNoLock(key, count, value) })
}
// Pub/Sub implementation (remains unchanged as it's a separate system)
type memorySubscription struct {
store *memoryStore
channelName string
msgChan chan *Message
}
func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan }
func (ms *memorySubscription) Close() error {
return ms.store.removeSubscriber(ms.channelName, ms.msgChan)
}
func (s *memoryStore) Publish(channel string, message []byte) error {
s.mu.RLock()
defer s.mu.RUnlock()
subscribers, ok := s.pubsub[channel]
if !ok {
return nil
}
msg := &Message{Channel: channel, Payload: message}
for _, ch := range subscribers {
select {
case ch <- msg:
case <-time.After(100 * time.Millisecond):
s.logger.Warnf("Could not publish to a subscriber on channel '%s' within 100ms", channel)
}
}
return nil
}
func (s *memoryStore) Subscribe(channel string) (Subscription, error) {
s.mu.Lock()
defer s.mu.Unlock()
msgChan := make(chan *Message, 10)
sub := &memorySubscription{store: s, channelName: channel, msgChan: msgChan}
s.pubsub[channel] = append(s.pubsub[channel], msgChan)
return sub, nil
}
func (s *memoryStore) removeSubscriber(channelName string, msgChan chan *Message) error {
s.mu.Lock()
defer s.mu.Unlock()
subscribers, ok := s.pubsub[channelName]
if !ok {
return nil
}
newSubscribers := make([]chan *Message, 0)
for _, ch := range subscribers {
if ch != msgChan {
newSubscribers = append(newSubscribers, ch)
}
}
if len(newSubscribers) == 0 {
delete(s.pubsub, channelName)
} else {
s.pubsub[channelName] = newSubscribers
}
close(msgChan)
return nil
}