Fix basepool & 优化 repo
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
// Filename: internal/service/resource_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"gemini-balancer/internal/domain/proxy"
|
||||
apperrors "gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
@@ -16,10 +15,7 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoResourceAvailable = errors.New("no available resource found for the request")
|
||||
)
|
||||
|
||||
// RequestResources 封装了一次成功请求所需的所有资源。
|
||||
type RequestResources struct {
|
||||
KeyGroup *models.KeyGroup
|
||||
APIKey *models.APIKey
|
||||
@@ -28,41 +24,51 @@ type RequestResources struct {
|
||||
RequestConfig *models.RequestConfig
|
||||
}
|
||||
|
||||
// ResourceService 负责根据请求参数和业务规则,动态地选择和分配API密钥及相关资源。
|
||||
type ResourceService struct {
|
||||
settingsManager *settings.SettingsManager
|
||||
groupManager *GroupManager
|
||||
keyRepo repository.KeyRepository
|
||||
authTokenRepo repository.AuthTokenRepository
|
||||
apiKeyService *APIKeyService
|
||||
proxyManager *proxy.Module
|
||||
logger *logrus.Entry
|
||||
initOnce sync.Once
|
||||
}
|
||||
|
||||
// NewResourceService 创建并初始化一个新的 ResourceService 实例。
|
||||
func NewResourceService(
|
||||
sm *settings.SettingsManager,
|
||||
gm *GroupManager,
|
||||
kr repository.KeyRepository,
|
||||
atr repository.AuthTokenRepository,
|
||||
aks *APIKeyService,
|
||||
pm *proxy.Module,
|
||||
logger *logrus.Logger,
|
||||
) *ResourceService {
|
||||
rs := &ResourceService{
|
||||
settingsManager: sm,
|
||||
groupManager: gm,
|
||||
keyRepo: kr,
|
||||
authTokenRepo: atr,
|
||||
apiKeyService: aks,
|
||||
proxyManager: pm,
|
||||
logger: logger.WithField("component", "ResourceService📦️"),
|
||||
}
|
||||
|
||||
// 使用 sync.Once 确保预热任务在服务生命周期内仅执行一次
|
||||
rs.initOnce.Do(func() {
|
||||
go rs.preWarmCache(logger)
|
||||
go rs.preWarmCache()
|
||||
})
|
||||
return rs
|
||||
}
|
||||
|
||||
// GetResourceFromBasePool 使用智能聚合池模式获取资源。
|
||||
func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken *models.AuthToken, modelName string) (*RequestResources, error) {
|
||||
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"})
|
||||
log.Debug("Entering BasePool resource acquisition.")
|
||||
|
||||
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken.AllowedGroups)
|
||||
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken)
|
||||
if len(candidateGroups) == 0 {
|
||||
log.Warn("No candidate groups found for BasePool construction.")
|
||||
return nil, apperrors.ErrNoKeysAvailable
|
||||
@@ -84,17 +90,18 @@ func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken
|
||||
log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.")
|
||||
return nil, err
|
||||
}
|
||||
resources.RequestConfig = &models.RequestConfig{}
|
||||
resources.RequestConfig = &models.RequestConfig{} // BasePool 模式使用默认请求配置
|
||||
|
||||
log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID)
|
||||
return resources, nil
|
||||
}
|
||||
|
||||
// GetResourceFromGroup 使用精确路由模式(指定密钥组)获取资源。
|
||||
func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *models.AuthToken, groupName string) (*RequestResources, error) {
|
||||
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"})
|
||||
log.Debug("Entering PreciseRoute resource acquisition.")
|
||||
|
||||
targetGroup, ok := s.groupManager.GetGroupByName(groupName)
|
||||
|
||||
if !ok {
|
||||
return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.")
|
||||
}
|
||||
@@ -113,37 +120,39 @@ func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *m
|
||||
log.WithError(err).Error("Failed to assemble resources for precise route.")
|
||||
return nil, err
|
||||
}
|
||||
resources.RequestConfig = targetGroup.RequestConfig
|
||||
resources.RequestConfig = targetGroup.RequestConfig // 精确路由使用该组的特定请求配置
|
||||
|
||||
log.Infof("Successfully selected KeyID %d by precise routing to GroupID %d.", apiKey.ID, targetGroup.ID)
|
||||
return resources, nil
|
||||
}
|
||||
|
||||
// GetAllowedModelsForToken 获取指定认证令牌有权访问的所有模型名称列表。
|
||||
func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken) []string {
|
||||
allGroups := s.groupManager.GetAllGroups()
|
||||
if len(allGroups) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
allowedModelsSet := make(map[string]struct{})
|
||||
|
||||
allowedGroupIDs := make(map[uint]bool)
|
||||
if authToken.IsAdmin {
|
||||
for _, group := range allGroups {
|
||||
allowedGroupIDs[group.ID] = true
|
||||
}
|
||||
} else {
|
||||
for _, ag := range authToken.AllowedGroups {
|
||||
allowedGroupIDs[ag.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
allowedModelsSet := make(map[string]struct{})
|
||||
for _, group := range allGroups {
|
||||
if allowedGroupIDs[group.ID] {
|
||||
for _, modelMapping := range group.AllowedModels {
|
||||
allowedModelsSet[modelMapping.ModelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
allowedGroupIDs := make(map[uint]bool)
|
||||
for _, ag := range authToken.AllowedGroups {
|
||||
allowedGroupIDs[ag.ID] = true
|
||||
}
|
||||
for _, group := range allGroups {
|
||||
if _, ok := allowedGroupIDs[group.ID]; ok {
|
||||
for _, modelMapping := range group.AllowedModels {
|
||||
allowedModelsSet[modelMapping.ModelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]string, 0, len(allowedModelsSet))
|
||||
for modelName := range allowedModelsSet {
|
||||
result = append(result, modelName)
|
||||
@@ -152,12 +161,52 @@ func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken)
|
||||
return result
|
||||
}
|
||||
|
||||
// ReportRequestResult 向 APIKeyService 报告请求的最终结果,以便更新密钥状态。
|
||||
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
|
||||
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
|
||||
return
|
||||
}
|
||||
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
|
||||
}
|
||||
|
||||
// --- 私有辅助方法 ---
|
||||
|
||||
// preWarmCache 在后台执行一次性的缓存预热任务。
|
||||
func (s *ResourceService) preWarmCache() {
|
||||
time.Sleep(2 * time.Second) // 等待其他服务组件可能完成初始化
|
||||
s.logger.Info("Performing initial key cache pre-warming...")
|
||||
|
||||
// 强制加载 GroupManager 缓存
|
||||
s.logger.Info("Pre-warming GroupManager cache...")
|
||||
_ = s.groupManager.GetAllGroups()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // 给予更长的超时
|
||||
defer cancel()
|
||||
|
||||
if err := s.keyRepo.LoadAllKeysToStore(ctx); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
|
||||
} else {
|
||||
s.logger.Info("Initial key cache pre-warming completed successfully.")
|
||||
}
|
||||
}
|
||||
|
||||
// assembleRequestResources 根据密钥组和API密钥组装最终的资源对象。
|
||||
func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKey *models.APIKey) (*RequestResources, error) {
|
||||
selectedUpstream := s.selectUpstreamForGroup(group)
|
||||
if selectedUpstream == nil {
|
||||
return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.")
|
||||
}
|
||||
var proxyConfig *models.ProxyConfig
|
||||
var err error
|
||||
// 只有在组明确启用代理时,才为其分配代理
|
||||
if group.EnableProxy {
|
||||
proxyConfig, err = s.proxyManager.AssignProxyIfNeeded(apiKey)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Group '%s' (ID: %d) requires a proxy, but failed to assign one for KeyID %d", group.Name, group.ID, apiKey.ID)
|
||||
// 根据业务需求,这里必须返回错误,因为代理是该组的强制要求
|
||||
return nil, apperrors.NewAPIError(apperrors.ErrProxyNotAvailable, "Required proxy is not available for this request.")
|
||||
}
|
||||
}
|
||||
return &RequestResources{
|
||||
KeyGroup: group,
|
||||
APIKey: apiKey,
|
||||
@@ -166,8 +215,10 @@ func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKe
|
||||
}, nil
|
||||
}
|
||||
|
||||
// selectUpstreamForGroup 为指定的密钥组选择一个上游端点。
|
||||
func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models.UpstreamEndpoint {
|
||||
if len(group.AllowedUpstreams) > 0 {
|
||||
// (未来可扩展负载均衡逻辑)
|
||||
return group.AllowedUpstreams[0]
|
||||
}
|
||||
globalSettings := s.settingsManager.GetSettings()
|
||||
@@ -177,56 +228,39 @@ func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ResourceService) preWarmCache(logger *logrus.Logger) error {
|
||||
time.Sleep(2 * time.Second)
|
||||
s.logger.Info("Performing initial key cache pre-warming...")
|
||||
if err := s.keyRepo.LoadAllKeysToStore(context.Background()); err != nil {
|
||||
logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
|
||||
return err
|
||||
}
|
||||
s.logger.Info("Initial key cache pre-warming completed successfully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ResourceService) GetResourcesForRequest(modelName string, allowedGroups []*models.KeyGroup) (*RequestResources, error) {
|
||||
return nil, errors.New("GetResourcesForRequest is deprecated; use GetResourceFromBasePool or GetResourceFromGroup")
|
||||
}
|
||||
|
||||
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowedGroupsFromToken []*models.KeyGroup) []*models.KeyGroup {
|
||||
// filterAndSortCandidateGroups 根据模型名称和令牌权限,筛选并排序出合格的候选密钥组。
|
||||
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, authToken *models.AuthToken) []*models.KeyGroup {
|
||||
allGroupsFromCache := s.groupManager.GetAllGroups()
|
||||
var candidateGroups []*models.KeyGroup
|
||||
allowedGroupIDs := make(map[uint]bool)
|
||||
isTokenRestricted := len(allowedGroupsFromToken) > 0
|
||||
if isTokenRestricted {
|
||||
for _, ag := range allowedGroupsFromToken {
|
||||
allowedGroupIDs[ag.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, group := range allGroupsFromCache {
|
||||
if isTokenRestricted && !allowedGroupIDs[group.ID] {
|
||||
// 检查令牌权限
|
||||
if !s.isTokenAllowedForGroup(authToken, group.ID) {
|
||||
continue
|
||||
}
|
||||
isModelAllowed := false
|
||||
if len(group.AllowedModels) == 0 {
|
||||
isModelAllowed = true
|
||||
} else {
|
||||
for _, m := range group.AllowedModels {
|
||||
if m.ModelName == modelName {
|
||||
isModelAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if isModelAllowed {
|
||||
// 检查模型支持情况 (如果组内未限制模型,则默认支持所有模型)
|
||||
if len(group.AllowedModels) == 0 || s.groupSupportsModel(group, modelName) {
|
||||
candidateGroups = append(candidateGroups, group)
|
||||
}
|
||||
}
|
||||
|
||||
sort.SliceStable(candidateGroups, func(i, j int) bool {
|
||||
return candidateGroups[i].Order < candidateGroups[j].Order
|
||||
})
|
||||
return candidateGroups
|
||||
}
|
||||
|
||||
// groupSupportsModel 检查指定的密钥组是否支持给定的模型名称。
|
||||
func (s *ResourceService) groupSupportsModel(group *models.KeyGroup, modelName string) bool {
|
||||
for _, m := range group.AllowedModels {
|
||||
if m.ModelName == modelName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isTokenAllowedForGroup 检查指定的认证令牌是否有权访问给定的密钥组。
|
||||
func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, groupID uint) bool {
|
||||
if authToken.IsAdmin {
|
||||
return true
|
||||
@@ -238,10 +272,3 @@ func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, gr
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
|
||||
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
|
||||
return
|
||||
}
|
||||
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user