207 lines
6.9 KiB
Go
207 lines
6.9 KiB
Go
// Filename: internal/repository/auth_token.go
|
|
|
|
package repository
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"gemini-balancer/internal/crypto"
|
|
"gemini-balancer/internal/models"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// AuthTokenRepository defines the interface for AuthToken data access.
|
|
type AuthTokenRepository interface {
|
|
GetAllTokensWithGroups() ([]*models.AuthToken, error)
|
|
BatchUpdateTokens(updates []*models.TokenUpdateRequest) error
|
|
GetTokenByHashedValue(tokenHash string) (*models.AuthToken, error) // <-- Add this line
|
|
SeedAdminToken(encryptedToken, tokenHash string) error // <-- And this line for the seeder
|
|
}
|
|
|
|
type gormAuthTokenRepository struct {
|
|
db *gorm.DB
|
|
crypto *crypto.Service
|
|
logger *logrus.Entry
|
|
}
|
|
|
|
func NewAuthTokenRepository(db *gorm.DB, crypto *crypto.Service, logger *logrus.Logger) AuthTokenRepository {
|
|
return &gormAuthTokenRepository{
|
|
db: db,
|
|
crypto: crypto,
|
|
logger: logger.WithField("component", "repository.authToken🔐"),
|
|
}
|
|
}
|
|
|
|
// GetAllTokensWithGroups fetches all tokens and decrypts them for use in services.
|
|
func (r *gormAuthTokenRepository) GetAllTokensWithGroups() ([]*models.AuthToken, error) {
|
|
var tokens []*models.AuthToken
|
|
if err := r.db.Preload("AllowedGroups").Find(&tokens).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
// [CRITICAL] Decrypt all tokens before returning them.
|
|
if err := r.decryptTokens(tokens); err != nil {
|
|
// Log the error but return the partially decrypted data, as some might be usable.
|
|
r.logger.WithError(err).Error("Batch decryption failed for some auth tokens.")
|
|
}
|
|
return tokens, nil
|
|
}
|
|
|
|
// BatchUpdateTokens provides a transactional way to update all tokens, handling encryption.
|
|
func (r *gormAuthTokenRepository) BatchUpdateTokens(updates []*models.TokenUpdateRequest) error {
|
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
|
// 1. Separate admin and user tokens from the request
|
|
var adminUpdate *models.TokenUpdateRequest
|
|
var userUpdates []*models.TokenUpdateRequest
|
|
for _, u := range updates {
|
|
if u.IsAdmin {
|
|
adminUpdate = u
|
|
} else {
|
|
userUpdates = append(userUpdates, u)
|
|
}
|
|
}
|
|
|
|
// 2. Handle Admin Token Update
|
|
if adminUpdate != nil && adminUpdate.Token != "" {
|
|
encryptedToken, err := r.crypto.Encrypt(adminUpdate.Token)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt admin token: %w", err)
|
|
}
|
|
hash := sha256.Sum256([]byte(adminUpdate.Token))
|
|
tokenHash := hex.EncodeToString(hash[:])
|
|
|
|
// Update both encrypted value and the hash
|
|
updateData := map[string]interface{}{
|
|
"encrypted_token": encryptedToken,
|
|
"token_hash": tokenHash,
|
|
}
|
|
if err := tx.Model(&models.AuthToken{}).Where("is_admin = ?", true).Updates(updateData).Error; err != nil {
|
|
return fmt.Errorf("failed to update admin token in db: %w", err)
|
|
}
|
|
}
|
|
|
|
// 3. Handle User Tokens Upsert
|
|
var existingTokens []*models.AuthToken
|
|
if err := tx.Where("is_admin = ?", false).Find(&existingTokens).Error; err != nil {
|
|
return fmt.Errorf("failed to fetch existing user tokens: %w", err)
|
|
}
|
|
existingTokenMap := make(map[uint]bool)
|
|
for _, t := range existingTokens {
|
|
existingTokenMap[t.ID] = true
|
|
}
|
|
|
|
var tokensToUpsert []models.AuthToken
|
|
for _, req := range userUpdates {
|
|
if req.Token == "" {
|
|
continue // Skip tokens with empty values
|
|
}
|
|
encryptedToken, err := r.crypto.Encrypt(req.Token)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt token for upsert (ID: %d): %w", req.ID, err)
|
|
}
|
|
hash := sha256.Sum256([]byte(req.Token))
|
|
tokenHash := hex.EncodeToString(hash[:])
|
|
|
|
var groups []*models.KeyGroup
|
|
if len(req.AllowedGroupIDs) > 0 {
|
|
if err := tx.Find(&groups, req.AllowedGroupIDs).Error; err != nil {
|
|
return fmt.Errorf("failed to find key groups for token %d: %w", req.ID, err)
|
|
}
|
|
}
|
|
tokensToUpsert = append(tokensToUpsert, models.AuthToken{
|
|
ID: req.ID,
|
|
EncryptedToken: encryptedToken,
|
|
TokenHash: tokenHash,
|
|
Description: req.Description,
|
|
Tag: req.Tag,
|
|
Status: req.Status,
|
|
IsAdmin: false,
|
|
AllowedGroups: groups,
|
|
})
|
|
}
|
|
if len(tokensToUpsert) > 0 {
|
|
if err := tx.Save(&tokensToUpsert).Error; err != nil {
|
|
return fmt.Errorf("failed to upsert user tokens: %w", err)
|
|
}
|
|
}
|
|
|
|
// 4. Handle Deletions
|
|
incomingUserTokenIDs := make(map[uint]bool)
|
|
for _, u := range userUpdates {
|
|
if u.ID != 0 {
|
|
incomingUserTokenIDs[u.ID] = true
|
|
}
|
|
}
|
|
var idsToDelete []uint
|
|
for id := range existingTokenMap {
|
|
if !incomingUserTokenIDs[id] {
|
|
idsToDelete = append(idsToDelete, id)
|
|
}
|
|
}
|
|
if len(idsToDelete) > 0 {
|
|
if err := tx.Model(&models.AuthToken{}).Where("id IN ?", idsToDelete).Association("AllowedGroups").Clear(); err != nil {
|
|
return fmt.Errorf("failed to clear associations for tokens to be deleted: %w", err)
|
|
}
|
|
if err := tx.Where("id IN ?", idsToDelete).Delete(&models.AuthToken{}).Error; err != nil {
|
|
return fmt.Errorf("failed to delete user tokens: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// --- Crypto Helper Functions ---
|
|
|
|
func (r *gormAuthTokenRepository) decryptToken(token *models.AuthToken) error {
|
|
if token == nil || token.EncryptedToken == "" || token.Token != "" {
|
|
return nil // Nothing to decrypt or already done
|
|
}
|
|
plaintext, err := r.crypto.Decrypt(token.EncryptedToken)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decrypt auth token ID %d: %w", token.ID, err)
|
|
}
|
|
token.Token = plaintext
|
|
return nil
|
|
}
|
|
|
|
func (r *gormAuthTokenRepository) decryptTokens(tokens []*models.AuthToken) error {
|
|
for i := range tokens {
|
|
if err := r.decryptToken(tokens[i]); err != nil {
|
|
r.logger.Error(err) // Log error but continue for other tokens
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetTokenByHashedValue finds a token by its SHA256 hash for authentication.
|
|
func (r *gormAuthTokenRepository) GetTokenByHashedValue(tokenHash string) (*models.AuthToken, error) {
|
|
var authToken models.AuthToken
|
|
// Find the active token by its hash. This is the core of our secure authentication.
|
|
err := r.db.Where("token_hash = ? AND status = 'active'", tokenHash).Preload("AllowedGroups").First(&authToken).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// [CRITICAL] Decrypt the token before returning it to the service layer.
|
|
// This ensures that subsequent logic (like in ResourceService) gets the full, usable object.
|
|
if err := r.decryptToken(&authToken); err != nil {
|
|
return nil, err
|
|
}
|
|
return &authToken, nil
|
|
}
|
|
|
|
// SeedAdminToken is a special-purpose function for the seeder to insert the initial admin token.
|
|
func (r *gormAuthTokenRepository) SeedAdminToken(encryptedToken, tokenHash string) error {
|
|
adminToken := models.AuthToken{
|
|
EncryptedToken: encryptedToken,
|
|
TokenHash: tokenHash,
|
|
Description: "Default Administrator Token",
|
|
Tag: "SYSTEM_ADMIN",
|
|
IsAdmin: true,
|
|
Status: "active", // Ensure the seeded token is active
|
|
}
|
|
// Using FirstOrCreate to be idempotent. If an admin token already exists, it does nothing.
|
|
return r.db.Where(models.AuthToken{IsAdmin: true}).FirstOrCreate(&adminToken).Error
|
|
}
|