New
This commit is contained in:
197
internal/service/analytics_service.go
Normal file
197
internal/service/analytics_service.go
Normal file
@@ -0,0 +1,197 @@
|
||||
// Filename: internal/service/analytics_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/db/dialect"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
flushLoopInterval = 1 * time.Minute
|
||||
)
|
||||
|
||||
type AnalyticsServiceLogger struct{ *logrus.Entry }
|
||||
|
||||
type AnalyticsService struct {
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
dialect dialect.DialectAdapter
|
||||
}
|
||||
|
||||
func NewAnalyticsService(db *gorm.DB, s store.Store, logger *logrus.Logger, d dialect.DialectAdapter) *AnalyticsService {
|
||||
return &AnalyticsService{
|
||||
db: db,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "Analytics📊"),
|
||||
stopChan: make(chan struct{}),
|
||||
dialect: d,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) Start() {
|
||||
s.wg.Add(2) // 2 (flushLoop, eventListener)
|
||||
go s.flushLoop()
|
||||
go s.eventListener()
|
||||
s.logger.Info("AnalyticsService (Command Side) started.")
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) Stop() {
|
||||
close(s.stopChan)
|
||||
s.wg.Wait()
|
||||
s.logger.Info("AnalyticsService stopped. Performing final data flush...")
|
||||
s.flushToDB() // 停止前刷盘
|
||||
s.logger.Info("AnalyticsService final data flush completed.")
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) eventListener() {
|
||||
defer s.wg.Done()
|
||||
sub, err := s.store.Subscribe(models.TopicRequestFinished)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
|
||||
return
|
||||
}
|
||||
defer sub.Close()
|
||||
s.logger.Info("AnalyticsService subscribed to request events.")
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-sub.Channel():
|
||||
var event models.RequestFinishedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal analytics event: %v", err)
|
||||
continue
|
||||
}
|
||||
s.handleAnalyticsEvent(&event)
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("AnalyticsService stopping event listener.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) {
|
||||
if event.GroupID == 0 {
|
||||
return
|
||||
}
|
||||
key := fmt.Sprintf("analytics:hourly:%s", time.Now().UTC().Format("2006-01-02T15"))
|
||||
fieldPrefix := fmt.Sprintf("%d:%s", event.GroupID, event.ModelName)
|
||||
|
||||
pipe := s.store.Pipeline()
|
||||
pipe.HIncrBy(key, fieldPrefix+":requests", 1)
|
||||
if event.IsSuccess {
|
||||
pipe.HIncrBy(key, fieldPrefix+":success", 1)
|
||||
}
|
||||
if event.PromptTokens > 0 {
|
||||
pipe.HIncrBy(key, fieldPrefix+":prompt", int64(event.PromptTokens))
|
||||
}
|
||||
if event.CompletionTokens > 0 {
|
||||
pipe.HIncrBy(key, fieldPrefix+":completion", int64(event.CompletionTokens))
|
||||
}
|
||||
|
||||
if err := pipe.Exec(); err != nil {
|
||||
s.logger.Warnf("[%s] Failed to record analytics event to store for group %d: %v", event.CorrelationID, event.GroupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) flushLoop() {
|
||||
defer s.wg.Done()
|
||||
ticker := time.NewTicker(flushLoopInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.flushToDB()
|
||||
case <-s.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) flushToDB() {
|
||||
now := time.Now().UTC()
|
||||
keysToFlush := []string{
|
||||
fmt.Sprintf("analytics:hourly:%s", now.Add(-1*time.Hour).Format("2006-01-02T15")),
|
||||
fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")),
|
||||
}
|
||||
|
||||
for _, key := range keysToFlush {
|
||||
data, err := s.store.HGetAll(key)
|
||||
if err != nil || len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
statsToFlush, parsedFields := s.parseStatsFromHash(now.Truncate(time.Hour), data)
|
||||
|
||||
if len(statsToFlush) > 0 {
|
||||
upsertClause := s.dialect.OnConflictUpdateAll(
|
||||
[]string{"time", "group_id", "model_name"}, // conflict columns
|
||||
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}, // update columns
|
||||
)
|
||||
err := s.db.Clauses(upsertClause).Create(&statsToFlush).Error
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to flush analytics data for key %s: %v", key, err)
|
||||
} else {
|
||||
s.logger.Infof("Successfully flushed %d records from key %s.", len(statsToFlush), key)
|
||||
_ = s.store.HDel(key, parsedFields...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]string) ([]models.StatsHourly, []string) {
|
||||
tempAggregator := make(map[string]*models.StatsHourly)
|
||||
var parsedFields []string
|
||||
for field, valueStr := range data {
|
||||
parts := strings.Split(field, ":")
|
||||
if len(parts) != 3 {
|
||||
continue
|
||||
}
|
||||
groupIDStr, modelName, counterType := parts[0], parts[1], parts[2]
|
||||
|
||||
aggKey := groupIDStr + ":" + modelName
|
||||
if _, ok := tempAggregator[aggKey]; !ok {
|
||||
gid, err := strconv.Atoi(groupIDStr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
tempAggregator[aggKey] = &models.StatsHourly{
|
||||
Time: t,
|
||||
GroupID: uint(gid),
|
||||
ModelName: modelName,
|
||||
}
|
||||
}
|
||||
val, _ := strconv.ParseInt(valueStr, 10, 64)
|
||||
switch counterType {
|
||||
case "requests":
|
||||
tempAggregator[aggKey].RequestCount = val
|
||||
case "success":
|
||||
tempAggregator[aggKey].SuccessCount = val
|
||||
case "prompt":
|
||||
tempAggregator[aggKey].PromptTokens = val
|
||||
case "completion":
|
||||
tempAggregator[aggKey].CompletionTokens = val
|
||||
}
|
||||
parsedFields = append(parsedFields, field)
|
||||
}
|
||||
var result []models.StatsHourly
|
||||
for _, stats := range tempAggregator {
|
||||
if stats.RequestCount > 0 {
|
||||
result = append(result, *stats)
|
||||
}
|
||||
}
|
||||
return result, parsedFields
|
||||
}
|
||||
857
internal/service/apikey_service.go
Normal file
857
internal/service/apikey_service.go
Normal file
@@ -0,0 +1,857 @@
|
||||
// Filename: internal/service/apikey_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/channel"
|
||||
CustomErrors "gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/task"
|
||||
"math"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
TaskTypeRestoreAllBannedInGroup = "restore_all_banned_in_group"
|
||||
TaskTypeRestoreSpecificKeys = "restore_specific_keys_in_group"
|
||||
TaskTypeUpdateStatusByFilter = "update_status_by_filter"
|
||||
)
|
||||
|
||||
// DTOs & Constants
|
||||
const (
|
||||
TestEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
)
|
||||
|
||||
type BatchRestoreResult struct {
|
||||
RestoredCount int `json:"restored_count"`
|
||||
SkippedCount int `json:"skipped_count"`
|
||||
SkippedKeys []SkippedKeyInfo `json:"skipped_keys"`
|
||||
}
|
||||
|
||||
type SkippedKeyInfo struct {
|
||||
KeyID uint `json:"key_id"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type PaginatedAPIKeys struct {
|
||||
Items []*models.APIKeyDetails `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
type KeyTestResult struct {
|
||||
Key string `json:"key"`
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type APIKeyService struct {
|
||||
db *gorm.DB
|
||||
keyRepo repository.KeyRepository
|
||||
channel channel.ChannelProxy
|
||||
store store.Store
|
||||
SettingsManager *settings.SettingsManager
|
||||
taskService task.Reporter
|
||||
logger *logrus.Entry
|
||||
stopChan chan struct{}
|
||||
validationService *KeyValidationService
|
||||
groupManager *GroupManager
|
||||
}
|
||||
|
||||
func NewAPIKeyService(
|
||||
db *gorm.DB,
|
||||
repo repository.KeyRepository,
|
||||
ch channel.ChannelProxy,
|
||||
s store.Store,
|
||||
sm *settings.SettingsManager,
|
||||
ts task.Reporter,
|
||||
vs *KeyValidationService,
|
||||
gm *GroupManager,
|
||||
logger *logrus.Logger,
|
||||
) *APIKeyService {
|
||||
logger.Debugf("[FORENSIC PROBE | INJECTION | APIKeyService] Received a KeyRepository. Fingerprint: %p", repo)
|
||||
return &APIKeyService{
|
||||
db: db,
|
||||
keyRepo: repo,
|
||||
channel: ch,
|
||||
store: s,
|
||||
SettingsManager: sm,
|
||||
taskService: ts,
|
||||
logger: logger.WithField("component", "APIKeyService🔑"),
|
||||
stopChan: make(chan struct{}),
|
||||
validationService: vs,
|
||||
groupManager: gm,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) Start() {
|
||||
requestSub, err := s.store.Subscribe(models.TopicRequestFinished)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
|
||||
return
|
||||
}
|
||||
masterKeySub, err := s.store.Subscribe(models.TopicMasterKeyStatusChanged)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicMasterKeyStatusChanged, err)
|
||||
return
|
||||
}
|
||||
keyStatusSub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
|
||||
return
|
||||
}
|
||||
importSub, err := s.store.Subscribe(models.TopicImportGroupCompleted)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicImportGroupCompleted, err)
|
||||
return
|
||||
}
|
||||
s.logger.Info("Started and subscribed to request, master key, health check, and import events.")
|
||||
|
||||
go func() {
|
||||
defer requestSub.Close()
|
||||
defer masterKeySub.Close()
|
||||
defer keyStatusSub.Close()
|
||||
defer importSub.Close()
|
||||
for {
|
||||
select {
|
||||
case msg := <-requestSub.Channel():
|
||||
var event models.RequestFinishedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal event for key status update: %v", err)
|
||||
continue
|
||||
}
|
||||
s.handleKeyUsageEvent(&event)
|
||||
|
||||
case msg := <-masterKeySub.Channel():
|
||||
var event models.MasterKeyStatusChangedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal MasterKeyStatusChangedEvent: %v", err)
|
||||
continue
|
||||
}
|
||||
s.handleMasterKeyStatusChangeEvent(&event)
|
||||
case msg := <-keyStatusSub.Channel():
|
||||
var event models.KeyStatusChangedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err)
|
||||
continue
|
||||
}
|
||||
s.handleKeyStatusChangeEvent(&event)
|
||||
|
||||
case msg := <-importSub.Channel():
|
||||
var event models.ImportGroupCompletedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to unmarshal ImportGroupCompletedEvent.")
|
||||
continue
|
||||
}
|
||||
s.logger.Infof("Received ImportGroupCompletedEvent for group %d, triggering validation for %d keys.", event.GroupID, len(event.KeyIDs))
|
||||
|
||||
go s.handlePostImportValidation(&event)
|
||||
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Stopping event listener.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *APIKeyService) Stop() {
|
||||
close(s.stopChan)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) handleKeyUsageEvent(event *models.RequestFinishedEvent) {
|
||||
if event.KeyID == 0 || event.GroupID == 0 {
|
||||
return
|
||||
}
|
||||
// Handle success case: key recovery and timestamp update.
|
||||
if event.IsSuccess {
|
||||
mapping, err := s.keyRepo.GetMapping(event.GroupID, event.KeyID)
|
||||
if err != nil {
|
||||
// Log if mapping is not found, but don't proceed.
|
||||
s.logger.Warnf("[%s] Could not find mapping for G:%d K:%d on successful request: %v", event.CorrelationID, event.GroupID, event.KeyID, err)
|
||||
return
|
||||
}
|
||||
|
||||
needsUpdate := false
|
||||
oldStatus := mapping.Status
|
||||
|
||||
// If status was not active, it's a recovery.
|
||||
if mapping.Status != models.StatusActive {
|
||||
mapping.Status = models.StatusActive
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
mapping.LastError = ""
|
||||
needsUpdate = true
|
||||
}
|
||||
// Always update LastUsedAt timestamp.
|
||||
now := time.Now()
|
||||
mapping.LastUsedAt = &now
|
||||
needsUpdate = true
|
||||
|
||||
if needsUpdate {
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
s.logger.Errorf("[%s] Failed to update mapping for G:%d K:%d after successful request: %v", event.CorrelationID, event.GroupID, event.KeyID, err)
|
||||
} else if oldStatus != models.StatusActive {
|
||||
// Only publish event if status actually changed.
|
||||
go s.publishStatusChangeEvent(event.GroupID, event.KeyID, oldStatus, models.StatusActive, "key_recovered_after_use")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// Handle failure case: delegate to the centralized judgment function.
|
||||
if event.Error != nil {
|
||||
s.judgeKeyErrors(
|
||||
event.CorrelationID,
|
||||
event.GroupID,
|
||||
event.KeyID,
|
||||
event.Error,
|
||||
event.IsPreciseRouting,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) handleKeyStatusChangeEvent(event *models.KeyStatusChangedEvent) {
|
||||
log := s.logger.WithFields(logrus.Fields{
|
||||
"group_id": event.GroupID,
|
||||
"key_id": event.KeyID,
|
||||
"new_status": event.NewStatus,
|
||||
"reason": event.ChangeReason,
|
||||
})
|
||||
log.Info("Received KeyStatusChangedEvent, will update polling caches.")
|
||||
s.keyRepo.HandleCacheUpdateEvent(event.GroupID, event.KeyID, event.NewStatus)
|
||||
log.Info("Polling caches updated based on health check event.")
|
||||
}
|
||||
|
||||
func (s *APIKeyService) publishStatusChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
||||
changeEvent := models.KeyStatusChangedEvent{
|
||||
KeyID: keyID,
|
||||
GroupID: groupID,
|
||||
OldStatus: oldStatus,
|
||||
NewStatus: newStatus,
|
||||
ChangeReason: reason,
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
eventData, _ := json.Marshal(changeEvent)
|
||||
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
s.logger.Errorf("Failed to publish key status changed event for group %d: %v", groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) ListAPIKeys(params *models.APIKeyQueryParams) (*PaginatedAPIKeys, error) {
|
||||
// --- Path 1: High-performance DB pagination (no keyword) ---
|
||||
if params.Keyword == "" {
|
||||
items, total, err := s.keyRepo.GetPaginatedKeysAndMappingsByGroup(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
totalPages := 0
|
||||
if total > 0 && params.PageSize > 0 {
|
||||
totalPages = int(math.Ceil(float64(total) / float64(params.PageSize)))
|
||||
}
|
||||
return &PaginatedAPIKeys{
|
||||
Items: items,
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
TotalPages: totalPages,
|
||||
}, nil
|
||||
}
|
||||
// --- Path 2: In-memory search (keyword present) ---
|
||||
s.logger.Infof("Performing heavy in-memory search for group %d with keyword '%s'", params.KeyGroupID, params.Keyword)
|
||||
// To get all keys, we fetch all IDs first, then get their full details.
|
||||
var statusesToFilter []string
|
||||
if params.Status != "" {
|
||||
statusesToFilter = append(statusesToFilter, params.Status)
|
||||
} else {
|
||||
statusesToFilter = append(statusesToFilter, "all") // "all" gets every status
|
||||
}
|
||||
allKeyIDs, err := s.keyRepo.FindKeyIDsByStatus(params.KeyGroupID, statusesToFilter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch all key IDs for search: %w", err)
|
||||
}
|
||||
if len(allKeyIDs) == 0 {
|
||||
return &PaginatedAPIKeys{Items: []*models.APIKeyDetails{}, Total: 0, Page: 1, PageSize: params.PageSize, TotalPages: 0}, nil
|
||||
}
|
||||
|
||||
// This is the heavy operation: getting all keys and decrypting them.
|
||||
allKeys, err := s.keyRepo.GetKeysByIDs(allKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch all key details for in-memory search: %w", err)
|
||||
}
|
||||
// We also need mappings to build the final `APIKeyDetails`.
|
||||
var allMappings []models.GroupAPIKeyMapping
|
||||
err = s.db.Where("key_group_id = ? AND api_key_id IN ?", params.KeyGroupID, allKeyIDs).Find(&allMappings).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch mappings for in-memory search: %w", err)
|
||||
}
|
||||
mappingMap := make(map[uint]*models.GroupAPIKeyMapping)
|
||||
for i := range allMappings {
|
||||
mappingMap[allMappings[i].APIKeyID] = &allMappings[i]
|
||||
}
|
||||
// Filter the results in memory.
|
||||
var filteredItems []*models.APIKeyDetails
|
||||
for _, key := range allKeys {
|
||||
if strings.Contains(key.APIKey, params.Keyword) {
|
||||
if mapping, ok := mappingMap[key.ID]; ok {
|
||||
filteredItems = append(filteredItems, &models.APIKeyDetails{
|
||||
ID: key.ID,
|
||||
CreatedAt: key.CreatedAt,
|
||||
UpdatedAt: key.UpdatedAt,
|
||||
APIKey: key.APIKey,
|
||||
MasterStatus: key.MasterStatus,
|
||||
Status: mapping.Status,
|
||||
LastError: mapping.LastError,
|
||||
ConsecutiveErrorCount: mapping.ConsecutiveErrorCount,
|
||||
LastUsedAt: mapping.LastUsedAt,
|
||||
CooldownUntil: mapping.CooldownUntil,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
// Sort the filtered results to ensure consistent pagination (by ID descending).
|
||||
sort.Slice(filteredItems, func(i, j int) bool {
|
||||
return filteredItems[i].ID > filteredItems[j].ID
|
||||
})
|
||||
// Manually paginate the filtered results.
|
||||
total := int64(len(filteredItems))
|
||||
start := (params.Page - 1) * params.PageSize
|
||||
end := start + params.PageSize
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
if start >= len(filteredItems) {
|
||||
return &PaginatedAPIKeys{
|
||||
Items: []*models.APIKeyDetails{},
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
TotalPages: int(math.Ceil(float64(total) / float64(params.PageSize))),
|
||||
}, nil
|
||||
}
|
||||
if end > len(filteredItems) {
|
||||
end = len(filteredItems)
|
||||
}
|
||||
paginatedItems := filteredItems[start:end]
|
||||
return &PaginatedAPIKeys{
|
||||
Items: paginatedItems,
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
TotalPages: int(math.Ceil(float64(total) / float64(params.PageSize))),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) UpdateAPIKey(key *models.APIKey) error {
|
||||
go func() {
|
||||
var oldKey models.APIKey
|
||||
if err := s.db.First(&oldKey, key.ID).Error; err != nil {
|
||||
s.logger.Errorf("Failed to find old key state for ID %d before update: %v", key.ID, err)
|
||||
return
|
||||
}
|
||||
if err := s.keyRepo.Update(key); err != nil {
|
||||
s.logger.Errorf("Failed to asynchronously update key ID %d: %v", key.ID, err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) HardDeleteAPIKeyByID(id uint) error {
|
||||
// Get all associated groups before deletion to publish correct events
|
||||
groups, err := s.keyRepo.GetGroupsForKey(id)
|
||||
if err != nil {
|
||||
s.logger.Warnf("Could not get groups for key ID %d before hard deletion (key might be orphaned): %v", id, err)
|
||||
}
|
||||
|
||||
err = s.keyRepo.HardDeleteByID(id)
|
||||
if err == nil {
|
||||
// Publish events for each group the key was a part of
|
||||
for _, groupID := range groups {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
KeyID: id,
|
||||
GroupID: groupID,
|
||||
ChangeReason: "key_hard_deleted",
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
go s.store.Publish(models.TopicKeyStatusChanged, eventData)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *APIKeyService) UpdateMappingStatus(groupID, keyID uint, newStatus models.APIKeyStatus) (*models.GroupAPIKeyMapping, error) {
|
||||
key, err := s.keyRepo.GetKeyByID(keyID)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
}
|
||||
|
||||
if key.MasterStatus != models.MasterStatusActive && newStatus == models.StatusActive {
|
||||
return nil, CustomErrors.ErrStateConflictMasterRevoked
|
||||
}
|
||||
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
oldStatus := mapping.Status
|
||||
if oldStatus == newStatus {
|
||||
return mapping, nil
|
||||
}
|
||||
mapping.Status = newStatus
|
||||
if newStatus == models.StatusActive {
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
mapping.LastError = ""
|
||||
}
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "manual_update")
|
||||
return mapping, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) handleMasterKeyStatusChangeEvent(event *models.MasterKeyStatusChangedEvent) {
|
||||
s.logger.Infof("Received MasterKeyStatusChangedEvent for Key ID %d: %s -> %s", event.KeyID, event.OldMasterStatus, event.NewMasterStatus)
|
||||
if event.NewMasterStatus != models.MasterStatusRevoked {
|
||||
return
|
||||
}
|
||||
affectedGroupIDs, err := s.keyRepo.GetGroupsForKey(event.KeyID)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to get groups for key ID %d to propagate revoked status.", event.KeyID)
|
||||
return
|
||||
}
|
||||
if len(affectedGroupIDs) == 0 {
|
||||
s.logger.Infof("Key ID %d is revoked, but it's not associated with any group. No action needed.", event.KeyID)
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Propagating REVOKED astatus for Key ID %d to %d groups.", event.KeyID, len(affectedGroupIDs))
|
||||
for _, groupID := range affectedGroupIDs {
|
||||
_, err := s.UpdateMappingStatus(groupID, event.KeyID, models.StatusBanned)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
s.logger.WithError(err).Errorf("Failed to update mapping status to BANNED for Key ID %d in Group ID %d.", event.KeyID, groupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) StartRestoreKeysTask(groupID uint, keyIDs []uint) (*task.Status, error) {
|
||||
if len(keyIDs) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No key IDs provided for restoration.")
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeRestoreSpecificKeys, resourceID, len(keyIDs), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runRestoreKeysTask(taskStatus.ID, resourceID, groupID, keyIDs)
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) runRestoreKeysTask(taskID string, resourceID string, groupID uint, keyIDs []uint) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.logger.Errorf("Panic recovered in runRestoreKeysTask: %v", r)
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in restore keys task: %v", r))
|
||||
}
|
||||
}()
|
||||
var mappingsToProcess []models.GroupAPIKeyMapping
|
||||
err := s.db.Preload("APIKey").
|
||||
Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).
|
||||
Find(&mappingsToProcess).Error
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
result := &BatchRestoreResult{
|
||||
SkippedKeys: make([]SkippedKeyInfo, 0),
|
||||
}
|
||||
var successfulMappings []*models.GroupAPIKeyMapping
|
||||
processedCount := 0
|
||||
for _, mapping := range mappingsToProcess {
|
||||
processedCount++
|
||||
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
|
||||
if mapping.APIKey == nil {
|
||||
result.SkippedCount++
|
||||
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "Associated APIKey entity not found."})
|
||||
continue
|
||||
}
|
||||
if mapping.APIKey.MasterStatus != models.MasterStatusActive {
|
||||
result.SkippedCount++
|
||||
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: fmt.Sprintf("Master status is '%s'.", mapping.APIKey.MasterStatus)})
|
||||
continue
|
||||
}
|
||||
oldStatus := mapping.Status
|
||||
if oldStatus != models.StatusActive {
|
||||
mapping.Status = models.StatusActive
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
mapping.LastError = ""
|
||||
// Use the version that doesn't trigger individual cache updates.
|
||||
if err := s.keyRepo.UpdateMappingWithoutCache(&mapping); err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to update mapping in DB during restore task.")
|
||||
result.SkippedCount++
|
||||
result.SkippedKeys = append(result.SkippedKeys, SkippedKeyInfo{KeyID: mapping.APIKeyID, Reason: "DB update failed."})
|
||||
} else {
|
||||
result.RestoredCount++
|
||||
successfulMappings = append(successfulMappings, &mapping) // Collect for batch cache update.
|
||||
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, models.StatusActive, "batch_restore")
|
||||
}
|
||||
} else {
|
||||
result.RestoredCount++ // Already active, count as success.
|
||||
}
|
||||
}
|
||||
// After the loop, perform one single, efficient cache update.
|
||||
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to perform batch cache update after restore task.")
|
||||
// This is not a task-fatal error, so we just log it and continue.
|
||||
}
|
||||
// Account for keys that were requested but not found in the initial DB query.
|
||||
result.SkippedCount += (len(keyIDs) - len(mappingsToProcess))
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) StartRestoreAllBannedTask(groupID uint) (*task.Status, error) {
|
||||
var bannedKeyIDs []uint
|
||||
err := s.db.Model(&models.GroupAPIKeyMapping{}).
|
||||
Where("key_group_id = ? AND status = ?", groupID, models.StatusBanned).
|
||||
Pluck("api_key_id", &bannedKeyIDs).Error
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
}
|
||||
if len(bannedKeyIDs) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "No banned keys to restore in this group.")
|
||||
}
|
||||
return s.StartRestoreKeysTask(groupID, bannedKeyIDs)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) handlePostImportValidation(event *models.ImportGroupCompletedEvent) {
|
||||
group, ok := s.groupManager.GetGroupByID(event.GroupID)
|
||||
if !ok {
|
||||
s.logger.Errorf("Group with id %d not found during post-import validation, aborting.", event.GroupID)
|
||||
return
|
||||
}
|
||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to build operational config for group %d, aborting validation: %v", event.GroupID, err)
|
||||
return
|
||||
}
|
||||
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(event.GroupID)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to build key check endpoint for group %d, aborting validation: %v", event.GroupID, err)
|
||||
return
|
||||
}
|
||||
globalSettings := s.SettingsManager.GetSettings()
|
||||
concurrency := globalSettings.BaseKeyCheckConcurrency
|
||||
if opConfig.KeyCheckConcurrency != nil {
|
||||
concurrency = *opConfig.KeyCheckConcurrency
|
||||
}
|
||||
if concurrency <= 0 {
|
||||
concurrency = 10 // Safety fallback
|
||||
}
|
||||
timeout := time.Duration(globalSettings.KeyCheckTimeoutSeconds) * time.Second
|
||||
keysToValidate, err := s.keyRepo.GetKeysByIDs(event.KeyIDs)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to get key models for validation in group %d: %v", event.GroupID, err)
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Validating %d keys for group %d with concurrency %d against endpoint %s", len(keysToValidate), event.GroupID, concurrency, endpoint)
|
||||
var wg sync.WaitGroup
|
||||
jobs := make(chan models.APIKey, len(keysToValidate))
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for key := range jobs {
|
||||
validationErr := s.validationService.ValidateSingleKey(&key, timeout, endpoint)
|
||||
if validationErr == nil {
|
||||
s.logger.Infof("Key ID %d PASSED validation. Status -> ACTIVE.", key.ID)
|
||||
if _, err := s.UpdateMappingStatus(event.GroupID, key.ID, models.StatusActive); err != nil {
|
||||
s.logger.Errorf("Failed to update status to ACTIVE for Key ID %d in group %d: %v", key.ID, event.GroupID, err)
|
||||
}
|
||||
} else {
|
||||
var apiErr *CustomErrors.APIError
|
||||
if !CustomErrors.As(validationErr, &apiErr) {
|
||||
apiErr = &CustomErrors.APIError{Message: validationErr.Error()}
|
||||
}
|
||||
s.judgeKeyErrors("", event.GroupID, key.ID, apiErr, false)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
for _, key := range keysToValidate {
|
||||
jobs <- key
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
s.logger.Infof("Finished post-import validation for group %d.", event.GroupID)
|
||||
}
|
||||
|
||||
// [NEW] StartUpdateStatusByFilterTask starts a background job to update the status of keys
|
||||
// that match a specific set of source statuses within a group.
|
||||
func (s *APIKeyService) StartUpdateStatusByFilterTask(groupID uint, sourceStatuses []string, newStatus models.APIKeyStatus) (*task.Status, error) {
|
||||
s.logger.Infof("Starting task to update status to '%s' for keys in group %d matching statuses: %v", newStatus, groupID, sourceStatuses)
|
||||
|
||||
// 1. Find key IDs using the new repository method. Using IDs is more efficient for updates.
|
||||
keyIDs, err := s.keyRepo.FindKeyIDsByStatus(groupID, sourceStatuses)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
}
|
||||
if len(keyIDs) == 0 {
|
||||
now := time.Now()
|
||||
return &task.Status{
|
||||
IsRunning: false, // The "task" is not running.
|
||||
Processed: 0,
|
||||
Total: 0,
|
||||
Result: map[string]string{ // We use the flexible Result field to pass the message.
|
||||
"message": "没有找到任何符合当前过滤条件的Key可供操作。",
|
||||
},
|
||||
Error: "", // There is no error.
|
||||
StartedAt: now,
|
||||
FinishedAt: &now, // It started and finished at the same time.
|
||||
}, nil // Return nil for the error, signaling a 200 OK.
|
||||
}
|
||||
// 2. Start a new task using the TaskService, following existing patterns.
|
||||
resourceID := fmt.Sprintf("group-%d-status-update", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeUpdateStatusByFilter, resourceID, len(keyIDs), 30*time.Minute)
|
||||
if err != nil {
|
||||
return nil, err // Pass up errors like "task already in progress".
|
||||
}
|
||||
|
||||
// 3. Run the core logic in a separate goroutine.
|
||||
go s.runUpdateStatusByFilterTask(taskStatus.ID, resourceID, groupID, keyIDs, newStatus)
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
// [NEW] runUpdateStatusByFilterTask is the private worker function for the above task.
|
||||
func (s *APIKeyService) runUpdateStatusByFilterTask(taskID, resourceID string, groupID uint, keyIDs []uint, newStatus models.APIKeyStatus) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.logger.Errorf("Panic recovered in runUpdateStatusByFilterTask: %v", r)
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("panic in update status task: %v", r))
|
||||
}
|
||||
}()
|
||||
type BatchUpdateResult struct {
|
||||
UpdatedCount int `json:"updated_count"`
|
||||
SkippedCount int `json:"skipped_count"`
|
||||
}
|
||||
result := &BatchUpdateResult{}
|
||||
var successfulMappings []*models.GroupAPIKeyMapping
|
||||
// 1. Fetch all key master statuses in one go. This is efficient.
|
||||
keys, err := s.keyRepo.GetKeysByIDs(keyIDs)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Failed to fetch keys for bulk status update, aborting task.")
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
masterStatusMap := make(map[uint]models.MasterAPIKeyStatus)
|
||||
for _, key := range keys {
|
||||
masterStatusMap[key.ID] = key.MasterStatus
|
||||
}
|
||||
// 2. [THE REFINEMENT] Fetch all relevant mappings directly using s.db,
|
||||
// avoiding the need for a new repository method. This pattern is
|
||||
// already used in other parts of this service.
|
||||
var mappings []*models.GroupAPIKeyMapping
|
||||
if err := s.db.Where("key_group_id = ? AND api_key_id IN ?", groupID, keyIDs).Find(&mappings).Error; err != nil {
|
||||
s.logger.WithError(err).Error("Failed to fetch mappings for bulk status update, aborting task.")
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
processedCount := 0
|
||||
for _, mapping := range mappings {
|
||||
processedCount++
|
||||
// The progress update should reflect the number of items *being processed*, not the final count.
|
||||
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
|
||||
masterStatus, ok := masterStatusMap[mapping.APIKeyID]
|
||||
if !ok {
|
||||
result.SkippedCount++
|
||||
continue
|
||||
}
|
||||
if masterStatus != models.MasterStatusActive && newStatus == models.StatusActive {
|
||||
result.SkippedCount++
|
||||
continue
|
||||
}
|
||||
oldStatus := mapping.Status
|
||||
if oldStatus != newStatus {
|
||||
mapping.Status = newStatus
|
||||
if newStatus == models.StatusActive {
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
mapping.LastError = ""
|
||||
}
|
||||
if err := s.keyRepo.UpdateMappingWithoutCache(mapping); err != nil {
|
||||
result.SkippedCount++
|
||||
} else {
|
||||
result.UpdatedCount++
|
||||
successfulMappings = append(successfulMappings, mapping)
|
||||
go s.publishStatusChangeEvent(groupID, mapping.APIKeyID, oldStatus, newStatus, "batch_status_update")
|
||||
}
|
||||
} else {
|
||||
result.UpdatedCount++ // Already in desired state, count as success.
|
||||
}
|
||||
}
|
||||
result.SkippedCount += (len(keyIDs) - len(mappings))
|
||||
if err := s.keyRepo.HandleCacheUpdateEventBatch(successfulMappings); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to perform batch cache update after status update task.")
|
||||
}
|
||||
s.logger.Infof("Finished bulk status update task '%s' for group %d. Updated: %d, Skipped: %d.", taskID, groupID, result.UpdatedCount, result.SkippedCount)
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) HandleRequestResult(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *CustomErrors.APIError) {
|
||||
if success {
|
||||
if group.PollingStrategy == models.StrategyWeighted {
|
||||
go s.keyRepo.UpdateKeyUsageTimestamp(group.ID, key.ID)
|
||||
}
|
||||
return
|
||||
}
|
||||
if apiErr == nil {
|
||||
s.logger.Warnf("Request failed for KeyID %d in GroupID %d but no specific API error was provided. No action taken.", key.ID, group.ID)
|
||||
return
|
||||
}
|
||||
errMsg := apiErr.Message
|
||||
if CustomErrors.IsPermanentUpstreamError(errMsg) || CustomErrors.IsTemporaryUpstreamError(errMsg) {
|
||||
s.logger.Warnf("Request for KeyID %d in GroupID %d failed with a key-specific error: %s. Issuing order to move to cooldown.", key.ID, group.ID, errMsg)
|
||||
go s.keyRepo.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
|
||||
} else {
|
||||
s.logger.Infof("Request for KeyID %d in GroupID %d failed with a non-key-specific error. No punitive action taken. Error: %s", key.ID, group.ID, errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// sanitizeForLog truncates long error messages and removes JSON-like content for cleaner Info/Error logs.
|
||||
func sanitizeForLog(errMsg string) string {
|
||||
// Find the start of any potential JSON blob or detailed structure.
|
||||
jsonStartIndex := strings.Index(errMsg, "{")
|
||||
var cleanMsg string
|
||||
if jsonStartIndex != -1 {
|
||||
// If a '{' is found, take everything before it as the summary
|
||||
// and append a simple placeholder.
|
||||
cleanMsg = strings.TrimSpace(errMsg[:jsonStartIndex]) + " {...}"
|
||||
} else {
|
||||
// If no JSON-like structure is found, use the original message.
|
||||
cleanMsg = errMsg
|
||||
}
|
||||
// Always apply a final length truncation as a safeguard against extremely long non-JSON errors.
|
||||
const maxLen = 250
|
||||
if len(cleanMsg) > maxLen {
|
||||
return cleanMsg[:maxLen] + "..."
|
||||
}
|
||||
return cleanMsg
|
||||
}
|
||||
|
||||
func (s *APIKeyService) judgeKeyErrors(
|
||||
correlationID string,
|
||||
groupID, keyID uint,
|
||||
apiErr *CustomErrors.APIError,
|
||||
isPreciseRouting bool,
|
||||
) {
|
||||
logger := s.logger.WithFields(logrus.Fields{"group_id": groupID, "key_id": keyID, "correlation_id": correlationID})
|
||||
mapping, err := s.keyRepo.GetMapping(groupID, keyID)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Cannot apply consequences for error: mapping not found.")
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
mapping.LastUsedAt = &now
|
||||
errorMessage := apiErr.Message
|
||||
if CustomErrors.IsPermanentUpstreamError(errorMessage) {
|
||||
logger.Errorf("Permanent error detected. Banning mapping and revoking master key. Reason: %s", sanitizeForLog(errorMessage))
|
||||
logger.WithField("full_error_details", errorMessage).Debug("Full details of the permanent error.")
|
||||
if mapping.Status != models.StatusBanned {
|
||||
oldStatus := mapping.Status
|
||||
mapping.Status = models.StatusBanned
|
||||
mapping.LastError = errorMessage
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
logger.WithError(err).Error("Failed to update mapping status to BANNED.")
|
||||
} else {
|
||||
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, mapping.Status, "permanent_error_banned")
|
||||
go s.revokeMasterKey(keyID, "permanent_upstream_error")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if CustomErrors.IsTemporaryUpstreamError(errorMessage) {
|
||||
mapping.LastError = errorMessage
|
||||
mapping.ConsecutiveErrorCount++
|
||||
var threshold int
|
||||
if isPreciseRouting {
|
||||
group, ok := s.groupManager.GetGroupByID(groupID)
|
||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||
if !ok || err != nil {
|
||||
logger.Warnf("Could not build operational config for group %d in Precise Routing mode. Falling back to global settings.", groupID)
|
||||
threshold = s.SettingsManager.GetSettings().BlacklistThreshold
|
||||
} else {
|
||||
threshold = *opConfig.KeyBlacklistThreshold
|
||||
}
|
||||
} else {
|
||||
threshold = s.SettingsManager.GetSettings().BlacklistThreshold
|
||||
}
|
||||
logger.Warnf("Temporary error detected. Incrementing error count. New count: %d (Threshold: %d). Reason: %s", mapping.ConsecutiveErrorCount, threshold, sanitizeForLog(errorMessage))
|
||||
logger.WithField("full_error_details", errorMessage).Debug("Full details of the temporary error.")
|
||||
oldStatus := mapping.Status
|
||||
newStatus := oldStatus
|
||||
if mapping.ConsecutiveErrorCount >= threshold && oldStatus == models.StatusActive {
|
||||
newStatus = models.StatusCooldown
|
||||
logger.Errorf("Putting mapping into COOLDOWN due to reaching temporary error threshold (%d)", threshold)
|
||||
}
|
||||
if oldStatus != newStatus {
|
||||
mapping.Status = newStatus
|
||||
}
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
logger.WithError(err).Error("Failed to update mapping after temporary error.")
|
||||
return
|
||||
}
|
||||
if oldStatus != newStatus {
|
||||
go s.publishStatusChangeEvent(groupID, keyID, oldStatus, newStatus, "error_threshold_reached")
|
||||
}
|
||||
return
|
||||
}
|
||||
logger.Infof("Ignored truly ignorable upstream error. Only updating LastUsedAt. Reason: %s", sanitizeForLog(errorMessage))
|
||||
logger.WithField("full_error_details", errorMessage).Debug("Full details of the ignorable error.")
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
logger.WithError(err).Error("Failed to update LastUsedAt for ignorable error.")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) revokeMasterKey(keyID uint, reason string) {
|
||||
key, err := s.keyRepo.GetKeyByID(keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
s.logger.Warnf("Attempted to revoke non-existent key ID %d.", keyID)
|
||||
} else {
|
||||
s.logger.Errorf("Failed to get key by ID %d for master status revocation: %v", keyID, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if key.MasterStatus == models.MasterStatusRevoked {
|
||||
return
|
||||
}
|
||||
oldMasterStatus := key.MasterStatus
|
||||
newMasterStatus := models.MasterStatusRevoked
|
||||
if err := s.keyRepo.UpdateMasterStatusByID(keyID, newMasterStatus); err != nil {
|
||||
s.logger.Errorf("Failed to update master status to REVOKED for key ID %d: %v", keyID, err)
|
||||
return
|
||||
}
|
||||
masterKeyEvent := models.MasterKeyStatusChangedEvent{
|
||||
KeyID: keyID,
|
||||
OldMasterStatus: oldMasterStatus,
|
||||
NewMasterStatus: newMasterStatus,
|
||||
ChangeReason: reason,
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
eventData, _ := json.Marshal(masterKeyEvent)
|
||||
_ = s.store.Publish(models.TopicMasterKeyStatusChanged, eventData)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) GetAPIKeyStringsForExport(groupID uint, statuses []string) ([]string, error) {
|
||||
return s.keyRepo.GetKeyStringsByGroupAndStatus(groupID, statuses)
|
||||
}
|
||||
315
internal/service/dashboard_query_service.go
Normal file
315
internal/service/dashboard_query_service.go
Normal file
@@ -0,0 +1,315 @@
|
||||
// Filename: internal/service/dashboard_query_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/syncer"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const overviewCacheChannel = "syncer:cache:dashboard_overview"
|
||||
|
||||
// DashboardQueryService 负责所有面向前端的仪表盘数据查询。
|
||||
|
||||
type DashboardQueryService struct {
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
overviewSyncer *syncer.CacheSyncer[*models.DashboardStatsResponse]
|
||||
logger *logrus.Entry
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
func NewDashboardQueryService(db *gorm.DB, s store.Store, logger *logrus.Logger) (*DashboardQueryService, error) {
|
||||
qs := &DashboardQueryService{
|
||||
db: db,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "DashboardQueryService"),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
loader := qs.loadOverviewData
|
||||
overviewSyncer, err := syncer.NewCacheSyncer(loader, s, overviewCacheChannel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create overview cache syncer: %w", err)
|
||||
}
|
||||
qs.overviewSyncer = overviewSyncer
|
||||
return qs, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) Start() {
|
||||
go s.eventListener()
|
||||
s.logger.Info("DashboardQueryService started and listening for invalidation events.")
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) Stop() {
|
||||
close(s.stopChan)
|
||||
s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.")
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, error) {
|
||||
statsKey := fmt.Sprintf("stats:group:%d", groupID)
|
||||
keyStatsMap, err := s.store.HGetAll(statsKey)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID)
|
||||
return nil, fmt.Errorf("failed to get key stats from cache: %w", err)
|
||||
}
|
||||
keyStats := make(map[string]int64)
|
||||
for k, v := range keyStatsMap {
|
||||
val, _ := strconv.ParseInt(v, 10, 64)
|
||||
keyStats[k] = val
|
||||
}
|
||||
now := time.Now()
|
||||
oneHourAgo := now.Add(-1 * time.Hour)
|
||||
twentyFourHoursAgo := now.Add(-24 * time.Hour)
|
||||
type requestStatsResult struct {
|
||||
TotalRequests int64
|
||||
SuccessRequests int64
|
||||
}
|
||||
var last1Hour, last24Hours requestStatsResult
|
||||
s.db.Model(&models.StatsHourly{}).
|
||||
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
|
||||
Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
|
||||
Scan(&last1Hour)
|
||||
s.db.Model(&models.StatsHourly{}).
|
||||
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
|
||||
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
|
||||
Scan(&last24Hours)
|
||||
failureRate1h := 0.0
|
||||
if last1Hour.TotalRequests > 0 {
|
||||
failureRate1h = float64(last1Hour.TotalRequests-last1Hour.SuccessRequests) / float64(last1Hour.TotalRequests) * 100
|
||||
}
|
||||
failureRate24h := 0.0
|
||||
if last24Hours.TotalRequests > 0 {
|
||||
failureRate24h = float64(last24Hours.TotalRequests-last24Hours.SuccessRequests) / float64(last24Hours.TotalRequests) * 100
|
||||
}
|
||||
last1HourStats := map[string]any{
|
||||
"total_requests": last1Hour.TotalRequests,
|
||||
"success_requests": last1Hour.SuccessRequests,
|
||||
"failure_rate": failureRate1h,
|
||||
}
|
||||
last24HoursStats := map[string]any{
|
||||
"total_requests": last24Hours.TotalRequests,
|
||||
"success_requests": last24Hours.SuccessRequests,
|
||||
"failure_rate": failureRate24h,
|
||||
}
|
||||
result := map[string]any{
|
||||
"key_stats": keyStats,
|
||||
"last_1_hour": last1HourStats,
|
||||
"last_24_hours": last24HoursStats,
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) eventListener() {
|
||||
keyStatusSub, _ := s.store.Subscribe(models.TopicKeyStatusChanged)
|
||||
upstreamStatusSub, _ := s.store.Subscribe(models.TopicUpstreamHealthChanged)
|
||||
defer keyStatusSub.Close()
|
||||
defer upstreamStatusSub.Close()
|
||||
for {
|
||||
select {
|
||||
case <-keyStatusSub.Channel():
|
||||
s.logger.Info("Received key status changed event, invalidating overview cache...")
|
||||
_ = s.InvalidateOverviewCache()
|
||||
case <-upstreamStatusSub.Channel():
|
||||
s.logger.Info("Received upstream status changed event, invalidating overview cache...")
|
||||
_ = s.InvalidateOverviewCache()
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Stopping dashboard event listener.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetDashboardOverviewData 从 Syncer 缓存中高速获取仪表盘概览数据。
|
||||
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
|
||||
cachedDataPtr := s.overviewSyncer.Get()
|
||||
if cachedDataPtr == nil {
|
||||
return &models.DashboardStatsResponse{}, fmt.Errorf("overview cache is not available or still syncing")
|
||||
}
|
||||
return cachedDataPtr, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) InvalidateOverviewCache() error {
|
||||
return s.overviewSyncer.Invalidate()
|
||||
}
|
||||
|
||||
// QueryHistoricalChart 查询历史图表数据。
|
||||
func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.ChartData, error) {
|
||||
type ChartPoint struct {
|
||||
TimeLabel string `gorm:"column:time_label"`
|
||||
ModelName string `gorm:"column:model_name"`
|
||||
TotalRequests int64 `gorm:"column:total_requests"`
|
||||
}
|
||||
sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour)
|
||||
sqlFormat, goFormat := s.buildTimeFormatSelectClause()
|
||||
selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat)
|
||||
query := s.db.Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
|
||||
if groupID != nil && *groupID > 0 {
|
||||
query = query.Where("group_id = ?", *groupID)
|
||||
}
|
||||
var points []ChartPoint
|
||||
if err := query.Find(&points).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
datasets := make(map[string]map[string]int64)
|
||||
for _, p := range points {
|
||||
if _, ok := datasets[p.ModelName]; !ok {
|
||||
datasets[p.ModelName] = make(map[string]int64)
|
||||
}
|
||||
datasets[p.ModelName][p.TimeLabel] = p.TotalRequests
|
||||
}
|
||||
var labels []string
|
||||
for t := sevenDaysAgo; t.Before(time.Now()); t = t.Add(time.Hour) {
|
||||
labels = append(labels, t.Format(goFormat))
|
||||
}
|
||||
chartData := &models.ChartData{Labels: labels, Datasets: make([]models.ChartDataset, 0)}
|
||||
colorPalette := []string{"#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0", "#9966FF", "#FF9F40"}
|
||||
colorIndex := 0
|
||||
for modelName, dataPoints := range datasets {
|
||||
dataArray := make([]int64, len(labels))
|
||||
for i, label := range labels {
|
||||
dataArray[i] = dataPoints[label]
|
||||
}
|
||||
chartData.Datasets = append(chartData.Datasets, models.ChartDataset{
|
||||
Label: modelName,
|
||||
Data: dataArray,
|
||||
Color: colorPalette[colorIndex%len(colorPalette)],
|
||||
})
|
||||
colorIndex++
|
||||
}
|
||||
return chartData, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
|
||||
s.logger.Info("[CacheSyncer] Starting to load overview data from database...")
|
||||
startTime := time.Now()
|
||||
resp := &models.DashboardStatsResponse{
|
||||
KeyStatusCount: make(map[models.APIKeyStatus]int64),
|
||||
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
|
||||
KeyCount: models.StatCard{}, // 确保KeyCount是一个空的结构体,而不是nil
|
||||
RequestCount24h: models.StatCard{}, // 同上
|
||||
TokenCount: make(map[string]any),
|
||||
UpstreamHealthStatus: make(map[string]string),
|
||||
RPM: models.StatCard{},
|
||||
RequestCounts: make(map[string]int64),
|
||||
}
|
||||
// --- 1. Aggregate Operational Status from Mappings ---
|
||||
type MappingStatusResult struct {
|
||||
Status models.APIKeyStatus
|
||||
Count int64
|
||||
}
|
||||
var mappingStatusResults []MappingStatusResult
|
||||
if err := s.db.Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to query mapping status stats: %w", err)
|
||||
}
|
||||
for _, res := range mappingStatusResults {
|
||||
resp.KeyStatusCount[res.Status] = res.Count
|
||||
}
|
||||
|
||||
// --- 2. Aggregate Master Status from APIKeys ---
|
||||
type MasterStatusResult struct {
|
||||
Status models.MasterAPIKeyStatus
|
||||
Count int64
|
||||
}
|
||||
var masterStatusResults []MasterStatusResult
|
||||
if err := s.db.Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to query master status stats: %w", err)
|
||||
}
|
||||
var totalKeys, invalidKeys int64
|
||||
for _, res := range masterStatusResults {
|
||||
resp.MasterStatusCount[res.Status] = res.Count
|
||||
totalKeys += res.Count
|
||||
if res.Status != models.MasterStatusActive {
|
||||
invalidKeys += res.Count
|
||||
}
|
||||
}
|
||||
resp.KeyCount = models.StatCard{Value: float64(totalKeys), SubValue: invalidKeys, SubValueTip: "非活跃身份密钥数"}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 1. RPM (1分钟), RPH (1小时), RPD (今日): 从“瞬时记忆”(request_logs)中精确查询
|
||||
var count1m, count1h, count1d int64
|
||||
// RPM: 从此刻倒推1分钟
|
||||
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
|
||||
// RPH: 从此刻倒推1小时
|
||||
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
|
||||
|
||||
// RPD: 从今天零点 (UTC) 到此刻
|
||||
year, month, day := now.UTC().Date()
|
||||
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
|
||||
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
|
||||
// 2. RP30D (30天): 从“长期记忆”(stats_hourly)中高效查询,以保证性能
|
||||
var count30d int64
|
||||
s.db.Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
|
||||
|
||||
resp.RequestCounts["1m"] = count1m
|
||||
resp.RequestCounts["1h"] = count1h
|
||||
resp.RequestCounts["1d"] = count1d
|
||||
resp.RequestCounts["30d"] = count30d
|
||||
|
||||
var upstreams []*models.UpstreamEndpoint
|
||||
if err := s.db.Find(&upstreams).Error; err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.")
|
||||
} else {
|
||||
for _, u := range upstreams {
|
||||
resp.UpstreamHealthStatus[u.URL] = u.Status
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
s.logger.Infof("[CacheSyncer] Successfully finished loading overview data in %s.", duration)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H, error) {
|
||||
var startTime time.Time
|
||||
now := time.Now()
|
||||
switch period {
|
||||
case "1m":
|
||||
startTime = now.Add(-1 * time.Minute)
|
||||
case "1h":
|
||||
startTime = now.Add(-1 * time.Hour)
|
||||
case "1d":
|
||||
year, month, day := now.UTC().Date()
|
||||
startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid period specified: %s", period)
|
||||
}
|
||||
var result struct {
|
||||
Total int64
|
||||
Success int64
|
||||
}
|
||||
|
||||
err := s.db.Model(&models.RequestLog{}).
|
||||
Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success").
|
||||
Where("request_time >= ?", startTime).
|
||||
Scan(&result).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return gin.H{
|
||||
"total": result.Total,
|
||||
"success": result.Success,
|
||||
"failure": result.Total - result.Success,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) buildTimeFormatSelectClause() (string, string) {
|
||||
dialect := s.db.Dialector.Name()
|
||||
switch dialect {
|
||||
case "mysql":
|
||||
return "DATE_FORMAT(time, '%Y-%m-%d %H:00:00')", "2006-01-02 15:00:00"
|
||||
case "sqlite":
|
||||
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
|
||||
default:
|
||||
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
|
||||
}
|
||||
}
|
||||
149
internal/service/db_log_writer_service.go
Normal file
149
internal/service/db_log_writer_service.go
Normal file
@@ -0,0 +1,149 @@
|
||||
// Filename: internal/service/db_log_writer_service.go (全新文件)
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type DBLogWriterService struct {
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
logBuffer chan *models.RequestLog
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
SettingsManager *settings.SettingsManager
|
||||
}
|
||||
|
||||
func NewDBLogWriterService(db *gorm.DB, s store.Store, settings *settings.SettingsManager, logger *logrus.Logger) *DBLogWriterService {
|
||||
cfg := settings.GetSettings()
|
||||
bufferCapacity := cfg.LogBufferCapacity
|
||||
if bufferCapacity <= 0 {
|
||||
bufferCapacity = 1000
|
||||
}
|
||||
return &DBLogWriterService{
|
||||
db: db,
|
||||
store: s,
|
||||
SettingsManager: settings,
|
||||
logger: logger.WithField("component", "DBLogWriter📝"),
|
||||
// 使用配置值来创建缓冲区
|
||||
logBuffer: make(chan *models.RequestLog, bufferCapacity),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DBLogWriterService) Start() {
|
||||
s.wg.Add(2) // 一个用于事件监听,一个用于数据库写入
|
||||
|
||||
// 启动事件监听器
|
||||
go s.eventListenerLoop()
|
||||
// 启动数据库写入器
|
||||
go s.dbWriterLoop()
|
||||
|
||||
s.logger.Info("DBLogWriterService started.")
|
||||
}
|
||||
|
||||
func (s *DBLogWriterService) Stop() {
|
||||
s.logger.Info("DBLogWriterService stopping...")
|
||||
close(s.stopChan) // 通知所有goroutine停止
|
||||
s.wg.Wait() // 等待所有goroutine完成
|
||||
s.logger.Info("DBLogWriterService stopped.")
|
||||
}
|
||||
|
||||
// eventListenerLoop 负责从store接收事件并放入内存缓冲区
|
||||
func (s *DBLogWriterService) eventListenerLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
sub, err := s.store.Subscribe(models.TopicRequestFinished)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
|
||||
return
|
||||
}
|
||||
defer sub.Close()
|
||||
|
||||
s.logger.Info("Subscribed to request events for database logging.")
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-sub.Channel():
|
||||
var event models.RequestFinishedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal event for logging: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 将事件中的日志部分放入缓冲区
|
||||
select {
|
||||
case s.logBuffer <- &event.RequestLog:
|
||||
default:
|
||||
s.logger.Warn("Log buffer is full. A log message might be dropped.")
|
||||
}
|
||||
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Event listener loop stopping.")
|
||||
// 关闭缓冲区,以通知dbWriterLoop处理完剩余日志后退出
|
||||
close(s.logBuffer)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dbWriterLoop 负责从内存缓冲区批量读取日志并写入数据库
|
||||
func (s *DBLogWriterService) dbWriterLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
// 在启动时获取一次配置
|
||||
cfg := s.SettingsManager.GetSettings()
|
||||
batchSize := cfg.LogFlushBatchSize
|
||||
if batchSize <= 0 {
|
||||
batchSize = 100
|
||||
}
|
||||
|
||||
flushTimeout := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
|
||||
if flushTimeout <= 0 {
|
||||
flushTimeout = 5 * time.Second
|
||||
}
|
||||
batch := make([]*models.RequestLog, 0, batchSize)
|
||||
ticker := time.NewTicker(flushTimeout)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case logEntry, ok := <-s.logBuffer:
|
||||
if !ok {
|
||||
if len(batch) > 0 {
|
||||
s.flushBatch(batch)
|
||||
}
|
||||
s.logger.Info("DB writer loop finished.")
|
||||
return
|
||||
}
|
||||
batch = append(batch, logEntry)
|
||||
if len(batch) >= batchSize { // 使用配置的批次大小
|
||||
s.flushBatch(batch)
|
||||
batch = make([]*models.RequestLog, 0, batchSize)
|
||||
}
|
||||
case <-ticker.C:
|
||||
if len(batch) > 0 {
|
||||
s.flushBatch(batch)
|
||||
batch = make([]*models.RequestLog, 0, batchSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// flushBatch 将一个批次的日志写入数据库
|
||||
func (s *DBLogWriterService) flushBatch(batch []*models.RequestLog) {
|
||||
if err := s.db.CreateInBatches(batch, len(batch)).Error; err != nil {
|
||||
s.logger.WithField("batch_size", len(batch)).WithError(err).Error("Failed to flush log batch to database.")
|
||||
} else {
|
||||
s.logger.Infof("Successfully flushed %d logs to database.", len(batch))
|
||||
}
|
||||
}
|
||||
596
internal/service/group_manager.go
Normal file
596
internal/service/group_manager.go
Normal file
@@ -0,0 +1,596 @@
|
||||
// Filename: internal/service/group_manager.go (Syncer升级版)
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/pkg/reflectutil"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/syncer"
|
||||
"gemini-balancer/internal/utils"
|
||||
"net/url"
|
||||
"path"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const GroupUpdateChannel = "groups:cache_invalidation"
|
||||
|
||||
type GroupManagerCacheData struct {
|
||||
Groups []*models.KeyGroup
|
||||
GroupsByName map[string]*models.KeyGroup
|
||||
GroupsByID map[uint]*models.KeyGroup
|
||||
KeyCounts map[uint]int64 // GroupID -> Total Key Count
|
||||
KeyStatusCounts map[uint]map[models.APIKeyStatus]int64 // GroupID -> Status -> Count
|
||||
}
|
||||
|
||||
type GroupManager struct {
|
||||
db *gorm.DB
|
||||
keyRepo repository.KeyRepository
|
||||
groupRepo repository.GroupRepository
|
||||
settingsManager *settings.SettingsManager
|
||||
syncer *syncer.CacheSyncer[GroupManagerCacheData]
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
type UpdateOrderPayload struct {
|
||||
ID uint `json:"id" binding:"required"`
|
||||
Order int `json:"order"`
|
||||
}
|
||||
|
||||
func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc[GroupManagerCacheData] {
|
||||
return func() (GroupManagerCacheData, error) {
|
||||
logger.Debugf("[GML-LOG 1/5] ---> Entering NewGroupManagerLoader...")
|
||||
var groups []*models.KeyGroup
|
||||
logger.Debugf("[GML-LOG 2/5] About to execute DB query with Preloads...")
|
||||
|
||||
if err := db.Preload("AllowedUpstreams").
|
||||
Preload("AllowedModels").
|
||||
Preload("Settings").
|
||||
Preload("RequestConfig").
|
||||
Find(&groups).Error; err != nil {
|
||||
logger.Errorf("[GML-LOG] CRITICAL: DB query for groups failed: %v", err)
|
||||
return GroupManagerCacheData{}, fmt.Errorf("failed to load key groups for cache: %w", err)
|
||||
}
|
||||
logger.Debugf("[GML-LOG 2.1/5] DB query for groups finished. Found %d group records.", len(groups))
|
||||
|
||||
var allMappings []*models.GroupAPIKeyMapping
|
||||
if err := db.Find(&allMappings).Error; err != nil {
|
||||
logger.Errorf("[GML-LOG] CRITICAL: DB query for mappings failed: %v", err)
|
||||
return GroupManagerCacheData{}, fmt.Errorf("failed to load key mappings for cache: %w", err)
|
||||
}
|
||||
logger.Debugf("[GML-LOG 2.2/5] DB query for mappings finished. Found %d total mapping records.", len(allMappings))
|
||||
|
||||
mappingsByGroupID := make(map[uint][]*models.GroupAPIKeyMapping)
|
||||
for i := range allMappings {
|
||||
mapping := allMappings[i] // Avoid pointer issues with range
|
||||
mappingsByGroupID[mapping.KeyGroupID] = append(mappingsByGroupID[mapping.KeyGroupID], mapping)
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
if mappings, ok := mappingsByGroupID[group.ID]; ok {
|
||||
group.Mappings = mappings
|
||||
}
|
||||
}
|
||||
logger.Debugf("[GML-LOG 3/5] Finished manually associating mappings to groups.")
|
||||
|
||||
keyCounts := make(map[uint]int64, len(groups))
|
||||
keyStatusCounts := make(map[uint]map[models.APIKeyStatus]int64, len(groups))
|
||||
|
||||
for _, group := range groups {
|
||||
keyCounts[group.ID] = int64(len(group.Mappings))
|
||||
statusCounts := make(map[models.APIKeyStatus]int64)
|
||||
for _, mapping := range group.Mappings {
|
||||
statusCounts[mapping.Status]++
|
||||
}
|
||||
keyStatusCounts[group.ID] = statusCounts
|
||||
}
|
||||
groupsByName := make(map[string]*models.KeyGroup, len(groups))
|
||||
groupsByID := make(map[uint]*models.KeyGroup, len(groups))
|
||||
|
||||
logger.Debugf("[GML-LOG 4/5] Starting to process group records into maps...")
|
||||
for i, group := range groups {
|
||||
if group == nil {
|
||||
logger.Debugf("[GML] CRITICAL: Found a 'nil' group pointer at index %d! This is the most likely cause of the panic.", i)
|
||||
} else {
|
||||
groupsByName[group.Name] = group
|
||||
groupsByID[group.ID] = group
|
||||
}
|
||||
}
|
||||
logger.Debugf("[GML-LOG 5/5] Finished processing records. Building final cache data...")
|
||||
return GroupManagerCacheData{
|
||||
Groups: groups,
|
||||
GroupsByName: groupsByName,
|
||||
GroupsByID: groupsByID,
|
||||
KeyCounts: keyCounts,
|
||||
KeyStatusCounts: keyStatusCounts,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func NewGroupManager(
|
||||
db *gorm.DB,
|
||||
keyRepo repository.KeyRepository,
|
||||
groupRepo repository.GroupRepository,
|
||||
sm *settings.SettingsManager,
|
||||
syncer *syncer.CacheSyncer[GroupManagerCacheData],
|
||||
logger *logrus.Logger,
|
||||
) *GroupManager {
|
||||
return &GroupManager{
|
||||
db: db,
|
||||
keyRepo: keyRepo,
|
||||
groupRepo: groupRepo,
|
||||
settingsManager: sm,
|
||||
syncer: syncer,
|
||||
logger: logger.WithField("component", "GroupManager"),
|
||||
}
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetAllGroups() []*models.KeyGroup {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.Groups) == 0 {
|
||||
return []*models.KeyGroup{}
|
||||
}
|
||||
groupsToOrder := cache.Groups
|
||||
sort.Slice(groupsToOrder, func(i, j int) bool {
|
||||
if groupsToOrder[i].Order != groupsToOrder[j].Order {
|
||||
return groupsToOrder[i].Order < groupsToOrder[j].Order
|
||||
}
|
||||
return groupsToOrder[i].ID < groupsToOrder[j].ID
|
||||
})
|
||||
return groupsToOrder
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetKeyCount(groupID uint) int64 {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.KeyCounts) == 0 {
|
||||
return 0
|
||||
}
|
||||
count := cache.KeyCounts[groupID]
|
||||
return count
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetKeyStatusCount(groupID uint) map[models.APIKeyStatus]int64 {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.KeyStatusCounts) == 0 {
|
||||
return make(map[models.APIKeyStatus]int64)
|
||||
}
|
||||
if counts, ok := cache.KeyStatusCounts[groupID]; ok {
|
||||
return counts
|
||||
}
|
||||
return make(map[models.APIKeyStatus]int64)
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetGroupByName(name string) (*models.KeyGroup, bool) {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.GroupsByName) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
group, ok := cache.GroupsByName[name]
|
||||
return group, ok
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetGroupByID(id uint) (*models.KeyGroup, bool) {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.GroupsByID) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
group, ok := cache.GroupsByID[id]
|
||||
return group, ok
|
||||
}
|
||||
|
||||
func (gm *GroupManager) Stop() {
|
||||
gm.syncer.Stop()
|
||||
}
|
||||
|
||||
func (gm *GroupManager) Invalidate() error {
|
||||
return gm.syncer.Invalidate()
|
||||
}
|
||||
|
||||
// --- Write Operations ---
|
||||
|
||||
// CreateKeyGroup creates a new key group, including its operational settings, and invalidates the cache.
|
||||
func (gm *GroupManager) CreateKeyGroup(group *models.KeyGroup, settings *models.KeyGroupSettings) error {
|
||||
if !utils.IsValidGroupName(group.Name) {
|
||||
return errors.New("invalid group name: must contain only lowercase letters, numbers, and hyphens")
|
||||
}
|
||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 1. Create the group itself to get an ID
|
||||
if err := tx.Create(group).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. If settings are provided, create the associated GroupSettings record
|
||||
if settings != nil {
|
||||
// Only marshal non-nil fields to keep the JSON clean
|
||||
settingsToMarshal := make(map[string]interface{})
|
||||
if settings.EnableKeyCheck != nil {
|
||||
settingsToMarshal["enable_key_check"] = settings.EnableKeyCheck
|
||||
}
|
||||
if settings.KeyCheckIntervalMinutes != nil {
|
||||
settingsToMarshal["key_check_interval_minutes"] = settings.KeyCheckIntervalMinutes
|
||||
}
|
||||
if settings.KeyBlacklistThreshold != nil {
|
||||
settingsToMarshal["key_blacklist_threshold"] = settings.KeyBlacklistThreshold
|
||||
}
|
||||
if settings.KeyCooldownMinutes != nil {
|
||||
settingsToMarshal["key_cooldown_minutes"] = settings.KeyCooldownMinutes
|
||||
}
|
||||
if settings.KeyCheckConcurrency != nil {
|
||||
settingsToMarshal["key_check_concurrency"] = settings.KeyCheckConcurrency
|
||||
}
|
||||
if settings.KeyCheckEndpoint != nil {
|
||||
settingsToMarshal["key_check_endpoint"] = settings.KeyCheckEndpoint
|
||||
}
|
||||
if settings.KeyCheckModel != nil {
|
||||
settingsToMarshal["key_check_model"] = settings.KeyCheckModel
|
||||
}
|
||||
if settings.MaxRetries != nil {
|
||||
settingsToMarshal["max_retries"] = settings.MaxRetries
|
||||
}
|
||||
if settings.EnableSmartGateway != nil {
|
||||
settingsToMarshal["enable_smart_gateway"] = settings.EnableSmartGateway
|
||||
}
|
||||
if len(settingsToMarshal) > 0 {
|
||||
settingsJSON, err := json.Marshal(settingsToMarshal)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal group settings: %w", err)
|
||||
}
|
||||
groupSettings := models.GroupSettings{
|
||||
GroupID: group.ID,
|
||||
SettingsJSON: datatypes.JSON(settingsJSON),
|
||||
}
|
||||
if err := tx.Create(&groupSettings).Error; err != nil {
|
||||
return fmt.Errorf("failed to save group settings: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go gm.Invalidate()
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateKeyGroup updates an existing key group, its settings, and associations, then invalidates the cache.
|
||||
func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *models.KeyGroupSettings, upstreamURLs []string, modelNames []string) error {
|
||||
if !utils.IsValidGroupName(group.Name) {
|
||||
return fmt.Errorf("invalid group name: must contain only lowercase letters, numbers, and hyphens")
|
||||
}
|
||||
uniqueUpstreamURLs := uniqueStrings(upstreamURLs)
|
||||
uniqueModelNames := uniqueStrings(modelNames)
|
||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||
// --- 1. Update AllowedUpstreams (M:N relationship) ---
|
||||
var upstreams []*models.UpstreamEndpoint
|
||||
if len(uniqueUpstreamURLs) > 0 {
|
||||
if err := tx.Where("url IN ?", uniqueUpstreamURLs).Find(&upstreams).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := tx.Model(group).Association("AllowedUpstreams").Replace(upstreams); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Model(group).Association("AllowedModels").Clear(); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(uniqueModelNames) > 0 {
|
||||
var newMappings []models.GroupModelMapping
|
||||
for _, name := range uniqueModelNames {
|
||||
newMappings = append(newMappings, models.GroupModelMapping{ModelName: name})
|
||||
}
|
||||
if err := tx.Model(group).Association("AllowedModels").Append(newMappings); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Model(group).Updates(group).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var existingSettings models.GroupSettings
|
||||
if err := tx.Where("group_id = ?", group.ID).First(&existingSettings).Error; err != nil && err != gorm.ErrRecordNotFound {
|
||||
return err
|
||||
}
|
||||
var currentSettingsData models.KeyGroupSettings
|
||||
if len(existingSettings.SettingsJSON) > 0 {
|
||||
if err := json.Unmarshal(existingSettings.SettingsJSON, ¤tSettingsData); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal existing group settings: %w", err)
|
||||
}
|
||||
}
|
||||
if err := reflectutil.MergeNilFields(¤tSettingsData, newSettings); err != nil {
|
||||
return fmt.Errorf("failed to merge group settings: %w", err)
|
||||
}
|
||||
updatedJSON, err := json.Marshal(currentSettingsData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal updated group settings: %w", err)
|
||||
}
|
||||
existingSettings.GroupID = group.ID
|
||||
existingSettings.SettingsJSON = datatypes.JSON(updatedJSON)
|
||||
return tx.Save(&existingSettings).Error
|
||||
})
|
||||
if err == nil {
|
||||
go gm.Invalidate()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteKeyGroup deletes a key group and subsequently cleans up any keys that have become orphans.
|
||||
func (gm *GroupManager) DeleteKeyGroup(id uint) error {
|
||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||
gm.logger.Infof("Starting transaction to delete KeyGroup ID: %d", id)
|
||||
// Step 1: First, retrieve the group object we are about to delete.
|
||||
var group models.KeyGroup
|
||||
if err := tx.First(&group, id).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
gm.logger.Warnf("Attempted to delete a non-existent KeyGroup with ID: %d", id)
|
||||
return nil // Don't treat as an error, the group is already gone.
|
||||
}
|
||||
gm.logger.WithError(err).Errorf("Failed to find KeyGroup with ID: %d for deletion", id)
|
||||
return err
|
||||
}
|
||||
// Step 2: Clear all many-to-many and one-to-many associations using GORM's safe methods.
|
||||
if err := tx.Model(&group).Association("AllowedUpstreams").Clear(); err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to clear 'AllowedUpstreams' association for KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(&group).Association("AllowedModels").Clear(); err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to clear 'AllowedModels' association for KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(&group).Association("Mappings").Clear(); err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to clear 'Mappings' (API Key associations) for KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
// Also clear settings if they exist to maintain data integrity
|
||||
if err := tx.Model(&group).Association("Settings").Delete(group.Settings); err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to delete 'Settings' association for KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
// Step 3: Delete the KeyGroup itself.
|
||||
if err := tx.Delete(&group).Error; err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to delete KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
gm.logger.Infof("KeyGroup ID %d associations cleared and entity deleted. Triggering orphan key cleanup.", id)
|
||||
// Step 4: Trigger the orphan key cleanup (this logic remains the same and is correct).
|
||||
deletedCount, err := gm.keyRepo.DeleteOrphanKeysTx(tx)
|
||||
if err != nil {
|
||||
gm.logger.WithError(err).Error("Failed to clean up orphan keys after deleting group.")
|
||||
return err
|
||||
}
|
||||
if deletedCount > 0 {
|
||||
gm.logger.Infof("Successfully cleaned up %d orphan keys.", deletedCount)
|
||||
}
|
||||
gm.logger.Infof("Transaction for deleting KeyGroup ID: %d completed successfully.", id)
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
go gm.Invalidate()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||
var originalGroup models.KeyGroup
|
||||
if err := gm.db.
|
||||
Preload("RequestConfig").
|
||||
Preload("Mappings").
|
||||
Preload("AllowedUpstreams").
|
||||
Preload("AllowedModels").
|
||||
First(&originalGroup, id).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to find original group with id %d: %w", id, err)
|
||||
}
|
||||
newGroup := originalGroup
|
||||
timestamp := time.Now().Unix()
|
||||
newGroup.ID = 0
|
||||
newGroup.Name = fmt.Sprintf("%s-clone-%d", originalGroup.Name, timestamp)
|
||||
newGroup.DisplayName = fmt.Sprintf("%s-clone-%d", originalGroup.DisplayName, timestamp)
|
||||
newGroup.CreatedAt = time.Time{}
|
||||
newGroup.UpdatedAt = time.Time{}
|
||||
|
||||
newGroup.RequestConfigID = nil
|
||||
newGroup.RequestConfig = nil
|
||||
newGroup.Mappings = nil
|
||||
newGroup.AllowedUpstreams = nil
|
||||
newGroup.AllowedModels = nil
|
||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||
|
||||
if err := tx.Create(&newGroup).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if originalGroup.RequestConfig != nil {
|
||||
newRequestConfig := *originalGroup.RequestConfig
|
||||
newRequestConfig.ID = 0 // Mark as new record
|
||||
|
||||
if err := tx.Create(&newRequestConfig).Error; err != nil {
|
||||
return fmt.Errorf("failed to clone request config: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Model(&newGroup).Update("request_config_id", newRequestConfig.ID).Error; err != nil {
|
||||
return fmt.Errorf("failed to link new group to cloned request config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var originalSettings models.GroupSettings
|
||||
err := tx.Where("group_id = ?", originalGroup.ID).First(&originalSettings).Error
|
||||
if err == nil && len(originalSettings.SettingsJSON) > 0 {
|
||||
newSettings := models.GroupSettings{
|
||||
GroupID: newGroup.ID,
|
||||
SettingsJSON: originalSettings.SettingsJSON,
|
||||
}
|
||||
if err := tx.Create(&newSettings).Error; err != nil {
|
||||
return fmt.Errorf("failed to clone group settings: %w", err)
|
||||
}
|
||||
} else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("failed to query original group settings: %w", err)
|
||||
}
|
||||
|
||||
if len(originalGroup.Mappings) > 0 {
|
||||
newMappings := make([]models.GroupAPIKeyMapping, len(originalGroup.Mappings))
|
||||
for i, oldMapping := range originalGroup.Mappings {
|
||||
newMappings[i] = models.GroupAPIKeyMapping{
|
||||
KeyGroupID: newGroup.ID,
|
||||
APIKeyID: oldMapping.APIKeyID,
|
||||
Status: oldMapping.Status,
|
||||
LastError: oldMapping.LastError,
|
||||
ConsecutiveErrorCount: oldMapping.ConsecutiveErrorCount,
|
||||
LastUsedAt: oldMapping.LastUsedAt,
|
||||
CooldownUntil: oldMapping.CooldownUntil,
|
||||
}
|
||||
}
|
||||
if err := tx.Create(&newMappings).Error; err != nil {
|
||||
return fmt.Errorf("failed to clone key group mappings: %w", err)
|
||||
}
|
||||
}
|
||||
if len(originalGroup.AllowedUpstreams) > 0 {
|
||||
if err := tx.Model(&newGroup).Association("AllowedUpstreams").Append(originalGroup.AllowedUpstreams); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(originalGroup.AllowedModels) > 0 {
|
||||
if err := tx.Model(&newGroup).Association("AllowedModels").Append(originalGroup.AllowedModels); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go gm.Invalidate()
|
||||
|
||||
var finalClonedGroup models.KeyGroup
|
||||
if err := gm.db.
|
||||
Preload("RequestConfig").
|
||||
Preload("Mappings").
|
||||
Preload("AllowedUpstreams").
|
||||
Preload("AllowedModels").
|
||||
First(&finalClonedGroup, newGroup.ID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &finalClonedGroup, nil
|
||||
}
|
||||
|
||||
func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) {
|
||||
globalSettings := gm.settingsManager.GetSettings()
|
||||
s := "gemini-1.5-flash" // Per user feedback for default model
|
||||
opConfig := &models.KeyGroupSettings{
|
||||
EnableKeyCheck: &globalSettings.EnableBaseKeyCheck,
|
||||
KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency,
|
||||
KeyCheckIntervalMinutes: &globalSettings.BaseKeyCheckIntervalMinutes,
|
||||
KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL,
|
||||
KeyBlacklistThreshold: &globalSettings.BlacklistThreshold,
|
||||
KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes,
|
||||
KeyCheckModel: &s,
|
||||
MaxRetries: &globalSettings.MaxRetries,
|
||||
EnableSmartGateway: &globalSettings.EnableSmartGateway,
|
||||
}
|
||||
|
||||
if group == nil {
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
var groupSettingsRecord models.GroupSettings
|
||||
err := gm.db.Where("group_id = ?", group.ID).First(&groupSettingsRecord).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
gm.logger.WithError(err).Errorf("Failed to query group settings for group ID %d", group.ID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(groupSettingsRecord.SettingsJSON) == 0 {
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
var groupSpecificSettings models.KeyGroupSettings
|
||||
if err := json.Unmarshal(groupSettingsRecord.SettingsJSON, &groupSpecificSettings); err != nil {
|
||||
gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings JSON.")
|
||||
return opConfig, err
|
||||
}
|
||||
|
||||
if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil {
|
||||
gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group-specific settings over defaults.")
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
|
||||
group, ok := gm.GetGroupByID(groupID)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("group with id %d not found", groupID)
|
||||
}
|
||||
opConfig, err := gm.BuildOperationalConfig(group)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build operational config for group %d: %w", groupID, err)
|
||||
}
|
||||
globalSettings := gm.settingsManager.GetSettings()
|
||||
baseURL := globalSettings.DefaultUpstreamURL
|
||||
if opConfig.KeyCheckEndpoint != nil && *opConfig.KeyCheckEndpoint != "" {
|
||||
baseURL = *opConfig.KeyCheckEndpoint
|
||||
}
|
||||
if baseURL == "" {
|
||||
return "", fmt.Errorf("no key check endpoint or default upstream URL is configured for group %d", groupID)
|
||||
}
|
||||
modelName := globalSettings.BaseKeyCheckModel
|
||||
if opConfig.KeyCheckModel != nil && *opConfig.KeyCheckModel != "" {
|
||||
modelName = *opConfig.KeyCheckModel
|
||||
}
|
||||
parsedURL, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse base URL '%s': %w", baseURL, err)
|
||||
}
|
||||
cleanedPath := parsedURL.Path
|
||||
cleanedPath = strings.TrimSuffix(cleanedPath, "/")
|
||||
cleanedPath = strings.TrimSuffix(cleanedPath, "/v1beta")
|
||||
parsedURL.Path = path.Join(cleanedPath, "v1beta", "models", modelName)
|
||||
finalEndpoint := parsedURL.String()
|
||||
return finalEndpoint, nil
|
||||
}
|
||||
|
||||
func (gm *GroupManager) UpdateOrder(payload []UpdateOrderPayload) error {
|
||||
ordersMap := make(map[uint]int, len(payload))
|
||||
for _, item := range payload {
|
||||
ordersMap[item.ID] = item.Order
|
||||
}
|
||||
if err := gm.groupRepo.UpdateOrderInTransaction(ordersMap); err != nil {
|
||||
gm.logger.WithError(err).Error("Failed to update group order in transaction")
|
||||
return fmt.Errorf("database transaction failed: %w", err)
|
||||
}
|
||||
gm.logger.Info("Group order updated successfully, invalidating cache...")
|
||||
go gm.Invalidate()
|
||||
return nil
|
||||
}
|
||||
|
||||
func uniqueStrings(slice []string) []string {
|
||||
keys := make(map[string]struct{})
|
||||
list := []string{}
|
||||
for _, entry := range slice {
|
||||
if _, value := keys[entry]; !value {
|
||||
keys[entry] = struct{}{}
|
||||
list = append(list, entry)
|
||||
}
|
||||
}
|
||||
return list
|
||||
}
|
||||
624
internal/service/healthcheck_service.go
Normal file
624
internal/service/healthcheck_service.go
Normal file
@@ -0,0 +1,624 @@
|
||||
// Filename: internal/service/healthcheck_service.go (最终校准版)
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"gemini-balancer/internal/channel"
|
||||
CustomErrors "gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/proxy"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ProxyCheckTargetURL = "https://www.google.com/generate_204"
|
||||
DefaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
StatusActive = "active"
|
||||
StatusInactive = "inactive"
|
||||
)
|
||||
|
||||
type HealthCheckServiceLogger struct{ *logrus.Entry }
|
||||
|
||||
type HealthCheckService struct {
|
||||
db *gorm.DB
|
||||
SettingsManager *settings.SettingsManager
|
||||
store store.Store
|
||||
keyRepo repository.KeyRepository
|
||||
groupManager *GroupManager
|
||||
channel channel.ChannelProxy
|
||||
keyValidationService *KeyValidationService
|
||||
logger *logrus.Entry
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
lastResultsMutex sync.RWMutex
|
||||
lastResults map[string]string
|
||||
groupCheckTimeMutex sync.Mutex
|
||||
groupNextCheckTime map[uint]time.Time
|
||||
}
|
||||
|
||||
func NewHealthCheckService(
|
||||
db *gorm.DB,
|
||||
ss *settings.SettingsManager,
|
||||
s store.Store,
|
||||
repo repository.KeyRepository,
|
||||
gm *GroupManager,
|
||||
ch channel.ChannelProxy,
|
||||
kvs *KeyValidationService,
|
||||
logger *logrus.Logger,
|
||||
) *HealthCheckService {
|
||||
return &HealthCheckService{
|
||||
db: db,
|
||||
SettingsManager: ss,
|
||||
store: s,
|
||||
keyRepo: repo,
|
||||
groupManager: gm,
|
||||
channel: ch,
|
||||
keyValidationService: kvs,
|
||||
logger: logger.WithField("component", "HealthCheck🩺"),
|
||||
stopChan: make(chan struct{}),
|
||||
lastResults: make(map[string]string),
|
||||
groupNextCheckTime: make(map[uint]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) Start() {
|
||||
s.logger.Info("Starting HealthCheckService with independent check loops...")
|
||||
s.wg.Add(4) // Now four loops
|
||||
go s.runKeyCheckLoop()
|
||||
go s.runUpstreamCheckLoop()
|
||||
go s.runProxyCheckLoop()
|
||||
go s.runBaseKeyCheckLoop()
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) Stop() {
|
||||
s.logger.Info("Stopping HealthCheckService...")
|
||||
close(s.stopChan)
|
||||
s.wg.Wait()
|
||||
s.logger.Info("HealthCheckService stopped gracefully.")
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) GetLastHealthCheckResults() map[string]string {
|
||||
s.lastResultsMutex.RLock()
|
||||
defer s.lastResultsMutex.RUnlock()
|
||||
resultsCopy := make(map[string]string, len(s.lastResults))
|
||||
for k, v := range s.lastResults {
|
||||
resultsCopy[k] = v
|
||||
}
|
||||
return resultsCopy
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) runKeyCheckLoop() {
|
||||
defer s.wg.Done()
|
||||
s.logger.Info("Key check dynamic scheduler loop started.")
|
||||
|
||||
// 主调度循环,每分钟检查一次任务
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.scheduleKeyChecks()
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Key check scheduler loop stopped.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) scheduleKeyChecks() {
|
||||
groups := s.groupManager.GetAllGroups()
|
||||
now := time.Now()
|
||||
|
||||
s.groupCheckTimeMutex.Lock()
|
||||
defer s.groupCheckTimeMutex.Unlock()
|
||||
|
||||
for _, group := range groups {
|
||||
// 获取特定于组的运营配置
|
||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build operational config for group, skipping health check scheduling.")
|
||||
continue
|
||||
}
|
||||
|
||||
if opConfig.EnableKeyCheck == nil || !*opConfig.EnableKeyCheck {
|
||||
continue // 跳过禁用了健康检查的组
|
||||
}
|
||||
|
||||
var intervalMinutes int
|
||||
if opConfig.KeyCheckIntervalMinutes != nil {
|
||||
intervalMinutes = *opConfig.KeyCheckIntervalMinutes
|
||||
}
|
||||
interval := time.Duration(intervalMinutes) * time.Minute
|
||||
if interval <= 0 {
|
||||
continue // 跳过无效的检查周期
|
||||
}
|
||||
|
||||
if nextCheckTime, ok := s.groupNextCheckTime[group.ID]; !ok || now.After(nextCheckTime) {
|
||||
s.logger.Infof("Scheduling key check for group '%s' (ID: %d)", group.Name, group.ID)
|
||||
go s.performKeyChecksForGroup(group, opConfig)
|
||||
s.groupNextCheckTime[group.ID] = now.Add(interval)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) runUpstreamCheckLoop() {
|
||||
defer s.wg.Done()
|
||||
s.logger.Info("Upstream check loop started.")
|
||||
if s.SettingsManager.GetSettings().EnableUpstreamCheck {
|
||||
s.performUpstreamChecks()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().UpstreamCheckIntervalSeconds) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if s.SettingsManager.GetSettings().EnableUpstreamCheck {
|
||||
s.logger.Debug("Upstream check ticker fired.")
|
||||
s.performUpstreamChecks()
|
||||
}
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Upstream check loop stopped.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) runProxyCheckLoop() {
|
||||
defer s.wg.Done()
|
||||
s.logger.Info("Proxy check loop started.")
|
||||
if s.SettingsManager.GetSettings().EnableProxyCheck {
|
||||
s.performProxyChecks()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(time.Duration(s.SettingsManager.GetSettings().ProxyCheckIntervalSeconds) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if s.SettingsManager.GetSettings().EnableProxyCheck {
|
||||
s.logger.Debug("Proxy check ticker fired.")
|
||||
s.performProxyChecks()
|
||||
}
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Proxy check loop stopped.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) performKeyChecksForGroup(group *models.KeyGroup, opConfig *models.KeyGroupSettings) {
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
||||
|
||||
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(group.ID)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to build key check endpoint for group, skipping check cycle.")
|
||||
return
|
||||
}
|
||||
|
||||
log := s.logger.WithFields(logrus.Fields{"group_id": group.ID, "group_name": group.Name})
|
||||
|
||||
log.Infof("Starting key health check cycle.")
|
||||
|
||||
var mappingsToCheck []models.GroupAPIKeyMapping
|
||||
err = s.db.Model(&models.GroupAPIKeyMapping{}).
|
||||
Joins("JOIN api_keys ON api_keys.id = group_api_key_mappings.api_key_id").
|
||||
Where("group_api_key_mappings.key_group_id = ?", group.ID).
|
||||
Where("api_keys.master_status = ?", models.MasterStatusActive).
|
||||
Where("group_api_key_mappings.status IN ?", []models.APIKeyStatus{models.StatusActive, models.StatusDisabled, models.StatusCooldown}).
|
||||
Preload("APIKey").
|
||||
Find(&mappingsToCheck).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to fetch key mappings for health check.")
|
||||
return
|
||||
}
|
||||
if len(mappingsToCheck) == 0 {
|
||||
log.Info("No key mappings to check for this group.")
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Starting health check for %d key mappings.", len(mappingsToCheck))
|
||||
jobs := make(chan models.GroupAPIKeyMapping, len(mappingsToCheck))
|
||||
var wg sync.WaitGroup
|
||||
var concurrency int
|
||||
if opConfig.KeyCheckConcurrency != nil {
|
||||
concurrency = *opConfig.KeyCheckConcurrency
|
||||
}
|
||||
if concurrency <= 0 {
|
||||
concurrency = 1 // 保证至少有一个 worker
|
||||
}
|
||||
for w := 1; w <= concurrency; w++ {
|
||||
wg.Add(1)
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
for mapping := range jobs {
|
||||
s.checkAndProcessMapping(&mapping, timeout, endpoint)
|
||||
}
|
||||
}(w)
|
||||
}
|
||||
for _, m := range mappingsToCheck {
|
||||
jobs <- m
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
log.Info("Finished key health check cycle.")
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) checkAndProcessMapping(mapping *models.GroupAPIKeyMapping, timeout time.Duration, endpoint string) {
|
||||
if mapping.APIKey == nil {
|
||||
s.logger.Warnf("Skipping check for mapping (G:%d, K:%d) because associated APIKey is nil.", mapping.KeyGroupID, mapping.APIKeyID)
|
||||
return
|
||||
}
|
||||
validationErr := s.keyValidationService.ValidateSingleKey(mapping.APIKey, timeout, endpoint)
|
||||
// --- 诊断一:验证成功 (健康) ---
|
||||
if validationErr == nil {
|
||||
if mapping.Status != models.StatusActive {
|
||||
s.activateMapping(mapping)
|
||||
}
|
||||
return
|
||||
}
|
||||
errorString := validationErr.Error()
|
||||
// --- 诊断二:永久性错误 ---
|
||||
if CustomErrors.IsPermanentUpstreamError(errorString) {
|
||||
s.revokeMapping(mapping, validationErr)
|
||||
return
|
||||
}
|
||||
// --- 诊断三:暂时性错误 ---
|
||||
if CustomErrors.IsTemporaryUpstreamError(errorString) {
|
||||
// Log with a higher level (WARN) since this is an actionable, proactive finding.
|
||||
s.logger.Warnf("Health check for key %d failed with temporary error, applying penalty. Reason: %v", mapping.APIKeyID, validationErr)
|
||||
s.penalizeMapping(mapping, validationErr)
|
||||
return
|
||||
}
|
||||
// --- 诊断四:其他未知或上游服务错误 ---
|
||||
|
||||
s.logger.Errorf("Health check for key %d failed with a transient or unknown upstream error: %v. Mapping will not be penalized by health check.", mapping.APIKeyID, validationErr)
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) activateMapping(mapping *models.GroupAPIKeyMapping) {
|
||||
oldStatus := mapping.Status
|
||||
mapping.Status = models.StatusActive
|
||||
mapping.ConsecutiveErrorCount = 0
|
||||
mapping.LastError = ""
|
||||
if err := s.keyRepo.UpdateMapping(mapping); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to activate mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Mapping (G:%d, K:%d) successfully activated from %s to ACTIVE.", mapping.KeyGroupID, mapping.APIKeyID, oldStatus)
|
||||
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) penalizeMapping(mapping *models.GroupAPIKeyMapping, err error) {
|
||||
// Re-fetch group-specific operational config to get the correct thresholds
|
||||
group, ok := s.groupManager.GetGroupByID(mapping.KeyGroupID)
|
||||
if !ok {
|
||||
s.logger.Errorf("Could not find group with ID %d to apply penalty.", mapping.KeyGroupID)
|
||||
return
|
||||
}
|
||||
opConfig, buildErr := s.groupManager.BuildOperationalConfig(group)
|
||||
if buildErr != nil {
|
||||
s.logger.WithError(buildErr).Errorf("Failed to build operational config for group %d during penalty.", mapping.KeyGroupID)
|
||||
return
|
||||
}
|
||||
oldStatus := mapping.Status
|
||||
mapping.LastError = err.Error()
|
||||
mapping.ConsecutiveErrorCount++
|
||||
// Use the group-specific threshold
|
||||
threshold := *opConfig.KeyBlacklistThreshold
|
||||
if mapping.ConsecutiveErrorCount >= threshold {
|
||||
mapping.Status = models.StatusCooldown
|
||||
cooldownDuration := time.Duration(*opConfig.KeyCooldownMinutes) * time.Minute
|
||||
cooldownTime := time.Now().Add(cooldownDuration)
|
||||
mapping.CooldownUntil = &cooldownTime
|
||||
s.logger.Warnf("Mapping (G:%d, K:%d) reached error threshold (%d) during health check and is now in COOLDOWN for %v.", mapping.KeyGroupID, mapping.APIKeyID, threshold, cooldownDuration)
|
||||
}
|
||||
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil {
|
||||
s.logger.WithError(errDb).Errorf("Failed to penalize mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
|
||||
return
|
||||
}
|
||||
if oldStatus != mapping.Status {
|
||||
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) revokeMapping(mapping *models.GroupAPIKeyMapping, err error) {
|
||||
oldStatus := mapping.Status
|
||||
if oldStatus == models.StatusBanned {
|
||||
return // Already banned, do nothing.
|
||||
}
|
||||
|
||||
mapping.Status = models.StatusBanned
|
||||
mapping.LastError = "Definitive error: " + err.Error()
|
||||
mapping.ConsecutiveErrorCount = 0 // Reset counter as this is a final state for this group
|
||||
|
||||
if errDb := s.keyRepo.UpdateMapping(mapping); errDb != nil {
|
||||
s.logger.WithError(errDb).Errorf("Failed to revoke mapping (G:%d, K:%d).", mapping.KeyGroupID, mapping.APIKeyID)
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Warnf("Mapping (G:%d, K:%d) has been BANNED due to a definitive error: %v", mapping.KeyGroupID, mapping.APIKeyID, err)
|
||||
s.publishKeyStatusChangedEvent(mapping.KeyGroupID, mapping.APIKeyID, oldStatus, mapping.Status)
|
||||
|
||||
s.logger.Infof("Triggering MasterStatus update for definitively failed key ID %d.", mapping.APIKeyID)
|
||||
if err := s.keyRepo.UpdateAPIKeyStatus(mapping.APIKeyID, models.MasterStatusRevoked); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to update master status for key ID %d after group-level ban.", mapping.APIKeyID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) performUpstreamChecks() {
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.UpstreamCheckTimeoutSeconds) * time.Second
|
||||
var upstreams []*models.UpstreamEndpoint
|
||||
if err := s.db.Find(&upstreams).Error; err != nil {
|
||||
s.logger.WithError(err).Error("Failed to retrieve upstreams.")
|
||||
return
|
||||
}
|
||||
if len(upstreams) == 0 {
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Starting validation for %d upstreams.", len(upstreams))
|
||||
var wg sync.WaitGroup
|
||||
for _, u := range upstreams {
|
||||
wg.Add(1)
|
||||
go func(upstream *models.UpstreamEndpoint) {
|
||||
defer wg.Done()
|
||||
oldStatus := upstream.Status
|
||||
isAlive := s.checkEndpoint(upstream.URL, timeout)
|
||||
newStatus := StatusInactive
|
||||
if isAlive {
|
||||
newStatus = StatusActive
|
||||
}
|
||||
s.lastResultsMutex.Lock()
|
||||
s.lastResults[upstream.URL] = newStatus
|
||||
s.lastResultsMutex.Unlock()
|
||||
if oldStatus != newStatus {
|
||||
s.logger.Infof("Upstream '%s' status changed from %s -> %s", upstream.URL, oldStatus, newStatus)
|
||||
if err := s.db.Model(upstream).Update("status", newStatus).Error; err != nil {
|
||||
s.logger.WithError(err).WithField("upstream_id", upstream.ID).Error("Failed to update upstream status.")
|
||||
} else {
|
||||
s.publishUpstreamHealthChangedEvent(upstream, oldStatus, newStatus)
|
||||
}
|
||||
}
|
||||
}(u)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) checkEndpoint(urlStr string, timeout time.Duration) bool {
|
||||
client := http.Client{Timeout: timeout}
|
||||
resp, err := client.Head(urlStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return resp.StatusCode < http.StatusInternalServerError
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) performProxyChecks() {
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.ProxyCheckTimeoutSeconds) * time.Second
|
||||
var proxies []*models.ProxyConfig
|
||||
if err := s.db.Find(&proxies).Error; err != nil {
|
||||
s.logger.WithError(err).Error("Failed to retrieve proxies.")
|
||||
return
|
||||
}
|
||||
if len(proxies) == 0 {
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Starting validation for %d proxies.", len(proxies))
|
||||
var wg sync.WaitGroup
|
||||
for _, p := range proxies {
|
||||
wg.Add(1)
|
||||
go func(proxyCfg *models.ProxyConfig) {
|
||||
defer wg.Done()
|
||||
isAlive := s.checkProxy(proxyCfg, timeout)
|
||||
newStatus := StatusInactive
|
||||
if isAlive {
|
||||
newStatus = StatusActive
|
||||
}
|
||||
s.lastResultsMutex.Lock()
|
||||
s.lastResults[proxyCfg.Address] = newStatus
|
||||
s.lastResultsMutex.Unlock()
|
||||
if proxyCfg.Status != newStatus {
|
||||
s.logger.Infof("Proxy '%s' status changed from %s -> %s", proxyCfg.Address, proxyCfg.Status, newStatus)
|
||||
if err := s.db.Model(proxyCfg).Update("status", newStatus).Error; err != nil {
|
||||
s.logger.WithError(err).WithField("proxy_id", proxyCfg.ID).Error("Failed to update proxy status.")
|
||||
}
|
||||
}
|
||||
}(p)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) checkProxy(proxyCfg *models.ProxyConfig, timeout time.Duration) bool {
|
||||
transport := &http.Transport{}
|
||||
switch proxyCfg.Protocol {
|
||||
case "http", "https":
|
||||
proxyUrl, err := url.Parse(proxyCfg.Protocol + "://" + proxyCfg.Address)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Invalid proxy URL format.")
|
||||
return false
|
||||
}
|
||||
transport.Proxy = http.ProxyURL(proxyUrl)
|
||||
case "socks5":
|
||||
dialer, err := proxy.SOCKS5("tcp", proxyCfg.Address, nil, proxy.Direct)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).WithField("proxy_address", proxyCfg.Address).Warn("Failed to create SOCKS5 dialer.")
|
||||
return false
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
default:
|
||||
s.logger.WithField("protocol", proxyCfg.Protocol).Warn("Unsupported proxy protocol.")
|
||||
return false
|
||||
}
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: timeout,
|
||||
}
|
||||
resp, err := client.Get(ProxyCheckTargetURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) publishKeyStatusChangedEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus) {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
KeyID: keyID,
|
||||
GroupID: groupID,
|
||||
OldStatus: oldStatus,
|
||||
NewStatus: newStatus,
|
||||
ChangeReason: "health_check",
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
payload, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to marshal KeyStatusChangedEvent for group %d.", groupID)
|
||||
return
|
||||
}
|
||||
if err := s.store.Publish(models.TopicKeyStatusChanged, payload); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to publish KeyStatusChangedEvent for group %d.", groupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) publishUpstreamHealthChangedEvent(upstream *models.UpstreamEndpoint, oldStatus, newStatus string) {
|
||||
event := models.UpstreamHealthChangedEvent{
|
||||
UpstreamID: upstream.ID,
|
||||
UpstreamURL: upstream.URL,
|
||||
OldStatus: oldStatus,
|
||||
NewStatus: newStatus,
|
||||
Latency: 0,
|
||||
Reason: "health_check",
|
||||
CheckedAt: time.Now(),
|
||||
}
|
||||
payload, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Failed to marshal UpstreamHealthChangedEvent.")
|
||||
return
|
||||
}
|
||||
if err := s.store.Publish(models.TopicUpstreamHealthChanged, payload); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to publish UpstreamHealthChangedEvent.")
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Global Base Key Check (New Logic)
|
||||
// =========================================================================
|
||||
|
||||
func (s *HealthCheckService) runBaseKeyCheckLoop() {
|
||||
defer s.wg.Done()
|
||||
s.logger.Info("Global base key check loop started.")
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
|
||||
if !settings.EnableBaseKeyCheck {
|
||||
s.logger.Info("Global base key check is disabled.")
|
||||
return
|
||||
}
|
||||
|
||||
// Perform an initial check on startup
|
||||
s.performBaseKeyChecks()
|
||||
|
||||
interval := time.Duration(settings.BaseKeyCheckIntervalMinutes) * time.Minute
|
||||
if interval <= 0 {
|
||||
s.logger.Warnf("Invalid BaseKeyCheckIntervalMinutes: %d. Disabling base key check loop.", settings.BaseKeyCheckIntervalMinutes)
|
||||
return
|
||||
}
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.performBaseKeyChecks()
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Global base key check loop stopped.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HealthCheckService) performBaseKeyChecks() {
|
||||
s.logger.Info("Starting global base key check cycle.")
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
||||
endpoint := settings.BaseKeyCheckEndpoint
|
||||
concurrency := settings.BaseKeyCheckConcurrency
|
||||
keys, err := s.keyRepo.GetActiveMasterKeys()
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Failed to fetch active master keys for base check.")
|
||||
return
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
s.logger.Info("No active master keys to perform base check on.")
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Performing base check on %d active master keys.", len(keys))
|
||||
jobs := make(chan *models.APIKey, len(keys))
|
||||
var wg sync.WaitGroup
|
||||
if concurrency <= 0 {
|
||||
concurrency = 5 // Safe default
|
||||
}
|
||||
for w := 0; w < concurrency; w++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for key := range jobs {
|
||||
err := s.keyValidationService.ValidateSingleKey(key, timeout, endpoint)
|
||||
if err != nil && CustomErrors.IsPermanentUpstreamError(err.Error()) {
|
||||
oldStatus := key.MasterStatus
|
||||
s.logger.Warnf("Key ID %d (%s...) failed definitive base check. Setting MasterStatus to REVOKED. Reason: %v", key.ID, key.APIKey[:4], err)
|
||||
if updateErr := s.keyRepo.UpdateAPIKeyStatus(key.ID, models.MasterStatusRevoked); updateErr != nil {
|
||||
s.logger.WithError(updateErr).Errorf("Failed to update master status for key ID %d.", key.ID)
|
||||
} else {
|
||||
s.publishMasterKeyStatusChangedEvent(key.ID, oldStatus, models.MasterStatusRevoked)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
for _, key := range keys {
|
||||
jobs <- key
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
s.logger.Info("Global base key check cycle finished.")
|
||||
}
|
||||
|
||||
// 事件发布辅助函数
|
||||
func (s *HealthCheckService) publishMasterKeyStatusChangedEvent(keyID uint, oldStatus, newStatus models.MasterAPIKeyStatus) {
|
||||
event := models.MasterKeyStatusChangedEvent{
|
||||
KeyID: keyID,
|
||||
OldMasterStatus: oldStatus,
|
||||
NewMasterStatus: newStatus,
|
||||
ChangeReason: "base_health_check",
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
payload, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to marshal MasterKeyStatusChangedEvent for key %d.", keyID)
|
||||
return
|
||||
}
|
||||
if err := s.store.Publish(models.TopicMasterKeyStatusChanged, payload); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to publish MasterKeyStatusChangedEvent for key %d.", keyID)
|
||||
}
|
||||
}
|
||||
397
internal/service/key_import_service.go
Normal file
397
internal/service/key_import_service.go
Normal file
@@ -0,0 +1,397 @@
|
||||
// Filename: internal/service/key_import_service.go
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/task"
|
||||
"gemini-balancer/internal/utils"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
TaskTypeAddKeysToGroup = "add_keys_to_group"
|
||||
TaskTypeUnlinkKeysFromGroup = "unlink_keys_from_group"
|
||||
TaskTypeHardDeleteKeys = "hard_delete_keys"
|
||||
TaskTypeRestoreKeys = "restore_keys"
|
||||
chunkSize = 500
|
||||
)
|
||||
|
||||
type KeyImportService struct {
|
||||
taskService task.Reporter
|
||||
keyRepo repository.KeyRepository
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
apiKeyService *APIKeyService
|
||||
}
|
||||
|
||||
func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.Store, as *APIKeyService, logger *logrus.Logger) *KeyImportService {
|
||||
return &KeyImportService{
|
||||
taskService: ts,
|
||||
keyRepo: kr,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "KeyImportService🚀"),
|
||||
apiKeyService: as,
|
||||
}
|
||||
}
|
||||
|
||||
// --- 通用的 Panic-Safe 任務執行器 ---
|
||||
func (s *KeyImportService) runTaskWithRecovery(taskID string, resourceID string, taskFunc func()) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err := fmt.Errorf("panic recovered in task %s: %v", taskID, r)
|
||||
s.logger.Error(err)
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
|
||||
}
|
||||
}()
|
||||
taskFunc()
|
||||
}
|
||||
|
||||
// --- Public Task Starters ---
|
||||
|
||||
func (s *KeyImportService) StartAddKeysTask(groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found in input text")
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||||
s.runAddKeysTask(taskStatus.ID, resourceID, groupID, keys, validateOnImport)
|
||||
})
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyImportService) StartUnlinkKeysTask(groupID uint, keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||||
s.runUnlinkKeysTask(taskStatus.ID, resourceID, groupID, keys)
|
||||
})
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyImportService) StartHardDeleteKeysTask(keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
resourceID := "global_hard_delete" // Global lock
|
||||
taskStatus, err := s.taskService.StartTask(0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||||
s.runHardDeleteKeysTask(taskStatus.ID, resourceID, keys)
|
||||
})
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyImportService) StartRestoreKeysTask(keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
resourceID := "global_restore_keys" // Global lock
|
||||
taskStatus, err := s.taskService.StartTask(0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||||
s.runRestoreKeysTask(taskStatus.ID, resourceID, keys)
|
||||
})
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
// --- Private Task Runners ---
|
||||
|
||||
func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
|
||||
// 步骤 1: 对输入的原始 key 列表进行去重。
|
||||
uniqueKeysMap := make(map[string]struct{})
|
||||
var uniqueKeyStrings []string
|
||||
for _, kStr := range keys {
|
||||
if _, exists := uniqueKeysMap[kStr]; !exists {
|
||||
uniqueKeysMap[kStr] = struct{}{}
|
||||
uniqueKeyStrings = append(uniqueKeyStrings, kStr)
|
||||
}
|
||||
}
|
||||
if len(uniqueKeyStrings) == 0 {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
|
||||
return
|
||||
}
|
||||
// 步骤 2: 确保所有 Key 在主表中存在(创建或恢复),并获取它们完整的实体。
|
||||
keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings))
|
||||
for i, keyStr := range uniqueKeyStrings {
|
||||
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
|
||||
}
|
||||
allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
|
||||
return
|
||||
}
|
||||
// 步骤 3: 找出在这些 Key 中,哪些【已经】被链接到了当前分组。
|
||||
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
|
||||
return
|
||||
}
|
||||
alreadyLinkedIDSet := make(map[uint]struct{})
|
||||
for _, key := range alreadyLinkedModels {
|
||||
alreadyLinkedIDSet[key.ID] = struct{}{}
|
||||
}
|
||||
// 步骤 4: 确定【真正需要】被链接到当前分组的 key 列表 (我们的"工作集")。
|
||||
var keysToLink []models.APIKey
|
||||
for _, key := range allKeyModels {
|
||||
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
|
||||
keysToLink = append(keysToLink, key)
|
||||
}
|
||||
}
|
||||
// 步骤 5: 更新任务的 Total 总量为精确的 "工作集" 大小。
|
||||
if err := s.taskService.UpdateTotalByID(taskID, len(keysToLink)); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
||||
}
|
||||
// 步骤 6: 分块处理【链接Key到组】的操作,并实时更新进度。
|
||||
if len(keysToLink) > 0 {
|
||||
idsToLink := make([]uint, len(keysToLink))
|
||||
for i, key := range keysToLink {
|
||||
idsToLink[i] = key.ID
|
||||
}
|
||||
for i := 0; i < len(idsToLink); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(idsToLink) {
|
||||
end = len(idsToLink)
|
||||
}
|
||||
chunk := idsToLink[i:end]
|
||||
if err := s.keyRepo.LinkKeysToGroup(groupID, chunk); err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
|
||||
return
|
||||
}
|
||||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||||
}
|
||||
}
|
||||
|
||||
// 步骤 7: 准备最终结果并结束任务。
|
||||
result := gin.H{
|
||||
"newly_linked_count": len(keysToLink),
|
||||
"already_linked_count": len(alreadyLinkedIDSet),
|
||||
"total_linked_count": len(allKeyModels),
|
||||
}
|
||||
// 步骤 8: 根据 `validateOnImport` 标志, 发布事件或直接激活 (只对新链接的keys操作)。
|
||||
if len(keysToLink) > 0 {
|
||||
idsToLink := make([]uint, len(keysToLink))
|
||||
for i, key := range keysToLink {
|
||||
idsToLink[i] = key.ID
|
||||
}
|
||||
if validateOnImport {
|
||||
s.publishImportGroupCompletedEvent(groupID, idsToLink)
|
||||
for _, keyID := range idsToLink {
|
||||
s.publishSingleKeyChangeEvent(groupID, keyID, "", models.StatusPendingValidation, "key_linked")
|
||||
}
|
||||
} else {
|
||||
for _, keyID := range idsToLink {
|
||||
if _, err := s.apiKeyService.UpdateMappingStatus(groupID, keyID, models.StatusActive); err != nil {
|
||||
s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
// runUnlinkKeysTask
|
||||
func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID uint, keys []string) {
|
||||
uniqueKeysMap := make(map[string]struct{})
|
||||
var uniqueKeys []string
|
||||
for _, kStr := range keys {
|
||||
if _, exists := uniqueKeysMap[kStr]; !exists {
|
||||
uniqueKeysMap[kStr] = struct{}{}
|
||||
uniqueKeys = append(uniqueKeys, kStr)
|
||||
}
|
||||
}
|
||||
// 步骤 1: 一次性找出所有输入 Key 中,实际存在于本组的 Key 实体。这是我们的"工作集"。
|
||||
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(keysToUnlink) == 0 {
|
||||
result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
return
|
||||
}
|
||||
idsToUnlink := make([]uint, len(keysToUnlink))
|
||||
for i, key := range keysToUnlink {
|
||||
idsToUnlink[i] = key.ID
|
||||
}
|
||||
// 步骤 2: 更新任务的 Total 总量为精确的 "工作集" 大小。
|
||||
if err := s.taskService.UpdateTotalByID(taskID, len(idsToUnlink)); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
||||
}
|
||||
var totalUnlinked int64
|
||||
// 步骤 3: 分块处理【解绑Key】的操作,并上报进度。
|
||||
for i := 0; i < len(idsToUnlink); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(idsToUnlink) {
|
||||
end = len(idsToUnlink)
|
||||
}
|
||||
chunk := idsToUnlink[i:end]
|
||||
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(groupID, chunk)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
|
||||
return
|
||||
}
|
||||
totalUnlinked += unlinked
|
||||
|
||||
for _, keyID := range chunk {
|
||||
s.publishSingleKeyChangeEvent(groupID, keyID, models.StatusActive, "", "key_unlinked")
|
||||
}
|
||||
|
||||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||||
}
|
||||
|
||||
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.")
|
||||
}
|
||||
result := gin.H{
|
||||
"unlinked_count": totalUnlinked,
|
||||
"hard_deleted_count": totalDeleted,
|
||||
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
|
||||
}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys []string) {
|
||||
var totalDeleted int64
|
||||
for i := 0; i < len(keys); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(keys) {
|
||||
end = len(keys)
|
||||
}
|
||||
chunk := keys[i:end]
|
||||
|
||||
deleted, err := s.keyRepo.HardDeleteByValues(chunk)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
|
||||
return
|
||||
}
|
||||
totalDeleted += deleted
|
||||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||||
}
|
||||
|
||||
result := gin.H{
|
||||
"hard_deleted_count": totalDeleted,
|
||||
"not_found_count": int64(len(keys)) - totalDeleted,
|
||||
}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(0, "keys_hard_deleted") // Global event
|
||||
}
|
||||
|
||||
func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys []string) {
|
||||
var restoredCount int64
|
||||
for i := 0; i < len(keys); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(keys) {
|
||||
end = len(keys)
|
||||
}
|
||||
chunk := keys[i:end]
|
||||
|
||||
count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
|
||||
return
|
||||
}
|
||||
restoredCount += count
|
||||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||||
}
|
||||
result := gin.H{
|
||||
"restored_count": restoredCount,
|
||||
"not_found_count": int64(len(keys)) - restoredCount,
|
||||
}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(0, "keys_bulk_restored") // Global event
|
||||
}
|
||||
|
||||
func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
GroupID: groupID,
|
||||
KeyID: keyID,
|
||||
OldStatus: oldStatus,
|
||||
NewStatus: newStatus,
|
||||
ChangeReason: reason,
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
s.logger.WithError(err).WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_id": keyID,
|
||||
"reason": reason,
|
||||
}).Error("Failed to publish single key change event.")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KeyImportService) publishChangeEvent(groupID uint, reason string) {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
GroupID: groupID,
|
||||
ChangeReason: reason,
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
_ = s.store.Publish(models.TopicKeyStatusChanged, eventData)
|
||||
}
|
||||
|
||||
func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs []uint) {
|
||||
if len(keyIDs) == 0 {
|
||||
return
|
||||
}
|
||||
event := models.ImportGroupCompletedEvent{
|
||||
GroupID: groupID,
|
||||
KeyIDs: keyIDs,
|
||||
CompletedAt: time.Now(),
|
||||
}
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.")
|
||||
return
|
||||
}
|
||||
if err := s.store.Publish(models.TopicImportGroupCompleted, eventData); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.")
|
||||
} else {
|
||||
s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs))
|
||||
}
|
||||
}
|
||||
|
||||
// [NEW] StartUnlinkKeysByFilterTask starts a task to unlink keys matching a status filter.
|
||||
func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
|
||||
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
|
||||
// 1. [New] Find the keys to operate on.
|
||||
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
|
||||
}
|
||||
if len(keyValues) == 0 {
|
||||
return nil, fmt.Errorf("no keys found matching the provided filter")
|
||||
}
|
||||
// 2. [REUSE] Convert to text and call the existing, robust unlink task logic.
|
||||
keysAsText := strings.Join(keyValues, "\n")
|
||||
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
|
||||
return s.StartUnlinkKeysTask(groupID, keysAsText)
|
||||
}
|
||||
217
internal/service/key_validation_service.go
Normal file
217
internal/service/key_validation_service.go
Normal file
@@ -0,0 +1,217 @@
|
||||
// Filename: internal/service/key_validation_service.go
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/channel"
|
||||
CustomErrors "gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/task"
|
||||
"gemini-balancer/internal/utils"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
TaskTypeTestKeys = "test_keys"
|
||||
)
|
||||
|
||||
type KeyValidationService struct {
|
||||
taskService task.Reporter
|
||||
channel channel.ChannelProxy
|
||||
db *gorm.DB
|
||||
SettingsManager *settings.SettingsManager
|
||||
groupManager *GroupManager
|
||||
store store.Store
|
||||
keyRepo repository.KeyRepository
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm.DB, ss *settings.SettingsManager, gm *GroupManager, st store.Store, kr repository.KeyRepository, logger *logrus.Logger) *KeyValidationService {
|
||||
return &KeyValidationService{
|
||||
taskService: ts,
|
||||
channel: ch,
|
||||
db: db,
|
||||
SettingsManager: ss,
|
||||
groupManager: gm,
|
||||
store: st,
|
||||
keyRepo: kr,
|
||||
logger: logger.WithField("component", "KeyValidationService🧐"),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout time.Duration, endpoint string) error {
|
||||
if err := s.keyRepo.Decrypt(key); err != nil {
|
||||
return fmt.Errorf("failed to decrypt key %d for validation: %w", key.ID, err)
|
||||
}
|
||||
client := &http.Client{Timeout: timeout}
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to create request for key validation (ID: %d): %v", key.ID, err)
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
s.channel.ModifyRequest(req, key) // Use the injected channel to modify the request
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
// This is a network-level error (e.g., timeout, DNS issue)
|
||||
return fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil // Success
|
||||
}
|
||||
|
||||
// Read the body for more error details
|
||||
bodyBytes, readErr := io.ReadAll(resp.Body)
|
||||
var errorMsg string
|
||||
if readErr != nil {
|
||||
errorMsg = "Failed to read error response body"
|
||||
} else {
|
||||
errorMsg = string(bodyBytes)
|
||||
}
|
||||
|
||||
// This is a validation failure with a specific HTTP status code
|
||||
return &CustomErrors.APIError{
|
||||
HTTPStatus: resp.StatusCode,
|
||||
Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg),
|
||||
Code: "VALIDATION_FAILED",
|
||||
}
|
||||
}
|
||||
|
||||
// --- 异步任务方法 (全面适配新task包) ---
|
||||
func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string) (*task.Status, error) {
|
||||
keyStrings := utils.ParseKeysFromText(keysText)
|
||||
if len(keyStrings) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
|
||||
}
|
||||
apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
}
|
||||
if len(apiKeyModels) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system")
|
||||
}
|
||||
if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task.")
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation")
|
||||
}
|
||||
group, ok := s.groupManager.GetGroupByID(groupID)
|
||||
if !ok {
|
||||
// [FIX] Correctly use the NewAPIError constructor for a missing group.
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
|
||||
}
|
||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err // Pass up the error from task service (e.g., "task already running")
|
||||
}
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
||||
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskStatus.ID, resourceID, nil, err) // End task with error if endpoint fails
|
||||
return nil, err
|
||||
}
|
||||
var concurrency int
|
||||
if opConfig.KeyCheckConcurrency != nil {
|
||||
concurrency = *opConfig.KeyCheckConcurrency
|
||||
} else {
|
||||
concurrency = settings.BaseKeyCheckConcurrency
|
||||
}
|
||||
go s.runTestKeysTask(taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
finalResults := make([]models.KeyTestResult, len(keys))
|
||||
processedCount := 0
|
||||
if concurrency <= 0 {
|
||||
concurrency = 10
|
||||
}
|
||||
type job struct {
|
||||
Index int
|
||||
Value models.APIKey
|
||||
}
|
||||
jobs := make(chan job, len(keys))
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := range jobs {
|
||||
apiKeyModel := j.Value
|
||||
keyToValidate := apiKeyModel
|
||||
validationErr := s.ValidateSingleKey(&keyToValidate, timeout, endpoint)
|
||||
|
||||
var currentResult models.KeyTestResult
|
||||
event := models.RequestFinishedEvent{
|
||||
GroupID: groupID,
|
||||
KeyID: apiKeyModel.ID,
|
||||
}
|
||||
if validationErr == nil {
|
||||
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "valid", Message: "Validation successful."}
|
||||
event.IsSuccess = true
|
||||
} else {
|
||||
var apiErr *CustomErrors.APIError
|
||||
if CustomErrors.As(validationErr, &apiErr) {
|
||||
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "invalid", Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message)}
|
||||
event.Error = apiErr
|
||||
} else {
|
||||
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "error", Message: "Validation check failed: " + validationErr.Error()}
|
||||
event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
|
||||
}
|
||||
event.IsSuccess = false
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
if err := s.store.Publish(models.TopicRequestFinished, eventData); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
finalResults[j.Index] = currentResult
|
||||
processedCount++
|
||||
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
for i, k := range keys {
|
||||
jobs <- job{Index: i, Value: k}
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"results": finalResults}, nil)
|
||||
}
|
||||
|
||||
func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
|
||||
s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses)
|
||||
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
}
|
||||
if len(keyValues) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria.")
|
||||
}
|
||||
keysAsText := strings.Join(keyValues, "\n")
|
||||
s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID)
|
||||
return s.StartTestKeysTask(groupID, keysAsText)
|
||||
}
|
||||
65
internal/service/log_service.go
Normal file
65
internal/service/log_service.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/models"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type LogService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewLogService(db *gorm.DB) *LogService {
|
||||
return &LogService{db: db}
|
||||
}
|
||||
|
||||
// Record 记录一条日志到数据库 (TODO 暂时保留简单实现,后续再重构为异步)
|
||||
func (s *LogService) Record(log *models.RequestLog) error {
|
||||
return s.db.Create(log).Error
|
||||
}
|
||||
|
||||
func (s *LogService) GetLogs(c *gin.Context) ([]models.RequestLog, error) {
|
||||
var logs []models.RequestLog
|
||||
|
||||
query := s.db.Model(&models.RequestLog{}).Scopes(s.filtersScope(c)).Order("request_time desc")
|
||||
|
||||
// 简单的分页 ( TODO 后续可以做得更复杂)
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
// 执行查询
|
||||
err := query.Limit(pageSize).Offset(offset).Find(&logs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
func (s *LogService) filtersScope(c *gin.Context) func(db *gorm.DB) *gorm.DB {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
if modelName := c.Query("model_name"); modelName != "" {
|
||||
db = db.Where("model_name = ?", modelName)
|
||||
}
|
||||
if isSuccessStr := c.Query("is_success"); isSuccessStr != "" {
|
||||
if isSuccess, err := strconv.ParseBool(isSuccessStr); err == nil {
|
||||
db = db.Where("is_success = ?", isSuccess)
|
||||
}
|
||||
}
|
||||
if statusCodeStr := c.Query("status_code"); statusCodeStr != "" {
|
||||
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil {
|
||||
db = db.Where("status_code = ?", statusCode)
|
||||
}
|
||||
}
|
||||
if keyIDStr := c.Query("key_id"); keyIDStr != "" {
|
||||
if keyID, err := strconv.ParseUint(keyIDStr, 10, 64); err == nil {
|
||||
db = db.Where("key_id = ?", keyID)
|
||||
}
|
||||
}
|
||||
return db
|
||||
}
|
||||
}
|
||||
267
internal/service/resource_service.go
Normal file
267
internal/service/resource_service.go
Normal file
@@ -0,0 +1,267 @@
|
||||
// Filename: internal/service/resource_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
apperrors "gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/settings"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoResourceAvailable = errors.New("no available resource found for the request")
|
||||
)
|
||||
|
||||
type RequestResources struct {
|
||||
KeyGroup *models.KeyGroup
|
||||
APIKey *models.APIKey
|
||||
UpstreamEndpoint *models.UpstreamEndpoint
|
||||
ProxyConfig *models.ProxyConfig
|
||||
RequestConfig *models.RequestConfig
|
||||
}
|
||||
|
||||
type ResourceService struct {
|
||||
settingsManager *settings.SettingsManager
|
||||
groupManager *GroupManager
|
||||
keyRepo repository.KeyRepository
|
||||
apiKeyService *APIKeyService
|
||||
logger *logrus.Entry
|
||||
initOnce sync.Once
|
||||
}
|
||||
|
||||
func NewResourceService(
|
||||
sm *settings.SettingsManager,
|
||||
gm *GroupManager,
|
||||
kr repository.KeyRepository,
|
||||
aks *APIKeyService,
|
||||
logger *logrus.Logger,
|
||||
) *ResourceService {
|
||||
logger.Debugf("[FORENSIC PROBE | INJECTION | ResourceService] Received 'keyRepo' param. Fingerprint: %p", kr)
|
||||
rs := &ResourceService{
|
||||
settingsManager: sm,
|
||||
groupManager: gm,
|
||||
keyRepo: kr,
|
||||
apiKeyService: aks,
|
||||
logger: logger.WithField("component", "ResourceService📦️"),
|
||||
}
|
||||
|
||||
rs.initOnce.Do(func() {
|
||||
go rs.preWarmCache(logger)
|
||||
})
|
||||
return rs
|
||||
|
||||
}
|
||||
|
||||
// --- [模式一:智能聚合模式] ---
|
||||
func (s *ResourceService) GetResourceFromBasePool(authToken *models.AuthToken, modelName string) (*RequestResources, error) {
|
||||
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"})
|
||||
log.Debug("Entering BasePool resource acquisition.")
|
||||
// 1.筛选出所有符合条件的候选组,并按优先级排序
|
||||
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken.AllowedGroups)
|
||||
if len(candidateGroups) == 0 {
|
||||
log.Warn("No candidate groups found for BasePool construction.")
|
||||
return nil, apperrors.ErrNoKeysAvailable
|
||||
}
|
||||
// 2.从 BasePool中,根据系统全局策略选择一个Key
|
||||
basePool := &repository.BasePool{
|
||||
CandidateGroups: candidateGroups,
|
||||
PollingStrategy: s.settingsManager.GetSettings().PollingStrategy,
|
||||
}
|
||||
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(basePool)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Failed to select a key from the BasePool.")
|
||||
return nil, apperrors.ErrNoKeysAvailable
|
||||
}
|
||||
// 3. 组装最终资源
|
||||
// [关键] 在此模式下,RequestConfig 永远是空的,以保证透明性。
|
||||
resources, err := s.assembleRequestResources(selectedGroup, apiKey)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.")
|
||||
return nil, err
|
||||
}
|
||||
resources.RequestConfig = &models.RequestConfig{} // 强制为空
|
||||
log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID)
|
||||
return resources, nil
|
||||
}
|
||||
|
||||
// --- [模式二:精确路由模式] ---
|
||||
func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, groupName string) (*RequestResources, error) {
|
||||
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"})
|
||||
log.Debug("Entering PreciseRoute resource acquisition.")
|
||||
|
||||
targetGroup, ok := s.groupManager.GetGroupByName(groupName)
|
||||
|
||||
if !ok {
|
||||
return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.")
|
||||
}
|
||||
|
||||
if !s.isTokenAllowedForGroup(authToken, targetGroup.ID) {
|
||||
return nil, apperrors.NewAPIError(apperrors.ErrPermissionDenied, "Token does not have permission to access this group.")
|
||||
}
|
||||
|
||||
apiKey, _, err := s.keyRepo.SelectOneActiveKey(targetGroup)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Failed to select a key from the precisely targeted group.")
|
||||
return nil, apperrors.ErrNoKeysAvailable
|
||||
}
|
||||
|
||||
resources, err := s.assembleRequestResources(targetGroup, apiKey)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to assemble resources for precise route.")
|
||||
return nil, err
|
||||
}
|
||||
resources.RequestConfig = targetGroup.RequestConfig
|
||||
|
||||
log.Infof("Successfully selected KeyID %d by precise routing to GroupID %d.", apiKey.ID, targetGroup.ID)
|
||||
return resources, nil
|
||||
}
|
||||
|
||||
func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken) []string {
|
||||
allGroups := s.groupManager.GetAllGroups()
|
||||
if len(allGroups) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
allowedModelsSet := make(map[string]struct{})
|
||||
if authToken.IsAdmin {
|
||||
for _, group := range allGroups {
|
||||
for _, modelMapping := range group.AllowedModels {
|
||||
|
||||
allowedModelsSet[modelMapping.ModelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
allowedGroupIDs := make(map[uint]bool)
|
||||
for _, ag := range authToken.AllowedGroups {
|
||||
allowedGroupIDs[ag.ID] = true
|
||||
}
|
||||
for _, group := range allGroups {
|
||||
if _, ok := allowedGroupIDs[group.ID]; ok {
|
||||
for _, modelMapping := range group.AllowedModels {
|
||||
|
||||
allowedModelsSet[modelMapping.ModelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result := make([]string, 0, len(allowedModelsSet))
|
||||
for modelName := range allowedModelsSet {
|
||||
result = append(result, modelName)
|
||||
}
|
||||
sort.Strings(result)
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKey *models.APIKey) (*RequestResources, error) {
|
||||
selectedUpstream := s.selectUpstreamForGroup(group)
|
||||
if selectedUpstream == nil {
|
||||
return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.")
|
||||
}
|
||||
var proxyConfig *models.ProxyConfig
|
||||
// [注意] 代理逻辑需要一个 proxyModule 实例,我们暂时置空。后续需要重新注入依赖。
|
||||
// if group.EnableProxy && s.proxyModule != nil {
|
||||
// var err error
|
||||
// proxyConfig, err = s.proxyModule.AssignProxyIfNeeded(apiKey)
|
||||
// if err != nil {
|
||||
// s.logger.WithError(err).Warnf("Failed to assign proxy for API key %d.", apiKey.ID)
|
||||
// }
|
||||
// }
|
||||
return &RequestResources{
|
||||
KeyGroup: group,
|
||||
APIKey: apiKey,
|
||||
UpstreamEndpoint: selectedUpstream,
|
||||
ProxyConfig: proxyConfig,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models.UpstreamEndpoint {
|
||||
if len(group.AllowedUpstreams) > 0 {
|
||||
return group.AllowedUpstreams[0]
|
||||
}
|
||||
globalSettings := s.settingsManager.GetSettings()
|
||||
if globalSettings.DefaultUpstreamURL != "" {
|
||||
return &models.UpstreamEndpoint{URL: globalSettings.DefaultUpstreamURL, Status: "active"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ResourceService) preWarmCache(logger *logrus.Logger) error {
|
||||
time.Sleep(2 * time.Second)
|
||||
s.logger.Info("Performing initial key cache pre-warming...")
|
||||
if err := s.keyRepo.LoadAllKeysToStore(); err != nil {
|
||||
logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
|
||||
return err
|
||||
}
|
||||
s.logger.Info("Initial key cache pre-warming completed successfully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ResourceService) GetResourcesForRequest(modelName string, allowedGroups []*models.KeyGroup) (*RequestResources, error) {
|
||||
return nil, errors.New("GetResourcesForRequest is deprecated; use GetResourceFromBasePool or GetResourceFromGroup")
|
||||
}
|
||||
|
||||
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowedGroupsFromToken []*models.KeyGroup) []*models.KeyGroup {
|
||||
allGroupsFromCache := s.groupManager.GetAllGroups()
|
||||
var candidateGroups []*models.KeyGroup
|
||||
// 1. 确定权限范围
|
||||
allowedGroupIDs := make(map[uint]bool)
|
||||
isTokenRestricted := len(allowedGroupsFromToken) > 0
|
||||
if isTokenRestricted {
|
||||
for _, ag := range allowedGroupsFromToken {
|
||||
allowedGroupIDs[ag.ID] = true
|
||||
}
|
||||
}
|
||||
// 2. 筛选
|
||||
for _, group := range allGroupsFromCache {
|
||||
// 检查Token权限
|
||||
if isTokenRestricted && !allowedGroupIDs[group.ID] {
|
||||
continue
|
||||
}
|
||||
// 检查模型是否被允许
|
||||
isModelAllowed := false
|
||||
if len(group.AllowedModels) == 0 { // 如果组不限制模型,则允许
|
||||
isModelAllowed = true
|
||||
} else {
|
||||
for _, m := range group.AllowedModels {
|
||||
if m.ModelName == modelName {
|
||||
isModelAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if isModelAllowed {
|
||||
candidateGroups = append(candidateGroups, group)
|
||||
}
|
||||
}
|
||||
|
||||
// 3.按 Order 字段升序排序
|
||||
sort.SliceStable(candidateGroups, func(i, j int) bool {
|
||||
return candidateGroups[i].Order < candidateGroups[j].Order
|
||||
})
|
||||
return candidateGroups
|
||||
}
|
||||
|
||||
func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, groupID uint) bool {
|
||||
if authToken.IsAdmin {
|
||||
return true
|
||||
}
|
||||
for _, allowedGroup := range authToken.AllowedGroups {
|
||||
if allowedGroup.ID == groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
|
||||
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
|
||||
return
|
||||
}
|
||||
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
|
||||
}
|
||||
83
internal/service/security_service.go
Normal file
83
internal/service/security_service.go
Normal file
@@ -0,0 +1,83 @@
|
||||
// Filename: internal/service/security_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256" // [NEW] Import crypto library for hashing
|
||||
"encoding/hex" // [NEW] Import hex encoding
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository" // [NEW] Import repository
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const loginAttemptsKey = "security:login_attempts"
|
||||
|
||||
type SecurityService struct {
|
||||
repo repository.AuthTokenRepository
|
||||
store store.Store
|
||||
SettingsManager *settings.SettingsManager
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
// NewSecurityService signature updated to accept the repository.
|
||||
func NewSecurityService(repo repository.AuthTokenRepository, store store.Store, settingsManager *settings.SettingsManager, logger *logrus.Logger) *SecurityService {
|
||||
return &SecurityService{
|
||||
repo: repo,
|
||||
store: store,
|
||||
SettingsManager: settingsManager,
|
||||
logger: logger.WithField("component", "SecurityService🛡️"),
|
||||
}
|
||||
}
|
||||
|
||||
// AuthenticateToken is now secure and efficient.
|
||||
func (s *SecurityService) AuthenticateToken(tokenValue string) (*models.AuthToken, error) {
|
||||
if tokenValue == "" {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
// [REFACTORED]
|
||||
// 1. Hash the incoming plaintext token.
|
||||
hash := sha256.Sum256([]byte(tokenValue))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
|
||||
// 2. Delegate the lookup to the repository using the hash.
|
||||
return s.repo.GetTokenByHashedValue(tokenHash)
|
||||
}
|
||||
|
||||
// IsIPBanned
|
||||
func (s *SecurityService) IsIPBanned(ctx context.Context, ip string) (bool, error) {
|
||||
banKey := fmt.Sprintf("banned_ip:%s", ip)
|
||||
return s.store.Exists(banKey)
|
||||
}
|
||||
|
||||
// RecordFailedLoginAttempt
|
||||
func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip string) error {
|
||||
if !s.SettingsManager.IsIPBanEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
count, err := s.store.HIncrBy(loginAttemptsKey, ip, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
maxAttempts := s.SettingsManager.GetMaxLoginAttempts()
|
||||
if count >= int64(maxAttempts) {
|
||||
banDuration := s.SettingsManager.GetIPBanDuration()
|
||||
banKey := fmt.Sprintf("banned_ip:%s", ip)
|
||||
|
||||
if err := s.store.Set(banKey, []byte("1"), banDuration); err != nil {
|
||||
return err
|
||||
}
|
||||
s.logger.Warnf("IP BANNED: IP [%s] has been banned for %v due to excessive failed login attempts.", ip, banDuration)
|
||||
|
||||
s.store.HDel(loginAttemptsKey, ip)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
196
internal/service/stats_service.go
Normal file
196
internal/service/stats_service.go
Normal file
@@ -0,0 +1,196 @@
|
||||
// Filename: internal/service/stats_service.go
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/store"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type StatsService struct {
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
keyRepo repository.KeyRepository
|
||||
logger *logrus.Entry
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository, logger *logrus.Logger) *StatsService {
|
||||
return &StatsService{
|
||||
db: db,
|
||||
store: s,
|
||||
keyRepo: repo,
|
||||
logger: logger.WithField("component", "StatsService"),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StatsService) Start() {
|
||||
s.logger.Info("Starting event listener for stats maintenance.")
|
||||
sub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer sub.Close()
|
||||
for {
|
||||
select {
|
||||
case msg := <-sub.Channel():
|
||||
var event models.KeyStatusChangedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err)
|
||||
continue
|
||||
}
|
||||
s.handleKeyStatusChange(&event)
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Stopping stats event listener.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *StatsService) Stop() {
|
||||
close(s.stopChan)
|
||||
}
|
||||
|
||||
func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent) {
|
||||
if event.GroupID == 0 {
|
||||
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
|
||||
return
|
||||
}
|
||||
statsKey := fmt.Sprintf("stats:group:%d", event.GroupID)
|
||||
s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason)
|
||||
|
||||
switch event.ChangeReason {
|
||||
case "key_unlinked", "key_hard_deleted":
|
||||
if event.OldStatus != "" {
|
||||
s.store.HIncrBy(statsKey, "total_keys", -1)
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||||
} else {
|
||||
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(event.GroupID)
|
||||
}
|
||||
case "key_linked":
|
||||
if event.NewStatus != "" {
|
||||
s.store.HIncrBy(statsKey, "total_keys", 1)
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||||
} else {
|
||||
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
|
||||
s.RecalculateGroupKeyStats(event.GroupID)
|
||||
}
|
||||
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||||
default:
|
||||
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(event.GroupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
|
||||
s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID)
|
||||
var results []struct {
|
||||
Status models.APIKeyStatus
|
||||
Count int64
|
||||
}
|
||||
if err := s.db.Model(&models.GroupAPIKeyMapping{}).
|
||||
Where("key_group_id = ?", groupID).
|
||||
Select("status, COUNT(*) as count").
|
||||
Group("status").
|
||||
Scan(&results).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
statsKey := fmt.Sprintf("stats:group:%d", groupID)
|
||||
|
||||
updates := make(map[string]interface{})
|
||||
totalKeys := int64(0)
|
||||
for _, res := range results {
|
||||
updates[fmt.Sprintf("%s_keys", res.Status)] = res.Count
|
||||
totalKeys += res.Count
|
||||
}
|
||||
updates["total_keys"] = totalKeys
|
||||
|
||||
if err := s.store.Del(statsKey); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
|
||||
}
|
||||
if err := s.store.HSet(statsKey, updates); err != nil {
|
||||
return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err)
|
||||
}
|
||||
s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StatsService) GetDashboardStats() (*models.DashboardStatsResponse, error) {
|
||||
// TODO 逻辑:
|
||||
// 1. 从Redis中获取所有分组的Key统计 (HGetAll)
|
||||
// 2. 从 stats_hourly 表中获取过去24小时的请求数和错误率
|
||||
// 3. 组合成 DashboardStatsResponse
|
||||
// ... 这个方法的具体实现,我们可以在DashboardQueryService中完成,
|
||||
// 这里我们先确保StatsService的核心职责(维护缓存)已经完成。
|
||||
// 为了编译通过,我们先返回一个空对象。
|
||||
|
||||
// 伪代码:
|
||||
// keyCounts, _ := s.store.HGetAll("stats:global:keys")
|
||||
// ...
|
||||
|
||||
return &models.DashboardStatsResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *StatsService) AggregateHourlyStats() error {
|
||||
s.logger.Info("Starting aggregation of the last hour's request data...")
|
||||
now := time.Now()
|
||||
endTime := now.Truncate(time.Hour) // 例如:15:23 -> 15:00
|
||||
startTime := endTime.Add(-1 * time.Hour) // 15:00 -> 14:00
|
||||
|
||||
s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339))
|
||||
type aggregationResult struct {
|
||||
GroupID uint
|
||||
ModelName string
|
||||
RequestCount int64
|
||||
SuccessCount int64
|
||||
PromptTokens int64
|
||||
CompletionTokens int64
|
||||
}
|
||||
var results []aggregationResult
|
||||
err := s.db.Model(&models.RequestLog{}).
|
||||
Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens").
|
||||
Where("request_time >= ? AND request_time < ?", startTime, endTime).
|
||||
Group("group_id, model_name").
|
||||
Scan(&results).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query aggregation data from request_logs: %w", err)
|
||||
}
|
||||
if len(results) == 0 {
|
||||
s.logger.Info("No request logs found in the last hour to aggregate. Skipping.")
|
||||
return nil
|
||||
}
|
||||
|
||||
s.logger.Infof("Found %d aggregated data rows to insert/update.", len(results))
|
||||
|
||||
var hourlyStats []models.StatsHourly
|
||||
for _, res := range results {
|
||||
hourlyStats = append(hourlyStats, models.StatsHourly{
|
||||
Time: startTime, // 所有记录的时间戳都是该小时的起点
|
||||
GroupID: res.GroupID,
|
||||
ModelName: res.ModelName,
|
||||
RequestCount: res.RequestCount,
|
||||
SuccessCount: res.SuccessCount,
|
||||
PromptTokens: res.PromptTokens,
|
||||
CompletionTokens: res.CompletionTokens,
|
||||
})
|
||||
}
|
||||
|
||||
return s.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
|
||||
}).Create(&hourlyStats).Error
|
||||
}
|
||||
72
internal/service/token_manager.go
Normal file
72
internal/service/token_manager.go
Normal file
@@ -0,0 +1,72 @@
|
||||
// Filename: internal/service/token_manager.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/syncer"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const TopicTokenChanged = "events:token_changed"
|
||||
|
||||
type TokenManager struct {
|
||||
repo repository.AuthTokenRepository
|
||||
syncer *syncer.CacheSyncer[[]*models.AuthToken]
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
// NewTokenManager's signature is updated to accept the new repository.
|
||||
func NewTokenManager(repo repository.AuthTokenRepository, store store.Store, logger *logrus.Logger) (*TokenManager, error) {
|
||||
tm := &TokenManager{
|
||||
repo: repo,
|
||||
logger: logger.WithField("component", "TokenManager🔐"),
|
||||
}
|
||||
|
||||
tokenLoader := func() ([]*models.AuthToken, error) {
|
||||
tm.logger.Info("Loading all auth tokens via repository...")
|
||||
tokens, err := tm.repo.GetAllTokensWithGroups()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load auth tokens from repo: %w", err)
|
||||
}
|
||||
tm.logger.Infof("Successfully loaded and decrypted %d auth tokens into cache.", len(tokens))
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
s, err := syncer.NewCacheSyncer(tokenLoader, store, TopicTokenChanged)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token manager syncer: %w", err)
|
||||
}
|
||||
tm.syncer = s
|
||||
|
||||
return tm, nil
|
||||
}
|
||||
|
||||
func (tm *TokenManager) GetAllTokens() []*models.AuthToken {
|
||||
return tm.syncer.Get()
|
||||
}
|
||||
|
||||
// BatchUpdateTokens is now a thin wrapper around the repository method.
|
||||
func (tm *TokenManager) BatchUpdateTokens(incomingTokens []*models.TokenUpdateRequest) error {
|
||||
tm.logger.Info("Delegating BatchUpdateTokens to repository...")
|
||||
|
||||
if err := tm.repo.BatchUpdateTokens(incomingTokens); err != nil {
|
||||
tm.logger.Errorf("Repository failed to batch update tokens: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
tm.logger.Info("BatchUpdateTokens finished successfully. Invalidating cache.")
|
||||
return tm.Invalidate()
|
||||
}
|
||||
|
||||
func (tm *TokenManager) Invalidate() error {
|
||||
return tm.syncer.Invalidate()
|
||||
}
|
||||
|
||||
func (tm *TokenManager) Stop() {
|
||||
tm.syncer.Stop()
|
||||
}
|
||||
Reference in New Issue
Block a user