248 lines
7.8 KiB
Go
248 lines
7.8 KiB
Go
// Filename: internal/service/resource_service.go
|
||
|
||
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
apperrors "gemini-balancer/internal/errors"
|
||
"gemini-balancer/internal/models"
|
||
"gemini-balancer/internal/repository"
|
||
"gemini-balancer/internal/settings"
|
||
"sort"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/sirupsen/logrus"
|
||
)
|
||
|
||
var (
|
||
ErrNoResourceAvailable = errors.New("no available resource found for the request")
|
||
)
|
||
|
||
type RequestResources struct {
|
||
KeyGroup *models.KeyGroup
|
||
APIKey *models.APIKey
|
||
UpstreamEndpoint *models.UpstreamEndpoint
|
||
ProxyConfig *models.ProxyConfig
|
||
RequestConfig *models.RequestConfig
|
||
}
|
||
|
||
type ResourceService struct {
|
||
settingsManager *settings.SettingsManager
|
||
groupManager *GroupManager
|
||
keyRepo repository.KeyRepository
|
||
apiKeyService *APIKeyService
|
||
logger *logrus.Entry
|
||
initOnce sync.Once
|
||
}
|
||
|
||
func NewResourceService(
|
||
sm *settings.SettingsManager,
|
||
gm *GroupManager,
|
||
kr repository.KeyRepository,
|
||
aks *APIKeyService,
|
||
logger *logrus.Logger,
|
||
) *ResourceService {
|
||
rs := &ResourceService{
|
||
settingsManager: sm,
|
||
groupManager: gm,
|
||
keyRepo: kr,
|
||
apiKeyService: aks,
|
||
logger: logger.WithField("component", "ResourceService📦️"),
|
||
}
|
||
|
||
rs.initOnce.Do(func() {
|
||
go rs.preWarmCache(logger)
|
||
})
|
||
return rs
|
||
}
|
||
|
||
func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken *models.AuthToken, modelName string) (*RequestResources, error) {
|
||
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"})
|
||
log.Debug("Entering BasePool resource acquisition.")
|
||
|
||
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken.AllowedGroups)
|
||
if len(candidateGroups) == 0 {
|
||
log.Warn("No candidate groups found for BasePool construction.")
|
||
return nil, apperrors.ErrNoKeysAvailable
|
||
}
|
||
|
||
basePool := &repository.BasePool{
|
||
CandidateGroups: candidateGroups,
|
||
PollingStrategy: s.settingsManager.GetSettings().PollingStrategy,
|
||
}
|
||
|
||
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(ctx, basePool)
|
||
if err != nil {
|
||
log.WithError(err).Warn("Failed to select a key from the BasePool.")
|
||
return nil, apperrors.ErrNoKeysAvailable
|
||
}
|
||
|
||
resources, err := s.assembleRequestResources(selectedGroup, apiKey)
|
||
if err != nil {
|
||
log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.")
|
||
return nil, err
|
||
}
|
||
resources.RequestConfig = &models.RequestConfig{}
|
||
log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID)
|
||
return resources, nil
|
||
}
|
||
|
||
func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *models.AuthToken, groupName string) (*RequestResources, error) {
|
||
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"})
|
||
log.Debug("Entering PreciseRoute resource acquisition.")
|
||
|
||
targetGroup, ok := s.groupManager.GetGroupByName(groupName)
|
||
|
||
if !ok {
|
||
return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.")
|
||
}
|
||
if !s.isTokenAllowedForGroup(authToken, targetGroup.ID) {
|
||
return nil, apperrors.NewAPIError(apperrors.ErrPermissionDenied, "Token does not have permission to access this group.")
|
||
}
|
||
|
||
apiKey, _, err := s.keyRepo.SelectOneActiveKey(ctx, targetGroup)
|
||
if err != nil {
|
||
log.WithError(err).Warn("Failed to select a key from the precisely targeted group.")
|
||
return nil, apperrors.ErrNoKeysAvailable
|
||
}
|
||
|
||
resources, err := s.assembleRequestResources(targetGroup, apiKey)
|
||
if err != nil {
|
||
log.WithError(err).Error("Failed to assemble resources for precise route.")
|
||
return nil, err
|
||
}
|
||
resources.RequestConfig = targetGroup.RequestConfig
|
||
|
||
log.Infof("Successfully selected KeyID %d by precise routing to GroupID %d.", apiKey.ID, targetGroup.ID)
|
||
return resources, nil
|
||
}
|
||
|
||
func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken) []string {
|
||
allGroups := s.groupManager.GetAllGroups()
|
||
if len(allGroups) == 0 {
|
||
return []string{}
|
||
}
|
||
allowedModelsSet := make(map[string]struct{})
|
||
if authToken.IsAdmin {
|
||
for _, group := range allGroups {
|
||
for _, modelMapping := range group.AllowedModels {
|
||
allowedModelsSet[modelMapping.ModelName] = struct{}{}
|
||
}
|
||
}
|
||
} else {
|
||
allowedGroupIDs := make(map[uint]bool)
|
||
for _, ag := range authToken.AllowedGroups {
|
||
allowedGroupIDs[ag.ID] = true
|
||
}
|
||
for _, group := range allGroups {
|
||
if _, ok := allowedGroupIDs[group.ID]; ok {
|
||
for _, modelMapping := range group.AllowedModels {
|
||
allowedModelsSet[modelMapping.ModelName] = struct{}{}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
result := make([]string, 0, len(allowedModelsSet))
|
||
for modelName := range allowedModelsSet {
|
||
result = append(result, modelName)
|
||
}
|
||
sort.Strings(result)
|
||
return result
|
||
}
|
||
|
||
func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKey *models.APIKey) (*RequestResources, error) {
|
||
selectedUpstream := s.selectUpstreamForGroup(group)
|
||
if selectedUpstream == nil {
|
||
return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.")
|
||
}
|
||
var proxyConfig *models.ProxyConfig
|
||
return &RequestResources{
|
||
KeyGroup: group,
|
||
APIKey: apiKey,
|
||
UpstreamEndpoint: selectedUpstream,
|
||
ProxyConfig: proxyConfig,
|
||
}, nil
|
||
}
|
||
|
||
func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models.UpstreamEndpoint {
|
||
if len(group.AllowedUpstreams) > 0 {
|
||
return group.AllowedUpstreams[0]
|
||
}
|
||
globalSettings := s.settingsManager.GetSettings()
|
||
if globalSettings.DefaultUpstreamURL != "" {
|
||
return &models.UpstreamEndpoint{URL: globalSettings.DefaultUpstreamURL, Status: "active"}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (s *ResourceService) preWarmCache(logger *logrus.Logger) error {
|
||
time.Sleep(2 * time.Second)
|
||
s.logger.Info("Performing initial key cache pre-warming...")
|
||
if err := s.keyRepo.LoadAllKeysToStore(context.Background()); err != nil {
|
||
logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
|
||
return err
|
||
}
|
||
s.logger.Info("Initial key cache pre-warming completed successfully.")
|
||
return nil
|
||
}
|
||
|
||
func (s *ResourceService) GetResourcesForRequest(modelName string, allowedGroups []*models.KeyGroup) (*RequestResources, error) {
|
||
return nil, errors.New("GetResourcesForRequest is deprecated; use GetResourceFromBasePool or GetResourceFromGroup")
|
||
}
|
||
|
||
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowedGroupsFromToken []*models.KeyGroup) []*models.KeyGroup {
|
||
allGroupsFromCache := s.groupManager.GetAllGroups()
|
||
var candidateGroups []*models.KeyGroup
|
||
allowedGroupIDs := make(map[uint]bool)
|
||
isTokenRestricted := len(allowedGroupsFromToken) > 0
|
||
if isTokenRestricted {
|
||
for _, ag := range allowedGroupsFromToken {
|
||
allowedGroupIDs[ag.ID] = true
|
||
}
|
||
}
|
||
for _, group := range allGroupsFromCache {
|
||
if isTokenRestricted && !allowedGroupIDs[group.ID] {
|
||
continue
|
||
}
|
||
isModelAllowed := false
|
||
if len(group.AllowedModels) == 0 {
|
||
isModelAllowed = true
|
||
} else {
|
||
for _, m := range group.AllowedModels {
|
||
if m.ModelName == modelName {
|
||
isModelAllowed = true
|
||
break
|
||
}
|
||
}
|
||
}
|
||
if isModelAllowed {
|
||
candidateGroups = append(candidateGroups, group)
|
||
}
|
||
}
|
||
sort.SliceStable(candidateGroups, func(i, j int) bool {
|
||
return candidateGroups[i].Order < candidateGroups[j].Order
|
||
})
|
||
return candidateGroups
|
||
}
|
||
|
||
func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, groupID uint) bool {
|
||
if authToken.IsAdmin {
|
||
return true
|
||
}
|
||
for _, allowedGroup := range authToken.AllowedGroups {
|
||
if allowedGroup.ID == groupID {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
|
||
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
|
||
return
|
||
}
|
||
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
|
||
}
|