Files
gemini-banlancer/internal/store/memory_store.go
2025-11-23 22:42:58 +08:00

869 lines
21 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Filename: internal/store/memory_store.go
package store
import (
"context"
"fmt"
"math/rand"
"sort"
"strconv"
"sync"
"time"
"github.com/sirupsen/logrus"
)
// ensure memoryStore implements Store interface
var _ Store = (*memoryStore)(nil)
type memoryStoreItem struct {
value interface{}
expireAt time.Time
}
func (item *memoryStoreItem) isExpired() bool {
return !item.expireAt.IsZero() && time.Now().After(item.expireAt)
}
type zsetMember struct {
Value string
Score float64
}
type memoryStore struct {
items map[string]*memoryStoreItem
pubsub map[string][]chan *Message
mu sync.RWMutex
rng *rand.Rand
rngMu sync.Mutex
logger *logrus.Entry
}
func NewMemoryStore(logger *logrus.Logger) Store {
store := &memoryStore{
items: make(map[string]*memoryStoreItem),
pubsub: make(map[string][]chan *Message),
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
logger: logger.WithField("component", "store.memory 🗱"),
}
go store.startGCollector()
return store
}
func (s *memoryStore) startGCollector() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
s.mu.Lock()
now := time.Now()
for key, item := range s.items {
if !item.expireAt.IsZero() && now.After(item.expireAt) {
delete(s.items, key)
}
}
s.mu.Unlock()
}
}
// --- 所有方法签名都增加了 context.Context 参数以匹配接口 ---
// --- 内存实现可以忽略该参数,用 _ 接收 ---
func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
var expireAt time.Time
if ttl > 0 {
expireAt = time.Now().Add(ttl)
}
s.items[key] = &memoryStoreItem{value: value, expireAt: expireAt}
return nil
}
func (s *memoryStore) Get(_ context.Context, key string) ([]byte, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return nil, ErrNotFound
}
if value, ok := item.value.([]byte); ok {
return value, nil
}
return nil, ErrNotFound
}
func (s *memoryStore) Del(_ context.Context, keys ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
for _, key := range keys {
delete(s.items, key)
}
return nil
}
func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
return ok && !item.isExpired(), nil
}
func (s *memoryStore) Expire(_ context.Context, key string, expiration time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok {
return ErrNotFound
}
item.expireAt = time.Now().Add(expiration)
return nil
}
func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if ok && !item.isExpired() {
return false, nil
}
var expireAt time.Time
if ttl > 0 {
expireAt = time.Now().Add(ttl)
}
s.items[key] = &memoryStoreItem{value: value, expireAt: expireAt}
return true, nil
}
func (s *memoryStore) Close() error { return nil }
func (s *memoryStore) HDel(_ context.Context, key string, fields ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return nil
}
if hash, ok := item.value.(map[string]string); ok {
for _, field := range fields {
delete(hash, field)
}
}
return nil
}
func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
s.items[key] = item
}
hash, ok := item.value.(map[string]string)
if !ok {
hash = make(map[string]string)
item.value = hash
}
for field, value := range values {
hash[field] = fmt.Sprintf("%v", value)
}
return nil
}
func (s *memoryStore) HGet(_ context.Context, key, field string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
if hash, ok := item.value.(map[string]string); ok {
if value, exists := hash[field]; exists {
return value, nil
}
}
return "", ErrNotFound
}
func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return make(map[string]string), nil
}
if hash, ok := item.value.(map[string]string); ok {
result := make(map[string]string, len(hash))
for k, v := range hash {
result[k] = v
}
return result, nil
}
return make(map[string]string), nil
}
func (s *memoryStore) HIncrBy(_ context.Context, key, field string, incr int64) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
s.items[key] = item
}
hash, ok := item.value.(map[string]string)
if !ok {
hash = make(map[string]string)
item.value = hash
}
var currentVal int64
if valStr, ok := hash[field]; ok {
fmt.Sscanf(valStr, "%d", &currentVal)
}
newVal := currentVal + incr
hash[field] = fmt.Sprintf("%d", newVal)
return newVal, nil
}
func (s *memoryStore) LPush(_ context.Context, key string, values ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make([]string, 0)}
s.items[key] = item
}
list, ok := item.value.([]string)
if !ok {
list = make([]string, 0)
}
stringValues := make([]string, len(values))
for i, v := range values {
stringValues[i] = fmt.Sprintf("%v", v)
}
item.value = append(stringValues, list...)
return nil
}
func (s *memoryStore) LRem(_ context.Context, key string, count int64, value any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return nil
}
list, ok := item.value.([]string)
if !ok {
return nil
}
valToRemove := fmt.Sprintf("%v", value)
newList := make([]string, 0, len(list))
removedCount := int64(0)
for _, v := range list {
if v == valToRemove && (count == 0 || removedCount < count) {
removedCount++
} else {
newList = append(newList, v)
}
}
item.value = newList
return nil
}
func (s *memoryStore) SAdd(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]struct{})}
s.items[key] = item
}
set, ok := item.value.(map[string]struct{})
if !ok {
set = make(map[string]struct{})
item.value = set
}
for _, member := range members {
set[fmt.Sprintf("%v", member)] = struct{}{}
}
return nil
}
func (s *memoryStore) SPopN(_ context.Context, key string, count int64) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return []string{}, nil
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return []string{}, nil
}
if int64(len(set)) < count {
count = int64(len(set))
}
popped := make([]string, 0, count)
keys := make([]string, 0, len(set))
for k := range set {
keys = append(keys, k)
}
s.rngMu.Lock()
s.rng.Shuffle(len(keys), func(i, j int) { keys[i], keys[j] = keys[j], keys[i] })
s.rngMu.Unlock()
for i := int64(0); i < count; i++ {
poppedKey := keys[i]
popped = append(popped, poppedKey)
delete(set, poppedKey)
}
return popped, nil
}
func (s *memoryStore) SMembers(_ context.Context, key string) ([]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return []string{}, nil
}
set, ok := item.value.(map[string]struct{})
if !ok {
return []string{}, nil
}
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
return members, nil
}
func (s *memoryStore) SRem(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return nil
}
set, ok := item.value.(map[string]struct{})
if !ok {
return nil
}
for _, member := range members {
delete(set, fmt.Sprintf("%v", member))
}
return nil
}
func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return "", ErrNotFound
}
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
if len(members) == 0 {
return "", ErrNotFound
}
s.rngMu.Lock()
n := s.rng.Intn(len(members))
s.rngMu.Unlock()
return members[n], nil
}
func (s *memoryStore) SUnionStore(_ context.Context, destination string, keys ...string) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
unionSet := make(map[string]struct{})
for _, key := range keys {
item, ok := s.items[key]
if !ok || item.isExpired() {
continue
}
if set, ok := item.value.(map[string]struct{}); ok {
for member := range set {
unionSet[member] = struct{}{}
}
}
}
destItem := &memoryStoreItem{value: unionSet}
s.items[destination] = destItem
return int64(len(unionSet)), nil
}
func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
list, ok := item.value.([]string)
if !ok || len(list) == 0 {
return "", ErrNotFound
}
val := list[len(list)-1]
item.value = append([]string{val}, list[:len(list)-1]...)
return val, nil
}
func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
list, ok := item.value.([]string)
if !ok {
return "", ErrNotFound
}
l := int64(len(list))
if index < 0 {
index += l
}
if index < 0 || index >= l {
return "", ErrNotFound
}
return list[index], nil
}
func (s *memoryStore) MSet(ctx context.Context, values map[string]any) error {
s.mu.Lock()
defer s.mu.Unlock()
for key, value := range values {
// 内存存储不支持独立的 TTL因此我们假设永不过期
s.items[key] = &memoryStoreItem{value: value, expireAt: time.Time{}}
}
return nil
}
func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make([]zsetMember, 0)}
s.items[key] = item
}
zset, ok := item.value.([]zsetMember)
if !ok {
zset = make([]zsetMember, 0)
}
membersMap := make(map[string]float64, len(zset))
for _, z := range zset {
membersMap[z.Value] = z.Score
}
for memberVal, score := range members {
membersMap[memberVal] = score
}
newZSet := make([]zsetMember, 0, len(membersMap))
for val, score := range membersMap {
newZSet = append(newZSet, zsetMember{Value: val, Score: score})
}
sort.Slice(newZSet, func(i, j int) bool {
if newZSet[i].Score == newZSet[j].Score {
return newZSet[i].Value < newZSet[j].Value
}
return newZSet[i].Score < newZSet[j].Score
})
item.value = newZSet
return nil
}
func (s *memoryStore) ZRange(_ context.Context, key string, start, stop int64) ([]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return []string{}, nil
}
zset, ok := item.value.([]zsetMember)
if !ok {
return []string{}, nil
}
l := int64(len(zset))
if start < 0 {
start += l
}
if stop < 0 {
stop += l
}
if start < 0 {
start = 0
}
if start > stop || start >= l {
return []string{}, nil
}
if stop >= l {
stop = l - 1
}
result := make([]string, 0, stop-start+1)
for i := start; i <= stop; i++ {
result = append(result, zset[i].Value)
}
return result, nil
}
func (s *memoryStore) ZRem(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return nil
}
zset, ok := item.value.([]zsetMember)
if !ok {
return nil
}
membersToRemove := make(map[string]struct{}, len(members))
for _, m := range members {
membersToRemove[fmt.Sprintf("%v", m)] = struct{}{}
}
newZSet := make([]zsetMember, 0, len(zset))
for _, z := range zset {
if _, exists := membersToRemove[z.Value]; !exists {
newZSet = append(newZSet, z)
}
}
item.value = newZSet
return nil
}
func (s *memoryStore) PopAndCycleSetMember(_ context.Context, mainKey, cooldownKey string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
mainItem, mainOk := s.items[mainKey]
var mainSet map[string]struct{}
if mainOk && !mainItem.isExpired() {
mainSet, mainOk = mainItem.value.(map[string]struct{})
mainOk = mainOk && len(mainSet) > 0
} else {
mainOk = false
}
if !mainOk {
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
return "", ErrNotFound
}
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
if !cooldownSetOk || len(cooldownSet) == 0 {
return "", ErrNotFound
}
s.items[mainKey] = cooldownItem
delete(s.items, cooldownKey)
mainSet = cooldownSet
}
var popped string
for k := range mainSet {
popped = k
break
}
delete(mainSet, popped)
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
s.items[cooldownKey] = cooldownItem
}
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
if !ok {
cooldownSet = make(map[string]struct{})
cooldownItem.value = cooldownSet
}
cooldownSet[popped] = struct{}{}
return popped, nil
}
type memoryPipeliner struct {
store *memoryStore
ops []func()
}
func (s *memoryStore) Pipeline(_ context.Context) Pipeliner {
return &memoryPipeliner{store: s}
}
func (p *memoryPipeliner) Exec() error {
p.store.mu.Lock()
defer p.store.mu.Unlock()
for _, op := range p.ops {
op()
}
return nil
}
func (p *memoryPipeliner) Expire(key string, expiration time.Duration) {
capturedKey := key
p.ops = append(p.ops, func() {
if item, ok := p.store.items[capturedKey]; ok {
item.expireAt = time.Now().Add(expiration)
}
})
}
func (p *memoryPipeliner) Del(keys ...string) {
capturedKeys := make([]string, len(keys))
copy(capturedKeys, keys)
p.ops = append(p.ops, func() {
for _, key := range capturedKeys {
delete(p.store.items, key)
}
})
}
func (p *memoryPipeliner) Set(key string, value []byte, expiration time.Duration) {
capturedKey := key
capturedValue := value
p.ops = append(p.ops, func() {
var expireAt time.Time
if expiration > 0 {
expireAt = time.Now().Add(expiration)
}
p.store.items[capturedKey] = &memoryStoreItem{
value: capturedValue,
expireAt: expireAt,
}
})
}
func (p *memoryPipeliner) SAdd(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
copy(capturedMembers, members)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]struct{})}
p.store.items[capturedKey] = item
}
set, ok := item.value.(map[string]struct{})
if !ok {
set = make(map[string]struct{})
item.value = set
}
for _, member := range capturedMembers {
set[fmt.Sprintf("%v", member)] = struct{}{}
}
})
}
func (p *memoryPipeliner) SRem(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
copy(capturedMembers, members)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
return
}
set, ok := item.value.(map[string]struct{})
if !ok {
return
}
for _, member := range capturedMembers {
delete(set, fmt.Sprintf("%v", member))
}
})
}
func (p *memoryPipeliner) LPush(key string, values ...any) {
capturedKey := key
capturedValues := make([]any, len(values))
copy(capturedValues, values)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make([]string, 0)}
p.store.items[capturedKey] = item
}
list, ok := item.value.([]string)
if !ok {
list = make([]string, 0)
}
stringValues := make([]string, len(capturedValues))
for i, v := range capturedValues {
stringValues[i] = fmt.Sprintf("%v", v)
}
item.value = append(stringValues, list...)
})
}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {
capturedKey := key
capturedValue := fmt.Sprintf("%v", value)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
return
}
list, ok := item.value.([]string)
if !ok {
return
}
newList := make([]string, 0, len(list))
removed := int64(0)
for _, v := range list {
if count != 0 && v == capturedValue && (count < 0 || removed < count) {
removed++
continue
}
newList = append(newList, v)
}
item.value = newList
})
}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
capturedKey := key
capturedValues := make(map[string]any, len(values))
for k, v := range values {
capturedValues[k] = v
}
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
p.store.items[capturedKey] = item
}
hash, ok := item.value.(map[string]string)
if !ok {
hash = make(map[string]string)
item.value = hash
}
for field, value := range capturedValues {
hash[field] = fmt.Sprintf("%v", value)
}
})
}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {
capturedKey := key
capturedField := field
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
p.store.items[capturedKey] = item
}
hash, ok := item.value.(map[string]string)
if !ok {
hash = make(map[string]string)
item.value = hash
}
current, _ := strconv.ParseInt(hash[capturedField], 10, 64)
hash[capturedField] = strconv.FormatInt(current+incr, 10)
})
}
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {
capturedKey := key
capturedMembers := make(map[string]float64, len(members))
for k, v := range members {
capturedMembers[k] = v
}
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]float64)}
p.store.items[capturedKey] = item
}
zset, ok := item.value.(map[string]float64)
if !ok {
zset = make(map[string]float64)
item.value = zset
}
for member, score := range capturedMembers {
zset[member] = score
}
})
}
func (p *memoryPipeliner) ZRem(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
copy(capturedMembers, members)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
return
}
zset, ok := item.value.(map[string]float64)
if !ok {
return
}
for _, member := range capturedMembers {
delete(zset, fmt.Sprintf("%v", member))
}
})
}
func (p *memoryPipeliner) MSet(values map[string]any) {
capturedValues := make(map[string]any, len(values))
for k, v := range values {
capturedValues[k] = v
}
p.ops = append(p.ops, func() {
for key, value := range capturedValues {
p.store.items[key] = &memoryStoreItem{
value: value,
expireAt: time.Time{}, // Pipelined MSet 同样假设永不过期
}
}
})
}
type memorySubscription struct {
store *memoryStore
channelName string
msgChan chan *Message
}
func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan }
func (ms *memorySubscription) ChannelName() string { return ms.channelName }
func (ms *memorySubscription) Close() error {
return ms.store.removeSubscriber(ms.channelName, ms.msgChan)
}
func (s *memoryStore) Publish(_ context.Context, channel string, message []byte) error {
s.mu.RLock()
defer s.mu.RUnlock()
subscribers, ok := s.pubsub[channel]
if !ok {
return nil
}
msg := &Message{Channel: channel, Payload: message}
for _, ch := range subscribers {
select {
case ch <- msg:
case <-time.After(100 * time.Millisecond):
s.logger.Warnf("Could not publish to a subscriber on channel '%s' within 100ms", channel)
}
}
return nil
}
func (s *memoryStore) Subscribe(_ context.Context, channel string) (Subscription, error) {
s.mu.Lock()
defer s.mu.Unlock()
msgChan := make(chan *Message, 10)
sub := &memorySubscription{store: s, channelName: channel, msgChan: msgChan}
s.pubsub[channel] = append(s.pubsub[channel], msgChan)
return sub, nil
}
func (s *memoryStore) removeSubscriber(channelName string, msgChan chan *Message) error {
s.mu.Lock()
defer s.mu.Unlock()
subscribers, ok := s.pubsub[channelName]
if !ok {
return nil
}
newSubscribers := make([]chan *Message, 0)
for _, ch := range subscribers {
if ch != msgChan {
newSubscribers = append(newSubscribers, ch)
}
}
if len(newSubscribers) == 0 {
delete(s.pubsub, channelName)
} else {
s.pubsub[channelName] = newSubscribers
}
close(msgChan)
return nil
}