before update store

This commit is contained in:
XOF
2025-11-22 11:51:20 +08:00
parent 6a0f344e5c
commit ac0e0a8275
5 changed files with 316 additions and 250 deletions

View File

@@ -24,7 +24,7 @@ type SystemSettings struct {
BaseKeyCheckIntervalMinutes int `json:"base_key_check_interval_minutes" default:"1440" name:"全局Key检查周期(分钟)" category:"健康检查" desc:"对所有ACTIVE状态的Key进行身份检查的周期建议设置较长时间例如1天(1440分钟)。"`
BaseKeyCheckConcurrency int `json:"base_key_check_concurrency" default:"5" name:"全局Key检查并发数" category:"健康检查" desc:"执行全局Key身份检查时的并发请求数量。"`
BaseKeyCheckEndpoint string `json:"base_key_check_endpoint" default:"https://generativelanguage.googleapis.com/v1beta/models" name:"全局Key检查端点" category:"健康检查" desc:"用于全局Key身份检查的目标URL。"`
BaseKeyCheckModel string `json:"base_key_check_model" default:"gemini-1.5-flash" name:"默认Key检查模型" category:"健康检查" desc:"用于分组健康检查和手动密钥测试时的默认回退模型。"`
BaseKeyCheckModel string `json:"base_key_check_model" default:"gemini-2.0-flash-lite" name:"默认Key检查模型" category:"健康检查" desc:"用于分组健康检查和手动密钥测试时的默认回退模型。"`
EnableUpstreamCheck bool `json:"enable_upstream_check" default:"true" name:"启用上游检查" category:"健康检查" desc:"是否启用对上游服务(Upstream)的健康检查。"`
UpstreamCheckTimeoutSeconds int `json:"upstream_check_timeout_seconds" default:"20" name:"上游检查超时(秒)" category:"健康检查" desc:"对单个上游服务进行健康检查时的网络超时时间。"`

View File

@@ -11,7 +11,6 @@ import (
"io"
"sort"
"strconv"
"strings"
"time"
"gorm.io/gorm"
@@ -82,20 +81,13 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.
// SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
protocol := "default"
if pool.Protocol != "" {
protocol = string(pool.Protocol)
}
// 生成唯一的池ID确保不同请求组合的轮询状态相互隔离
poolID := generatePoolID(pool.CandidateGroups, protocol)
log := r.logger.WithField("pool_id", poolID).WithField("protocol", protocol)
poolID := generatePoolID(pool.CandidateGroups)
log := r.logger.WithField("pool_id", poolID)
if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil {
log.WithError(err).Error("Failed to ensure BasePool cache exists.")
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, err
}
return nil, nil, fmt.Errorf("unexpected error while ensuring base pool cache: %w", err)
return nil, nil, err
}
var keyIDStr string
@@ -154,65 +146,78 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error {
listKey := fmt.Sprintf(BasePoolSequential, poolID)
// --- [逻辑优化] 提前处理“毒丸”,让逻辑更清晰 ---
exists, err := r.store.Exists(listKey)
if err != nil {
r.logger.WithError(err).Errorf("Failed to check existence of basepool key: %s", listKey)
return err
r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID)
return err // 直接返回读取错误
}
if exists {
val, err := r.store.LIndex(listKey, 0)
if err != nil {
return err
// 如果连 LIndex 都失败,说明缓存可能已损坏,允许重建
r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID)
} else {
if val == EmptyPoolPlaceholder {
return gorm.ErrRecordNotFound // 已知为空,直接返回
}
return nil // 缓存有效,直接返回
}
if val == EmptyPoolPlaceholder {
return gorm.ErrRecordNotFound
}
return nil
}
// --- [锁机制优化] 增加分布式锁,防止并发构建时的惊群效应 ---
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
acquired, err := r.store.SetNX(lockKey, []byte("1"), 10*time.Second)
acquired, err := r.store.SetNX(lockKey, []byte("1"), 10*time.Second) // 10秒锁超时
if err != nil {
r.logger.WithError(err).Errorf("Failed to acquire distributed lock for basepool build: %s", lockKey)
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.")
return err
}
if !acquired {
// 未获取到锁,等待一小段时间后重试,让持有锁的协程完成构建
time.Sleep(100 * time.Millisecond)
return r.ensureBasePoolCacheExists(pool, poolID)
}
defer r.store.Del(lockKey)
defer r.store.Del(lockKey) // 确保在函数退出时释放锁
// 双重检查,防止在获取锁的间隙,已有其他协程完成了构建
if exists, _ := r.store.Exists(listKey); exists {
return nil
}
r.logger.Infof("BasePool cache for pool_id '%s' not found. Building now...", poolID)
r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID)
var allActiveKeyIDs []string
lruMembers := make(map[string]float64)
for _, group := range pool.CandidateGroups {
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
groupKeyIDs, err := r.store.SMembers(activeKeySetKey)
// --- [核心修正] ---
// 这是整个问题的根源。我们绝不能在读取失败时,默默地`continue`。
// 任何读取源数据的失败,都必须被视为一次构建过程的彻底失败,并立即中止。
if err != nil {
r.logger.WithError(err).Warnf("Failed to get active keys for group %d during BasePool build", group.ID)
continue
r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID)
// 返回这个瞬时错误。这会导致本次请求失败,但绝不会写入“毒丸”,
// 从而给了下一次请求一个全新的、成功的机会。
return err
}
// 只有在 SMembers 成功时,才继续处理
allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...)
for _, keyIDStr := range groupKeyIDs {
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
_, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
if err != nil {
if errors.Is(err, store.ErrNotFound) || strings.Contains(err.Error(), "failed to get") {
r.logger.WithError(err).Warnf("Cache inconsistency detected for KeyID %s in GroupID %d. Skipping.", keyIDStr, group.ID)
continue
} else {
return err
if err == nil && mapping != nil {
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
}
}
allActiveKeyIDs = append(allActiveKeyIDs, keyIDStr)
if mapping != nil && mapping.LastUsedAt != nil {
lruMembers[keyIDStr] = float64(mapping.LastUsedAt.UnixMilli())
lruMembers[keyIDStr] = score
}
}
}
// --- [逻辑修正] ---
// 只有在“我们成功读取了所有数据,但发现数据本身是空的”这种情况下,
// 才允许写入“毒丸”。
if len(allActiveKeyIDs) == 0 {
r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID)
pipe := r.store.Pipeline()
pipe.LPush(listKey, EmptyPoolPlaceholder)
pipe.Expire(listKey, EmptyCacheTTL)
@@ -221,16 +226,23 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
}
return gorm.ErrRecordNotFound
}
// 使用管道填充所有轮询结构
pipe := r.store.Pipeline()
// 1. 顺序
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
// 2. 随机
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
// 设置合理的过期时间例如5分钟以防止孤儿数据
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolLRU, poolID), CacheTTL)
if err := pipe.Exec(); err != nil {
return err
}
if len(lruMembers) > 0 {
r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers)
}
@@ -246,7 +258,7 @@ func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID
}
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
func generatePoolID(groups []*models.KeyGroup, protocol string) string {
func generatePoolID(groups []*models.KeyGroup) string {
ids := make([]int, len(groups))
for i, g := range groups {
ids[i] = int(g.ID)
@@ -254,7 +266,7 @@ func generatePoolID(groups []*models.KeyGroup, protocol string) string {
sort.Ints(ids)
h := sha1.New()
io.WriteString(h, fmt.Sprintf("protocol:%s;groups:%v", protocol, ids))
io.WriteString(h, fmt.Sprintf("%v", ids))
return fmt.Sprintf("%x", h.Sum(nil))
}

View File

@@ -1,4 +1,4 @@
// Filename: internal/store/memory_store.go (统一存储重构版)
// Filename: internal/store/memory_store.go (经同行审查后最终修复版)
package store
@@ -12,76 +12,144 @@ import (
"github.com/sirupsen/logrus"
)
// ensure memoryStore implements Store interface
var _ Store = (*memoryStore)(nil)
// [核心重构] memoryStoreItem 现在是通用容器,可以存储任何类型的值,并自带过期时间
type memoryStoreItem struct {
value interface{} // 可以是 []byte, []string, map[string]string, map[string]struct{}, []zsetMember
value interface{}
expireAt time.Time
}
// isExpired 检查一个条目是否已过期
func (item *memoryStoreItem) isExpired() bool {
return !item.expireAt.IsZero() && time.Now().After(item.expireAt)
}
// zsetMember 保持不变
type zsetMember struct {
Value string
Score float64
}
// [核心重构] memoryStore 现在使用一个统一的 map 来存储所有数据
type memoryStore struct {
items map[string]*memoryStoreItem // 指向 item 的指针,以便原地修改
items map[string]*memoryStoreItem
pubsub map[string][]chan *Message
mu sync.RWMutex
// [USER SUGGESTION APPLIED] 使用带锁的随机数源以保证并发安全
rng *rand.Rand
rngMu sync.Mutex
logger *logrus.Entry
}
// NewMemoryStore [核心重構] 構造函數也被簡化了
func NewMemoryStore(logger *logrus.Logger) Store {
return &memoryStore{
store := &memoryStore{
items: make(map[string]*memoryStoreItem),
pubsub: make(map[string][]chan *Message),
// 使用当前时间作为种子,创建一个新的随机数源
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
logger: logger.WithField("component", "store.memory 🗱"),
}
go store.startGCollector()
return store
}
// [核心重构] getItem 是一个新的内部辅助函数,它封装了获取、检查过期和删除的通用逻辑
func (s *memoryStore) getItem(key string, lockForWrite bool) *memoryStoreItem {
if !lockForWrite {
// 如果是读操作,先用读锁检查
s.mu.RLock()
item, ok := s.items[key]
if !ok || item.isExpired() {
s.mu.RUnlock()
// 如果不存在或已过期,需要尝试获取写锁来删除它
if ok { // 只有在确定 item 存在但已过期时才需要删除
s.mu.Lock()
// 再次检查,防止在获取写锁期间状态已改变
if item, ok := s.items[key]; ok && item.isExpired() {
delete(s.items, key)
}
s.mu.Unlock()
// [USER SUGGESTION INCORPORATED] Fix #1: 使用 now := time.Now() 进行原子性检查
func (s *memoryStore) startGCollector() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
s.mu.Lock()
now := time.Now() // 避免在循环中重复调用
for key, item := range s.items {
if !item.expireAt.IsZero() && now.After(item.expireAt) {
delete(s.items, key)
}
return nil // 无论如何都返回 nil
}
s.mu.RUnlock()
return item
s.mu.Unlock()
}
// 对于写操作,直接使用写锁
item, ok := s.items[key]
if ok && item.isExpired() {
delete(s.items, key)
return nil
}
return item
}
// --- 所有接口方法现在都基于新的统一结构重写 ---
// [USER SUGGESTION INCORPORATED] Fix #2 & #3: 修复了致命的nil检查和类型断言问题
func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
mainItem, mainOk := s.items[mainKey]
var mainSet map[string]struct{}
if mainOk && !mainItem.isExpired() {
// 安全地进行类型断言
mainSet, mainOk = mainItem.value.(map[string]struct{})
// 确保断言成功且集合不为空
mainOk = mainOk && len(mainSet) > 0
} else {
mainOk = false
}
if !mainOk {
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
return "", ErrNotFound
}
// 安全地进行类型断言
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
if !cooldownSetOk || len(cooldownSet) == 0 {
return "", ErrNotFound
}
s.items[mainKey] = cooldownItem
delete(s.items, cooldownKey)
mainSet = cooldownSet
}
var popped string
for k := range mainSet {
popped = k
break
}
delete(mainSet, popped)
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
s.items[cooldownKey] = cooldownItem
}
// 安全地处理冷却池
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
if !ok {
cooldownSet = make(map[string]struct{})
cooldownItem.value = cooldownSet
}
cooldownSet[popped] = struct{}{}
return popped, nil
}
// SRandMember [并发修复版] 使用带锁的rng
func (s *memoryStore) SRandMember(key string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return "", ErrNotFound
}
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
if len(members) == 0 {
return "", ErrNotFound
}
s.rngMu.Lock()
n := s.rng.Intn(len(members))
s.rngMu.Unlock()
return members[n], nil
}
// --- 以下是其余函数的最终版本,它们都遵循了安全、原子的锁策略 ---
func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
s.mu.Lock()
@@ -95,14 +163,16 @@ func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
}
func (s *memoryStore) Get(key string) ([]byte, error) {
item := s.getItem(key, false)
if item == nil {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return nil, ErrNotFound
}
if value, ok := item.value.([]byte); ok {
return value, nil
}
return nil, ErrNotFound // Type mismatch, treat as not found
return nil, ErrNotFound
}
func (s *memoryStore) Del(keys ...string) error {
@@ -114,20 +184,18 @@ func (s *memoryStore) Del(keys ...string) error {
return nil
}
func (s *memoryStore) delNoLock(keys ...string) {
for _, key := range keys {
delete(s.items, key)
}
}
func (s *memoryStore) Exists(key string) (bool, error) {
return s.getItem(key, false) != nil, nil
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
return ok && !item.isExpired(), nil
}
func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
if item := s.getItem(key, true); item != nil {
item, ok := s.items[key]
if ok && !item.isExpired() {
return false, nil
}
var expireAt time.Time
@@ -143,8 +211,8 @@ func (s *memoryStore) Close() error { return nil }
func (s *memoryStore) HDel(key string, fields ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
item := s.getItem(key, true)
if item == nil {
item, ok := s.items[key]
if !ok || item.isExpired() {
return nil
}
if hash, ok := item.value.(map[string]string); ok {
@@ -158,29 +226,27 @@ func (s *memoryStore) HDel(key string, fields ...string) error {
func (s *memoryStore) HSet(key string, values map[string]any) error {
s.mu.Lock()
defer s.mu.Unlock()
s.hSetNoLock(key, values)
return nil
}
func (s *memoryStore) hSetNoLock(key string, values map[string]any) {
item := s.getItem(key, true)
if item == nil {
item, ok := s.items[key]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
s.items[key] = item
}
hash, ok := item.value.(map[string]string)
if !ok { // If key exists but is not a hash, create a new hash
if !ok {
hash = make(map[string]string)
item.value = hash
}
for field, value := range values {
hash[field] = fmt.Sprintf("%v", value)
}
return nil
}
func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
item := s.getItem(key, false)
if item == nil {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return make(map[string]string), nil
}
if hash, ok := item.value.(map[string]string); ok {
@@ -196,12 +262,8 @@ func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.hIncrByNoLock(key, field, incr)
}
func (s *memoryStore) hIncrByNoLock(key, field string, incr int64) (int64, error) {
item := s.getItem(key, true)
if item == nil {
item, ok := s.items[key]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
s.items[key] = item
}
@@ -222,13 +284,8 @@ func (s *memoryStore) hIncrByNoLock(key, field string, incr int64) (int64, error
func (s *memoryStore) LPush(key string, values ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
s.lPushNoLock(key, values...)
return nil
}
func (s *memoryStore) lPushNoLock(key string, values ...any) {
item := s.getItem(key, true)
if item == nil {
item, ok := s.items[key]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make([]string, 0)}
s.items[key] = item
}
@@ -241,22 +298,19 @@ func (s *memoryStore) lPushNoLock(key string, values ...any) {
stringValues[i] = fmt.Sprintf("%v", v)
}
item.value = append(stringValues, list...)
return nil
}
func (s *memoryStore) LRem(key string, count int64, value any) error {
s.mu.Lock()
defer s.mu.Unlock()
s.lRemNoLock(key, count, value)
return nil
}
func (s *memoryStore) lRemNoLock(key string, count int64, value any) {
item := s.getItem(key, true)
if item == nil {
return
item, ok := s.items[key]
if !ok || item.isExpired() {
return nil
}
list, ok := item.value.([]string)
if !ok {
return
return nil
}
valToRemove := fmt.Sprintf("%v", value)
newList := make([]string, 0, len(list))
@@ -269,17 +323,14 @@ func (s *memoryStore) lRemNoLock(key string, count int64, value any) {
}
}
item.value = newList
}
func (s *memoryStore) SAdd(key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
s.sAddNoLock(key, members...)
return nil
}
func (s *memoryStore) sAddNoLock(key string, members ...any) {
item := s.getItem(key, true)
if item == nil {
func (s *memoryStore) SAdd(key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]struct{})}
s.items[key] = item
}
@@ -291,12 +342,14 @@ func (s *memoryStore) sAddNoLock(key string, members ...any) {
for _, member := range members {
set[fmt.Sprintf("%v", member)] = struct{}{}
}
return nil
}
func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item := s.getItem(key, true)
if item == nil {
item, ok := s.items[key]
if !ok || item.isExpired() {
return []string{}, nil
}
set, ok := item.value.(map[string]struct{})
@@ -311,7 +364,9 @@ func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
for k := range set {
keys = append(keys, k)
}
rand.Shuffle(len(keys), func(i, j int) { keys[i], keys[j] = keys[j], keys[i] })
s.rngMu.Lock()
s.rng.Shuffle(len(keys), func(i, j int) { keys[i], keys[j] = keys[j], keys[i] })
s.rngMu.Unlock()
for i := int64(0); i < count; i++ {
poppedKey := keys[i]
popped = append(popped, poppedKey)
@@ -319,68 +374,47 @@ func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
}
return popped, nil
}
func (s *memoryStore) SMembers(key string) ([]string, error) {
item := s.getItem(key, false)
if item == nil {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return []string{}, nil
}
set, ok := item.value.(map[string]struct{})
if !ok {
return []string{}, nil
}
s.mu.RLock() // Lock needed for iterating map
defer s.mu.RUnlock()
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
return members, nil
}
func (s *memoryStore) SRem(key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
s.sRemNoLock(key, members...)
return nil
}
func (s *memoryStore) sRemNoLock(key string, members ...any) {
item := s.getItem(key, true)
if item == nil {
return
item, ok := s.items[key]
if !ok || item.isExpired() {
return nil
}
set, ok := item.value.(map[string]struct{})
if !ok {
return
return nil
}
for _, member := range members {
delete(set, fmt.Sprintf("%v", member))
}
return nil
}
func (s *memoryStore) SRandMember(key string) (string, error) {
item := s.getItem(key, false)
if item == nil {
return "", ErrNotFound
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return "", ErrNotFound
}
s.mu.RLock()
defer s.mu.RUnlock()
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
if len(members) == 0 {
return "", ErrNotFound
}
return members[rand.Intn(len(members))], nil
}
func (s *memoryStore) Rotate(key string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item := s.getItem(key, true)
if item == nil {
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
list, ok := item.value.([]string)
@@ -391,17 +425,18 @@ func (s *memoryStore) Rotate(key string) (string, error) {
item.value = append([]string{val}, list[:len(list)-1]...)
return val, nil
}
func (s *memoryStore) LIndex(key string, index int64) (string, error) {
item := s.getItem(key, false)
if item == nil {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
list, ok := item.value.([]string)
if !ok {
return "", ErrNotFound
}
s.mu.RLock()
defer s.mu.RUnlock()
l := int64(len(list))
if index < 0 {
index += l
@@ -416,8 +451,8 @@ func (s *memoryStore) LIndex(key string, index int64) (string, error) {
func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
s.mu.Lock()
defer s.mu.Unlock()
item := s.getItem(key, true)
if item == nil {
item, ok := s.items[key]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make([]zsetMember, 0)}
s.items[key] = item
}
@@ -425,39 +460,39 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
if !ok {
zset = make([]zsetMember, 0)
}
for memberVal, score := range members {
found := false
for i := range zset {
if zset[i].Value == memberVal {
zset[i].Score = score
found = true
break
}
}
if !found {
zset = append(zset, zsetMember{Value: memberVal, Score: score})
}
membersMap := make(map[string]float64, len(zset))
for _, z := range zset {
membersMap[z.Value] = z.Score
}
sort.Slice(zset, func(i, j int) bool {
if zset[i].Score == zset[j].Score {
return zset[i].Value < zset[j].Value
for memberVal, score := range members {
membersMap[memberVal] = score
}
newZSet := make([]zsetMember, 0, len(membersMap))
for val, score := range membersMap {
newZSet = append(newZSet, zsetMember{Value: val, Score: score})
}
// NOTE: This ZSet implementation is simple but not performant for large sets.
// A production implementation would use a skip list or a balanced tree.
sort.Slice(newZSet, func(i, j int) bool {
if newZSet[i].Score == newZSet[j].Score {
return newZSet[i].Value < newZSet[j].Value
}
return zset[i].Score < zset[j].Score
return newZSet[i].Score < newZSet[j].Score
})
item.value = zset
item.value = newZSet
return nil
}
func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
item := s.getItem(key, false)
if item == nil {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return []string{}, nil
}
zset, ok := item.value.([]zsetMember)
if !ok {
return []string{}, nil
}
s.mu.RLock()
defer s.mu.RUnlock()
l := int64(len(zset))
if start < 0 {
start += l
@@ -483,8 +518,8 @@ func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
func (s *memoryStore) ZRem(key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item := s.getItem(key, true)
if item == nil {
item, ok := s.items[key]
if !ok || item.isExpired() {
return nil
}
zset, ok := item.value.([]zsetMember)
@@ -504,32 +539,6 @@ func (s *memoryStore) ZRem(key string, members ...any) error {
item.value = newZSet
return nil
}
func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
mainItem := s.getItem(mainKey, true)
if mainItem == nil || len(mainItem.value.(map[string]struct{})) == 0 {
cooldownItem := s.getItem(cooldownKey, true)
if cooldownItem == nil {
return "", ErrNotFound
}
// "Rename" by moving value and deleting old key
s.items[mainKey] = cooldownItem
delete(s.items, cooldownKey)
mainItem = cooldownItem
}
mainSet, ok := mainItem.value.(map[string]struct{})
if !ok || len(mainSet) == 0 {
return "", ErrNotFound // Should not happen after cycle logic
}
var popped string
for k := range mainSet {
popped = k
break
}
delete(mainSet, popped)
return popped, nil
}
// Pipeline implementation
type memoryPipeliner struct {
@@ -549,41 +558,90 @@ func (p *memoryPipeliner) Exec() error {
return nil
}
// [核心修正] Expire 现在可以正确地为任何 key 设置过期时间
func (p *memoryPipeliner) Expire(key string, expiration time.Duration) {
// [USER SUGGESTION APPLIED] Fix #4: Capture value, not reference
capturedKey := key
p.ops = append(p.ops, func() {
// This must be called within Exec's lock
item := p.store.getItem(key, true)
if item != nil {
if item, ok := p.store.items[capturedKey]; ok {
item.expireAt = time.Now().Add(expiration)
}
})
}
// All other pipeliner methods...
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
p.ops = append(p.ops, func() { p.store.hSetNoLock(key, values) })
}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {
p.ops = append(p.ops, func() { p.store.hIncrByNoLock(key, field, incr) })
}
func (p *memoryPipeliner) Del(keys ...string) {
p.ops = append(p.ops, func() { p.store.delNoLock(keys...) })
capturedKeys := make([]string, len(keys))
copy(capturedKeys, keys)
p.ops = append(p.ops, func() {
for _, key := range capturedKeys {
delete(p.store.items, key)
}
})
}
func (p *memoryPipeliner) SAdd(key string, members ...any) {
p.ops = append(p.ops, func() { p.store.sAddNoLock(key, members...) })
}
func (p *memoryPipeliner) SRem(key string, members ...any) {
p.ops = append(p.ops, func() { p.store.sRemNoLock(key, members...) })
}
func (p *memoryPipeliner) LPush(key string, values ...any) {
p.ops = append(p.ops, func() { p.store.lPushNoLock(key, values...) })
}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {
p.ops = append(p.ops, func() { p.store.lRemNoLock(key, count, value) })
capturedKey := key
capturedMembers := make([]any, len(members))
copy(capturedMembers, members)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]struct{})}
p.store.items[capturedKey] = item
}
set, ok := item.value.(map[string]struct{})
if !ok {
set = make(map[string]struct{})
item.value = set
}
for _, member := range capturedMembers {
set[fmt.Sprintf("%v", member)] = struct{}{}
}
})
}
// Pub/Sub implementation (remains unchanged as it's a separate system)
func (p *memoryPipeliner) SRem(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
copy(capturedMembers, members)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
return
}
set, ok := item.value.(map[string]struct{})
if !ok {
return
}
for _, member := range capturedMembers {
delete(set, fmt.Sprintf("%v", member))
}
})
}
func (p *memoryPipeliner) LPush(key string, values ...any) {
capturedKey := key
capturedValues := make([]any, len(values))
copy(capturedValues, values)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make([]string, 0)}
p.store.items[capturedKey] = item
}
list, ok := item.value.([]string)
if !ok {
list = make([]string, 0)
}
stringValues := make([]string, len(capturedValues))
for i, v := range capturedValues {
stringValues[i] = fmt.Sprintf("%v", v)
}
item.value = append(stringValues, list...)
})
}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
// --- Pub/Sub implementation (remains unchanged) ---
type memorySubscription struct {
store *memoryStore
channelName string