Files
gemini-banlancer/internal/repository/auth_token.go
2025-11-20 12:24:05 +08:00

207 lines
6.9 KiB
Go

// Filename: internal/repository/auth_token.go
package repository
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"gemini-balancer/internal/crypto"
"gemini-balancer/internal/models"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
// AuthTokenRepository defines the interface for AuthToken data access.
type AuthTokenRepository interface {
GetAllTokensWithGroups() ([]*models.AuthToken, error)
BatchUpdateTokens(updates []*models.TokenUpdateRequest) error
GetTokenByHashedValue(tokenHash string) (*models.AuthToken, error) // <-- Add this line
SeedAdminToken(encryptedToken, tokenHash string) error // <-- And this line for the seeder
}
type gormAuthTokenRepository struct {
db *gorm.DB
crypto *crypto.Service
logger *logrus.Entry
}
func NewAuthTokenRepository(db *gorm.DB, crypto *crypto.Service, logger *logrus.Logger) AuthTokenRepository {
return &gormAuthTokenRepository{
db: db,
crypto: crypto,
logger: logger.WithField("component", "repository.authToken🔐"),
}
}
// GetAllTokensWithGroups fetches all tokens and decrypts them for use in services.
func (r *gormAuthTokenRepository) GetAllTokensWithGroups() ([]*models.AuthToken, error) {
var tokens []*models.AuthToken
if err := r.db.Preload("AllowedGroups").Find(&tokens).Error; err != nil {
return nil, err
}
// [CRITICAL] Decrypt all tokens before returning them.
if err := r.decryptTokens(tokens); err != nil {
// Log the error but return the partially decrypted data, as some might be usable.
r.logger.WithError(err).Error("Batch decryption failed for some auth tokens.")
}
return tokens, nil
}
// BatchUpdateTokens provides a transactional way to update all tokens, handling encryption.
func (r *gormAuthTokenRepository) BatchUpdateTokens(updates []*models.TokenUpdateRequest) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 1. Separate admin and user tokens from the request
var adminUpdate *models.TokenUpdateRequest
var userUpdates []*models.TokenUpdateRequest
for _, u := range updates {
if u.IsAdmin {
adminUpdate = u
} else {
userUpdates = append(userUpdates, u)
}
}
// 2. Handle Admin Token Update
if adminUpdate != nil && adminUpdate.Token != "" {
encryptedToken, err := r.crypto.Encrypt(adminUpdate.Token)
if err != nil {
return fmt.Errorf("failed to encrypt admin token: %w", err)
}
hash := sha256.Sum256([]byte(adminUpdate.Token))
tokenHash := hex.EncodeToString(hash[:])
// Update both encrypted value and the hash
updateData := map[string]interface{}{
"encrypted_token": encryptedToken,
"token_hash": tokenHash,
}
if err := tx.Model(&models.AuthToken{}).Where("is_admin = ?", true).Updates(updateData).Error; err != nil {
return fmt.Errorf("failed to update admin token in db: %w", err)
}
}
// 3. Handle User Tokens Upsert
var existingTokens []*models.AuthToken
if err := tx.Where("is_admin = ?", false).Find(&existingTokens).Error; err != nil {
return fmt.Errorf("failed to fetch existing user tokens: %w", err)
}
existingTokenMap := make(map[uint]bool)
for _, t := range existingTokens {
existingTokenMap[t.ID] = true
}
var tokensToUpsert []models.AuthToken
for _, req := range userUpdates {
if req.Token == "" {
continue // Skip tokens with empty values
}
encryptedToken, err := r.crypto.Encrypt(req.Token)
if err != nil {
return fmt.Errorf("failed to encrypt token for upsert (ID: %d): %w", req.ID, err)
}
hash := sha256.Sum256([]byte(req.Token))
tokenHash := hex.EncodeToString(hash[:])
var groups []*models.KeyGroup
if len(req.AllowedGroupIDs) > 0 {
if err := tx.Find(&groups, req.AllowedGroupIDs).Error; err != nil {
return fmt.Errorf("failed to find key groups for token %d: %w", req.ID, err)
}
}
tokensToUpsert = append(tokensToUpsert, models.AuthToken{
ID: req.ID,
EncryptedToken: encryptedToken,
TokenHash: tokenHash,
Description: req.Description,
Tag: req.Tag,
Status: req.Status,
IsAdmin: false,
AllowedGroups: groups,
})
}
if len(tokensToUpsert) > 0 {
if err := tx.Save(&tokensToUpsert).Error; err != nil {
return fmt.Errorf("failed to upsert user tokens: %w", err)
}
}
// 4. Handle Deletions
incomingUserTokenIDs := make(map[uint]bool)
for _, u := range userUpdates {
if u.ID != 0 {
incomingUserTokenIDs[u.ID] = true
}
}
var idsToDelete []uint
for id := range existingTokenMap {
if !incomingUserTokenIDs[id] {
idsToDelete = append(idsToDelete, id)
}
}
if len(idsToDelete) > 0 {
if err := tx.Model(&models.AuthToken{}).Where("id IN ?", idsToDelete).Association("AllowedGroups").Clear(); err != nil {
return fmt.Errorf("failed to clear associations for tokens to be deleted: %w", err)
}
if err := tx.Where("id IN ?", idsToDelete).Delete(&models.AuthToken{}).Error; err != nil {
return fmt.Errorf("failed to delete user tokens: %w", err)
}
}
return nil
})
}
// --- Crypto Helper Functions ---
func (r *gormAuthTokenRepository) decryptToken(token *models.AuthToken) error {
if token == nil || token.EncryptedToken == "" || token.Token != "" {
return nil // Nothing to decrypt or already done
}
plaintext, err := r.crypto.Decrypt(token.EncryptedToken)
if err != nil {
return fmt.Errorf("failed to decrypt auth token ID %d: %w", token.ID, err)
}
token.Token = plaintext
return nil
}
func (r *gormAuthTokenRepository) decryptTokens(tokens []*models.AuthToken) error {
for i := range tokens {
if err := r.decryptToken(tokens[i]); err != nil {
r.logger.Error(err) // Log error but continue for other tokens
}
}
return nil
}
// GetTokenByHashedValue finds a token by its SHA256 hash for authentication.
func (r *gormAuthTokenRepository) GetTokenByHashedValue(tokenHash string) (*models.AuthToken, error) {
var authToken models.AuthToken
// Find the active token by its hash. This is the core of our secure authentication.
err := r.db.Where("token_hash = ? AND status = 'active'", tokenHash).Preload("AllowedGroups").First(&authToken).Error
if err != nil {
return nil, err
}
// [CRITICAL] Decrypt the token before returning it to the service layer.
// This ensures that subsequent logic (like in ResourceService) gets the full, usable object.
if err := r.decryptToken(&authToken); err != nil {
return nil, err
}
return &authToken, nil
}
// SeedAdminToken is a special-purpose function for the seeder to insert the initial admin token.
func (r *gormAuthTokenRepository) SeedAdminToken(encryptedToken, tokenHash string) error {
adminToken := models.AuthToken{
EncryptedToken: encryptedToken,
TokenHash: tokenHash,
Description: "Default Administrator Token",
Tag: "SYSTEM_ADMIN",
IsAdmin: true,
Status: "active", // Ensure the seeded token is active
}
// Using FirstOrCreate to be idempotent. If an admin token already exists, it does nothing.
return r.db.Where(models.AuthToken{IsAdmin: true}).FirstOrCreate(&adminToken).Error
}