before update store
This commit is contained in:
@@ -24,7 +24,7 @@ type SystemSettings struct {
|
||||
BaseKeyCheckIntervalMinutes int `json:"base_key_check_interval_minutes" default:"1440" name:"全局Key检查周期(分钟)" category:"健康检查" desc:"对所有ACTIVE状态的Key进行身份检查的周期,建议设置较长时间,例如1天(1440分钟)。"`
|
||||
BaseKeyCheckConcurrency int `json:"base_key_check_concurrency" default:"5" name:"全局Key检查并发数" category:"健康检查" desc:"执行全局Key身份检查时的并发请求数量。"`
|
||||
BaseKeyCheckEndpoint string `json:"base_key_check_endpoint" default:"https://generativelanguage.googleapis.com/v1beta/models" name:"全局Key检查端点" category:"健康检查" desc:"用于全局Key身份检查的目标URL。"`
|
||||
BaseKeyCheckModel string `json:"base_key_check_model" default:"gemini-1.5-flash" name:"默认Key检查模型" category:"健康检查" desc:"用于分组健康检查和手动密钥测试时的默认回退模型。"`
|
||||
BaseKeyCheckModel string `json:"base_key_check_model" default:"gemini-2.0-flash-lite" name:"默认Key检查模型" category:"健康检查" desc:"用于分组健康检查和手动密钥测试时的默认回退模型。"`
|
||||
|
||||
EnableUpstreamCheck bool `json:"enable_upstream_check" default:"true" name:"启用上游检查" category:"健康检查" desc:"是否启用对上游服务(Upstream)的健康检查。"`
|
||||
UpstreamCheckTimeoutSeconds int `json:"upstream_check_timeout_seconds" default:"20" name:"上游检查超时(秒)" category:"健康检查" desc:"对单个上游服务进行健康检查时的网络超时时间。"`
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"io"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -82,20 +81,13 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.
|
||||
|
||||
// SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。
|
||||
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
|
||||
protocol := "default"
|
||||
if pool.Protocol != "" {
|
||||
protocol = string(pool.Protocol)
|
||||
}
|
||||
// 生成唯一的池ID,确保不同请求组合的轮询状态相互隔离
|
||||
poolID := generatePoolID(pool.CandidateGroups, protocol)
|
||||
log := r.logger.WithField("pool_id", poolID).WithField("protocol", protocol)
|
||||
poolID := generatePoolID(pool.CandidateGroups)
|
||||
log := r.logger.WithField("pool_id", poolID)
|
||||
|
||||
if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil {
|
||||
log.WithError(err).Error("Failed to ensure BasePool cache exists.")
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil, err
|
||||
}
|
||||
return nil, nil, fmt.Errorf("unexpected error while ensuring base pool cache: %w", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var keyIDStr string
|
||||
@@ -154,65 +146,78 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
|
||||
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
|
||||
func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error {
|
||||
listKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
|
||||
// --- [逻辑优化] 提前处理“毒丸”,让逻辑更清晰 ---
|
||||
exists, err := r.store.Exists(listKey)
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Errorf("Failed to check existence of basepool key: %s", listKey)
|
||||
return err
|
||||
r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID)
|
||||
return err // 直接返回读取错误
|
||||
}
|
||||
if exists {
|
||||
val, err := r.store.LIndex(listKey, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
// 如果连 LIndex 都失败,说明缓存可能已损坏,允许重建
|
||||
r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID)
|
||||
} else {
|
||||
if val == EmptyPoolPlaceholder {
|
||||
return gorm.ErrRecordNotFound // 已知为空,直接返回
|
||||
}
|
||||
return nil // 缓存有效,直接返回
|
||||
}
|
||||
if val == EmptyPoolPlaceholder {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// --- [锁机制优化] 增加分布式锁,防止并发构建时的惊群效应 ---
|
||||
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
|
||||
acquired, err := r.store.SetNX(lockKey, []byte("1"), 10*time.Second)
|
||||
acquired, err := r.store.SetNX(lockKey, []byte("1"), 10*time.Second) // 10秒锁超时
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Errorf("Failed to acquire distributed lock for basepool build: %s", lockKey)
|
||||
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.")
|
||||
return err
|
||||
}
|
||||
if !acquired {
|
||||
// 未获取到锁,等待一小段时间后重试,让持有锁的协程完成构建
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return r.ensureBasePoolCacheExists(pool, poolID)
|
||||
}
|
||||
defer r.store.Del(lockKey)
|
||||
defer r.store.Del(lockKey) // 确保在函数退出时释放锁
|
||||
// 双重检查,防止在获取锁的间隙,已有其他协程完成了构建
|
||||
if exists, _ := r.store.Exists(listKey); exists {
|
||||
return nil
|
||||
}
|
||||
r.logger.Infof("BasePool cache for pool_id '%s' not found. Building now...", poolID)
|
||||
r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID)
|
||||
var allActiveKeyIDs []string
|
||||
lruMembers := make(map[string]float64)
|
||||
|
||||
for _, group := range pool.CandidateGroups {
|
||||
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
|
||||
groupKeyIDs, err := r.store.SMembers(activeKeySetKey)
|
||||
|
||||
// --- [核心修正] ---
|
||||
// 这是整个问题的根源。我们绝不能在读取失败时,默默地`continue`。
|
||||
// 任何读取源数据的失败,都必须被视为一次构建过程的彻底失败,并立即中止。
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to get active keys for group %d during BasePool build", group.ID)
|
||||
continue
|
||||
r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID)
|
||||
// 返回这个瞬时错误。这会导致本次请求失败,但绝不会写入“毒丸”,
|
||||
// 从而给了下一次请求一个全新的、成功的机会。
|
||||
return err
|
||||
}
|
||||
// 只有在 SMembers 成功时,才继续处理
|
||||
allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...)
|
||||
for _, keyIDStr := range groupKeyIDs {
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
|
||||
_, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) || strings.Contains(err.Error(), "failed to get") {
|
||||
r.logger.WithError(err).Warnf("Cache inconsistency detected for KeyID %s in GroupID %d. Skipping.", keyIDStr, group.ID)
|
||||
continue
|
||||
} else {
|
||||
return err
|
||||
if err == nil && mapping != nil {
|
||||
var score float64
|
||||
if mapping.LastUsedAt != nil {
|
||||
score = float64(mapping.LastUsedAt.UnixMilli())
|
||||
}
|
||||
}
|
||||
allActiveKeyIDs = append(allActiveKeyIDs, keyIDStr)
|
||||
if mapping != nil && mapping.LastUsedAt != nil {
|
||||
lruMembers[keyIDStr] = float64(mapping.LastUsedAt.UnixMilli())
|
||||
lruMembers[keyIDStr] = score
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- [逻辑修正] ---
|
||||
// 只有在“我们成功读取了所有数据,但发现数据本身是空的”这种情况下,
|
||||
// 才允许写入“毒丸”。
|
||||
if len(allActiveKeyIDs) == 0 {
|
||||
r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID)
|
||||
pipe := r.store.Pipeline()
|
||||
pipe.LPush(listKey, EmptyPoolPlaceholder)
|
||||
pipe.Expire(listKey, EmptyCacheTTL)
|
||||
@@ -221,16 +226,23 @@ func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID str
|
||||
}
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
// 使用管道填充所有轮询结构
|
||||
pipe := r.store.Pipeline()
|
||||
// 1. 顺序
|
||||
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
|
||||
// 2. 随机
|
||||
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
|
||||
|
||||
// 设置合理的过期时间,例如5分钟,以防止孤儿数据
|
||||
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolLRU, poolID), CacheTTL)
|
||||
|
||||
if err := pipe.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(lruMembers) > 0 {
|
||||
r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers)
|
||||
}
|
||||
@@ -246,7 +258,7 @@ func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID
|
||||
}
|
||||
|
||||
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
|
||||
func generatePoolID(groups []*models.KeyGroup, protocol string) string {
|
||||
func generatePoolID(groups []*models.KeyGroup) string {
|
||||
ids := make([]int, len(groups))
|
||||
for i, g := range groups {
|
||||
ids[i] = int(g.ID)
|
||||
@@ -254,7 +266,7 @@ func generatePoolID(groups []*models.KeyGroup, protocol string) string {
|
||||
sort.Ints(ids)
|
||||
|
||||
h := sha1.New()
|
||||
io.WriteString(h, fmt.Sprintf("protocol:%s;groups:%v", protocol, ids))
|
||||
io.WriteString(h, fmt.Sprintf("%v", ids))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Filename: internal/store/memory_store.go (统一存储重构版)
|
||||
// Filename: internal/store/memory_store.go (经同行审查后最终修复版)
|
||||
|
||||
package store
|
||||
|
||||
@@ -12,76 +12,144 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ensure memoryStore implements Store interface
|
||||
var _ Store = (*memoryStore)(nil)
|
||||
|
||||
// [核心重构] memoryStoreItem 现在是通用容器,可以存储任何类型的值,并自带过期时间
|
||||
type memoryStoreItem struct {
|
||||
value interface{} // 可以是 []byte, []string, map[string]string, map[string]struct{}, []zsetMember
|
||||
value interface{}
|
||||
expireAt time.Time
|
||||
}
|
||||
|
||||
// isExpired 检查一个条目是否已过期
|
||||
func (item *memoryStoreItem) isExpired() bool {
|
||||
return !item.expireAt.IsZero() && time.Now().After(item.expireAt)
|
||||
}
|
||||
|
||||
// zsetMember 保持不变
|
||||
type zsetMember struct {
|
||||
Value string
|
||||
Score float64
|
||||
}
|
||||
|
||||
// [核心重构] memoryStore 现在使用一个统一的 map 来存储所有数据
|
||||
type memoryStore struct {
|
||||
items map[string]*memoryStoreItem // 指向 item 的指针,以便原地修改
|
||||
items map[string]*memoryStoreItem
|
||||
pubsub map[string][]chan *Message
|
||||
mu sync.RWMutex
|
||||
// [USER SUGGESTION APPLIED] 使用带锁的随机数源以保证并发安全
|
||||
rng *rand.Rand
|
||||
rngMu sync.Mutex
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
// NewMemoryStore [核心重構] 構造函數也被簡化了
|
||||
func NewMemoryStore(logger *logrus.Logger) Store {
|
||||
return &memoryStore{
|
||||
store := &memoryStore{
|
||||
items: make(map[string]*memoryStoreItem),
|
||||
pubsub: make(map[string][]chan *Message),
|
||||
// 使用当前时间作为种子,创建一个新的随机数源
|
||||
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
logger: logger.WithField("component", "store.memory 🗱"),
|
||||
}
|
||||
go store.startGCollector()
|
||||
return store
|
||||
}
|
||||
|
||||
// [核心重构] getItem 是一个新的内部辅助函数,它封装了获取、检查过期和删除的通用逻辑
|
||||
func (s *memoryStore) getItem(key string, lockForWrite bool) *memoryStoreItem {
|
||||
if !lockForWrite {
|
||||
// 如果是读操作,先用读锁检查
|
||||
s.mu.RLock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
s.mu.RUnlock()
|
||||
// 如果不存在或已过期,需要尝试获取写锁来删除它
|
||||
if ok { // 只有在确定 item 存在但已过期时才需要删除
|
||||
s.mu.Lock()
|
||||
// 再次检查,防止在获取写锁期间状态已改变
|
||||
if item, ok := s.items[key]; ok && item.isExpired() {
|
||||
delete(s.items, key)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
// [USER SUGGESTION INCORPORATED] Fix #1: 使用 now := time.Now() 进行原子性检查
|
||||
func (s *memoryStore) startGCollector() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
s.mu.Lock()
|
||||
now := time.Now() // 避免在循环中重复调用
|
||||
for key, item := range s.items {
|
||||
if !item.expireAt.IsZero() && now.After(item.expireAt) {
|
||||
delete(s.items, key)
|
||||
}
|
||||
return nil // 无论如何都返回 nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
return item
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// 对于写操作,直接使用写锁
|
||||
item, ok := s.items[key]
|
||||
if ok && item.isExpired() {
|
||||
delete(s.items, key)
|
||||
return nil
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
// --- 所有接口方法现在都基于新的统一结构重写 ---
|
||||
// [USER SUGGESTION INCORPORATED] Fix #2 & #3: 修复了致命的nil检查和类型断言问题
|
||||
func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
mainItem, mainOk := s.items[mainKey]
|
||||
var mainSet map[string]struct{}
|
||||
|
||||
if mainOk && !mainItem.isExpired() {
|
||||
// 安全地进行类型断言
|
||||
mainSet, mainOk = mainItem.value.(map[string]struct{})
|
||||
// 确保断言成功且集合不为空
|
||||
mainOk = mainOk && len(mainSet) > 0
|
||||
} else {
|
||||
mainOk = false
|
||||
}
|
||||
|
||||
if !mainOk {
|
||||
cooldownItem, cooldownOk := s.items[cooldownKey]
|
||||
if !cooldownOk || cooldownItem.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
// 安全地进行类型断言
|
||||
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
|
||||
if !cooldownSetOk || len(cooldownSet) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
s.items[mainKey] = cooldownItem
|
||||
delete(s.items, cooldownKey)
|
||||
mainSet = cooldownSet
|
||||
}
|
||||
|
||||
var popped string
|
||||
for k := range mainSet {
|
||||
popped = k
|
||||
break
|
||||
}
|
||||
delete(mainSet, popped)
|
||||
|
||||
cooldownItem, cooldownOk := s.items[cooldownKey]
|
||||
if !cooldownOk || cooldownItem.isExpired() {
|
||||
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
|
||||
s.items[cooldownKey] = cooldownItem
|
||||
}
|
||||
// 安全地处理冷却池
|
||||
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
|
||||
if !ok {
|
||||
cooldownSet = make(map[string]struct{})
|
||||
cooldownItem.value = cooldownSet
|
||||
}
|
||||
cooldownSet[popped] = struct{}{}
|
||||
|
||||
return popped, nil
|
||||
}
|
||||
|
||||
// SRandMember [并发修复版] 使用带锁的rng
|
||||
func (s *memoryStore) SRandMember(key string) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok || len(set) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
members := make([]string, 0, len(set))
|
||||
for member := range set {
|
||||
members = append(members, member)
|
||||
}
|
||||
if len(members) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
s.rngMu.Lock()
|
||||
n := s.rng.Intn(len(members))
|
||||
s.rngMu.Unlock()
|
||||
|
||||
return members[n], nil
|
||||
}
|
||||
|
||||
// --- 以下是其余函数的最终版本,它们都遵循了安全、原子的锁策略 ---
|
||||
|
||||
func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
|
||||
s.mu.Lock()
|
||||
@@ -95,14 +163,16 @@ func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
|
||||
}
|
||||
|
||||
func (s *memoryStore) Get(key string) ([]byte, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if value, ok := item.value.([]byte); ok {
|
||||
return value, nil
|
||||
}
|
||||
return nil, ErrNotFound // Type mismatch, treat as not found
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
func (s *memoryStore) Del(keys ...string) error {
|
||||
@@ -114,20 +184,18 @@ func (s *memoryStore) Del(keys ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) delNoLock(keys ...string) {
|
||||
for _, key := range keys {
|
||||
delete(s.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *memoryStore) Exists(key string) (bool, error) {
|
||||
return s.getItem(key, false) != nil, nil
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
return ok && !item.isExpired(), nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if item := s.getItem(key, true); item != nil {
|
||||
item, ok := s.items[key]
|
||||
if ok && !item.isExpired() {
|
||||
return false, nil
|
||||
}
|
||||
var expireAt time.Time
|
||||
@@ -143,8 +211,8 @@ func (s *memoryStore) Close() error { return nil }
|
||||
func (s *memoryStore) HDel(key string, fields ...string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return nil
|
||||
}
|
||||
if hash, ok := item.value.(map[string]string); ok {
|
||||
@@ -158,29 +226,27 @@ func (s *memoryStore) HDel(key string, fields ...string) error {
|
||||
func (s *memoryStore) HSet(key string, values map[string]any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.hSetNoLock(key, values)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) hSetNoLock(key string, values map[string]any) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
item = &memoryStoreItem{value: make(map[string]string)}
|
||||
s.items[key] = item
|
||||
}
|
||||
hash, ok := item.value.(map[string]string)
|
||||
if !ok { // If key exists but is not a hash, create a new hash
|
||||
if !ok {
|
||||
hash = make(map[string]string)
|
||||
item.value = hash
|
||||
}
|
||||
for field, value := range values {
|
||||
hash[field] = fmt.Sprintf("%v", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return make(map[string]string), nil
|
||||
}
|
||||
if hash, ok := item.value.(map[string]string); ok {
|
||||
@@ -196,12 +262,8 @@ func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
|
||||
func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.hIncrByNoLock(key, field, incr)
|
||||
}
|
||||
|
||||
func (s *memoryStore) hIncrByNoLock(key, field string, incr int64) (int64, error) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
item = &memoryStoreItem{value: make(map[string]string)}
|
||||
s.items[key] = item
|
||||
}
|
||||
@@ -222,13 +284,8 @@ func (s *memoryStore) hIncrByNoLock(key, field string, incr int64) (int64, error
|
||||
func (s *memoryStore) LPush(key string, values ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.lPushNoLock(key, values...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) lPushNoLock(key string, values ...any) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
item = &memoryStoreItem{value: make([]string, 0)}
|
||||
s.items[key] = item
|
||||
}
|
||||
@@ -241,22 +298,19 @@ func (s *memoryStore) lPushNoLock(key string, values ...any) {
|
||||
stringValues[i] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
item.value = append(stringValues, list...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) LRem(key string, count int64, value any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.lRemNoLock(key, count, value)
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) lRemNoLock(key string, count int64, value any) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
return
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return nil
|
||||
}
|
||||
list, ok := item.value.([]string)
|
||||
if !ok {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
valToRemove := fmt.Sprintf("%v", value)
|
||||
newList := make([]string, 0, len(list))
|
||||
@@ -269,17 +323,14 @@ func (s *memoryStore) lRemNoLock(key string, count int64, value any) {
|
||||
}
|
||||
}
|
||||
item.value = newList
|
||||
}
|
||||
func (s *memoryStore) SAdd(key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sAddNoLock(key, members...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) sAddNoLock(key string, members ...any) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
func (s *memoryStore) SAdd(key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
item = &memoryStoreItem{value: make(map[string]struct{})}
|
||||
s.items[key] = item
|
||||
}
|
||||
@@ -291,12 +342,14 @@ func (s *memoryStore) sAddNoLock(key string, members ...any) {
|
||||
for _, member := range members {
|
||||
set[fmt.Sprintf("%v", member)] = struct{}{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return []string{}, nil
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
@@ -311,7 +364,9 @@ func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
|
||||
for k := range set {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
rand.Shuffle(len(keys), func(i, j int) { keys[i], keys[j] = keys[j], keys[i] })
|
||||
s.rngMu.Lock()
|
||||
s.rng.Shuffle(len(keys), func(i, j int) { keys[i], keys[j] = keys[j], keys[i] })
|
||||
s.rngMu.Unlock()
|
||||
for i := int64(0); i < count; i++ {
|
||||
poppedKey := keys[i]
|
||||
popped = append(popped, poppedKey)
|
||||
@@ -319,68 +374,47 @@ func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
|
||||
}
|
||||
return popped, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SMembers(key string) ([]string, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return []string{}, nil
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok {
|
||||
return []string{}, nil
|
||||
}
|
||||
s.mu.RLock() // Lock needed for iterating map
|
||||
defer s.mu.RUnlock()
|
||||
members := make([]string, 0, len(set))
|
||||
for member := range set {
|
||||
members = append(members, member)
|
||||
}
|
||||
return members, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SRem(key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sRemNoLock(key, members...)
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) sRemNoLock(key string, members ...any) {
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
return
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return nil
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
for _, member := range members {
|
||||
delete(set, fmt.Sprintf("%v", member))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SRandMember(key string) (string, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok || len(set) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
members := make([]string, 0, len(set))
|
||||
for member := range set {
|
||||
members = append(members, member)
|
||||
}
|
||||
if len(members) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return members[rand.Intn(len(members))], nil
|
||||
}
|
||||
func (s *memoryStore) Rotate(key string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
list, ok := item.value.([]string)
|
||||
@@ -391,17 +425,18 @@ func (s *memoryStore) Rotate(key string) (string, error) {
|
||||
item.value = append([]string{val}, list[:len(list)-1]...)
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) LIndex(key string, index int64) (string, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
list, ok := item.value.([]string)
|
||||
if !ok {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
l := int64(len(list))
|
||||
if index < 0 {
|
||||
index += l
|
||||
@@ -416,8 +451,8 @@ func (s *memoryStore) LIndex(key string, index int64) (string, error) {
|
||||
func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
item = &memoryStoreItem{value: make([]zsetMember, 0)}
|
||||
s.items[key] = item
|
||||
}
|
||||
@@ -425,39 +460,39 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
|
||||
if !ok {
|
||||
zset = make([]zsetMember, 0)
|
||||
}
|
||||
for memberVal, score := range members {
|
||||
found := false
|
||||
for i := range zset {
|
||||
if zset[i].Value == memberVal {
|
||||
zset[i].Score = score
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
zset = append(zset, zsetMember{Value: memberVal, Score: score})
|
||||
}
|
||||
membersMap := make(map[string]float64, len(zset))
|
||||
for _, z := range zset {
|
||||
membersMap[z.Value] = z.Score
|
||||
}
|
||||
sort.Slice(zset, func(i, j int) bool {
|
||||
if zset[i].Score == zset[j].Score {
|
||||
return zset[i].Value < zset[j].Value
|
||||
for memberVal, score := range members {
|
||||
membersMap[memberVal] = score
|
||||
}
|
||||
newZSet := make([]zsetMember, 0, len(membersMap))
|
||||
for val, score := range membersMap {
|
||||
newZSet = append(newZSet, zsetMember{Value: val, Score: score})
|
||||
}
|
||||
// NOTE: This ZSet implementation is simple but not performant for large sets.
|
||||
// A production implementation would use a skip list or a balanced tree.
|
||||
sort.Slice(newZSet, func(i, j int) bool {
|
||||
if newZSet[i].Score == newZSet[j].Score {
|
||||
return newZSet[i].Value < newZSet[j].Value
|
||||
}
|
||||
return zset[i].Score < zset[j].Score
|
||||
return newZSet[i].Score < newZSet[j].Score
|
||||
})
|
||||
item.value = zset
|
||||
item.value = newZSet
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
|
||||
item := s.getItem(key, false)
|
||||
if item == nil {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return []string{}, nil
|
||||
}
|
||||
zset, ok := item.value.([]zsetMember)
|
||||
if !ok {
|
||||
return []string{}, nil
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
l := int64(len(zset))
|
||||
if start < 0 {
|
||||
start += l
|
||||
@@ -483,8 +518,8 @@ func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
|
||||
func (s *memoryStore) ZRem(key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item := s.getItem(key, true)
|
||||
if item == nil {
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return nil
|
||||
}
|
||||
zset, ok := item.value.([]zsetMember)
|
||||
@@ -504,32 +539,6 @@ func (s *memoryStore) ZRem(key string, members ...any) error {
|
||||
item.value = newZSet
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
mainItem := s.getItem(mainKey, true)
|
||||
if mainItem == nil || len(mainItem.value.(map[string]struct{})) == 0 {
|
||||
cooldownItem := s.getItem(cooldownKey, true)
|
||||
if cooldownItem == nil {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
// "Rename" by moving value and deleting old key
|
||||
s.items[mainKey] = cooldownItem
|
||||
delete(s.items, cooldownKey)
|
||||
mainItem = cooldownItem
|
||||
}
|
||||
mainSet, ok := mainItem.value.(map[string]struct{})
|
||||
if !ok || len(mainSet) == 0 {
|
||||
return "", ErrNotFound // Should not happen after cycle logic
|
||||
}
|
||||
var popped string
|
||||
for k := range mainSet {
|
||||
popped = k
|
||||
break
|
||||
}
|
||||
delete(mainSet, popped)
|
||||
return popped, nil
|
||||
}
|
||||
|
||||
// Pipeline implementation
|
||||
type memoryPipeliner struct {
|
||||
@@ -549,41 +558,90 @@ func (p *memoryPipeliner) Exec() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// [核心修正] Expire 现在可以正确地为任何 key 设置过期时间
|
||||
func (p *memoryPipeliner) Expire(key string, expiration time.Duration) {
|
||||
// [USER SUGGESTION APPLIED] Fix #4: Capture value, not reference
|
||||
capturedKey := key
|
||||
p.ops = append(p.ops, func() {
|
||||
// This must be called within Exec's lock
|
||||
item := p.store.getItem(key, true)
|
||||
if item != nil {
|
||||
if item, ok := p.store.items[capturedKey]; ok {
|
||||
item.expireAt = time.Now().Add(expiration)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// All other pipeliner methods...
|
||||
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
|
||||
p.ops = append(p.ops, func() { p.store.hSetNoLock(key, values) })
|
||||
}
|
||||
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {
|
||||
p.ops = append(p.ops, func() { p.store.hIncrByNoLock(key, field, incr) })
|
||||
}
|
||||
func (p *memoryPipeliner) Del(keys ...string) {
|
||||
p.ops = append(p.ops, func() { p.store.delNoLock(keys...) })
|
||||
capturedKeys := make([]string, len(keys))
|
||||
copy(capturedKeys, keys)
|
||||
p.ops = append(p.ops, func() {
|
||||
for _, key := range capturedKeys {
|
||||
delete(p.store.items, key)
|
||||
}
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) SAdd(key string, members ...any) {
|
||||
p.ops = append(p.ops, func() { p.store.sAddNoLock(key, members...) })
|
||||
}
|
||||
func (p *memoryPipeliner) SRem(key string, members ...any) {
|
||||
p.ops = append(p.ops, func() { p.store.sRemNoLock(key, members...) })
|
||||
}
|
||||
func (p *memoryPipeliner) LPush(key string, values ...any) {
|
||||
p.ops = append(p.ops, func() { p.store.lPushNoLock(key, values...) })
|
||||
}
|
||||
func (p *memoryPipeliner) LRem(key string, count int64, value any) {
|
||||
p.ops = append(p.ops, func() { p.store.lRemNoLock(key, count, value) })
|
||||
capturedKey := key
|
||||
capturedMembers := make([]any, len(members))
|
||||
copy(capturedMembers, members)
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
item = &memoryStoreItem{value: make(map[string]struct{})}
|
||||
p.store.items[capturedKey] = item
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok {
|
||||
set = make(map[string]struct{})
|
||||
item.value = set
|
||||
}
|
||||
for _, member := range capturedMembers {
|
||||
set[fmt.Sprintf("%v", member)] = struct{}{}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Pub/Sub implementation (remains unchanged as it's a separate system)
|
||||
func (p *memoryPipeliner) SRem(key string, members ...any) {
|
||||
capturedKey := key
|
||||
capturedMembers := make([]any, len(members))
|
||||
copy(capturedMembers, members)
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
return
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, member := range capturedMembers {
|
||||
delete(set, fmt.Sprintf("%v", member))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) LPush(key string, values ...any) {
|
||||
capturedKey := key
|
||||
capturedValues := make([]any, len(values))
|
||||
copy(capturedValues, values)
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
item = &memoryStoreItem{value: make([]string, 0)}
|
||||
p.store.items[capturedKey] = item
|
||||
}
|
||||
list, ok := item.value.([]string)
|
||||
if !ok {
|
||||
list = make([]string, 0)
|
||||
}
|
||||
stringValues := make([]string, len(capturedValues))
|
||||
for i, v := range capturedValues {
|
||||
stringValues[i] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
item.value = append(stringValues, list...)
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
|
||||
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
|
||||
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
|
||||
|
||||
// --- Pub/Sub implementation (remains unchanged) ---
|
||||
type memorySubscription struct {
|
||||
store *memoryStore
|
||||
channelName string
|
||||
|
||||
Reference in New Issue
Block a user