diff --git a/cache/cache.go b/cache/cache.go new file mode 100644 index 0000000..4a28a97 --- /dev/null +++ b/cache/cache.go @@ -0,0 +1,170 @@ +// cache/cache.go +package cache + +import ( + "crypto/sha256" + "encoding/hex" + "sync" + "time" +) + +type CacheEntry struct { + Data []byte + Headers map[string]string + CreatedAt time.Time + ExpiresAt time.Time + Size int64 +} + +type MemoryCache struct { + entries sync.Map + maxSize int64 + currentSize int64 + ttl time.Duration + mu sync.Mutex +} + +func NewMemoryCache(maxSize int64, ttl time.Duration) *MemoryCache { + mc := &MemoryCache{ + maxSize: maxSize, + ttl: ttl, + } + + // 启动清理协程 + go mc.cleanup() + + return mc +} + +func (mc *MemoryCache) Get(key string) (*CacheEntry, bool) { + val, ok := mc.entries.Load(key) + if !ok { + return nil, false + } + + entry := val.(*CacheEntry) + + // 检查是否过期 + if time.Now().After(entry.ExpiresAt) { + mc.Delete(key) + return nil, false + } + + return entry, true +} + +func (mc *MemoryCache) Set(key string, data []byte, headers map[string]string) bool { + size := int64(len(data)) + + mc.mu.Lock() + defer mc.mu.Unlock() + + // 检查是否超过最大缓存大小 + if mc.currentSize+size > mc.maxSize { + // 尝试清理过期条目 + mc.evictExpired() + + // 如果还是不够,使用 LRU 清理 + if mc.currentSize+size > mc.maxSize { + mc.evictOldest(size) + } + } + + now := time.Now() + entry := &CacheEntry{ + Data: data, + Headers: headers, + CreatedAt: now, + ExpiresAt: now.Add(mc.ttl), + Size: size, + } + + mc.entries.Store(key, entry) + mc.currentSize += size + + return true +} + +func (mc *MemoryCache) Delete(key string) { + val, ok := mc.entries.LoadAndDelete(key) + if ok { + entry := val.(*CacheEntry) + mc.mu.Lock() + mc.currentSize -= entry.Size + mc.mu.Unlock() + } +} + +func (mc *MemoryCache) GenerateKey(url string) string { + hash := sha256.Sum256([]byte(url)) + return hex.EncodeToString(hash[:]) +} + +func (mc *MemoryCache) evictExpired() { + now := time.Now() + mc.entries.Range(func(key, value interface{}) bool { + entry := value.(*CacheEntry) + if now.After(entry.ExpiresAt) { + mc.entries.Delete(key) + mc.currentSize -= entry.Size + } + return true + }) +} + +func (mc *MemoryCache) evictOldest(needed int64) { + type entryWithKey struct { + key string + entry *CacheEntry + } + + var entries []entryWithKey + + mc.entries.Range(func(key, value interface{}) bool { + entries = append(entries, entryWithKey{ + key: key.(string), + entry: value.(*CacheEntry), + }) + return true + }) + + // 按创建时间排序 + for i := 0; i < len(entries)-1; i++ { + for j := i + 1; j < len(entries); j++ { + if entries[i].entry.CreatedAt.After(entries[j].entry.CreatedAt) { + entries[i], entries[j] = entries[j], entries[i] + } + } + } + + // 删除最旧的条目直到有足够空间 + freed := int64(0) + for _, e := range entries { + if freed >= needed { + break + } + mc.entries.Delete(e.key) + mc.currentSize -= e.entry.Size + freed += e.entry.Size + } +} + +func (mc *MemoryCache) cleanup() { + ticker := time.NewTicker(10 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + mc.mu.Lock() + mc.evictExpired() + mc.mu.Unlock() + } +} + +func (mc *MemoryCache) Stats() (entries int, size int64) { + count := 0 + mc.entries.Range(func(key, value interface{}) bool { + count++ + return true + }) + return count, mc.currentSize +}