275 lines
9.6 KiB
Go
275 lines
9.6 KiB
Go
// Filename: internal/service/resource_service.go
|
||
package service
|
||
|
||
import (
|
||
"context"
|
||
"gemini-balancer/internal/domain/proxy"
|
||
apperrors "gemini-balancer/internal/errors"
|
||
"gemini-balancer/internal/models"
|
||
"gemini-balancer/internal/repository"
|
||
"gemini-balancer/internal/settings"
|
||
"sort"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/sirupsen/logrus"
|
||
)
|
||
|
||
// RequestResources 封装了一次成功请求所需的所有资源。
|
||
type RequestResources struct {
|
||
KeyGroup *models.KeyGroup
|
||
APIKey *models.APIKey
|
||
UpstreamEndpoint *models.UpstreamEndpoint
|
||
ProxyConfig *models.ProxyConfig
|
||
RequestConfig *models.RequestConfig
|
||
}
|
||
|
||
// ResourceService 负责根据请求参数和业务规则,动态地选择和分配API密钥及相关资源。
|
||
type ResourceService struct {
|
||
settingsManager *settings.SettingsManager
|
||
groupManager *GroupManager
|
||
keyRepo repository.KeyRepository
|
||
authTokenRepo repository.AuthTokenRepository
|
||
apiKeyService *APIKeyService
|
||
proxyManager *proxy.Module
|
||
logger *logrus.Entry
|
||
initOnce sync.Once
|
||
}
|
||
|
||
// NewResourceService 创建并初始化一个新的 ResourceService 实例。
|
||
func NewResourceService(
|
||
sm *settings.SettingsManager,
|
||
gm *GroupManager,
|
||
kr repository.KeyRepository,
|
||
atr repository.AuthTokenRepository,
|
||
aks *APIKeyService,
|
||
pm *proxy.Module,
|
||
logger *logrus.Logger,
|
||
) *ResourceService {
|
||
rs := &ResourceService{
|
||
settingsManager: sm,
|
||
groupManager: gm,
|
||
keyRepo: kr,
|
||
authTokenRepo: atr,
|
||
apiKeyService: aks,
|
||
proxyManager: pm,
|
||
logger: logger.WithField("component", "ResourceService📦️"),
|
||
}
|
||
|
||
// 使用 sync.Once 确保预热任务在服务生命周期内仅执行一次
|
||
rs.initOnce.Do(func() {
|
||
go rs.preWarmCache()
|
||
})
|
||
return rs
|
||
}
|
||
|
||
// GetResourceFromBasePool 使用智能聚合池模式获取资源。
|
||
func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken *models.AuthToken, modelName string) (*RequestResources, error) {
|
||
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"})
|
||
log.Debug("Entering BasePool resource acquisition.")
|
||
|
||
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken)
|
||
if len(candidateGroups) == 0 {
|
||
log.Warn("No candidate groups found for BasePool construction.")
|
||
return nil, apperrors.ErrNoKeysAvailable
|
||
}
|
||
|
||
basePool := &repository.BasePool{
|
||
CandidateGroups: candidateGroups,
|
||
PollingStrategy: s.settingsManager.GetSettings().PollingStrategy,
|
||
}
|
||
|
||
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(ctx, basePool)
|
||
if err != nil {
|
||
log.WithError(err).Warn("Failed to select a key from the BasePool.")
|
||
return nil, apperrors.ErrNoKeysAvailable
|
||
}
|
||
|
||
resources, err := s.assembleRequestResources(selectedGroup, apiKey)
|
||
if err != nil {
|
||
log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.")
|
||
return nil, err
|
||
}
|
||
resources.RequestConfig = &models.RequestConfig{} // BasePool 模式使用默认请求配置
|
||
|
||
log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID)
|
||
return resources, nil
|
||
}
|
||
|
||
// GetResourceFromGroup 使用精确路由模式(指定密钥组)获取资源。
|
||
func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *models.AuthToken, groupName string) (*RequestResources, error) {
|
||
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"})
|
||
log.Debug("Entering PreciseRoute resource acquisition.")
|
||
|
||
targetGroup, ok := s.groupManager.GetGroupByName(groupName)
|
||
if !ok {
|
||
return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.")
|
||
}
|
||
if !s.isTokenAllowedForGroup(authToken, targetGroup.ID) {
|
||
return nil, apperrors.NewAPIError(apperrors.ErrPermissionDenied, "Token does not have permission to access this group.")
|
||
}
|
||
|
||
apiKey, _, err := s.keyRepo.SelectOneActiveKey(ctx, targetGroup)
|
||
if err != nil {
|
||
log.WithError(err).Warn("Failed to select a key from the precisely targeted group.")
|
||
return nil, apperrors.ErrNoKeysAvailable
|
||
}
|
||
|
||
resources, err := s.assembleRequestResources(targetGroup, apiKey)
|
||
if err != nil {
|
||
log.WithError(err).Error("Failed to assemble resources for precise route.")
|
||
return nil, err
|
||
}
|
||
resources.RequestConfig = targetGroup.RequestConfig // 精确路由使用该组的特定请求配置
|
||
|
||
log.Infof("Successfully selected KeyID %d by precise routing to GroupID %d.", apiKey.ID, targetGroup.ID)
|
||
return resources, nil
|
||
}
|
||
|
||
// GetAllowedModelsForToken 获取指定认证令牌有权访问的所有模型名称列表。
|
||
func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken) []string {
|
||
allGroups := s.groupManager.GetAllGroups()
|
||
if len(allGroups) == 0 {
|
||
return []string{}
|
||
}
|
||
|
||
allowedGroupIDs := make(map[uint]bool)
|
||
if authToken.IsAdmin {
|
||
for _, group := range allGroups {
|
||
allowedGroupIDs[group.ID] = true
|
||
}
|
||
} else {
|
||
for _, ag := range authToken.AllowedGroups {
|
||
allowedGroupIDs[ag.ID] = true
|
||
}
|
||
}
|
||
|
||
allowedModelsSet := make(map[string]struct{})
|
||
for _, group := range allGroups {
|
||
if allowedGroupIDs[group.ID] {
|
||
for _, modelMapping := range group.AllowedModels {
|
||
allowedModelsSet[modelMapping.ModelName] = struct{}{}
|
||
}
|
||
}
|
||
}
|
||
|
||
result := make([]string, 0, len(allowedModelsSet))
|
||
for modelName := range allowedModelsSet {
|
||
result = append(result, modelName)
|
||
}
|
||
sort.Strings(result)
|
||
return result
|
||
}
|
||
|
||
// ReportRequestResult 向 APIKeyService 报告请求的最终结果,以便更新密钥状态。
|
||
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
|
||
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
|
||
return
|
||
}
|
||
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
|
||
}
|
||
|
||
// --- 私有辅助方法 ---
|
||
|
||
// preWarmCache 在后台执行一次性的缓存预热任务。
|
||
func (s *ResourceService) preWarmCache() {
|
||
time.Sleep(2 * time.Second) // 等待其他服务组件可能完成初始化
|
||
s.logger.Info("Performing initial key cache pre-warming...")
|
||
|
||
// 强制加载 GroupManager 缓存
|
||
s.logger.Info("Pre-warming GroupManager cache...")
|
||
_ = s.groupManager.GetAllGroups()
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // 给予更长的超时
|
||
defer cancel()
|
||
|
||
if err := s.keyRepo.LoadAllKeysToStore(ctx); err != nil {
|
||
s.logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
|
||
} else {
|
||
s.logger.Info("Initial key cache pre-warming completed successfully.")
|
||
}
|
||
}
|
||
|
||
// assembleRequestResources 根据密钥组和API密钥组装最终的资源对象。
|
||
func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKey *models.APIKey) (*RequestResources, error) {
|
||
selectedUpstream := s.selectUpstreamForGroup(group)
|
||
if selectedUpstream == nil {
|
||
return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.")
|
||
}
|
||
var proxyConfig *models.ProxyConfig
|
||
var err error
|
||
// 只有在组明确启用代理时,才为其分配代理
|
||
if group.EnableProxy {
|
||
proxyConfig, err = s.proxyManager.AssignProxyIfNeeded(apiKey)
|
||
if err != nil {
|
||
s.logger.WithError(err).Errorf("Group '%s' (ID: %d) requires a proxy, but failed to assign one for KeyID %d", group.Name, group.ID, apiKey.ID)
|
||
// 根据业务需求,这里必须返回错误,因为代理是该组的强制要求
|
||
return nil, apperrors.NewAPIError(apperrors.ErrProxyNotAvailable, "Required proxy is not available for this request.")
|
||
}
|
||
}
|
||
return &RequestResources{
|
||
KeyGroup: group,
|
||
APIKey: apiKey,
|
||
UpstreamEndpoint: selectedUpstream,
|
||
ProxyConfig: proxyConfig,
|
||
}, nil
|
||
}
|
||
|
||
// selectUpstreamForGroup 为指定的密钥组选择一个上游端点。
|
||
func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models.UpstreamEndpoint {
|
||
if len(group.AllowedUpstreams) > 0 {
|
||
// (未来可扩展负载均衡逻辑)
|
||
return group.AllowedUpstreams[0]
|
||
}
|
||
globalSettings := s.settingsManager.GetSettings()
|
||
if globalSettings.DefaultUpstreamURL != "" {
|
||
return &models.UpstreamEndpoint{URL: globalSettings.DefaultUpstreamURL, Status: "active"}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// filterAndSortCandidateGroups 根据模型名称和令牌权限,筛选并排序出合格的候选密钥组。
|
||
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, authToken *models.AuthToken) []*models.KeyGroup {
|
||
allGroupsFromCache := s.groupManager.GetAllGroups()
|
||
var candidateGroups []*models.KeyGroup
|
||
|
||
for _, group := range allGroupsFromCache {
|
||
// 检查令牌权限
|
||
if !s.isTokenAllowedForGroup(authToken, group.ID) {
|
||
continue
|
||
}
|
||
// 检查模型支持情况 (如果组内未限制模型,则默认支持所有模型)
|
||
if len(group.AllowedModels) == 0 || s.groupSupportsModel(group, modelName) {
|
||
candidateGroups = append(candidateGroups, group)
|
||
}
|
||
}
|
||
|
||
sort.SliceStable(candidateGroups, func(i, j int) bool {
|
||
return candidateGroups[i].Order < candidateGroups[j].Order
|
||
})
|
||
return candidateGroups
|
||
}
|
||
|
||
// groupSupportsModel 检查指定的密钥组是否支持给定的模型名称。
|
||
func (s *ResourceService) groupSupportsModel(group *models.KeyGroup, modelName string) bool {
|
||
for _, m := range group.AllowedModels {
|
||
if m.ModelName == modelName {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// isTokenAllowedForGroup 检查指定的认证令牌是否有权访问给定的密钥组。
|
||
func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, groupID uint) bool {
|
||
if authToken.IsAdmin {
|
||
return true
|
||
}
|
||
for _, allowedGroup := range authToken.AllowedGroups {
|
||
if allowedGroup.ID == groupID {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|