Files
godns/internal/cache/badger_cache.go
2026-01-06 02:25:24 +08:00

248 lines
6.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package cache
import (
"fmt"
"path/filepath"
"time"
"github.com/dgraph-io/badger/v4"
"github.com/dgraph-io/badger/v4/options"
"github.com/miekg/dns"
"godns/pkg/logger"
)
// Cache 定义缓存接口
type Cache interface {
Get(key string) (*CachedMsg, bool)
Set(key string, msg *CachedMsg, ttl time.Duration) error
Delete(key string) error
Close() error
Stats() string
}
// CachedMsg represents a cached DNS message with expiration time
type CachedMsg struct {
Msg *dns.Msg `json:"msg"`
Expires time.Time `json:"expires"`
}
// BadgerCache wraps BadgerDB for DNS query caching
type BadgerCache struct {
db *badger.DB
logger logger.Logger
}
// NewBadgerCache creates a new BadgerDB cache instance with optimized settings for embedded devices
func NewBadgerCache(dataPath string, log logger.Logger) (*BadgerCache, error) {
dbPath := filepath.Join(dataPath, "cache")
opts := badger.DefaultOptions(dbPath)
// 针对树莓派等嵌入式设备的优化配置(目标:总内存 ~32MB
// MemTable4MBBadgerDB 默认保持 2 个 MemTable
opts.MemTableSize = 4 << 20 // 4MB (内存占用 ~8MB)
// ValueLog4MB
opts.ValueLogFileSize = 4 << 20 // 4MB
// BlockCache16MB提升读取性能
opts.BlockCacheSize = 16 << 20 // 16MB
// IndexCache8MB加速索引查找
opts.IndexCacheSize = 8 << 20 // 8MB
// Level 0 tables
opts.NumLevelZeroTables = 2
opts.NumLevelZeroTablesStall = 4
// 关闭压缩,节省 CPU
opts.Compression = options.None
// DNS 响应通常较小,内联存储减少磁盘访问
opts.ValueThreshold = 512
// 异步写入,提高性能
opts.SyncWrites = false
// ValueLog 条目数量
opts.ValueLogMaxEntries = 50000
// 压缩线程数
opts.NumCompactors = 2
// 禁用冲突检测,提升写入性能
opts.DetectConflicts = false
// 禁用内部日志
opts.Logger = nil
db, err := badger.Open(opts)
if err != nil {
return nil, fmt.Errorf("failed to open BadgerDB: %w", err)
}
cache := &BadgerCache{db: db, logger: log}
// Start garbage collection routines
go cache.runGC()
go cache.runCompaction()
return cache, nil
}
// Set stores a DNS message in the cache with the given key and TTL
func (bc *BadgerCache) Set(key string, msg *CachedMsg, ttl time.Duration) error {
// Pack DNS message to wire format
dnsData, err := msg.Msg.Pack()
if err != nil {
return fmt.Errorf("failed to pack DNS message: %w", err)
}
// 直接存储二进制数据8字节过期时间 + DNS wire format
// 避免 JSON 序列化开销
expiresBytes := make([]byte, 8)
// 使用 Unix 时间戳(秒)
expiresUnix := msg.Expires.Unix()
for i := 0; i < 8; i++ {
expiresBytes[i] = byte(expiresUnix >> (56 - i*8))
}
// 组合数据:过期时间 + DNS数据
data := append(expiresBytes, dnsData...)
return bc.db.Update(func(txn *badger.Txn) error {
entry := badger.NewEntry([]byte(key), data).WithTTL(ttl)
return txn.SetEntry(entry)
})
}
// Get retrieves a DNS message from the cache
func (bc *BadgerCache) Get(key string) (*CachedMsg, bool) {
var cachedMsg *CachedMsg
err := bc.db.View(func(txn *badger.Txn) error {
item, err := txn.Get([]byte(key))
if err != nil {
return err
}
return item.Value(func(val []byte) error {
// 数据格式8字节过期时间 + DNS wire format
if len(val) < 8 {
return fmt.Errorf("invalid cache data: too short")
}
// 解析过期时间
var expiresUnix int64
for i := 0; i < 8; i++ {
expiresUnix = (expiresUnix << 8) | int64(val[i])
}
expires := time.Unix(expiresUnix, 0)
// 解析 DNS 消息
msg := new(dns.Msg)
if err := msg.Unpack(val[8:]); err != nil {
return fmt.Errorf("failed to unpack DNS message: %w", err)
}
cachedMsg = &CachedMsg{
Msg: msg,
Expires: expires,
}
return nil
})
})
if err != nil {
if err == badger.ErrKeyNotFound {
return nil, false
}
// 缓存数据损坏或格式不兼容,返回未命中,后续 Set 会覆盖
bc.logger.Printf("Cache get error for key %s: %v", key, err)
return nil, false
}
return cachedMsg, true
}
// Delete removes a key from the cache
func (bc *BadgerCache) Delete(key string) error {
return bc.db.Update(func(txn *badger.Txn) error {
return txn.Delete([]byte(key))
})
}
// Close closes the BadgerDB instance
func (bc *BadgerCache) Close() error {
return bc.db.Close()
}
// runGC runs garbage collection periodically to clean up expired entries in value log
func (bc *BadgerCache) runGC() {
ticker := time.NewTicker(15 * time.Minute)
defer ticker.Stop()
for range ticker.C {
// Run GC multiple times until no more rewrite is needed
gcCount := 0
for {
err := bc.db.RunValueLogGC(0.5)
if err != nil {
if err != badger.ErrNoRewrite {
bc.logger.Printf("BadgerDB GC error: %v", err)
}
break
}
gcCount++
// Limit GC runs and add delay to prevent CPU hogging
if gcCount >= 10 {
bc.logger.Printf("BadgerDB GC: reached max runs limit (10)")
break
}
// Sleep briefly between GC cycles to reduce CPU usage
time.Sleep(500 * time.Millisecond)
}
if gcCount > 0 {
bc.logger.Printf("BadgerDB GC: completed %d runs", gcCount)
}
// Check disk usage and clean if necessary
bc.checkAndCleanDiskUsage()
}
}
// runCompaction runs LSM tree compaction periodically to clean up expired key metadata
func (bc *BadgerCache) runCompaction() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
err := bc.db.Flatten(1)
if err != nil {
bc.logger.Printf("BadgerDB compaction error: %v", err)
}
}
}
// checkAndCleanDiskUsage checks if cache exceeds size limit and triggers cleanup
func (bc *BadgerCache) checkAndCleanDiskUsage() {
lsm, vlog := bc.db.Size()
totalSize := lsm + vlog
maxSize := int64(50 << 20) // 50MB limit (适合家用路由器等嵌入式设备)
if totalSize > maxSize {
bc.logger.Printf("Cache size %d MB exceeds limit %d MB, triggering cleanup", totalSize>>20, maxSize>>20)
// Force compaction to reduce size
if err := bc.db.Flatten(2); err != nil {
bc.logger.Printf("BadgerDB flatten error: %v", err)
}
}
}
// Stats returns cache statistics
func (bc *BadgerCache) Stats() string {
lsm, vlog := bc.db.Size()
return fmt.Sprintf("LSM size: %d bytes, Value log size: %d bytes", lsm, vlog)
}