Files
gemini-banlancer/internal/store/memory_store.go
2025-11-22 14:20:05 +08:00

681 lines
16 KiB
Go

// Filename: internal/store/memory_store.go
package store
import (
"context"
"fmt"
"math/rand"
"sort"
"sync"
"time"
"github.com/sirupsen/logrus"
)
// ensure memoryStore implements Store interface
var _ Store = (*memoryStore)(nil)
type memoryStoreItem struct {
value interface{}
expireAt time.Time
}
func (item *memoryStoreItem) isExpired() bool {
return !item.expireAt.IsZero() && time.Now().After(item.expireAt)
}
type zsetMember struct {
Value string
Score float64
}
type memoryStore struct {
items map[string]*memoryStoreItem
pubsub map[string][]chan *Message
mu sync.RWMutex
rng *rand.Rand
rngMu sync.Mutex
logger *logrus.Entry
}
func NewMemoryStore(logger *logrus.Logger) Store {
store := &memoryStore{
items: make(map[string]*memoryStoreItem),
pubsub: make(map[string][]chan *Message),
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
logger: logger.WithField("component", "store.memory 🗱"),
}
go store.startGCollector()
return store
}
func (s *memoryStore) startGCollector() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
s.mu.Lock()
now := time.Now()
for key, item := range s.items {
if !item.expireAt.IsZero() && now.After(item.expireAt) {
delete(s.items, key)
}
}
s.mu.Unlock()
}
}
// --- [架构修正] 所有方法签名都增加了 context.Context 参数以匹配接口 ---
// --- 内存实现可以忽略该参数,用 _ 接收 ---
func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
var expireAt time.Time
if ttl > 0 {
expireAt = time.Now().Add(ttl)
}
s.items[key] = &memoryStoreItem{value: value, expireAt: expireAt}
return nil
}
func (s *memoryStore) Get(_ context.Context, key string) ([]byte, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return nil, ErrNotFound
}
if value, ok := item.value.([]byte); ok {
return value, nil
}
return nil, ErrNotFound
}
func (s *memoryStore) Del(_ context.Context, keys ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
for _, key := range keys {
delete(s.items, key)
}
return nil
}
func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
return ok && !item.isExpired(), nil
}
func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if ok && !item.isExpired() {
return false, nil
}
var expireAt time.Time
if ttl > 0 {
expireAt = time.Now().Add(ttl)
}
s.items[key] = &memoryStoreItem{value: value, expireAt: expireAt}
return true, nil
}
func (s *memoryStore) Close() error { return nil }
func (s *memoryStore) HDel(_ context.Context, key string, fields ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return nil
}
if hash, ok := item.value.(map[string]string); ok {
for _, field := range fields {
delete(hash, field)
}
}
return nil
}
func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
s.items[key] = item
}
hash, ok := item.value.(map[string]string)
if !ok {
hash = make(map[string]string)
item.value = hash
}
for field, value := range values {
hash[field] = fmt.Sprintf("%v", value)
}
return nil
}
func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return make(map[string]string), nil
}
if hash, ok := item.value.(map[string]string); ok {
result := make(map[string]string, len(hash))
for k, v := range hash {
result[k] = v
}
return result, nil
}
return make(map[string]string), nil
}
func (s *memoryStore) HIncrBy(_ context.Context, key, field string, incr int64) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
s.items[key] = item
}
hash, ok := item.value.(map[string]string)
if !ok {
hash = make(map[string]string)
item.value = hash
}
var currentVal int64
if valStr, ok := hash[field]; ok {
fmt.Sscanf(valStr, "%d", &currentVal)
}
newVal := currentVal + incr
hash[field] = fmt.Sprintf("%d", newVal)
return newVal, nil
}
func (s *memoryStore) LPush(_ context.Context, key string, values ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make([]string, 0)}
s.items[key] = item
}
list, ok := item.value.([]string)
if !ok {
list = make([]string, 0)
}
stringValues := make([]string, len(values))
for i, v := range values {
stringValues[i] = fmt.Sprintf("%v", v)
}
item.value = append(stringValues, list...)
return nil
}
func (s *memoryStore) LRem(_ context.Context, key string, count int64, value any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return nil
}
list, ok := item.value.([]string)
if !ok {
return nil
}
valToRemove := fmt.Sprintf("%v", value)
newList := make([]string, 0, len(list))
removedCount := int64(0)
for _, v := range list {
if v == valToRemove && (count == 0 || removedCount < count) {
removedCount++
} else {
newList = append(newList, v)
}
}
item.value = newList
return nil
}
func (s *memoryStore) SAdd(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]struct{})}
s.items[key] = item
}
set, ok := item.value.(map[string]struct{})
if !ok {
set = make(map[string]struct{})
item.value = set
}
for _, member := range members {
set[fmt.Sprintf("%v", member)] = struct{}{}
}
return nil
}
func (s *memoryStore) SPopN(_ context.Context, key string, count int64) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return []string{}, nil
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return []string{}, nil
}
if int64(len(set)) < count {
count = int64(len(set))
}
popped := make([]string, 0, count)
keys := make([]string, 0, len(set))
for k := range set {
keys = append(keys, k)
}
s.rngMu.Lock()
s.rng.Shuffle(len(keys), func(i, j int) { keys[i], keys[j] = keys[j], keys[i] })
s.rngMu.Unlock()
for i := int64(0); i < count; i++ {
poppedKey := keys[i]
popped = append(popped, poppedKey)
delete(set, poppedKey)
}
return popped, nil
}
func (s *memoryStore) SMembers(_ context.Context, key string) ([]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return []string{}, nil
}
set, ok := item.value.(map[string]struct{})
if !ok {
return []string{}, nil
}
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
return members, nil
}
func (s *memoryStore) SRem(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return nil
}
set, ok := item.value.(map[string]struct{})
if !ok {
return nil
}
for _, member := range members {
delete(set, fmt.Sprintf("%v", member))
}
return nil
}
func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return "", ErrNotFound
}
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
if len(members) == 0 {
return "", ErrNotFound
}
s.rngMu.Lock()
n := s.rng.Intn(len(members))
s.rngMu.Unlock()
return members[n], nil
}
func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
list, ok := item.value.([]string)
if !ok || len(list) == 0 {
return "", ErrNotFound
}
val := list[len(list)-1]
item.value = append([]string{val}, list[:len(list)-1]...)
return val, nil
}
func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
list, ok := item.value.([]string)
if !ok {
return "", ErrNotFound
}
l := int64(len(list))
if index < 0 {
index += l
}
if index < 0 || index >= l {
return "", ErrNotFound
}
return list[index], nil
}
func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make([]zsetMember, 0)}
s.items[key] = item
}
zset, ok := item.value.([]zsetMember)
if !ok {
zset = make([]zsetMember, 0)
}
membersMap := make(map[string]float64, len(zset))
for _, z := range zset {
membersMap[z.Value] = z.Score
}
for memberVal, score := range members {
membersMap[memberVal] = score
}
newZSet := make([]zsetMember, 0, len(membersMap))
for val, score := range membersMap {
newZSet = append(newZSet, zsetMember{Value: val, Score: score})
}
sort.Slice(newZSet, func(i, j int) bool {
if newZSet[i].Score == newZSet[j].Score {
return newZSet[i].Value < newZSet[j].Value
}
return newZSet[i].Score < newZSet[j].Score
})
item.value = newZSet
return nil
}
func (s *memoryStore) ZRange(_ context.Context, key string, start, stop int64) ([]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return []string{}, nil
}
zset, ok := item.value.([]zsetMember)
if !ok {
return []string{}, nil
}
l := int64(len(zset))
if start < 0 {
start += l
}
if stop < 0 {
stop += l
}
if start < 0 {
start = 0
}
if start > stop || start >= l {
return []string{}, nil
}
if stop >= l {
stop = l - 1
}
result := make([]string, 0, stop-start+1)
for i := start; i <= stop; i++ {
result = append(result, zset[i].Value)
}
return result, nil
}
func (s *memoryStore) ZRem(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return nil
}
zset, ok := item.value.([]zsetMember)
if !ok {
return nil
}
membersToRemove := make(map[string]struct{}, len(members))
for _, m := range members {
membersToRemove[fmt.Sprintf("%v", m)] = struct{}{}
}
newZSet := make([]zsetMember, 0, len(zset))
for _, z := range zset {
if _, exists := membersToRemove[z.Value]; !exists {
newZSet = append(newZSet, z)
}
}
item.value = newZSet
return nil
}
func (s *memoryStore) PopAndCycleSetMember(_ context.Context, mainKey, cooldownKey string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
mainItem, mainOk := s.items[mainKey]
var mainSet map[string]struct{}
if mainOk && !mainItem.isExpired() {
mainSet, mainOk = mainItem.value.(map[string]struct{})
mainOk = mainOk && len(mainSet) > 0
} else {
mainOk = false
}
if !mainOk {
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
return "", ErrNotFound
}
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
if !cooldownSetOk || len(cooldownSet) == 0 {
return "", ErrNotFound
}
s.items[mainKey] = cooldownItem
delete(s.items, cooldownKey)
mainSet = cooldownSet
}
var popped string
for k := range mainSet {
popped = k
break
}
delete(mainSet, popped)
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
s.items[cooldownKey] = cooldownItem
}
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
if !ok {
cooldownSet = make(map[string]struct{})
cooldownItem.value = cooldownSet
}
cooldownSet[popped] = struct{}{}
return popped, nil
}
type memoryPipeliner struct {
store *memoryStore
ops []func()
}
func (s *memoryStore) Pipeline(_ context.Context) Pipeliner {
return &memoryPipeliner{store: s}
}
func (p *memoryPipeliner) Exec() error {
p.store.mu.Lock()
defer p.store.mu.Unlock()
for _, op := range p.ops {
op()
}
return nil
}
func (p *memoryPipeliner) Expire(key string, expiration time.Duration) {
capturedKey := key
p.ops = append(p.ops, func() {
if item, ok := p.store.items[capturedKey]; ok {
item.expireAt = time.Now().Add(expiration)
}
})
}
func (p *memoryPipeliner) Del(keys ...string) {
capturedKeys := make([]string, len(keys))
copy(capturedKeys, keys)
p.ops = append(p.ops, func() {
for _, key := range capturedKeys {
delete(p.store.items, key)
}
})
}
func (p *memoryPipeliner) SAdd(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
copy(capturedMembers, members)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]struct{})}
p.store.items[capturedKey] = item
}
set, ok := item.value.(map[string]struct{})
if !ok {
set = make(map[string]struct{})
item.value = set
}
for _, member := range capturedMembers {
set[fmt.Sprintf("%v", member)] = struct{}{}
}
})
}
func (p *memoryPipeliner) SRem(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
copy(capturedMembers, members)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
return
}
set, ok := item.value.(map[string]struct{})
if !ok {
return
}
for _, member := range capturedMembers {
delete(set, fmt.Sprintf("%v", member))
}
})
}
func (p *memoryPipeliner) LPush(key string, values ...any) {
capturedKey := key
capturedValues := make([]any, len(values))
copy(capturedValues, values)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make([]string, 0)}
p.store.items[capturedKey] = item
}
list, ok := item.value.([]string)
if !ok {
list = make([]string, 0)
}
stringValues := make([]string, len(capturedValues))
for i, v := range capturedValues {
stringValues[i] = fmt.Sprintf("%v", v)
}
item.value = append(stringValues, list...)
})
}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {}
func (p *memoryPipeliner) ZRem(key string, members ...any) {}
type memorySubscription struct {
store *memoryStore
channelName string
msgChan chan *Message
}
func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan }
func (ms *memorySubscription) ChannelName() string { return ms.channelName }
func (ms *memorySubscription) Close() error {
return ms.store.removeSubscriber(ms.channelName, ms.msgChan)
}
func (s *memoryStore) Publish(_ context.Context, channel string, message []byte) error {
s.mu.RLock()
defer s.mu.RUnlock()
subscribers, ok := s.pubsub[channel]
if !ok {
return nil
}
msg := &Message{Channel: channel, Payload: message}
for _, ch := range subscribers {
select {
case ch <- msg:
case <-time.After(100 * time.Millisecond):
s.logger.Warnf("Could not publish to a subscriber on channel '%s' within 100ms", channel)
}
}
return nil
}
func (s *memoryStore) Subscribe(_ context.Context, channel string) (Subscription, error) {
s.mu.Lock()
defer s.mu.Unlock()
msgChan := make(chan *Message, 10)
sub := &memorySubscription{store: s, channelName: channel, msgChan: msgChan}
s.pubsub[channel] = append(s.pubsub[channel], msgChan)
return sub, nil
}
func (s *memoryStore) removeSubscriber(channelName string, msgChan chan *Message) error {
s.mu.Lock()
defer s.mu.Unlock()
subscribers, ok := s.pubsub[channelName]
if !ok {
return nil
}
newSubscribers := make([]chan *Message, 0)
for _, ch := range subscribers {
if ch != msgChan {
newSubscribers = append(newSubscribers, ch)
}
}
if len(newSubscribers) == 0 {
delete(s.pubsub, channelName)
} else {
s.pubsub[channelName] = newSubscribers
}
close(msgChan)
return nil
}