Fix basepool & 优化 repo

This commit is contained in:
XOF
2025-11-23 22:42:58 +08:00
parent 2b0b9b67dc
commit 6c7283d51b
16 changed files with 1312 additions and 723 deletions

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"math/rand"
"sort"
"strconv"
"sync"
"time"
@@ -65,7 +66,7 @@ func (s *memoryStore) startGCollector() {
}
}
// --- [架构修正] 所有方法签名都增加了 context.Context 参数以匹配接口 ---
// --- 所有方法签名都增加了 context.Context 参数以匹配接口 ---
// --- 内存实现可以忽略该参数,用 _ 接收 ---
func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error {
@@ -108,6 +109,17 @@ func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) {
return ok && !item.isExpired(), nil
}
func (s *memoryStore) Expire(_ context.Context, key string, expiration time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok {
return ErrNotFound
}
item.expireAt = time.Now().Add(expiration)
return nil
}
func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -159,6 +171,21 @@ func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any)
return nil
}
func (s *memoryStore) HGet(_ context.Context, key, field string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
if hash, ok := item.value.(map[string]string); ok {
if value, exists := hash[field]; exists {
return value, nil
}
}
return "", ErrNotFound
}
func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -351,6 +378,26 @@ func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error)
return members[n], nil
}
func (s *memoryStore) SUnionStore(_ context.Context, destination string, keys ...string) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
unionSet := make(map[string]struct{})
for _, key := range keys {
item, ok := s.items[key]
if !ok || item.isExpired() {
continue
}
if set, ok := item.value.(map[string]struct{}); ok {
for member := range set {
unionSet[member] = struct{}{}
}
}
}
destItem := &memoryStoreItem{value: unionSet}
s.items[destination] = destItem
return int64(len(unionSet)), nil
}
func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -388,6 +435,16 @@ func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string
return list[index], nil
}
func (s *memoryStore) MSet(ctx context.Context, values map[string]any) error {
s.mu.Lock()
defer s.mu.Unlock()
for key, value := range values {
// 内存存储不支持独立的 TTL因此我们假设永不过期
s.items[key] = &memoryStoreItem{value: value, expireAt: time.Time{}}
}
return nil
}
func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -556,6 +613,22 @@ func (p *memoryPipeliner) Del(keys ...string) {
}
})
}
func (p *memoryPipeliner) Set(key string, value []byte, expiration time.Duration) {
capturedKey := key
capturedValue := value
p.ops = append(p.ops, func() {
var expireAt time.Time
if expiration > 0 {
expireAt = time.Now().Add(expiration)
}
p.store.items[capturedKey] = &memoryStoreItem{
value: capturedValue,
expireAt: expireAt,
}
})
}
func (p *memoryPipeliner) SAdd(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
@@ -576,6 +649,7 @@ func (p *memoryPipeliner) SAdd(key string, members ...any) {
}
})
}
func (p *memoryPipeliner) SRem(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
@@ -615,11 +689,125 @@ func (p *memoryPipeliner) LPush(key string, values ...any) {
item.value = append(stringValues, list...)
})
}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {}
func (p *memoryPipeliner) ZRem(key string, members ...any) {}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {
capturedKey := key
capturedValue := fmt.Sprintf("%v", value)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
return
}
list, ok := item.value.([]string)
if !ok {
return
}
newList := make([]string, 0, len(list))
removed := int64(0)
for _, v := range list {
if count != 0 && v == capturedValue && (count < 0 || removed < count) {
removed++
continue
}
newList = append(newList, v)
}
item.value = newList
})
}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
capturedKey := key
capturedValues := make(map[string]any, len(values))
for k, v := range values {
capturedValues[k] = v
}
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
p.store.items[capturedKey] = item
}
hash, ok := item.value.(map[string]string)
if !ok {
hash = make(map[string]string)
item.value = hash
}
for field, value := range capturedValues {
hash[field] = fmt.Sprintf("%v", value)
}
})
}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {
capturedKey := key
capturedField := field
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
p.store.items[capturedKey] = item
}
hash, ok := item.value.(map[string]string)
if !ok {
hash = make(map[string]string)
item.value = hash
}
current, _ := strconv.ParseInt(hash[capturedField], 10, 64)
hash[capturedField] = strconv.FormatInt(current+incr, 10)
})
}
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {
capturedKey := key
capturedMembers := make(map[string]float64, len(members))
for k, v := range members {
capturedMembers[k] = v
}
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]float64)}
p.store.items[capturedKey] = item
}
zset, ok := item.value.(map[string]float64)
if !ok {
zset = make(map[string]float64)
item.value = zset
}
for member, score := range capturedMembers {
zset[member] = score
}
})
}
func (p *memoryPipeliner) ZRem(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
copy(capturedMembers, members)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
return
}
zset, ok := item.value.(map[string]float64)
if !ok {
return
}
for _, member := range capturedMembers {
delete(zset, fmt.Sprintf("%v", member))
}
})
}
func (p *memoryPipeliner) MSet(values map[string]any) {
capturedValues := make(map[string]any, len(values))
for k, v := range values {
capturedValues[k] = v
}
p.ops = append(p.ops, func() {
for key, value := range capturedValues {
p.store.items[key] = &memoryStoreItem{
value: value,
expireAt: time.Time{}, // Pipelined MSet 同样假设永不过期
}
}
})
}
type memorySubscription struct {
store *memoryStore

View File

@@ -75,10 +75,24 @@ func (s *RedisStore) Close() error {
return s.client.Close()
}
func (s *RedisStore) Expire(ctx context.Context, key string, expiration time.Duration) error {
return s.client.Expire(ctx, key, expiration).Err()
}
func (s *RedisStore) HSet(ctx context.Context, key string, values map[string]any) error {
return s.client.HSet(ctx, key, values).Err()
}
func (s *RedisStore) HGet(ctx context.Context, key, field string) (string, error) {
val, err := s.client.HGet(ctx, key, field).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
}
return "", err
}
return val, nil
}
func (s *RedisStore) HGetAll(ctx context.Context, key string) (map[string]string, error) {
return s.client.HGetAll(ctx, key).Result()
}
@@ -111,6 +125,18 @@ func (s *RedisStore) Rotate(ctx context.Context, key string) (string, error) {
return val, nil
}
func (s *RedisStore) MSet(ctx context.Context, values map[string]any) error {
if len(values) == 0 {
return nil
}
// Redis MSet 命令需要 [key1, value1, key2, value2, ...] 格式的切片
pairs := make([]interface{}, 0, len(values)*2)
for k, v := range values {
pairs = append(pairs, k, v)
}
return s.client.MSet(ctx, pairs...).Err()
}
func (s *RedisStore) SAdd(ctx context.Context, key string, members ...any) error {
return s.client.SAdd(ctx, key, members...).Err()
}
@@ -141,6 +167,13 @@ func (s *RedisStore) SRandMember(ctx context.Context, key string) (string, error
return member, nil
}
func (s *RedisStore) SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error) {
if len(keys) == 0 {
return 0, nil
}
return s.client.SUnionStore(ctx, destination, keys...).Result()
}
func (s *RedisStore) ZAdd(ctx context.Context, key string, members map[string]float64) error {
if len(members) == 0 {
return nil
@@ -216,6 +249,17 @@ func (p *redisPipeliner) LPush(key string, values ...any) { p.pipe.LPush(p.ctx,
func (p *redisPipeliner) LRem(key string, count int64, value any) {
p.pipe.LRem(p.ctx, key, count, value)
}
func (p *redisPipeliner) Set(key string, value []byte, expiration time.Duration) {
p.pipe.Set(p.ctx, key, value, expiration)
}
func (p *redisPipeliner) MSet(values map[string]any) {
if len(values) == 0 {
return
}
p.pipe.MSet(p.ctx, values)
}
func (p *redisPipeliner) SAdd(key string, members ...any) { p.pipe.SAdd(p.ctx, key, members...) }
func (p *redisPipeliner) SRem(key string, members ...any) { p.pipe.SRem(p.ctx, key, members...) }
func (p *redisPipeliner) ZAdd(key string, members map[string]float64) {

View File

@@ -35,6 +35,8 @@ type Pipeliner interface {
HIncrBy(key, field string, incr int64)
// SET
MSet(values map[string]any)
Set(key string, value []byte, expiration time.Duration)
SAdd(key string, members ...any)
SRem(key string, members ...any)
@@ -58,9 +60,11 @@ type Store interface {
Del(ctx context.Context, keys ...string) error
Exists(ctx context.Context, key string) (bool, error)
SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error)
MSet(ctx context.Context, values map[string]any) error
// HASH operations
HSet(ctx context.Context, key string, values map[string]any) error
HGet(ctx context.Context, key, field string) (string, error)
HGetAll(ctx context.Context, key string) (map[string]string, error)
HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error)
HDel(ctx context.Context, key string, fields ...string) error
@@ -70,6 +74,7 @@ type Store interface {
LRem(ctx context.Context, key string, count int64, value any) error
Rotate(ctx context.Context, key string) (string, error)
LIndex(ctx context.Context, key string, index int64) (string, error)
Expire(ctx context.Context, key string, expiration time.Duration) error
// SET operations
SAdd(ctx context.Context, key string, members ...any) error
@@ -77,6 +82,7 @@ type Store interface {
SMembers(ctx context.Context, key string) ([]string, error)
SRem(ctx context.Context, key string, members ...any) error
SRandMember(ctx context.Context, key string) (string, error)
SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error)
// Pub/Sub operations
Publish(ctx context.Context, channel string, message []byte) error