Fix basepool & 优化 repo
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -65,7 +66,7 @@ func (s *memoryStore) startGCollector() {
|
||||
}
|
||||
}
|
||||
|
||||
// --- [架构修正] 所有方法签名都增加了 context.Context 参数以匹配接口 ---
|
||||
// --- 所有方法签名都增加了 context.Context 参数以匹配接口 ---
|
||||
// --- 内存实现可以忽略该参数,用 _ 接收 ---
|
||||
|
||||
func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
@@ -108,6 +109,17 @@ func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) {
|
||||
return ok && !item.isExpired(), nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Expire(_ context.Context, key string, expiration time.Duration) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok {
|
||||
return ErrNotFound
|
||||
}
|
||||
item.expireAt = time.Now().Add(expiration)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -159,6 +171,21 @@ func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) HGet(_ context.Context, key, field string) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
if hash, ok := item.value.(map[string]string); ok {
|
||||
if value, exists := hash[field]; exists {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
@@ -351,6 +378,26 @@ func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error)
|
||||
return members[n], nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SUnionStore(_ context.Context, destination string, keys ...string) (int64, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
unionSet := make(map[string]struct{})
|
||||
for _, key := range keys {
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
continue
|
||||
}
|
||||
if set, ok := item.value.(map[string]struct{}); ok {
|
||||
for member := range set {
|
||||
unionSet[member] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
destItem := &memoryStoreItem{value: unionSet}
|
||||
s.items[destination] = destItem
|
||||
return int64(len(unionSet)), nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -388,6 +435,16 @@ func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string
|
||||
return list[index], nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) MSet(ctx context.Context, values map[string]any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for key, value := range values {
|
||||
// 内存存储不支持独立的 TTL,因此我们假设永不过期
|
||||
s.items[key] = &memoryStoreItem{value: value, expireAt: time.Time{}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -556,6 +613,22 @@ func (p *memoryPipeliner) Del(keys ...string) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) Set(key string, value []byte, expiration time.Duration) {
|
||||
capturedKey := key
|
||||
capturedValue := value
|
||||
p.ops = append(p.ops, func() {
|
||||
var expireAt time.Time
|
||||
if expiration > 0 {
|
||||
expireAt = time.Now().Add(expiration)
|
||||
}
|
||||
p.store.items[capturedKey] = &memoryStoreItem{
|
||||
value: capturedValue,
|
||||
expireAt: expireAt,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) SAdd(key string, members ...any) {
|
||||
capturedKey := key
|
||||
capturedMembers := make([]any, len(members))
|
||||
@@ -576,6 +649,7 @@ func (p *memoryPipeliner) SAdd(key string, members ...any) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) SRem(key string, members ...any) {
|
||||
capturedKey := key
|
||||
capturedMembers := make([]any, len(members))
|
||||
@@ -615,11 +689,125 @@ func (p *memoryPipeliner) LPush(key string, values ...any) {
|
||||
item.value = append(stringValues, list...)
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
|
||||
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
|
||||
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
|
||||
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {}
|
||||
func (p *memoryPipeliner) ZRem(key string, members ...any) {}
|
||||
func (p *memoryPipeliner) LRem(key string, count int64, value any) {
|
||||
capturedKey := key
|
||||
capturedValue := fmt.Sprintf("%v", value)
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
return
|
||||
}
|
||||
list, ok := item.value.([]string)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
newList := make([]string, 0, len(list))
|
||||
removed := int64(0)
|
||||
for _, v := range list {
|
||||
if count != 0 && v == capturedValue && (count < 0 || removed < count) {
|
||||
removed++
|
||||
continue
|
||||
}
|
||||
newList = append(newList, v)
|
||||
}
|
||||
item.value = newList
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
|
||||
capturedKey := key
|
||||
capturedValues := make(map[string]any, len(values))
|
||||
for k, v := range values {
|
||||
capturedValues[k] = v
|
||||
}
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
item = &memoryStoreItem{value: make(map[string]string)}
|
||||
p.store.items[capturedKey] = item
|
||||
}
|
||||
hash, ok := item.value.(map[string]string)
|
||||
if !ok {
|
||||
hash = make(map[string]string)
|
||||
item.value = hash
|
||||
}
|
||||
for field, value := range capturedValues {
|
||||
hash[field] = fmt.Sprintf("%v", value)
|
||||
}
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {
|
||||
capturedKey := key
|
||||
capturedField := field
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
item = &memoryStoreItem{value: make(map[string]string)}
|
||||
p.store.items[capturedKey] = item
|
||||
}
|
||||
hash, ok := item.value.(map[string]string)
|
||||
if !ok {
|
||||
hash = make(map[string]string)
|
||||
item.value = hash
|
||||
}
|
||||
current, _ := strconv.ParseInt(hash[capturedField], 10, 64)
|
||||
hash[capturedField] = strconv.FormatInt(current+incr, 10)
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {
|
||||
capturedKey := key
|
||||
capturedMembers := make(map[string]float64, len(members))
|
||||
for k, v := range members {
|
||||
capturedMembers[k] = v
|
||||
}
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
item = &memoryStoreItem{value: make(map[string]float64)}
|
||||
p.store.items[capturedKey] = item
|
||||
}
|
||||
zset, ok := item.value.(map[string]float64)
|
||||
if !ok {
|
||||
zset = make(map[string]float64)
|
||||
item.value = zset
|
||||
}
|
||||
for member, score := range capturedMembers {
|
||||
zset[member] = score
|
||||
}
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) ZRem(key string, members ...any) {
|
||||
capturedKey := key
|
||||
capturedMembers := make([]any, len(members))
|
||||
copy(capturedMembers, members)
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
return
|
||||
}
|
||||
zset, ok := item.value.(map[string]float64)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, member := range capturedMembers {
|
||||
delete(zset, fmt.Sprintf("%v", member))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) MSet(values map[string]any) {
|
||||
capturedValues := make(map[string]any, len(values))
|
||||
for k, v := range values {
|
||||
capturedValues[k] = v
|
||||
}
|
||||
p.ops = append(p.ops, func() {
|
||||
for key, value := range capturedValues {
|
||||
p.store.items[key] = &memoryStoreItem{
|
||||
value: value,
|
||||
expireAt: time.Time{}, // Pipelined MSet 同样假设永不过期
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type memorySubscription struct {
|
||||
store *memoryStore
|
||||
|
||||
@@ -75,10 +75,24 @@ func (s *RedisStore) Close() error {
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Expire(ctx context.Context, key string, expiration time.Duration) error {
|
||||
return s.client.Expire(ctx, key, expiration).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HSet(ctx context.Context, key string, values map[string]any) error {
|
||||
return s.client.HSet(ctx, key, values).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HGet(ctx context.Context, key, field string) (string, error) {
|
||||
val, err := s.client.HGet(ctx, key, field).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
func (s *RedisStore) HGetAll(ctx context.Context, key string) (map[string]string, error) {
|
||||
return s.client.HGetAll(ctx, key).Result()
|
||||
}
|
||||
@@ -111,6 +125,18 @@ func (s *RedisStore) Rotate(ctx context.Context, key string) (string, error) {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) MSet(ctx context.Context, values map[string]any) error {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
// Redis MSet 命令需要 [key1, value1, key2, value2, ...] 格式的切片
|
||||
pairs := make([]interface{}, 0, len(values)*2)
|
||||
for k, v := range values {
|
||||
pairs = append(pairs, k, v)
|
||||
}
|
||||
return s.client.MSet(ctx, pairs...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SAdd(ctx context.Context, key string, members ...any) error {
|
||||
return s.client.SAdd(ctx, key, members...).Err()
|
||||
}
|
||||
@@ -141,6 +167,13 @@ func (s *RedisStore) SRandMember(ctx context.Context, key string) (string, error
|
||||
return member, nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error) {
|
||||
if len(keys) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return s.client.SUnionStore(ctx, destination, keys...).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) ZAdd(ctx context.Context, key string, members map[string]float64) error {
|
||||
if len(members) == 0 {
|
||||
return nil
|
||||
@@ -216,6 +249,17 @@ func (p *redisPipeliner) LPush(key string, values ...any) { p.pipe.LPush(p.ctx,
|
||||
func (p *redisPipeliner) LRem(key string, count int64, value any) {
|
||||
p.pipe.LRem(p.ctx, key, count, value)
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Set(key string, value []byte, expiration time.Duration) {
|
||||
p.pipe.Set(p.ctx, key, value, expiration)
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) MSet(values map[string]any) {
|
||||
if len(values) == 0 {
|
||||
return
|
||||
}
|
||||
p.pipe.MSet(p.ctx, values)
|
||||
}
|
||||
func (p *redisPipeliner) SAdd(key string, members ...any) { p.pipe.SAdd(p.ctx, key, members...) }
|
||||
func (p *redisPipeliner) SRem(key string, members ...any) { p.pipe.SRem(p.ctx, key, members...) }
|
||||
func (p *redisPipeliner) ZAdd(key string, members map[string]float64) {
|
||||
|
||||
@@ -35,6 +35,8 @@ type Pipeliner interface {
|
||||
HIncrBy(key, field string, incr int64)
|
||||
|
||||
// SET
|
||||
MSet(values map[string]any)
|
||||
Set(key string, value []byte, expiration time.Duration)
|
||||
SAdd(key string, members ...any)
|
||||
SRem(key string, members ...any)
|
||||
|
||||
@@ -58,9 +60,11 @@ type Store interface {
|
||||
Del(ctx context.Context, keys ...string) error
|
||||
Exists(ctx context.Context, key string) (bool, error)
|
||||
SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error)
|
||||
MSet(ctx context.Context, values map[string]any) error
|
||||
|
||||
// HASH operations
|
||||
HSet(ctx context.Context, key string, values map[string]any) error
|
||||
HGet(ctx context.Context, key, field string) (string, error)
|
||||
HGetAll(ctx context.Context, key string) (map[string]string, error)
|
||||
HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error)
|
||||
HDel(ctx context.Context, key string, fields ...string) error
|
||||
@@ -70,6 +74,7 @@ type Store interface {
|
||||
LRem(ctx context.Context, key string, count int64, value any) error
|
||||
Rotate(ctx context.Context, key string) (string, error)
|
||||
LIndex(ctx context.Context, key string, index int64) (string, error)
|
||||
Expire(ctx context.Context, key string, expiration time.Duration) error
|
||||
|
||||
// SET operations
|
||||
SAdd(ctx context.Context, key string, members ...any) error
|
||||
@@ -77,6 +82,7 @@ type Store interface {
|
||||
SMembers(ctx context.Context, key string) ([]string, error)
|
||||
SRem(ctx context.Context, key string, members ...any) error
|
||||
SRandMember(ctx context.Context, key string) (string, error)
|
||||
SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error)
|
||||
|
||||
// Pub/Sub operations
|
||||
Publish(ctx context.Context, channel string, message []byte) error
|
||||
|
||||
Reference in New Issue
Block a user