Files
gemini-banlancer/internal/service/resource_service.go
2025-11-23 22:42:58 +08:00

275 lines
9.6 KiB
Go
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Filename: internal/service/resource_service.go
package service
import (
"context"
"gemini-balancer/internal/domain/proxy"
apperrors "gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
"gemini-balancer/internal/repository"
"gemini-balancer/internal/settings"
"sort"
"sync"
"time"
"github.com/sirupsen/logrus"
)
// RequestResources 封装了一次成功请求所需的所有资源。
type RequestResources struct {
KeyGroup *models.KeyGroup
APIKey *models.APIKey
UpstreamEndpoint *models.UpstreamEndpoint
ProxyConfig *models.ProxyConfig
RequestConfig *models.RequestConfig
}
// ResourceService 负责根据请求参数和业务规则动态地选择和分配API密钥及相关资源。
type ResourceService struct {
settingsManager *settings.SettingsManager
groupManager *GroupManager
keyRepo repository.KeyRepository
authTokenRepo repository.AuthTokenRepository
apiKeyService *APIKeyService
proxyManager *proxy.Module
logger *logrus.Entry
initOnce sync.Once
}
// NewResourceService 创建并初始化一个新的 ResourceService 实例。
func NewResourceService(
sm *settings.SettingsManager,
gm *GroupManager,
kr repository.KeyRepository,
atr repository.AuthTokenRepository,
aks *APIKeyService,
pm *proxy.Module,
logger *logrus.Logger,
) *ResourceService {
rs := &ResourceService{
settingsManager: sm,
groupManager: gm,
keyRepo: kr,
authTokenRepo: atr,
apiKeyService: aks,
proxyManager: pm,
logger: logger.WithField("component", "ResourceService📦"),
}
// 使用 sync.Once 确保预热任务在服务生命周期内仅执行一次
rs.initOnce.Do(func() {
go rs.preWarmCache()
})
return rs
}
// GetResourceFromBasePool 使用智能聚合池模式获取资源。
func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken *models.AuthToken, modelName string) (*RequestResources, error) {
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"})
log.Debug("Entering BasePool resource acquisition.")
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken)
if len(candidateGroups) == 0 {
log.Warn("No candidate groups found for BasePool construction.")
return nil, apperrors.ErrNoKeysAvailable
}
basePool := &repository.BasePool{
CandidateGroups: candidateGroups,
PollingStrategy: s.settingsManager.GetSettings().PollingStrategy,
}
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(ctx, basePool)
if err != nil {
log.WithError(err).Warn("Failed to select a key from the BasePool.")
return nil, apperrors.ErrNoKeysAvailable
}
resources, err := s.assembleRequestResources(selectedGroup, apiKey)
if err != nil {
log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.")
return nil, err
}
resources.RequestConfig = &models.RequestConfig{} // BasePool 模式使用默认请求配置
log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID)
return resources, nil
}
// GetResourceFromGroup 使用精确路由模式(指定密钥组)获取资源。
func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *models.AuthToken, groupName string) (*RequestResources, error) {
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"})
log.Debug("Entering PreciseRoute resource acquisition.")
targetGroup, ok := s.groupManager.GetGroupByName(groupName)
if !ok {
return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.")
}
if !s.isTokenAllowedForGroup(authToken, targetGroup.ID) {
return nil, apperrors.NewAPIError(apperrors.ErrPermissionDenied, "Token does not have permission to access this group.")
}
apiKey, _, err := s.keyRepo.SelectOneActiveKey(ctx, targetGroup)
if err != nil {
log.WithError(err).Warn("Failed to select a key from the precisely targeted group.")
return nil, apperrors.ErrNoKeysAvailable
}
resources, err := s.assembleRequestResources(targetGroup, apiKey)
if err != nil {
log.WithError(err).Error("Failed to assemble resources for precise route.")
return nil, err
}
resources.RequestConfig = targetGroup.RequestConfig // 精确路由使用该组的特定请求配置
log.Infof("Successfully selected KeyID %d by precise routing to GroupID %d.", apiKey.ID, targetGroup.ID)
return resources, nil
}
// GetAllowedModelsForToken 获取指定认证令牌有权访问的所有模型名称列表。
func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken) []string {
allGroups := s.groupManager.GetAllGroups()
if len(allGroups) == 0 {
return []string{}
}
allowedGroupIDs := make(map[uint]bool)
if authToken.IsAdmin {
for _, group := range allGroups {
allowedGroupIDs[group.ID] = true
}
} else {
for _, ag := range authToken.AllowedGroups {
allowedGroupIDs[ag.ID] = true
}
}
allowedModelsSet := make(map[string]struct{})
for _, group := range allGroups {
if allowedGroupIDs[group.ID] {
for _, modelMapping := range group.AllowedModels {
allowedModelsSet[modelMapping.ModelName] = struct{}{}
}
}
}
result := make([]string, 0, len(allowedModelsSet))
for modelName := range allowedModelsSet {
result = append(result, modelName)
}
sort.Strings(result)
return result
}
// ReportRequestResult 向 APIKeyService 报告请求的最终结果,以便更新密钥状态。
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
return
}
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
}
// --- 私有辅助方法 ---
// preWarmCache 在后台执行一次性的缓存预热任务。
func (s *ResourceService) preWarmCache() {
time.Sleep(2 * time.Second) // 等待其他服务组件可能完成初始化
s.logger.Info("Performing initial key cache pre-warming...")
// 强制加载 GroupManager 缓存
s.logger.Info("Pre-warming GroupManager cache...")
_ = s.groupManager.GetAllGroups()
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // 给予更长的超时
defer cancel()
if err := s.keyRepo.LoadAllKeysToStore(ctx); err != nil {
s.logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
} else {
s.logger.Info("Initial key cache pre-warming completed successfully.")
}
}
// assembleRequestResources 根据密钥组和API密钥组装最终的资源对象。
func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKey *models.APIKey) (*RequestResources, error) {
selectedUpstream := s.selectUpstreamForGroup(group)
if selectedUpstream == nil {
return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.")
}
var proxyConfig *models.ProxyConfig
var err error
// 只有在组明确启用代理时,才为其分配代理
if group.EnableProxy {
proxyConfig, err = s.proxyManager.AssignProxyIfNeeded(apiKey)
if err != nil {
s.logger.WithError(err).Errorf("Group '%s' (ID: %d) requires a proxy, but failed to assign one for KeyID %d", group.Name, group.ID, apiKey.ID)
// 根据业务需求,这里必须返回错误,因为代理是该组的强制要求
return nil, apperrors.NewAPIError(apperrors.ErrProxyNotAvailable, "Required proxy is not available for this request.")
}
}
return &RequestResources{
KeyGroup: group,
APIKey: apiKey,
UpstreamEndpoint: selectedUpstream,
ProxyConfig: proxyConfig,
}, nil
}
// selectUpstreamForGroup 为指定的密钥组选择一个上游端点。
func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models.UpstreamEndpoint {
if len(group.AllowedUpstreams) > 0 {
// (未来可扩展负载均衡逻辑)
return group.AllowedUpstreams[0]
}
globalSettings := s.settingsManager.GetSettings()
if globalSettings.DefaultUpstreamURL != "" {
return &models.UpstreamEndpoint{URL: globalSettings.DefaultUpstreamURL, Status: "active"}
}
return nil
}
// filterAndSortCandidateGroups 根据模型名称和令牌权限,筛选并排序出合格的候选密钥组。
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, authToken *models.AuthToken) []*models.KeyGroup {
allGroupsFromCache := s.groupManager.GetAllGroups()
var candidateGroups []*models.KeyGroup
for _, group := range allGroupsFromCache {
// 检查令牌权限
if !s.isTokenAllowedForGroup(authToken, group.ID) {
continue
}
// 检查模型支持情况 (如果组内未限制模型,则默认支持所有模型)
if len(group.AllowedModels) == 0 || s.groupSupportsModel(group, modelName) {
candidateGroups = append(candidateGroups, group)
}
}
sort.SliceStable(candidateGroups, func(i, j int) bool {
return candidateGroups[i].Order < candidateGroups[j].Order
})
return candidateGroups
}
// groupSupportsModel 检查指定的密钥组是否支持给定的模型名称。
func (s *ResourceService) groupSupportsModel(group *models.KeyGroup, modelName string) bool {
for _, m := range group.AllowedModels {
if m.ModelName == modelName {
return true
}
}
return false
}
// isTokenAllowedForGroup 检查指定的认证令牌是否有权访问给定的密钥组。
func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, groupID uint) bool {
if authToken.IsAdmin {
return true
}
for _, allowedGroup := range authToken.AllowedGroups {
if allowedGroup.ID == groupID {
return true
}
}
return false
}