122 lines
2.7 KiB
Go
122 lines
2.7 KiB
Go
// Filename: internal/middleware/auth.go
|
||
|
||
package middleware
|
||
|
||
import (
|
||
"gemini-balancer/internal/service"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/sirupsen/logrus"
|
||
)
|
||
|
||
type ErrorResponse struct {
|
||
Error string `json:"error"`
|
||
Code string `json:"code,omitempty"`
|
||
}
|
||
|
||
// APIAdminAuthMiddleware 管理后台 API 认证
|
||
func APIAdminAuthMiddleware(
|
||
securityService *service.SecurityService,
|
||
logger *logrus.Logger,
|
||
) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
tokenValue := extractBearerToken(c)
|
||
if tokenValue == "" {
|
||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
|
||
Error: "Authentication required",
|
||
Code: "AUTH_MISSING",
|
||
})
|
||
return
|
||
}
|
||
|
||
// ✅ 只传 token 参数(移除 context)
|
||
authToken, err := securityService.AuthenticateToken(tokenValue)
|
||
if err != nil {
|
||
logger.WithError(err).Warn("Authentication failed")
|
||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
|
||
Error: "Invalid authentication",
|
||
Code: "AUTH_INVALID",
|
||
})
|
||
return
|
||
}
|
||
|
||
if !authToken.IsAdmin {
|
||
c.AbortWithStatusJSON(http.StatusForbidden, ErrorResponse{
|
||
Error: "Admin access required",
|
||
Code: "AUTH_FORBIDDEN",
|
||
})
|
||
return
|
||
}
|
||
|
||
c.Set("adminUser", authToken)
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// ProxyAuthMiddleware 代理请求认证
|
||
func ProxyAuthMiddleware(
|
||
securityService *service.SecurityService,
|
||
logger *logrus.Logger,
|
||
) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
tokenValue := extractProxyToken(c)
|
||
if tokenValue == "" {
|
||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
|
||
Error: "API key required",
|
||
Code: "KEY_MISSING",
|
||
})
|
||
return
|
||
}
|
||
|
||
// ✅ 只传 token 参数(移除 context)
|
||
authToken, err := securityService.AuthenticateToken(tokenValue)
|
||
if err != nil {
|
||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
|
||
Error: "Invalid API key",
|
||
Code: "KEY_INVALID",
|
||
})
|
||
return
|
||
}
|
||
|
||
c.Set("authToken", authToken)
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// extractProxyToken 按优先级提取 token
|
||
func extractProxyToken(c *gin.Context) string {
|
||
// 优先级 1: Authorization Header
|
||
if token := extractBearerToken(c); token != "" {
|
||
return token
|
||
}
|
||
|
||
// 优先级 2: X-Api-Key
|
||
if key := c.GetHeader("X-Api-Key"); key != "" {
|
||
return key
|
||
}
|
||
|
||
// 优先级 3: X-Goog-Api-Key
|
||
if key := c.GetHeader("X-Goog-Api-Key"); key != "" {
|
||
return key
|
||
}
|
||
|
||
// 优先级 4: Query 参数(不推荐)
|
||
return c.Query("key")
|
||
}
|
||
|
||
// extractBearerToken 提取 Bearer Token
|
||
func extractBearerToken(c *gin.Context) string {
|
||
authHeader := c.GetHeader("Authorization")
|
||
if authHeader == "" {
|
||
return ""
|
||
}
|
||
|
||
const prefix = "Bearer "
|
||
if !strings.HasPrefix(authHeader, prefix) {
|
||
return ""
|
||
}
|
||
|
||
return strings.TrimSpace(authHeader[len(prefix):])
|
||
} |