Compare commits

..

11 Commits

117 changed files with 36720 additions and 6391 deletions

View File

@@ -3,10 +3,12 @@ package main
import (
"gemini-balancer/internal/app"
"gemini-balancer/internal/container"
"gemini-balancer/internal/logging"
"log"
)
func main() {
defer logging.Close()
cont, err := container.BuildContainer()
if err != nil {
log.Fatalf("FATAL: Failed to build dependency container: %v", err)

View File

@@ -14,6 +14,12 @@ server:
log:
level: "debug"
# 日志轮转配置
max_size: 100 # MB
max_backups: 7 # 保留文件数
max_age: 30 # 保留天数
compress: true # 压缩旧日志
redis:
dsn: "redis://localhost:6379/0"

View File

@@ -5,4 +5,5 @@ esbuild ./frontend/js/main.js \
--outdir=./web/static/js \
--splitting \
--format=esm \
--loader:.css=css \
--watch=forever

13
frontend/css/flatpickr.min.css vendored Normal file

File diff suppressed because one or more lines are too long

View File

@@ -1,6 +1,6 @@
/* static/css/input.css */
@import "tailwindcss";
/* @import "./css/flatpickr.min.css"; */
/* =================================================================== */
/* [核心] 定义 shadcn/ui 的设计系统变量 */
/* =================================================================== */
@@ -19,7 +19,7 @@
--secondary: theme(colors.zinc.200);
--secondary-foreground: theme(colors.zinc.900);
--destructive: theme(colors.red.600);
--destructive: theme(colors.red.500);
--destructive-foreground: theme(colors.white);
--accent: theme(colors.zinc.100);
--accent-foreground: theme(colors.zinc.900);
@@ -69,10 +69,10 @@
@apply bg-primary text-primary-foreground hover:bg-primary/90;
}
.btn-secondary {
@apply bg-secondary text-secondary-foreground hover:bg-secondary/80;
@apply bg-secondary text-secondary-foreground border border-zinc-500/30 hover:bg-secondary/80;
}
.btn-destructive {
@apply bg-destructive text-destructive-foreground hover:bg-destructive/90;
@apply bg-destructive text-destructive-foreground border border-zinc-500/30 hover:bg-destructive/90;
}
.btn-outline {
@apply border border-input bg-background hover:bg-accent hover:text-accent-foreground;
@@ -83,7 +83,9 @@
.btn-link {
@apply text-primary underline-offset-4 hover:underline;
}
.btn-group-item.active {
@apply bg-primary text-primary-foreground;
}
/* 按钮尺寸变体 */
.btn-lg { @apply h-11 rounded-md px-8; }
.btn-md { @apply h-10 px-4 py-2; }
@@ -97,6 +99,155 @@
focus-visible:outline-none focus-visible:outline-2 focus-visible:outline-offset-2 focus-visible:outline-[(var(--ring))]
disabled:cursor-not-allowed disabled:opacity-50;
}
/* ------------------------------------------------ */
/* Custom Flatpickr Theme using shadcn/ui variables */
/* ------------------------------------------------ */
.flatpickr-calendar {
/* --- 主题样式 --- */
@apply bg-background text-foreground rounded-lg shadow-lg border border-zinc-500/30 w-auto font-sans;
animation: var(--animation-panel-in);
width: 200px;
/* --- 核心结构样式 --- */
display: none;
position: absolute;
visibility: hidden;
opacity: 0;
padding: 0;
z-index: 999;
box-sizing: border-box;
transition: opacity 0.15s ease-out, visibility 0.15s ease-out;
}
.flatpickr-calendar.open {
opacity: 1;
visibility: visible;
display: inline-block;
}
.flatpickr-calendar.not-ready {
top: 0;
left: 0;
visibility: hidden;
}
.flatpickr-calendar.static {
position: relative;
top: auto;
left: auto;
display: block;
visibility: visible;
opacity: 1;
}
.flatpickr-calendar.static { position: relative; top: auto; left: auto; display: block; visibility: visible; opacity: 1; }
.flatpickr-calendar.not-ready { top: 0; left: 0; visibility: hidden; }
/* 月份导航区域 */
.flatpickr-months {
@apply flex items-center bg-transparent p-0 border-b border-zinc-500/30;
}
.flatpickr-month { @apply h-auto pt-2 pb-1; }
.flatpickr-current-month {
@apply flex flex-1 items-center justify-center text-foreground font-semibold text-sm h-auto;
}
.flatpickr-current-month .numInputWrapper {
@apply ml-0;
}
.flatpickr-current-month .numInputWrapper input.numInput {
@apply w-14 text-center font-semibold bg-transparent border-0 p-0 text-sm text-foreground;
/* 移除默认的浏览器样式和聚焦时的轮廓 */
@apply appearance-none focus:outline-none focus:ring-0;
-moz-appearance: textfield;
}
/* 强制移除数字输入框在 Chrome, Safari, Edge 等 Webkit 浏览器中的上下箭头 */
.flatpickr-current-month .numInputWrapper input.numInput::-webkit-outer-spin-button,
.flatpickr-current-month .numInputWrapper input.numInput::-webkit-inner-spin-button {
-webkit-appearance: none;
margin: 0;
}
.flatpickr-current-month .cur-month { @apply font-semibold; }
.flatpickr-current-month .flatpickr-monthDropdown-months {
@apply w-22 font-semibold bg-transparent border-0 p-0 text-sm text-foreground text-right;
@apply appearance-none focus:outline-none focus:ring-0;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap; /* 确保月份名不换行 */
}
option.flatpickr-monthDropdown-month {
@apply bg-background text-foreground border-0;
@apply dark:bg-zinc-800 dark:text-zinc-200;
}
.flatpickr-current-month .flatpickr-monthDropdown-months,
.flatpickr-current-month .numInputWrapper input.numInput {
@apply text-sm pl-0 ;
}
/* 导航箭头 */
.flatpickr-prev-month,
.flatpickr-next-month {
@apply inline-flex items-center justify-center whitespace-nowrap rounded-md pt-2 pb-1 text-sm font-medium transition-colors
focus-visible:outline-none focus-visible:outline-2 focus-visible:outline-offset-2 focus-visible:outline-[(var(--ring))]
disabled:pointer-events-none disabled:opacity-50
hover:text-accent-foreground
h-7 w-7 shrink-0;
position: relative;
}
.flatpickr-prev-month svg,
.flatpickr-next-month svg {
@apply h-3 w-3 hover:h-4 hover:w-4 fill-zinc-500;
}
/* 星期标题 */
.flatpickr-weekdaycontainer {
@apply flex justify-around p-1;
}
span.flatpickr-weekday {
@apply flex-1 text-center text-muted-foreground font-medium;
font-size: 0.7rem;
}
/* 日期网格 */
.dayContainer {
@apply flex flex-wrap p-1 pt-0;
box-sizing: border-box;
}
.flatpickr-day {
@apply w-4 h-6.5 flex items-center justify-center rounded-full border-0 text-foreground transition-colors shrink-0; /* <--- 从 w-9 h-9 缩小 */
flex-basis: 14.2857%;
line-height: 1;
cursor: pointer;
font-size: 0.7rem; /* 介于 text-xs 和 text-sm 之间 */
}
.flatpickr-day:hover,
.flatpickr-day:focus { @apply bg-accent text-accent-foreground outline-none; }
.flatpickr-day.today { @apply border border-primary; }
.flatpickr-day.selected,
.flatpickr-day.startRange,
.flatpickr-day.endRange,
.flatpickr-day.selected:hover,
.flatpickr-day.startRange:hover,
.flatpickr-day.endRange:hover {
@apply bg-primary text-primary-foreground;
}
.flatpickr-day.inRange { @apply bg-accent rounded-none shadow-none; }
.flatpickr-day.startRange { @apply rounded-l-full; }
.flatpickr-day.endRange { @apply rounded-r-full; }
.flatpickr-day.disabled,
.flatpickr-day.disabled:hover { @apply bg-transparent text-muted-foreground/50 cursor-not-allowed; }
.flatpickr-day.nextMonthDay, .flatpickr-day.prevMonthDay { @apply text-muted-foreground/50; }
.flatpickr-day.nextMonthDay:hover, .flatpickr-day.prevMonthDay:hover { @apply bg-accent; }
/* 清除按钮 */
.flatpickr-calendar .flatpickr-clear-button {
@apply h-6 py-3 inline-flex items-center justify-center whitespace-nowrap rounded-md text-xs font-medium transition-colors
focus-visible:outline-none focus-visible:outline-2 focus-visible:outline-offset-2 focus-visible:outline-[(var(--ring))]
disabled:pointer-events-none disabled:opacity-50
text-primary underline-offset-4 hover:text-sm
w-full rounded-t-none border-t border-zinc-500/20;
}
}
@custom-variant dark (&:where(.dark, .dark *));
@@ -214,8 +365,8 @@
@apply w-[1.2rem] text-center;
@apply transition-all duration-300 ease-in-out;
/* 悬停和激活状态 */
@apply group-hover:text-[#60a5fa] group-hover:[filter:drop-shadow(0_0_5px_rgba(59,130,246,0.5))];
@apply group-data-[active='true']:text-[#60a5fa] group-data-[active='true']:[filter:drop-shadow(0_0_5px_rgba(59,130,246,0.7))];
@apply group-hover:text-[#60a5fa] group-hover:filter-[drop-shadow(0_0_5px_rgba(59,130,246,0.5))];
@apply group-data-[active='true']:text-[#60a5fa] group-data-[active='true']:filter-[drop-shadow(0_0_5px_rgba(59,130,246,0.7))];
}
/* 4. 指示器 */
.nav-indicator {
@@ -243,13 +394,13 @@
@apply flex items-start p-3 w-full rounded-lg shadow-lg ring-1 ring-black/5 dark:ring-white/10 bg-white/80 dark:bg-zinc-800/80 backdrop-blur-md pointer-events-auto;
}
.toast-icon {
@apply flex-shrink-0 w-8 h-8 rounded-full flex items-center justify-center text-white mr-3;
@apply shrink-0 w-8 h-8 rounded-full flex items-center justify-center text-white mr-3;
}
.toast-icon-loading {@apply bg-blue-500;}
.toast-icon-success {@apply bg-green-500;}
.toast-icon-error {@apply bg-red-500;}
.toast-content {
@apply flex-grow;
@apply grow
}
.toast-title {
@apply font-semibold text-sm text-zinc-800 dark:text-zinc-100;
@@ -273,7 +424,7 @@
/* --- 任务项主内容区 (左栏) --- */
.task-item-main {
@apply flex items-center justify-between flex-grow gap-1; /* flex-grow 使其占据所有可用空间 */
@apply flex items-center justify-between grow gap-1; /* flex-grow 使其占据所有可用空间 */
}
/* 2. 任务项头部: 包含标题和时间戳 */
.task-item-header {
@@ -285,7 +436,7 @@
}
.task-item-timestamp {
/* 融合了您原有的字体样式 */
@apply text-xs self-start pt-1.5 pl-2 text-zinc-400 dark:text-zinc-500 flex-shrink-0;
@apply text-xs self-start pt-1.5 pl-2 text-zinc-400 dark:text-zinc-500 shrink-0
}
/* 3. [新增] 阶段动画的核心容器 */
@@ -298,13 +449,13 @@
@apply flex items-center gap-2 p-1.5 rounded-md transition-all duration-300 ease-in-out relative;
}
.task-stage-icon {
@apply w-4 h-4 relative flex-shrink-0 text-zinc-400;
@apply w-4 h-4 relative shrink-0 text-zinc-400;
}
.task-stage-icon i {
@apply absolute top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 opacity-0 transition-opacity duration-200;
}
.task-stage-content {
@apply flex-grow flex justify-between items-baseline text-xs;
@apply grow justify-between items-baseline text-xs;
}
.task-stage-name {
@apply text-zinc-600 dark:text-zinc-400;
@@ -373,7 +524,7 @@
}
/* --- 4. 折叠/展开的雪佛兰图标 --- */
.task-toggle-icon {
@apply transition-transform duration-300 ease-in-out text-zinc-400 flex-shrink-0 ml-2;
@apply transition-transform duration-300 ease-in-out text-zinc-400 shrink-0 ml-2;
}
/* --- 5. 展开状态下的图标旋转 --- */
/*
@@ -470,7 +621,7 @@
* 2. 【新增】移动端首屏的 "当前分组" 选择器样式
*/
.mobile-group-selector {
@apply flex-grow flex items-center justify-between p-3 border border-zinc-200 dark:border-zinc-700 rounded-lg;
@apply grow flex items-center justify-between p-3 border border-zinc-200 dark:border-zinc-700 rounded-lg;
}
/* 移动端群组下拉列表样式 */
.mobile-group-menu-active {
@@ -621,7 +772,7 @@
/* Tag Input Component */
.tag-input-container {
@apply flex flex-wrap items-center gap-2 mt-1 w-full rounded-md bg-white dark:bg-zinc-700 border border-zinc-300 dark:border-zinc-600 p-2 min-h-[40px] focus-within:border-blue-500 focus-within:ring-1 focus-within:ring-blue-500;
@apply flex flex-wrap items-center gap-2 mt-1 w-full rounded-md bg-white dark:bg-zinc-700 border border-zinc-300 dark:border-zinc-600 p-2 min-h-10 focus-within:border-blue-500 focus-within:ring-1 focus-within:ring-blue-500;
}
.tag-item {
@apply flex items-center gap-x-1.5 bg-blue-100 dark:bg-blue-500/20 text-blue-800 dark:text-blue-200 text-sm font-medium rounded-full px-2.5 py-0.5;
@@ -631,7 +782,7 @@
}
.tag-input-new {
/* 使其在容器内垂直居中,感觉更好 */
@apply flex-grow bg-transparent focus:outline-none text-sm self-center;
@apply grow bg-transparent focus:outline-none text-sm self-center;
}
/* 为复制按钮提供基础样式 */
@@ -702,7 +853,7 @@
}
/* .tooltip-text is now dynamically generated by JS */
.global-tooltip {
@apply fixed z-[9999] w-max max-w-xs whitespace-normal rounded-lg bg-zinc-800 px-3 py-2 text-sm font-medium text-white shadow-lg transition-opacity duration-200;
@apply fixed z-9999 w-max max-w-xs whitespace-normal rounded-lg bg-zinc-800 px-3 py-2 text-sm font-medium text-white shadow-lg transition-opacity duration-200;
}
}
@@ -752,14 +903,18 @@
@apply text-4xl;
}
/* =================================================================== */
/* 自定义 table 样式 (Custom SweetAlert2 Styles) */
/* =================================================================== */
@layer components {
/* --- [新增] 可复用的表格组件样式 --- */
/* --- 可复用的表格组件样式 --- */
.table {
@apply w-full caption-bottom text-sm;
}
.table-header {
/* 使用语义化颜色,自动适应暗色模式 */
@apply sticky top-0 z-10 border-b border-border bg-muted/50;
@apply sticky top-0 z-10 border-b border-border bg-zinc-200 dark:bg-zinc-900;
}
.table-header .table-row {
/* 表头的 hover 效果通常与数据行不同,或者没有 */
@@ -769,7 +924,7 @@
@apply [&_tr:last-child]:border-0;
}
.table-row {
@apply border-b border-border transition-colors hover:bg-muted/80;
@apply border-b border-border transition-colors hover:bg-muted;
}
.table-head-cell {
@apply h-12 px-4 text-left align-middle font-medium text-muted-foreground;
@@ -777,4 +932,5 @@
.table-cell {
@apply p-4 align-middle;
}
}
}

View File

@@ -0,0 +1,129 @@
// Filename: frontend/js/components/customSelectV2.js
import { createPopper } from '../vendor/popper.esm.min.js';
export default class CustomSelectV2 {
constructor(container) {
this.container = container;
this.trigger = this.container.querySelector('.custom-select-trigger');
this.nativeSelect = this.container.querySelector('select');
this.template = this.container.querySelector('.custom-select-panel-template');
if (!this.trigger || !this.nativeSelect || !this.template) {
console.warn('CustomSelectV2 cannot initialize: missing required elements.', this.container);
return;
}
this.panel = null;
this.popperInstance = null;
this.isOpen = false;
this.triggerText = this.trigger.querySelector('span');
if (typeof CustomSelectV2.openInstance === 'undefined') {
CustomSelectV2.openInstance = null;
CustomSelectV2.initGlobalListener();
}
this.updateTriggerText();
this.bindEvents();
}
static initGlobalListener() {
document.addEventListener('click', (event) => {
const instance = CustomSelectV2.openInstance;
if (instance && !instance.container.contains(event.target) && (!instance.panel || !instance.panel.contains(event.target))) {
instance.close();
}
});
}
createPanel() {
const panelFragment = this.template.content.cloneNode(true);
this.panel = panelFragment.querySelector('.custom-select-panel');
document.body.appendChild(this.panel);
this.panel.innerHTML = '';
Array.from(this.nativeSelect.options).forEach(option => {
const item = document.createElement('a');
item.href = '#';
item.className = 'custom-select-option block w-full text-left px-3 py-1.5 text-sm text-zinc-700 hover:bg-zinc-100 dark:text-zinc-200 dark:hover:bg-zinc-700';
item.textContent = option.textContent;
item.dataset.value = option.value;
if (option.selected) { item.classList.add('is-selected'); }
this.panel.appendChild(item);
});
this.panel.addEventListener('click', (event) => {
event.preventDefault();
const optionEl = event.target.closest('.custom-select-option');
if (optionEl) { this.selectOption(optionEl); }
});
}
bindEvents() {
this.trigger.addEventListener('click', (event) => {
event.stopPropagation();
if (CustomSelectV2.openInstance && CustomSelectV2.openInstance !== this) {
CustomSelectV2.openInstance.close();
}
this.toggle();
});
}
selectOption(optionEl) {
const selectedValue = optionEl.dataset.value;
if (this.nativeSelect.value !== selectedValue) {
this.nativeSelect.value = selectedValue;
this.nativeSelect.dispatchEvent(new Event('change', { bubbles: true }));
}
this.updateTriggerText();
this.close();
}
updateTriggerText() {
const selectedOption = this.nativeSelect.options[this.nativeSelect.selectedIndex];
if (selectedOption) {
this.triggerText.textContent = selectedOption.textContent;
}
}
toggle() { this.isOpen ? this.close() : this.open(); }
open() {
if (this.isOpen) return;
this.isOpen = true;
if (!this.panel) { this.createPanel(); }
this.panel.style.display = 'block';
this.panel.offsetHeight;
this.popperInstance = createPopper(this.trigger, this.panel, {
placement: 'top-start',
modifiers: [
{ name: 'offset', options: { offset: [0, 8] } },
{ name: 'flip', options: { fallbackPlacements: ['bottom-start'] } }
],
});
CustomSelectV2.openInstance = this;
}
close() {
if (!this.isOpen) return;
this.isOpen = false;
if (this.popperInstance) {
this.popperInstance.destroy();
this.popperInstance = null;
}
if (this.panel) {
this.panel.remove();
this.panel = null;
}
if (CustomSelectV2.openInstance === this) {
CustomSelectV2.openInstance = null;
}
}
}

View File

@@ -0,0 +1,95 @@
// Filename: frontend/js/components/filterPopover.js
import { createPopper } from '../vendor/popper.esm.min.js';
export default class FilterPopover {
constructor(triggerElement, options, title) {
if (!triggerElement || typeof createPopper !== 'function') {
console.error('FilterPopover: Trigger element or Popper.js not found.');
return;
}
this.triggerElement = triggerElement;
this.options = options;
this.title = title;
this.selectedValues = new Set();
this._createPopoverHTML();
this.popperInstance = createPopper(this.triggerElement, this.popoverElement, {
placement: 'bottom-start',
modifiers: [{ name: 'offset', options: { offset: [0, 8] } }],
});
this._bindEvents();
}
_createPopoverHTML() {
this.popoverElement = document.createElement('div');
this.popoverElement.className = 'hidden z-50 min-w-[12rem] rounded-md border-1 border-zinc-500/30 bg-popover bg-white dark:bg-zinc-900 p-2 text-popover-foreground shadow-md';
this.popoverElement.innerHTML = `
<div class="px-2 py-1.5 text-sm font-semibold">${this.title}</div>
<div class="space-y-1 p-1">
${this.options.map(option => `
<label class="flex items-center space-x-2 px-2 py-1.5 rounded-md hover:bg-accent cursor-pointer">
<input type="checkbox" value="${option.value}" class="h-4 w-4 rounded border-zinc-300 text-blue-600 focus:ring-blue-500">
<span class="text-sm">${option.label}</span>
</label>
`).join('')}
</div>
<div class="border-t border-border mt-2 pt-2 px-2 flex justify-end space-x-2">
<button data-action="clear" class="btn btn-ghost h-7 px-2 text-xs">清空</button>
<button data-action="apply" class="btn btn-primary h-7 px-2 text-xs">应用</button>
</div>
`;
document.body.appendChild(this.popoverElement);
}
_bindEvents() {
this.triggerElement.addEventListener('click', () => this.toggle());
document.addEventListener('click', (event) => {
if (!this.popoverElement.contains(event.target) && !this.triggerElement.contains(event.target)) {
this.hide();
}
});
this.popoverElement.addEventListener('click', (event) => {
const target = event.target.closest('button');
if (!target) return;
const action = target.dataset.action;
if (action === 'clear') this._handleClear();
if (action === 'apply') this._handleApply();
});
}
_handleClear() {
this.popoverElement.querySelectorAll('input[type="checkbox"]').forEach(cb => cb.checked = false);
this.selectedValues.clear();
this._handleApply();
}
_handleApply() {
this.selectedValues.clear();
this.popoverElement.querySelectorAll('input:checked').forEach(cb => {
this.selectedValues.add(cb.value);
});
const filterChangeEvent = new CustomEvent('filter-change', {
detail: {
filterKey: this.triggerElement.id,
selected: this.selectedValues
}
});
this.triggerElement.dispatchEvent(filterChangeEvent);
this.hide();
}
toggle() {
this.popoverElement.classList.toggle('hidden');
this.popperInstance.update();
}
hide() {
this.popoverElement.classList.add('hidden');
}
}

View File

@@ -326,6 +326,41 @@ class UIPatterns {
});
}
}
/**
* Sets a button to a loading state by disabling it and showing a spinner.
* It stores the button's original content to be restored later.
* @param {HTMLButtonElement} button - The button element to modify.
*/
setButtonLoading(button) {
if (!button) return;
// Store original content if it hasn't been stored already
if (!button.dataset.originalContent) {
button.dataset.originalContent = button.innerHTML;
}
button.disabled = true;
button.innerHTML = '<i class="fas fa-spinner fa-spin"></i>';
}
/**
* Restores a button from its loading state to its original content and enables it.
* @param {HTMLButtonElement} button - The button element to restore.
*/
clearButtonLoading(button) {
if (!button) return;
if (button.dataset.originalContent) {
button.innerHTML = button.dataset.originalContent;
// Clean up the data attribute
delete button.dataset.originalContent;
}
button.disabled = false;
}
/**
* Returns the HTML for a streaming text cursor animation.
* This is used as a placeholder in the chat UI while waiting for an assistant's response.
* @returns {string} The HTML string for the loader.
*/
renderStreamingLoader() {
return '<span class="streaming-cursor animate-pulse">▋</span>';
}
}
/**

View File

@@ -1,5 +1,7 @@
// Filename: frontend/js/main.js
import Swal from './vendor/sweetalert2.esm.js';
import './vendor/sweetalert2.min.css';
import anime from './vendor/anime.esm.js';
// === 1. 导入通用组件 (这些是所有页面都可能用到的,保持静态导入) ===
import SlidingTabs from './components/slidingTabs.js';
import CustomSelect from './components/customSelect.js';
@@ -14,6 +16,7 @@ const pageModules = {
'dashboard': () => import('./pages/dashboard.js'),
'keys': () => import('./pages/keys/index.js'),
'logs': () => import('./pages/logs/index.js'),
'chat': () => import('./pages/chat/index.js'),
// 'settings': () => import('./pages/settings.js'), // 未来启用 settings 页面
// 未来新增的页面只需在这里添加一行映射esbuild会自动处理
};
@@ -48,3 +51,5 @@ window.modalManager = modalManager;
window.taskCenterManager = taskCenterManager;
window.toastManager = toastManager;
window.uiPatterns = uiPatterns;
window.Swal = Swal;
window.anime = anime;

View File

@@ -0,0 +1,178 @@
// Filename: frontend/js/pages/chat/SessionManager.js
import { nanoid } from 'https://cdn.jsdelivr.net/npm/nanoid/nanoid.js';
const LOCAL_STORAGE_KEY = 'gemini_chat_state';
/**
* Manages the state and persistence of chat sessions.
* This class handles loading from/saving to localStorage,
* and all operations like creating, switching, and deleting sessions.
*/
export class SessionManager {
constructor() {
this.state = null;
}
/**
* Initializes the manager by loading state from localStorage or creating a default state.
*/
init() {
this._loadState();
}
// --- Public API for state access ---
getSessions() {
return this.state.sessions;
}
getCurrentSessionId() {
return this.state.currentSessionId;
}
getCurrentSession() {
return this.state.sessions.find(s => s.id === this.state.currentSessionId);
}
// --- Public API for state mutation ---
/**
* Creates a new, empty session and sets it as the current one.
*/
createSession() {
const newSessionId = nanoid();
const newSession = {
id: newSessionId,
name: '新会话',
systemPrompt: '',
messages: [],
modelConfig: { model: 'gemini-2.0-flash-lite' },
params: { temperature: 0.7 }
};
this.state.sessions.unshift(newSession);
this.state.currentSessionId = newSessionId;
this._saveState();
}
/**
* Switches the current session to the one with the given ID.
* @param {string} sessionId The ID of the session to switch to.
*/
switchSession(sessionId) {
if (this.state.currentSessionId === sessionId) return;
this.state.currentSessionId = sessionId;
this._saveState();
}
/**
* Deletes a session by its ID.
* @param {string} sessionId The ID of the session to delete.
*/
deleteSession(sessionId) {
this.state.sessions = this.state.sessions.filter(s => s.id !== sessionId);
if (this.state.currentSessionId === sessionId) {
this.state.currentSessionId = this.state.sessions[0]?.id || null;
if (!this.state.currentSessionId) {
this._createInitialState(); // Create a new one if all are deleted
}
}
this._saveState();
}
/**
* [NEW] Clears all messages from the currently active session.
*/
clearCurrentSession() {
const currentSession = this.getCurrentSession();
if (currentSession) {
currentSession.messages = [];
this._saveState();
}
}
/**
* Adds a message to the current session and updates the session name if it's the first message.
* @param {object} message The message object to add.
* @returns {object} The session that was updated.
*/
addMessage(message) {
const currentSession = this.getCurrentSession();
if (currentSession) {
if (currentSession.messages.length === 0 && message.role === 'user') {
currentSession.name = message.content.substring(0, 30);
}
currentSession.messages.push(message);
this._saveState();
return currentSession;
}
return null;
}
deleteMessage(messageId) {
const currentSession = this.getCurrentSession();
if (currentSession) {
const messageIndex = currentSession.messages.findIndex(m => m.id === messageId);
if (messageIndex > -1) {
currentSession.messages.splice(messageIndex, 1);
this._saveState();
console.log(`Message ${messageId} deleted.`);
}
}
}
truncateMessagesAfter(messageId) {
const currentSession = this.getCurrentSession();
if (currentSession) {
const messageIndex = currentSession.messages.findIndex(m => m.id === messageId);
// Ensure the message exists and it's not already the last one
if (messageIndex > -1 && messageIndex < currentSession.messages.length - 1) {
currentSession.messages.splice(messageIndex + 1);
this._saveState();
console.log(`Truncated messages after ${messageId}.`);
}
}
}
// --- Private persistence methods ---
_saveState() {
try {
localStorage.setItem(LOCAL_STORAGE_KEY, JSON.stringify(this.state));
} catch (error) {
console.error("Failed to save session state:", error);
}
}
_loadState() {
try {
const stateString = localStorage.getItem(LOCAL_STORAGE_KEY);
if (stateString) {
this.state = JSON.parse(stateString);
} else {
this._createInitialState();
}
} catch (error) {
console.error("Failed to load session state, creating initial state:", error);
this._createInitialState();
}
}
_createInitialState() {
const initialSessionId = nanoid();
this.state = {
sessions: [{
id: initialSessionId,
name: '新会话',
systemPrompt: '',
messages: [],
modelConfig: { model: 'gemini-2.0-flash-lite' },
params: { temperature: 0.7 }
}],
currentSessionId: initialSessionId,
settings: {}
};
this._saveState();
}
}

View File

@@ -0,0 +1,113 @@
// Filename: frontend/js/pages/chat/chatSettings.js
export class ChatSettings {
constructor(elements) {
// [MODIFIED] Store the root elements passed from ChatPage
this.elements = {};
this.elements.root = elements; // Keep a reference to all elements
// [MODIFIED] Query for specific elements this class controls, relative to their panels
this._initScopedDOMElements();
// Initialize panel states to ensure they are collapsed on load
this.elements.quickSettingsPanel.style.gridTemplateRows = '0fr';
this.elements.sessionParamsPanel.style.gridTemplateRows = '0fr';
}
// [NEW] A dedicated method to find elements within their specific panels
_initScopedDOMElements() {
this.elements.quickSettingsPanel = this.elements.root.quickSettingsPanel;
this.elements.sessionParamsPanel = this.elements.root.sessionParamsPanel;
this.elements.toggleQuickSettingsBtn = this.elements.root.toggleQuickSettingsBtn;
this.elements.toggleSessionParamsBtn = this.elements.root.toggleSessionParamsBtn;
// Query elements within the quick settings panel
this.elements.btnGroups = this.elements.quickSettingsPanel.querySelectorAll('.btn-group');
this.elements.directRoutingOptions = this.elements.quickSettingsPanel.querySelector('#direct-routing-options');
// Query elements within the session params panel
this.elements.temperatureSlider = this.elements.sessionParamsPanel.querySelector('#temperature-slider');
this.elements.temperatureValue = this.elements.sessionParamsPanel.querySelector('#temperature-value');
this.elements.contextSlider = this.elements.sessionParamsPanel.querySelector('#context-slider');
this.elements.contextValue = this.elements.sessionParamsPanel.querySelector('#context-value');
}
init() {
if (!this.elements.toggleQuickSettingsBtn) {
console.warn("ChatSettings: Aborting initialization, required elements not found.");
return;
}
this._initPanelToggleListeners();
this._initButtonGroupListeners();
this._initSliderListeners();
}
_initPanelToggleListeners() {
this.elements.toggleQuickSettingsBtn.addEventListener('click', () =>
this._togglePanel(this.elements.quickSettingsPanel, this.elements.toggleQuickSettingsBtn)
);
this.elements.toggleSessionParamsBtn.addEventListener('click', () =>
this._togglePanel(this.elements.sessionParamsPanel, this.elements.toggleSessionParamsBtn)
);
}
_initButtonGroupListeners() {
// [FIXED] This logic is now guaranteed to work with the correctly scoped elements.
this.elements.btnGroups.forEach(group => {
group.addEventListener('click', (e) => {
const button = e.target.closest('.btn-group-item');
if (!button) return;
group.querySelectorAll('.btn-group-item').forEach(btn => btn.removeAttribute('data-active'));
button.setAttribute('data-active', 'true');
if (button.dataset.group === 'routing-mode') {
this._handleRoutingModeChange(button.dataset.value);
}
});
});
}
_initSliderListeners() {
// [FIXED] Add null checks for robustness, now that elements are queried scoped.
if (this.elements.temperatureSlider) {
this.elements.temperatureSlider.addEventListener('input', () => {
this.elements.temperatureValue.textContent = parseFloat(this.elements.temperatureSlider.value).toFixed(1);
});
}
if (this.elements.contextSlider) {
this.elements.contextSlider.addEventListener('input', () => {
this.elements.contextValue.textContent = `${this.elements.contextSlider.value}k`;
});
}
}
_handleRoutingModeChange(selectedValue) {
// [FIXED] This logic now correctly targets the scoped element.
if (this.elements.directRoutingOptions) {
if (selectedValue === 'direct') {
this.elements.directRoutingOptions.classList.remove('hidden');
} else {
this.elements.directRoutingOptions.classList.add('hidden');
}
}
}
_togglePanel(panel, button) {
const isExpanded = panel.hasAttribute('data-expanded');
this.elements.quickSettingsPanel.removeAttribute('data-expanded');
this.elements.sessionParamsPanel.removeAttribute('data-expanded');
this.elements.toggleQuickSettingsBtn.removeAttribute('data-active');
this.elements.toggleSessionParamsBtn.removeAttribute('data-active');
this.elements.quickSettingsPanel.style.gridTemplateRows = '0fr';
this.elements.sessionParamsPanel.style.gridTemplateRows = '0fr';
if (!isExpanded) {
panel.setAttribute('data-expanded', 'true');
button.setAttribute('data-active', 'true');
panel.style.gridTemplateRows = '1fr';
}
}
}

View File

@@ -0,0 +1,545 @@
// Filename: frontend/js/pages/chat/index.js
import { nanoid } from '../../vendor/nanoid.js';
import { uiPatterns } from '../../components/ui.js';
import { apiFetch } from '../../services/api.js';
import { marked } from '../../vendor/marked.min.js';
import { SessionManager } from './SessionManager.js';
import { ChatSettings } from './chatSettings.js';
import CustomSelectV2 from '../../components/customSelectV2.js';
marked.use({ breaks: true, gfm: true });
function getCookie(name) {
let cookieValue = null;
if (document.cookie && document.cookie !== '') {
const cookies = document.cookie.split(';');
for (let i = 0; i < cookies.length; i++) {
const cookie = cookies[i].trim();
if (cookie.substring(0, name.length + 1) === (name + '=')) {
cookieValue = decodeURIComponent(cookie.substring(name.length + 1));
break;
}
}
}
return cookieValue;
}
class ChatPage {
constructor() {
this.sessionManager = new SessionManager();
this.isStreaming = false;
this.elements = {};
this.initialized = false;
this.searchTerm = '';
this.settingsManager = null;
}
init() {
if (!document.querySelector('[data-page-id="chat"]')) { return; }
this.sessionManager.init();
this.initialized = true;
this._initDOMElements();
this._initComponents();
this._initEventListeners();
this._render();
console.log("ChatPage initialized. Session management is delegated.", this.sessionManager.state);
}
_initDOMElements() {
this.elements.chatScrollContainer = document.getElementById('chat-scroll-container');
this.elements.chatMessagesContainer = document.getElementById('chat-messages-container');
this.elements.messageForm = document.getElementById('message-form');
this.elements.messageInput = document.getElementById('message-input');
this.elements.sendBtn = document.getElementById('send-btn');
this.elements.newSessionBtn = document.getElementById('new-session-btn');
this.elements.sessionListContainer = document.getElementById('session-list-container');
this.elements.chatHeaderTitle = document.querySelector('.chat-header-title');
this.elements.clearSessionBtn = document.getElementById('clear-session-btn');
this.elements.sessionSearchInput = document.getElementById('session-search-input');
this.elements.toggleQuickSettingsBtn = document.getElementById('toggle-quick-settings');
this.elements.toggleSessionParamsBtn = document.getElementById('toggle-session-params');
this.elements.quickSettingsPanel = document.getElementById('quick-settings-panel');
this.elements.sessionParamsPanel = document.getElementById('session-params-panel');
this.elements.directRoutingOptions = document.getElementById('direct-routing-options');
this.elements.btnGroups = document.querySelectorAll('.btn-group');
this.elements.temperatureSlider = document.getElementById('temperature-slider');
this.elements.temperatureValue = document.getElementById('temperature-value');
this.elements.contextSlider = document.getElementById('context-slider');
this.elements.contextValue = document.getElementById('context-value');
this.elements.groupSelectContainer = document.getElementById('group-select-container');
}
// [NEW] A dedicated method for initializing complex UI components
_initComponents() {
if (this.elements.groupSelectContainer) {
new CustomSelectV2(this.elements.groupSelectContainer);
}
// In the future, we will initialize the model select component here as well
}
_initEventListeners() {
// --- Initialize Settings Manager First ---
this.settingsManager = new ChatSettings(this.elements);
this.settingsManager.init();
// --- Core Chat Event Listeners ---
this.elements.messageForm.addEventListener('submit', (e) => { e.preventDefault(); this._handleSendMessage(); });
this.elements.messageInput.addEventListener('keydown', (e) => { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); this._handleSendMessage(); } });
this.elements.messageInput.addEventListener('input', () => this._autoResizeTextarea());
this.elements.newSessionBtn.addEventListener('click', () => {
this.sessionManager.createSession();
this._render();
this.elements.messageInput.focus();
});
this.elements.sessionListContainer.addEventListener('click', (e) => {
const sessionItem = e.target.closest('[data-session-id]');
const deleteBtn = e.target.closest('.delete-session-btn');
if (deleteBtn) {
e.preventDefault();
const sessionId = deleteBtn.closest('[data-session-id]').dataset.sessionId;
this._handleDeleteSession(sessionId);
} else if (sessionItem) {
e.preventDefault();
const sessionId = sessionItem.dataset.sessionId;
this.sessionManager.switchSession(sessionId);
this._render();
this.elements.messageInput.focus();
}
});
this.elements.clearSessionBtn.addEventListener('click', () => this._handleClearSession());
this.elements.sessionSearchInput.addEventListener('input', (e) => {
this.searchTerm = e.target.value.trim();
this._renderSessionList();
});
this.elements.chatMessagesContainer.addEventListener('click', (e) => {
const messageElement = e.target.closest('[data-message-id]');
if (!messageElement) return;
const messageId = messageElement.dataset.messageId;
const copyBtn = e.target.closest('.action-copy');
const deleteBtn = e.target.closest('.action-delete');
const retryBtn = e.target.closest('.action-retry');
if (copyBtn) {
this._handleCopyMessage(messageId);
} else if (deleteBtn) {
this._handleDeleteMessage(messageId, e.target);
} else if (retryBtn) {
this._handleRetryMessage(messageId);
}
});
}
_handleCopyMessage(messageId) {
const currentSession = this.sessionManager.getCurrentSession();
if (!currentSession) return;
const message = currentSession.messages.find(m => m.id === messageId);
if (!message || !message.content) {
console.error("Message content not found for copying.");
return;
}
// Handle cases where content might be HTML (like error messages)
// by stripping tags to get plain text.
let textToCopy = message.content;
if (textToCopy.includes('<') && textToCopy.includes('>')) {
const tempDiv = document.createElement('div');
tempDiv.innerHTML = textToCopy;
textToCopy = tempDiv.textContent || tempDiv.innerText || '';
}
navigator.clipboard.writeText(textToCopy)
.then(() => {
Swal.fire({
toast: true,
position: 'top-end',
icon: 'success',
title: '已复制',
showConfirmButton: false,
timer: 1500,
customClass: {
popup: `${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}`
}
});
})
.catch(err => {
console.error('Failed to copy text: ', err);
Swal.fire({
toast: true,
position: 'top-end',
icon: 'error',
title: '复制失败',
showConfirmButton: false,
timer: 1500,
customClass: {
popup: `${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}`
}
});
});
}
_handleDeleteMessage(messageId, targetElement) {
// Remove any existing popover first to prevent duplicates
const existingPopover = document.getElementById('delete-confirmation-popover');
if (existingPopover) {
existingPopover.remove();
}
// Create the popover element with your specified dimensions.
const popover = document.createElement('div');
popover.id = 'delete-confirmation-popover';
// [MODIFIED] - Using your w-36, and adding flexbox classes for centering.
popover.className = 'absolute z-50 p-3 w-45 border border-border rounded-md shadow-lg bg-background text-popover-foreground flex flex-col items-center';
// [MODIFIED] - Added an icon and classes for horizontal centering.
popover.innerHTML = `
<div class="flex items-center gap-2 mb-2">
<i class="fas fa-exclamation-circle text-yellow-500"></i>
<p class="text-sm">确认删除此消息吗?</p>
</div>
<div class="flex translate-x-12 gap-2 w-full">
<button class="btn btn-secondary rounded-xs w-12 btn-xs popover-cancel">取消</button>
<button class="btn btn-destructive rounded-xs w-12 btn-xs popover-confirm">确认</button>
</div>
`;
document.body.appendChild(popover);
// Position the popover above the clicked icon
const iconRect = targetElement.closest('button').getBoundingClientRect();
const popoverRect = popover.getBoundingClientRect();
popover.style.top = `${window.scrollY + iconRect.top - popoverRect.height - 8}px`;
popover.style.left = `${window.scrollX + iconRect.left + (iconRect.width / 2) - (popoverRect.width / 2)}px`;
// Event listener to close the popover if clicked outside
const outsideClickListener = (event) => {
if (!popover.contains(event.target) && event.target !== targetElement) {
popover.remove();
document.removeEventListener('click', outsideClickListener);
}
};
setTimeout(() => document.addEventListener('click', outsideClickListener), 0);
// Event listeners for the buttons inside the popover
popover.querySelector('.popover-confirm').addEventListener('click', () => {
this.sessionManager.deleteMessage(messageId);
this._renderChatMessages();
this._renderSessionList();
popover.remove();
document.removeEventListener('click', outsideClickListener);
});
popover.querySelector('.popover-cancel').addEventListener('click', () => {
popover.remove();
document.removeEventListener('click', outsideClickListener);
});
}
_handleRetryMessage(messageId) {
if (this.isStreaming) return; // Prevent retrying while a response is already generating
const currentSession = this.sessionManager.getCurrentSession();
if (!currentSession) return;
const message = currentSession.messages.find(m => m.id === messageId);
if (!message) return;
if (message.role === 'user') {
// Logic for retrying from a user's prompt
this.sessionManager.truncateMessagesAfter(messageId);
} else if (message.role === 'assistant') {
// Logic for regenerating an assistant's response (must be the last one)
this.sessionManager.deleteMessage(messageId);
}
// After data manipulation, update the UI and trigger a new response
this._renderChatMessages();
this._renderSessionList();
this._getAssistantResponse();
}
_autoResizeTextarea() {
const el = this.elements.messageInput;
el.style.height = 'auto';
el.style.height = (el.scrollHeight) + 'px';
}
_handleSendMessage() {
if (this.isStreaming) return;
const content = this.elements.messageInput.value.trim();
if (!content) return;
const userMessage = { id: nanoid(), role: 'user', content: content };
this.sessionManager.addMessage(userMessage);
this._renderChatMessages();
this._renderSessionList();
this.elements.messageInput.value = '';
this._autoResizeTextarea();
this.elements.messageInput.focus();
this._getAssistantResponse();
}
async _getAssistantResponse() {
this.isStreaming = true;
this._setLoadingState(true);
const currentSession = this.sessionManager.getCurrentSession();
const assistantMessageId = nanoid();
let finalAssistantMessage = { id: assistantMessageId, role: 'assistant', content: '' };
// Step 1: Create and render a temporary UI placeholder for streaming.
// [MODIFIED] The placeholder now uses the three-dot animation.
const placeholderHtml = `
<div class="flex items-start gap-4" data-message-id="${assistantMessageId}">
<span class="flex h-8 w-8 shrink-0 items-center justify-center rounded-full bg-primary text-primary-foreground">
<i class="fas fa-robot"></i>
</span>
<div class="flex-1 space-y-2">
<div class="relative group rounded-lg p-5 bg-primary/10 border/20">
<div class="prose prose-sm max-w-none text-foreground message-content">
<div class="flex items-center gap-1">
<span class="h-2 w-2 bg-foreground/50 rounded-full animate-bounce" style="animation-delay: 0s;"></span>
<span class="h-2 w-2 bg-foreground/50 rounded-full animate-bounce" style="animation-delay: 0.1s;"></span>
<span class="h-2 w-2 bg-foreground/50 rounded-full animate-bounce" style="animation-delay: 0.2s;"></span>
</div>
</div>
</div>
</div>
</div>`;
this.elements.chatMessagesContainer.insertAdjacentHTML('beforeend', placeholderHtml);
this._scrollToBottom();
const assistantMessageContentEl = this.elements.chatMessagesContainer.querySelector(`[data-message-id="${assistantMessageId}"] .message-content`);
try {
const token = getCookie('gemini_admin_session');
const headers = { 'Authorization': `Bearer ${token}` };
const response = await apiFetch('/v1/chat/completions', {
method: 'POST',
headers,
body: JSON.stringify({
model: currentSession.modelConfig.model,
messages: currentSession.messages.filter(m => m.content).map(({ role, content }) => ({ role, content })),
stream: true,
})
});
if (!response.body) throw new Error("Response body is null.");
const reader = response.body.getReader();
const decoder = new TextDecoder();
while (true) {
const { value, done } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\n').filter(line => line.trim().startsWith('data:'));
for (const line of lines) {
const dataStr = line.replace(/^data: /, '').trim();
if (dataStr !== '[DONE]') {
try {
const data = JSON.parse(dataStr);
const deltaContent = data.choices[0]?.delta?.content;
if (deltaContent) {
finalAssistantMessage.content += deltaContent;
assistantMessageContentEl.innerHTML = marked.parse(finalAssistantMessage.content);
this._scrollToBottom();
}
} catch (e) { /* ignore malformed JSON */ }
}
}
}
} catch (error) {
console.error('Fetch stream error:', error);
const errorMessage = error.rawMessageFromServer || error.message;
finalAssistantMessage.content = `<span class="text-red-500">请求失败: ${errorMessage}</span>`;
} finally {
this.sessionManager.addMessage(finalAssistantMessage);
this._renderChatMessages();
this._renderSessionList();
this.isStreaming = false;
this._setLoadingState(false);
this.elements.messageInput.focus();
}
}
_renderMessage(message, replace = false, isLastMessage = false) {
let contentHtml;
if (message.role === 'user') {
const escapedContent = message.content.replace(/</g, "&lt;").replace(/>/g, "&gt;");
contentHtml = `<p class="text-sm text-foreground message-content">${escapedContent.replace(/\n/g, '<br>')}</p>`;
} else {
// [FIXED] Simplified logic: if it's an assistant message, it either has real content or error HTML.
// The isStreamingPlaceholder case is now handled differently and removed from here.
const isErrorHtml = message.content && message.content.includes('<span class="text-red-500">');
contentHtml = isErrorHtml ?
`<div class="message-content">${message.content}</div>` :
`<div class="prose prose-sm max-w-none text-foreground message-content">${marked.parse(message.content || '')}</div>`;
}
// [FIXED] No special handling for streaming placeholders needed anymore.
// If a message has content, it gets actions. An error message has content, so it will get actions.
let actionsHtml = '';
let retryButton = '';
if (message.role === 'user') {
retryButton = `
<button class="btn btn-ghost btn-icon w-6 h-6 hover:text-sky-500 action-retry rounded-full" title="从此消息重新生成">
<i class="fas fa-redo text-xs"></i>
</button>`;
} else if (message.role === 'assistant' && isLastMessage) {
// This now correctly applies to final error messages too.
retryButton = `
<button class="btn btn-ghost btn-icon w-6 h-6 hover:text-sky-500 action-retry rounded-full" title="重新生成回答">
<i class="fas fa-redo text-xs"></i>
</button>`;
}
const toolbarBaseClasses = "message-actions flex items-center gap-1 transition-opacity duration-200";
const toolbarPositionClass = isLastMessage ? "mt-2" : "absolute bottom-2.5 right-2.5";
const visibilityClass = isLastMessage ? "" : "opacity-0 group-hover:opacity-100";
actionsHtml = `
<div class="${toolbarBaseClasses} ${toolbarPositionClass} ${visibilityClass}">
${retryButton}
<button class="btn btn-ghost btn-icon w-6 h-6 hover:text-sky-500 action-copy rounded-full" title="复制">
<i class="far fa-copy text-xs"></i>
</button>
<button class="btn btn-ghost btn-icon w-6 h-6 hover:text-sky-500 action-delete rounded-full" title="删除">
<i class="far fa-trash-alt text-xs"></i>
</button>
</div>
`;
const messageBubbleClasses = `relative group rounded-lg p-5 ${message.role === 'user' ? 'bg-muted' : 'bg-primary/10 border/20'}`;
const messageHtml = `
<div class="flex items-start gap-4" data-message-id="${message.id}">
<span class="flex h-8 w-8 shrink-0 items-center justify-center rounded-full ${message.role === 'user' ? 'bg-secondary text-secondary-foreground' : 'bg-primary text-primary-foreground'}">
<i class="fas ${message.role === 'user' ? 'fa-user' : 'fa-robot'}"></i>
</span>
<div class="flex-1 space-y-2">
<div class="${messageBubbleClasses}">
${contentHtml}
${actionsHtml}
</div>
</div>
</div>`;
const existingElement = this.elements.chatMessagesContainer.querySelector(`[data-message-id="${message.id}"]`);
if (replace && existingElement) {
existingElement.outerHTML = messageHtml;
} else if (!existingElement) {
this.elements.chatMessagesContainer.insertAdjacentHTML('beforeend', messageHtml);
}
if (!replace) { this._scrollToBottom(); }
}
_scrollToBottom() {
this.elements.chatScrollContainer.scrollTop = this.elements.chatScrollContainer.scrollHeight;
}
_render() {
this._renderSessionList();
this._renderChatMessages();
this._renderChatHeader();
}
_renderSessionList() {
let sessions = this.sessionManager.getSessions();
const currentSessionId = this.sessionManager.getCurrentSessionId();
if (this.searchTerm) {
const lowerCaseSearchTerm = this.searchTerm.toLowerCase();
sessions = sessions.filter(session => {
const titleMatch = session.name.toLowerCase().includes(lowerCaseSearchTerm);
const messageMatch = session.messages.some(message =>
message.content && message.content.toLowerCase().includes(lowerCaseSearchTerm)
);
return titleMatch || messageMatch;
});
}
this.elements.sessionListContainer.innerHTML = sessions.map(session => {
const isActive = session.id === currentSessionId;
const lastMessage = session.messages.length > 0 ? session.messages[session.messages.length - 1].content : '新会话';
return `
<div class="relative group flex items-center" data-session-id="${session.id}">
<a href="#" class="grow flex flex-col items-start gap-2 rounded-lg p-3 text-left text-sm transition-all hover:bg-accent ${isActive ? 'bg-accent' : ''} min-w-0">
<div class="w-full font-semibold truncate pr-2">${session.name}</div>
<div class="w-full text-xs text-muted-foreground line-clamp-2">${session.messages.length > 0 ? (lastMessage.includes('<span class="text-red-500">') ? '[请求失败]' : lastMessage) : '新会话'}</div>
</a>
<button class="delete-session-btn absolute top-1/2 -translate-y-1/2 right-2 w-6 h-6 flex items-center justify-center rounded-full bg-muted text-muted-foreground opacity-0 group-hover:opacity-100 hover:bg-destructive/80 hover:text-destructive-foreground transition-all" aria-label="删除会话">
<i class="fas fa-times"></i>
</button>
</div>
`;
}).join('');
}
_renderChatMessages() {
this.elements.chatMessagesContainer.innerHTML = '';
const currentSession = this.sessionManager.getCurrentSession();
if (currentSession) {
const messages = currentSession.messages;
const lastMessageIndex = messages.length > 0 ? messages.length - 1 : -1;
messages.forEach((message, index) => {
const isLastMessage = (index === lastMessageIndex);
this._renderMessage(message, false, isLastMessage);
});
}
}
_renderChatHeader() {
const currentSession = this.sessionManager.getCurrentSession();
if (currentSession && this.elements.chatHeaderTitle) {
this.elements.chatHeaderTitle.textContent = currentSession.name;
}
}
_setLoadingState(isLoading) {
this.elements.messageInput.disabled = isLoading;
this.elements.sendBtn.disabled = isLoading;
if (isLoading) {
uiPatterns.setButtonLoading(this.elements.sendBtn);
} else {
uiPatterns.clearButtonLoading(this.elements.sendBtn);
}
}
_handleClearSession() {
Swal.fire({
width: '22rem',
backdrop: `rgba(0,0,0,0.5)`,
heightAuto: false,
customClass: { popup: `swal2-custom-style ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}` },
title: '确定要清空会话吗?',
text: '当前会话的所有聊天记录将被删除,但会话本身会保留。',
showCancelButton: true,
confirmButtonText: '确认清空',
cancelButtonText: '取消',
reverseButtons: false,
confirmButtonColor: '#ef4444',
cancelButtonColor: '#6b7280',
focusConfirm: false,
focusCancel: true,
}).then((result) => {
if (result.isConfirmed) {
this.sessionManager.clearCurrentSession(); // This method needs to be added to SessionManager
this._render();
}
});
}
_handleDeleteSession(sessionId) {
Swal.fire({
width: '22rem',
backdrop: `rgba(0,0,0,0.5)`,
heightAuto: false,
customClass: { popup: `swal2-custom-style ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}` },
title: '确定要删除吗?',
text: '此会话的所有记录将被永久删除,无法撤销。',
showCancelButton: true,
confirmButtonText: '确认删除',
cancelButtonText: '取消',
reverseButtons: false,
confirmButtonColor: '#ef4444',
cancelButtonColor: '#6b7280',
focusConfirm: false,
focusCancel: true,
}).then((result) => {
if (result.isConfirmed) {
this.sessionManager.deleteSession(sessionId);
this._render();
}
});
}
}
export default function() {
const page = new ChatPage();
page.init();
}

View File

@@ -427,7 +427,7 @@ class ApiKeyList {
contentHtml = `
<div class="task-item-main">
<div class="task-item-icon-summary text-red-500"><i class="fas fa-exclamation-triangle"></i></div>
<div class="task-item-content flex-grow">
<div class="task-item-content grow">
<p class="task-item-title">验证任务出错: ${maskedKey}</p>
<p class="task-item-status text-red-500 truncate" title="${safeError}">${safeError}</p>
</div>
@@ -443,7 +443,7 @@ class ApiKeyList {
contentHtml = `
<div class="task-item-main">
<div class="task-item-icon-summary"><i class="${iconClass}"></i></div>
<div class="task-item-content flex-grow">
<div class="task-item-content grow">
<p class="task-item-title">${title}: ${maskedKey}</p>
<p class="task-item-status truncate" title="${safeMessage}">${safeMessage}</p>
</div>
@@ -455,7 +455,7 @@ class ApiKeyList {
contentHtml = `
<div class="task-item-main gap-3">
<div class="task-item-icon task-item-icon-running"><i class="fas fa-spinner animate-spin"></i></div>
<div class="task-item-content flex-grow">
<div class="task-item-content grow">
<p class="task-item-title">正在验证: ${maskedKey}</p>
<p class="task-item-status">运行中... (${data.processed}/${data.total})</p>
</div>
@@ -495,7 +495,7 @@ class ApiKeyList {
data-mapping-id="${mappingId}">
<input type="checkbox" class="api-key-checkbox h-4 w-4 rounded border-zinc-300 text-blue-600 focus:ring-blue-500 shrink-0">
<span data-status-indicator class="w-2 h-2 rounded-full shrink-0"></span>
<div class="flex-grow min-w-0">
<div class="grow min-w-0">
<p class="font-mono text-xs font-semibold truncate">${maskedKey}</p>
<p class="text-xs text-zinc-400 mt-1">失败: ${errorCount} 次</p>
</div>
@@ -854,7 +854,7 @@ class ApiKeyList {
contentHtml = `
<div class="task-item-main gap-3">
<div class="task-item-icon task-item-icon-running"><i class="fas fa-spinner animate-spin"></i></div>
<div class="task-item-content flex-grow">
<div class="task-item-content grow">
<p class="task-item-title">批量验证 ${data.total} 个Key</p>
<p class="task-item-status">运行中... (${data.processed}/${data.total})</p>
</div>
@@ -867,7 +867,7 @@ class ApiKeyList {
contentHtml = `
<div class="task-item-main">
<div class="task-item-icon-summary text-red-500"><i class="fas fa-exclamation-triangle"></i></div>
<div class="task-item-content flex-grow">
<div class="task-item-content grow">
<p class="task-item-title">批量验证任务出错</p>
<p class="task-item-status text-red-500 truncate" title="${data.error}">${data.error}</p>
</div>
@@ -893,7 +893,7 @@ class ApiKeyList {
return `
<div class="flex items-start text-xs">
<i class="fas fa-check-circle text-green-500 mt-0.5 mr-2"></i>
<div class="flex-grow">
<div class="grow">
<p class="font-mono">${maskedKey}</p>
<p class="text-zinc-400">${safeMessage}</p>
</div>
@@ -902,7 +902,7 @@ class ApiKeyList {
return `
<div class="flex items-start text-xs">
<i class="fas fa-times-circle text-red-500 mt-0.5 mr-2"></i>
<div class="flex-grow">
<div class="grow">
<p class="font-mono">${maskedKey}</p>
<p class="text-zinc-400">${safeMessage}</p>
</div>
@@ -913,7 +913,7 @@ class ApiKeyList {
contentHtml = `
<div class="task-item-main">
<div class="task-item-icon-summary"><i class="${overallIconClass}"></i></div>
<div class="task-item-content flex-grow">
<div class="task-item-content grow">
<div class="flex justify-between items-center cursor-pointer" data-task-toggle>
<p class="task-item-title">${summaryTitle}</p>
<i class="fas fa-chevron-down task-toggle-icon"></i>
@@ -1135,6 +1135,8 @@ class ApiKeyList {
const actionConfig = this._getQuickActionConfig(action);
if (actionConfig && actionConfig.requiresConfirm) {
const result = await Swal.fire({
backdrop: `rgba(0,0,0,0.5)`,
heightAuto: false,
target: '#main-content-wrapper',
title: '请确认操作',
html: actionConfig.confirmText,
@@ -1246,7 +1248,7 @@ class ApiKeyList {
contentHtml = `
<div class="task-item-main gap-3">
<div class="task-item-icon task-item-icon-running"><i class="fas fa-spinner animate-spin"></i></div>
<div class="task-item-content flex-grow">
<div class="task-item-content grow">
<p class="task-item-title">${title}</p>
<p class="task-item-status">运行中... (${data.processed}/${data.total})</p>
</div>
@@ -1256,7 +1258,7 @@ class ApiKeyList {
contentHtml = `
<div class="task-item-main">
<div class="task-item-icon-summary text-red-500"><i class="fas fa-exclamation-triangle"></i></div>
<div class="task-item-content flex-grow">
<div class="task-item-content grow">
<p class="task-item-title">${title}任务出错</p>
<p class="task-item-status text-red-500 truncate" title="${safeError}">${safeError}</p>
</div>
@@ -1293,7 +1295,7 @@ class ApiKeyList {
contentHtml = `
<div class="task-item-main">
<div class="task-item-icon-summary"><i class="${iconClass}"></i></div>
<div class="task-item-content flex-grow">
<div class="task-item-content grow">
<p class="task-item-title">${title}</p>
<p class="task-item-status truncate" title="${safeSummary}">${safeSummary}</p>
</div>

View File

@@ -0,0 +1,151 @@
// Filename: frontend/js/components/BatchActions.js
import { apiFetchJson } from '../../services/api.js';
// 存储对 LogsPage 实例的引用
let logsPageInstance = null;
// 存储对 DOM 元素的引用
const elements = {
batchActionsBtn: null,
batchActionsMenu: null,
deleteSelectedLogsBtn: null,
clearAllLogsBtn: null,
};
// 用于处理页面点击以关闭菜单的绑定函数
const handleDocumentClick = (event) => {
if (!elements.batchActionsMenu.contains(event.target) && !elements.batchActionsBtn.contains(event.target)) {
closeBatchActionsMenu();
}
};
// 关闭菜单的逻辑
function closeBatchActionsMenu() {
if (elements.batchActionsMenu && !elements.batchActionsMenu.classList.contains('hidden')) {
elements.batchActionsMenu.classList.remove('opacity-100', 'scale-100');
elements.batchActionsMenu.classList.add('opacity-0', 'scale-95');
setTimeout(() => {
elements.batchActionsMenu.classList.add('hidden');
}, 100);
document.removeEventListener('click', handleDocumentClick);
}
}
// 切换菜单显示/隐藏
function handleBatchActionsToggle(event) {
event.stopPropagation();
const isHidden = elements.batchActionsMenu.classList.contains('hidden');
if (isHidden) {
elements.batchActionsMenu.classList.remove('hidden', 'opacity-0', 'scale-95');
elements.batchActionsMenu.classList.add('opacity-100', 'scale-100');
document.addEventListener('click', handleDocumentClick);
} else {
closeBatchActionsMenu();
}
}
// 处理删除选中日志的逻辑
async function handleDeleteSelectedLogs() {
closeBatchActionsMenu();
const selectedIds = Array.from(logsPageInstance.state.selectedLogIds);
if (selectedIds.length === 0) return;
Swal.fire({
width: '20rem',
backdrop: `rgba(0,0,0,0.5)`,
heightAuto: false,
customClass: { popup: `swal2-custom-style ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}` },
title: '确认删除',
text: `您确定要删除选定的 ${selectedIds.length} 条日志吗?此操作不可撤销。`,
icon: 'warning',
showCancelButton: true,
confirmButtonText: '确认删除',
cancelButtonText: '取消',
reverseButtons: false,
confirmButtonColor: '#ef4444',
cancelButtonColor: '#6b7280',
focusCancel: true,
target: '#main-content-wrapper',
}).then(async (result) => {
if (result.isConfirmed) {
try {
const idsQueryString = selectedIds.join(',');
const url = `/admin/logs?ids=${idsQueryString}`;
const { success, message } = await apiFetchJson(url, { method: 'DELETE' });
if (success) {
Swal.fire({ toast: true, position: 'top-end', icon: 'success', title: '删除成功', showConfirmButton: false, timer: 2000, timerProgressBar: true });
logsPageInstance.loadAndRenderLogs(); // 使用实例刷新列表
} else {
throw new Error(message || '删除失败,请稍后重试。');
}
} catch (error) {
Swal.fire({ icon: 'error', title: '操作失败', text: error.message, target: '#main-content-wrapper' });
}
}
});
}
// [NEW] 处理清空所有日志的逻辑
async function handleClearAllLogs() {
closeBatchActionsMenu();
Swal.fire({
width: '20rem',
backdrop: `rgba(0,0,0,0.5)`,
heightAuto: false,
customClass: { popup: `swal2-custom-style ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}` },
title: '危险操作确认',
html: `您确定要<strong class="text-red-500">清空全部</strong>日志吗?<br>此操作不可撤销!`,
icon: 'warning',
showCancelButton: true,
confirmButtonText: '确认清空',
cancelButtonText: '取消',
reverseButtons: false,
confirmButtonColor: '#ef4444',
cancelButtonColor: '#6b7280',
focusCancel: true,
target: '#main-content-wrapper',
}).then(async (result) => {
if (result.isConfirmed) {
try {
const url = `/admin/logs/all`; // 后端清空所有日志的接口
const { success, message } = await apiFetchJson(url, { method: 'DELETE' });
if (success) {
Swal.fire({ toast: true, position: 'top-end', icon: 'success', title: '清空成功', showConfirmButton: false, timer: 2000, timerProgressBar: true });
logsPageInstance.loadAndRenderLogs(); // 刷新列表
} else {
throw new Error(message || '清空失败,请稍后重试。');
}
} catch (error) {
Swal.fire({ icon: 'error', title: '操作失败', text: error.message, target: '#main-content-wrapper' });
}
}
});
}
/**
* 初始化批量操作模块
* @param {object} logsPage - LogsPage 类的实例
*/
export function initBatchActions(logsPage) {
logsPageInstance = logsPage;
// 查询所有需要的 DOM 元素
elements.batchActionsBtn = document.getElementById('batch-actions-btn');
elements.batchActionsMenu = document.getElementById('batch-actions-menu');
elements.deleteSelectedLogsBtn = document.getElementById('delete-selected-logs-btn');
elements.clearAllLogsBtn = document.getElementById('clear-all-logs-btn'); // [NEW] 查询新按钮
if (!elements.batchActionsBtn) return; // 如果找不到主按钮,则不进行任何操作
// 绑定事件监听器
elements.batchActionsBtn.addEventListener('click', handleBatchActionsToggle);
if (elements.deleteSelectedLogsBtn) {
elements.deleteSelectedLogsBtn.addEventListener('click', handleDeleteSelectedLogs);
}
if (elements.clearAllLogsBtn) { // [NEW] 绑定新按钮的事件
elements.clearAllLogsBtn.addEventListener('click', handleClearAllLogs);
}
}
// [NEW] - END

View File

@@ -1,13 +1,19 @@
// Filename: frontend/js/pages/logs/index.js
import { apiFetchJson } from '../../services/api.js';
import LogList from './logList.js';
import CustomSelectV2 from '../../components/customSelectV2.js';
import { debounce } from '../../utils/utils.js';
import FilterPopover from '../../components/filterPopover.js';
import { STATIC_ERROR_MAP, STATUS_CODE_MAP } from './logList.js';
import SystemLogTerminal from './systemLog.js';
import { initBatchActions } from './batchActions.js';
import flatpickr from '../../vendor/flatpickr.js';
import LogSettingsModal from './logSettingsModal.js';
// [最终版] 创建一个共享的数据仓库,用于缓存 Groups 和 Keys
const dataStore = {
groups: new Map(),
keys: new Map(),
};
groups: new Map(),
keys: new Map(),
};
class LogsPage {
constructor() {
@@ -15,32 +21,597 @@ class LogsPage {
logs: [],
pagination: { page: 1, pages: 1, total: 0, page_size: 20 },
isLoading: true,
filters: { page: 1, page_size: 20 }
filters: {
page: 1,
page_size: 20,
q: '',
key_ids: new Set(),
group_ids: new Set(),
error_types: new Set(),
status_codes: new Set(),
start_date: null,
end_date: null,
},
selectedLogIds: new Set(),
currentView: 'error',
};
this.elements = {
tableBody: document.getElementById('logs-table-body'),
tabsContainer: document.querySelector('[data-sliding-tabs-container]'),
contentContainer: document.getElementById('log-content-container'),
errorFilters: document.getElementById('error-logs-filters'),
systemControls: document.getElementById('system-logs-controls'),
errorTemplate: document.getElementById('error-logs-template'),
systemTemplate: document.getElementById('system-logs-template'),
settingsBtn: document.querySelector('button[aria-label="日志设置"]'),
};
this.initialized = !!this.elements.tableBody;
this.initialized = !!this.elements.contentContainer;
if (this.initialized) {
this.logList = new LogList(this.elements.tableBody, dataStore);
this.logList = null;
this.systemLogTerminal = null;
this.debouncedLoadAndRender = debounce(() => this.loadAndRenderLogs(), 300);
this.fp = null;
this.themeObserver = null;
this.settingsModal = null;
this.currentSettings = {};
}
}
async init() {
if (!this.initialized) return;
this.initEventListeners();
// 页面初始化:先加载群组,再加载日志
this._initPermanentEventListeners();
await this.loadCurrentSettings();
this._initSettingsModal();
await this.loadGroupsOnce();
await this.loadAndRenderLogs();
this.state.currentView = null;
this.switchToView('error');
}
initEventListeners() { /* 分页和筛选的事件监听器 */ }
_initSettingsModal() {
if (!this.elements.settingsBtn) return;
this.settingsModal = new LogSettingsModal({
onSave: this.handleSaveSettings.bind(this)
});
this.elements.settingsBtn.addEventListener('click', () => {
const settingsForModal = {
log_level: this.currentSettings.log_level,
auto_cleanup: {
enabled: this.currentSettings.log_auto_cleanup_enabled,
retention_days: this.currentSettings.log_auto_cleanup_retention_days,
exec_time: this.currentSettings.log_auto_cleanup_time,
interval: 'daily',
}
};
this.settingsModal.open(settingsForModal);
});
}
async loadCurrentSettings() {
try {
const { success, data } = await apiFetchJson('/admin/settings');
if (success) {
this.currentSettings = data;
} else {
console.error('Failed to load settings from server.');
this.currentSettings = { log_auto_cleanup_time: '04:05' };
}
} catch (error) {
console.error('Failed to load log settings:', error);
this.currentSettings = { log_auto_cleanup_time: '04:05' };
}
}
async handleSaveSettings(settingsData) {
const partialPayload = {
"log_level": settingsData.log_level,
"log_auto_cleanup_enabled": settingsData.auto_cleanup.enabled,
"log_auto_cleanup_time": settingsData.auto_cleanup.exec_time,
};
if (settingsData.auto_cleanup.enabled) {
let retentionDays = settingsData.auto_cleanup.retention_days;
if (retentionDays === null || retentionDays <= 0) {
retentionDays = 30;
}
partialPayload.log_auto_cleanup_retention_days = retentionDays;
}
console.log('Sending PARTIAL settings update to /admin/settings:', partialPayload);
try {
const { success, message } = await apiFetchJson('/admin/settings', {
method: 'PUT',
body: JSON.stringify(partialPayload)
});
if (!success) {
throw new Error(message || 'Failed to save settings');
}
Object.assign(this.currentSettings, partialPayload);
} catch (error) {
console.error('Error saving log settings:', error);
throw error;
}
}
_initPermanentEventListeners() {
this.elements.tabsContainer.addEventListener('click', (event) => {
const tabItem = event.target.closest('[data-tab-target]');
if (!tabItem) return;
event.preventDefault();
const viewName = tabItem.dataset.tabTarget;
if (viewName) {
this.switchToView(viewName);
}
});
}
switchToView(viewName) {
if (this.state.currentView === viewName && this.elements.contentContainer.innerHTML !== '') return;
if (this.systemLogTerminal) {
this.systemLogTerminal.disconnect();
this.systemLogTerminal = null;
}
if (this.fp) {
this.fp.destroy();
this.fp = null;
}
if (this.themeObserver) {
this.themeObserver.disconnect();
this.themeObserver = null;
}
this.state.currentView = viewName;
this.elements.contentContainer.innerHTML = '';
const isErrorView = viewName === 'error';
this.elements.errorFilters.style.display = isErrorView ? 'flex' : 'none';
this.elements.systemControls.style.display = isErrorView ? 'none' : 'flex';
if (isErrorView) {
const template = this.elements.errorTemplate.content.cloneNode(true);
this.elements.contentContainer.appendChild(template);
requestAnimationFrame(() => {
this._initErrorLogView();
});
} else if (viewName === 'system') {
const template = this.elements.systemTemplate.content.cloneNode(true);
this.elements.contentContainer.appendChild(template);
requestAnimationFrame(() => {
this._initSystemLogView();
});
}
}
_initErrorLogView() {
this.elements.tableBody = document.getElementById('logs-table-body');
this.elements.selectedCount = document.querySelector('.flex-1.text-sm span.font-semibold:nth-child(1)');
this.elements.totalCount = document.querySelector('.flex-1.text-sm span:last-child');
this.elements.pageSizeSelect = document.querySelector('[data-component="custom-select-v2"] select');
this.elements.pageInfo = document.querySelector('.flex.w-\\[100px\\]');
this.elements.paginationBtns = document.querySelectorAll('[data-pagination-controls] button');
this.elements.selectAllCheckbox = document.querySelector('thead .table-head-cell input[type="checkbox"]');
this.elements.searchInput = document.getElementById('log-search-input');
this.elements.errorTypeFilterBtn = document.getElementById('filter-error-type-btn');
this.elements.errorCodeFilterBtn = document.getElementById('filter-error-code-btn');
this.elements.dateRangeFilterBtn = document.getElementById('filter-date-range-btn');
this.logList = new LogList(this.elements.tableBody, dataStore);
const selectContainer = document.querySelector('[data-component="custom-select-v2"]');
if (selectContainer) { new CustomSelectV2(selectContainer); }
this.initFilterPopovers();
this.initDateRangePicker();
this.initEventListeners();
this._observeThemeChanges();
initBatchActions(this);
this.loadAndRenderLogs();
}
_observeThemeChanges() {
const applyTheme = () => {
if (!this.fp || !this.fp.calendarContainer) return;
if (document.documentElement.classList.contains('dark')) {
this.fp.calendarContainer.classList.add('dark');
} else {
this.fp.calendarContainer.classList.remove('dark');
}
};
this.themeObserver = new MutationObserver((mutationsList) => {
for (const mutation of mutationsList) {
if (mutation.type === 'attributes' && mutation.attributeName === 'class') {
applyTheme();
}
}
});
this.themeObserver.observe(document.documentElement, { attributes: true });
applyTheme();
}
_initSystemLogView() {
this.systemLogTerminal = new SystemLogTerminal(
this.elements.contentContainer,
this.elements.systemControls
);
Swal.fire({
width: '20rem',
backdrop: `rgba(0,0,0,0.5)`,
heightAuto: false,
customClass: { popup: `swal2-custom-style ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}` },
title: '系统终端日志',
text: '您即将连接到实时系统日志流窗口。',
showCancelButton: true,
confirmButtonText: '确认',
cancelButtonText: '取消',
reverseButtons: false,
confirmButtonColor: 'rgba(31, 102, 255, 0.8)',
cancelButtonColor: '#6b7280',
focusConfirm: false,
focusCancel: false,
target: '#main-content-wrapper',
}).then((result) => {
if (result.isConfirmed) {
this.systemLogTerminal.connect();
} else {
const errorLogTab = Array.from(this.elements.tabsContainer.querySelectorAll('[data-tab-target="error"]'))[0];
if (errorLogTab) errorLogTab.click();
}
});
}
initFilterPopovers() {
const errorTypeOptions = [
...Object.values(STATUS_CODE_MAP).map(v => ({ value: v.type, label: v.type })),
...Object.values(STATIC_ERROR_MAP).map(v => ({ value: v.type, label: v.type }))
];
const uniqueErrorTypeOptions = Array.from(new Map(errorTypeOptions.map(item => [item.value, item])).values());
if (this.elements.errorTypeFilterBtn) {
new FilterPopover(this.elements.errorTypeFilterBtn, uniqueErrorTypeOptions, '筛选错误类型');
}
const statusCodeOptions = Object.keys(STATUS_CODE_MAP).map(code => ({ value: code, label: code }));
if (this.elements.errorCodeFilterBtn) {
new FilterPopover(this.elements.errorCodeFilterBtn, statusCodeOptions, '筛选状态码');
}
}
initDateRangePicker() {
if (!this.elements.dateRangeFilterBtn) return;
const buttonTextSpan = this.elements.dateRangeFilterBtn.querySelector('span');
const originalText = buttonTextSpan.textContent;
this.fp = flatpickr(this.elements.dateRangeFilterBtn, {
mode: 'range',
dateFormat: 'Y-m-d',
onClose: (selectedDates) => {
if (selectedDates.length === 2) {
const [start, end] = selectedDates;
end.setHours(23, 59, 59, 999);
this.state.filters.start_date = start.toISOString();
this.state.filters.end_date = end.toISOString();
const startDateStr = start.toISOString().split('T')[0];
const endDateStr = end.toISOString().split('T')[0];
buttonTextSpan.textContent = `${startDateStr} ~ ${endDateStr}`;
this.elements.dateRangeFilterBtn.classList.add('!border-primary', '!text-primary');
this.state.filters.page = 1;
this.loadAndRenderLogs();
}
},
onReady: (selectedDates, dateStr, instance) => {
if (document.documentElement.classList.contains('dark')) {
instance.calendarContainer.classList.add('dark');
}
const clearButton = document.createElement('button');
clearButton.textContent = '清除';
clearButton.className = 'button flatpickr-button flatpickr-clear-button';
clearButton.addEventListener('click', (e) => {
e.preventDefault();
instance.clear();
this.state.filters.start_date = null;
this.state.filters.end_date = null;
buttonTextSpan.textContent = originalText;
this.elements.dateRangeFilterBtn.classList.remove('!border-primary', '!text-primary');
this.state.filters.page = 1;
this.loadAndRenderLogs();
instance.close();
});
instance.calendarContainer.appendChild(clearButton);
const nativeMonthSelect = instance.monthsDropdownContainer;
if (!nativeMonthSelect) return;
const monthYearContainer = nativeMonthSelect.parentElement;
const wrapper = document.createElement('div');
wrapper.className = 'custom-select-v2-container relative inline-block text-left';
wrapper.innerHTML = `
<button type="button" class="custom-select-trigger inline-flex justify-center items-center w-22 gap-x-1.5 rounded-md bg-transparent px-1 py-1 text-xs font-semibold text-foreground shadow-sm ring-0 ring-inset ring-input hover:bg-accent focus:outline-none focus:ring-1 focus:ring-offset-1 focus:ring-offset-background focus:ring-primary" aria-haspopup="true">
<span class="truncate"></span>
</button>
`;
const template = document.createElement('template');
template.className = 'custom-select-panel-template';
template.innerHTML = `
<div class="custom-select-panel absolute z-1000 my-2 w-24 origin-top-right rounded-md bg-popover dark:bg-zinc-900 shadow-lg ring-1 ring-zinc-500/30 ring-opacity-5 focus:outline-none" role="menu" aria-orientation="vertical" tabindex="-1">
</div>
`;
nativeMonthSelect.classList.add('hidden');
wrapper.appendChild(nativeMonthSelect);
wrapper.appendChild(template);
monthYearContainer.prepend(wrapper);
const customSelect = new CustomSelectV2(wrapper);
instance.customMonthSelect = customSelect;
},
onMonthChange: (selectedDates, dateStr, instance) => {
if (instance.customMonthSelect) {
instance.customMonthSelect.updateTriggerText();
}
},
});
}
initEventListeners() {
if (this.elements.pageSizeSelect) {
this.elements.pageSizeSelect.addEventListener('change', (e) => this.changePageSize(parseInt(e.target.value, 10)));
}
if (this.elements.paginationBtns.length >= 4) {
this.elements.paginationBtns[0].addEventListener('click', () => this.goToPage(1));
this.elements.paginationBtns[1].addEventListener('click', () => this.goToPage(this.state.pagination.page - 1));
this.elements.paginationBtns[2].addEventListener('click', () => this.goToPage(this.state.pagination.page + 1));
this.elements.paginationBtns[3].addEventListener('click', () => this.goToPage(this.state.pagination.pages));
}
if (this.elements.selectAllCheckbox) {
this.elements.selectAllCheckbox.addEventListener('change', (event) => this.handleSelectAllChange(event));
}
if (this.elements.tableBody) {
this.elements.tableBody.addEventListener('click', (event) => {
const checkbox = event.target.closest('input[type="checkbox"]');
const actionButton = event.target.closest('button[data-action]');
if (checkbox) {
this.handleSelectionChange(checkbox);
} else if (actionButton) {
this._handleLogRowAction(actionButton);
}
});
}
if (this.elements.searchInput) {
this.elements.searchInput.addEventListener('input', (event) => this.handleSearchInput(event));
}
if (this.elements.errorTypeFilterBtn) {
this.elements.errorTypeFilterBtn.addEventListener('filter-change', (e) => this.handleFilterChange(e));
}
if (this.elements.errorCodeFilterBtn) {
this.elements.errorCodeFilterBtn.addEventListener('filter-change', (e) => this.handleFilterChange(e));
}
}
handleFilterChange(event) {
const { filterKey, selected } = event.detail;
if (filterKey === 'filter-error-type-btn') {
this.state.filters.error_types = selected;
} else if (filterKey === 'filter-error-code-btn') {
this.state.filters.status_codes = selected;
}
this.state.filters.page = 1;
this.loadAndRenderLogs();
}
handleSearchInput(event) {
const searchTerm = event.target.value.trim().toLowerCase();
this.state.filters.page = 1;
this.state.filters.q = '';
this.state.filters.key_ids = new Set();
this.state.filters.group_ids = new Set();
if (searchTerm === '') {
this.debouncedLoadAndRender();
return;
}
const matchedGroupIds = new Set();
dataStore.groups.forEach(group => {
if (group.display_name.toLowerCase().includes(searchTerm)) {
matchedGroupIds.add(group.id);
}
});
const matchedKeyIds = new Set();
dataStore.keys.forEach(key => {
if (key.APIKey && key.APIKey.toLowerCase().includes(searchTerm)) {
matchedKeyIds.add(key.ID);
}
});
if (matchedGroupIds.size > 0) this.state.filters.group_ids = matchedGroupIds;
if (matchedKeyIds.size > 0) this.state.filters.key_ids = matchedKeyIds;
if (matchedGroupIds.size === 0 && matchedKeyIds.size === 0) {
this.state.filters.q = searchTerm;
}
this.debouncedLoadAndRender();
}
handleSelectionChange(checkbox) {
const row = checkbox.closest('.table-row');
if (!row) return;
const logId = parseInt(row.dataset.logId, 10);
if (isNaN(logId)) return;
if (checkbox.checked) {
this.state.selectedLogIds.add(logId);
} else {
this.state.selectedLogIds.delete(logId);
}
this.syncSelectionUI();
}
handleSelectAllChange(event) {
const isChecked = event.target.checked;
this.state.logs.forEach(log => {
if (isChecked) {
this.state.selectedLogIds.add(log.ID);
} else {
this.state.selectedLogIds.delete(log.ID);
}
});
this.syncRowCheckboxes();
this.syncSelectionUI();
}
syncRowCheckboxes() {
const isAllChecked = this.elements.selectAllCheckbox.checked;
this.elements.tableBody.querySelectorAll('input[type="checkbox"]').forEach(cb => {
cb.checked = isAllChecked;
});
}
syncSelectionUI() {
if (!this.elements.selectAllCheckbox || !this.elements.selectedCount) return;
const selectedCount = this.state.selectedLogIds.size;
const visibleLogsCount = this.state.logs.length;
if (selectedCount === 0) {
this.elements.selectAllCheckbox.checked = false;
this.elements.selectAllCheckbox.indeterminate = false;
} else if (selectedCount < visibleLogsCount) {
this.elements.selectAllCheckbox.checked = false;
this.elements.selectAllCheckbox.indeterminate = true;
} else if (selectedCount === visibleLogsCount && visibleLogsCount > 0) {
this.elements.selectAllCheckbox.checked = true;
this.elements.selectAllCheckbox.indeterminate = false;
}
this.elements.selectedCount.textContent = selectedCount;
const hasSelection = selectedCount > 0;
const deleteSelectedBtn = document.getElementById('delete-selected-logs-btn');
if (deleteSelectedBtn) {
deleteSelectedBtn.disabled = !hasSelection;
}
}
async _handleLogRowAction(button) {
const action = button.dataset.action;
const row = button.closest('.table-row');
const isDarkMode = document.documentElement.classList.contains('dark');
if (!row) return;
const logId = parseInt(row.dataset.logId, 10);
const log = this.state.logs.find(l => l.ID === logId);
if (!log) {
Swal.fire({ toast: true, position: 'top-end', icon: 'error', title: '找不到日志数据', showConfirmButton: false, timer: 2000 });
return;
}
switch (action) {
case 'view-log-details': {
const detailsHtml = `
<div class="space-y-3 text-left text-sm p-2">
<div class="flex"><p class="w-24 font-semibold text-zinc-500 shrink-0">状态码</p><p class="font-mono text-zinc-800 dark:text-zinc-200">${log.StatusCode || 'N/A'}</p></div>
<div class="flex"><p class="w-24 font-semibold text-zinc-500 shrink-0">状态</p><p class="font-mono text-zinc-800 dark:text-zinc-200">${log.Status || 'N/A'}</p></div>
<div class="flex"><p class="w-24 font-semibold text-zinc-500 shrink-0">模型</p><p class="font-mono text-zinc-800 dark:text-zinc-200">${log.ModelName || 'N/A'}</p></div>
<div class="border-t border-zinc-200 dark:border-zinc-700 my-2"></div>
<div>
<p class="font-semibold text-zinc-500 mb-1">错误消息</p>
<div class="max-h-40 overflow-y-auto bg-zinc-100 dark:bg-zinc-800 p-2 rounded-md text-zinc-700 dark:text-zinc-300 wrap-break-word text-xs">
${log.ErrorMessage ? log.ErrorMessage.replace(/\n/g, '<br>') : '无错误消息。'}
</div>
</div>
</div>
`;
Swal.fire({
target: '#main-content-wrapper',
width: '32rem',
backdrop: `rgba(0,0,0,0.5)`,
heightAuto: false,
customClass: {
popup: `swal2-custom-style rounded-xl ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}`,
title: 'text-lg font-bold',
htmlContainer: 'm-0 text-left',
},
title: '日志详情',
html: detailsHtml,
showCloseButton: false,
showConfirmButton: false,
});
break;
}
case 'copy-api-key': {
const key = dataStore.keys.get(log.KeyID);
if (key && key.APIKey) {
navigator.clipboard.writeText(key.APIKey).then(() => {
Swal.fire({ toast: true, position: 'top-end', customClass: { popup: `swal2-custom-style ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}` }, icon: 'success', title: 'API Key 已复制', showConfirmButton: false, timer: 1500 });
}).catch(err => {
Swal.fire({ toast: true, position: 'top-end', icon: 'error', title: '复制失败', text: err.message, showConfirmButton: false, timer: 2000 });
});
} else {
Swal.fire({ toast: true, position: 'top-end', icon: 'warning', title: '未找到完整的API Key', showConfirmButton: false, timer: 2000 });
return;
}
if (navigator.clipboard && window.isSecureContext) {
navigator.clipboard.writeText(key.APIKey).then(() => {
Swal.fire({ toast: true, position: 'top-end', icon: 'success', title: 'API Key 已复制', showConfirmButton: false, timer: 1500 });
}).catch(err => {
Swal.fire({ toast: true, position: 'top-end', icon: 'error', title: '复制失败', text: err.message, showConfirmButton: false, timer: 2000 });
});
} else {
// 如果不可用,则提供明确的错误提示
Swal.fire({
icon: 'error',
title: '复制失败',
text: '此功能需要安全连接 (HTTPS) 或在 localhost 环境下使用。',
target: '#main-content-wrapper',
customClass: { popup: `swal2-custom-style ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}` },
});
}
break;
}
case 'delete-log': {
Swal.fire({
width: '20rem',
backdrop: `rgba(0,0,0,0.5)`,
heightAuto: false,
customClass: { popup: `swal2-custom-style ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}` },
title: '确认删除',
text: `您确定要删除这条日志吗?此操作不可撤销。`,
showCancelButton: true,
confirmButtonText: '确认删除',
cancelButtonText: '取消',
reverseButtons: false,
confirmButtonColor: '#ef4444',
cancelButtonColor: '#6b7280',
focusCancel: true,
target: '#main-content-wrapper',
}).then(async (result) => {
if (result.isConfirmed) {
try {
const url = `/admin/logs?ids=${logId}`;
const { success, message } = await apiFetchJson(url, { method: 'DELETE' });
if (success) {
Swal.fire({ toast: true, position: 'top-end', icon: 'success', title: '删除成功', showConfirmButton: false, timer: 2000, timerProgressBar: true });
this.loadAndRenderLogs();
} else {
throw new Error(message || '删除失败,请稍后重试。');
}
} catch (error) {
Swal.fire({ icon: 'error', title: '操作失败', text: error.message, target: '#main-content-wrapper' });
}
}
});
break;
}
}
}
changePageSize(newSize) {
this.state.filters.page_size = newSize;
this.state.filters.page = 1;
this.loadAndRenderLogs();
}
goToPage(page) {
if (page < 1 || page > this.state.pagination.pages || this.state.isLoading) return;
this.state.filters.page = page;
this.loadAndRenderLogs();
}
updatePaginationUI() {
const { page, pages, total } = this.state.pagination;
if (this.elements.pageInfo) {
this.elements.pageInfo.textContent = `${page} / ${pages}`;
}
if (this.elements.totalCount) {
this.elements.totalCount.textContent = total;
}
if (this.elements.paginationBtns.length >= 4) {
const isFirstPage = page === 1;
const isLastPage = page === pages || pages === 0;
this.elements.paginationBtns[0].disabled = isFirstPage;
this.elements.paginationBtns[1].disabled = isFirstPage;
this.elements.paginationBtns[2].disabled = isLastPage;
this.elements.paginationBtns[3].disabled = isLastPage;
}
}
async loadGroupsOnce() {
if (dataStore.groups.size > 0) return; // 防止重复加载
if (dataStore.groups.size > 0) return;
try {
const { success, data } = await apiFetchJson("/admin/keygroups");
if (success && Array.isArray(data)) {
@@ -53,33 +624,81 @@ class LogsPage {
async loadAndRenderLogs() {
this.state.isLoading = true;
this.logList.renderLoading();
this.state.selectedLogIds.clear();
this.logList.renderLoading();
this.updatePaginationUI();
this.syncSelectionUI();
try {
const query = new URLSearchParams(this.state.filters);
const { success, data } = await apiFetchJson(`/admin/logs?${query.toString()}`);
const finalParams = {};
const { filters } = this.state;
Object.keys(filters).forEach(key => {
if (!(filters[key] instanceof Set)) {
finalParams[key] = filters[key];
}
});
// --- [MODIFIED] START: Combine all error-related filters into a single parameter for OR logic ---
const allErrorCodes = new Set();
const allStatusCodes = new Set(filters.status_codes);
if (filters.error_types.size > 0) {
filters.error_types.forEach(type => {
// Find matching static error codes (e.g., 'API_KEY_INVALID')
for (const [code, obj] of Object.entries(STATIC_ERROR_MAP)) {
if (obj.type === type) {
allErrorCodes.add(code);
}
}
// Find matching status codes (e.g., 400, 401)
for (const [code, obj] of Object.entries(STATUS_CODE_MAP)) {
if (obj.type === type) {
allStatusCodes.add(code);
}
}
});
}
if (success && typeof data === 'object') {
// Pass the combined codes to the backend. The backend will handle the OR logic.
if (allErrorCodes.size > 0) finalParams.error_codes = [...allErrorCodes].join(',');
if (allStatusCodes.size > 0) finalParams.status_codes = [...allStatusCodes].join(',');
// --- [MODIFIED] END ---
if (filters.key_ids.size > 0) finalParams.key_ids = [...filters.key_ids].join(',');
if (filters.group_ids.size > 0) finalParams.group_ids = [...filters.group_ids].join(',');
Object.keys(finalParams).forEach(key => {
if (finalParams[key] === '' || finalParams[key] === null || finalParams[key] === undefined) {
delete finalParams[key];
}
});
const query = new URLSearchParams(finalParams);
const { success, data } = await apiFetchJson(
`/admin/logs?${query.toString()}`,
{ cache: 'no-cache', noCache: true }
);
if (success && typeof data === 'object' && data.items) {
const { items, total, page, page_size } = data;
this.state.logs = items;
this.state.pagination = { page, page_size, total, pages: Math.ceil(total / page_size) };
// [核心] 在渲染前按需批量加载本页日志所需的、尚未缓存的Key信息
const totalPages = Math.ceil(total / page_size);
this.state.pagination = { page, page_size, total, pages: totalPages > 0 ? totalPages : 1 };
await this.enrichLogsWithKeyNames(items);
// 调用 render此时 dataStore 中已包含所有需要的数据
this.logList.render(this.state.logs, this.state.pagination);
this.logList.render(this.state.logs, this.state.pagination, this.state.selectedLogIds);
} else {
this.state.logs = [];
this.state.pagination = { ...this.state.pagination, total: 0, pages: 1, page: 1 };
this.logList.render([], this.state.pagination);
}
} catch (error) {
console.error("Failed to load logs:", error);
this.state.logs = [];
this.state.pagination = { ...this.state.pagination, total: 0, pages: 1, page: 1 };
this.logList.render([], this.state.pagination);
} finally {
this.state.isLoading = false;
this.updatePaginationUI();
this.syncSelectionUI();
}
}
async enrichLogsWithKeyNames(logs) {
const missingKeyIds = [...new Set(
logs.filter(log => log.KeyID && !dataStore.keys.has(log.KeyID)).map(log => log.KeyID)

View File

@@ -1,7 +1,7 @@
// Filename: frontend/js/pages/logs/logList.js
import { escapeHTML } from '../../utils/utils.js';
const STATIC_ERROR_MAP = {
export const STATIC_ERROR_MAP = {
'API_KEY_INVALID': { type: '密钥无效', style: 'red' },
'INVALID_ARGUMENT': { type: '参数无效', style: 'red' },
'PERMISSION_DENIED': { type: '权限不足', style: 'red' },
@@ -13,8 +13,8 @@ const STATIC_ERROR_MAP = {
'INTERNAL': { type: 'Google内部错误', style: 'yellow' },
'UNAVAILABLE': { type: '服务不可用', style: 'yellow' },
};
// --- [更新] HTTP状态码到类型和样式的动态映射表 ---
const STATUS_CODE_MAP = {
// --- HTTP状态码到类型和样式的动态映射表 ---
export const STATUS_CODE_MAP = {
400: { type: '错误请求', style: 'red' },
401: { type: '认证失败', style: 'red' },
403: { type: '禁止访问', style: 'red' },
@@ -55,7 +55,7 @@ class LogList {
this.container.innerHTML = `<tr><td colspan="9" class="p-8 text-center text-muted-foreground"><i class="fas fa-spinner fa-spin mr-2"></i> 加载日志中...</td></tr>`;
}
render(logs, pagination) {
render(logs, pagination, selectedLogIds) {
if (!this.container) return;
if (!logs || logs.length === 0) {
this.container.innerHTML = `<tr><td colspan="9" class="p-8 text-center text-muted-foreground">没有找到相关的日志记录。</td></tr>`;
@@ -63,7 +63,10 @@ class LogList {
}
const { page, page_size } = pagination;
const startIndex = (page - 1) * page_size;
const logsHtml = logs.map((log, index) => this.createLogRowHtml(log, startIndex + index + 1)).join('');
const logsHtml = logs.map((log, index) => {
const isChecked = selectedLogIds.has(log.ID);
return this.createLogRowHtml(log, startIndex + index + 1, isChecked);
}).join('');
this.container.innerHTML = logsHtml;
}
@@ -75,7 +78,7 @@ class LogList {
statusCodeHtml: `<span class="inline-flex items-center rounded-md bg-green-500/10 px-2 py-1 text-xs font-medium text-green-600">成功</span>`
};
}
// 2. [新增] 特殊场景优先判断 (结合ErrorCode和ErrorMessage)
// 2. 特殊场景优先判断 (结合ErrorCode和ErrorMessage)
const codeMatch = log.ErrorCode ? log.ErrorCode.match(errorCodeRegex) : null;
if (codeMatch && codeMatch[1] && log.ErrorMessage) {
const code = parseInt(codeMatch[1], 10);
@@ -125,7 +128,7 @@ class LogList {
return `<div class="inline-block rounded bg-zinc-100 dark:bg-zinc-800 px-2 py-0.5"><span class="font-quinquefive text-xs tracking-wider ${styleClass}">${modelName}</span></div>`;
}
createLogRowHtml(log, index) {
createLogRowHtml(log, index, isChecked) {
const group = this.dataStore.groups.get(log.GroupID);
const groupName = group ? group.display_name : (log.GroupID ? `Group #${log.GroupID}` : 'N/A');
const key = this.dataStore.keys.get(log.KeyID);
@@ -140,9 +143,13 @@ class LogList {
const modelNameFormatted = this._formatModelName(log.ModelName);
const errorMessageAttr = log.ErrorMessage ? `data-error-message="${escape(log.ErrorMessage)}"` : '';
const requestTime = new Date(log.RequestTime).toLocaleString();
const checkedAttr = isChecked ? 'checked' : '';
return `
<tr class="table-row" data-log-id="${log.ID}" ${errorMessageAttr}>
<td class="table-cell"><input type="checkbox" class="h-4 w-4 rounded border-zinc-300 text-blue-600 focus:ring-blue-500"></td>
<tr class="table-row group even:bg-zinc-200/30 dark:even:bg-black/10" data-log-id="${log.ID}" ${errorMessageAttr}>
<td class="table-cell">
<input type="checkbox" class="h-4 w-4 rounded border-zinc-300 text-blue-600 focus:ring-blue-500" ${checkedAttr}>
</td>
<td class="table-cell font-mono text-muted-foreground">${index}</td>
<td class="table-cell font-medium font-mono">${apiKeyDisplay}</td>
<td class="table-cell">${groupName}</td>
@@ -150,14 +157,29 @@ class LogList {
<td class="table-cell">${errorInfo.statusCodeHtml}</td>
<td class="table-cell">${modelNameFormatted}</td>
<td class="table-cell text-muted-foreground text-xs">${requestTime}</td>
<td class="table-cell">
<button class="btn btn-ghost btn-icon btn-sm" aria-label="查看详情">
<i class="fas fa-ellipsis-h h-4 w-4"></i>
</button>
<td class="table-cell relative">
<!-- [MODIFIED] - 2. 替换原有按钮为悬浮操作菜单 -->
<div class="flex items-center justify-center">
<!-- 默认显示的图标 -->
<span class="text-zinc-400 group-hover:opacity-0 transition-opacity">
<i class="fas fa-ellipsis-h h-4 w-4"></i>
</span>
<!-- 悬浮时显示的操作按钮 -->
<div class="absolute right-2 top-1/2 -translate-y-1/2 flex items-center bg-zinc-100 dark:bg-zinc-700 rounded-full shadow-md opacity-0 group-hover:opacity-100 transition-opacity duration-200 z-10">
<button class="px-2 py-1 text-zinc-500 hover:text-blue-500" data-action="view-log-details" title="查看详情">
<i class="fas fa-eye"></i>
</button>
<button class="px-2 py-1 text-zinc-500 hover:text-green-500" data-action="copy-api-key" title="复制APIKey">
<i class="fas fa-copy"></i>
</button>
<button class="px-2 py-1 text-zinc-500 hover:text-red-500" data-action="delete-log" title="删除日志">
<i class="fas fa-trash-alt"></i>
</button>
</div>
</div>
</td>
</tr>
`;
}
}
export default LogList;

View File

@@ -0,0 +1,148 @@
// Filename: frontend/js/pages/logs/logList.js
import { modalManager } from '../../components/ui.js';
export default class LogSettingsModal {
constructor({ onSave }) {
this.modalId = 'log-settings-modal';
this.onSave = onSave;
const modal = document.getElementById(this.modalId);
if (!modal) {
throw new Error(`Modal with id "${this.modalId}" not found.`);
}
this.elements = {
modal: modal,
title: document.getElementById('log-settings-modal-title'),
saveBtn: document.getElementById('log-settings-save-btn'),
logLevelSelect: document.getElementById('log-level-select'),
cleanupEnableToggle: document.getElementById('log-cleanup-enable'),
cleanupSettingsPanel: document.getElementById('log-cleanup-settings'),
cleanupRetentionInput: document.getElementById('log-cleanup-retention-days'),
retentionDaysGroup: document.getElementById('retention-days-group'),
retentionPresetBtns: document.querySelectorAll('#retention-days-group button[data-days]'),
cleanupExecTimeInput: document.getElementById('log-cleanup-exec-time'), // [NEW] 添加时间选择器元素
};
this.activePresetClasses = ['!bg-primary', '!text-primary-foreground', '!border-primary', 'hover:!bg-primary/90'];
this.inactivePresetClasses = ['modal-btn-secondary'];
this._initEventListeners();
}
open(settingsData = {}) {
this._populateForm(settingsData);
modalManager.show(this.modalId);
}
close() {
modalManager.hide(this.modalId);
}
_initEventListeners() {
this.elements.saveBtn.addEventListener('click', this._handleSave.bind(this));
this.elements.cleanupEnableToggle.addEventListener('change', (e) => {
this.elements.cleanupSettingsPanel.classList.toggle('hidden', !e.target.checked);
});
this._initRetentionPresets();
const closeAction = () => this.close();
const closeTriggers = this.elements.modal.querySelectorAll(`[data-modal-close="${this.modalId}"]`);
closeTriggers.forEach(trigger => trigger.addEventListener('click', closeAction));
this.elements.modal.addEventListener('click', (event) => {
if (event.target === this.elements.modal) closeAction();
});
}
_initRetentionPresets() {
this.elements.retentionPresetBtns.forEach(btn => {
btn.addEventListener('click', () => {
const days = btn.dataset.days;
this.elements.cleanupRetentionInput.value = days;
this._updateActivePresetButton(days);
});
});
this.elements.cleanupRetentionInput.addEventListener('input', (e) => {
this._updateActivePresetButton(e.target.value);
});
}
_updateActivePresetButton(currentValue) {
this.elements.retentionPresetBtns.forEach(btn => {
if (btn.dataset.days === currentValue) {
btn.classList.remove(...this.inactivePresetClasses);
btn.classList.add(...this.activePresetClasses);
} else {
btn.classList.remove(...this.activePresetClasses);
btn.classList.add(...this.inactivePresetClasses);
}
});
}
async _handleSave() {
const data = this._collectFormData();
if (data.auto_cleanup.enabled && (!data.auto_cleanup.retention_days || data.auto_cleanup.retention_days <= 0)) {
alert('启用自动清理时保留天数必须是大于0的数字。');
return;
}
if (this.onSave) {
this.elements.saveBtn.disabled = true;
this.elements.saveBtn.textContent = '保存中...';
try {
await this.onSave(data);
this.close();
} catch (error) {
console.error("Failed to save log settings:", error);
// 可以添加一个UI提示比如 toast
} finally {
this.elements.saveBtn.disabled = false;
this.elements.saveBtn.textContent = '保存设置';
}
}
}
// [MODIFIED] - 更新此方法以填充新的时间选择器
_populateForm(data) {
this.elements.logLevelSelect.value = data.log_level || 'INFO';
const cleanup = data.auto_cleanup || {};
const isCleanupEnabled = cleanup.enabled || false;
this.elements.cleanupEnableToggle.checked = isCleanupEnabled;
this.elements.cleanupSettingsPanel.classList.toggle('hidden', !isCleanupEnabled);
const retentionDays = cleanup.retention_days || '';
this.elements.cleanupRetentionInput.value = retentionDays;
this._updateActivePresetButton(retentionDays.toString());
// [NEW] 填充执行时间,提供一个安全的默认值
this.elements.cleanupExecTimeInput.value = cleanup.exec_time || '04:05';
}
// [MODIFIED] - 更新此方法以收集新的时间数据
_collectFormData() {
const parseIntOrNull = (value) => {
const trimmed = value.trim();
if (trimmed === '') return null;
const num = parseInt(trimmed, 10);
return isNaN(num) ? null : num;
};
const isCleanupEnabled = this.elements.cleanupEnableToggle.checked;
const formData = {
log_level: this.elements.logLevelSelect.value,
auto_cleanup: {
enabled: isCleanupEnabled,
interval: isCleanupEnabled ? 'daily' : null,
retention_days: isCleanupEnabled ? parseIntOrNull(this.elements.cleanupRetentionInput.value) : null,
exec_time: isCleanupEnabled ? this.elements.cleanupExecTimeInput.value : '04:05', // [NEW] 收集时间数据
},
};
return formData;
}
}

View File

@@ -0,0 +1,176 @@
// Filename: frontend/js/pages/logs/systemLog.js
export default class SystemLogTerminal {
constructor(container, controlsContainer) {
this.container = container;
this.controlsContainer = controlsContainer;
this.ws = null;
this.isPaused = false;
this.shouldAutoScroll = true;
this.reconnectAttempts = 0;
this.maxReconnectAttempts = 5;
this.isConnected = false;
this.elements = {
output: this.container.querySelector('#log-terminal-output'),
statusIndicator: this.controlsContainer.querySelector('#terminal-status-indicator'),
clearBtn: this.controlsContainer.querySelector('[data-action="clear-terminal"]'),
pauseBtn: this.controlsContainer.querySelector('[data-action="toggle-pause-terminal"]'),
scrollBtn: this.controlsContainer.querySelector('[data-action="toggle-scroll-terminal"]'),
connectBtn: this.controlsContainer.querySelector('[data-action="toggle-connect-terminal"]'),
settingsBtn: this.controlsContainer.querySelector('[data-action="terminal-settings"]'),
};
this._initEventListeners();
}
_initEventListeners() {
this.elements.clearBtn.addEventListener('click', () => this.clear());
this.elements.pauseBtn.addEventListener('click', () => this.togglePause());
this.elements.scrollBtn.addEventListener('click', () => this.toggleAutoScroll());
this.elements.connectBtn.addEventListener('click', () => this.toggleConnect());
this.elements.settingsBtn.addEventListener('click', () => this.openSettings());
}
toggleConnect() {
if (this.isConnected) {
this.disconnect();
} else {
this.connect();
}
}
connect() {
this.clear();
this._appendMessage('info', '正在连接到实时日志流...');
this._updateStatus('connecting', '连接中...');
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsUrl = `${protocol}//${window.location.host}/ws/system-logs`;
this.ws = new WebSocket(wsUrl);
this.ws.onopen = () => {
this._appendMessage('info', '✓ 已连接到系统日志流');
this._updateStatus('connected', '已连接');
this.reconnectAttempts = 0;
this.isConnected = true;
this.elements.connectBtn.title = '断开';
this.elements.connectBtn.querySelector('i').classList.replace('fa-plug', 'fa-minus-circle');
};
this.ws.onmessage = (event) => {
if (this.isPaused) return;
try {
const data = JSON.parse(event.data);
const levelColors = {
'error': 'text-red-500',
'warning': 'text-yellow-400',
'info': 'text-green-400',
'debug': 'text-zinc-400'
};
const color = levelColors[data.level] || 'text-zinc-200';
const timestamp = new Date(data.timestamp).toLocaleTimeString();
const msg = `[${timestamp}] [${data.level.toUpperCase()}] ${data.message}`;
this._appendMessage(color, msg);
} catch (e) {
this._appendMessage('text-zinc-200', event.data);
}
};
this.ws.onerror = (error) => {
this._appendMessage('error', `✗ WebSocket 错误`);
this._updateStatus('error', '连接错误');
};
this.ws.onclose = () => {
this._appendMessage('error', '✗ 连接已断开');
this._updateStatus('disconnected', '未连接');
this.isConnected = false;
this.elements.connectBtn.title = '连接';
this.elements.connectBtn.querySelector('i').classList.replace('fa-minus-circle', 'fa-plug');
if (this.reconnectAttempts < this.maxReconnectAttempts) {
this.reconnectAttempts++;
setTimeout(() => {
this._appendMessage('info', `尝试重新连接 (${this.reconnectAttempts}/${this.maxReconnectAttempts})...`);
this.connect();
}, 3000);
}
};
}
disconnect() {
if (this.ws) {
this.ws.close();
this.ws = null;
}
this.reconnectAttempts = this.maxReconnectAttempts;
this.isConnected = false;
this._updateStatus('disconnected', '未连接');
this.elements.connectBtn.title = '连接';
this.elements.connectBtn.querySelector('i').classList.replace('fa-minus-circle', 'fa-plug');
}
clear() {
if(this.elements.output) {
this.elements.output.innerHTML = '';
}
}
togglePause() {
this.isPaused = !this.isPaused;
const icon = this.elements.pauseBtn.querySelector('i');
if (this.isPaused) {
this.elements.pauseBtn.title = '继续';
icon.classList.replace('fa-pause', 'fa-play');
} else {
this.elements.pauseBtn.title = '暂停';
icon.classList.replace('fa-play', 'fa-pause');
}
}
toggleAutoScroll() {
this.shouldAutoScroll = !this.shouldAutoScroll;
this.elements.scrollBtn.title = this.shouldAutoScroll ? '自动滚动' : '手动滚动';
}
openSettings() {
// 实现设置功能
console.log('打开设置');
}
_appendMessage(colorClass, text) {
if (!this.elements.output) return;
const p = document.createElement('p');
p.className = colorClass;
p.textContent = text;
this.elements.output.appendChild(p);
if (this.shouldAutoScroll) {
this.elements.output.scrollTop = this.elements.output.scrollHeight;
}
}
_updateStatus(status, text) {
const indicator = this.elements.statusIndicator.querySelector('span.relative');
const statusText = this.elements.statusIndicator.childNodes[2];
const colors = {
'connecting': 'bg-yellow-500',
'connected': 'bg-green-500',
'disconnected': 'bg-zinc-500',
'error': 'bg-red-500'
};
indicator.querySelectorAll('span').forEach(span => {
span.className = span.className.replace(/bg-\w+-\d+/g, colors[status] || colors.disconnected);
});
if (statusText) {
statusText.textContent = ` ${text}`;
}
}
}

1311
frontend/js/vendor/anime.esm.js vendored Normal file

File diff suppressed because it is too large Load Diff

2
frontend/js/vendor/flatpickr.js vendored Normal file

File diff suppressed because one or more lines are too long

72
frontend/js/vendor/marked.min.js vendored Normal file

File diff suppressed because one or more lines are too long

1
frontend/js/vendor/nanoid.js vendored Normal file
View File

@@ -0,0 +1 @@
let a="useandom-26T198340PX75pxJACKVERYMINDBUSHWOLF_GQZbfghjklqvwyzrict";export let nanoid=(e=21)=>{let t="",r=crypto.getRandomValues(new Uint8Array(e));for(let n=0;n<e;n++)t+=a[63&r[n]];return t};

6
frontend/js/vendor/popper.esm.min.js vendored Normal file

File diff suppressed because one or more lines are too long

4611
frontend/js/vendor/sweetalert2.esm.js vendored Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

3
go.mod
View File

@@ -10,6 +10,7 @@ require (
github.com/go-co-op/gocron v1.37.0
github.com/go-sql-driver/mysql v1.9.3
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/jackc/pgx/v5 v5.7.5
github.com/microcosm-cc/bluemonday v1.0.27
github.com/redis/go-redis/v9 v9.3.0
@@ -17,6 +18,8 @@ require (
github.com/spf13/viper v1.20.1
go.uber.org/dig v1.19.0
golang.org/x/net v0.42.0
golang.org/x/time v0.14.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gorm.io/datatypes v1.0.5
gorm.io/driver/mysql v1.6.0
gorm.io/driver/postgres v1.6.0

6
go.sum
View File

@@ -92,6 +92,8 @@ github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
@@ -311,6 +313,8 @@ golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
@@ -325,6 +329,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View File

@@ -28,12 +28,6 @@ type GeminiChannel struct {
httpClient *http.Client
}
// 用于安全提取信息的本地结构体
type requestMetadata struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *GeminiChannel {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
@@ -47,39 +41,58 @@ func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *Gemini
logger: logger,
httpClient: &http.Client{
Transport: transport,
Timeout: 0,
Timeout: 0, // Timeout is handled by the request context
},
}
}
// TransformRequest
func (ch *GeminiChannel) TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error) {
modelName = ch.ExtractModel(c, requestBody)
return requestBody, modelName, nil
}
// ExtractModel
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
return ch.extractModelFromRequest(c, bodyBytes)
}
// 统一的模型提取逻辑优先从请求体解析失败则回退到从URL路径解析。
func (ch *GeminiChannel) extractModelFromRequest(c *gin.Context, bodyBytes []byte) string {
var p struct {
Model string `json:"model"`
}
_ = json.Unmarshal(requestBody, &p)
modelName = strings.TrimPrefix(p.Model, "models/")
_ = json.Unmarshal(bodyBytes, &p)
modelName := strings.TrimPrefix(p.Model, "models/")
if modelName == "" {
modelName = ch.extractModelFromPath(c.Request.URL.Path)
}
return requestBody, modelName, nil
return modelName
}
func (ch *GeminiChannel) extractModelFromPath(path string) string {
parts := strings.Split(path, "/")
for _, part := range parts {
// 覆盖更多模型名称格式
if strings.HasPrefix(part, "gemini-") || strings.HasPrefix(part, "text-") || strings.HasPrefix(part, "embedding-") {
modelPart := strings.Split(part, ":")[0]
return modelPart
return strings.Split(part, ":")[0]
}
}
return ""
}
// IsOpenAICompatibleRequest 通过纯粹的字符串操作判断,不依赖 gin.Context。
func (ch *GeminiChannel) IsOpenAICompatibleRequest(c *gin.Context) bool {
path := c.Request.URL.Path
return strings.Contains(path, "/v1/chat/completions") || strings.Contains(path, "/v1/embeddings")
return ch.isOpenAIPath(c.Request.URL.Path)
}
func (ch *GeminiChannel) isOpenAIPath(path string) bool {
return strings.Contains(path, "/v1/chat/completions") ||
strings.Contains(path, "/v1/completions") ||
strings.Contains(path, "/v1/embeddings") ||
strings.Contains(path, "/v1/models") ||
strings.Contains(path, "/v1/audio/")
}
func (ch *GeminiChannel) ValidateKey(
@@ -88,25 +101,28 @@ func (ch *GeminiChannel) ValidateKey(
targetURL string,
timeout time.Duration,
) *CustomErrors.APIError {
client := &http.Client{
Timeout: timeout,
}
client := &http.Client{Timeout: timeout}
req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil)
if err != nil {
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request: "+err.Error())
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request")
}
ch.ModifyRequest(req, apiKey)
resp, err := client.Do(req)
if err != nil {
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "failed to send validation request: "+err.Error())
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "validation request failed: "+err.Error())
}
defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return nil
}
errorBody, _ := io.ReadAll(resp.Body)
parsedMessage := CustomErrors.ParseUpstreamError(errorBody)
parsedMessage, _ := CustomErrors.ParseUpstreamError(errorBody)
return &CustomErrors.APIError{
HTTPStatus: resp.StatusCode,
Code: fmt.Sprintf("UPSTREAM_%d", resp.StatusCode),
@@ -115,10 +131,6 @@ func (ch *GeminiChannel) ValidateKey(
}
func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey) {
// TODO: [Future Refactoring] Decouple auth logic from URL path.
// The authentication method (e.g., Bearer token vs. API key in query) should ideally be a property
// of the UpstreamEndpoint or a new "AuthProfile" entity, rather than being hardcoded based on URL patterns.
// This would make the channel more generic and adaptable to new upstream provider types.
if strings.Contains(req.URL.Path, "/v1beta/openai/") {
req.Header.Set("Authorization", "Bearer "+apiKey.APIKey)
} else {
@@ -133,24 +145,22 @@ func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
if strings.HasSuffix(c.Request.URL.Path, ":streamGenerateContent") {
return true
}
var meta requestMetadata
if err := json.Unmarshal(bodyBytes, &meta); err == nil {
var meta struct {
Stream bool `json:"stream"`
}
if json.Unmarshal(bodyBytes, &meta) == nil {
return meta.Stream
}
return false
}
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
_, modelName, _ := ch.TransformRequest(c, bodyBytes)
return modelName
}
// RewritePath 使用 url.JoinPath 保证路径拼接的正确性。
func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
tempCtx := &gin.Context{Request: &http.Request{URL: &url.URL{Path: originalPath}}}
var rewrittenSegment string
if ch.IsOpenAICompatibleRequest(tempCtx) {
var apiEndpoint string
if ch.isOpenAIPath(originalPath) {
v1Index := strings.LastIndex(originalPath, "/v1/")
var apiEndpoint string
if v1Index != -1 {
apiEndpoint = originalPath[v1Index+len("/v1/"):]
} else {
@@ -158,69 +168,76 @@ func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
}
rewrittenSegment = "v1beta/openai/" + apiEndpoint
} else {
tempPath := originalPath
if strings.HasPrefix(tempPath, "/v1/") {
tempPath = "/v1beta/" + strings.TrimPrefix(tempPath, "/v1/")
if strings.HasPrefix(originalPath, "/v1/") {
rewrittenSegment = "v1beta/" + strings.TrimPrefix(originalPath, "/v1/")
} else {
rewrittenSegment = strings.TrimPrefix(originalPath, "/")
}
rewrittenSegment = strings.TrimPrefix(tempPath, "/")
}
trimmedBasePath := strings.TrimSuffix(basePath, "/")
pathToJoin := rewrittenSegment
trimmedBasePath := strings.TrimSuffix(basePath, "/")
// 防止版本号重复拼接,例如 basePath 是 /v1beta而重写段也是 v1beta/..
versionPrefixes := []string{"v1beta", "v1"}
for _, prefix := range versionPrefixes {
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(pathToJoin, prefix+"/") {
pathToJoin = strings.TrimPrefix(pathToJoin, prefix+"/")
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(rewrittenSegment, prefix+"/") {
rewrittenSegment = strings.TrimPrefix(rewrittenSegment, prefix+"/")
break
}
}
finalPath, err := url.JoinPath(trimmedBasePath, pathToJoin)
finalPath, err := url.JoinPath(trimmedBasePath, rewrittenSegment)
if err != nil {
return trimmedBasePath + "/" + strings.TrimPrefix(pathToJoin, "/")
// 回退到简单的字符串拼接
return trimmedBasePath + "/" + strings.TrimPrefix(rewrittenSegment, "/")
}
return finalPath
}
func (ch *GeminiChannel) ModifyResponse(resp *http.Response) error {
// 这是一个桩实现,暂时不需要任何逻辑。
return nil
return nil // 桩实现
}
func (ch *GeminiChannel) HandleError(c *gin.Context, err error) {
// 这是一个桩实现,暂时不需要任何逻辑。
// 桩实现
}
// ==========================================================
// ================== “智能路由”的核心引擎 ===================
// ==========================================================
func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams) {
log := ch.logger.WithField("correlation_id", params.CorrelationID)
targetURL, err := url.Parse(params.UpstreamURL)
if err != nil {
log.WithError(err).Error("Failed to parse upstream URL")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL format"))
log.WithError(err).Error("Invalid upstream URL")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL"))
return
}
targetURL.Path = c.Request.URL.Path
targetURL.RawQuery = c.Request.URL.RawQuery
initialReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", targetURL.String(), bytes.NewReader(params.RequestBody))
if err != nil {
log.WithError(err).Error("Failed to create initial smart request")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, err.Error()))
log.WithError(err).Error("Failed to create initial smart stream request")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to create request"))
return
}
ch.ModifyRequest(initialReq, params.APIKey)
initialReq.Header.Del("Authorization")
resp, err := ch.httpClient.Do(initialReq)
if err != nil {
log.WithError(err).Error("Initial smart request failed")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, err.Error()))
log.WithError(err).Error("Initial smart stream request failed")
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "Request to upstream failed"))
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Warnf("Initial request received non-200 status: %d", resp.StatusCode)
standardizedResp := ch.standardizeError(resp, params.LogTruncationLimit, log)
defer standardizedResp.Body.Close()
c.Writer.WriteHeader(standardizedResp.StatusCode)
for key, values := range standardizedResp.Header {
for _, value := range values {
@@ -228,45 +245,71 @@ func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartR
}
}
io.Copy(c.Writer, standardizedResp.Body)
params.EventLogger.IsSuccess = false
params.EventLogger.StatusCode = resp.StatusCode
return
}
ch.processStreamAndRetry(c, initialReq.Header, resp.Body, params, log)
}
func (ch *GeminiChannel) processStreamAndRetry(
c *gin.Context, initialRequestHeaders http.Header, initialReader io.ReadCloser, params SmartRequestParams, log *logrus.Entry,
c *gin.Context,
initialRequestHeaders http.Header,
initialReader io.ReadCloser,
params SmartRequestParams,
log *logrus.Entry,
) {
defer initialReader.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
flusher, _ := c.Writer.(http.Flusher)
var accumulatedText strings.Builder
consecutiveRetryCount := 0
currentReader := initialReader
maxRetries := params.MaxRetries
retryDelay := params.RetryDelay
log.Infof("Starting smart stream session. Max retries: %d", maxRetries)
for {
if c.Request.Context().Err() != nil {
log.Info("Client disconnected, stopping stream processing.")
return
}
var interruptionReason string
scanner := bufio.NewScanner(currentReader)
for scanner.Scan() {
if c.Request.Context().Err() != nil {
log.Info("Client disconnected during scan.")
return
}
line := scanner.Text()
if line == "" {
continue
}
fmt.Fprintf(c.Writer, "%s\n\n", line)
flusher.Flush()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
var payload models.GeminiSSEPayload
if err := json.Unmarshal([]byte(data), &payload); err != nil {
continue
}
if len(payload.Candidates) > 0 {
candidate := payload.Candidates[0]
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
@@ -285,52 +328,71 @@ func (ch *GeminiChannel) processStreamAndRetry(
}
}
currentReader.Close()
if interruptionReason == "" {
if err := scanner.Err(); err != nil {
log.WithError(err).Warn("Stream scanner encountered an error.")
interruptionReason = "SCANNER_ERROR"
} else {
log.Warn("Stream dropped unexpectedly without a finish reason.")
log.Warn("Stream connection dropped without a finish reason.")
interruptionReason = "CONNECTION_DROP"
}
}
if consecutiveRetryCount >= maxRetries {
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error.", interruptionReason)
errData, _ := json.Marshal(map[string]interface{}{"error": map[string]interface{}{"code": http.StatusGatewayTimeout, "status": "DEADLINE_EXCEEDED", "message": fmt.Sprintf("Proxy retry limit exceeded. Last interruption: %s.", interruptionReason)}})
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error to client.", interruptionReason)
errData, _ := json.Marshal(map[string]interface{}{
"error": map[string]interface{}{
"code": http.StatusGatewayTimeout,
"status": "DEADLINE_EXCEEDED",
"message": fmt.Sprintf("Proxy retry limit exceeded after multiple interruptions. Last reason: %s", interruptionReason),
},
})
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(errData))
flusher.Flush()
return
}
consecutiveRetryCount++
params.EventLogger.Retries = consecutiveRetryCount
log.Infof("Stream interrupted. Attempting retry %d/%d after %v.", consecutiveRetryCount, maxRetries, retryDelay)
time.Sleep(retryDelay)
retryBody, _ := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
retryBody := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
retryBodyBytes, _ := json.Marshal(retryBody)
retryReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
retryReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
if err != nil {
log.WithError(err).Error("Failed to create retry request")
continue
}
retryReq.Header = initialRequestHeaders
ch.ModifyRequest(retryReq, params.APIKey)
retryReq.Header.Del("Authorization")
retryResp, err := ch.httpClient.Do(retryReq)
if err != nil || retryResp.StatusCode != http.StatusOK || retryResp.Body == nil {
if err != nil {
log.WithError(err).Errorf("Retry request failed.")
} else {
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
if retryResp.Body != nil {
retryResp.Body.Close()
}
}
if err != nil {
log.WithError(err).Error("Retry request failed")
continue
}
if retryResp.StatusCode != http.StatusOK {
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
retryResp.Body.Close()
continue
}
currentReader = retryResp.Body
}
}
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) (models.GeminiRequest, error) {
// buildRetryRequestBody 正确处理多轮对话的上下文插入。
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) models.GeminiRequest {
retryBody := originalBody
// 找到最后一个 'user' 角色的消息索引
lastUserIndex := -1
for i := len(retryBody.Contents) - 1; i >= 0; i-- {
if retryBody.Contents[i].Role == "user" {
@@ -338,25 +400,26 @@ func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText st
break
}
}
history := []models.GeminiContent{
{Role: "model", Parts: []models.Part{{Text: accumulatedText}}},
{Role: "user", Parts: []models.Part{{Text: SmartRetryPrompt}}},
}
if lastUserIndex != -1 {
// 如果找到了 'user' 消息,将历史记录插入到其后
newContents := make([]models.GeminiContent, 0, len(retryBody.Contents)+2)
newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...)
newContents = append(newContents, history...)
newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...)
retryBody.Contents = newContents
} else {
// 如果没有 'user' 消息(理论上不应发生),则直接追加
retryBody.Contents = append(retryBody.Contents, history...)
}
return retryBody, nil
}
// ===============================================
// ========= 辅助函数区 (继承并强化) =========
// ===============================================
return retryBody
}
type googleAPIError struct {
Error struct {
@@ -397,25 +460,28 @@ func truncate(s string, n int) string {
return s
}
// standardizeError
func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int, log *logrus.Entry) *http.Response {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
log.WithError(err).Error("Failed to read upstream error body")
bodyBytes = []byte("Failed to read upstream error body: " + err.Error())
bodyBytes = []byte("Failed to read upstream error body")
}
resp.Body.Close()
log.Errorf("Upstream error body: %s", truncate(string(bodyBytes), truncateLimit))
log.Errorf("Upstream error: %s", truncate(string(bodyBytes), truncateLimit))
var standardizedPayload googleAPIError
// 即使解析失败,也要构建一个标准的错误结构体
if json.Unmarshal(bodyBytes, &standardizedPayload) != nil || standardizedPayload.Error.Code == 0 {
standardizedPayload.Error.Code = resp.StatusCode
standardizedPayload.Error.Message = http.StatusText(resp.StatusCode)
standardizedPayload.Error.Status = statusToGoogleStatus(resp.StatusCode)
standardizedPayload.Error.Details = []interface{}{map[string]string{
"@type": "proxy.upstream.error",
"@type": "proxy.upstream.unparsed.error",
"body": truncate(string(bodyBytes), truncateLimit),
}}
}
newBodyBytes, _ := json.Marshal(standardizedPayload)
newResp := &http.Response{
StatusCode: resp.StatusCode,
@@ -425,10 +491,13 @@ func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int
}
newResp.Header.Set("Content-Type", "application/json; charset=utf-8")
newResp.Header.Set("Access-Control-Allow-Origin", "*")
return newResp
}
// errToJSON
func errToJSON(c *gin.Context, apiErr *CustomErrors.APIError) {
if c.IsAborted() {
return
}
c.JSON(apiErr.HTTPStatus, gin.H{"error": apiErr})
}

View File

@@ -13,9 +13,10 @@ type Config struct {
Database DatabaseConfig
Server ServerConfig
Log LogConfig
Redis RedisConfig `mapstructure:"redis"`
SessionSecret string `mapstructure:"session_secret"`
EncryptionKey string `mapstructure:"encryption_key"`
Redis RedisConfig `mapstructure:"redis"`
SessionSecret string `mapstructure:"session_secret"`
EncryptionKey string `mapstructure:"encryption_key"`
Repository RepositoryConfig `mapstructure:"repository"`
}
// DatabaseConfig 存储数据库连接信息
@@ -28,7 +29,9 @@ type DatabaseConfig struct {
// ServerConfig 存储HTTP服务器配置
type ServerConfig struct {
Port string `mapstructure:"port"`
Port string `mapstructure:"port"`
Host string `yaml:"host"`
CORSOrigins []string `yaml:"cors_origins"`
}
// LogConfig 存储日志配置
@@ -37,25 +40,36 @@ type LogConfig struct {
Format string `mapstructure:"format" json:"format"`
EnableFile bool `mapstructure:"enable_file" json:"enable_file"`
FilePath string `mapstructure:"file_path" json:"file_path"`
// 日志轮转配置(可选)
MaxSize int `yaml:"max_size"` // MB默认 100
MaxBackups int `yaml:"max_backups"` // 默认 7
MaxAge int `yaml:"max_age"` // 天,默认 30
Compress bool `yaml:"compress"` // 默认 true
}
type RedisConfig struct {
DSN string `mapstructure:"dsn"`
}
type RepositoryConfig struct {
BasePoolTTLMinutes int `mapstructure:"base_pool_ttl_minutes"`
BasePoolTTIMinutes int `mapstructure:"base_pool_tti_minutes"`
}
// LoadConfig 从文件和环境变量加载配置
func LoadConfig() (*Config, error) {
// 设置配置文件名和路径
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath(".")
viper.AddConfigPath("/etc/gemini-balancer/") // for production
// 允许从环境变量读取
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
viper.AutomaticEnv()
// 设置默认值
viper.SetDefault("server.port", "8080")
viper.SetDefault("server.port", "9000")
viper.SetDefault("log.level", "info")
viper.SetDefault("log.format", "text")
viper.SetDefault("log.enable_file", false)
@@ -67,6 +81,9 @@ func LoadConfig() (*Config, error) {
viper.SetDefault("database.conn_max_lifetime", "1h")
viper.SetDefault("encryption_key", "")
viper.SetDefault("repository.base_pool_ttl_minutes", 60)
viper.SetDefault("repository.base_pool_tti_minutes", 10)
// 读取配置文件
if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {

View File

@@ -2,14 +2,11 @@
package container
import (
"fmt"
"gemini-balancer/internal/app"
"gemini-balancer/internal/channel"
"gemini-balancer/internal/config"
"gemini-balancer/internal/crypto"
"gemini-balancer/internal/db"
"gemini-balancer/internal/db/dialect"
"gemini-balancer/internal/db/migrations"
"gemini-balancer/internal/domain/proxy"
"gemini-balancer/internal/domain/upstream"
"gemini-balancer/internal/handlers"
@@ -21,13 +18,10 @@ import (
"gemini-balancer/internal/service"
"gemini-balancer/internal/settings"
"gemini-balancer/internal/store"
"gemini-balancer/internal/syncer"
"gemini-balancer/internal/task"
"gemini-balancer/internal/webhandlers"
"github.com/sirupsen/logrus"
"go.uber.org/dig"
"gorm.io/gorm"
)
func BuildContainer() (*dig.Container, error) {
@@ -35,20 +29,9 @@ func BuildContainer() (*dig.Container, error) {
// =========== 阶段一: 基础设施层 (Infrastructure) ===========
container.Provide(config.LoadConfig)
container.Provide(func(cfg *config.Config, logger *logrus.Logger) (*gorm.DB, dialect.DialectAdapter, error) {
gormDB, adapter, err := db.NewDB(cfg, logger)
if err != nil {
return nil, nil, err
}
// 迁移运行逻辑
if err := migrations.RunVersionedMigrations(gormDB, cfg, logger); err != nil {
return nil, nil, fmt.Errorf("failed to run versioned migrations: %w", err)
}
return gormDB, adapter, nil
})
container.Provide(logging.NewLoggerWithWebSocket)
container.Provide(db.NewDBWithMigrations)
container.Provide(store.NewStore)
container.Provide(logging.NewLogger)
container.Provide(crypto.NewService)
container.Provide(repository.NewAuthTokenRepository)
container.Provide(repository.NewGroupRepository)
@@ -85,14 +68,10 @@ func BuildContainer() (*dig.Container, error) {
// --- Syncer & Loader for GroupManager ---
container.Provide(service.NewGroupManagerLoader)
// 为GroupManager配置Syncer
container.Provide(func(loader syncer.LoaderFunc[service.GroupManagerCacheData], store store.Store, logger *logrus.Logger) (*syncer.CacheSyncer[service.GroupManagerCacheData], error) {
const groupUpdateChannel = "groups:cache_invalidation"
return syncer.NewCacheSyncer(loader, store, groupUpdateChannel)
})
container.Provide(service.NewGroupManagerSyncer)
// =========== 阶段三: 适配器与处理器层 (Handlers & Adapters) ===========
// 为Channel提供依赖 (Logger 和 *models.SystemSettings 数据插座)
container.Provide(channel.NewGeminiChannel)
container.Provide(func(ch *channel.GeminiChannel) channel.ChannelProxy { return ch })

View File

@@ -2,8 +2,10 @@
package db
import (
"fmt"
"gemini-balancer/internal/config"
"gemini-balancer/internal/db/dialect"
"gemini-balancer/internal/db/migrations"
stdlog "log"
"os"
"path/filepath"
@@ -86,3 +88,16 @@ func NewDB(cfg *config.Config, appLogger *logrus.Logger) (*gorm.DB, dialect.Dial
Logger.Info("Database connection established successfully.")
return db, adapter, nil
}
func NewDBWithMigrations(cfg *config.Config, logger *logrus.Logger) (*gorm.DB, dialect.DialectAdapter, error) {
gormDB, adapter, err := NewDB(cfg, logger)
if err != nil {
return nil, nil, err
}
if err := migrations.RunVersionedMigrations(gormDB, cfg, logger); err != nil {
return nil, nil, fmt.Errorf("failed to run versioned migrations: %w", err)
}
return gormDB, adapter, nil
}

View File

@@ -2,6 +2,7 @@
package proxy
import (
"context"
"encoding/json"
"gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
@@ -49,7 +50,6 @@ func (h *handler) registerRoutes(rg *gin.RouterGroup) {
}
}
// --- 请求 DTO ---
type CreateProxyConfigRequest struct {
Address string `json:"address" binding:"required"`
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
@@ -64,12 +64,10 @@ type UpdateProxyConfigRequest struct {
Description *string `json:"description"`
}
// 单个检测的请求体 (与前端JS对齐)
type CheckSingleProxyRequest struct {
Proxy string `json:"proxy" binding:"required"`
}
// 批量检测的请求体
type CheckAllProxiesRequest struct {
Proxies []string `json:"proxies" binding:"required"`
}
@@ -84,7 +82,7 @@ func (h *handler) CreateProxyConfig(c *gin.Context) {
}
if req.Status == "" {
req.Status = "active" // 默认状态
req.Status = "active"
}
proxyConfig := models.ProxyConfig{
@@ -98,7 +96,6 @@ func (h *handler) CreateProxyConfig(c *gin.Context) {
response.Error(c, errors.ParseDBError(err))
return
}
// 写操作后,发布事件并使缓存失效
h.publishAndInvalidate(proxyConfig.ID, "created")
response.Created(c, proxyConfig)
}
@@ -199,17 +196,16 @@ func (h *handler) DeleteProxyConfig(c *gin.Context) {
response.NoContent(c)
}
// publishAndInvalidate 统一事件发布和缓存失效逻辑
func (h *handler) publishAndInvalidate(proxyID uint, action string) {
go h.manager.invalidate()
go func() {
ctx := context.Background()
event := models.ProxyStatusChangedEvent{ProxyID: proxyID, Action: action}
eventData, _ := json.Marshal(event)
_ = h.store.Publish(models.TopicProxyStatusChanged, eventData)
_ = h.store.Publish(ctx, models.TopicProxyStatusChanged, eventData)
}()
}
// 新的 Handler 方法和 DTO
type SyncProxiesRequest struct {
Proxies []string `json:"proxies"`
}
@@ -220,14 +216,12 @@ func (h *handler) SyncProxies(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.manager.SyncProxiesInBackground(req.Proxies)
taskStatus, err := h.manager.SyncProxiesInBackground(c.Request.Context(), req.Proxies)
if err != nil {
if errors.Is(err, ErrTaskConflict) {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
} else {
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
}
return
@@ -262,7 +256,7 @@ func (h *handler) CheckAllProxies(c *gin.Context) {
concurrency := cfg.ProxyCheckConcurrency
if concurrency <= 0 {
concurrency = 5 // 如果配置不合法,提供一个安全的默认值
concurrency = 5
}
results := h.manager.CheckMultipleProxies(req.Proxies, timeout, concurrency)
response.Success(c, results)

View File

@@ -2,14 +2,13 @@
package proxy
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
"gemini-balancer/internal/store"
"gemini-balancer/internal/syncer"
"gemini-balancer/internal/task"
"context"
"net"
"net/http"
"net/url"
@@ -25,7 +24,7 @@ import (
const (
TaskTypeProxySync = "proxy_sync"
proxyChunkSize = 200 // 代理同步的批量大小
proxyChunkSize = 200
)
type ProxyCheckResult struct {
@@ -35,13 +34,11 @@ type ProxyCheckResult struct {
ErrorMessage string `json:"error_message"`
}
// managerCacheData
type managerCacheData struct {
ActiveProxies []*models.ProxyConfig
ProxiesByID map[uint]*models.ProxyConfig
}
// manager结构体
type manager struct {
db *gorm.DB
syncer *syncer.CacheSyncer[managerCacheData]
@@ -80,21 +77,21 @@ func newManager(db *gorm.DB, syncer *syncer.CacheSyncer[managerCacheData], taskR
}
}
func (m *manager) SyncProxiesInBackground(proxyStrings []string) (*task.Status, error) {
func (m *manager) SyncProxiesInBackground(ctx context.Context, proxyStrings []string) (*task.Status, error) {
resourceID := "global_proxy_sync"
taskStatus, err := m.task.StartTask(0, TaskTypeProxySync, resourceID, len(proxyStrings), 0)
taskStatus, err := m.task.StartTask(ctx, 0, TaskTypeProxySync, resourceID, len(proxyStrings), 0)
if err != nil {
return nil, ErrTaskConflict
}
go m.runProxySyncTask(taskStatus.ID, proxyStrings)
go m.runProxySyncTask(context.Background(), taskStatus.ID, proxyStrings)
return taskStatus, nil
}
func (m *manager) runProxySyncTask(taskID string, finalProxyStrings []string) {
func (m *manager) runProxySyncTask(ctx context.Context, taskID string, finalProxyStrings []string) {
resourceID := "global_proxy_sync"
var allProxies []models.ProxyConfig
if err := m.db.Find(&allProxies).Error; err != nil {
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err))
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err))
return
}
currentProxyMap := make(map[string]uint)
@@ -125,19 +122,19 @@ func (m *manager) runProxySyncTask(taskID string, finalProxyStrings []string) {
}
if len(idsToDelete) > 0 {
if err := m.bulkDeleteByIDs(idsToDelete); err != nil {
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err))
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err))
return
}
}
if len(proxiesToAdd) > 0 {
if err := m.bulkAdd(proxiesToAdd); err != nil {
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err))
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err))
return
}
}
result := gin.H{"added": len(proxiesToAdd), "deleted": len(idsToDelete), "final_total": len(finalProxyMap)}
m.task.EndTaskByID(taskID, resourceID, result, nil)
m.publishChangeEvent("proxies_synced")
m.task.EndTaskByID(ctx, taskID, resourceID, result, nil)
m.publishChangeEvent(ctx, "proxies_synced")
go m.invalidate()
}
@@ -184,14 +181,15 @@ func (m *manager) bulkDeleteByIDs(ids []uint) error {
}
return nil
}
func (m *manager) bulkAdd(proxies []models.ProxyConfig) error {
return m.db.CreateInBatches(proxies, proxyChunkSize).Error
}
func (m *manager) publishChangeEvent(reason string) {
func (m *manager) publishChangeEvent(ctx context.Context, reason string) {
event := models.ProxyStatusChangedEvent{Action: reason}
eventData, _ := json.Marshal(event)
_ = m.store.Publish(models.TopicProxyStatusChanged, eventData)
_ = m.store.Publish(ctx, models.TopicProxyStatusChanged, eventData)
}
func (m *manager) assignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error) {
@@ -313,3 +311,8 @@ func (m *manager) checkProxyConnectivity(proxyCfg *models.ProxyConfig, timeout t
defer resp.Body.Close()
return true
}
type Manager interface {
AssignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error)
// ... 其他需要暴露给外部服务的方法
}

View File

@@ -20,7 +20,7 @@ type Module struct {
func NewModule(gormDB *gorm.DB, store store.Store, settingsManager *settings.SettingsManager, taskReporter task.Reporter, logger *logrus.Logger) (*Module, error) {
loader := newManagerLoader(gormDB)
cacheSyncer, err := syncer.NewCacheSyncer(loader, store, "proxies:cache_invalidation")
cacheSyncer, err := syncer.NewCacheSyncer(loader, store, "proxies:cache_invalidation", logger)
if err != nil {
return nil, err
}

View File

@@ -15,6 +15,7 @@ import (
type APIError struct {
HTTPStatus int
Code string
Status string `json:"status,omitempty"`
Message string
}
@@ -36,6 +37,7 @@ var (
ErrForbidden = &APIError{HTTPStatus: http.StatusForbidden, Code: "FORBIDDEN", Message: "You do not have permission to access this resource"}
ErrTaskInProgress = &APIError{HTTPStatus: http.StatusConflict, Code: "TASK_IN_PROGRESS", Message: "A task is already in progress"}
ErrBadGateway = &APIError{HTTPStatus: http.StatusBadGateway, Code: "BAD_GATEWAY", Message: "Upstream service error"}
ErrGatewayTimeout = &APIError{HTTPStatus: http.StatusGatewayTimeout, Code: "BAD_GATEWAY_TIMEOUT", Message: "Bad gateway timeout"}
ErrNoActiveKeys = &APIError{HTTPStatus: http.StatusServiceUnavailable, Code: "NO_ACTIVE_KEYS", Message: "No active API keys available for this group"}
ErrMaxRetriesExceeded = &APIError{HTTPStatus: http.StatusBadGateway, Code: "MAX_RETRIES_EXCEEDED", Message: "Request failed after maximum retries"}
ErrNoKeysAvailable = &APIError{HTTPStatus: http.StatusServiceUnavailable, Code: "NO_KEYS_AVAILABLE", Message: "No API keys available to process the request"}
@@ -44,6 +46,7 @@ var (
ErrGroupNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "GROUP_NOT_FOUND", Message: "The specified group was not found."}
ErrPermissionDenied = &APIError{HTTPStatus: http.StatusForbidden, Code: "PERMISSION_DENIED", Message: "Permission denied for this operation."}
ErrConfigurationError = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "CONFIGURATION_ERROR", Message: "A configuration error prevents this request from being processed."}
ErrProxyNotAvailable = &APIError{HTTPStatus: http.StatusNotFound, Code: "PROXY_ERROR", Message: "Required proxy is not available for this request."}
ErrStateConflictMasterRevoked = &APIError{HTTPStatus: http.StatusConflict, Code: "STATE_CONFLICT_MASTER_REVOKED", Message: "Cannot perform this operation on a revoked key."}
ErrNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "NOT_FOUND", Message: "Resource not found"}
@@ -60,11 +63,13 @@ func NewAPIError(base *APIError, message string) *APIError {
}
// NewAPIErrorWithUpstream creates a new APIError specifically for wrapping raw upstream errors.
func NewAPIErrorWithUpstream(statusCode int, code string, upstreamMessage string) *APIError {
func NewAPIErrorWithUpstream(statusCode int, code string, bodyBytes []byte) *APIError {
msg, status := ParseUpstreamError(bodyBytes)
return &APIError{
HTTPStatus: statusCode,
Code: code,
Message: upstreamMessage,
Message: msg,
Status: status,
}
}

View File

@@ -37,6 +37,7 @@ var permanentErrorSubstrings = []string{
"permission_denied", // Catches the 'status' field in Google's JSON error, e.g., "status": "PERMISSION_DENIED".
"service_disabled", // Catches the 'reason' field for disabled APIs, e.g., "reason": "SERVICE_DISABLED".
"api has not been used",
"reported as leaked", // Leaked
}
// --- 2. Temporary Errors ---
@@ -44,8 +45,10 @@ var permanentErrorSubstrings = []string{
// Action: Increment consecutive error count, potentially disable the key.
var temporaryErrorSubstrings = []string{
"quota",
"Quota exceeded",
"limit reached",
"insufficient",
"request limit",
"billing",
"exceeded",
"too many requests",
@@ -71,6 +74,25 @@ var clientNetworkErrorSubstrings = []string{
"broken pipe",
"use of closed network connection",
"request canceled",
"invalid query parameters", // 参数解析错误,归类为客户端错误
}
// --- 5. Retryable Network/Gateway Errors ---
// Errors that indicate temporary network or gateway issues, should retry with same or different key.
// Action: Retry the request.
var retryableNetworkErrorSubstrings = []string{
"bad gateway",
"service unavailable",
"gateway timeout",
"connection refused",
"connection reset",
"stream transmission interrupted", // ✅ 新增:流式传输中断
"failed to establish stream", // ✅ 新增:流式连接建立失败
"upstream connect error",
"no healthy upstream",
"502",
"503",
"504",
}
// IsPermanentUpstreamError checks if an upstream error indicates the key is permanently invalid.
@@ -96,6 +118,11 @@ func IsClientNetworkError(err error) bool {
return containsSubstring(err.Error(), clientNetworkErrorSubstrings)
}
// IsRetryableNetworkError checks if an error is a temporary network/gateway issue.
func IsRetryableNetworkError(msg string) bool {
return containsSubstring(msg, retryableNetworkErrorSubstrings)
}
// containsSubstring is a helper function to avoid code repetition.
func containsSubstring(s string, substrings []string) bool {
if s == "" {

View File

@@ -2,7 +2,6 @@ package errors
import (
"encoding/json"
"strings"
)
const (
@@ -10,67 +9,37 @@ const (
maxErrorBodyLength = 2048
)
// standardErrorResponse matches formats like: {"error": {"message": "..."}}
type standardErrorResponse struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
// vendorErrorResponse matches formats like: {"error_msg": "..."}
type vendorErrorResponse struct {
ErrorMsg string `json:"error_msg"`
}
// simpleErrorResponse matches formats like: {"error": "..."}
type simpleErrorResponse struct {
Error string `json:"error"`
}
// rootMessageErrorResponse matches formats like: {"message": "..."}
type rootMessageErrorResponse struct {
type upstreamErrorDetail struct {
Message string `json:"message"`
Status string `json:"status"`
}
type upstreamErrorPayload struct {
Error upstreamErrorDetail `json:"error"`
}
// ParseUpstreamError attempts to parse a structured error message from an upstream response body
func ParseUpstreamError(body []byte) string {
// 1. Attempt to parse the standard OpenAI/Gemini format.
var stdErr standardErrorResponse
if err := json.Unmarshal(body, &stdErr); err == nil {
if msg := strings.TrimSpace(stdErr.Error.Message); msg != "" {
return truncateString(msg, maxErrorBodyLength)
}
func ParseUpstreamError(body []byte) (message string, status string) {
if len(body) == 0 {
return "Upstream returned an empty error body", ""
}
// 2. Attempt to parse vendor-specific format (e.g., Baidu).
var vendorErr vendorErrorResponse
if err := json.Unmarshal(body, &vendorErr); err == nil {
if msg := strings.TrimSpace(vendorErr.ErrorMsg); msg != "" {
return truncateString(msg, maxErrorBodyLength)
}
// 优先级 1: 尝试解析 OpenAI 兼容接口返回的 `[{"error": {...}}]` 数组格式
var arrayPayload []upstreamErrorPayload
if err := json.Unmarshal(body, &arrayPayload); err == nil && len(arrayPayload) > 0 {
detail := arrayPayload[0].Error
return truncateString(detail.Message, maxErrorBodyLength), detail.Status
}
// 3. Attempt to parse simple error format.
var simpleErr simpleErrorResponse
if err := json.Unmarshal(body, &simpleErr); err == nil {
if msg := strings.TrimSpace(simpleErr.Error); msg != "" {
return truncateString(msg, maxErrorBodyLength)
}
// 优先级 2: 尝试解析 Gemini 原生接口可能返回的 `{"error": {...}}` 单个对象格式
var singlePayload upstreamErrorPayload
if err := json.Unmarshal(body, &singlePayload); err == nil && singlePayload.Error.Message != "" {
detail := singlePayload.Error
return truncateString(detail.Message, maxErrorBodyLength), detail.Status
}
// 4. Attempt to parse root-level message format.
var rootMsgErr rootMessageErrorResponse
if err := json.Unmarshal(body, &rootMsgErr); err == nil {
if msg := strings.TrimSpace(rootMsgErr.Message); msg != "" {
return truncateString(msg, maxErrorBodyLength)
}
}
// 5. Graceful Degradation: If all parsing fails, return the raw (but safe) body.
return truncateString(string(body), maxErrorBodyLength)
// 最终回退: 对于无法识别的 JSON 或纯文本错误
return truncateString(string(body), maxErrorBodyLength), ""
}
// truncateString ensures a string does not exceed a maximum length.
// truncateString remains unchanged.
func truncateString(s string, maxLength int) string {
if len(s) > maxLength {
return s[:maxLength]

View File

@@ -31,11 +31,10 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService, db *gorm.DB, keyImpo
}
}
// DTOs for API requests
type BulkAddKeysToGroupRequest struct {
KeyGroupID uint `json:"key_group_id" binding:"required"`
Keys string `json:"keys" binding:"required"`
ValidateOnImport bool `json:"validate_on_import"` // OmitEmpty/default is false
ValidateOnImport bool `json:"validate_on_import"`
}
type BulkUnlinkKeysFromGroupRequest struct {
@@ -72,11 +71,11 @@ type BulkTestKeysForGroupRequest struct {
}
type BulkActionFilter struct {
Status []string `json:"status"` // Changed to slice to accept multiple statuses
Status []string `json:"status"`
}
type BulkActionRequest struct {
Action string `json:"action" binding:"required,oneof=revalidate set_status delete"`
NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"` // For 'set_status' action
NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"`
Filter BulkActionFilter `json:"filter" binding:"required"`
}
@@ -89,7 +88,7 @@ func (h *APIKeyHandler) AddMultipleKeysToGroup(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyImportService.StartAddKeysTask(req.KeyGroupID, req.Keys, req.ValidateOnImport)
taskStatus, err := h.keyImportService.StartAddKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys, req.ValidateOnImport)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -104,7 +103,7 @@ func (h *APIKeyHandler) UnlinkMultipleKeysFromGroup(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyImportService.StartUnlinkKeysTask(req.KeyGroupID, req.Keys)
taskStatus, err := h.keyImportService.StartUnlinkKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -119,7 +118,7 @@ func (h *APIKeyHandler) HardDeleteMultipleKeys(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(req.Keys)
taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(c.Request.Context(), req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -134,7 +133,7 @@ func (h *APIKeyHandler) RestoreMultipleKeys(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyImportService.StartRestoreKeysTask(req.Keys)
taskStatus, err := h.keyImportService.StartRestoreKeysTask(c.Request.Context(), req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -148,7 +147,7 @@ func (h *APIKeyHandler) TestMultipleKeys(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.keyValidationService.StartTestKeysTask(req.KeyGroupID, req.Keys)
taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -172,7 +171,7 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
}
}
if len(ids) > 0 {
keys, err := h.apiKeyService.GetKeysByIds(ids)
keys, err := h.apiKeyService.GetKeysByIds(c.Request.Context(), ids)
if err != nil {
response.Error(c, &errors.APIError{
HTTPStatus: http.StatusInternalServerError,
@@ -191,7 +190,7 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
if params.PageSize <= 0 {
params.PageSize = 20
}
result, err := h.apiKeyService.ListAPIKeys(&params)
result, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), &params)
if err != nil {
response.Error(c, errors.ParseDBError(err))
return
@@ -201,19 +200,16 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
// ListKeysForGroup handles the GET /keygroups/:id/keys request.
func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
// 1. Manually handle the path parameter.
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
return
}
// 2. Bind query parameters using the correctly tagged struct.
var params models.APIKeyQueryParams
if err := c.ShouldBindQuery(&params); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, err.Error()))
return
}
// 3. Set server-side defaults and the path parameter.
if params.Page <= 0 {
params.Page = 1
}
@@ -221,15 +217,11 @@ func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
params.PageSize = 20
}
params.KeyGroupID = uint(groupID)
// 4. Call the service layer.
paginatedResult, err := h.apiKeyService.ListAPIKeys(&params)
paginatedResult, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), &params)
if err != nil {
response.Error(c, errors.ParseDBError(err))
return
}
// 5. [THE FIX] Return a successful response using the standard `response.Success`
// and a gin.H map, as confirmed to exist in your project.
response.Success(c, gin.H{
"items": paginatedResult.Items,
"total": paginatedResult.Total,
@@ -239,20 +231,17 @@ func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
}
func (h *APIKeyHandler) TestKeysForGroup(c *gin.Context) {
// Group ID is now correctly sourced from the URL path.
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
return
}
// The request body is now simpler, only needing the keys.
var req BulkTestKeysForGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
// Call the same underlying service, but with unambiguous context.
taskStatus, err := h.keyValidationService.StartTestKeysTask(uint(groupID), req.Keys)
taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), uint(groupID), req.Keys)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
return
@@ -267,7 +256,6 @@ func (h *APIKeyHandler) UpdateAPIKey(c *gin.Context) {
}
// UpdateGroupAPIKeyMapping handles updating a key's status within a specific group.
// Route: PUT /keygroups/:id/apikeys/:keyId
func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -284,8 +272,7 @@ func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
// Directly use the service to handle the logic
updatedMapping, err := h.apiKeyService.UpdateMappingStatus(uint(groupID), uint(keyID), req.Status)
updatedMapping, err := h.apiKeyService.UpdateMappingStatus(c.Request.Context(), uint(groupID), uint(keyID), req.Status)
if err != nil {
var apiErr *errors.APIError
if errors.As(err, &apiErr) {
@@ -305,7 +292,7 @@ func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
return
}
if err := h.apiKeyService.HardDeleteAPIKeyByID(uint(id)); err != nil {
if err := h.apiKeyService.HardDeleteAPIKeyByID(c.Request.Context(), uint(id)); err != nil {
response.Error(c, errors.ParseDBError(err))
return
}
@@ -313,7 +300,6 @@ func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) {
}
// RestoreKeysInGroup 恢复指定Key的接口
// POST /keygroups/:id/apikeys/restore
func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -325,7 +311,7 @@ func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
taskStatus, err := h.apiKeyService.StartRestoreKeysTask(uint(groupID), req.KeyIDs)
taskStatus, err := h.apiKeyService.StartRestoreKeysTask(c.Request.Context(), uint(groupID), req.KeyIDs)
if err != nil {
var apiErr *errors.APIError
if errors.As(err, &apiErr) {
@@ -339,14 +325,13 @@ func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
}
// RestoreAllBannedInGroup 一键恢复所有Banned Key的接口
// POST /keygroups/:id/apikeys/restore-all-banned
func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("groupId"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
return
}
taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(uint(groupID))
taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(c.Request.Context(), uint(groupID))
if err != nil {
var apiErr *errors.APIError
if errors.As(err, &apiErr) {
@@ -360,48 +345,39 @@ func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) {
}
// HandleBulkAction handles generic bulk actions on a key group based on server-side filters.
// Route: POST /keygroups/:id/bulk-actions
func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) {
// 1. Parse GroupID from URL
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
return
}
// 2. Bind the JSON payload to our new DTO
var req BulkActionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
return
}
// 3. Central logic: based on the action, call the appropriate service method.
var task *task.Status
var apiErr *errors.APIError
switch req.Action {
case "revalidate":
// Assume keyValidationService has a method that accepts a filter
task, err = h.keyValidationService.StartTestKeysByFilterTask(uint(groupID), req.Filter.Status)
task, err = h.keyValidationService.StartTestKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status)
case "set_status":
if req.NewStatus == "" {
apiErr = errors.NewAPIError(errors.ErrBadRequest, "new_status is required for set_status action")
break
}
// Assume apiKeyService has a method to update status by filter
targetStatus := models.APIKeyStatus(req.NewStatus) // Convert string to your model's type
task, err = h.apiKeyService.StartUpdateStatusByFilterTask(uint(groupID), req.Filter.Status, targetStatus)
targetStatus := models.APIKeyStatus(req.NewStatus)
task, err = h.apiKeyService.StartUpdateStatusByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status, targetStatus)
case "delete":
// Assume keyImportService has a method to unlink by filter
task, err = h.keyImportService.StartUnlinkKeysByFilterTask(uint(groupID), req.Filter.Status)
task, err = h.keyImportService.StartUnlinkKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status)
default:
apiErr = errors.NewAPIError(errors.ErrBadRequest, "Unsupported action: "+req.Action)
}
// 4. Handle errors from the switch block
if apiErr != nil {
response.Error(c, apiErr)
return
}
if err != nil {
// Attempt to parse it as a known APIError, otherwise, wrap it.
var parsedErr *errors.APIError
if errors.As(err, &parsedErr) {
response.Error(c, parsedErr)
@@ -410,21 +386,18 @@ func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) {
}
return
}
// 5. Return the task status on success
response.Success(c, task)
}
// ExportKeysForGroup handles requests to export all keys for a group based on status filters.
// Route: GET /keygroups/:id/apikeys/export
func (h *APIKeyHandler) ExportKeysForGroup(c *gin.Context) {
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
return
}
// Use QueryArray to correctly parse `status[]=active&status[]=cooldown`
statuses := c.QueryArray("status")
keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(uint(groupID), statuses)
keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(c.Request.Context(), uint(groupID), statuses)
if err != nil {
response.Error(c, errors.ParseDBError(err))
return

View File

@@ -30,7 +30,7 @@ func (h *DashboardHandler) GetOverview(c *gin.Context) {
c.JSON(http.StatusOK, stats)
}
// GetChart 获取仪表盘的图表数据
// GetChart
func (h *DashboardHandler) GetChart(c *gin.Context) {
var groupID *uint
if groupIDStr := c.Query("groupId"); groupIDStr != "" {
@@ -40,7 +40,7 @@ func (h *DashboardHandler) GetChart(c *gin.Context) {
}
}
chartData, err := h.queryService.QueryHistoricalChart(groupID)
chartData, err := h.queryService.QueryHistoricalChart(c.Request.Context(), groupID)
if err != nil {
apiErr := errors.NewAPIError(errors.ErrDatabase, err.Error())
c.JSON(apiErr.HTTPStatus, apiErr)
@@ -49,10 +49,10 @@ func (h *DashboardHandler) GetChart(c *gin.Context) {
c.JSON(http.StatusOK, chartData)
}
// GetRequestStats 处理对“期间调用概览”的请求
// GetRequestStats
func (h *DashboardHandler) GetRequestStats(c *gin.Context) {
period := c.Param("period") // 从 URL 路径中获取 period
stats, err := h.queryService.GetRequestStatsForPeriod(period)
period := c.Param("period")
stats, err := h.queryService.GetRequestStatsForPeriod(c.Request.Context(), period)
if err != nil {
apiErr := errors.NewAPIError(errors.ErrBadRequest, err.Error())
c.JSON(apiErr.HTTPStatus, apiErr)

View File

@@ -2,6 +2,7 @@
package handlers
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/errors"
@@ -31,7 +32,6 @@ func NewKeyGroupHandler(gm *service.GroupManager, s store.Store, qs *service.Das
}
}
// DTOs & 辅助函数
func isValidGroupName(name string) bool {
if name == "" {
return false
@@ -40,7 +40,6 @@ func isValidGroupName(name string) bool {
return match
}
// KeyGroupOperationalSettings defines the shared operational settings for a key group.
type KeyGroupOperationalSettings struct {
EnableKeyCheck *bool `json:"enable_key_check"`
KeyCheckIntervalMinutes *int `json:"key_check_interval_minutes"`
@@ -52,7 +51,6 @@ type KeyGroupOperationalSettings struct {
MaxRetries *int `json:"max_retries"`
EnableSmartGateway *bool `json:"enable_smart_gateway"`
}
type CreateKeyGroupRequest struct {
Name string `json:"name" binding:"required"`
DisplayName string `json:"display_name"`
@@ -60,11 +58,8 @@ type CreateKeyGroupRequest struct {
PollingStrategy string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
EnableProxy bool `json:"enable_proxy"`
ChannelType string `json:"channel_type"`
// Embed shared operational settings
KeyGroupOperationalSettings
}
type UpdateKeyGroupRequest struct {
Name *string `json:"name"`
DisplayName *string `json:"display_name"`
@@ -72,15 +67,10 @@ type UpdateKeyGroupRequest struct {
PollingStrategy *string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
EnableProxy *bool `json:"enable_proxy"`
ChannelType *string `json:"channel_type"`
// Embed shared operational settings
KeyGroupOperationalSettings
// M:N associations
AllowedUpstreams []string `json:"allowed_upstreams"`
AllowedModels []string `json:"allowed_models"`
}
type KeyGroupResponse struct {
ID uint `json:"id"`
Name string `json:"name"`
@@ -96,36 +86,30 @@ type KeyGroupResponse struct {
AllowedModels []string `json:"allowed_models"`
AllowedUpstreams []string `json:"allowed_upstreams"`
}
// [NEW] Define the detailed response structure for a single group.
type KeyGroupDetailsResponse struct {
KeyGroupResponse
Settings *models.GroupSettings `json:"settings,omitempty"`
RequestConfig *models.RequestConfig `json:"request_config,omitempty"`
}
// transformModelsToStrings converts a slice of GroupModelMapping pointers to a slice of model names.
func transformModelsToStrings(mappings []*models.GroupModelMapping) []string {
modelNames := make([]string, 0, len(mappings))
for _, mapping := range mappings {
if mapping != nil { // Safety check
if mapping != nil {
modelNames = append(modelNames, mapping.ModelName)
}
}
return modelNames
}
// transformUpstreamsToStrings converts a slice of UpstreamEndpoint pointers to a slice of URLs.
func transformUpstreamsToStrings(upstreams []*models.UpstreamEndpoint) []string {
urls := make([]string, 0, len(upstreams))
for _, upstream := range upstreams {
if upstream != nil { // Safety check
if upstream != nil {
urls = append(urls, upstream.URL)
}
}
return urls
}
func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount int64) KeyGroupResponse {
return KeyGroupResponse{
ID: group.ID,
@@ -139,13 +123,10 @@ func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount i
CreatedAt: group.CreatedAt,
UpdatedAt: group.UpdatedAt,
Order: group.Order,
AllowedModels: transformModelsToStrings(group.AllowedModels), // Call the new helper
AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams), // Call the new helper
AllowedModels: transformModelsToStrings(group.AllowedModels),
AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams),
}
}
// packGroupSettings is a helper to convert request-level operational settings
// into the model-level settings struct.
func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSettings {
return &models.KeyGroupSettings{
EnableKeyCheck: settings.EnableKeyCheck,
@@ -159,7 +140,6 @@ func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSet
EnableSmartGateway: settings.EnableSmartGateway,
}
}
func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup, *errors.APIError) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
@@ -171,7 +151,6 @@ func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup,
}
return group, nil
}
func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGroup) {
if req.Name != nil {
group.Name = *req.Name
@@ -197,9 +176,10 @@ func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGrou
// publishGroupChangeEvent encapsulates the logic for marshaling and publishing a group change event.
func (h *KeyGroupHandler) publishGroupChangeEvent(groupID uint, reason string) {
go func() {
ctx := context.Background()
event := models.KeyStatusChangedEvent{GroupID: groupID, ChangeReason: reason}
eventData, _ := json.Marshal(event)
h.store.Publish(models.TopicKeyStatusChanged, eventData)
_ = h.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
}()
}
@@ -216,7 +196,6 @@ func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
return
}
// The core logic remains, as it's specific to creation.
p := bluemonday.StripTagsPolicy()
sanitizedDisplayName := p.Sanitize(req.DisplayName)
sanitizedDescription := p.Sanitize(req.Description)
@@ -244,11 +223,9 @@ func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
response.Created(c, h.newKeyGroupResponse(keyGroup, 0))
}
// 统一的处理器可以处理两种情况:
// 1. GET /keygroups - 返回所有组的列表
// 2. GET /keygroups/:id - 返回指定ID的单个组
func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
// Case 1: Get a single group
if idStr := c.Param("id"); idStr != "" {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
@@ -265,7 +242,6 @@ func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
response.Success(c, detailedResponse)
return
}
// Case 2: Get all groups
allGroups := h.groupManager.GetAllGroups()
responses := make([]KeyGroupResponse, 0, len(allGroups))
for _, group := range allGroups {
@@ -275,7 +251,6 @@ func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
response.Success(c, responses)
}
// UpdateKeyGroup
func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
@@ -304,7 +279,6 @@ func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
response.Success(c, h.newKeyGroupResponse(freshGroup, keyCount))
}
// DeleteKeyGroup
func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
@@ -320,14 +294,14 @@ func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
response.Success(c, gin.H{"message": fmt.Sprintf("Group '%s' and its associated keys deleted successfully", groupName)})
}
// GetKeyGroupStats
func (h *KeyGroupHandler) GetKeyGroupStats(c *gin.Context) {
group, apiErr := h.getGroupFromContext(c)
if apiErr != nil {
response.Error(c, apiErr)
return
}
stats, err := h.queryService.GetGroupStats(group.ID)
stats, err := h.queryService.GetGroupStats(c.Request.Context(), group.ID)
if err != nil {
response.Error(c, errors.NewAPIError(errors.ErrDatabase, err.Error()))
return
@@ -350,7 +324,6 @@ func (h *KeyGroupHandler) CloneKeyGroup(c *gin.Context) {
response.Created(c, h.newKeyGroupResponse(clonedGroup, keyCount))
}
// 更新分组排序
func (h *KeyGroupHandler) UpdateKeyGroupOrder(c *gin.Context) {
var payload []service.UpdateOrderPayload
if err := c.ShouldBindJSON(&payload); err != nil {

View File

@@ -5,7 +5,9 @@ import (
"gemini-balancer/internal/errors"
"gemini-balancer/internal/response"
"gemini-balancer/internal/service"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
@@ -19,22 +21,81 @@ func NewLogHandler(logService *service.LogService) *LogHandler {
}
func (h *LogHandler) GetLogs(c *gin.Context) {
// 调用新的服务函数,接收日志列表和总数
logs, total, err := h.logService.GetLogs(c)
queryParams := make(map[string]string)
for key, values := range c.Request.URL.Query() {
if len(values) > 0 {
queryParams[key] = values[0]
}
}
params, err := service.ParseLogQueryParams(queryParams)
if err != nil {
response.Error(c, errors.ErrBadRequest)
return
}
logs, total, err := h.logService.GetLogs(c.Request.Context(), params)
if err != nil {
response.Error(c, errors.ErrDatabase)
return
}
response.Success(c, gin.H{
"items": logs,
"total": total,
"page": params.Page,
"page_size": params.PageSize,
})
}
// DeleteLogs 删除选定日志 DELETE /admin/logs?ids=1,2,3
func (h *LogHandler) DeleteLogs(c *gin.Context) {
idsStr := c.Query("ids")
if idsStr == "" {
response.Error(c, errors.ErrBadRequest)
return
}
var ids []uint
for _, idStr := range strings.Split(idsStr, ",") {
if id, err := strconv.ParseUint(strings.TrimSpace(idStr), 10, 32); err == nil {
ids = append(ids, uint(id))
}
}
if len(ids) == 0 {
response.Error(c, errors.ErrBadRequest)
return
}
if err := h.logService.DeleteLogs(c.Request.Context(), ids); err != nil {
response.Error(c, errors.ErrDatabase)
return
}
response.Success(c, gin.H{"deleted": len(ids)})
}
// DeleteAllLogs 删除全部日志 DELETE /admin/logs/all
func (h *LogHandler) DeleteAllLogs(c *gin.Context) {
if err := h.logService.DeleteAllLogs(c.Request.Context()); err != nil {
response.Error(c, errors.ErrDatabase)
return
}
response.Success(c, gin.H{"message": "all logs deleted"})
}
// DeleteOldLogs 删除旧日志 DELETE /admin/logs/old?days=30
func (h *LogHandler) DeleteOldLogs(c *gin.Context) {
daysStr := c.Query("days")
days, err := strconv.Atoi(daysStr)
if err != nil || days <= 0 {
response.Error(c, errors.ErrBadRequest)
return
}
deleted, err := h.logService.DeleteOldLogs(c.Request.Context(), days)
if err != nil {
response.Error(c, errors.ErrDatabase)
return
}
// 解析分页参数用于响应体
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
// 使用标准的分页响应结构
response.Success(c, gin.H{
"items": logs,
"total": total,
"page": page,
"page_size": pageSize,
})
response.Success(c, gin.H{"deleted": deleted, "days": days})
}

File diff suppressed because it is too large Load Diff

View File

@@ -7,6 +7,7 @@ import (
"gemini-balancer/internal/settings"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
type SettingHandler struct {
@@ -23,16 +24,35 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
func (h *SettingHandler) UpdateSettings(c *gin.Context) {
var newSettingsMap map[string]interface{}
if err := c.ShouldBindJSON(&newSettingsMap); err != nil {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
logrus.WithError(err).Error("Failed to bind JSON in UpdateSettings")
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, "Invalid JSON: "+err.Error()))
return
}
if err := h.settingsManager.UpdateSettings(newSettingsMap); err != nil {
// TODO 可以根据错误类型返回更具体的错误
logrus.Debugf("Received settings update: %+v", newSettingsMap)
validKeys := make(map[string]interface{})
for key, value := range newSettingsMap {
if _, exists := h.settingsManager.IsValidKey(key); exists {
validKeys[key] = value
} else {
logrus.Warnf("Invalid key received: %s", key)
}
}
if len(validKeys) == 0 {
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, "No valid settings keys provided"))
return
}
logrus.Debugf("Valid keys to update: %+v", validKeys)
if err := h.settingsManager.UpdateSettings(validKeys); err != nil {
logrus.WithError(err).Error("Failed to update settings")
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
return
}
response.Success(c, gin.H{"message": "Settings update request processed successfully."})
response.Success(c, gin.H{"message": "Settings updated successfully"})
}
// ResetSettingsToDefaults resets all settings to their default values

View File

@@ -33,7 +33,7 @@ func (h *TaskHandler) GetTaskStatus(c *gin.Context) {
return
}
taskStatus, err := h.taskService.GetStatus(taskID)
taskStatus, err := h.taskService.GetStatus(c.Request.Context(), taskID)
if err != nil {
// TODO 可以根据 service 层返回的具体错误类型进行更精细的处理
response.Error(c, errors.NewAPIError(errors.ErrResourceNotFound, err.Error()))

View File

@@ -0,0 +1,68 @@
// Filename: internal/handlers/websocket_handler.go
package handlers
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
)
type connWrapper struct {
conn *websocket.Conn
mu sync.Mutex
}
type WebSocketHandler struct {
logger *logrus.Logger
clients sync.Map
upgrader websocket.Upgrader
}
func NewWebSocketHandler(logger *logrus.Logger) *WebSocketHandler {
return &WebSocketHandler{
logger: logger,
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
},
}
}
func (h *WebSocketHandler) HandleSystemLogs(c *gin.Context) {
conn, err := h.upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
h.logger.WithError(err).Error("Failed to upgrade websocket")
return
}
defer conn.Close()
clientID := time.Now().UnixNano()
h.clients.Store(clientID, &connWrapper{conn: conn})
defer h.clients.Delete(clientID)
for {
if _, _, err := conn.ReadMessage(); err != nil {
break
}
}
}
func (h *WebSocketHandler) BroadcastLog(entry *logrus.Entry) {
msg := map[string]interface{}{
"timestamp": entry.Time.Format(time.RFC3339),
"level": entry.Level.String(),
"message": entry.Message,
"fields": entry.Data,
}
h.clients.Range(func(key, value interface{}) bool {
wrapper := value.(*connWrapper)
wrapper.mu.Lock()
wrapper.conn.WriteJSON(msg)
wrapper.mu.Unlock()
return true
})
}

View File

@@ -9,20 +9,25 @@ import (
"path/filepath"
"github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2"
)
// 包级变量,用于存储日志轮转器
var logRotator *lumberjack.Logger
// NewLogger 返回标准的 *logrus.Logger兼容 Fx 依赖注入)
func NewLogger(cfg *config.Config) *logrus.Logger {
logger := logrus.New()
// 1. 设置日志级别
// 设置日志级别
level, err := logrus.ParseLevel(cfg.Log.Level)
if err != nil {
logger.WithField("configured_level", cfg.Log.Level).Warn("Invalid log level specified, defaulting to 'info'.")
logger.WithField("configured_level", cfg.Log.Level).Warn("Invalid log level, defaulting to 'info'")
level = logrus.InfoLevel
}
logger.SetLevel(level)
// 2. 设置日志格式
// 设置日志格式
if cfg.Log.Format == "json" {
logger.SetFormatter(&logrus.JSONFormatter{
TimestampFormat: "2006-01-02T15:04:05.000Z07:00",
@@ -39,36 +44,57 @@ func NewLogger(cfg *config.Config) *logrus.Logger {
})
}
// 3. 设置日志输出
// 添加全局字段
hostname, _ := os.Hostname()
logger = logger.WithFields(logrus.Fields{
"service": "gemini-balancer",
"hostname": hostname,
}).Logger
// 设置日志输出
if cfg.Log.EnableFile {
if cfg.Log.FilePath == "" {
logger.Warn("Log file is enabled but no file path is specified. Logging to console only.")
logger.Warn("Log file enabled but no path specified. Logging to console only")
logger.SetOutput(os.Stdout)
return logger
}
logDir := filepath.Dir(cfg.Log.FilePath)
if err := os.MkdirAll(logDir, 0755); err != nil {
logger.WithError(err).Warn("Failed to create log directory. Logging to console only.")
if err := os.MkdirAll(logDir, 0750); err != nil {
logger.WithError(err).Warn("Failed to create log directory. Logging to console only")
logger.SetOutput(os.Stdout)
return logger
}
logFile, err := os.OpenFile(cfg.Log.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
logger.WithError(err).Warn("Failed to open log file. Logging to console only.")
logger.SetOutput(os.Stdout)
return logger
// 配置日志轮转(保存到包级变量)
logRotator = &lumberjack.Logger{
Filename: cfg.Log.FilePath,
MaxSize: getOrDefault(cfg.Log.MaxSize, 100),
MaxBackups: getOrDefault(cfg.Log.MaxBackups, 7),
MaxAge: getOrDefault(cfg.Log.MaxAge, 30),
Compress: cfg.Log.Compress,
}
// 同时输出到控制台和文件
logger.SetOutput(io.MultiWriter(os.Stdout, logFile))
logger.WithField("log_file_path", cfg.Log.FilePath).Info("Logging is now configured to output to both console and file.")
logger.SetOutput(io.MultiWriter(os.Stdout, logRotator))
logger.WithField("log_file", cfg.Log.FilePath).Info("Logging to both console and file")
} else {
// 仅输出到控制台
logger.SetOutput(os.Stdout)
}
logger.Info("Root logger initialized.")
logger.Info("Logger initialized successfully")
return logger
}
// Close 关闭日志轮转器(在 main.go 中调用)
func Close() {
if logRotator != nil {
logRotator.Close()
}
}
func getOrDefault(value, defaultValue int) int {
if value <= 0 {
return defaultValue
}
return value
}

View File

@@ -0,0 +1,36 @@
// Filename: internal/logging/websocket_hook.go
package logging
import (
"gemini-balancer/internal/config"
"gemini-balancer/internal/handlers"
"github.com/sirupsen/logrus"
)
type WebSocketHook struct {
broadcaster func(*logrus.Entry)
}
func NewWebSocketHook(broadcaster func(*logrus.Entry)) *WebSocketHook {
return &WebSocketHook{broadcaster: broadcaster}
}
func (h *WebSocketHook) Levels() []logrus.Level {
return logrus.AllLevels
}
func (h *WebSocketHook) Fire(entry *logrus.Entry) error {
if h.broadcaster != nil {
go h.broadcaster(entry)
}
return nil
}
func NewLoggerWithWebSocket(cfg *config.Config) (*logrus.Logger, *handlers.WebSocketHandler) {
logger := NewLogger(cfg)
wsHandler := handlers.NewWebSocketHandler(logger)
hook := NewWebSocketHook(wsHandler.BroadcastLog)
logger.AddHook(hook)
return logger, wsHandler
}

View File

@@ -1,4 +1,5 @@
// Filename: internal/middleware/auth.go
package middleware
import (
@@ -7,76 +8,115 @@ import (
"strings"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
// === API Admin 认证管道 (/admin/* API路由) ===
type ErrorResponse struct {
Error string `json:"error"`
Code string `json:"code,omitempty"`
}
func APIAdminAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc {
// APIAdminAuthMiddleware 管理后台 API 认证
func APIAdminAuthMiddleware(
securityService *service.SecurityService,
logger *logrus.Logger,
) gin.HandlerFunc {
return func(c *gin.Context) {
tokenValue := extractBearerToken(c)
if tokenValue == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization token is missing"})
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
Error: "Authentication required",
Code: "AUTH_MISSING",
})
return
}
// ✅ 只传 token 参数(移除 context
authToken, err := securityService.AuthenticateToken(tokenValue)
if err != nil || !authToken.IsAdmin {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or non-admin token"})
if err != nil {
logger.WithError(err).Warn("Authentication failed")
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
Error: "Invalid authentication",
Code: "AUTH_INVALID",
})
return
}
if !authToken.IsAdmin {
c.AbortWithStatusJSON(http.StatusForbidden, ErrorResponse{
Error: "Admin access required",
Code: "AUTH_FORBIDDEN",
})
return
}
c.Set("adminUser", authToken)
c.Next()
}
}
// === /v1 Proxy 认证 ===
func ProxyAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc {
// ProxyAuthMiddleware 代理请求认证
func ProxyAuthMiddleware(
securityService *service.SecurityService,
logger *logrus.Logger,
) gin.HandlerFunc {
return func(c *gin.Context) {
tokenValue := extractProxyToken(c)
if tokenValue == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "API key is missing from request"})
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
Error: "API key required",
Code: "KEY_MISSING",
})
return
}
// ✅ 只传 token 参数(移除 context
authToken, err := securityService.AuthenticateToken(tokenValue)
if err != nil {
// 通用信息,避免泄露过多信息
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or inactive token provided"})
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
Error: "Invalid API key",
Code: "KEY_INVALID",
})
return
}
c.Set("authToken", authToken)
c.Next()
}
}
// extractProxyToken 按优先级提取 token
func extractProxyToken(c *gin.Context) string {
if key := c.Query("key"); key != "" {
return key
}
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
if strings.HasPrefix(authHeader, "Bearer ") {
return strings.TrimPrefix(authHeader, "Bearer ")
}
// 优先级 1: Authorization Header
if token := extractBearerToken(c); token != "" {
return token
}
// 优先级 2: X-Api-Key
if key := c.GetHeader("X-Api-Key"); key != "" {
return key
}
// 优先级 3: X-Goog-Api-Key
if key := c.GetHeader("X-Goog-Api-Key"); key != "" {
return key
}
return ""
// 优先级 4: Query 参数(不推荐)
return c.Query("key")
}
// === 辅助函数 ===
// extractBearerToken 提取 Bearer Token
func extractBearerToken(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
return ""
}
parts := strings.Split(authHeader, " ")
if len(parts) == 2 && parts[0] == "Bearer" {
return parts[1]
const prefix = "Bearer "
if !strings.HasPrefix(authHeader, prefix) {
return ""
}
return ""
}
return strings.TrimSpace(authHeader[len(prefix):])
}

View File

@@ -0,0 +1,90 @@
// Filename: internal/middleware/cors.go
package middleware
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
type CORSConfig struct {
AllowedOrigins []string
AllowedMethods []string
AllowedHeaders []string
ExposedHeaders []string
AllowCredentials bool
MaxAge int
}
func CORSMiddleware(config CORSConfig) gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
// 检查 origin 是否允许
if origin != "" && isOriginAllowed(origin, config.AllowedOrigins) {
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
}
if config.AllowCredentials {
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
if len(config.ExposedHeaders) > 0 {
c.Writer.Header().Set("Access-Control-Expose-Headers",
strings.Join(config.ExposedHeaders, ", "))
}
// 处理预检请求
if c.Request.Method == http.MethodOptions {
if len(config.AllowedMethods) > 0 {
c.Writer.Header().Set("Access-Control-Allow-Methods",
strings.Join(config.AllowedMethods, ", "))
}
if len(config.AllowedHeaders) > 0 {
c.Writer.Header().Set("Access-Control-Allow-Headers",
strings.Join(config.AllowedHeaders, ", "))
}
if config.MaxAge > 0 {
c.Writer.Header().Set("Access-Control-Max-Age",
string(rune(config.MaxAge)))
}
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
func isOriginAllowed(origin string, allowedOrigins []string) bool {
for _, allowed := range allowedOrigins {
if allowed == "*" || allowed == origin {
return true
}
// 支持通配符子域名
if strings.HasPrefix(allowed, "*.") {
domain := strings.TrimPrefix(allowed, "*.")
if strings.HasSuffix(origin, domain) {
return true
}
}
}
return false
}
// 使用示例
func SetupCORS(r *gin.Engine) {
r.Use(CORSMiddleware(CORSConfig{
AllowedOrigins: []string{"https://yourdomain.com", "*.yourdomain.com"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Authorization", "Content-Type", "X-Api-Key"},
ExposedHeaders: []string{"X-Request-Id"},
AllowCredentials: true,
MaxAge: 3600,
}))
}

View File

@@ -1,84 +1,213 @@
// Filename: internal/middleware/log_redaction.go
// Filename: internal/middleware/logging.go
package middleware
import (
"bytes"
"io"
"regexp"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
const RedactedBodyKey = "redactedBody"
const RedactedAuthHeaderKey = "redactedAuthHeader"
const RedactedValue = `"[REDACTED]"`
const (
RedactedBodyKey = "redactedBody"
RedactedAuthHeaderKey = "redactedAuthHeader"
RedactedValue = `"[REDACTED]"`
)
// 预编译正则表达式(全局变量,提升性能)
var (
// JSON 敏感字段脱敏
jsonSensitiveKeys = regexp.MustCompile(`("(?i:api_key|apikey|token|password|secret|authorization|key|keys|auth)"\s*:\s*)"[^"]*"`)
// Bearer Token 脱敏
bearerTokenPattern = regexp.MustCompile(`^(Bearer\s+)\S+$`)
// URL 中的 key 参数脱敏
queryKeyPattern = regexp.MustCompile(`([?&](?i:key|token|apikey)=)[^&\s]+`)
)
// RedactionMiddleware 请求数据脱敏中间件
func RedactionMiddleware() gin.HandlerFunc {
// Pre-compile regex for efficiency
jsonKeyPattern := regexp.MustCompile(`("api_key"|"keys")\s*:\s*"[^"]*"`)
bearerTokenPattern := regexp.MustCompile(`^(Bearer\s+)\S+$`)
return func(c *gin.Context) {
// --- 1. Redact Request Body ---
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "DELETE" {
if bodyBytes, err := io.ReadAll(c.Request.Body); err == nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
bodyString := string(bodyBytes)
redactedBody := jsonKeyPattern.ReplaceAllString(bodyString, `$1:`+RedactedValue)
c.Set(RedactedBodyKey, redactedBody)
}
}
// --- 2. Redact Authorization Header ---
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
if bearerTokenPattern.MatchString(authHeader) {
redactedHeader := bearerTokenPattern.ReplaceAllString(authHeader, `${1}[REDACTED]`)
c.Set(RedactedAuthHeaderKey, redactedHeader)
}
// 1. 脱敏请求体
if shouldRedactBody(c) {
redactRequestBody(c)
}
// 2. 脱敏认证头
redactAuthHeader(c)
// 3. 脱敏 URL 查询参数
redactQueryParams(c)
c.Next()
}
}
// LogrusLogger is a Gin middleware that logs requests using a Logrus logger.
// It consumes redacted data prepared by the RedactionMiddleware.
// shouldRedactBody 判断是否需要脱敏请求体
func shouldRedactBody(c *gin.Context) bool {
method := c.Request.Method
contentType := c.GetHeader("Content-Type")
// 只处理包含 JSON 的 POST/PUT/PATCH/DELETE 请求
return (method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE") &&
strings.Contains(contentType, "application/json")
}
// redactRequestBody 脱敏请求体
func redactRequestBody(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
return
}
// 恢复请求体供后续使用
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// 脱敏敏感字段
bodyString := string(bodyBytes)
redactedBody := jsonSensitiveKeys.ReplaceAllString(bodyString, `$1`+RedactedValue)
c.Set(RedactedBodyKey, redactedBody)
}
// redactAuthHeader 脱敏认证头
func redactAuthHeader(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
return
}
if bearerTokenPattern.MatchString(authHeader) {
redacted := bearerTokenPattern.ReplaceAllString(authHeader, `${1}[REDACTED]`)
c.Set(RedactedAuthHeaderKey, redacted)
} else {
// 对于非 Bearer 的 token全部脱敏
c.Set(RedactedAuthHeaderKey, "[REDACTED]")
}
// 同时处理其他敏感 Header
sensitiveHeaders := []string{"X-Api-Key", "X-Goog-Api-Key", "Api-Key"}
for _, header := range sensitiveHeaders {
if value := c.GetHeader(header); value != "" {
c.Set("redacted_"+header, "[REDACTED]")
}
}
}
// redactQueryParams 脱敏 URL 查询参数
func redactQueryParams(c *gin.Context) {
rawQuery := c.Request.URL.RawQuery
if rawQuery == "" {
return
}
redacted := queryKeyPattern.ReplaceAllString(rawQuery, `${1}[REDACTED]`)
if redacted != rawQuery {
c.Set("redactedQuery", redacted)
}
}
// LogrusLogger Gin 请求日志中间件(使用 Logrus
func LogrusLogger(logger *logrus.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
method := c.Request.Method
// Process request
// 处理请求
c.Next()
// After request, gather data and log
// 计算延迟
latency := time.Since(start)
statusCode := c.Writer.Status()
clientIP := c.ClientIP()
entry := logger.WithFields(logrus.Fields{
"status_code": statusCode,
"latency_ms": latency.Milliseconds(),
"client_ip": c.ClientIP(),
"method": c.Request.Method,
"path": path,
})
// 构建日志字段
fields := logrus.Fields{
"status": statusCode,
"latency_ms": latency.Milliseconds(),
"ip": clientIP,
"method": method,
"path": path,
}
// 添加请求 ID如果存在
if requestID := getRequestID(c); requestID != "" {
fields["request_id"] = requestID
}
// 添加脱敏后的数据
if redactedBody, exists := c.Get(RedactedBodyKey); exists {
entry = entry.WithField("body", redactedBody)
fields["body"] = redactedBody
}
if redactedAuth, exists := c.Get(RedactedAuthHeaderKey); exists {
entry = entry.WithField("authorization", redactedAuth)
fields["authorization"] = redactedAuth
}
if redactedQuery, exists := c.Get("redactedQuery"); exists {
fields["query"] = redactedQuery
}
// 添加用户信息(如果已认证)
if user := getAuthenticatedUser(c); user != "" {
fields["user"] = user
}
// 根据状态码选择日志级别
entry := logger.WithFields(fields)
if len(c.Errors) > 0 {
entry.Error(c.Errors.String())
fields["errors"] = c.Errors.String()
entry.Error("Request failed")
} else {
entry.Info("request handled")
switch {
case statusCode >= 500:
entry.Error("Server error")
case statusCode >= 400:
entry.Warn("Client error")
case statusCode >= 300:
entry.Info("Redirect")
default:
// 只在 Debug 模式记录成功请求
if logger.Level >= logrus.DebugLevel {
entry.Debug("Request completed")
}
}
}
}
}
// getRequestID 获取请求 ID
func getRequestID(c *gin.Context) string {
if id, exists := c.Get("request_id"); exists {
if requestID, ok := id.(string); ok {
return requestID
}
}
return ""
}
// getAuthenticatedUser 获取已认证用户标识
func getAuthenticatedUser(c *gin.Context) string {
// 尝试从不同来源获取用户信息
if user, exists := c.Get("adminUser"); exists {
if authToken, ok := user.(interface{ GetID() string }); ok {
return authToken.GetID()
}
}
if user, exists := c.Get("authToken"); exists {
if authToken, ok := user.(interface{ GetID() string }); ok {
return authToken.GetID()
}
}
return ""
}

View File

@@ -0,0 +1,86 @@
// Filename: internal/middleware/rate_limit.go
package middleware
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
type RateLimiter struct {
limiters map[string]*rate.Limiter
mu sync.RWMutex
r rate.Limit // 每秒请求数
b int // 突发容量
}
func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
return &RateLimiter{
limiters: make(map[string]*rate.Limiter),
r: r,
b: b,
}
}
func (rl *RateLimiter) getLimiter(key string) *rate.Limiter {
rl.mu.Lock()
defer rl.mu.Unlock()
limiter, exists := rl.limiters[key]
if !exists {
limiter = rate.NewLimiter(rl.r, rl.b)
rl.limiters[key] = limiter
}
return limiter
}
// 定期清理不活跃的限制器
func (rl *RateLimiter) cleanup() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for range ticker.C {
rl.mu.Lock()
// 简单策略:定期清空(生产环境应该用 LRU
rl.limiters = make(map[string]*rate.Limiter)
rl.mu.Unlock()
}
}
func RateLimitMiddleware(limiter *RateLimiter) gin.HandlerFunc {
return func(c *gin.Context) {
// 按 IP 限流
key := c.ClientIP()
// 如果有认证 token按 token 限流(更精确)
if authToken, exists := c.Get("authToken"); exists {
if token, ok := authToken.(interface{ GetID() string }); ok {
key = "token:" + token.GetID()
}
}
l := limiter.getLimiter(key)
if !l.Allow() {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "Rate limit exceeded",
"code": "RATE_LIMIT",
})
return
}
c.Next()
}
}
// 使用示例
func SetupRateLimit(r *gin.Engine) {
limiter := NewRateLimiter(10, 20) // 每秒 10 个请求,突发 20
go limiter.cleanup()
r.Use(RateLimitMiddleware(limiter))
}

View File

@@ -0,0 +1,39 @@
// Filename: internal/middleware/request_id.go
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// RequestIDMiddleware 请求 ID 追踪中间件
func RequestIDMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 1. 尝试从 Header 获取现有的 Request ID
requestID := c.GetHeader("X-Request-Id")
// 2. 如果没有,生成新的
if requestID == "" {
requestID = uuid.New().String()
}
// 3. 设置到 Context
c.Set("request_id", requestID)
// 4. 返回给客户端(用于追踪)
c.Writer.Header().Set("X-Request-Id", requestID)
c.Next()
}
}
// GetRequestID 获取当前请求的 Request ID
func GetRequestID(c *gin.Context) string {
if id, exists := c.Get("request_id"); exists {
if requestID, ok := id.(string); ok {
return requestID
}
}
return ""
}

View File

@@ -1,4 +1,4 @@
// Filename: internal/middleware/security.go
// Filename: internal/middleware/security.go (简化版)
package middleware
@@ -6,26 +6,136 @@ import (
"gemini-balancer/internal/service"
"gemini-balancer/internal/settings"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
func IPBanMiddleware(securityService *service.SecurityService, settingsManager *settings.SettingsManager) gin.HandlerFunc {
// 简单的缓存项
type cacheItem struct {
value bool
expiration int64
}
// 简单的 TTL 缓存实现
type IPBanCache struct {
items map[string]*cacheItem
mu sync.RWMutex
ttl time.Duration
}
func NewIPBanCache() *IPBanCache {
cache := &IPBanCache{
items: make(map[string]*cacheItem),
ttl: 1 * time.Minute,
}
// 启动清理协程
go cache.cleanup()
return cache
}
func (c *IPBanCache) Get(key string) (bool, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
item, exists := c.items[key]
if !exists {
return false, false
}
// 检查是否过期
if time.Now().UnixNano() > item.expiration {
return false, false
}
return item.value, true
}
func (c *IPBanCache) Set(key string, value bool) {
c.mu.Lock()
defer c.mu.Unlock()
c.items[key] = &cacheItem{
value: value,
expiration: time.Now().Add(c.ttl).UnixNano(),
}
}
func (c *IPBanCache) Delete(key string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.items, key)
}
func (c *IPBanCache) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
c.mu.Lock()
now := time.Now().UnixNano()
for key, item := range c.items {
if now > item.expiration {
delete(c.items, key)
}
}
c.mu.Unlock()
}
}
func IPBanMiddleware(
securityService *service.SecurityService,
settingsManager *settings.SettingsManager,
banCache *IPBanCache,
logger *logrus.Logger,
) gin.HandlerFunc {
return func(c *gin.Context) {
if !settingsManager.IsIPBanEnabled() {
c.Next()
return
}
ip := c.ClientIP()
isBanned, err := securityService.IsIPBanned(c.Request.Context(), ip)
if err != nil {
// 查缓存
if isBanned, exists := banCache.Get(ip); exists {
if isBanned {
logger.WithField("ip", ip).Debug("IP blocked (cached)")
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "Access denied",
})
return
}
c.Next()
return
}
if isBanned {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "您的IP已被暂时封禁请稍后再试"})
// 查数据库
ctx := c.Request.Context()
isBanned, err := securityService.IsIPBanned(ctx, ip)
if err != nil {
logger.WithError(err).WithField("ip", ip).Error("Failed to check IP ban status")
// 降级策略:允许访问
c.Next()
return
}
// 更新缓存
banCache.Set(ip, isBanned)
if isBanned {
logger.WithField("ip", ip).Info("IP blocked")
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"error": "Access denied",
})
return
}
c.Next()
}
}

View File

@@ -0,0 +1,52 @@
// Filename: internal/middleware/timeout.go
package middleware
import (
"context"
"net/http"
"time"
"github.com/gin-gonic/gin"
)
func TimeoutMiddleware(timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
// 创建带超时的 context
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// 替换 request context
c.Request = c.Request.WithContext(ctx)
// 使用 channel 等待请求完成
finished := make(chan struct{})
go func() {
c.Next()
close(finished)
}()
select {
case <-finished:
// 请求正常完成
return
case <-ctx.Done():
// 超时
c.AbortWithStatusJSON(http.StatusGatewayTimeout, gin.H{
"error": "Request timeout",
"code": "TIMEOUT",
})
}
}
}
// 使用示例
func SetupTimeout(r *gin.Engine) {
// 对 API 路由设置 30 秒超时
api := r.Group("/api")
api.Use(TimeoutMiddleware(30 * time.Second))
{
// ... API routes
}
}

View File

@@ -1,23 +1,151 @@
// Filename: internal/middleware/web.go
package middleware
import (
"crypto/sha256"
"encoding/hex"
"gemini-balancer/internal/service"
"log"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
const (
AdminSessionCookie = "gemini_admin_session"
SessionMaxAge = 3600 * 24 * 7 // 7天
CacheTTL = 5 * time.Minute
CleanupInterval = 10 * time.Minute // 降低清理频率
SessionRefreshTime = 30 * time.Minute
)
// ==================== 缓存层 ====================
type authCacheEntry struct {
Token interface{}
ExpiresAt time.Time
}
type authCache struct {
mu sync.RWMutex
cache map[string]*authCacheEntry
ttl time.Duration
}
var webAuthCache = newAuthCache(CacheTTL)
func newAuthCache(ttl time.Duration) *authCache {
c := &authCache{
cache: make(map[string]*authCacheEntry),
ttl: ttl,
}
go c.cleanupLoop()
return c
}
func (c *authCache) get(key string) (interface{}, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, exists := c.cache[key]
if !exists || time.Now().After(entry.ExpiresAt) {
return nil, false
}
return entry.Token, true
}
func (c *authCache) set(key string, token interface{}) {
c.mu.Lock()
defer c.mu.Unlock()
c.cache[key] = &authCacheEntry{
Token: token,
ExpiresAt: time.Now().Add(c.ttl),
}
}
func (c *authCache) delete(key string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.cache, key)
}
func (c *authCache) cleanupLoop() {
ticker := time.NewTicker(CleanupInterval)
defer ticker.Stop()
for range ticker.C {
c.cleanup()
}
}
func (c *authCache) cleanup() {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
count := 0
for key, entry := range c.cache {
if now.After(entry.ExpiresAt) {
delete(c.cache, key)
count++
}
}
if count > 0 {
logrus.Debugf("[AuthCache] Cleaned up %d expired entries", count)
}
}
// ==================== 会话刷新缓存 ====================
var sessionRefreshCache = struct {
sync.RWMutex
timestamps map[string]time.Time
}{
timestamps: make(map[string]time.Time),
}
// 定期清理刷新时间戳
func init() {
go func() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
sessionRefreshCache.Lock()
now := time.Now()
for key, ts := range sessionRefreshCache.timestamps {
if now.Sub(ts) > 2*time.Hour {
delete(sessionRefreshCache.timestamps, key)
}
}
sessionRefreshCache.Unlock()
}
}()
}
// ==================== Cookie 操作 ====================
func SetAdminSessionCookie(c *gin.Context, adminToken string) {
c.SetCookie(AdminSessionCookie, adminToken, 3600*24*7, "/", "", false, true)
secure := c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https"
c.SetSameSite(http.SameSiteStrictMode)
c.SetCookie(AdminSessionCookie, adminToken, SessionMaxAge, "/", "", secure, true)
}
func SetAdminSessionCookieWithAge(c *gin.Context, adminToken string, maxAge int) {
secure := c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https"
c.SetSameSite(http.SameSiteStrictMode)
c.SetCookie(AdminSessionCookie, adminToken, maxAge, "/", "", secure, true)
}
func ClearAdminSessionCookie(c *gin.Context) {
c.SetSameSite(http.SameSiteStrictMode)
c.SetCookie(AdminSessionCookie, "", -1, "/", "", false, true)
}
@@ -29,26 +157,258 @@ func ExtractTokenFromCookie(c *gin.Context) string {
return cookie
}
// ==================== 认证中间件 ====================
func WebAdminAuthMiddleware(authService *service.SecurityService) gin.HandlerFunc {
logger := logrus.New()
logger.SetLevel(getLogLevel())
return func(c *gin.Context) {
cookie := ExtractTokenFromCookie(c)
log.Printf("[WebAuth_Guard] Intercepting request for: %s", c.Request.URL.Path)
log.Printf("[WebAuth_Guard] Found session cookie value: '%s'", cookie)
authToken, err := authService.AuthenticateToken(cookie)
if err != nil {
log.Printf("[WebAuth_Guard] FATAL: AuthenticateToken FAILED. Error: %v. Redirecting to /login.", err)
} else if !authToken.IsAdmin {
log.Printf("[WebAuth_Guard] FATAL: Token validated, but IsAdmin is FALSE. Redirecting to /login.")
} else {
log.Printf("[WebAuth_Guard] SUCCESS: Token validated and IsAdmin is TRUE. Allowing access.")
}
if err != nil || !authToken.IsAdmin {
if cookie == "" {
logger.Debug("[WebAuth] No session cookie found")
ClearAdminSessionCookie(c)
c.Redirect(http.StatusFound, "/login")
c.Abort()
redirectToLogin(c)
return
}
cacheKey := hashToken(cookie)
if cachedToken, found := webAuthCache.get(cacheKey); found {
logger.Debug("[WebAuth] Using cached token")
c.Set("adminUser", cachedToken)
refreshSessionIfNeeded(c, cookie)
c.Next()
return
}
logger.Debug("[WebAuth] Cache miss, authenticating...")
authToken, err := authService.AuthenticateToken(cookie)
if err != nil {
logger.WithError(err).Warn("[WebAuth] Authentication failed")
ClearAdminSessionCookie(c)
webAuthCache.delete(cacheKey)
redirectToLogin(c)
return
}
if !authToken.IsAdmin {
logger.Warn("[WebAuth] User is not admin")
ClearAdminSessionCookie(c)
webAuthCache.delete(cacheKey)
redirectToLogin(c)
return
}
webAuthCache.set(cacheKey, authToken)
logger.Debug("[WebAuth] Authentication success, token cached")
c.Set("adminUser", authToken)
refreshSessionIfNeeded(c, cookie)
c.Next()
}
}
func WebAdminAuthMiddlewareWithLogger(authService *service.SecurityService, logger *logrus.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
cookie := ExtractTokenFromCookie(c)
if cookie == "" {
logger.Debug("No session cookie found")
ClearAdminSessionCookie(c)
redirectToLogin(c)
return
}
cacheKey := hashToken(cookie)
if cachedToken, found := webAuthCache.get(cacheKey); found {
c.Set("adminUser", cachedToken)
refreshSessionIfNeeded(c, cookie)
c.Next()
return
}
authToken, err := authService.AuthenticateToken(cookie)
if err != nil {
logger.WithError(err).Warn("Token authentication failed")
ClearAdminSessionCookie(c)
webAuthCache.delete(cacheKey)
redirectToLogin(c)
return
}
if !authToken.IsAdmin {
logger.Warn("Token valid but user is not admin")
ClearAdminSessionCookie(c)
webAuthCache.delete(cacheKey)
redirectToLogin(c)
return
}
webAuthCache.set(cacheKey, authToken)
c.Set("adminUser", authToken)
refreshSessionIfNeeded(c, cookie)
c.Next()
}
}
// ==================== 辅助函数 ====================
func hashToken(token string) string {
h := sha256.Sum256([]byte(token))
return hex.EncodeToString(h[:])
}
func redirectToLogin(c *gin.Context) {
if isAjaxRequest(c) {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Session expired",
"code": "AUTH_REQUIRED",
})
c.Abort()
return
}
originalPath := c.Request.URL.Path
if originalPath != "/" && originalPath != "/login" {
c.Redirect(http.StatusFound, "/login?redirect="+originalPath)
} else {
c.Redirect(http.StatusFound, "/login")
}
c.Abort()
}
func isAjaxRequest(c *gin.Context) bool {
// 检查 Content-Type
contentType := c.GetHeader("Content-Type")
if strings.Contains(contentType, "application/json") {
return true
}
// 检查 Accept优先检查 JSON
accept := c.GetHeader("Accept")
if strings.Contains(accept, "application/json") &&
!strings.Contains(accept, "text/html") {
return true
}
// 兼容旧版 XMLHttpRequest
return c.GetHeader("X-Requested-With") == "XMLHttpRequest"
}
func refreshSessionIfNeeded(c *gin.Context, token string) {
tokenHash := hashToken(token)
sessionRefreshCache.RLock()
lastRefresh, exists := sessionRefreshCache.timestamps[tokenHash]
sessionRefreshCache.RUnlock()
if !exists || time.Since(lastRefresh) > SessionRefreshTime {
SetAdminSessionCookie(c, token)
sessionRefreshCache.Lock()
sessionRefreshCache.timestamps[tokenHash] = time.Now()
sessionRefreshCache.Unlock()
}
}
func getLogLevel() logrus.Level {
level := os.Getenv("LOG_LEVEL")
switch strings.ToLower(level) {
case "debug":
return logrus.DebugLevel
case "warn":
return logrus.WarnLevel
case "error":
return logrus.ErrorLevel
default:
return logrus.InfoLevel
}
}
// ==================== 工具函数 ====================
func GetAdminUserFromContext(c *gin.Context) (interface{}, bool) {
return c.Get("adminUser")
}
func InvalidateTokenCache(token string) {
tokenHash := hashToken(token)
webAuthCache.delete(tokenHash)
// 同时清理刷新时间戳
sessionRefreshCache.Lock()
delete(sessionRefreshCache.timestamps, tokenHash)
sessionRefreshCache.Unlock()
}
func ClearAllAuthCache() {
webAuthCache.mu.Lock()
webAuthCache.cache = make(map[string]*authCacheEntry)
webAuthCache.mu.Unlock()
sessionRefreshCache.Lock()
sessionRefreshCache.timestamps = make(map[string]time.Time)
sessionRefreshCache.Unlock()
}
// ==================== 调试工具 ====================
type SessionInfo struct {
HasCookie bool `json:"has_cookie"`
IsValid bool `json:"is_valid"`
IsAdmin bool `json:"is_admin"`
IsCached bool `json:"is_cached"`
LastActivity string `json:"last_activity"`
}
func GetSessionInfo(c *gin.Context, authService *service.SecurityService) SessionInfo {
info := SessionInfo{
HasCookie: false,
IsValid: false,
IsAdmin: false,
IsCached: false,
LastActivity: time.Now().Format(time.RFC3339),
}
cookie := ExtractTokenFromCookie(c)
if cookie == "" {
return info
}
info.HasCookie = true
cacheKey := hashToken(cookie)
if _, found := webAuthCache.get(cacheKey); found {
info.IsCached = true
}
authToken, err := authService.AuthenticateToken(cookie)
if err != nil {
return info
}
info.IsValid = true
info.IsAdmin = authToken.IsAdmin
return info
}
func GetCacheStats() map[string]interface{} {
webAuthCache.mu.RLock()
cacheSize := len(webAuthCache.cache)
webAuthCache.mu.RUnlock()
sessionRefreshCache.RLock()
refreshSize := len(sessionRefreshCache.timestamps)
sessionRefreshCache.RUnlock()
return map[string]interface{}{
"auth_cache_entries": cacheSize,
"refresh_cache_entries": refreshSize,
"ttl_seconds": int(webAuthCache.ttl.Seconds()),
"cleanup_interval": int(CleanupInterval.Seconds()),
"session_refresh_time": int(SessionRefreshTime.Seconds()),
}
}

View File

@@ -77,3 +77,9 @@ type APIKeyDetails struct {
CooldownUntil *time.Time `json:"cooldown_until"`
EncryptedKey string
}
// SettingsManager 定义了系统设置管理器的抽象接口。
type SettingsManager interface {
GetSettings() *SystemSettings
}

View File

@@ -101,6 +101,7 @@ type RequestLog struct {
LatencyMs int
IsSuccess bool
StatusCode int
Status string `gorm:"type:varchar(100);index"`
ModelName string `gorm:"type:varchar(100);index"`
GroupID *uint `gorm:"index"`
KeyID *uint `gorm:"index"`

View File

@@ -11,6 +11,7 @@ type SystemSettings struct {
BlacklistThreshold int `json:"blacklist_threshold" default:"3" name:"拉黑阈值" category:"密钥设置" desc:"一个Key连续失败多少次后进入冷却状态。"`
KeyCooldownMinutes int `json:"key_cooldown_minutes" default:"10" name:"密钥冷却时长(分钟)" category:"密钥设置" desc:"一个Key进入冷却状态后需要等待的时间单位为分钟。"`
LogFlushIntervalSeconds int `json:"log_flush_interval_seconds" default:"10" name:"日志刷新间隔(秒)" category:"日志设置" desc:"异步日志写入数据库的间隔时间(秒)。"`
MaxRequestBodySizeMB int `json:"max_request_body_size_mb" default:"10" name:"最大请求体大小 (MB)" category:"请求设置" desc:"允许代理接收的最大请求体大小单位为MB。超过此大小的请求将被拒绝。"`
PollingStrategy PollingStrategy `json:"polling_strategy" default:"random" name:"全局轮询策略" category:"调度设置" desc:"智能聚合模式下,从所有可用密钥中选择一个的默认策略。可选值: sequential(顺序), random(随机), weighted(加权)。"`
@@ -26,6 +27,8 @@ type SystemSettings struct {
BaseKeyCheckEndpoint string `json:"base_key_check_endpoint" default:"https://generativelanguage.googleapis.com/v1beta/models" name:"全局Key检查端点" category:"健康检查" desc:"用于全局Key身份检查的目标URL。"`
BaseKeyCheckModel string `json:"base_key_check_model" default:"gemini-2.0-flash-lite" name:"默认Key检查模型" category:"健康检查" desc:"用于分组健康检查和手动密钥测试时的默认回退模型。"`
KeyCheckSchedulerIntervalSeconds int `json:"key_check_scheduler_interval_seconds" default:"60" name:"Key检查调度器间隔(秒)" category:"健康检查" desc:"动态调度器检查各组是否需要执行健康检查的周期。"`
EnableUpstreamCheck bool `json:"enable_upstream_check" default:"true" name:"启用上游检查" category:"健康检查" desc:"是否启用对上游服务(Upstream)的健康检查。"`
UpstreamCheckTimeoutSeconds int `json:"upstream_check_timeout_seconds" default:"20" name:"上游检查超时(秒)" category:"健康检查" desc:"对单个上游服务进行健康检查时的网络超时时间。"`
@@ -41,6 +44,10 @@ type SystemSettings struct {
MaxLoginAttempts int `json:"max_login_attempts" default:"5" name:"最大登录失败次数" category:"安全设置" desc:"在一个IP被封禁前允许的连续登录失败次数。"`
IPBanDurationMinutes int `json:"ip_ban_duration_minutes" default:"15" name:"IP封禁时长(分钟)" category:"安全设置" desc:"IP被封禁的时长单位为分钟。"`
// BasePool 相关配置
// BasePoolTTLMinutes int `json:"base_pool_ttl_minutes" default:"30" name:"基础资源池最大生存时间(分钟)" category:"基础资源池" desc:"一个动态构建的基础资源池BasePool在Redis中的最大生存时间。到期后即使仍在活跃使用也会被强制重建。"`
// BasePoolTTIMinutes int `json:"base_pool_tti_minutes" default:"10" name:"基础资源池空闲超时(分钟)" category:"基础资源池" desc:"一个基础资源池BasePool在连续无请求后自动销毁的空闲等待时间。"`
//智能网关
LogTruncationLimit int `json:"log_truncation_limit" default:"8000" name:"日志截断长度" category:"日志设置" desc:"在日志中记录上游响应或错误时保留的最大字符数。0表示不截断。"`
EnableSmartGateway bool `json:"enable_smart_gateway" default:"false" name:"启用智能网关" category:"代理设置" desc:"开启后,系统将对流式请求进行智能中断续传、错误标准化等优化。关闭后,系统将作为一个纯净、无干扰的透明代理。"`
@@ -64,6 +71,10 @@ type SystemSettings struct {
LogBufferCapacity int `json:"log_buffer_capacity" default:"1000" name:"日志缓冲区容量" category:"日志设置" desc:"内存中日志缓冲区的最大容量,超过则可能丢弃日志。"`
LogFlushBatchSize int `json:"log_flush_batch_size" default:"100" name:"日志刷新批次大小" category:"日志设置" desc:"每次向数据库批量写入日志的最大数量。"`
LogAutoCleanupEnabled bool `json:"log_auto_cleanup_enabled" default:"false" name:"开启请求日志自动清理" category:"日志配置" desc:"启用后,系统将每日定时删除旧的请求日志。"`
LogAutoCleanupRetentionDays int `json:"log_auto_cleanup_retention_days" default:"30" name:"日志保留天数" category:"日志配置" desc:"自动清理任务将保留最近 N 天的日志。"`
LogAutoCleanupTime string `json:"log_auto_cleanup_time" default:"04:05" name:"每日清理执行时间" category:"日志配置" desc:"自动清理任务执行的时间点24小时制例如 04:05。"`
// --- API配置 ---
CustomHeaders map[string]string `json:"custom_headers" name:"自定义Headers" category:"API配置" ` // 默认为nil

View File

@@ -1,63 +1,96 @@
// Filename: internal/pongo/renderer.go
package pongo
import (
"fmt"
"net/http"
"sync"
"github.com/flosch/pongo2/v6"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/render"
"github.com/sirupsen/logrus"
)
type Renderer struct {
Context pongo2.Context
tplSet *pongo2.TemplateSet
mu sync.RWMutex
globalContext pongo2.Context
tplSet *pongo2.TemplateSet
logger *logrus.Logger
}
func New(directory string, isDebug bool) *Renderer {
func New(directory string, isDebug bool, logger *logrus.Logger) *Renderer {
loader := pongo2.MustNewLocalFileSystemLoader(directory)
tplSet := pongo2.NewSet("gin-pongo-templates", loader)
tplSet.Debug = isDebug
return &Renderer{Context: make(pongo2.Context), tplSet: tplSet}
return &Renderer{
globalContext: make(pongo2.Context),
tplSet: tplSet,
logger: logger,
}
}
// Instance returns a new render.HTML instance for a single request.
func (p *Renderer) Instance(name string, data interface{}) render.Render {
var glob pongo2.Context
if p.Context != nil {
glob = p.Context
}
// SetGlobalContext 线程安全地设置全局上下文
func (p *Renderer) SetGlobalContext(key string, value interface{}) {
p.mu.Lock()
defer p.mu.Unlock()
p.globalContext[key] = value
}
// Warmup 预加载模板
func (p *Renderer) Warmup(templateNames ...string) error {
for _, name := range templateNames {
if _, err := p.tplSet.FromCache(name); err != nil {
return fmt.Errorf("failed to warmup template '%s': %w", name, err)
}
}
p.logger.WithField("count", len(templateNames)).Info("Templates warmed up")
return nil
}
func (p *Renderer) Instance(name string, data interface{}) render.Render {
// 安全读取全局上下文
p.mu.RLock()
glob := make(pongo2.Context, len(p.globalContext))
for k, v := range p.globalContext {
glob[k] = v
}
p.mu.RUnlock()
// 解析请求数据
var context pongo2.Context
if data != nil {
if ginContext, ok := data.(gin.H); ok {
context = pongo2.Context(ginContext)
} else if pongoContext, ok := data.(pongo2.Context); ok {
context = pongoContext
} else if m, ok := data.(map[string]interface{}); ok {
context = m
} else {
switch v := data.(type) {
case gin.H:
context = pongo2.Context(v)
case pongo2.Context:
context = v
case map[string]interface{}:
context = v
default:
context = make(pongo2.Context)
}
} else {
context = make(pongo2.Context)
}
// 合并上下文(请求数据优先)
for k, v := range glob {
if _, ok := context[k]; !ok {
if _, exists := context[k]; !exists {
context[k] = v
}
}
// 加载模板
tpl, err := p.tplSet.FromCache(name)
if err != nil {
panic(fmt.Sprintf("Failed to load template '%s': %v", name, err))
p.logger.WithError(err).WithField("template", name).Error("Failed to load template")
return &ErrorHTML{
StatusCode: http.StatusInternalServerError,
Error: fmt.Errorf("template load error: %s", name),
}
}
return &HTML{
p: p,
Template: tpl,
Name: name,
Data: context,
@@ -65,7 +98,6 @@ func (p *Renderer) Instance(name string, data interface{}) render.Render {
}
type HTML struct {
p *Renderer
Template *pongo2.Template
Name string
Data pongo2.Context
@@ -82,15 +114,31 @@ func (h *HTML) Render(w http.ResponseWriter) error {
}
func (h *HTML) WriteContentType(w http.ResponseWriter) {
header := w.Header()
if val := header["Content-Type"]; len(val) == 0 {
header["Content-Type"] = []string{"text/html; charset=utf-8"}
if w.Header().Get("Content-Type") == "" {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
}
}
// ErrorHTML 错误渲染器
type ErrorHTML struct {
StatusCode int
Error error
}
func (e *ErrorHTML) Render(w http.ResponseWriter) error {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(e.StatusCode)
_, err := w.Write([]byte(e.Error.Error()))
return err
}
func (e *ErrorHTML) WriteContentType(w http.ResponseWriter) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
}
// C 获取或创建 pongo2 上下文
func C(ctx *gin.Context) pongo2.Context {
p, exists := ctx.Get("pongo2")
if exists {
if p, exists := ctx.Get("pongo2"); exists {
if pCtx, ok := p.(pongo2.Context); ok {
return pCtx
}

View File

@@ -17,8 +17,8 @@ import (
type AuthTokenRepository interface {
GetAllTokensWithGroups() ([]*models.AuthToken, error)
BatchUpdateTokens(updates []*models.TokenUpdateRequest) error
GetTokenByHashedValue(tokenHash string) (*models.AuthToken, error) // <-- Add this line
SeedAdminToken(encryptedToken, tokenHash string) error // <-- And this line for the seeder
GetTokenByHashedValue(tokenHash string) (*models.AuthToken, error)
SeedAdminToken(encryptedToken, tokenHash string) error
}
type gormAuthTokenRepository struct {

View File

@@ -1,13 +1,15 @@
// Filename: internal/repository/key_cache.go
// Filename: internal/repository/key_cache.go (最终定稿)
package repository
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
"strconv"
)
// --- Redis Key 常量定义 ---
const (
KeyGroup = "group:%d:keys:active"
KeyDetails = "key:%d:details"
@@ -22,13 +24,16 @@ const (
BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown"
)
func (r *gormKeyRepository) LoadAllKeysToStore() error {
r.logger.Info("Starting to load all keys and associations into cache, including polling structures...")
// LoadAllKeysToStore 从数据库加载所有密钥和映射关系并完整重建Redis缓存。
func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
r.logger.Info("Starting full cache rebuild for all keys and polling structures.")
var allMappings []*models.GroupAPIKeyMapping
if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil {
return fmt.Errorf("failed to load all mappings with APIKeys from DB: %w", err)
return fmt.Errorf("failed to load mappings with preloaded APIKeys: %w", err)
}
// 1. 批量解密所有涉及的密钥
keyMap := make(map[uint]*models.APIKey)
for _, m := range allMappings {
if m.APIKey != nil {
@@ -40,16 +45,16 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
keysToDecrypt = append(keysToDecrypt, *k)
}
if err := r.decryptKeys(keysToDecrypt); err != nil {
r.logger.WithError(err).Error("Critical error during cache preload: batch decryption failed.")
r.logger.WithError(err).Error("Batch decryption failed during cache rebuild.")
// 即使解密失败,也继续尝试加载未加密或已解密的部分
}
decryptedKeyMap := make(map[uint]models.APIKey)
for _, k := range keysToDecrypt {
decryptedKeyMap[k.ID] = k
}
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
pipe := r.store.Pipeline()
detailsToSet := make(map[string][]byte)
// 2. 清理所有分组的旧轮询结构
pipe := r.store.Pipeline(ctx)
var allGroups []*models.KeyGroup
if err := r.db.Find(&allGroups).Error; err == nil {
for _, group := range allGroups {
@@ -62,26 +67,41 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
)
}
} else {
r.logger.WithError(err).Error("Failed to get all groups for cache cleanup")
r.logger.WithError(err).Error("Failed to get groups for cache cleanup; proceeding with rebuild.")
}
// 3. 准备批量更新数据
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
detailsToSet := make(map[string]any)
for _, mapping := range allMappings {
if mapping.APIKey == nil {
continue
}
decryptedKey, ok := decryptedKeyMap[mapping.APIKeyID]
if !ok {
continue
continue // 跳过解密失败的密钥
}
// 准备 KeyDetails 和 KeyMapping 的 MSet 数据
keyJSON, _ := json.Marshal(decryptedKey)
detailsToSet[fmt.Sprintf(KeyDetails, decryptedKey.ID)] = keyJSON
mappingJSON, _ := json.Marshal(mapping)
detailsToSet[fmt.Sprintf(KeyMapping, mapping.KeyGroupID, decryptedKey.ID)] = mappingJSON
if mapping.Status == models.StatusActive {
activeKeysByGroup[mapping.KeyGroupID] = append(activeKeysByGroup[mapping.KeyGroupID], mapping)
}
}
// 4. 使用 MSet 批量写入详情和映射缓存
if len(detailsToSet) > 0 {
if err := r.store.MSet(ctx, detailsToSet); err != nil {
r.logger.WithError(err).Error("Failed to MSet key details and mappings during cache rebuild.")
}
}
// 5. 在Pipeline中重建所有分组的轮询结构
for groupID, activeMappings := range activeKeysByGroup {
if len(activeMappings) == 0 {
continue
@@ -100,22 +120,19 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...)
go r.store.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
}
// 6. 执行Pipeline
if err := pipe.Exec(); err != nil {
return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err)
}
for key, value := range detailsToSet {
if err := r.store.Set(key, value, 0); err != nil {
r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key)
}
return fmt.Errorf("pipeline execution for polling structures failed: %w", err)
}
r.logger.Info("Cache rebuild complete, including all polling structures.")
r.logger.Info("Full cache rebuild completed successfully.")
return nil
}
// updateStoreCacheForKey 更新单个APIKey的详情缓存 (K-V)。
func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
if err := r.decryptKey(key); err != nil {
return fmt.Errorf("failed to decrypt key %d for cache update: %w", key.ID, err)
@@ -124,81 +141,104 @@ func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
if err != nil {
return fmt.Errorf("failed to marshal key %d for cache update: %w", key.ID, err)
}
return r.store.Set(fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
return r.store.Set(context.Background(), fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
}
func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error {
groupIDs, err := r.GetGroupsForKey(key.ID)
// removeStoreCacheForKey 从所有缓存结构中彻底移除一个APIKey。
func (r *gormKeyRepository) removeStoreCacheForKey(ctx context.Context, key *models.APIKey) error {
groupIDs, err := r.GetGroupsForKey(ctx, key.ID)
if err != nil {
r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err)
r.logger.WithError(err).Warnf("Failed to get groups for key %d during cache removal, cleanup may be partial.", key.ID)
}
pipe := r.store.Pipeline()
pipe := r.store.Pipeline(ctx)
pipe.Del(fmt.Sprintf(KeyDetails, key.ID))
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
for _, groupID := range groupIDs {
pipe.Del(fmt.Sprintf(KeyMapping, groupID, key.ID))
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
go r.store.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
}
return pipe.Exec()
}
// updateStoreCacheForMapping 根据单个映射关系的状态,原子性地更新所有相关的缓存结构。
func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error {
pipe := r.store.Pipeline()
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID)
pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID)
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
groupID := mapping.KeyGroupID
ctx := context.Background()
pipe := r.store.Pipeline(ctx)
// 统一、无条件地从所有轮询结构中移除,确保状态清洁
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
// 如果新状态是 Active则重新添加到所有轮询结构中
if mapping.Status == models.StatusActive {
pipe.LPush(activeKeyListKey, mapping.APIKeyID)
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
}
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score})
}
// 无论状态如何,都更新映射详情的 K-V 缓存
mappingJSON, err := json.Marshal(mapping)
if err != nil {
return fmt.Errorf("failed to marshal mapping: %w", err)
}
pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0)
return pipe.Exec()
}
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error {
// HandleCacheUpdateEventBatch 批量、原子性地更新多个映射关系的缓存。
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error {
if len(mappings) == 0 {
return nil
}
groupUpdates := make(map[uint]struct {
ToAdd []interface{}
ToRemove []interface{}
})
pipe := r.store.Pipeline(ctx)
for _, mapping := range mappings {
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
update, ok := groupUpdates[mapping.KeyGroupID]
if !ok {
update = struct {
ToAdd []interface{}
ToRemove []interface{}
}{}
}
groupID := mapping.KeyGroupID
// 对于批处理中的每一个mapping都执行完整的、正确的“先删后增”逻辑
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
if mapping.Status == models.StatusActive {
update.ToRemove = append(update.ToRemove, keyIDStr)
update.ToAdd = append(update.ToAdd, keyIDStr)
} else {
update.ToRemove = append(update.ToRemove, keyIDStr)
}
groupUpdates[mapping.KeyGroupID] = update
}
pipe := r.store.Pipeline()
var pipelineError error
for groupID, updates := range groupUpdates {
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
if len(updates.ToRemove) > 0 {
for _, keyID := range updates.ToRemove {
pipe.LRem(activeKeyListKey, 0, keyID)
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr)
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
}
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score})
}
if len(updates.ToAdd) > 0 {
pipe.LPush(activeKeyListKey, updates.ToAdd...)
}
mappingJSON, _ := json.Marshal(mapping) // 在批处理中忽略单个marshal错误以保证大部分更新成功
pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0)
}
if err := pipe.Exec(); err != nil {
pipelineError = fmt.Errorf("redis pipeline execution failed: %w", err)
}
return pipelineError
return pipe.Exec()
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"gemini-balancer/internal/models"
"context"
"math/rand"
"strings"
"time"
@@ -22,7 +23,6 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
keyHashes := make([]string, len(keys))
keyValueToHashMap := make(map[string]string)
for i, k := range keys {
// All incoming keys must have plaintext APIKey
if k.APIKey == "" {
return nil, fmt.Errorf("cannot add key at index %d: plaintext APIKey is empty", i)
}
@@ -34,7 +34,6 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
var finalKeys []models.APIKey
err := r.db.Transaction(func(tx *gorm.DB) error {
var existingKeys []models.APIKey
// [MODIFIED] Query by hash to find existing keys.
if err := tx.Unscoped().Where("api_key_hash IN ?", keyHashes).Find(&existingKeys).Error; err != nil {
return err
}
@@ -68,24 +67,20 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
}
}
if len(keysToCreate) > 0 {
// [MODIFIED] Create now only provides encrypted data and hash.
if err := tx.Clauses(clause.OnConflict{DoNothing: true}, clause.Returning{}).Create(&keysToCreate).Error; err != nil {
return err
}
}
// [MODIFIED] Final select uses hashes to retrieve all relevant keys.
if err := tx.Where("api_key_hash IN ?", keyHashes).Find(&finalKeys).Error; err != nil {
return err
}
// [CRITICAL] Decrypt all keys before returning them to the service layer.
return r.decryptKeys(finalKeys)
})
return finalKeys, err
}
func (r *gormKeyRepository) Update(key *models.APIKey) error {
// [CRITICAL] Before saving, check if the plaintext APIKey field was populated.
// This indicates a potential change that needs to be re-encrypted.
if key.APIKey != "" {
encryptedKey, err := r.crypto.Encrypt(key.APIKey)
if err != nil {
@@ -97,16 +92,16 @@ func (r *gormKeyRepository) Update(key *models.APIKey) error {
key.APIKeyHash = hex.EncodeToString(hash[:])
}
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
// GORM automatically ignores `key.APIKey` because of the `gorm:"-"` tag.
return tx.Save(key).Error
})
if err != nil {
return err
}
// For the cache update, we need the plaintext. Decrypt if it's not already populated.
if err := r.decryptKey(key); err != nil {
r.logger.Warnf("DB updated key ID %d, but decryption for cache failed: %v", key.ID, err)
return nil // Continue without cache update if decryption fails.
return nil
}
if err := r.updateStoreCacheForKey(key); err != nil {
r.logger.Warnf("DB updated key ID %d, but cache update failed: %v", key.ID, err)
@@ -115,7 +110,7 @@ func (r *gormKeyRepository) Update(key *models.APIKey) error {
}
func (r *gormKeyRepository) HardDeleteByID(id uint) error {
key, err := r.GetKeyByID(id) // This now returns a decrypted key
key, err := r.GetKeyByID(id)
if err != nil {
return err
}
@@ -125,7 +120,7 @@ func (r *gormKeyRepository) HardDeleteByID(id uint) error {
if err != nil {
return err
}
if err := r.removeStoreCacheForKey(key); err != nil {
if err := r.removeStoreCacheForKey(context.Background(), key); err != nil {
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", id, err)
}
return nil
@@ -140,16 +135,13 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error
hash := sha256.Sum256([]byte(v))
hashes[i] = hex.EncodeToString(hash[:])
}
// Find the full key objects first to update the cache later.
var keysToDelete []models.APIKey
// [MODIFIED] Find by hash.
if err := r.db.Where("api_key_hash IN ?", hashes).Find(&keysToDelete).Error; err != nil {
return 0, err
}
if len(keysToDelete) == 0 {
return 0, nil
}
// Decrypt them to ensure cache has plaintext if needed.
if err := r.decryptKeys(keysToDelete); err != nil {
r.logger.Warnf("Decryption failed for keys before hard delete, cache removal may be impacted: %v", err)
}
@@ -167,7 +159,7 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error
return 0, err
}
for i := range keysToDelete {
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
}
}
@@ -194,7 +186,6 @@ func (r *gormKeyRepository) GetKeysByIDs(ids []uint) ([]models.APIKey, error) {
if err != nil {
return nil, err
}
// [CRITICAL] Decrypt before returning.
return keys, r.decryptKeys(keys)
}

View File

@@ -2,6 +2,7 @@
package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"gemini-balancer/internal/models"
@@ -110,13 +111,13 @@ func (r *gormKeyRepository) deleteOrphanKeysLogic(db *gorm.DB) (int64, error) {
}
result := db.Delete(&models.APIKey{}, orphanKeyIDs)
//result := db.Unscoped().Delete(&models.APIKey{}, orphanKeyIDs)
if result.Error != nil {
return 0, result.Error
}
for i := range keysToDelete {
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
// [修正] 使用 context.Background() 调用已更新的缓存清理函数
if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
r.logger.Warnf("DB deleted orphan key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
}
}
@@ -144,7 +145,7 @@ func (r *gormKeyRepository) GetActiveMasterKeys() ([]*models.APIKey, error) {
return keys, nil
}
func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error {
func (r *gormKeyRepository) UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error {
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
result := tx.Model(&models.APIKey{}).
Where("id = ?", keyID).
@@ -160,7 +161,7 @@ func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterA
if err == nil {
r.logger.Infof("MasterStatus for key ID %d changed, triggering a full cache reload.", keyID)
go func() {
if err := r.LoadAllKeysToStore(); err != nil {
if err := r.LoadAllKeysToStore(context.Background()); err != nil {
r.logger.Errorf("Failed to reload cache after MasterStatus change for key ID %d: %v", keyID, err)
}
}()

View File

@@ -2,6 +2,7 @@
package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
@@ -14,7 +15,7 @@ import (
"gorm.io/gorm/clause"
)
func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
func (r *gormKeyRepository) LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error {
if len(keyIDs) == 0 {
return nil
}
@@ -34,12 +35,12 @@ func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
}
for _, keyID := range keyIDs {
r.store.SAdd(fmt.Sprintf("key:%d:groups", keyID), groupID)
r.store.SAdd(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
}
return nil
}
func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (int64, error) {
func (r *gormKeyRepository) UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (int64, error) {
if len(keyIDs) == 0 {
return 0, nil
}
@@ -63,16 +64,16 @@ func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (in
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
for _, keyID := range keyIDs {
r.store.SRem(fmt.Sprintf("key:%d:groups", keyID), groupID)
r.store.LRem(activeKeyListKey, 0, strconv.Itoa(int(keyID)))
r.store.SRem(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
r.store.LRem(context.Background(), activeKeyListKey, 0, strconv.Itoa(int(keyID)))
}
return unlinkedCount, nil
}
func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
func (r *gormKeyRepository) GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error) {
cacheKey := fmt.Sprintf("key:%d:groups", keyID)
strGroupIDs, err := r.store.SMembers(cacheKey)
strGroupIDs, err := r.store.SMembers(context.Background(), cacheKey)
if err != nil || len(strGroupIDs) == 0 {
var groupIDs []uint
dbErr := r.db.Table("group_api_key_mappings").Where("api_key_id = ?", keyID).Pluck("key_group_id", &groupIDs).Error
@@ -84,7 +85,7 @@ func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
for _, id := range groupIDs {
interfaceSlice = append(interfaceSlice, id)
}
r.store.SAdd(cacheKey, interfaceSlice...)
r.store.SAdd(context.Background(), cacheKey, interfaceSlice...)
}
return groupIDs, nil
}
@@ -103,7 +104,7 @@ func (r *gormKeyRepository) GetMapping(groupID, keyID uint) (*models.GroupAPIKey
return &mapping, err
}
func (r *gormKeyRepository) UpdateMapping(mapping *models.GroupAPIKeyMapping) error {
func (r *gormKeyRepository) UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error {
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
return tx.Save(mapping).Error
})

View File

@@ -2,6 +2,7 @@
package repository
import (
"context"
"crypto/sha1"
"encoding/json"
"errors"
@@ -17,40 +18,40 @@ import (
)
const (
CacheTTL = 5 * time.Minute
EmptyPoolPlaceholder = "EMPTY_POOL"
EmptyCacheTTL = 1 * time.Minute
CacheTTL = 5 * time.Minute
EmptyCacheTTL = 1 * time.Minute
)
// SelectOneActiveKey 根据指定的轮询策略,从缓存中高效地选取一个可用的API密钥。
func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
// SelectOneActiveKey 根据指定的轮询策略,从单个密钥组缓存中选取一个可用的API密钥。
func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
if group == nil {
return nil, nil, fmt.Errorf("group cannot be nil")
}
var keyIDStr string
var err error
switch group.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
keyIDStr, err = r.store.Rotate(sequentialKey)
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
case models.StrategyWeighted:
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
results, zerr := r.store.ZRange(lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
keyIDStr = results[0]
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil {
if len(results) > 0 {
keyIDStr = results[0]
} else {
zerr = gorm.ErrRecordNotFound
}
}
err = zerr
case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
default: // 默认或未指定策略时,使用基础的随机策略
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default:
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
keyIDStr, err = r.store.SRandMember(activeKeySetKey)
keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey)
}
if err != nil {
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, gorm.ErrRecordNotFound
@@ -58,65 +59,70 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.
r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy)
return nil, nil, err
}
if keyIDStr == "" {
return nil, nil, gorm.ErrRecordNotFound
}
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
apiKey, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64)
if parseErr != nil {
r.logger.WithError(parseErr).Errorf("Invalid key ID format in group %d cache: %s", group.ID, keyIDStr)
return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr)
}
apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
if err != nil {
r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID)
// TODO 可以在此加入重试逻辑,再次调用 SelectOneActiveKey(group)
r.logger.WithError(err).Warnf("Cache inconsistency for key ID %d in group %d", keyID, group.ID)
return nil, nil, err
}
if group.PollingStrategy == models.StrategyWeighted {
go r.UpdateKeyUsageTimestamp(group.ID, uint(keyID))
go func() {
updateCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.UpdateKeyUsageTimestamp(updateCtx, group.ID, uint(keyID))
}()
}
return apiKey, mapping, nil
}
// SelectOneActiveKeyFromBasePool 智能聚合模式设计的全新轮询器
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
// 生成唯一的池ID确保不同请求组合的轮询状态相互隔离
poolID := generatePoolID(pool.CandidateGroups)
// SelectOneActiveKeyFromBasePool 智能聚合池中选取一个可用Key
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
if pool == nil || len(pool.CandidateGroups) == 0 {
return nil, nil, fmt.Errorf("invalid or empty base pool configuration")
}
poolID := r.generatePoolID(pool.CandidateGroups)
log := r.logger.WithField("pool_id", poolID)
if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil {
log.WithError(err).Error("Failed to ensure BasePool cache exists.")
if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
log.WithError(err).Error("Failed to ensure BasePool cache exists")
}
return nil, nil, err
}
var keyIDStr string
var err error
switch pool.PollingStrategy {
case models.StrategySequential:
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.Rotate(sequentialKey)
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
case models.StrategyWeighted:
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
results, zerr := r.store.ZRange(lruKey, 0, 0)
if zerr == nil && len(results) > 0 {
keyIDStr = results[0]
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
if zerr == nil {
if len(results) > 0 {
keyIDStr = results[0]
} else {
zerr = gorm.ErrRecordNotFound
}
}
err = zerr
case models.StrategyRandom:
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
default: // 默认策略,应该在 ensureCache 中处理,但作为降级方案
log.Warnf("Default polling strategy triggered inside selection. This should be rare.")
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
default:
log.Warnf("Unknown polling strategy '%s'. Using sequential as fallback.", pool.PollingStrategy)
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
keyIDStr, err = r.store.LIndex(sequentialKey, 0)
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
}
if err != nil {
if errors.Is(err, store.ErrNotFound) {
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, gorm.ErrRecordNotFound
}
log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy)
@@ -125,153 +131,266 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
if keyIDStr == "" {
return nil, nil, gorm.ErrRecordNotFound
}
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
for _, group := range pool.CandidateGroups {
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(uint(keyID), group.ID)
if cacheErr == nil && apiKey != nil && mapping != nil {
if pool.PollingStrategy == models.StrategyWeighted {
go r.updateKeyUsageTimestampForPool(poolID, uint(keyID))
}
return apiKey, group, nil
go func() {
bgCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.refreshBasePoolHeartbeat(bgCtx, poolID)
}()
keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64)
if parseErr != nil {
log.WithError(parseErr).Errorf("Invalid key ID format in BasePool cache: %s", keyIDStr)
return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr)
}
keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID)
groupIDStr, err := r.store.HGet(ctx, keyToGroupMapKey, keyIDStr)
if err != nil {
log.WithError(err).Errorf("Cache inconsistency: KeyID %d found in pool but not in key-to-group map", keyID)
return nil, nil, errors.New("cache inconsistency: key has no origin group mapping")
}
groupID, parseErr := strconv.ParseUint(groupIDStr, 10, 64)
if parseErr != nil {
log.WithError(parseErr).Errorf("Invalid group ID format in key-to-group map for key %d: %s", keyID, groupIDStr)
return nil, nil, errors.New("cache inconsistency: invalid group id in mapping")
}
apiKey, _, err := r.getKeyDetailsFromCache(ctx, uint(keyID), uint(groupID))
if err != nil {
log.WithError(err).Warnf("Cache inconsistency: Failed to get details for key %d in mapped group %d", keyID, groupID)
return nil, nil, err
}
var originGroup *models.KeyGroup
for _, g := range pool.CandidateGroups {
if g.ID == uint(groupID) {
originGroup = g
break
}
}
log.Errorf("Cache inconsistency: Selected KeyID %d from BasePool but could not find its origin group.", keyID)
return nil, nil, errors.New("cache inconsistency: selected key has no origin group")
if originGroup == nil {
log.Errorf("Logic error: Mapped GroupID %d not found in pool's candidate groups list", groupID)
return nil, nil, errors.New("cache inconsistency: mapped group not in candidate list")
}
if pool.PollingStrategy == models.StrategyWeighted {
go func() {
bgCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.updateKeyUsageTimestampForPool(bgCtx, poolID, uint(keyID))
}()
}
return apiKey, originGroup, nil
}
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error {
listKey := fmt.Sprintf(BasePoolSequential, poolID)
// --- [逻辑优化] 提前处理“毒丸”,让逻辑更清晰 ---
exists, err := r.store.Exists(listKey)
if err != nil {
r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID)
return err // 直接返回读取错误
// ensureBasePoolCacheExists 动态创建或验证 BasePool 的 Redis 缓存结构
func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error {
heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID)
emptyMarkerKey := fmt.Sprintf("basepool:empty:%s", poolID)
// 预检查,快速失败
if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists {
return gorm.ErrRecordNotFound
}
if exists {
val, err := r.store.LIndex(listKey, 0)
if err != nil {
// 如果连 LIndex 都失败,说明缓存可能已损坏,允许重建
r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID)
} else {
if val == EmptyPoolPlaceholder {
return gorm.ErrRecordNotFound // 已知为空,直接返回
}
return nil // 缓存有效,直接返回
}
}
// --- [锁机制优化] 增加分布式锁,防止并发构建时的惊群效应 ---
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
acquired, err := r.store.SetNX(lockKey, []byte("1"), 10*time.Second) // 10秒锁超时
if err != nil {
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.")
return err
}
if !acquired {
// 未获取到锁,等待一小段时间后重试,让持有锁的协程完成构建
time.Sleep(100 * time.Millisecond)
return r.ensureBasePoolCacheExists(pool, poolID)
}
defer r.store.Del(lockKey) // 确保在函数退出时释放锁
// 双重检查,防止在获取锁的间隙,已有其他协程完成了构建
if exists, _ := r.store.Exists(listKey); exists {
if exists, _ := r.store.Exists(ctx, heartbeatKey); exists {
return nil
}
r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID)
var allActiveKeyIDs []string
lruMembers := make(map[string]float64)
// 获取分布式锁
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
if err := r.acquireLock(ctx, lockKey); err != nil {
return err // acquireLock 内部已记录日志并返回明确错误
}
defer r.releaseLock(context.Background(), lockKey)
// 双重检查锁定
if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists {
return gorm.ErrRecordNotFound
}
if exists, _ := r.store.Exists(ctx, heartbeatKey); exists {
return nil
}
// 在执行重度操作前,最后检查一次上下文是否已取消
select {
case <-ctx.Done():
return ctx.Err()
default:
}
r.logger.Infof("Building BasePool cache for pool_id '%s'", poolID)
// 手动聚合所有 Keys 并同时构建 key-to-group 映射
keyToGroupMap := make(map[string]any)
allKeyIDsSet := make(map[string]struct{})
for _, group := range pool.CandidateGroups {
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
groupKeyIDs, err := r.store.SMembers(activeKeySetKey)
// --- [核心修正] ---
// 这是整个问题的根源。我们绝不能在读取失败时,默默地`continue`。
// 任何读取源数据的失败,都必须被视为一次构建过程的彻底失败,并立即中止。
groupKeySet := fmt.Sprintf(KeyGroup, group.ID)
groupKeyIDs, err := r.store.SMembers(ctx, groupKeySet)
if err != nil {
r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID)
// 返回这个瞬时错误。这会导致本次请求失败,但绝不会写入“毒丸”,
// 从而给了下一次请求一个全新的、成功的机会。
return err
r.logger.WithError(err).Warnf("Failed to get members for group %d during pool build", group.ID)
continue
}
// 只有在 SMembers 成功时,才继续处理
allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...)
for _, keyIDStr := range groupKeyIDs {
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
_, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
if err == nil && mapping != nil {
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
}
lruMembers[keyIDStr] = score
groupIDStr := strconv.FormatUint(uint64(group.ID), 10)
for _, keyID := range groupKeyIDs {
if _, exists := allKeyIDsSet[keyID]; !exists {
allKeyIDsSet[keyID] = struct{}{}
keyToGroupMap[keyID] = groupIDStr
}
}
}
// --- [逻辑修正] ---
// 只有在“我们成功读取了所有数据,但发现数据本身是空的”这种情况下,
// 才允许写入“毒丸”。
if len(allActiveKeyIDs) == 0 {
r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID)
pipe := r.store.Pipeline()
pipe.LPush(listKey, EmptyPoolPlaceholder)
pipe.Expire(listKey, EmptyCacheTTL)
if err := pipe.Exec(); err != nil {
r.logger.WithError(err).Errorf("Failed to set empty pool placeholder for pool_id '%s'", poolID)
// 处理空池情况
if len(allKeyIDsSet) == 0 {
emptyCacheTTL := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute / 2
if emptyCacheTTL < time.Minute {
emptyCacheTTL = time.Minute
}
r.logger.Warnf("No active keys found for pool_id '%s', setting empty marker.", poolID)
if err := r.store.Set(ctx, emptyMarkerKey, []byte("1"), emptyCacheTTL); err != nil {
r.logger.WithError(err).Warnf("Failed to set empty marker for pool_id '%s'", poolID)
}
return gorm.ErrRecordNotFound
}
// 使用管道填充所有轮询结构
pipe := r.store.Pipeline()
// 1. 顺序
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
// 2. 随机
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
// 设置合理的过期时间例如5分钟以防止孤儿数据
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
pipe.Expire(fmt.Sprintf(BasePoolLRU, poolID), CacheTTL)
allActiveKeyIDs := make([]string, 0, len(allKeyIDsSet))
for keyID := range allKeyIDsSet {
allActiveKeyIDs = append(allActiveKeyIDs, keyID)
}
// 使用 Pipeline 原子化构建所有缓存结构
basePoolTTL := time.Duration(r.config.Repository.BasePoolTTLMinutes) * time.Minute
basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
cooldownKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID)
pipe := r.store.Pipeline(ctx)
pipe.Del(mainPoolKey, sequentialKey, cooldownKey, lruKey, emptyMarkerKey, keyToGroupMapKey)
pipe.SAdd(mainPoolKey, r.toInterfaceSlice(allActiveKeyIDs)...)
pipe.LPush(sequentialKey, r.toInterfaceSlice(allActiveKeyIDs)...)
if len(keyToGroupMap) > 0 {
pipe.HSet(keyToGroupMapKey, keyToGroupMap)
pipe.Expire(keyToGroupMapKey, basePoolTTL)
}
pipe.Expire(mainPoolKey, basePoolTTL)
pipe.Expire(sequentialKey, basePoolTTL)
pipe.Expire(cooldownKey, basePoolTTL)
pipe.Expire(lruKey, basePoolTTL)
pipe.Set(heartbeatKey, []byte("1"), basePoolTTI)
if err := pipe.Exec(); err != nil {
r.logger.WithError(err).Errorf("Failed to populate polling structures for pool_id '%s'", poolID)
cleanupCtx, cancel := r.withTimeout(5 * time.Second)
defer cancel()
r.store.Del(cleanupCtx, mainPoolKey, sequentialKey, cooldownKey, lruKey, heartbeatKey, emptyMarkerKey, keyToGroupMapKey)
return err
}
if len(lruMembers) > 0 {
r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers)
}
// 异步填充 LRU 缓存,并传入已构建好的映射
go r.populateBasePoolLRUCache(context.Background(), poolID, allActiveKeyIDs, keyToGroupMap)
return nil
}
// --- 辅助方法 ---
// acquireLock 封装了带重试和指数退避的分布式锁获取逻辑。
func (r *gormKeyRepository) acquireLock(ctx context.Context, lockKey string) error {
const (
lockTTL = 30 * time.Second
lockMaxRetries = 5
lockBaseBackoff = 50 * time.Millisecond
)
for i := 0; i < lockMaxRetries; i++ {
acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), lockTTL)
if err != nil {
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock")
return err
}
if acquired {
return nil
}
time.Sleep(lockBaseBackoff * (1 << i))
}
return fmt.Errorf("failed to acquire lock for key '%s' after %d retries", lockKey, lockMaxRetries)
}
// releaseLock 封装了分布式锁的释放逻辑。
func (r *gormKeyRepository) releaseLock(ctx context.Context, lockKey string) {
if err := r.store.Del(ctx, lockKey); err != nil {
r.logger.WithError(err).Errorf("Failed to release distributed lock for key '%s'", lockKey)
}
}
// withTimeout 是 context.WithTimeout 的一个简单包装,便于测试和模拟。
func (r *gormKeyRepository) withTimeout(duration time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), duration)
}
// refreshBasePoolHeartbeat 异步刷新心跳Key的TTI
func (r *gormKeyRepository) refreshBasePoolHeartbeat(ctx context.Context, poolID string) {
basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute
heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID)
// 使用 EXPIRE 命令来刷新如果Key不存在它什么也不做是安全的
if err := r.store.Expire(ctx, heartbeatKey, basePoolTTI); err != nil {
if ctx.Err() == nil { // 避免在context取消后打印不必要的错误
r.logger.WithError(err).Warnf("Failed to refresh heartbeat for pool_id '%s'", poolID)
}
}
}
// populateBasePoolLRUCache 异步填充 BasePool 的 LRU 缓存结构
func (r *gormKeyRepository) populateBasePoolLRUCache(
parentCtx context.Context,
currentPoolID string,
keys []string,
keyToGroupMap map[string]any,
) {
lruMembers := make(map[string]float64, len(keys))
for _, keyIDStr := range keys {
select {
case <-parentCtx.Done():
return
default:
}
groupIDStr, ok := keyToGroupMap[keyIDStr].(string)
if !ok {
continue
}
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
groupID, _ := strconv.ParseUint(groupIDStr, 10, 64)
mappingKey := fmt.Sprintf(KeyMapping, groupID, keyID)
data, err := r.store.Get(parentCtx, mappingKey)
if err != nil {
continue
}
var mapping models.GroupAPIKeyMapping
if json.Unmarshal(data, &mapping) == nil {
var score float64
if mapping.LastUsedAt != nil {
score = float64(mapping.LastUsedAt.UnixMilli())
}
lruMembers[keyIDStr] = score
}
}
if len(lruMembers) > 0 {
lruKey := fmt.Sprintf(BasePoolLRU, currentPoolID)
if err := r.store.ZAdd(parentCtx, lruKey, lruMembers); err != nil {
if parentCtx.Err() == nil {
r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool '%s'", currentPoolID)
}
}
}
}
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID uint) {
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) {
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
r.store.ZAdd(lruKey, map[string]float64{
strconv.FormatUint(uint64(keyID), 10): nowMilli(),
err := r.store.ZAdd(ctx, lruKey, map[string]float64{
strconv.FormatUint(uint64(keyID), 10): r.nowMilli(),
})
if err != nil {
r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID)
}
}
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
func generatePoolID(groups []*models.KeyGroup) string {
func (r *gormKeyRepository) generatePoolID(groups []*models.KeyGroup) string {
ids := make([]int, len(groups))
for i, g := range groups {
ids[i] = int(g.ID)
}
sort.Ints(ids)
h := sha1.New()
io.WriteString(h, fmt.Sprintf("%v", ids))
return fmt.Sprintf("%x", h.Sum(nil))
}
// toInterfaceSlice 类型转换辅助函数
func toInterfaceSlice(slice []string) []interface{} {
func (r *gormKeyRepository) toInterfaceSlice(slice []string) []interface{} {
result := make([]interface{}, len(slice))
for i, v := range slice {
result[i] = v
@@ -280,13 +399,13 @@ func toInterfaceSlice(slice []string) []interface{} {
}
// nowMilli 返回当前的Unix毫秒时间戳用于LRU/Weighted策略
func nowMilli() float64 {
func (r *gormKeyRepository) nowMilli() float64 {
return float64(time.Now().UnixMilli())
}
// getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。
func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
apiKeyJSON, err := r.store.Get(fmt.Sprintf(KeyDetails, keyID))
func (r *gormKeyRepository) getKeyDetailsFromCache(ctx context.Context, keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
apiKeyJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyDetails, keyID))
if err != nil {
return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err)
}
@@ -295,7 +414,7 @@ func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models
return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err)
}
mappingJSON, err := r.store.Get(fmt.Sprintf(KeyMapping, groupID, keyID))
mappingJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyMapping, groupID, keyID))
if err != nil {
return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err)
}

View File

@@ -1,7 +1,9 @@
// Filename: internal/repository/key_writer.go
package repository
import (
"context"
"fmt"
"gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
@@ -9,7 +11,7 @@ import (
"time"
)
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint) {
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
timestamp := float64(time.Now().UnixMilli())
@@ -17,52 +19,51 @@ func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
strconv.FormatUint(uint64(keyID), 10): timestamp,
}
if err := r.store.ZAdd(lruKey, members); err != nil {
if err := r.store.ZAdd(ctx, lruKey, members); err != nil {
r.logger.WithError(err).Warnf("Failed to update usage timestamp for key %d in group %d", keyID, groupID)
}
}
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus) {
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
r.logger.Infof("SYNC: Directly updating polling caches for G:%d K:%d -> %s", groupID, keyID, newStatus)
r.updatePollingCachesLogic(groupID, keyID, newStatus)
r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus)
}
func (r *gormKeyRepository) HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus) {
func (r *gormKeyRepository) HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
r.logger.Infof("EVENT: Updating polling caches for G:%d K:%d -> %s from an event", groupID, keyID, newStatus)
r.updatePollingCachesLogic(groupID, keyID, newStatus)
r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus)
}
func (r *gormKeyRepository) updatePollingCachesLogic(groupID, keyID uint, newStatus models.APIKeyStatus) {
func (r *gormKeyRepository) updatePollingCachesLogic(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
keyIDStr := strconv.FormatUint(uint64(keyID), 10)
sequentialKey := fmt.Sprintf(KeyGroupSequential, groupID)
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, groupID)
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, groupID)
_ = r.store.LRem(sequentialKey, 0, keyIDStr)
_ = r.store.ZRem(lruKey, keyIDStr)
_ = r.store.SRem(mainPoolKey, keyIDStr)
_ = r.store.SRem(cooldownPoolKey, keyIDStr)
_ = r.store.LRem(ctx, sequentialKey, 0, keyIDStr)
_ = r.store.ZRem(ctx, lruKey, keyIDStr)
_ = r.store.SRem(ctx, mainPoolKey, keyIDStr)
_ = r.store.SRem(ctx, cooldownPoolKey, keyIDStr)
if newStatus == models.StatusActive {
if err := r.store.LPush(sequentialKey, keyIDStr); err != nil {
if err := r.store.LPush(ctx, sequentialKey, keyIDStr); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to sequential list for group %d", keyID, groupID)
}
members := map[string]float64{keyIDStr: 0}
if err := r.store.ZAdd(lruKey, members); err != nil {
if err := r.store.ZAdd(ctx, lruKey, members); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to LRU zset for group %d", keyID, groupID)
}
if err := r.store.SAdd(mainPoolKey, keyIDStr); err != nil {
if err := r.store.SAdd(ctx, mainPoolKey, keyIDStr); err != nil {
r.logger.WithError(err).Warnf("Failed to add key %d to random main pool for group %d", keyID, groupID)
}
}
}
// UpdateKeyStatusAfterRequest is the new central hub for handling feedback.
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
if success {
if group.PollingStrategy == models.StrategyWeighted {
go r.UpdateKeyUsageTimestamp(group.ID, key.ID)
go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, key.ID)
}
return
}
@@ -72,6 +73,5 @@ func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup,
}
r.logger.Warnf("Request failed for KeyID %d in GroupID %d with error: %s. Temporarily removing from active polling caches.", key.ID, group.ID, apiErr.Message)
// This call is correct. It uses the synchronous, direct method.
r.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
r.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
}

View File

@@ -2,6 +2,8 @@
package repository
import (
"context"
"gemini-balancer/internal/config"
"gemini-balancer/internal/crypto"
"gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
@@ -22,8 +24,8 @@ type BasePool struct {
type KeyRepository interface {
// --- 核心选取与调度 --- key_selector
SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
// --- 加密与解密 --- key_crud
Decrypt(key *models.APIKey) error
@@ -37,16 +39,16 @@ type KeyRepository interface {
GetKeyByID(id uint) (*models.APIKey, error)
GetKeyByValue(keyValue string) (*models.APIKey, error)
GetKeysByValues(keyValues []string) ([]models.APIKey, error)
GetKeysByIDs(ids []uint) ([]models.APIKey, error) // [新增] 根据一组主键ID批量获取Key
GetKeysByIDs(ids []uint) ([]models.APIKey, error)
GetKeysByGroup(groupID uint) ([]models.APIKey, error)
CountByGroup(groupID uint) (int64, error)
// --- 多对多关系管理 --- key_mapping
LinkKeysToGroup(groupID uint, keyIDs []uint) error
UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
GetGroupsForKey(keyID uint) ([]uint, error)
LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error
UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error)
GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error)
UpdateMapping(mapping *models.GroupAPIKeyMapping) error
UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error
GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error)
GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error)
FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error)
@@ -55,8 +57,8 @@ type KeyRepository interface {
UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error
// --- 缓存管理 --- key_cache
LoadAllKeysToStore() error
HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error
LoadAllKeysToStore(ctx context.Context) error
HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error
// --- 维护与后台任务 --- key_maintenance
StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error
@@ -65,16 +67,14 @@ type KeyRepository interface {
DeleteOrphanKeys() (int64, error)
DeleteOrphanKeysTx(tx *gorm.DB) (int64, error)
GetActiveMasterKeys() ([]*models.APIKey, error)
UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error
UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error
HardDeleteSoftDeletedBefore(date time.Time) (int64, error)
// --- 轮询策略的"写"操作 --- key_writer
UpdateKeyUsageTimestamp(groupID, keyID uint)
// 同步更新缓存,供核心业务使用
SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus)
// 异步更新缓存,供事件订阅者使用
HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus)
UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint)
SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus)
HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus)
UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
}
type GroupRepository interface {
@@ -88,18 +88,20 @@ type gormKeyRepository struct {
store store.Store
logger *logrus.Entry
crypto *crypto.Service
config *config.Config
}
type gormGroupRepository struct {
db *gorm.DB
}
func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service) KeyRepository {
func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service, cfg *config.Config) KeyRepository {
return &gormKeyRepository{
db: db,
store: s,
logger: logger.WithField("component", "repository.key🔗"),
crypto: crypto,
config: cfg,
}
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
func NewRouter(
@@ -35,6 +36,7 @@ func NewRouter(
settingHandler *handlers.SettingHandler,
dashboardHandler *handlers.DashboardHandler,
taskHandler *handlers.TaskHandler,
wsHandler *handlers.WebSocketHandler,
// Web Page Handlers
webAuthHandler *webhandlers.WebAuthHandler,
pageHandler *webhandlers.PageHandler,
@@ -42,70 +44,215 @@ func NewRouter(
upstreamModule *upstream.Module,
proxyModule *proxy.Module,
) *gin.Engine {
// === 1. 创建全局 Logger统一管理===
logger := createLogger(cfg)
// === 2. 设置 Gin 运行模式 ===
if cfg.Log.Level != "debug" {
gin.SetMode(gin.ReleaseMode)
}
router := gin.Default()
router.Static("/static", "./web/static")
// CORS 配置
config := cors.Config{
// 允许前端的来源。在生产环境中,需改为实际域名
AllowOrigins: []string{"http://localhost:9000"},
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization"},
ExposeHeaders: []string{"Content-Length"},
AllowCredentials: true,
MaxAge: 12 * time.Hour,
}
router.Use(cors.New(config))
isDebug := gin.Mode() != gin.ReleaseMode
router.HTMLRender = pongo.New("web/templates", isDebug)
// === 3. 创建 Router使用 gin.New() 以便完全控制中间件)===
router := gin.New()
// --- 基础设施 ---
router.GET("/", func(c *gin.Context) { c.Redirect(http.StatusMovedPermanently, "/dashboard") })
router.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) })
// --- 统一的认证管道 ---
apiAdminAuth := middleware.APIAdminAuthMiddleware(securityService)
// === 4. 注册全局中间件(按执行顺序)===
setupGlobalMiddleware(router, logger)
// === 5. 配置静态文件和模板 ===
setupStaticAndTemplates(router, logger)
// === 6. 配置 CORS ===
setupCORS(router, cfg)
// === 7. 注册基础路由 ===
setupBasicRoutes(router)
// === 8. 创建认证中间件 ===
apiAdminAuth := middleware.APIAdminAuthMiddleware(securityService, logger)
webAdminAuth := middleware.WebAdminAuthMiddleware(securityService)
router.Use(gin.RecoveryWithWriter(os.Stdout))
// --- 将正确的依赖和中间件管道传递下去 ---
registerProxyRoutes(router, proxyHandler, securityService)
registerAdminRoutes(router, apiAdminAuth, keyGroupHandler, tokensHandler, apiKeyHandler, logHandler, settingHandler, dashboardHandler, taskHandler, upstreamModule, proxyModule)
registerPublicAPIRoutes(router, apiAuthHandler, securityService, settingsManager)
// === 9. 注册业务路由(按功能分组)===
registerPublicAPIRoutes(router, apiAuthHandler, securityService, settingsManager, logger)
registerWebRoutes(router, webAdminAuth, webAuthHandler, pageHandler)
registerAdminRoutes(router, apiAdminAuth, keyGroupHandler, tokensHandler, apiKeyHandler,
logHandler, settingHandler, dashboardHandler, taskHandler, upstreamModule, proxyModule)
registerWebSocketRoutes(router, wsHandler)
registerProxyRoutes(router, proxyHandler, securityService, logger)
return router
}
// ==================== 辅助函数 ====================
// createLogger 创建并配置全局 Logger
func createLogger(cfg *config.Config) *logrus.Logger {
logger := logrus.New()
// 设置日志格式
if cfg.Log.Format == "json" {
logger.SetFormatter(&logrus.JSONFormatter{
TimestampFormat: time.RFC3339,
})
} else {
logger.SetFormatter(&logrus.TextFormatter{
FullTimestamp: true,
TimestampFormat: "2006-01-02 15:04:05",
})
}
// 设置日志级别
switch cfg.Log.Level {
case "debug":
logger.SetLevel(logrus.DebugLevel)
case "info":
logger.SetLevel(logrus.InfoLevel)
case "warn":
logger.SetLevel(logrus.WarnLevel)
case "error":
logger.SetLevel(logrus.ErrorLevel)
default:
logger.SetLevel(logrus.InfoLevel)
}
// 设置输出(可选:输出到文件)
logger.SetOutput(os.Stdout)
return logger
}
// setupGlobalMiddleware 设置全局中间件
func setupGlobalMiddleware(router *gin.Engine, logger *logrus.Logger) {
// 1. 请求 ID 中间件(用于链路追踪)
router.Use(middleware.RequestIDMiddleware())
// 2. 数据脱敏中间件(在日志前执行)
router.Use(middleware.RedactionMiddleware())
// 3. 日志中间件
router.Use(middleware.LogrusLogger(logger))
// 4. 错误恢复中间件
router.Use(gin.RecoveryWithWriter(os.Stdout))
}
// setupStaticAndTemplates 配置静态文件和模板
func setupStaticAndTemplates(router *gin.Engine, logger *logrus.Logger) {
router.Static("/static", "./web/static")
isDebug := gin.Mode() != gin.ReleaseMode
router.HTMLRender = pongo.New("web/templates", isDebug, logger)
}
// setupCORS 配置 CORS
func setupCORS(router *gin.Engine, cfg *config.Config) {
corsConfig := cors.Config{
AllowOrigins: getCORSOrigins(cfg),
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Request-Id"},
ExposeHeaders: []string{"Content-Length", "X-Request-Id"},
AllowCredentials: true,
MaxAge: 12 * time.Hour,
}
router.Use(cors.New(corsConfig))
}
// getCORSOrigins 获取 CORS 允许的来源
func getCORSOrigins(cfg *config.Config) []string {
// 默认值
origins := []string{"http://localhost:9000"}
// 从配置读取(修复:移除 nil 检查)
if len(cfg.Server.CORSOrigins) > 0 {
origins = cfg.Server.CORSOrigins
}
return origins
}
// setupBasicRoutes 设置基础路由
func setupBasicRoutes(router *gin.Engine) {
// 根路径重定向
router.GET("/", func(c *gin.Context) {
c.Redirect(http.StatusMovedPermanently, "/dashboard")
})
// 健康检查
router.GET("/health", handleHealthCheck)
// 版本信息(可选)
router.GET("/version", handleVersion)
}
// handleHealthCheck 健康检查处理器
func handleHealthCheck(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "ok",
"time": time.Now().Unix(),
})
}
// handleVersion 版本信息处理器
func handleVersion(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"version": "1.0.0", // 可以从配置或编译时变量读取
"build": "latest",
})
}
// ==================== 路由注册函数 ====================
// registerProxyRoutes 注册代理路由
func registerProxyRoutes(
router *gin.Engine, proxyHandler *handlers.ProxyHandler, securityService *service.SecurityService,
router *gin.Engine,
proxyHandler *handlers.ProxyHandler,
securityService *service.SecurityService,
logger *logrus.Logger,
) {
// 通用的代理认证中间件
proxyAuthMiddleware := middleware.ProxyAuthMiddleware(securityService)
// --- 模式一: 智能聚合模式 (根路径) ---
// /v1 和 /v1beta 路径作为默认入口,服务于 BasePool 聚合逻辑
// 创建代理认证中间件
proxyAuthMiddleware := middleware.ProxyAuthMiddleware(securityService, logger)
// 模式一: 智能聚合模式(默认入口)
registerAggregateProxyRoutes(router, proxyHandler, proxyAuthMiddleware)
// 模式二: 精确路由模式(按组名路由)
registerGroupProxyRoutes(router, proxyHandler, proxyAuthMiddleware)
}
// registerAggregateProxyRoutes 注册聚合代理路由
func registerAggregateProxyRoutes(
router *gin.Engine,
proxyHandler *handlers.ProxyHandler,
authMiddleware gin.HandlerFunc,
) {
// /v1 路径组
v1 := router.Group("/v1")
v1.Use(proxyAuthMiddleware)
v1.Use(authMiddleware)
{
v1.Any("/*path", proxyHandler.HandleProxy)
}
// /v1beta 路径组
v1beta := router.Group("/v1beta")
v1beta.Use(proxyAuthMiddleware)
v1beta.Use(authMiddleware)
{
v1beta.Any("/*path", proxyHandler.HandleProxy)
}
// --- 模式二: 精确路由模式 (/proxy/:group_name) ---
// 创建一个新的、物理隔离的路由组,用于按组名精确路由
}
// registerGroupProxyRoutes 注册分组代理路由
func registerGroupProxyRoutes(
router *gin.Engine,
proxyHandler *handlers.ProxyHandler,
authMiddleware gin.HandlerFunc,
) {
proxyGroup := router.Group("/proxy/:group_name")
proxyGroup.Use(proxyAuthMiddleware)
proxyGroup.Use(authMiddleware)
{
// 捕获所有子路径 (例如 /v1/chat/completions),并全部交给同一个 ProxyHandler。
proxyGroup.Any("/*path", proxyHandler.HandleProxy)
}
}
// registerAdminRoutes
// registerAdminRoutes 注册管理后台 API 路由
func registerAdminRoutes(
router *gin.Engine,
authMiddleware gin.HandlerFunc,
@@ -121,74 +268,115 @@ func registerAdminRoutes(
) {
admin := router.Group("/admin", authMiddleware)
{
// --- KeyGroup Base Routes ---
admin.POST("/keygroups", keyGroupHandler.CreateKeyGroup)
admin.GET("/keygroups", keyGroupHandler.GetKeyGroups)
admin.PUT("/keygroups/order", keyGroupHandler.UpdateKeyGroupOrder)
// --- KeyGroup Specific Routes (by :id) ---
admin.GET("/keygroups/:id", keyGroupHandler.GetKeyGroups)
admin.PUT("/keygroups/:id", keyGroupHandler.UpdateKeyGroup)
admin.DELETE("/keygroups/:id", keyGroupHandler.DeleteKeyGroup)
admin.POST("/keygroups/:id/clone", keyGroupHandler.CloneKeyGroup)
admin.GET("/keygroups/:id/stats", keyGroupHandler.GetKeyGroupStats)
admin.POST("/keygroups/:id/bulk-actions", apiKeyHandler.HandleBulkAction)
// --- APIKey Sub-resource Routes under a KeyGroup ---
keyGroupAPIKeys := admin.Group("/keygroups/:id/apikeys")
{
keyGroupAPIKeys.GET("", apiKeyHandler.ListKeysForGroup)
keyGroupAPIKeys.GET("/export", apiKeyHandler.ExportKeysForGroup)
keyGroupAPIKeys.POST("/bulk", apiKeyHandler.AddMultipleKeysToGroup)
keyGroupAPIKeys.DELETE("/bulk", apiKeyHandler.UnlinkMultipleKeysFromGroup)
keyGroupAPIKeys.POST("/test", apiKeyHandler.TestKeysForGroup)
keyGroupAPIKeys.PUT("/:keyId", apiKeyHandler.UpdateGroupAPIKeyMapping)
}
// KeyGroup 路由
registerKeyGroupRoutes(admin, keyGroupHandler, apiKeyHandler)
// Global key operations
admin.GET("/apikeys", apiKeyHandler.ListAPIKeys)
// admin.PUT("/apikeys/:id", apiKeyHandler.UpdateAPIKey) // DEPRECATED: Status is now contextual
admin.POST("/apikeys/test", apiKeyHandler.TestMultipleKeys) // Test keys globally
admin.DELETE("/apikeys/:id", apiKeyHandler.HardDeleteAPIKey) // Hard delete a single key
admin.DELETE("/apikeys/bulk", apiKeyHandler.HardDeleteMultipleKeys) // Hard delete multiple keys
admin.PUT("/apikeys/bulk/restore", apiKeyHandler.RestoreMultipleKeys) // Restore multiple keys globally
// APIKey 全局路由
registerAPIKeyRoutes(admin, apiKeyHandler)
// --- Global Routes ---
admin.GET("/tokens", tokensHandler.GetAllTokens)
admin.PUT("/tokens", tokensHandler.UpdateTokens)
admin.GET("/logs", logHandler.GetLogs)
admin.GET("/settings", settingHandler.GetSettings)
admin.PUT("/settings", settingHandler.UpdateSettings)
admin.PUT("/settings/reset", settingHandler.ResetSettingsToDefaults)
// 系统管理路由
registerSystemRoutes(admin, tokensHandler, logHandler, settingHandler, taskHandler)
// 用于查询异步任务的状态
admin.GET("/tasks/:id", taskHandler.GetTaskStatus)
// 仪表盘路由
registerDashboardRoutes(admin, dashboardHandler)
// 领域模块
// 领域模块路由
upstreamModule.RegisterRoutes(admin)
proxyModule.RegisterRoutes(admin)
// --- 全局仪表盘路由 ---
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/overview", dashboardHandler.GetOverview)
dashboard.GET("/chart", dashboardHandler.GetChart)
dashboard.GET("/stats/:period", dashboardHandler.GetRequestStats) // 点击详情
}
}
}
// registerWebRoutes
// registerKeyGroupRoutes 注册 KeyGroup 相关路由
func registerKeyGroupRoutes(
admin *gin.RouterGroup,
keyGroupHandler *handlers.KeyGroupHandler,
apiKeyHandler *handlers.APIKeyHandler,
) {
// 基础路由
admin.POST("/keygroups", keyGroupHandler.CreateKeyGroup)
admin.GET("/keygroups", keyGroupHandler.GetKeyGroups)
admin.PUT("/keygroups/order", keyGroupHandler.UpdateKeyGroupOrder)
// 特定 KeyGroup 路由
admin.GET("/keygroups/:id", keyGroupHandler.GetKeyGroups)
admin.PUT("/keygroups/:id", keyGroupHandler.UpdateKeyGroup)
admin.DELETE("/keygroups/:id", keyGroupHandler.DeleteKeyGroup)
admin.POST("/keygroups/:id/clone", keyGroupHandler.CloneKeyGroup)
admin.GET("/keygroups/:id/stats", keyGroupHandler.GetKeyGroupStats)
admin.POST("/keygroups/:id/bulk-actions", apiKeyHandler.HandleBulkAction)
// KeyGroup 的 APIKey 子资源
keyGroupAPIKeys := admin.Group("/keygroups/:id/apikeys")
{
keyGroupAPIKeys.GET("", apiKeyHandler.ListKeysForGroup)
keyGroupAPIKeys.GET("/export", apiKeyHandler.ExportKeysForGroup)
keyGroupAPIKeys.POST("/bulk", apiKeyHandler.AddMultipleKeysToGroup)
keyGroupAPIKeys.DELETE("/bulk", apiKeyHandler.UnlinkMultipleKeysFromGroup)
keyGroupAPIKeys.POST("/test", apiKeyHandler.TestKeysForGroup)
keyGroupAPIKeys.PUT("/:keyId", apiKeyHandler.UpdateGroupAPIKeyMapping)
}
}
// registerAPIKeyRoutes 注册 APIKey 全局路由
func registerAPIKeyRoutes(admin *gin.RouterGroup, apiKeyHandler *handlers.APIKeyHandler) {
admin.GET("/apikeys", apiKeyHandler.ListAPIKeys)
admin.POST("/apikeys/test", apiKeyHandler.TestMultipleKeys)
admin.DELETE("/apikeys/:id", apiKeyHandler.HardDeleteAPIKey)
admin.DELETE("/apikeys/bulk", apiKeyHandler.HardDeleteMultipleKeys)
admin.PUT("/apikeys/bulk/restore", apiKeyHandler.RestoreMultipleKeys)
}
// registerSystemRoutes 注册系统管理路由
func registerSystemRoutes(
admin *gin.RouterGroup,
tokensHandler *handlers.TokensHandler,
logHandler *handlers.LogHandler,
settingHandler *handlers.SettingHandler,
taskHandler *handlers.TaskHandler,
) {
// Token 管理
admin.GET("/tokens", tokensHandler.GetAllTokens)
admin.PUT("/tokens", tokensHandler.UpdateTokens)
// 日志管理
admin.GET("/logs", logHandler.GetLogs)
admin.DELETE("/logs", logHandler.DeleteLogs) // 删除选定
admin.DELETE("/logs/all", logHandler.DeleteAllLogs) // 删除全部
admin.DELETE("/logs/old", logHandler.DeleteOldLogs) // 删除旧日志
// 设置管理
admin.GET("/settings", settingHandler.GetSettings)
admin.PUT("/settings", settingHandler.UpdateSettings)
admin.PUT("/settings/reset", settingHandler.ResetSettingsToDefaults)
// 任务管理
admin.GET("/tasks/:id", taskHandler.GetTaskStatus)
}
// registerDashboardRoutes 注册仪表盘路由
func registerDashboardRoutes(admin *gin.RouterGroup, dashboardHandler *handlers.DashboardHandler) {
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/overview", dashboardHandler.GetOverview)
dashboard.GET("/chart", dashboardHandler.GetChart)
dashboard.GET("/stats/:period", dashboardHandler.GetRequestStats)
}
}
// registerWebRoutes 注册 Web 页面路由
func registerWebRoutes(
router *gin.Engine,
authMiddleware gin.HandlerFunc,
webAuthHandler *webhandlers.WebAuthHandler,
pageHandler *webhandlers.PageHandler,
) {
// 公开的认证路由
router.GET("/login", webAuthHandler.ShowLoginPage)
router.POST("/login", webAuthHandler.HandleLogin)
router.GET("/logout", webAuthHandler.HandleLogout)
// For Test only router.Run("127.0.0.1:9000")
// 受保护的Admin Web界面
// 受保护的管理界面
webGroup := router.Group("/", authMiddleware)
webGroup.Use(authMiddleware)
{
webGroup.GET("/keys", pageHandler.ShowKeysPage)
webGroup.GET("/settings", pageHandler.ShowConfigEditorPage)
@@ -197,14 +385,35 @@ func registerWebRoutes(
webGroup.GET("/tasks", pageHandler.ShowTasksPage)
webGroup.GET("/chat", pageHandler.ShowChatPage)
}
}
// registerPublicAPIRoutes 无需后台登录的公共API路由
func registerPublicAPIRoutes(router *gin.Engine, apiAuthHandler *handlers.APIAuthHandler, securityService *service.SecurityService, settingsManager *settings.SettingsManager) {
ipBanMiddleware := middleware.IPBanMiddleware(securityService, settingsManager)
publicAPIGroup := router.Group("/api")
// registerPublicAPIRoutes 注册公共 API 路由
func registerPublicAPIRoutes(
router *gin.Engine,
apiAuthHandler *handlers.APIAuthHandler,
securityService *service.SecurityService,
settingsManager *settings.SettingsManager,
logger *logrus.Logger,
) {
// 创建 IP 封禁中间件
ipBanCache := middleware.NewIPBanCache()
ipBanMiddleware := middleware.IPBanMiddleware(
securityService,
settingsManager,
ipBanCache,
logger,
)
// 公共 API 路由组
publicAPI := router.Group("/api")
{
publicAPIGroup.POST("/login", ipBanMiddleware, apiAuthHandler.HandleLogin)
publicAPI.POST("/login", ipBanMiddleware, apiAuthHandler.HandleLogin)
// 可以在这里添加其他公共 API 路由
// publicAPI.POST("/register", ipBanMiddleware, apiAuthHandler.HandleRegister)
// publicAPI.POST("/forgot-password", ipBanMiddleware, apiAuthHandler.HandleForgotPassword)
}
}
func registerWebSocketRoutes(router *gin.Engine, wsHandler *handlers.WebSocketHandler) {
router.GET("/ws/system-logs", wsHandler.HandleSystemLogs)
}

View File

@@ -1,42 +1,68 @@
// Filename: internal/scheduler/scheduler.go
// [REVISED] - 用这个更智能的版本完整替换
package scheduler
import (
"context"
"fmt" // [NEW] 导入 fmt
"gemini-balancer/internal/repository"
"gemini-balancer/internal/service"
"gemini-balancer/internal/settings"
"gemini-balancer/internal/store"
"strconv" // [NEW] 导入 strconv
"strings" // [NEW] 导入 strings
"sync"
"time"
"github.com/go-co-op/gocron"
"github.com/sirupsen/logrus"
)
// ... (Scheduler struct 和 NewScheduler 保持不变) ...
const LogCleanupTaskTag = "log-cleanup-task"
type Scheduler struct {
gocronScheduler *gocron.Scheduler
logger *logrus.Entry
statsService *service.StatsService
logService *service.LogService
settingsManager *settings.SettingsManager
keyRepo repository.KeyRepository
// healthCheckService *service.HealthCheckService // 健康检查任务预留
store store.Store
stopChan chan struct{}
wg sync.WaitGroup
}
func NewScheduler(statsSvc *service.StatsService, keyRepo repository.KeyRepository, logger *logrus.Logger) *Scheduler {
func NewScheduler(
statsSvc *service.StatsService,
logSvc *service.LogService,
keyRepo repository.KeyRepository,
settingsMgr *settings.SettingsManager,
store store.Store,
logger *logrus.Logger,
) *Scheduler {
s := gocron.NewScheduler(time.UTC)
s.TagsUnique()
return &Scheduler{
gocronScheduler: s,
logger: logger.WithField("component", "Scheduler📆"),
statsService: statsSvc,
logService: logSvc,
settingsManager: settingsMgr,
keyRepo: keyRepo,
store: store,
stopChan: make(chan struct{}),
}
}
// ... (Start 和 listenForSettingsUpdates 保持不变) ...
func (s *Scheduler) Start() {
s.logger.Info("Starting scheduler and registering jobs...")
// --- 任务注册 ---
// 使用CRON表达式精确定义“每小时的第5分钟”执行
// --- 注册静态定时任务 ---
_, err := s.gocronScheduler.Cron("5 * * * *").Tag("stats-aggregation").Do(func() {
s.logger.Info("Executing hourly request stats aggregation...")
if err := s.statsService.AggregateHourlyStats(); err != nil {
ctx := context.Background()
if err := s.statsService.AggregateHourlyStats(ctx); err != nil {
s.logger.WithError(err).Error("Hourly stats aggregation failed.")
} else {
s.logger.Info("Hourly stats aggregation completed successfully.")
@@ -45,26 +71,9 @@ func (s *Scheduler) Start() {
if err != nil {
s.logger.Errorf("Failed to schedule [stats-aggregation]: %v", err)
}
// 任务二:(预留) 自动健康检查 (例如每10分钟一次)
/*
_, err = s.gocronScheduler.Every(10).Minutes().Tag("auto-health-check").Do(func() {
s.logger.Info("Executing periodic health check for all groups...")
// s.healthCheckService.StartGlobalCheckTask() // 伪代码
})
if err != nil {
s.logger.Errorf("Failed to schedule [auto-health-check]: %v", err)
}
*/
// [NEW] --- 任务三: 清理软删除的API Keys ---
// Executes once daily at 3:15 AM UTC.
_, err = s.gocronScheduler.Cron("15 3 * * *").Tag("cleanup-soft-deleted-keys").Do(func() {
s.logger.Info("Executing daily cleanup of soft-deleted API keys...")
// Let's assume a retention period of 7 days for now.
// In a real scenario, this should come from settings.
const retentionDays = 7
count, err := s.keyRepo.HardDeleteSoftDeletedBefore(time.Now().AddDate(0, 0, -retentionDays))
if err != nil {
s.logger.WithError(err).Error("Daily cleanup of soft-deleted keys failed.")
@@ -77,14 +86,125 @@ func (s *Scheduler) Start() {
if err != nil {
s.logger.Errorf("Failed to schedule [cleanup-soft-deleted-keys]: %v", err)
}
// --- 任务注册结束 ---
s.gocronScheduler.StartAsync() // 异步启动,不阻塞应用主线程
// --- 动态任务初始化 ---
if err := s.UpdateLogCleanupTask(); err != nil {
s.logger.WithError(err).Error("Failed to initialize log cleanup task on startup.")
}
// --- 启动后台监听器和调度器 ---
s.wg.Add(1)
go s.listenForSettingsUpdates()
s.gocronScheduler.StartAsync()
s.logger.Info("Scheduler started.")
}
func (s *Scheduler) listenForSettingsUpdates() {
defer s.wg.Done()
s.logger.Info("Starting listener for system settings updates...")
for {
select {
case <-s.stopChan:
s.logger.Info("Stopping settings update listener.")
return
default:
}
ctx, cancel := context.WithCancel(context.Background())
subscription, err := s.store.Subscribe(ctx, settings.SettingsUpdateChannel)
if err != nil {
s.logger.WithError(err).Warnf("Failed to subscribe to settings channel, retrying in 5s...")
cancel()
time.Sleep(5 * time.Second)
continue
}
s.logger.Infof("Successfully subscribed to channel '%s'.", settings.SettingsUpdateChannel)
listenLoop:
for {
select {
case msg, ok := <-subscription.Channel():
if !ok {
s.logger.Warn("Subscription channel closed by publisher. Re-subscribing...")
break listenLoop
}
s.logger.Infof("Received settings update notification: %s", string(msg.Payload))
if err := s.UpdateLogCleanupTask(); err != nil {
s.logger.WithError(err).Error("Failed to update log cleanup task after notification.")
}
case <-s.stopChan:
s.logger.Info("Stopping settings update listener.")
subscription.Close()
cancel()
return
}
}
subscription.Close()
cancel()
}
}
// [MODIFIED] - UpdateLogCleanupTask 现在会动态生成 cron 表达式
func (s *Scheduler) UpdateLogCleanupTask() error {
if err := s.gocronScheduler.RemoveByTag(LogCleanupTaskTag); err != nil {
// This is not an error, just means the job didn't exist
}
settings := s.settingsManager.GetSettings()
if !settings.LogAutoCleanupEnabled || settings.LogAutoCleanupRetentionDays <= 0 {
s.logger.Info("Log auto-cleanup is disabled. Task removed or not scheduled.")
return nil
}
days := settings.LogAutoCleanupRetentionDays
// [NEW] 解析时间并生成 cron 表达式
cronSpec, err := parseTimeToCron(settings.LogAutoCleanupTime)
if err != nil {
s.logger.WithError(err).Warnf("Invalid cleanup time format '%s'. Falling back to default '04:05'.", settings.LogAutoCleanupTime)
cronSpec = "5 4 * * *" // 安全回退
}
s.logger.Infof("Scheduling/updating daily log cleanup task to retain last %d days of logs, using cron spec: '%s'", days, cronSpec)
_, err = s.gocronScheduler.Cron(cronSpec).Tag(LogCleanupTaskTag).Do(func() {
s.logger.Infof("Executing daily log cleanup, deleting logs older than %d days...", days)
ctx := context.Background()
deletedCount, err := s.logService.DeleteOldLogs(ctx, days)
if err != nil {
s.logger.WithError(err).Error("Daily log cleanup task failed.")
} else {
s.logger.Infof("Daily log cleanup task completed. Deleted %d old logs.", deletedCount)
}
})
if err != nil {
s.logger.WithError(err).Error("Failed to schedule new log cleanup task.")
return err
}
s.logger.Info("Log cleanup task updated successfully.")
return nil
}
// [NEW] - 用于解析 "HH:mm" 格式时间为 cron 表达式的辅助函数
func parseTimeToCron(timeStr string) (string, error) {
parts := strings.Split(timeStr, ":")
if len(parts) != 2 {
return "", fmt.Errorf("invalid time format, expected HH:mm")
}
hour, err := strconv.Atoi(parts[0])
if err != nil || hour < 0 || hour > 23 {
return "", fmt.Errorf("invalid hour value: %s", parts[0])
}
minute, err := strconv.Atoi(parts[1])
if err != nil || minute < 0 || minute > 59 {
return "", fmt.Errorf("invalid minute value: %s", parts[1])
}
return fmt.Sprintf("%d %d * * *", minute, hour), nil
}
func (s *Scheduler) Stop() {
s.logger.Info("Stopping scheduler...")
close(s.stopChan)
s.gocronScheduler.Stop()
s.logger.Info("Scheduler stopped.")
s.wg.Wait()
s.logger.Info("Scheduler stopped gracefully.")
}

View File

@@ -1,96 +1,183 @@
// Filename: internal/service/analytics_service.go
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/db/dialect"
"gemini-balancer/internal/models"
"gemini-balancer/internal/store"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"gemini-balancer/internal/db/dialect"
"gemini-balancer/internal/models"
"gemini-balancer/internal/settings"
"gemini-balancer/internal/store"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
const (
flushLoopInterval = 1 * time.Minute
defaultFlushInterval = 1 * time.Minute
maxRetryAttempts = 3
retryDelay = 5 * time.Second
)
type AnalyticsServiceLogger struct{ *logrus.Entry }
type AnalyticsService struct {
db *gorm.DB
store store.Store
logger *logrus.Entry
db *gorm.DB
store store.Store
logger *logrus.Entry
dialect dialect.DialectAdapter
settingsManager *settings.SettingsManager
stopChan chan struct{}
wg sync.WaitGroup
dialect dialect.DialectAdapter
ctx context.Context
cancel context.CancelFunc
// 统计指标
eventsReceived atomic.Uint64
eventsProcessed atomic.Uint64
eventsFailed atomic.Uint64
flushCount atomic.Uint64
recordsFlushed atomic.Uint64
flushErrors atomic.Uint64
lastFlushTime time.Time
lastFlushMutex sync.RWMutex
// 运行时配置
flushInterval time.Duration
configMutex sync.RWMutex
}
func NewAnalyticsService(db *gorm.DB, s store.Store, logger *logrus.Logger, d dialect.DialectAdapter) *AnalyticsService {
func NewAnalyticsService(
db *gorm.DB,
s store.Store,
logger *logrus.Logger,
d dialect.DialectAdapter,
settingsManager *settings.SettingsManager,
) *AnalyticsService {
ctx, cancel := context.WithCancel(context.Background())
return &AnalyticsService{
db: db,
store: s,
logger: logger.WithField("component", "Analytics📊"),
stopChan: make(chan struct{}),
dialect: d,
db: db,
store: s,
logger: logger.WithField("component", "Analytics📊"),
dialect: d,
settingsManager: settingsManager,
stopChan: make(chan struct{}),
ctx: ctx,
cancel: cancel,
flushInterval: defaultFlushInterval,
lastFlushTime: time.Now(),
}
}
func (s *AnalyticsService) Start() {
s.wg.Add(2) // 2 (flushLoop, eventListener)
go s.flushLoop()
s.wg.Add(3)
go s.eventListener()
s.logger.Info("AnalyticsService (Command Side) started.")
go s.flushLoop()
go s.metricsReporter()
s.logger.WithFields(logrus.Fields{
"flush_interval": s.flushInterval,
}).Info("AnalyticsService started")
}
func (s *AnalyticsService) Stop() {
s.logger.Info("AnalyticsService stopping...")
close(s.stopChan)
s.cancel()
s.wg.Wait()
s.logger.Info("AnalyticsService stopped. Performing final data flush...")
s.flushToDB() // 停止前刷盘
s.logger.Info("AnalyticsService final data flush completed.")
s.logger.Info("Performing final data flush...")
s.flushToDB()
// 输出最终统计
s.logger.WithFields(logrus.Fields{
"events_received": s.eventsReceived.Load(),
"events_processed": s.eventsProcessed.Load(),
"events_failed": s.eventsFailed.Load(),
"flush_count": s.flushCount.Load(),
"records_flushed": s.recordsFlushed.Load(),
"flush_errors": s.flushErrors.Load(),
}).Info("AnalyticsService stopped")
}
// 事件监听循环
func (s *AnalyticsService) eventListener() {
defer s.wg.Done()
sub, err := s.store.Subscribe(models.TopicRequestFinished)
sub, err := s.store.Subscribe(s.ctx, models.TopicRequestFinished)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
s.logger.WithError(err).Error("Failed to subscribe to request events, analytics disabled")
return
}
defer sub.Close()
s.logger.Info("AnalyticsService subscribed to request events.")
defer func() {
if err := sub.Close(); err != nil {
s.logger.WithError(err).Warn("Failed to close subscription")
}
}()
s.logger.Info("Subscribed to request events for analytics")
for {
select {
case msg := <-sub.Channel():
var event models.RequestFinishedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.Errorf("Failed to unmarshal analytics event: %v", err)
continue
}
s.handleAnalyticsEvent(&event)
s.handleMessage(msg)
case <-s.stopChan:
s.logger.Info("AnalyticsService stopping event listener.")
s.logger.Info("Event listener stopping")
return
case <-s.ctx.Done():
s.logger.Info("Event listener context cancelled")
return
}
}
}
func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) {
if event.RequestLog.GroupID == nil {
// 处理单条消息
func (s *AnalyticsService) handleMessage(msg *store.Message) {
var event models.RequestFinishedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.WithError(err).Error("Failed to unmarshal analytics event")
s.eventsFailed.Add(1)
return
}
key := fmt.Sprintf("analytics:hourly:%s", time.Now().UTC().Format("2006-01-02T15"))
s.eventsReceived.Add(1)
if err := s.handleAnalyticsEvent(&event); err != nil {
s.eventsFailed.Add(1)
s.logger.WithFields(logrus.Fields{
"correlation_id": event.CorrelationID,
"group_id": event.RequestLog.GroupID,
}).WithError(err).Warn("Failed to process analytics event")
} else {
s.eventsProcessed.Add(1)
}
}
// 处理分析事件
func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) error {
if event.RequestLog.GroupID == nil {
return nil // 跳过无 GroupID 的事件
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
now := time.Now().UTC()
key := fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15"))
fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName)
pipe := s.store.Pipeline()
pipe := s.store.Pipeline(ctx)
pipe.HIncrBy(key, fieldPrefix+":requests", 1)
if event.RequestLog.IsSuccess {
pipe.HIncrBy(key, fieldPrefix+":success", 1)
}
@@ -100,79 +187,213 @@ func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEve
if event.RequestLog.CompletionTokens > 0 {
pipe.HIncrBy(key, fieldPrefix+":completion", int64(event.RequestLog.CompletionTokens))
}
// 设置过期时间保留48小时
pipe.Expire(key, 48*time.Hour)
if err := pipe.Exec(); err != nil {
s.logger.Warnf("[%s] Failed to record analytics event to store for group %d: %v", event.CorrelationID, *event.RequestLog.GroupID, err)
return fmt.Errorf("redis pipeline failed: %w", err)
}
return nil
}
// 刷新循环
func (s *AnalyticsService) flushLoop() {
defer s.wg.Done()
ticker := time.NewTicker(flushLoopInterval)
s.configMutex.RLock()
interval := s.flushInterval
s.configMutex.RUnlock()
ticker := time.NewTicker(interval)
defer ticker.Stop()
s.logger.WithField("interval", interval).Info("Flush loop started")
for {
select {
case <-ticker.C:
s.flushToDB()
case <-s.stopChan:
s.logger.Info("Flush loop stopping")
return
case <-s.ctx.Done():
return
}
}
}
// 刷写到数据库
func (s *AnalyticsService) flushToDB() {
start := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
now := time.Now().UTC()
keysToFlush := []string{
fmt.Sprintf("analytics:hourly:%s", now.Add(-1*time.Hour).Format("2006-01-02T15")),
fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")),
}
keysToFlush := s.generateFlushKeys(now)
totalRecords := 0
totalErrors := 0
for _, key := range keysToFlush {
data, err := s.store.HGetAll(key)
if err != nil || len(data) == 0 {
continue
records, err := s.flushSingleKey(ctx, key, now)
if err != nil {
s.logger.WithError(err).WithField("key", key).Error("Failed to flush key")
totalErrors++
s.flushErrors.Add(1)
} else {
totalRecords += records
}
}
statsToFlush, parsedFields := s.parseStatsFromHash(now.Truncate(time.Hour), data)
s.recordsFlushed.Add(uint64(totalRecords))
s.flushCount.Add(1)
if len(statsToFlush) > 0 {
upsertClause := s.dialect.OnConflictUpdateAll(
[]string{"time", "group_id", "model_name"}, // conflict columns
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}, // update columns
)
err := s.db.Clauses(upsertClause).Create(&statsToFlush).Error
if err != nil {
s.logger.Errorf("Failed to flush analytics data for key %s: %v", key, err)
} else {
s.logger.Infof("Successfully flushed %d records from key %s.", len(statsToFlush), key)
_ = s.store.HDel(key, parsedFields...)
}
}
s.lastFlushMutex.Lock()
s.lastFlushTime = time.Now()
s.lastFlushMutex.Unlock()
duration := time.Since(start)
if totalRecords > 0 || totalErrors > 0 {
s.logger.WithFields(logrus.Fields{
"records_flushed": totalRecords,
"keys_processed": len(keysToFlush),
"errors": totalErrors,
"duration": duration,
}).Info("Analytics data flush completed")
} else {
s.logger.WithField("duration", duration).Debug("Analytics flush completed (no data)")
}
}
// 生成需要刷新的 Redis 键
func (s *AnalyticsService) generateFlushKeys(now time.Time) []string {
keys := make([]string, 0, 4)
// 当前小时
keys = append(keys, fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")))
// 前3个小时处理延迟和时区问题
for i := 1; i <= 3; i++ {
pastHour := now.Add(-time.Duration(i) * time.Hour)
keys = append(keys, fmt.Sprintf("analytics:hourly:%s", pastHour.Format("2006-01-02T15")))
}
return keys
}
// 刷写单个 Redis 键
func (s *AnalyticsService) flushSingleKey(ctx context.Context, key string, baseTime time.Time) (int, error) {
data, err := s.store.HGetAll(ctx, key)
if err != nil {
return 0, fmt.Errorf("failed to get hash data: %w", err)
}
if len(data) == 0 {
return 0, nil // 无数据,跳过
}
// 解析时间戳
hourStr := strings.TrimPrefix(key, "analytics:hourly:")
recordTime, err := time.Parse("2006-01-02T15", hourStr)
if err != nil {
s.logger.WithError(err).WithField("key", key).Warn("Failed to parse time from key")
recordTime = baseTime.Truncate(time.Hour)
}
statsToFlush, parsedFields := s.parseStatsFromHash(recordTime, data)
if len(statsToFlush) == 0 {
return 0, nil
}
// 使用事务 + 重试机制
var dbErr error
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
dbErr = s.upsertStatsWithTransaction(ctx, statsToFlush)
if dbErr == nil {
break
}
if attempt < maxRetryAttempts {
s.logger.WithFields(logrus.Fields{
"attempt": attempt,
"key": key,
}).WithError(dbErr).Warn("Database upsert failed, retrying...")
time.Sleep(retryDelay)
}
}
if dbErr != nil {
return 0, fmt.Errorf("failed to upsert after %d attempts: %w", maxRetryAttempts, dbErr)
}
// 删除已处理的字段
if len(parsedFields) > 0 {
if err := s.store.HDel(ctx, key, parsedFields...); err != nil {
s.logger.WithError(err).WithField("key", key).Warn("Failed to delete flushed fields from Redis")
}
}
return len(statsToFlush), nil
}
// 使用事务批量 upsert
func (s *AnalyticsService) upsertStatsWithTransaction(ctx context.Context, stats []models.StatsHourly) error {
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
upsertClause := s.dialect.OnConflictUpdateAll(
[]string{"time", "group_id", "model_name"},
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"},
)
return tx.Clauses(upsertClause).Create(&stats).Error
})
}
// 解析 Redis Hash 数据
func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]string) ([]models.StatsHourly, []string) {
tempAggregator := make(map[string]*models.StatsHourly)
var parsedFields []string
parsedFields := make([]string, 0, len(data))
for field, valueStr := range data {
parts := strings.Split(field, ":")
if len(parts) != 3 {
s.logger.WithField("field", field).Warn("Invalid field format")
continue
}
groupIDStr, modelName, counterType := parts[0], parts[1], parts[2]
groupIDStr, modelName, counterType := parts[0], parts[1], parts[2]
aggKey := groupIDStr + ":" + modelName
if _, ok := tempAggregator[aggKey]; !ok {
gid, err := strconv.Atoi(groupIDStr)
if err != nil {
s.logger.WithFields(logrus.Fields{
"field": field,
"group_id": groupIDStr,
}).Warn("Invalid group ID")
continue
}
tempAggregator[aggKey] = &models.StatsHourly{
Time: t,
GroupID: uint(gid),
ModelName: modelName,
}
}
val, _ := strconv.ParseInt(valueStr, 10, 64)
val, err := strconv.ParseInt(valueStr, 10, 64)
if err != nil {
s.logger.WithFields(logrus.Fields{
"field": field,
"value": valueStr,
}).Warn("Invalid counter value")
continue
}
switch counterType {
case "requests":
tempAggregator[aggKey].RequestCount = val
@@ -182,14 +403,92 @@ func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]strin
tempAggregator[aggKey].PromptTokens = val
case "completion":
tempAggregator[aggKey].CompletionTokens = val
default:
s.logger.WithField("counter_type", counterType).Warn("Unknown counter type")
continue
}
parsedFields = append(parsedFields, field)
}
var result []models.StatsHourly
result := make([]models.StatsHourly, 0, len(tempAggregator))
for _, stats := range tempAggregator {
if stats.RequestCount > 0 {
result = append(result, *stats)
}
}
return result, parsedFields
}
// 定期输出统计信息
func (s *AnalyticsService) metricsReporter() {
defer s.wg.Done()
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.reportMetrics()
case <-s.stopChan:
return
case <-s.ctx.Done():
return
}
}
}
func (s *AnalyticsService) reportMetrics() {
s.lastFlushMutex.RLock()
lastFlush := s.lastFlushTime
s.lastFlushMutex.RUnlock()
received := s.eventsReceived.Load()
processed := s.eventsProcessed.Load()
failed := s.eventsFailed.Load()
var successRate float64
if received > 0 {
successRate = float64(processed) / float64(received) * 100
}
s.logger.WithFields(logrus.Fields{
"events_received": received,
"events_processed": processed,
"events_failed": failed,
"success_rate": fmt.Sprintf("%.2f%%", successRate),
"flush_count": s.flushCount.Load(),
"records_flushed": s.recordsFlushed.Load(),
"flush_errors": s.flushErrors.Load(),
"last_flush_ago": time.Since(lastFlush).Round(time.Second),
}).Info("Analytics metrics")
}
// GetMetrics 返回当前统计指标(供监控使用)
func (s *AnalyticsService) GetMetrics() map[string]interface{} {
s.lastFlushMutex.RLock()
lastFlush := s.lastFlushTime
s.lastFlushMutex.RUnlock()
received := s.eventsReceived.Load()
processed := s.eventsProcessed.Load()
var successRate float64
if received > 0 {
successRate = float64(processed) / float64(received) * 100
}
return map[string]interface{}{
"events_received": received,
"events_processed": processed,
"events_failed": s.eventsFailed.Load(),
"success_rate": successRate,
"flush_count": s.flushCount.Load(),
"records_flushed": s.recordsFlushed.Load(),
"flush_errors": s.flushErrors.Load(),
"last_flush_ago": time.Since(lastFlush).Seconds(),
"flush_interval": s.flushInterval.Seconds(),
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,164 +1,300 @@
// Filename: internal/service/dashboard_query_service.go
package service
import (
"context"
"fmt"
"strconv"
"sync"
"sync/atomic"
"time"
"gemini-balancer/internal/models"
"gemini-balancer/internal/store"
"gemini-balancer/internal/syncer"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
const overviewCacheChannel = "syncer:cache:dashboard_overview"
const (
overviewCacheChannel = "syncer:cache:dashboard_overview"
defaultChartDays = 7
cacheLoadTimeout = 30 * time.Second
)
// DashboardQueryService 负责所有面向前端的仪表盘数据查询。
var (
// 图表颜色调色板
chartColorPalette = []string{
"#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0",
"#9966FF", "#FF9F40", "#C9CBCF", "#4D5360",
}
)
type DashboardQueryService struct {
db *gorm.DB
store store.Store
overviewSyncer *syncer.CacheSyncer[*models.DashboardStatsResponse]
logger *logrus.Entry
stopChan chan struct{}
stopChan chan struct{}
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
// 统计指标
queryCount atomic.Uint64
cacheHits atomic.Uint64
cacheMisses atomic.Uint64
overviewLoadCount atomic.Uint64
lastQueryTime time.Time
lastQueryMutex sync.RWMutex
}
func NewDashboardQueryService(db *gorm.DB, s store.Store, logger *logrus.Logger) (*DashboardQueryService, error) {
qs := &DashboardQueryService{
db: db,
store: s,
logger: logger.WithField("component", "DashboardQueryService"),
stopChan: make(chan struct{}),
func NewDashboardQueryService(
db *gorm.DB,
s store.Store,
logger *logrus.Logger,
) (*DashboardQueryService, error) {
ctx, cancel := context.WithCancel(context.Background())
service := &DashboardQueryService{
db: db,
store: s,
logger: logger.WithField("component", "DashboardQuery📈"),
stopChan: make(chan struct{}),
ctx: ctx,
cancel: cancel,
lastQueryTime: time.Now(),
}
loader := qs.loadOverviewData
overviewSyncer, err := syncer.NewCacheSyncer(loader, s, overviewCacheChannel)
// 创建 CacheSyncer
overviewSyncer, err := syncer.NewCacheSyncer(
service.loadOverviewData,
s,
overviewCacheChannel,
logger,
)
if err != nil {
cancel()
return nil, fmt.Errorf("failed to create overview cache syncer: %w", err)
}
qs.overviewSyncer = overviewSyncer
return qs, nil
service.overviewSyncer = overviewSyncer
return service, nil
}
func (s *DashboardQueryService) Start() {
s.wg.Add(2)
go s.eventListener()
s.logger.Info("DashboardQueryService started and listening for invalidation events.")
go s.metricsReporter()
s.logger.Info("DashboardQueryService started")
}
func (s *DashboardQueryService) Stop() {
s.logger.Info("DashboardQueryService stopping...")
close(s.stopChan)
s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.")
s.cancel()
s.wg.Wait()
// 输出最终统计
s.logger.WithFields(logrus.Fields{
"total_queries": s.queryCount.Load(),
"cache_hits": s.cacheHits.Load(),
"cache_misses": s.cacheMisses.Load(),
"overview_loads": s.overviewLoadCount.Load(),
}).Info("DashboardQueryService stopped")
}
func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, error) {
// ==================== 核心查询方法 ====================
// GetDashboardOverviewData 获取仪表盘概览数据(带缓存)
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
s.queryCount.Add(1)
cachedDataPtr := s.overviewSyncer.Get()
if cachedDataPtr == nil {
s.cacheMisses.Add(1)
s.logger.Warn("Overview cache is empty, attempting to load...")
// 触发立即加载
if err := s.overviewSyncer.Invalidate(); err != nil {
return nil, fmt.Errorf("failed to trigger cache reload: %w", err)
}
// 等待加载完成最多30秒
ctx, cancel := context.WithTimeout(context.Background(), cacheLoadTimeout)
defer cancel()
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if data := s.overviewSyncer.Get(); data != nil {
s.cacheHits.Add(1)
return data, nil
}
case <-ctx.Done():
return nil, fmt.Errorf("timeout waiting for overview cache to load")
}
}
}
s.cacheHits.Add(1)
return cachedDataPtr, nil
}
// InvalidateOverviewCache 手动失效概览缓存
func (s *DashboardQueryService) InvalidateOverviewCache() error {
s.logger.Info("Manually invalidating overview cache")
return s.overviewSyncer.Invalidate()
}
// GetGroupStats 获取指定分组的统计数据
func (s *DashboardQueryService) GetGroupStats(ctx context.Context, groupID uint) (map[string]any, error) {
s.queryCount.Add(1)
s.updateLastQueryTime()
start := time.Now()
// 1. 从 Redis 获取 Key 统计
statsKey := fmt.Sprintf("stats:group:%d", groupID)
keyStatsMap, err := s.store.HGetAll(statsKey)
keyStatsMap, err := s.store.HGetAll(ctx, statsKey)
if err != nil {
s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID)
s.logger.WithError(err).Errorf("Failed to get key stats for group %d", groupID)
return nil, fmt.Errorf("failed to get key stats from cache: %w", err)
}
keyStats := make(map[string]int64)
for k, v := range keyStatsMap {
val, _ := strconv.ParseInt(v, 10, 64)
keyStats[k] = val
}
now := time.Now()
// 2. 查询请求统计(使用 UTC 时间)
now := time.Now().UTC()
oneHourAgo := now.Add(-1 * time.Hour)
twentyFourHoursAgo := now.Add(-24 * time.Hour)
type requestStatsResult struct {
TotalRequests int64
SuccessRequests int64
}
var last1Hour, last24Hours requestStatsResult
s.db.Model(&models.StatsHourly{}).
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
Scan(&last1Hour)
s.db.Model(&models.StatsHourly{}).
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
Scan(&last24Hours)
failureRate1h := 0.0
if last1Hour.TotalRequests > 0 {
failureRate1h = float64(last1Hour.TotalRequests-last1Hour.SuccessRequests) / float64(last1Hour.TotalRequests) * 100
}
failureRate24h := 0.0
if last24Hours.TotalRequests > 0 {
failureRate24h = float64(last24Hours.TotalRequests-last24Hours.SuccessRequests) / float64(last24Hours.TotalRequests) * 100
}
last1HourStats := map[string]any{
"total_requests": last1Hour.TotalRequests,
"success_requests": last1Hour.SuccessRequests,
"failure_rate": failureRate1h,
}
last24HoursStats := map[string]any{
"total_requests": last24Hours.TotalRequests,
"success_requests": last24Hours.SuccessRequests,
"failure_rate": failureRate24h,
// 并发查询优化
var wg sync.WaitGroup
errChan := make(chan error, 2)
wg.Add(2)
// 查询最近1小时
go func() {
defer wg.Done()
if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests").
Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
Scan(&last1Hour).Error; err != nil {
errChan <- fmt.Errorf("failed to query 1h stats: %w", err)
}
}()
// 查询最近24小时
go func() {
defer wg.Done()
if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}).
Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests").
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
Scan(&last24Hours).Error; err != nil {
errChan <- fmt.Errorf("failed to query 24h stats: %w", err)
}
}()
wg.Wait()
close(errChan)
// 检查错误
for err := range errChan {
if err != nil {
return nil, err
}
}
// 3. 计算失败率
failureRate1h := s.calculateFailureRate(last1Hour.TotalRequests, last1Hour.SuccessRequests)
failureRate24h := s.calculateFailureRate(last24Hours.TotalRequests, last24Hours.SuccessRequests)
result := map[string]any{
"key_stats": keyStats,
"last_1_hour": last1HourStats,
"last_24_hours": last24HoursStats,
"key_stats": keyStats,
"last_1_hour": map[string]any{
"total_requests": last1Hour.TotalRequests,
"success_requests": last1Hour.SuccessRequests,
"failed_requests": last1Hour.TotalRequests - last1Hour.SuccessRequests,
"failure_rate": failureRate1h,
},
"last_24_hours": map[string]any{
"total_requests": last24Hours.TotalRequests,
"success_requests": last24Hours.SuccessRequests,
"failed_requests": last24Hours.TotalRequests - last24Hours.SuccessRequests,
"failure_rate": failureRate24h,
},
}
duration := time.Since(start)
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"duration": duration,
}).Debug("Group stats query completed")
return result, nil
}
func (s *DashboardQueryService) eventListener() {
keyStatusSub, _ := s.store.Subscribe(models.TopicKeyStatusChanged)
upstreamStatusSub, _ := s.store.Subscribe(models.TopicUpstreamHealthChanged)
defer keyStatusSub.Close()
defer upstreamStatusSub.Close()
for {
select {
case <-keyStatusSub.Channel():
s.logger.Info("Received key status changed event, invalidating overview cache...")
_ = s.InvalidateOverviewCache()
case <-upstreamStatusSub.Channel():
s.logger.Info("Received upstream status changed event, invalidating overview cache...")
_ = s.InvalidateOverviewCache()
case <-s.stopChan:
s.logger.Info("Stopping dashboard event listener.")
return
}
}
}
// QueryHistoricalChart 查询历史图表数据
func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupID *uint) (*models.ChartData, error) {
s.queryCount.Add(1)
s.updateLastQueryTime()
// GetDashboardOverviewData 从 Syncer 缓存中高速获取仪表盘概览数据。
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
cachedDataPtr := s.overviewSyncer.Get()
if cachedDataPtr == nil {
return &models.DashboardStatsResponse{}, fmt.Errorf("overview cache is not available or still syncing")
}
return cachedDataPtr, nil
}
start := time.Now()
func (s *DashboardQueryService) InvalidateOverviewCache() error {
return s.overviewSyncer.Invalidate()
}
// QueryHistoricalChart 查询历史图表数据。
func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.ChartData, error) {
type ChartPoint struct {
TimeLabel string `gorm:"column:time_label"`
ModelName string `gorm:"column:model_name"`
TotalRequests int64 `gorm:"column:total_requests"`
}
sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour)
// 查询最近7天数据使用 UTC
sevenDaysAgo := time.Now().UTC().AddDate(0, 0, -defaultChartDays).Truncate(time.Hour)
// 根据数据库类型构建时间格式化子句
sqlFormat, goFormat := s.buildTimeFormatSelectClause()
selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat)
query := s.db.Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
selectClause := fmt.Sprintf(
"%s as time_label, model_name, COALESCE(SUM(request_count), 0) as total_requests",
sqlFormat,
)
// 构建查询
query := s.db.WithContext(ctx).
Model(&models.StatsHourly{}).
Select(selectClause).
Where("time >= ?", sevenDaysAgo).
Group("time_label, model_name").
Order("time_label ASC")
if groupID != nil && *groupID > 0 {
query = query.Where("group_id = ?", *groupID)
}
var points []ChartPoint
if err := query.Find(&points).Error; err != nil {
return nil, err
return nil, fmt.Errorf("failed to query chart data: %w", err)
}
// 构建数据集
datasets := make(map[string]map[string]int64)
for _, p := range points {
if _, ok := datasets[p.ModelName]; !ok {
@@ -166,150 +302,495 @@ func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.Cha
}
datasets[p.ModelName][p.TimeLabel] = p.TotalRequests
}
// 生成时间标签(按小时)
var labels []string
for t := sevenDaysAgo; t.Before(time.Now()); t = t.Add(time.Hour) {
for t := sevenDaysAgo; t.Before(time.Now().UTC()); t = t.Add(time.Hour) {
labels = append(labels, t.Format(goFormat))
}
chartData := &models.ChartData{Labels: labels, Datasets: make([]models.ChartDataset, 0)}
colorPalette := []string{"#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0", "#9966FF", "#FF9F40"}
// 构建图表数据
chartData := &models.ChartData{
Labels: labels,
Datasets: make([]models.ChartDataset, 0, len(datasets)),
}
colorIndex := 0
for modelName, dataPoints := range datasets {
dataArray := make([]int64, len(labels))
for i, label := range labels {
dataArray[i] = dataPoints[label]
}
chartData.Datasets = append(chartData.Datasets, models.ChartDataset{
Label: modelName,
Data: dataArray,
Color: colorPalette[colorIndex%len(colorPalette)],
Color: chartColorPalette[colorIndex%len(chartColorPalette)],
})
colorIndex++
}
duration := time.Since(start)
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"points": len(points),
"datasets": len(chartData.Datasets),
"duration": duration,
}).Debug("Historical chart query completed")
return chartData, nil
}
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
s.logger.Info("[CacheSyncer] Starting to load overview data from database...")
startTime := time.Now()
resp := &models.DashboardStatsResponse{
KeyStatusCount: make(map[models.APIKeyStatus]int64),
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
KeyCount: models.StatCard{}, // 确保KeyCount是一个空的结构体而不是nil
RequestCount24h: models.StatCard{}, // 同上
TokenCount: make(map[string]any),
UpstreamHealthStatus: make(map[string]string),
RPM: models.StatCard{},
RequestCounts: make(map[string]int64),
}
// --- 1. Aggregate Operational Status from Mappings ---
type MappingStatusResult struct {
Status models.APIKeyStatus
Count int64
}
var mappingStatusResults []MappingStatusResult
if err := s.db.Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
return nil, fmt.Errorf("failed to query mapping status stats: %w", err)
}
for _, res := range mappingStatusResults {
resp.KeyStatusCount[res.Status] = res.Count
}
// GetRequestStatsForPeriod 获取指定时间段的请求统计
func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) {
s.queryCount.Add(1)
s.updateLastQueryTime()
// --- 2. Aggregate Master Status from APIKeys ---
type MasterStatusResult struct {
Status models.MasterAPIKeyStatus
Count int64
}
var masterStatusResults []MasterStatusResult
if err := s.db.Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
return nil, fmt.Errorf("failed to query master status stats: %w", err)
}
var totalKeys, invalidKeys int64
for _, res := range masterStatusResults {
resp.MasterStatusCount[res.Status] = res.Count
totalKeys += res.Count
if res.Status != models.MasterStatusActive {
invalidKeys += res.Count
}
}
resp.KeyCount = models.StatCard{Value: float64(totalKeys), SubValue: invalidKeys, SubValueTip: "非活跃身份密钥数"}
now := time.Now()
// 1. RPM (1分钟), RPH (1小时), RPD (今日): 从“瞬时记忆”(request_logs)中精确查询
var count1m, count1h, count1d int64
// RPM: 从此刻倒推1分钟
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
// RPH: 从此刻倒推1小时
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
// RPD: 从今天零点 (UTC) 到此刻
year, month, day := now.UTC().Date()
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
// 2. RP30D (30天): 从“长期记忆”(stats_hourly)中高效查询,以保证性能
var count30d int64
s.db.Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
resp.RequestCounts["1m"] = count1m
resp.RequestCounts["1h"] = count1h
resp.RequestCounts["1d"] = count1d
resp.RequestCounts["30d"] = count30d
var upstreams []*models.UpstreamEndpoint
if err := s.db.Find(&upstreams).Error; err != nil {
s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.")
} else {
for _, u := range upstreams {
resp.UpstreamHealthStatus[u.URL] = u.Status
}
}
duration := time.Since(startTime)
s.logger.Infof("[CacheSyncer] Successfully finished loading overview data in %s.", duration)
return resp, nil
}
func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H, error) {
var startTime time.Time
now := time.Now()
now := time.Now().UTC()
switch period {
case "1m":
startTime = now.Add(-1 * time.Minute)
case "1h":
startTime = now.Add(-1 * time.Hour)
case "1d":
year, month, day := now.UTC().Date()
year, month, day := now.Date()
startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
default:
return nil, fmt.Errorf("invalid period specified: %s", period)
return nil, fmt.Errorf("invalid period specified: %s (must be 1m, 1h, or 1d)", period)
}
var result struct {
Total int64
Success int64
}
err := s.db.Model(&models.RequestLog{}).
Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success").
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
Select("COUNT(*) as total, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success").
Where("request_time >= ?", startTime).
Scan(&result).Error
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to query request stats: %w", err)
}
return gin.H{
"period": period,
"total": result.Total,
"success": result.Success,
"failure": result.Total - result.Success,
}, nil
}
// ==================== 内部方法 ====================
// loadOverviewData 加载仪表盘概览数据(供 CacheSyncer 调用)
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
s.overviewLoadCount.Add(1)
startTime := time.Now()
s.logger.Info("Starting to load dashboard overview data...")
resp := &models.DashboardStatsResponse{
KeyStatusCount: make(map[models.APIKeyStatus]int64),
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
KeyCount: models.StatCard{},
RequestCount24h: models.StatCard{},
TokenCount: make(map[string]any),
UpstreamHealthStatus: make(map[string]string),
RPM: models.StatCard{},
RequestCounts: make(map[string]int64),
}
var loadErr error
var wg sync.WaitGroup
errChan := make(chan error, 10)
// 1. 并发加载 Key 映射状态统计
wg.Add(1)
go func() {
defer wg.Done()
if err := s.loadMappingStatusStats(ctx, resp); err != nil {
errChan <- fmt.Errorf("mapping stats: %w", err)
}
}()
// 2. 并发加载 Master Key 状态统计
wg.Add(1)
go func() {
defer wg.Done()
if err := s.loadMasterStatusStats(ctx, resp); err != nil {
errChan <- fmt.Errorf("master stats: %w", err)
}
}()
// 3. 并发加载请求统计
wg.Add(1)
go func() {
defer wg.Done()
if err := s.loadRequestCounts(ctx, resp); err != nil {
errChan <- fmt.Errorf("request counts: %w", err)
}
}()
// 4. 并发加载上游健康状态
wg.Add(1)
go func() {
defer wg.Done()
if err := s.loadUpstreamHealth(ctx, resp); err != nil {
// 上游健康状态失败不阻塞整体加载
s.logger.WithError(err).Warn("Failed to load upstream health status")
}
}()
// 等待所有加载任务完成
wg.Wait()
close(errChan)
// 收集错误
for err := range errChan {
if err != nil {
loadErr = err
break
}
}
if loadErr != nil {
s.logger.WithError(loadErr).Error("Failed to load overview data")
return nil, loadErr
}
duration := time.Since(startTime)
s.logger.WithFields(logrus.Fields{
"duration": duration,
"total_keys": resp.KeyCount.Value,
"requests_1d": resp.RequestCounts["1d"],
"upstreams": len(resp.UpstreamHealthStatus),
}).Info("Successfully loaded dashboard overview data")
return resp, nil
}
// loadMappingStatusStats 加载 Key 映射状态统计
func (s *DashboardQueryService) loadMappingStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error {
type MappingStatusResult struct {
Status models.APIKeyStatus
Count int64
}
var results []MappingStatusResult
if err := s.db.WithContext(ctx).
Model(&models.GroupAPIKeyMapping{}).
Select("status, COUNT(*) as count").
Group("status").
Find(&results).Error; err != nil {
return err
}
for _, res := range results {
resp.KeyStatusCount[res.Status] = res.Count
}
return nil
}
// loadMasterStatusStats 加载 Master Key 状态统计
func (s *DashboardQueryService) loadMasterStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error {
type MasterStatusResult struct {
Status models.MasterAPIKeyStatus
Count int64
}
var results []MasterStatusResult
if err := s.db.WithContext(ctx).
Model(&models.APIKey{}).
Select("master_status as status, COUNT(*) as count").
Group("master_status").
Find(&results).Error; err != nil {
return err
}
var totalKeys, invalidKeys int64
for _, res := range results {
resp.MasterStatusCount[res.Status] = res.Count
totalKeys += res.Count
if res.Status != models.MasterStatusActive {
invalidKeys += res.Count
}
}
resp.KeyCount = models.StatCard{
Value: float64(totalKeys),
SubValue: invalidKeys,
SubValueTip: "非活跃身份密钥数",
}
return nil
}
// loadRequestCounts 加载请求计数统计
func (s *DashboardQueryService) loadRequestCounts(ctx context.Context, resp *models.DashboardStatsResponse) error {
now := time.Now().UTC()
// 使用 RequestLog 表查询短期数据
var count1m, count1h int64
// 最近1分钟
if err := s.db.WithContext(ctx).
Model(&models.RequestLog{}).
Where("request_time >= ?", now.Add(-1*time.Minute)).
Count(&count1m).Error; err != nil {
return fmt.Errorf("1m count: %w", err)
}
// 最近1小时
if err := s.db.WithContext(ctx).
Model(&models.RequestLog{}).
Where("request_time >= ?", now.Add(-1*time.Hour)).
Count(&count1h).Error; err != nil {
return fmt.Errorf("1h count: %w", err)
}
// 今天UTC
year, month, day := now.Date()
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
var count1d int64
if err := s.db.WithContext(ctx).
Model(&models.RequestLog{}).
Where("request_time >= ?", startOfDay).
Count(&count1d).Error; err != nil {
return fmt.Errorf("1d count: %w", err)
}
// 最近30天使用聚合表
var count30d int64
if err := s.db.WithContext(ctx).
Model(&models.StatsHourly{}).
Where("time >= ?", now.AddDate(0, 0, -30)).
Select("COALESCE(SUM(request_count), 0)").
Scan(&count30d).Error; err != nil {
return fmt.Errorf("30d count: %w", err)
}
resp.RequestCounts["1m"] = count1m
resp.RequestCounts["1h"] = count1h
resp.RequestCounts["1d"] = count1d
resp.RequestCounts["30d"] = count30d
return nil
}
// loadUpstreamHealth 加载上游健康状态
func (s *DashboardQueryService) loadUpstreamHealth(ctx context.Context, resp *models.DashboardStatsResponse) error {
var upstreams []*models.UpstreamEndpoint
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
return err
}
for _, u := range upstreams {
resp.UpstreamHealthStatus[u.URL] = u.Status
}
return nil
}
// ==================== 事件监听 ====================
// eventListener 监听缓存失效事件
func (s *DashboardQueryService) eventListener() {
defer s.wg.Done()
// 订阅事件
keyStatusSub, err1 := s.store.Subscribe(s.ctx, models.TopicKeyStatusChanged)
upstreamStatusSub, err2 := s.store.Subscribe(s.ctx, models.TopicUpstreamHealthChanged)
// 错误处理
if err1 != nil {
s.logger.WithError(err1).Error("Failed to subscribe to key status events")
keyStatusSub = nil
}
if err2 != nil {
s.logger.WithError(err2).Error("Failed to subscribe to upstream status events")
upstreamStatusSub = nil
}
// 如果全部失败,直接返回
if keyStatusSub == nil && upstreamStatusSub == nil {
s.logger.Error("All event subscriptions failed, listener disabled")
return
}
// 安全关闭订阅
defer func() {
if keyStatusSub != nil {
if err := keyStatusSub.Close(); err != nil {
s.logger.WithError(err).Warn("Failed to close key status subscription")
}
}
if upstreamStatusSub != nil {
if err := upstreamStatusSub.Close(); err != nil {
s.logger.WithError(err).Warn("Failed to close upstream status subscription")
}
}
}()
s.logger.WithFields(logrus.Fields{
"key_status_sub": keyStatusSub != nil,
"upstream_status_sub": upstreamStatusSub != nil,
}).Info("Event listener started")
neverReady := make(chan *store.Message)
close(neverReady) // 立即关闭,确保永远不会阻塞
for {
// 动态选择有效的 channel
var keyStatusChan <-chan *store.Message = neverReady
if keyStatusSub != nil {
keyStatusChan = keyStatusSub.Channel()
}
var upstreamStatusChan <-chan *store.Message = neverReady
if upstreamStatusSub != nil {
upstreamStatusChan = upstreamStatusSub.Channel()
}
select {
case _, ok := <-keyStatusChan:
if !ok {
s.logger.Warn("Key status channel closed")
keyStatusSub = nil
continue
}
s.logger.Debug("Received key status changed event")
if err := s.InvalidateOverviewCache(); err != nil {
s.logger.WithError(err).Warn("Failed to invalidate cache on key status change")
}
case _, ok := <-upstreamStatusChan:
if !ok {
s.logger.Warn("Upstream status channel closed")
upstreamStatusSub = nil
continue
}
s.logger.Debug("Received upstream status changed event")
if err := s.InvalidateOverviewCache(); err != nil {
s.logger.WithError(err).Warn("Failed to invalidate cache on upstream status change")
}
case <-s.stopChan:
s.logger.Info("Event listener stopping (stopChan)")
return
case <-s.ctx.Done():
s.logger.Info("Event listener stopping (context cancelled)")
return
}
}
}
// ==================== 监控指标 ====================
// metricsReporter 定期输出统计信息
func (s *DashboardQueryService) metricsReporter() {
defer s.wg.Done()
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.reportMetrics()
case <-s.stopChan:
return
case <-s.ctx.Done():
return
}
}
}
func (s *DashboardQueryService) reportMetrics() {
s.lastQueryMutex.RLock()
lastQuery := s.lastQueryTime
s.lastQueryMutex.RUnlock()
totalQueries := s.queryCount.Load()
hits := s.cacheHits.Load()
misses := s.cacheMisses.Load()
var cacheHitRate float64
if hits+misses > 0 {
cacheHitRate = float64(hits) / float64(hits+misses) * 100
}
s.logger.WithFields(logrus.Fields{
"total_queries": totalQueries,
"cache_hits": hits,
"cache_misses": misses,
"cache_hit_rate": fmt.Sprintf("%.2f%%", cacheHitRate),
"overview_loads": s.overviewLoadCount.Load(),
"last_query_ago": time.Since(lastQuery).Round(time.Second),
}).Info("DashboardQuery metrics")
}
// GetMetrics 返回当前统计指标(供监控使用)
func (s *DashboardQueryService) GetMetrics() map[string]interface{} {
s.lastQueryMutex.RLock()
lastQuery := s.lastQueryTime
s.lastQueryMutex.RUnlock()
hits := s.cacheHits.Load()
misses := s.cacheMisses.Load()
var cacheHitRate float64
if hits+misses > 0 {
cacheHitRate = float64(hits) / float64(hits+misses) * 100
}
return map[string]interface{}{
"total_queries": s.queryCount.Load(),
"cache_hits": hits,
"cache_misses": misses,
"cache_hit_rate": cacheHitRate,
"overview_loads": s.overviewLoadCount.Load(),
"last_query_ago": time.Since(lastQuery).Seconds(),
}
}
// ==================== 辅助方法 ====================
// calculateFailureRate 计算失败率
func (s *DashboardQueryService) calculateFailureRate(total, success int64) float64 {
if total == 0 {
return 0.0
}
return float64(total-success) / float64(total) * 100
}
// updateLastQueryTime 更新最后查询时间
func (s *DashboardQueryService) updateLastQueryTime() {
s.lastQueryMutex.Lock()
s.lastQueryTime = time.Now()
s.lastQueryMutex.Unlock()
}
// buildTimeFormatSelectClause 根据数据库类型构建时间格式化子句
func (s *DashboardQueryService) buildTimeFormatSelectClause() (string, string) {
dialect := s.db.Dialector.Name()
switch dialect {
case "mysql":
return "DATE_FORMAT(time, '%Y-%m-%d %H:00:00')", "2006-01-02 15:00:00"
case "postgres":
return "TO_CHAR(time, 'YYYY-MM-DD HH24:00:00')", "2006-01-02 15:00:00"
case "sqlite":
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
default:
s.logger.WithField("dialect", dialect).Warn("Unknown database dialect, using SQLite format")
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
}
}

View File

@@ -1,14 +1,16 @@
// Filename: internal/service/db_log_writer_service.go
package service
import (
"context"
"encoding/json"
"sync"
"sync/atomic"
"time"
"gemini-balancer/internal/models"
"gemini-balancer/internal/settings"
"gemini-balancer/internal/store"
"sync"
"time"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
@@ -18,132 +20,324 @@ type DBLogWriterService struct {
db *gorm.DB
store store.Store
logger *logrus.Entry
logBuffer chan *models.RequestLog
stopChan chan struct{}
wg sync.WaitGroup
SettingsManager *settings.SettingsManager
settingsManager *settings.SettingsManager
logBuffer chan *models.RequestLog
stopChan chan struct{}
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
// 统计指标
totalReceived atomic.Uint64
totalFlushed atomic.Uint64
totalDropped atomic.Uint64
flushCount atomic.Uint64
lastFlushTime time.Time
lastFlushMutex sync.RWMutex
}
func NewDBLogWriterService(db *gorm.DB, s store.Store, settings *settings.SettingsManager, logger *logrus.Logger) *DBLogWriterService {
cfg := settings.GetSettings()
func NewDBLogWriterService(
db *gorm.DB,
s store.Store,
settingsManager *settings.SettingsManager,
logger *logrus.Logger,
) *DBLogWriterService {
cfg := settingsManager.GetSettings()
bufferCapacity := cfg.LogBufferCapacity
if bufferCapacity <= 0 {
bufferCapacity = 1000
}
ctx, cancel := context.WithCancel(context.Background())
return &DBLogWriterService{
db: db,
store: s,
SettingsManager: settings,
settingsManager: settingsManager,
logger: logger.WithField("component", "DBLogWriter📝"),
// 使用配置值来创建缓冲区
logBuffer: make(chan *models.RequestLog, bufferCapacity),
stopChan: make(chan struct{}),
logBuffer: make(chan *models.RequestLog, bufferCapacity),
stopChan: make(chan struct{}),
ctx: ctx,
cancel: cancel,
lastFlushTime: time.Now(),
}
}
func (s *DBLogWriterService) Start() {
s.wg.Add(2) // 一个用于事件监听,一个用于数据库写入
// 启动事件监听器
s.wg.Add(2)
go s.eventListenerLoop()
// 启动数据库写入器
go s.dbWriterLoop()
s.logger.Info("DBLogWriterService started.")
// 定期输出统计信息
s.wg.Add(1)
go s.metricsReporter()
s.logger.WithFields(logrus.Fields{
"buffer_capacity": cap(s.logBuffer),
}).Info("DBLogWriterService started")
}
func (s *DBLogWriterService) Stop() {
s.logger.Info("DBLogWriterService stopping...")
close(s.stopChan) // 通知所有goroutine停止
s.wg.Wait() // 等待所有goroutine完成
s.logger.Info("DBLogWriterService stopped.")
close(s.stopChan)
s.cancel() // 取消上下文
s.wg.Wait()
// 输出最终统计
s.logger.WithFields(logrus.Fields{
"total_received": s.totalReceived.Load(),
"total_flushed": s.totalFlushed.Load(),
"total_dropped": s.totalDropped.Load(),
"flush_count": s.flushCount.Load(),
}).Info("DBLogWriterService stopped")
}
// eventListenerLoop 负责从store接收事件并放入内存缓冲区
// 事件监听循环
func (s *DBLogWriterService) eventListenerLoop() {
defer s.wg.Done()
sub, err := s.store.Subscribe(models.TopicRequestFinished)
sub, err := s.store.Subscribe(s.ctx, models.TopicRequestFinished)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
s.logger.WithError(err).Error("Failed to subscribe to request events, log writing disabled")
return
}
defer sub.Close()
defer func() {
if err := sub.Close(); err != nil {
s.logger.WithError(err).Warn("Failed to close subscription")
}
}()
s.logger.Info("Subscribed to request events for database logging.")
s.logger.Info("Subscribed to request events for database logging")
for {
select {
case msg := <-sub.Channel():
var event models.RequestFinishedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.Errorf("Failed to unmarshal event for logging: %v", err)
continue
}
// 将事件中的日志部分放入缓冲区
select {
case s.logBuffer <- &event.RequestLog:
default:
s.logger.Warn("Log buffer is full. A log message might be dropped.")
}
s.handleMessage(msg)
case <-s.stopChan:
s.logger.Info("Event listener loop stopping.")
// 关闭缓冲区以通知dbWriterLoop处理完剩余日志后退出
s.logger.Info("Event listener loop stopping")
close(s.logBuffer)
return
case <-s.ctx.Done():
s.logger.Info("Event listener context cancelled")
close(s.logBuffer)
return
}
}
}
// dbWriterLoop 负责从内存缓冲区批量读取日志并写入数据库
// 处理单条消息
func (s *DBLogWriterService) handleMessage(msg *store.Message) {
var event models.RequestFinishedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.WithError(err).Error("Failed to unmarshal request event")
return
}
s.totalReceived.Add(1)
select {
case s.logBuffer <- &event.RequestLog:
// 成功入队
default:
// 缓冲区满,丢弃日志
dropped := s.totalDropped.Add(1)
if dropped%100 == 1 { // 每100条丢失输出一次警告
s.logger.WithFields(logrus.Fields{
"total_dropped": dropped,
"buffer_capacity": cap(s.logBuffer),
"buffer_len": len(s.logBuffer),
}).Warn("Log buffer full, messages being dropped")
}
}
}
// 数据库写入循环
func (s *DBLogWriterService) dbWriterLoop() {
defer s.wg.Done()
// 在启动时获取一次配置
cfg := s.SettingsManager.GetSettings()
cfg := s.settingsManager.GetSettings()
batchSize := cfg.LogFlushBatchSize
if batchSize <= 0 {
batchSize = 100
}
flushTimeout := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
if flushTimeout <= 0 {
flushTimeout = 5 * time.Second
flushInterval := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
if flushInterval <= 0 {
flushInterval = 5 * time.Second
}
s.logger.WithFields(logrus.Fields{
"batch_size": batchSize,
"flush_interval": flushInterval,
}).Info("DB writer loop started")
batch := make([]*models.RequestLog, 0, batchSize)
ticker := time.NewTicker(flushTimeout)
ticker := time.NewTicker(flushInterval)
defer ticker.Stop()
// 配置热更新检查(每分钟)
configTicker := time.NewTicker(1 * time.Minute)
defer configTicker.Stop()
for {
select {
case logEntry, ok := <-s.logBuffer:
if !ok {
// 通道关闭,刷新剩余日志
if len(batch) > 0 {
s.flushBatch(batch)
}
s.logger.Info("DB writer loop finished.")
s.logger.Info("DB writer loop finished")
return
}
batch = append(batch, logEntry)
if len(batch) >= batchSize { // 使用配置的批次大小
if len(batch) >= batchSize {
s.flushBatch(batch)
batch = make([]*models.RequestLog, 0, batchSize)
}
case <-ticker.C:
if len(batch) > 0 {
s.flushBatch(batch)
batch = make([]*models.RequestLog, 0, batchSize)
}
case <-configTicker.C:
// 热更新配置
cfg := s.settingsManager.GetSettings()
newBatchSize := cfg.LogFlushBatchSize
if newBatchSize <= 0 {
newBatchSize = 100
}
newFlushInterval := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
if newFlushInterval <= 0 {
newFlushInterval = 5 * time.Second
}
if newBatchSize != batchSize {
s.logger.WithFields(logrus.Fields{
"old": batchSize,
"new": newBatchSize,
}).Info("Batch size updated")
batchSize = newBatchSize
if len(batch) >= batchSize {
s.flushBatch(batch)
batch = make([]*models.RequestLog, 0, batchSize)
}
}
if newFlushInterval != flushInterval {
s.logger.WithFields(logrus.Fields{
"old": flushInterval,
"new": newFlushInterval,
}).Info("Flush interval updated")
flushInterval = newFlushInterval
ticker.Reset(flushInterval)
}
}
}
}
// flushBatch 将一个批次的日志写入数据库
// 批量刷写到数据库
func (s *DBLogWriterService) flushBatch(batch []*models.RequestLog) {
if err := s.db.CreateInBatches(batch, len(batch)).Error; err != nil {
s.logger.WithField("batch_size", len(batch)).WithError(err).Error("Failed to flush log batch to database.")
if len(batch) == 0 {
return
}
start := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err := s.db.WithContext(ctx).CreateInBatches(batch, len(batch)).Error
duration := time.Since(start)
s.lastFlushMutex.Lock()
s.lastFlushTime = time.Now()
s.lastFlushMutex.Unlock()
if err != nil {
s.logger.WithFields(logrus.Fields{
"batch_size": len(batch),
"duration": duration,
}).WithError(err).Error("Failed to flush log batch to database")
} else {
s.logger.Infof("Successfully flushed %d logs to database.", len(batch))
flushed := s.totalFlushed.Add(uint64(len(batch)))
flushCount := s.flushCount.Add(1)
// 只在慢写入或大批量时输出日志
if duration > 1*time.Second || len(batch) > 500 {
s.logger.WithFields(logrus.Fields{
"batch_size": len(batch),
"duration": duration,
"total_flushed": flushed,
"flush_count": flushCount,
}).Info("Log batch flushed to database")
} else {
s.logger.WithFields(logrus.Fields{
"batch_size": len(batch),
"duration": duration,
}).Debug("Log batch flushed to database")
}
}
}
// 定期输出统计信息
func (s *DBLogWriterService) metricsReporter() {
defer s.wg.Done()
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.reportMetrics()
case <-s.stopChan:
return
case <-s.ctx.Done():
return
}
}
}
func (s *DBLogWriterService) reportMetrics() {
s.lastFlushMutex.RLock()
lastFlush := s.lastFlushTime
s.lastFlushMutex.RUnlock()
received := s.totalReceived.Load()
flushed := s.totalFlushed.Load()
dropped := s.totalDropped.Load()
pending := uint64(len(s.logBuffer))
s.logger.WithFields(logrus.Fields{
"received": received,
"flushed": flushed,
"dropped": dropped,
"pending": pending,
"flush_count": s.flushCount.Load(),
"last_flush": time.Since(lastFlush).Round(time.Second),
"buffer_usage": float64(pending) / float64(cap(s.logBuffer)) * 100,
"success_rate": float64(flushed) / float64(received) * 100,
}).Info("DBLogWriter metrics")
}
// GetMetrics 返回当前统计指标(供监控使用)
func (s *DBLogWriterService) GetMetrics() map[string]interface{} {
s.lastFlushMutex.RLock()
lastFlush := s.lastFlushTime
s.lastFlushMutex.RUnlock()
return map[string]interface{}{
"total_received": s.totalReceived.Load(),
"total_flushed": s.totalFlushed.Load(),
"total_dropped": s.totalDropped.Load(),
"flush_count": s.flushCount.Load(),
"buffer_pending": len(s.logBuffer),
"buffer_capacity": cap(s.logBuffer),
"last_flush_ago": time.Since(lastFlush).Seconds(),
}
}

View File

@@ -1,5 +1,4 @@
// Filename: internal/service/group_manager.go (Syncer升级版)
// Filename: internal/service/group_manager.go
package service
import (
@@ -10,6 +9,7 @@ import (
"gemini-balancer/internal/pkg/reflectutil"
"gemini-balancer/internal/repository"
"gemini-balancer/internal/settings"
"gemini-balancer/internal/store"
"gemini-balancer/internal/syncer"
"gemini-balancer/internal/utils"
"net/url"
@@ -29,10 +29,9 @@ type GroupManagerCacheData struct {
Groups []*models.KeyGroup
GroupsByName map[string]*models.KeyGroup
GroupsByID map[uint]*models.KeyGroup
KeyCounts map[uint]int64 // GroupID -> Total Key Count
KeyStatusCounts map[uint]map[models.APIKeyStatus]int64 // GroupID -> Status -> Count
KeyCounts map[uint]int64
KeyStatusCounts map[uint]map[models.APIKeyStatus]int64
}
type GroupManager struct {
db *gorm.DB
keyRepo repository.KeyRepository
@@ -41,7 +40,6 @@ type GroupManager struct {
syncer *syncer.CacheSyncer[GroupManagerCacheData]
logger *logrus.Entry
}
type UpdateOrderPayload struct {
ID uint `json:"id" binding:"required"`
Order int `json:"order"`
@@ -49,43 +47,19 @@ type UpdateOrderPayload struct {
func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc[GroupManagerCacheData] {
return func() (GroupManagerCacheData, error) {
logger.Debugf("[GML-LOG 1/5] ---> Entering NewGroupManagerLoader...")
var groups []*models.KeyGroup
logger.Debugf("[GML-LOG 2/5] About to execute DB query with Preloads...")
if err := db.Preload("AllowedUpstreams").
Preload("AllowedModels").
Preload("Settings").
Preload("RequestConfig").
Preload("Mappings").
Find(&groups).Error; err != nil {
logger.Errorf("[GML-LOG] CRITICAL: DB query for groups failed: %v", err)
return GroupManagerCacheData{}, fmt.Errorf("failed to load key groups for cache: %w", err)
return GroupManagerCacheData{}, fmt.Errorf("failed to load groups: %w", err)
}
logger.Debugf("[GML-LOG 2.1/5] DB query for groups finished. Found %d group records.", len(groups))
var allMappings []*models.GroupAPIKeyMapping
if err := db.Find(&allMappings).Error; err != nil {
logger.Errorf("[GML-LOG] CRITICAL: DB query for mappings failed: %v", err)
return GroupManagerCacheData{}, fmt.Errorf("failed to load key mappings for cache: %w", err)
}
logger.Debugf("[GML-LOG 2.2/5] DB query for mappings finished. Found %d total mapping records.", len(allMappings))
mappingsByGroupID := make(map[uint][]*models.GroupAPIKeyMapping)
for i := range allMappings {
mapping := allMappings[i] // Avoid pointer issues with range
mappingsByGroupID[mapping.KeyGroupID] = append(mappingsByGroupID[mapping.KeyGroupID], mapping)
}
for _, group := range groups {
if mappings, ok := mappingsByGroupID[group.ID]; ok {
group.Mappings = mappings
}
}
logger.Debugf("[GML-LOG 3/5] Finished manually associating mappings to groups.")
keyCounts := make(map[uint]int64, len(groups))
keyStatusCounts := make(map[uint]map[models.APIKeyStatus]int64, len(groups))
groupsByName := make(map[string]*models.KeyGroup, len(groups))
groupsByID := make(map[uint]*models.KeyGroup, len(groups))
for _, group := range groups {
keyCounts[group.ID] = int64(len(group.Mappings))
statusCounts := make(map[models.APIKeyStatus]int64)
@@ -93,20 +67,9 @@ func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc
statusCounts[mapping.Status]++
}
keyStatusCounts[group.ID] = statusCounts
groupsByName[group.Name] = group
groupsByID[group.ID] = group
}
groupsByName := make(map[string]*models.KeyGroup, len(groups))
groupsByID := make(map[uint]*models.KeyGroup, len(groups))
logger.Debugf("[GML-LOG 4/5] Starting to process group records into maps...")
for i, group := range groups {
if group == nil {
logger.Debugf("[GML] CRITICAL: Found a 'nil' group pointer at index %d! This is the most likely cause of the panic.", i)
} else {
groupsByName[group.Name] = group
groupsByID[group.ID] = group
}
}
logger.Debugf("[GML-LOG 5/5] Finished processing records. Building final cache data...")
return GroupManagerCacheData{
Groups: groups,
GroupsByName: groupsByName,
@@ -116,7 +79,6 @@ func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc
}, nil
}
}
func NewGroupManager(
db *gorm.DB,
keyRepo repository.KeyRepository,
@@ -134,138 +96,67 @@ func NewGroupManager(
logger: logger.WithField("component", "GroupManager"),
}
}
func (gm *GroupManager) GetAllGroups() []*models.KeyGroup {
cache := gm.syncer.Get()
if len(cache.Groups) == 0 {
return []*models.KeyGroup{}
}
groupsToOrder := cache.Groups
sort.Slice(groupsToOrder, func(i, j int) bool {
if groupsToOrder[i].Order != groupsToOrder[j].Order {
return groupsToOrder[i].Order < groupsToOrder[j].Order
groups := gm.syncer.Get().Groups
sort.Slice(groups, func(i, j int) bool {
if groups[i].Order != groups[j].Order {
return groups[i].Order < groups[j].Order
}
return groupsToOrder[i].ID < groupsToOrder[j].ID
return groups[i].ID < groups[j].ID
})
return groupsToOrder
return groups
}
func (gm *GroupManager) GetKeyCount(groupID uint) int64 {
cache := gm.syncer.Get()
if len(cache.KeyCounts) == 0 {
return 0
}
count := cache.KeyCounts[groupID]
return count
return gm.syncer.Get().KeyCounts[groupID]
}
func (gm *GroupManager) GetKeyStatusCount(groupID uint) map[models.APIKeyStatus]int64 {
cache := gm.syncer.Get()
if len(cache.KeyStatusCounts) == 0 {
return make(map[models.APIKeyStatus]int64)
}
if counts, ok := cache.KeyStatusCounts[groupID]; ok {
if counts, ok := gm.syncer.Get().KeyStatusCounts[groupID]; ok {
return counts
}
return make(map[models.APIKeyStatus]int64)
}
func (gm *GroupManager) GetGroupByName(name string) (*models.KeyGroup, bool) {
cache := gm.syncer.Get()
if len(cache.GroupsByName) == 0 {
return nil, false
}
group, ok := cache.GroupsByName[name]
group, ok := gm.syncer.Get().GroupsByName[name]
return group, ok
}
func (gm *GroupManager) GetGroupByID(id uint) (*models.KeyGroup, bool) {
cache := gm.syncer.Get()
if len(cache.GroupsByID) == 0 {
return nil, false
}
group, ok := cache.GroupsByID[id]
group, ok := gm.syncer.Get().GroupsByID[id]
return group, ok
}
func (gm *GroupManager) Stop() {
gm.syncer.Stop()
}
func (gm *GroupManager) Invalidate() error {
return gm.syncer.Invalidate()
}
// --- Write Operations ---
// CreateKeyGroup creates a new key group, including its operational settings, and invalidates the cache.
func (gm *GroupManager) CreateKeyGroup(group *models.KeyGroup, settings *models.KeyGroupSettings) error {
if !utils.IsValidGroupName(group.Name) {
return errors.New("invalid group name: must contain only lowercase letters, numbers, and hyphens")
}
err := gm.db.Transaction(func(tx *gorm.DB) error {
// 1. Create the group itself to get an ID
if err := tx.Create(group).Error; err != nil {
return err
}
// 2. If settings are provided, create the associated GroupSettings record
if settings != nil {
// Only marshal non-nil fields to keep the JSON clean
settingsToMarshal := make(map[string]interface{})
if settings.EnableKeyCheck != nil {
settingsToMarshal["enable_key_check"] = settings.EnableKeyCheck
settingsJSON, err := json.Marshal(settings)
if err != nil {
return fmt.Errorf("failed to marshal settings: %w", err)
}
if settings.KeyCheckIntervalMinutes != nil {
settingsToMarshal["key_check_interval_minutes"] = settings.KeyCheckIntervalMinutes
groupSettings := models.GroupSettings{
GroupID: group.ID,
SettingsJSON: datatypes.JSON(settingsJSON),
}
if settings.KeyBlacklistThreshold != nil {
settingsToMarshal["key_blacklist_threshold"] = settings.KeyBlacklistThreshold
}
if settings.KeyCooldownMinutes != nil {
settingsToMarshal["key_cooldown_minutes"] = settings.KeyCooldownMinutes
}
if settings.KeyCheckConcurrency != nil {
settingsToMarshal["key_check_concurrency"] = settings.KeyCheckConcurrency
}
if settings.KeyCheckEndpoint != nil {
settingsToMarshal["key_check_endpoint"] = settings.KeyCheckEndpoint
}
if settings.KeyCheckModel != nil {
settingsToMarshal["key_check_model"] = settings.KeyCheckModel
}
if settings.MaxRetries != nil {
settingsToMarshal["max_retries"] = settings.MaxRetries
}
if settings.EnableSmartGateway != nil {
settingsToMarshal["enable_smart_gateway"] = settings.EnableSmartGateway
}
if len(settingsToMarshal) > 0 {
settingsJSON, err := json.Marshal(settingsToMarshal)
if err != nil {
return fmt.Errorf("failed to marshal group settings: %w", err)
}
groupSettings := models.GroupSettings{
GroupID: group.ID,
SettingsJSON: datatypes.JSON(settingsJSON),
}
if err := tx.Create(&groupSettings).Error; err != nil {
return fmt.Errorf("failed to save group settings: %w", err)
}
if err := tx.Create(&groupSettings).Error; err != nil {
return err
}
}
return nil
})
if err != nil {
return err
if err == nil {
go gm.Invalidate()
}
go gm.Invalidate()
return nil
return err
}
// UpdateKeyGroup updates an existing key group, its settings, and associations, then invalidates the cache.
func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *models.KeyGroupSettings, upstreamURLs []string, modelNames []string) error {
if !utils.IsValidGroupName(group.Name) {
return fmt.Errorf("invalid group name: must contain only lowercase letters, numbers, and hyphens")
@@ -273,7 +164,6 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
uniqueUpstreamURLs := uniqueStrings(upstreamURLs)
uniqueModelNames := uniqueStrings(modelNames)
err := gm.db.Transaction(func(tx *gorm.DB) error {
// --- 1. Update AllowedUpstreams (M:N relationship) ---
var upstreams []*models.UpstreamEndpoint
if len(uniqueUpstreamURLs) > 0 {
if err := tx.Where("url IN ?", uniqueUpstreamURLs).Find(&upstreams).Error; err != nil {
@@ -283,7 +173,6 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
if err := tx.Model(group).Association("AllowedUpstreams").Replace(upstreams); err != nil {
return err
}
if err := tx.Model(group).Association("AllowedModels").Clear(); err != nil {
return err
}
@@ -296,11 +185,9 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
return err
}
}
if err := tx.Model(group).Updates(group).Error; err != nil {
return err
}
var existingSettings models.GroupSettings
if err := tx.Where("group_id = ?", group.ID).First(&existingSettings).Error; err != nil && err != gorm.ErrRecordNotFound {
return err
@@ -308,15 +195,15 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
var currentSettingsData models.KeyGroupSettings
if len(existingSettings.SettingsJSON) > 0 {
if err := json.Unmarshal(existingSettings.SettingsJSON, &currentSettingsData); err != nil {
return fmt.Errorf("failed to unmarshal existing group settings: %w", err)
return fmt.Errorf("failed to unmarshal existing settings: %w", err)
}
}
if err := reflectutil.MergeNilFields(&currentSettingsData, newSettings); err != nil {
return fmt.Errorf("failed to merge group settings: %w", err)
return fmt.Errorf("failed to merge settings: %w", err)
}
updatedJSON, err := json.Marshal(currentSettingsData)
if err != nil {
return fmt.Errorf("failed to marshal updated group settings: %w", err)
return fmt.Errorf("failed to marshal updated settings: %w", err)
}
existingSettings.GroupID = group.ID
existingSettings.SettingsJSON = datatypes.JSON(updatedJSON)
@@ -327,55 +214,25 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
}
return err
}
// DeleteKeyGroup deletes a key group and subsequently cleans up any keys that have become orphans.
func (gm *GroupManager) DeleteKeyGroup(id uint) error {
err := gm.db.Transaction(func(tx *gorm.DB) error {
gm.logger.Infof("Starting transaction to delete KeyGroup ID: %d", id)
// Step 1: First, retrieve the group object we are about to delete.
var group models.KeyGroup
if err := tx.First(&group, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
gm.logger.Warnf("Attempted to delete a non-existent KeyGroup with ID: %d", id)
return nil // Don't treat as an error, the group is already gone.
return nil
}
gm.logger.WithError(err).Errorf("Failed to find KeyGroup with ID: %d for deletion", id)
return err
}
// Step 2: Clear all many-to-many and one-to-many associations using GORM's safe methods.
if err := tx.Model(&group).Association("AllowedUpstreams").Clear(); err != nil {
gm.logger.WithError(err).Errorf("Failed to clear 'AllowedUpstreams' association for KeyGroup ID: %d", id)
if err := tx.Select("AllowedUpstreams", "AllowedModels", "Mappings", "Settings").Delete(&group).Error; err != nil {
return err
}
if err := tx.Model(&group).Association("AllowedModels").Clear(); err != nil {
gm.logger.WithError(err).Errorf("Failed to clear 'AllowedModels' association for KeyGroup ID: %d", id)
return err
}
if err := tx.Model(&group).Association("Mappings").Clear(); err != nil {
gm.logger.WithError(err).Errorf("Failed to clear 'Mappings' (API Key associations) for KeyGroup ID: %d", id)
return err
}
// Also clear settings if they exist to maintain data integrity
if err := tx.Model(&group).Association("Settings").Delete(group.Settings); err != nil {
gm.logger.WithError(err).Errorf("Failed to delete 'Settings' association for KeyGroup ID: %d", id)
return err
}
// Step 3: Delete the KeyGroup itself.
if err := tx.Delete(&group).Error; err != nil {
gm.logger.WithError(err).Errorf("Failed to delete KeyGroup ID: %d", id)
return err
}
gm.logger.Infof("KeyGroup ID %d associations cleared and entity deleted. Triggering orphan key cleanup.", id)
// Step 4: Trigger the orphan key cleanup (this logic remains the same and is correct).
deletedCount, err := gm.keyRepo.DeleteOrphanKeysTx(tx)
if err != nil {
gm.logger.WithError(err).Error("Failed to clean up orphan keys after deleting group.")
return err
}
if deletedCount > 0 {
gm.logger.Infof("Successfully cleaned up %d orphan keys.", deletedCount)
gm.logger.Infof("Cleaned up %d orphan keys after deleting group %d", deletedCount, id)
}
gm.logger.Infof("Transaction for deleting KeyGroup ID: %d completed successfully.", id)
return nil
})
if err == nil {
@@ -383,7 +240,6 @@ func (gm *GroupManager) DeleteKeyGroup(id uint) error {
}
return err
}
func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
var originalGroup models.KeyGroup
if err := gm.db.
@@ -392,7 +248,7 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
Preload("AllowedUpstreams").
Preload("AllowedModels").
First(&originalGroup, id).Error; err != nil {
return nil, fmt.Errorf("failed to find original group with id %d: %w", id, err)
return nil, fmt.Errorf("failed to find original group %d: %w", id, err)
}
newGroup := originalGroup
timestamp := time.Now().Unix()
@@ -401,31 +257,25 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
newGroup.DisplayName = fmt.Sprintf("%s-clone-%d", originalGroup.DisplayName, timestamp)
newGroup.CreatedAt = time.Time{}
newGroup.UpdatedAt = time.Time{}
newGroup.RequestConfigID = nil
newGroup.RequestConfig = nil
newGroup.Mappings = nil
newGroup.AllowedUpstreams = nil
newGroup.AllowedModels = nil
err := gm.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Create(&newGroup).Error; err != nil {
return err
}
if originalGroup.RequestConfig != nil {
newRequestConfig := *originalGroup.RequestConfig
newRequestConfig.ID = 0 // Mark as new record
newRequestConfig.ID = 0
if err := tx.Create(&newRequestConfig).Error; err != nil {
return fmt.Errorf("failed to clone request config: %w", err)
}
if err := tx.Model(&newGroup).Update("request_config_id", newRequestConfig.ID).Error; err != nil {
return fmt.Errorf("failed to link new group to cloned request config: %w", err)
return fmt.Errorf("failed to link cloned request config: %w", err)
}
}
var originalSettings models.GroupSettings
err := tx.Where("group_id = ?", originalGroup.ID).First(&originalSettings).Error
if err == nil && len(originalSettings.SettingsJSON) > 0 {
@@ -434,12 +284,11 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
SettingsJSON: originalSettings.SettingsJSON,
}
if err := tx.Create(&newSettings).Error; err != nil {
return fmt.Errorf("failed to clone group settings: %w", err)
return fmt.Errorf("failed to clone settings: %w", err)
}
} else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("failed to query original group settings: %w", err)
return fmt.Errorf("failed to query original settings: %w", err)
}
if len(originalGroup.Mappings) > 0 {
newMappings := make([]models.GroupAPIKeyMapping, len(originalGroup.Mappings))
for i, oldMapping := range originalGroup.Mappings {
@@ -454,7 +303,7 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
}
}
if err := tx.Create(&newMappings).Error; err != nil {
return fmt.Errorf("failed to clone key group mappings: %w", err)
return fmt.Errorf("failed to clone mappings: %w", err)
}
}
if len(originalGroup.AllowedUpstreams) > 0 {
@@ -469,13 +318,10 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
}
return nil
})
if err != nil {
return nil, err
}
go gm.Invalidate()
var finalClonedGroup models.KeyGroup
if err := gm.db.
Preload("RequestConfig").
@@ -487,10 +333,8 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
}
return &finalClonedGroup, nil
}
func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) {
globalSettings := gm.settingsManager.GetSettings()
s := "gemini-1.5-flash" // Per user feedback for default model
opConfig := &models.KeyGroupSettings{
EnableKeyCheck: &globalSettings.EnableBaseKeyCheck,
KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency,
@@ -498,52 +342,43 @@ func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.
KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL,
KeyBlacklistThreshold: &globalSettings.BlacklistThreshold,
KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes,
KeyCheckModel: &s,
KeyCheckModel: &globalSettings.BaseKeyCheckModel,
MaxRetries: &globalSettings.MaxRetries,
EnableSmartGateway: &globalSettings.EnableSmartGateway,
}
if group == nil {
return opConfig, nil
}
var groupSettingsRecord models.GroupSettings
err := gm.db.Where("group_id = ?", group.ID).First(&groupSettingsRecord).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return opConfig, nil
}
gm.logger.WithError(err).Errorf("Failed to query group settings for group ID %d", group.ID)
return nil, err
}
if len(groupSettingsRecord.SettingsJSON) == 0 {
return opConfig, nil
}
var groupSpecificSettings models.KeyGroupSettings
if err := json.Unmarshal(groupSettingsRecord.SettingsJSON, &groupSpecificSettings); err != nil {
gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings JSON.")
return opConfig, err
}
if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil {
gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group-specific settings over defaults.")
gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings")
return opConfig, nil
}
if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil {
gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group settings")
return opConfig, nil
}
return opConfig, nil
}
func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
group, ok := gm.GetGroupByID(groupID)
if !ok {
return "", fmt.Errorf("group with id %d not found", groupID)
return "", fmt.Errorf("group %d not found", groupID)
}
opConfig, err := gm.BuildOperationalConfig(group)
if err != nil {
return "", fmt.Errorf("failed to build operational config for group %d: %w", groupID, err)
return "", err
}
globalSettings := gm.settingsManager.GetSettings()
baseURL := globalSettings.DefaultUpstreamURL
@@ -551,7 +386,7 @@ func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
baseURL = *opConfig.KeyCheckEndpoint
}
if baseURL == "" {
return "", fmt.Errorf("no key check endpoint or default upstream URL is configured for group %d", groupID)
return "", fmt.Errorf("no endpoint configured for group %d", groupID)
}
modelName := globalSettings.BaseKeyCheckModel
if opConfig.KeyCheckModel != nil && *opConfig.KeyCheckModel != "" {
@@ -559,38 +394,41 @@ func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
}
parsedURL, err := url.Parse(baseURL)
if err != nil {
return "", fmt.Errorf("failed to parse base URL '%s': %w", baseURL, err)
return "", fmt.Errorf("invalid URL '%s': %w", baseURL, err)
}
cleanedPath := parsedURL.Path
cleanedPath = strings.TrimSuffix(cleanedPath, "/")
cleanedPath = strings.TrimSuffix(cleanedPath, "/v1beta")
parsedURL.Path = path.Join(cleanedPath, "v1beta", "models", modelName)
finalEndpoint := parsedURL.String()
return finalEndpoint, nil
cleanedPath := strings.TrimSuffix(strings.TrimSuffix(parsedURL.Path, "/"), "/v1beta")
parsedURL.Path = path.Join(cleanedPath, "v1beta/models", modelName)
return parsedURL.String(), nil
}
func (gm *GroupManager) UpdateOrder(payload []UpdateOrderPayload) error {
ordersMap := make(map[uint]int, len(payload))
for _, item := range payload {
ordersMap[item.ID] = item.Order
}
if err := gm.groupRepo.UpdateOrderInTransaction(ordersMap); err != nil {
gm.logger.WithError(err).Error("Failed to update group order in transaction")
return fmt.Errorf("database transaction failed: %w", err)
return fmt.Errorf("failed to update order: %w", err)
}
gm.logger.Info("Group order updated successfully, invalidating cache...")
go gm.Invalidate()
return nil
}
func uniqueStrings(slice []string) []string {
keys := make(map[string]struct{})
list := []string{}
for _, entry := range slice {
if _, value := keys[entry]; !value {
keys[entry] = struct{}{}
list = append(list, entry)
seen := make(map[string]struct{}, len(slice))
result := make([]string, 0, len(slice))
for _, s := range slice {
if _, exists := seen[s]; !exists {
seen[s] = struct{}{}
result = append(result, s)
}
}
return list
return result
}
// GroupManager配置Syncer
func NewGroupManagerSyncer(
loader syncer.LoaderFunc[GroupManagerCacheData],
store store.Store,
logger *logrus.Logger,
) (*syncer.CacheSyncer[GroupManagerCacheData], error) {
const groupUpdateChannel = "groups:cache_invalidation"
return syncer.NewCacheSyncer(loader, store, groupUpdateChannel, logger)
}

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,7 @@
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
@@ -22,6 +23,10 @@ const (
TaskTypeHardDeleteKeys = "hard_delete_keys"
TaskTypeRestoreKeys = "restore_keys"
chunkSize = 500
// 任务超时时间常量化
defaultTaskTimeout = 15 * time.Minute
longTaskTimeout = time.Hour
)
type KeyImportService struct {
@@ -42,295 +47,425 @@ func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.
}
}
// --- 通用的 Panic-Safe 任務執行器 ---
func (s *KeyImportService) runTaskWithRecovery(taskID string, resourceID string, taskFunc func()) {
// runTaskWithRecovery 统一的任务恢复包装器
func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) {
defer func() {
if r := recover(); r != nil {
err := fmt.Errorf("panic recovered in task %s: %v", taskID, r)
s.logger.Error(err)
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
s.logger.WithField("task_id", taskID).Error(err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
}
}()
taskFunc()
}
// --- Public Task Starters ---
func (s *KeyImportService) StartAddKeysTask(groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
// StartAddKeysTask 启动批量添加密钥任务
func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found in input text")
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), defaultTaskTimeout)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
s.runAddKeysTask(taskStatus.ID, resourceID, groupID, keys, validateOnImport)
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport)
})
return taskStatus, nil
}
func (s *KeyImportService) StartUnlinkKeysTask(groupID uint, keysText string) (*task.Status, error) {
// StartUnlinkKeysTask 启动批量解绑密钥任务
func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), longTaskTimeout)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
s.runUnlinkKeysTask(taskStatus.ID, resourceID, groupID, keys)
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys)
})
return taskStatus, nil
}
func (s *KeyImportService) StartHardDeleteKeysTask(keysText string) (*task.Status, error) {
// StartHardDeleteKeysTask 启动硬删除密钥任务
func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := "global_hard_delete" // Global lock
taskStatus, err := s.taskService.StartTask(0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
resourceID := "global_hard_delete"
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), longTaskTimeout)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
s.runHardDeleteKeysTask(taskStatus.ID, resourceID, keys)
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
})
return taskStatus, nil
}
func (s *KeyImportService) StartRestoreKeysTask(keysText string) (*task.Status, error) {
// StartRestoreKeysTask 启动恢复密钥任务
func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
keys := utils.ParseKeysFromText(keysText)
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found")
}
resourceID := "global_restore_keys" // Global lock
taskStatus, err := s.taskService.StartTask(0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
resourceID := "global_restore_keys"
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), longTaskTimeout)
if err != nil {
return nil, err
}
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
s.runRestoreKeysTask(taskStatus.ID, resourceID, keys)
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
})
return taskStatus, nil
}
// --- Private Task Runners ---
// StartUnlinkKeysByFilterTask 根据状态过滤条件批量解绑
func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
// 步骤 1: 对输入的原始 key 列表进行去重。
uniqueKeysMap := make(map[string]struct{})
var uniqueKeyStrings []string
for _, kStr := range keys {
if _, exists := uniqueKeysMap[kStr]; !exists {
uniqueKeysMap[kStr] = struct{}{}
uniqueKeyStrings = append(uniqueKeyStrings, kStr)
}
}
if len(uniqueKeyStrings) == 0 {
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
return
}
// 步骤 2: 确保所有 Key 在主表中存在(创建或恢复),并获取它们完整的实体。
keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings))
for i, keyStr := range uniqueKeyStrings {
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
}
allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure)
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
}
if len(keyValues) == 0 {
return nil, fmt.Errorf("no keys found matching the provided filter")
}
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
}
// ==================== 核心任务执行逻辑 ====================
// runAddKeysTask 执行批量添加密钥
func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
// 1. 去重
uniqueKeys := s.deduplicateKeys(keys)
if len(uniqueKeys) == 0 {
s.endTaskWithResult(ctx, taskID, resourceID, gin.H{
"newly_linked_count": 0,
"already_linked_count": 0,
}, nil)
return
}
// 步骤 3: 找出在这些 Key 中,哪些【已经】被链接到了当前分组。
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID)
// 2. 确保所有密钥在数据库中存在(幂等操作)
allKeyModels, err := s.ensureKeysExist(uniqueKeys)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
return
}
alreadyLinkedIDSet := make(map[uint]struct{})
for _, key := range alreadyLinkedModels {
alreadyLinkedIDSet[key.ID] = struct{}{}
// 3. 过滤已关联的密钥
keysToLink, alreadyLinkedCount, err := s.filterNewKeys(allKeyModels, groupID, uniqueKeys)
if err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check linked keys: %w", err))
return
}
// 步骤 4: 确定【真正需要】被链接到当前分组的 key 列表 (我们的"工作集")。
var keysToLink []models.APIKey
for _, key := range allKeyModels {
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
keysToLink = append(keysToLink, key)
}
}
// 步骤 5: 更新任务的 Total 总量为精确的 "工作集" 大小。
if err := s.taskService.UpdateTotalByID(taskID, len(keysToLink)); err != nil {
// 4. 更新任务的实际处理总数
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
}
// 步骤 6: 分块处理【链接Key到组】的操作并实时更新进度。
// 5. 批量关联密钥到组
if len(keysToLink) > 0 {
idsToLink := make([]uint, len(keysToLink))
for i, key := range keysToLink {
idsToLink[i] = key.ID
}
for i := 0; i < len(idsToLink); i += chunkSize {
end := i + chunkSize
if end > len(idsToLink) {
end = len(idsToLink)
}
chunk := idsToLink[i:end]
if err := s.keyRepo.LinkKeysToGroup(groupID, chunk); err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
return
}
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
if err := s.linkKeysInChunks(ctx, taskID, groupID, keysToLink); err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
return
}
}
// 步骤 7: 准备最终结果并结束任务。
// 6. 根据验证标志处理密钥状态
if len(keysToLink) > 0 {
s.processNewlyLinkedKeys(ctx, groupID, keysToLink, validateOnImport)
}
// 7. 返回结果
result := gin.H{
"newly_linked_count": len(keysToLink),
"already_linked_count": len(alreadyLinkedIDSet),
"already_linked_count": alreadyLinkedCount,
"total_linked_count": len(allKeyModels),
}
// 步骤 8: 根据 `validateOnImport` 标志, 发布事件或直接激活 (只对新链接的keys操作)。
if len(keysToLink) > 0 {
idsToLink := make([]uint, len(keysToLink))
for i, key := range keysToLink {
idsToLink[i] = key.ID
}
if validateOnImport {
s.publishImportGroupCompletedEvent(groupID, idsToLink)
for _, keyID := range idsToLink {
s.publishSingleKeyChangeEvent(groupID, keyID, "", models.StatusPendingValidation, "key_linked")
}
} else {
for _, keyID := range idsToLink {
if _, err := s.apiKeyService.UpdateMappingStatus(groupID, keyID, models.StatusActive); err != nil {
s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err)
}
}
}
}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
}
// runUnlinkKeysTask
func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID uint, keys []string) {
uniqueKeysMap := make(map[string]struct{})
var uniqueKeys []string
for _, kStr := range keys {
if _, exists := uniqueKeysMap[kStr]; !exists {
uniqueKeysMap[kStr] = struct{}{}
uniqueKeys = append(uniqueKeys, kStr)
}
}
// 步骤 1: 一次性找出所有输入 Key 中,实际存在于本组的 Key 实体。这是我们的"工作集"。
// runUnlinkKeysTask 执行批量解绑密钥
func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
// 1. 去重
uniqueKeys := s.deduplicateKeys(keys)
// 2. 查找需要解绑的密钥
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
return
}
if len(keysToUnlink) == 0 {
result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
result := gin.H{
"unlinked_count": 0,
"hard_deleted_count": 0,
"not_found_count": len(uniqueKeys),
}
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
return
}
idsToUnlink := make([]uint, len(keysToUnlink))
for i, key := range keysToUnlink {
idsToUnlink[i] = key.ID
}
// 步骤 2: 更新任务的 Total 总量为精确的 "工作集" 大小。
if err := s.taskService.UpdateTotalByID(taskID, len(idsToUnlink)); err != nil {
// 3. 提取密钥 ID
idsToUnlink := s.extractKeyIDs(keysToUnlink)
// 4. 更新任务总数
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
}
var totalUnlinked int64
// 步骤 3: 分块处理【解绑Key】的操作并上报进度。
for i := 0; i < len(idsToUnlink); i += chunkSize {
end := i + chunkSize
if end > len(idsToUnlink) {
end = len(idsToUnlink)
}
chunk := idsToUnlink[i:end]
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(groupID, chunk)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
return
}
totalUnlinked += unlinked
for _, keyID := range chunk {
s.publishSingleKeyChangeEvent(groupID, keyID, models.StatusActive, "", "key_unlinked")
}
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
// 5. 批量解绑
totalUnlinked, err := s.unlinkKeysInChunks(ctx, taskID, groupID, idsToUnlink)
if err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
return
}
// 6. 清理孤立密钥
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
if err != nil {
s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.")
}
// 7. 返回结果
result := gin.H{
"unlinked_count": totalUnlinked,
"hard_deleted_count": totalDeleted,
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
}
func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys []string) {
var totalDeleted int64
for i := 0; i < len(keys); i += chunkSize {
end := i + chunkSize
if end > len(keys) {
end = len(keys)
}
chunk := keys[i:end]
// runHardDeleteKeysTask 执行硬删除密钥
func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
totalDeleted, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) {
return s.keyRepo.HardDeleteByValues(chunk)
})
deleted, err := s.keyRepo.HardDeleteByValues(chunk)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
return
}
totalDeleted += deleted
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
if err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
return
}
result := gin.H{
"hard_deleted_count": totalDeleted,
"not_found_count": int64(len(keys)) - totalDeleted,
}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.publishChangeEvent(0, "keys_hard_deleted") // Global event
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
}
func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys []string) {
var restoredCount int64
for i := 0; i < len(keys); i += chunkSize {
end := i + chunkSize
if end > len(keys) {
end = len(keys)
}
chunk := keys[i:end]
// runRestoreKeysTask 执行恢复密钥
func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
restoredCount, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) {
return s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
})
count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
if err != nil {
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
return
}
restoredCount += count
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
if err != nil {
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
return
}
result := gin.H{
"restored_count": restoredCount,
"not_found_count": int64(len(keys)) - restoredCount,
}
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
s.publishChangeEvent(0, "keys_bulk_restored") // Global event
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
}
func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
// ==================== 辅助方法 ====================
// deduplicateKeys 去重密钥列表
func (s *KeyImportService) deduplicateKeys(keys []string) []string {
uniqueKeysMap := make(map[string]struct{}, len(keys))
uniqueKeys := make([]string, 0, len(keys))
for _, kStr := range keys {
if _, exists := uniqueKeysMap[kStr]; !exists {
uniqueKeysMap[kStr] = struct{}{}
uniqueKeys = append(uniqueKeys, kStr)
}
}
return uniqueKeys
}
// ensureKeysExist 确保所有密钥在数据库中存在
func (s *KeyImportService) ensureKeysExist(keys []string) ([]models.APIKey, error) {
keysToEnsure := make([]models.APIKey, len(keys))
for i, keyStr := range keys {
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
}
return s.keyRepo.AddKeys(keysToEnsure)
}
// filterNewKeys 过滤已关联的密钥,返回需要新增的密钥
func (s *KeyImportService) filterNewKeys(allKeyModels []models.APIKey, groupID uint, uniqueKeys []string) ([]models.APIKey, int, error) {
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
if err != nil {
return nil, 0, err
}
alreadyLinkedIDSet := make(map[uint]struct{}, len(alreadyLinkedModels))
for _, key := range alreadyLinkedModels {
alreadyLinkedIDSet[key.ID] = struct{}{}
}
keysToLink := make([]models.APIKey, 0, len(allKeyModels)-len(alreadyLinkedIDSet))
for _, key := range allKeyModels {
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
keysToLink = append(keysToLink, key)
}
}
return keysToLink, len(alreadyLinkedIDSet), nil
}
// extractKeyIDs 提取密钥 ID 列表
func (s *KeyImportService) extractKeyIDs(keys []models.APIKey) []uint {
ids := make([]uint, len(keys))
for i, key := range keys {
ids[i] = key.ID
}
return ids
}
// linkKeysInChunks 分块关联密钥到组
func (s *KeyImportService) linkKeysInChunks(ctx context.Context, taskID string, groupID uint, keysToLink []models.APIKey) error {
idsToLink := s.extractKeyIDs(keysToLink)
for i := 0; i < len(idsToLink); i += chunkSize {
end := min(i+chunkSize, len(idsToLink))
chunk := idsToLink[i:end]
if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
return fmt.Errorf("chunk failed to link keys: %w", err)
}
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
}
return nil
}
// unlinkKeysInChunks 分块解绑密钥
func (s *KeyImportService) unlinkKeysInChunks(ctx context.Context, taskID string, groupID uint, idsToUnlink []uint) (int64, error) {
var totalUnlinked int64
for i := 0; i < len(idsToUnlink); i += chunkSize {
end := min(i+chunkSize, len(idsToUnlink))
chunk := idsToUnlink[i:end]
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk)
if err != nil {
return 0, fmt.Errorf("chunk failed: could not unlink keys: %w", err)
}
totalUnlinked += unlinked
// 发布解绑事件
for _, keyID := range chunk {
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked")
}
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
}
return totalUnlinked, nil
}
// processKeysInChunks 通用的分块处理密钥逻辑
func (s *KeyImportService) processKeysInChunks(
ctx context.Context,
taskID string,
keys []string,
processFunc func(chunk []string) (int64, error),
) (int64, error) {
var totalProcessed int64
for i := 0; i < len(keys); i += chunkSize {
end := min(i+chunkSize, len(keys))
chunk := keys[i:end]
count, err := processFunc(chunk)
if err != nil {
return 0, fmt.Errorf("failed to process chunk: %w", err)
}
totalProcessed += count
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
}
return totalProcessed, nil
}
// processNewlyLinkedKeys 处理新关联的密钥(验证或直接激活)
func (s *KeyImportService) processNewlyLinkedKeys(ctx context.Context, groupID uint, keysToLink []models.APIKey, validateOnImport bool) {
idsToLink := s.extractKeyIDs(keysToLink)
if validateOnImport {
// 发布批量导入完成事件,触发验证
s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
// 发布单个密钥状态变更事件
for _, keyID := range idsToLink {
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
}
} else {
// 直接激活密钥,不进行验证
for _, keyID := range idsToLink {
if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_id": keyID,
}).Errorf("Failed to directly activate key: %v", err)
}
}
}
}
// endTaskWithResult 统一的任务结束处理
func (s *KeyImportService) endTaskWithResult(ctx context.Context, taskID, resourceID string, result gin.H, err error) {
if err != nil {
s.logger.WithFields(logrus.Fields{
"task_id": taskID,
"resource_id": resourceID,
}).WithError(err).Error("Task failed")
}
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, err)
}
// ==================== 事件发布方法 ====================
// publishSingleKeyChangeEvent 发布单个密钥状态变更事件
func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
event := models.KeyStatusChangedEvent{
GroupID: groupID,
KeyID: keyID,
@@ -339,59 +474,88 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldS
ChangeReason: reason,
ChangedAt: time.Now(),
}
eventData, _ := json.Marshal(event)
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.WithError(err).WithFields(logrus.Fields{
eventData, err := json.Marshal(event)
if err != nil {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_id": keyID,
"reason": reason,
}).Error("Failed to publish single key change event.")
}).WithError(err).Error("Failed to marshal key change event")
return
}
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_id": keyID,
"reason": reason,
}).WithError(err).Error("Failed to publish single key change event")
}
}
func (s *KeyImportService) publishChangeEvent(groupID uint, reason string) {
// publishChangeEvent 发布通用变更事件
func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) {
event := models.KeyStatusChangedEvent{
GroupID: groupID,
ChangeReason: reason,
}
eventData, _ := json.Marshal(event)
_ = s.store.Publish(models.TopicKeyStatusChanged, eventData)
eventData, err := json.Marshal(event)
if err != nil {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"reason": reason,
}).WithError(err).Error("Failed to marshal change event")
return
}
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"reason": reason,
}).WithError(err).Error("Failed to publish change event")
}
}
func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs []uint) {
// publishImportGroupCompletedEvent 发布批量导入完成事件
func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) {
if len(keyIDs) == 0 {
return
}
event := models.ImportGroupCompletedEvent{
GroupID: groupID,
KeyIDs: keyIDs,
CompletedAt: time.Now(),
}
eventData, err := json.Marshal(event)
if err != nil {
s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.")
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_count": len(keyIDs),
}).WithError(err).Error("Failed to marshal ImportGroupCompletedEvent")
return
}
if err := s.store.Publish(models.TopicImportGroupCompleted, eventData); err != nil {
s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.")
if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_count": len(keyIDs),
}).WithError(err).Error("Failed to publish ImportGroupCompletedEvent")
} else {
s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs))
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"key_count": len(keyIDs),
}).Info("Published ImportGroupCompletedEvent")
}
}
// [NEW] StartUnlinkKeysByFilterTask starts a task to unlink keys matching a status filter.
func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
// 1. [New] Find the keys to operate on.
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
// min 返回两个整数中的较小值
func min(a, b int) int {
if a < b {
return a
}
if len(keyValues) == 0 {
return nil, fmt.Errorf("no keys found matching the provided filter")
}
// 2. [REUSE] Convert to text and call the existing, robust unlink task logic.
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
return s.StartUnlinkKeysTask(groupID, keysAsText)
}
return b
}

View File

@@ -2,6 +2,7 @@
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/channel"
@@ -24,26 +25,38 @@ import (
)
const (
TaskTypeTestKeys = "test_keys"
TaskTypeTestKeys = "test_keys"
defaultConcurrency = 10
maxValidationConcurrency = 100
validationTaskTimeout = time.Hour
)
type KeyValidationService struct {
taskService task.Reporter
channel channel.ChannelProxy
db *gorm.DB
SettingsManager *settings.SettingsManager
settingsManager *settings.SettingsManager
groupManager *GroupManager
store store.Store
keyRepo repository.KeyRepository
logger *logrus.Entry
}
func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm.DB, ss *settings.SettingsManager, gm *GroupManager, st store.Store, kr repository.KeyRepository, logger *logrus.Logger) *KeyValidationService {
func NewKeyValidationService(
ts task.Reporter,
ch channel.ChannelProxy,
db *gorm.DB,
ss *settings.SettingsManager,
gm *GroupManager,
st store.Store,
kr repository.KeyRepository,
logger *logrus.Logger,
) *KeyValidationService {
return &KeyValidationService{
taskService: ts,
channel: ch,
db: db,
SettingsManager: ss,
settingsManager: ss,
groupManager: gm,
store: st,
keyRepo: kr,
@@ -51,53 +64,54 @@ func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm
}
}
// ==================== 公开接口 ====================
// ValidateSingleKey 验证单个密钥
func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout time.Duration, endpoint string) error {
// 1. 解密密钥
if err := s.keyRepo.Decrypt(key); err != nil {
return fmt.Errorf("failed to decrypt key %d for validation: %w", key.ID, err)
}
// 2. 创建 HTTP 客户端和请求
client := &http.Client{Timeout: timeout}
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
s.logger.Errorf("Failed to create request for key validation (ID: %d): %v", key.ID, err)
s.logger.WithFields(logrus.Fields{
"key_id": key.ID,
"endpoint": endpoint,
}).Error("Failed to create validation request")
return fmt.Errorf("failed to create request: %w", err)
}
s.channel.ModifyRequest(req, key) // Use the injected channel to modify the request
// 3. 修改请求(添加密钥认证头)
s.channel.ModifyRequest(req, key)
// 4. 执行请求
resp, err := client.Do(req)
if err != nil {
// This is a network-level error (e.g., timeout, DNS issue)
return fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
// 5. 检查响应状态
if resp.StatusCode == http.StatusOK {
return nil // Success
return nil
}
// Read the body for more error details
bodyBytes, readErr := io.ReadAll(resp.Body)
var errorMsg string
if readErr != nil {
errorMsg = "Failed to read error response body"
} else {
errorMsg = string(bodyBytes)
}
// This is a validation failure with a specific HTTP status code
return &CustomErrors.APIError{
HTTPStatus: resp.StatusCode,
Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg),
Code: "VALIDATION_FAILED",
}
// 6. 处理错误响应
return s.buildValidationError(resp)
}
// --- 异步任务方法 (全面适配新task包) ---
func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string) (*task.Status, error) {
// StartTestKeysTask 启动批量密钥测试任务
func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
// 1. 解析和验证输入
keyStrings := utils.ParseKeysFromText(keysText)
if len(keyStrings) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
}
// 2. 查询密钥模型
apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
@@ -105,116 +119,345 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
if len(apiKeyModels) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system")
}
// 3. 批量解密密钥
if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil {
s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task.")
s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task")
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation")
}
// 4. 获取组配置
group, ok := s.groupManager.GetGroupByID(groupID)
if !ok {
// [FIX] Correctly use the NewAPIError constructor for a missing group.
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
}
opConfig, err := s.groupManager.BuildOperationalConfig(group)
if err != nil {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
}
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
if err != nil {
return nil, err // Pass up the error from task service (e.g., "task already running")
}
settings := s.SettingsManager.GetSettings()
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
// 5. 构建验证端点
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
if err != nil {
s.taskService.EndTaskByID(taskStatus.ID, resourceID, nil, err) // End task with error if endpoint fails
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build endpoint: %v", err))
}
// 6. 创建任务
resourceID := fmt.Sprintf("group-%d", groupID)
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), validationTaskTimeout)
if err != nil {
return nil, err
}
var concurrency int
if opConfig.KeyCheckConcurrency != nil {
concurrency = *opConfig.KeyCheckConcurrency
} else {
concurrency = settings.BaseKeyCheckConcurrency
}
go s.runTestKeysTask(taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
// 7. 准备任务参数
params := s.buildValidationParams(opConfig)
// 8. 启动异步验证任务
go s.runTestKeysTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, params, endpoint)
return taskStatus, nil
}
func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
var wg sync.WaitGroup
var mu sync.Mutex
finalResults := make([]models.KeyTestResult, len(keys))
processedCount := 0
if concurrency <= 0 {
concurrency = 10
}
type job struct {
Index int
Value models.APIKey
}
jobs := make(chan job, len(keys))
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := range jobs {
apiKeyModel := j.Value
keyToValidate := apiKeyModel
validationErr := s.ValidateSingleKey(&keyToValidate, timeout, endpoint)
// StartTestKeysByFilterTask 根据状态过滤启动批量测试任务
func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
s.logger.WithFields(logrus.Fields{
"group_id": groupID,
"statuses": statuses,
}).Info("Starting test task with status filter")
var currentResult models.KeyTestResult
event := models.RequestFinishedEvent{
RequestLog: models.RequestLog{
// GroupID 和 KeyID 在 RequestLog 模型中是指针,需要取地址
GroupID: &groupID,
KeyID: &apiKeyModel.ID,
},
}
if validationErr == nil {
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "valid", Message: "Validation successful."}
event.RequestLog.IsSuccess = true
} else {
var apiErr *CustomErrors.APIError
if CustomErrors.As(validationErr, &apiErr) {
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "invalid", Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message)}
event.Error = apiErr
} else {
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "error", Message: "Validation check failed: " + validationErr.Error()}
event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
}
event.RequestLog.IsSuccess = false
}
eventData, _ := json.Marshal(event)
if err := s.store.Publish(models.TopicRequestFinished, eventData); err != nil {
s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID)
}
mu.Lock()
finalResults[j.Index] = currentResult
processedCount++
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
mu.Unlock()
}
}()
}
for i, k := range keys {
jobs <- job{Index: i, Value: k}
}
close(jobs)
wg.Wait()
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"results": finalResults}, nil)
}
func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses)
// 1. 根据过滤条件查询密钥
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
if err != nil {
return nil, CustomErrors.ParseDBError(err)
}
if len(keyValues) == 0 {
return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria.")
return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria")
}
// 2. 转换为文本格式并启动任务
keysAsText := strings.Join(keyValues, "\n")
s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID)
return s.StartTestKeysTask(groupID, keysAsText)
s.logger.Infof("Found %d keys to validate for group %d", len(keyValues), groupID)
return s.StartTestKeysTask(ctx, groupID, keysAsText)
}
// ==================== 核心任务执行逻辑 ====================
// validationParams 验证参数封装
type validationParams struct {
timeout time.Duration
concurrency int
}
// buildValidationParams 构建验证参数
func (s *KeyValidationService) buildValidationParams(opConfig *models.KeyGroupSettings) validationParams {
settings := s.settingsManager.GetSettings()
// 从配置读取超时时间(而非硬编码)
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
if timeout <= 0 {
timeout = 30 * time.Second // 仅在配置无效时使用默认值
}
// 从配置读取并发数(优先级:组配置 > 全局配置 > 兜底默认值)
var concurrency int
if opConfig.KeyCheckConcurrency != nil && *opConfig.KeyCheckConcurrency > 0 {
concurrency = *opConfig.KeyCheckConcurrency
} else if settings.BaseKeyCheckConcurrency > 0 {
concurrency = settings.BaseKeyCheckConcurrency
} else {
concurrency = defaultConcurrency // 兜底默认值
}
// 限制最大并发数(防护措施)
if concurrency > maxValidationConcurrency {
concurrency = maxValidationConcurrency
}
return validationParams{
timeout: timeout,
concurrency: concurrency,
}
}
// runTestKeysTaskWithRecovery 带恢复机制的任务执行包装器
func (s *KeyValidationService) runTestKeysTaskWithRecovery(
ctx context.Context,
taskID string,
resourceID string,
groupID uint,
keys []models.APIKey,
params validationParams,
endpoint string,
) {
defer func() {
if r := recover(); r != nil {
err := fmt.Errorf("panic recovered in validation task %s: %v", taskID, r)
s.logger.WithField("task_id", taskID).Error(err)
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
}
}()
s.runTestKeysTask(ctx, taskID, resourceID, groupID, keys, params, endpoint)
}
// runTestKeysTask 执行批量密钥验证任务
func (s *KeyValidationService) runTestKeysTask(
ctx context.Context,
taskID string,
resourceID string,
groupID uint,
keys []models.APIKey,
params validationParams,
endpoint string,
) {
s.logger.WithFields(logrus.Fields{
"task_id": taskID,
"group_id": groupID,
"key_count": len(keys),
"concurrency": params.concurrency,
"timeout": params.timeout,
}).Info("Starting validation task")
// 1. 初始化结果收集
results := make([]models.KeyTestResult, len(keys))
// 2. 创建任务分发器
dispatcher := newValidationDispatcher(
keys,
params.concurrency,
s,
ctx,
taskID,
groupID,
endpoint,
params.timeout,
)
// 3. 执行并发验证
dispatcher.run(results)
// 4. 完成任务
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": results}, nil)
s.logger.WithFields(logrus.Fields{
"task_id": taskID,
"group_id": groupID,
"processed": len(results),
}).Info("Validation task completed")
}
// ==================== 验证调度器 ====================
// validationJob 验证作业
type validationJob struct {
index int
key models.APIKey
}
// validationDispatcher 验证任务分发器
type validationDispatcher struct {
keys []models.APIKey
concurrency int
service *KeyValidationService
ctx context.Context
taskID string
groupID uint
endpoint string
timeout time.Duration
mu sync.Mutex
processedCount int
}
// newValidationDispatcher 创建验证分发器
func newValidationDispatcher(
keys []models.APIKey,
concurrency int,
service *KeyValidationService,
ctx context.Context,
taskID string,
groupID uint,
endpoint string,
timeout time.Duration,
) *validationDispatcher {
return &validationDispatcher{
keys: keys,
concurrency: concurrency,
service: service,
ctx: ctx,
taskID: taskID,
groupID: groupID,
endpoint: endpoint,
timeout: timeout,
}
}
// run 执行并发验证
func (d *validationDispatcher) run(results []models.KeyTestResult) {
var wg sync.WaitGroup
jobs := make(chan validationJob, len(d.keys))
// 启动 worker pool
for i := 0; i < d.concurrency; i++ {
wg.Add(1)
go d.worker(&wg, jobs, results)
}
// 分发任务
for i, key := range d.keys {
jobs <- validationJob{index: i, key: key}
}
close(jobs)
// 等待所有 worker 完成
wg.Wait()
}
// worker 验证工作协程
func (d *validationDispatcher) worker(wg *sync.WaitGroup, jobs <-chan validationJob, results []models.KeyTestResult) {
defer wg.Done()
for job := range jobs {
result := d.validateKey(job.key)
d.mu.Lock()
results[job.index] = result
d.processedCount++
_ = d.service.taskService.UpdateProgressByID(d.ctx, d.taskID, d.processedCount)
d.mu.Unlock()
}
}
// validateKey 验证单个密钥并返回结果
func (d *validationDispatcher) validateKey(key models.APIKey) models.KeyTestResult {
// 1. 执行验证
validationErr := d.service.ValidateSingleKey(&key, d.timeout, d.endpoint)
// 2. 构建结果和事件
result, event := d.buildResultAndEvent(key, validationErr)
// 3. 发布验证事件
d.publishValidationEvent(key.ID, event)
return result
}
// buildResultAndEvent 构建验证结果和事件
func (d *validationDispatcher) buildResultAndEvent(key models.APIKey, validationErr error) (models.KeyTestResult, models.RequestFinishedEvent) {
event := models.RequestFinishedEvent{
RequestLog: models.RequestLog{
GroupID: &d.groupID,
KeyID: &key.ID,
},
}
if validationErr == nil {
// 验证成功
event.RequestLog.IsSuccess = true
return models.KeyTestResult{
Key: key.APIKey,
Status: "valid",
Message: "Validation successful",
}, event
}
// 验证失败
event.RequestLog.IsSuccess = false
var apiErr *CustomErrors.APIError
if CustomErrors.As(validationErr, &apiErr) {
event.Error = apiErr
return models.KeyTestResult{
Key: key.APIKey,
Status: "invalid",
Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message),
}, event
}
// 其他错误
event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
return models.KeyTestResult{
Key: key.APIKey,
Status: "error",
Message: "Validation check failed: " + validationErr.Error(),
}, event
}
// publishValidationEvent 发布验证事件
func (d *validationDispatcher) publishValidationEvent(keyID uint, event models.RequestFinishedEvent) {
eventData, err := json.Marshal(event)
if err != nil {
d.service.logger.WithFields(logrus.Fields{
"key_id": keyID,
"group_id": d.groupID,
}).WithError(err).Error("Failed to marshal validation event")
return
}
if err := d.service.store.Publish(d.ctx, models.TopicRequestFinished, eventData); err != nil {
d.service.logger.WithFields(logrus.Fields{
"key_id": keyID,
"group_id": d.groupID,
}).WithError(err).Error("Failed to publish validation event")
}
}
// ==================== 辅助方法 ====================
// buildValidationError 构建验证错误
func (s *KeyValidationService) buildValidationError(resp *http.Response) error {
bodyBytes, readErr := io.ReadAll(resp.Body)
var errorMsg string
if readErr != nil {
errorMsg = "Failed to read error response body"
s.logger.WithError(readErr).Warn("Failed to read validation error response")
} else {
errorMsg = string(bodyBytes)
}
return &CustomErrors.APIError{
HTTPStatus: resp.StatusCode,
Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg),
Code: "VALIDATION_FAILED",
}
}

View File

@@ -2,77 +2,196 @@
package service
import (
"context"
"fmt"
"gemini-balancer/internal/models"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
type LogService struct {
db *gorm.DB
db *gorm.DB
logger *logrus.Entry
}
func NewLogService(db *gorm.DB) *LogService {
return &LogService{db: db}
func NewLogService(db *gorm.DB, logger *logrus.Logger) *LogService {
return &LogService{
db: db,
logger: logger.WithField("component", "LogService"),
}
}
func (s *LogService) Record(log *models.RequestLog) error {
return s.db.Create(log).Error
func (s *LogService) Record(ctx context.Context, log *models.RequestLog) error {
return s.db.WithContext(ctx).Create(log).Error
}
func (s *LogService) GetLogs(c *gin.Context) ([]models.RequestLog, int64, error) {
// LogQueryParams 解耦 Gin使用结构体传参
type LogQueryParams struct {
Page int
PageSize int
ModelName string
IsSuccess *bool // 使用指针区分"未设置"和"false"
StatusCode *int
KeyIDs []string
GroupIDs []string
Q string
ErrorCodes []string
StatusCodes []string
}
func (s *LogService) GetLogs(ctx context.Context, params LogQueryParams) ([]models.RequestLog, int64, error) {
// 参数校验
if params.Page < 1 {
params.Page = 1
}
if params.PageSize < 1 || params.PageSize > 100 {
params.PageSize = 20
}
var logs []models.RequestLog
var total int64
query := s.db.Model(&models.RequestLog{}).Scopes(s.filtersScope(c))
// 先计算总数
// 构建基础查询
query := s.db.WithContext(ctx).Model(&models.RequestLog{})
query = s.applyFilters(query, params)
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to count logs: %w", err)
}
if total == 0 {
return []models.RequestLog{}, 0, nil
}
// 再执行分页查询
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
offset := (page - 1) * pageSize
err := query.Order("request_time desc").Limit(pageSize).Offset(offset).Find(&logs).Error
if err != nil {
return nil, 0, err
offset := (params.Page - 1) * params.PageSize
if err := query.Order("request_time DESC").
Limit(params.PageSize).
Offset(offset).
Find(&logs).Error; err != nil {
return nil, 0, fmt.Errorf("failed to query logs: %w", err)
}
return logs, total, nil
}
func (s *LogService) filtersScope(c *gin.Context) func(db *gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
if modelName := c.Query("model_name"); modelName != "" {
db = db.Where("model_name = ?", modelName)
}
if isSuccessStr := c.Query("is_success"); isSuccessStr != "" {
if isSuccess, err := strconv.ParseBool(isSuccessStr); err == nil {
db = db.Where("is_success = ?", isSuccess)
}
}
if statusCodeStr := c.Query("status_code"); statusCodeStr != "" {
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil {
db = db.Where("status_code = ?", statusCode)
}
}
if keyIDStr := c.Query("key_id"); keyIDStr != "" {
if keyID, err := strconv.ParseUint(keyIDStr, 10, 64); err == nil {
db = db.Where("key_id = ?", keyID)
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if groupID, err := strconv.ParseUint(groupIDStr, 10, 64); err == nil {
db = db.Where("group_id = ?", groupID)
}
}
return db
func (s *LogService) applyFilters(query *gorm.DB, params LogQueryParams) *gorm.DB {
if params.IsSuccess != nil {
query = query.Where("is_success = ?", *params.IsSuccess)
} else {
query = query.Where("is_success = ?", false)
}
if params.ModelName != "" {
query = query.Where("model_name = ?", params.ModelName)
}
if params.StatusCode != nil {
query = query.Where("status_code = ?", *params.StatusCode)
}
if len(params.KeyIDs) > 0 {
query = query.Where("key_id IN (?)", params.KeyIDs)
}
if len(params.GroupIDs) > 0 {
query = query.Where("group_id IN (?)", params.GroupIDs)
}
hasErrorCodes := len(params.ErrorCodes) > 0
hasStatusCodes := len(params.StatusCodes) > 0
if hasErrorCodes && hasStatusCodes {
query = query.Where(
s.db.Where("error_code IN (?)", params.ErrorCodes).
Or("status_code IN (?)", params.StatusCodes),
)
} else if hasErrorCodes {
query = query.Where("error_code IN (?)", params.ErrorCodes)
} else if hasStatusCodes {
query = query.Where("status_code IN (?)", params.StatusCodes)
}
if params.Q != "" {
searchQuery := "%" + params.Q + "%"
query = query.Where(
"model_name LIKE ? OR error_code LIKE ? OR error_message LIKE ? OR CAST(status_code AS CHAR) LIKE ?",
searchQuery, searchQuery, searchQuery, searchQuery,
)
}
return query
}
// ParseLogQueryParams 在 Handler 层调用,解析 Gin 参数
func ParseLogQueryParams(queryParams map[string]string) (LogQueryParams, error) {
params := LogQueryParams{
Page: 1,
PageSize: 20,
}
if pageStr, ok := queryParams["page"]; ok {
if page, err := strconv.Atoi(pageStr); err == nil && page > 0 {
params.Page = page
}
}
if pageSizeStr, ok := queryParams["page_size"]; ok {
if pageSize, err := strconv.Atoi(pageSizeStr); err == nil && pageSize > 0 {
params.PageSize = pageSize
}
}
if modelName, ok := queryParams["model_name"]; ok {
params.ModelName = modelName
}
if isSuccessStr, ok := queryParams["is_success"]; ok {
if isSuccess, err := strconv.ParseBool(isSuccessStr); err == nil {
params.IsSuccess = &isSuccess
} else {
return params, fmt.Errorf("invalid is_success parameter: %s", isSuccessStr)
}
}
if statusCodeStr, ok := queryParams["status_code"]; ok {
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil {
params.StatusCode = &statusCode
} else {
return params, fmt.Errorf("invalid status_code parameter: %s", statusCodeStr)
}
}
if keyIDsStr, ok := queryParams["key_ids"]; ok {
params.KeyIDs = strings.Split(keyIDsStr, ",")
}
if groupIDsStr, ok := queryParams["group_ids"]; ok {
params.GroupIDs = strings.Split(groupIDsStr, ",")
}
if errorCodesStr, ok := queryParams["error_codes"]; ok {
params.ErrorCodes = strings.Split(errorCodesStr, ",")
}
if statusCodesStr, ok := queryParams["status_codes"]; ok {
params.StatusCodes = strings.Split(statusCodesStr, ",")
}
if q, ok := queryParams["q"]; ok {
params.Q = q
}
return params, nil
}
// DeleteLogs 删除指定ID的日志
func (s *LogService) DeleteLogs(ctx context.Context, ids []uint) error {
if len(ids) == 0 {
return fmt.Errorf("no log IDs provided")
}
return s.db.WithContext(ctx).Delete(&models.RequestLog{}, ids).Error
}
// DeleteAllLogs 删除所有日志
func (s *LogService) DeleteAllLogs(ctx context.Context) error {
return s.db.WithContext(ctx).Where("1 = 1").Delete(&models.RequestLog{}).Error
}
// DeleteOldLogs 删除指定天数之前的日志
func (s *LogService) DeleteOldLogs(ctx context.Context, days int) (int64, error) {
if days <= 0 {
return 0, fmt.Errorf("days must be positive")
}
result := s.db.WithContext(ctx).
Where("request_time < DATE_SUB(NOW(), INTERVAL ? DAY)", days).
Delete(&models.RequestLog{})
return result.RowsAffected, result.Error
}

View File

@@ -1,9 +1,9 @@
// Filename: internal/service/resource_service.go
package service
import (
"errors"
"context"
"gemini-balancer/internal/domain/proxy"
apperrors "gemini-balancer/internal/errors"
"gemini-balancer/internal/models"
"gemini-balancer/internal/repository"
@@ -15,10 +15,7 @@ import (
"github.com/sirupsen/logrus"
)
var (
ErrNoResourceAvailable = errors.New("no available resource found for the request")
)
// RequestResources 封装了一次成功请求所需的所有资源。
type RequestResources struct {
KeyGroup *models.KeyGroup
APIKey *models.APIKey
@@ -27,86 +24,92 @@ type RequestResources struct {
RequestConfig *models.RequestConfig
}
// ResourceService 负责根据请求参数和业务规则动态地选择和分配API密钥及相关资源。
type ResourceService struct {
settingsManager *settings.SettingsManager
groupManager *GroupManager
keyRepo repository.KeyRepository
authTokenRepo repository.AuthTokenRepository
apiKeyService *APIKeyService
proxyManager *proxy.Module
logger *logrus.Entry
initOnce sync.Once
}
// NewResourceService 创建并初始化一个新的 ResourceService 实例。
func NewResourceService(
sm *settings.SettingsManager,
gm *GroupManager,
kr repository.KeyRepository,
atr repository.AuthTokenRepository,
aks *APIKeyService,
pm *proxy.Module,
logger *logrus.Logger,
) *ResourceService {
logger.Debugf("[FORENSIC PROBE | INJECTION | ResourceService] Received 'keyRepo' param. Fingerprint: %p", kr)
rs := &ResourceService{
settingsManager: sm,
groupManager: gm,
keyRepo: kr,
authTokenRepo: atr,
apiKeyService: aks,
proxyManager: pm,
logger: logger.WithField("component", "ResourceService📦"),
}
// 使用 sync.Once 确保预热任务在服务生命周期内仅执行一次
rs.initOnce.Do(func() {
go rs.preWarmCache(logger)
go rs.preWarmCache()
})
return rs
}
// --- [模式一:智能聚合模式] ---
func (s *ResourceService) GetResourceFromBasePool(authToken *models.AuthToken, modelName string) (*RequestResources, error) {
// GetResourceFromBasePool 使用智能聚合模式获取资源。
func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken *models.AuthToken, modelName string) (*RequestResources, error) {
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"})
log.Debug("Entering BasePool resource acquisition.")
// 1.筛选出所有符合条件的候选组,并按优先级排序
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken.AllowedGroups)
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken)
if len(candidateGroups) == 0 {
log.Warn("No candidate groups found for BasePool construction.")
return nil, apperrors.ErrNoKeysAvailable
}
// 2.从 BasePool中根据系统全局策略选择一个Key
basePool := &repository.BasePool{
CandidateGroups: candidateGroups,
PollingStrategy: s.settingsManager.GetSettings().PollingStrategy,
}
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(basePool)
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(ctx, basePool)
if err != nil {
log.WithError(err).Warn("Failed to select a key from the BasePool.")
return nil, apperrors.ErrNoKeysAvailable
}
// 3. 组装最终资源
// [关键] 在此模式下RequestConfig 永远是空的,以保证透明性。
resources, err := s.assembleRequestResources(selectedGroup, apiKey)
if err != nil {
log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.")
return nil, err
}
resources.RequestConfig = &models.RequestConfig{} // 强制为空
resources.RequestConfig = &models.RequestConfig{} // BasePool 模式使用默认请求配置
log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID)
return resources, nil
}
// --- [模式二:精确路由模式] ---
func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, groupName string) (*RequestResources, error) {
// GetResourceFromGroup 使用精确路由模式(指定密钥组)获取资源。
func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *models.AuthToken, groupName string) (*RequestResources, error) {
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"})
log.Debug("Entering PreciseRoute resource acquisition.")
targetGroup, ok := s.groupManager.GetGroupByName(groupName)
if !ok {
return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.")
}
if !s.isTokenAllowedForGroup(authToken, targetGroup.ID) {
return nil, apperrors.NewAPIError(apperrors.ErrPermissionDenied, "Token does not have permission to access this group.")
}
apiKey, _, err := s.keyRepo.SelectOneActiveKey(targetGroup)
apiKey, _, err := s.keyRepo.SelectOneActiveKey(ctx, targetGroup)
if err != nil {
log.WithError(err).Warn("Failed to select a key from the precisely targeted group.")
return nil, apperrors.ErrNoKeysAvailable
@@ -117,39 +120,39 @@ func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, grou
log.WithError(err).Error("Failed to assemble resources for precise route.")
return nil, err
}
resources.RequestConfig = targetGroup.RequestConfig
resources.RequestConfig = targetGroup.RequestConfig // 精确路由使用该组的特定请求配置
log.Infof("Successfully selected KeyID %d by precise routing to GroupID %d.", apiKey.ID, targetGroup.ID)
return resources, nil
}
// GetAllowedModelsForToken 获取指定认证令牌有权访问的所有模型名称列表。
func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken) []string {
allGroups := s.groupManager.GetAllGroups()
if len(allGroups) == 0 {
return []string{}
}
allowedModelsSet := make(map[string]struct{})
allowedGroupIDs := make(map[uint]bool)
if authToken.IsAdmin {
for _, group := range allGroups {
for _, modelMapping := range group.AllowedModels {
allowedModelsSet[modelMapping.ModelName] = struct{}{}
}
allowedGroupIDs[group.ID] = true
}
} else {
allowedGroupIDs := make(map[uint]bool)
for _, ag := range authToken.AllowedGroups {
allowedGroupIDs[ag.ID] = true
}
for _, group := range allGroups {
if _, ok := allowedGroupIDs[group.ID]; ok {
for _, modelMapping := range group.AllowedModels {
}
allowedModelsSet[modelMapping.ModelName] = struct{}{}
}
allowedModelsSet := make(map[string]struct{})
for _, group := range allGroups {
if allowedGroupIDs[group.ID] {
for _, modelMapping := range group.AllowedModels {
allowedModelsSet[modelMapping.ModelName] = struct{}{}
}
}
}
result := make([]string, 0, len(allowedModelsSet))
for modelName := range allowedModelsSet {
result = append(result, modelName)
@@ -158,20 +161,52 @@ func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken)
return result
}
// ReportRequestResult 向 APIKeyService 报告请求的最终结果,以便更新密钥状态。
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
return
}
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
}
// --- 私有辅助方法 ---
// preWarmCache 在后台执行一次性的缓存预热任务。
func (s *ResourceService) preWarmCache() {
time.Sleep(2 * time.Second) // 等待其他服务组件可能完成初始化
s.logger.Info("Performing initial key cache pre-warming...")
// 强制加载 GroupManager 缓存
s.logger.Info("Pre-warming GroupManager cache...")
_ = s.groupManager.GetAllGroups()
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // 给予更长的超时
defer cancel()
if err := s.keyRepo.LoadAllKeysToStore(ctx); err != nil {
s.logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
} else {
s.logger.Info("Initial key cache pre-warming completed successfully.")
}
}
// assembleRequestResources 根据密钥组和API密钥组装最终的资源对象。
func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKey *models.APIKey) (*RequestResources, error) {
selectedUpstream := s.selectUpstreamForGroup(group)
if selectedUpstream == nil {
return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.")
}
var proxyConfig *models.ProxyConfig
// [注意] 代理逻辑需要一个 proxyModule 实例,我们暂时置空。后续需要重新注入依赖。
// if group.EnableProxy && s.proxyModule != nil {
// var err error
// proxyConfig, err = s.proxyModule.AssignProxyIfNeeded(apiKey)
// if err != nil {
// s.logger.WithError(err).Warnf("Failed to assign proxy for API key %d.", apiKey.ID)
// }
// }
var err error
// 只有在组明确启用代理时,才为其分配代理
if group.EnableProxy {
proxyConfig, err = s.proxyManager.AssignProxyIfNeeded(apiKey)
if err != nil {
s.logger.WithError(err).Errorf("Group '%s' (ID: %d) requires a proxy, but failed to assign one for KeyID %d", group.Name, group.ID, apiKey.ID)
// 根据业务需求,这里必须返回错误,因为代理是该组的强制要求
return nil, apperrors.NewAPIError(apperrors.ErrProxyNotAvailable, "Required proxy is not available for this request.")
}
}
return &RequestResources{
KeyGroup: group,
APIKey: apiKey,
@@ -180,8 +215,10 @@ func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKe
}, nil
}
// selectUpstreamForGroup 为指定的密钥组选择一个上游端点。
func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models.UpstreamEndpoint {
if len(group.AllowedUpstreams) > 0 {
// (未来可扩展负载均衡逻辑)
return group.AllowedUpstreams[0]
}
globalSettings := s.settingsManager.GetSettings()
@@ -191,62 +228,39 @@ func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models
return nil
}
func (s *ResourceService) preWarmCache(logger *logrus.Logger) error {
time.Sleep(2 * time.Second)
s.logger.Info("Performing initial key cache pre-warming...")
if err := s.keyRepo.LoadAllKeysToStore(); err != nil {
logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
return err
}
s.logger.Info("Initial key cache pre-warming completed successfully.")
return nil
}
func (s *ResourceService) GetResourcesForRequest(modelName string, allowedGroups []*models.KeyGroup) (*RequestResources, error) {
return nil, errors.New("GetResourcesForRequest is deprecated; use GetResourceFromBasePool or GetResourceFromGroup")
}
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowedGroupsFromToken []*models.KeyGroup) []*models.KeyGroup {
// filterAndSortCandidateGroups 根据模型名称和令牌权限,筛选并排序出合格的候选密钥组。
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, authToken *models.AuthToken) []*models.KeyGroup {
allGroupsFromCache := s.groupManager.GetAllGroups()
var candidateGroups []*models.KeyGroup
// 1. 确定权限范围
allowedGroupIDs := make(map[uint]bool)
isTokenRestricted := len(allowedGroupsFromToken) > 0
if isTokenRestricted {
for _, ag := range allowedGroupsFromToken {
allowedGroupIDs[ag.ID] = true
}
}
// 2. 筛选
for _, group := range allGroupsFromCache {
// 检查Token权限
if isTokenRestricted && !allowedGroupIDs[group.ID] {
// 检查令牌权限
if !s.isTokenAllowedForGroup(authToken, group.ID) {
continue
}
// 检查模型是否被允许
isModelAllowed := false
if len(group.AllowedModels) == 0 { // 如果组不限制模型,则允许
isModelAllowed = true
} else {
for _, m := range group.AllowedModels {
if m.ModelName == modelName {
isModelAllowed = true
break
}
}
}
if isModelAllowed {
// 检查模型支持情况 (如果组内未限制模型,则默认支持所有模型)
if len(group.AllowedModels) == 0 || s.groupSupportsModel(group, modelName) {
candidateGroups = append(candidateGroups, group)
}
}
// 3.按 Order 字段升序排序
sort.SliceStable(candidateGroups, func(i, j int) bool {
return candidateGroups[i].Order < candidateGroups[j].Order
})
return candidateGroups
}
// groupSupportsModel 检查指定的密钥组是否支持给定的模型名称。
func (s *ResourceService) groupSupportsModel(group *models.KeyGroup, modelName string) bool {
for _, m := range group.AllowedModels {
if m.ModelName == modelName {
return true
}
}
return false
}
// isTokenAllowedForGroup 检查指定的认证令牌是否有权访问给定的密钥组。
func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, groupID uint) bool {
if authToken.IsAdmin {
return true
@@ -258,10 +272,3 @@ func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, gr
}
return false
}
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
return
}
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
}

View File

@@ -52,7 +52,7 @@ func (s *SecurityService) AuthenticateToken(tokenValue string) (*models.AuthToke
// IsIPBanned
func (s *SecurityService) IsIPBanned(ctx context.Context, ip string) (bool, error) {
banKey := fmt.Sprintf("banned_ip:%s", ip)
return s.store.Exists(banKey)
return s.store.Exists(ctx, banKey)
}
// RecordFailedLoginAttempt
@@ -61,7 +61,7 @@ func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip strin
return nil
}
count, err := s.store.HIncrBy(loginAttemptsKey, ip, 1)
count, err := s.store.HIncrBy(ctx, loginAttemptsKey, ip, 1)
if err != nil {
return err
}
@@ -71,12 +71,12 @@ func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip strin
banDuration := s.SettingsManager.GetIPBanDuration()
banKey := fmt.Sprintf("banned_ip:%s", ip)
if err := s.store.Set(banKey, []byte("1"), banDuration); err != nil {
if err := s.store.Set(ctx, banKey, []byte("1"), banDuration); err != nil {
return err
}
s.logger.Warnf("IP BANNED: IP [%s] has been banned for %v due to excessive failed login attempts.", ip, banDuration)
s.store.HDel(loginAttemptsKey, ip)
s.store.HDel(ctx, loginAttemptsKey, ip)
}
return nil

View File

@@ -2,6 +2,7 @@
package service
import (
"context"
"encoding/json"
"fmt"
"gemini-balancer/internal/models"
@@ -34,75 +35,121 @@ func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository,
func (s *StatsService) Start() {
s.logger.Info("Starting event listener for stats maintenance.")
sub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
if err != nil {
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
return
}
go func() {
defer sub.Close()
for {
select {
case msg := <-sub.Channel():
var event models.KeyStatusChangedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err)
continue
}
s.handleKeyStatusChange(&event)
case <-s.stopChan:
s.logger.Info("Stopping stats event listener.")
return
}
}
}()
go s.listenForEvents()
}
func (s *StatsService) Stop() {
close(s.stopChan)
}
func (s *StatsService) listenForEvents() {
for {
select {
case <-s.stopChan:
s.logger.Info("Stopping stats event listener.")
return
default:
}
ctx, cancel := context.WithCancel(context.Background())
sub, err := s.store.Subscribe(ctx, models.TopicKeyStatusChanged)
if err != nil {
s.logger.Errorf("Failed to subscribe: %v, retrying in 5s", err)
cancel()
time.Sleep(5 * time.Second)
continue
}
s.logger.Info("Subscribed to key status changes")
s.handleSubscription(sub, cancel)
}
}
func (s *StatsService) handleSubscription(sub store.Subscription, cancel context.CancelFunc) {
defer sub.Close()
defer cancel()
for {
select {
case msg := <-sub.Channel():
var event models.KeyStatusChangedEvent
if err := json.Unmarshal(msg.Payload, &event); err != nil {
s.logger.Errorf("Failed to unmarshal event: %v", err)
continue
}
s.handleKeyStatusChange(&event)
case <-s.stopChan:
return
}
}
}
func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent) {
if event.GroupID == 0 {
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
return
}
ctx := context.Background()
statsKey := fmt.Sprintf("stats:group:%d", event.GroupID)
s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason)
switch event.ChangeReason {
case "key_unlinked", "key_hard_deleted":
if event.OldStatus != "" {
s.store.HIncrBy(statsKey, "total_keys", -1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
if _, err := s.store.HIncrBy(ctx, statsKey, "total_keys", -1); err != nil {
s.logger.WithError(err).Errorf("Failed to decrement total_keys for group %d", event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1); err != nil {
s.logger.WithError(err).Errorf("Failed to decrement %s_keys for group %d", event.OldStatus, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
} else {
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
case "key_linked":
if event.NewStatus != "" {
s.store.HIncrBy(statsKey, "total_keys", 1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
if _, err := s.store.HIncrBy(ctx, statsKey, "total_keys", 1); err != nil {
s.logger.WithError(err).Errorf("Failed to increment total_keys for group %d", event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1); err != nil {
s.logger.WithError(err).Errorf("Failed to increment %s_keys for group %d", event.NewStatus, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
} else {
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1); err != nil {
s.logger.WithError(err).Errorf("Failed to decrement %s_keys for group %d", event.OldStatus, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1); err != nil {
s.logger.WithError(err).Errorf("Failed to increment %s_keys for group %d", event.NewStatus, event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
return
}
default:
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
s.RecalculateGroupKeyStats(event.GroupID)
s.RecalculateGroupKeyStats(ctx, event.GroupID)
}
}
func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
func (s *StatsService) RecalculateGroupKeyStats(ctx context.Context, groupID uint) error {
s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID)
var results []struct {
Status models.APIKeyStatus
Count int64
}
if err := s.db.Model(&models.GroupAPIKeyMapping{}).
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
Where("key_group_id = ?", groupID).
Select("status, COUNT(*) as count").
Group("status").
@@ -111,45 +158,36 @@ func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
}
statsKey := fmt.Sprintf("stats:group:%d", groupID)
updates := make(map[string]interface{})
totalKeys := int64(0)
updates := map[string]interface{}{
"active_keys": int64(0),
"disabled_keys": int64(0),
"error_keys": int64(0),
"total_keys": int64(0),
}
for _, res := range results {
updates[fmt.Sprintf("%s_keys", res.Status)] = res.Count
totalKeys += res.Count
updates["total_keys"] = updates["total_keys"].(int64) + res.Count
}
updates["total_keys"] = totalKeys
if err := s.store.Del(statsKey); err != nil {
if err := s.store.Del(ctx, statsKey); err != nil {
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
}
if err := s.store.HSet(statsKey, updates); err != nil {
if err := s.store.HSet(ctx, statsKey, updates); err != nil {
return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err)
}
s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID)
return nil
}
func (s *StatsService) GetDashboardStats() (*models.DashboardStatsResponse, error) {
// TODO 逻辑:
// 1. 从Redis中获取所有分组的Key统计 (HGetAll)
// 2. 从 stats_hourly 表中获取过去24小时的请求数和错误率
// 3. 组合成 DashboardStatsResponse
// ... 这个方法的具体实现我们可以在DashboardQueryService中完成
// 这里我们先确保StatsService的核心职责维护缓存已经完成。
// 为了编译通过,我们先返回一个空对象。
// 伪代码:
// keyCounts, _ := s.store.HGetAll("stats:global:keys")
// ...
func (s *StatsService) GetDashboardStats(ctx context.Context) (*models.DashboardStatsResponse, error) {
return &models.DashboardStatsResponse{}, nil
}
func (s *StatsService) AggregateHourlyStats() error {
func (s *StatsService) AggregateHourlyStats(ctx context.Context) error {
s.logger.Info("Starting aggregation of the last hour's request data...")
now := time.Now()
endTime := now.Truncate(time.Hour) // 例如15:23 -> 15:00
startTime := endTime.Add(-1 * time.Hour) // 15:00 -> 14:00
endTime := now.Truncate(time.Hour)
startTime := endTime.Add(-1 * time.Hour)
s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339))
type aggregationResult struct {
@@ -161,7 +199,8 @@ func (s *StatsService) AggregateHourlyStats() error {
CompletionTokens int64
}
var results []aggregationResult
err := s.db.Model(&models.RequestLog{}).
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens").
Where("request_time >= ? AND request_time < ?", startTime, endTime).
Group("group_id, model_name").
@@ -179,7 +218,7 @@ func (s *StatsService) AggregateHourlyStats() error {
var hourlyStats []models.StatsHourly
for _, res := range results {
hourlyStats = append(hourlyStats, models.StatsHourly{
Time: startTime, // 所有记录的时间戳都是该小时的起点
Time: startTime,
GroupID: res.GroupID,
ModelName: res.ModelName,
RequestCount: res.RequestCount,
@@ -189,8 +228,18 @@ func (s *StatsService) AggregateHourlyStats() error {
})
}
return s.db.Clauses(clause.OnConflict{
if err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
}).Create(&hourlyStats).Error
}).Create(&hourlyStats).Error; err != nil {
return err
}
if err := s.db.WithContext(ctx).
Where("request_time >= ? AND request_time < ?", startTime, endTime).
Delete(&models.RequestLog{}).Error; err != nil {
s.logger.WithError(err).Warn("Failed to delete aggregated request logs")
}
return nil
}

View File

@@ -37,7 +37,7 @@ func NewTokenManager(repo repository.AuthTokenRepository, store store.Store, log
return tokens, nil
}
s, err := syncer.NewCacheSyncer(tokenLoader, store, TopicTokenChanged)
s, err := syncer.NewCacheSyncer(tokenLoader, store, TopicTokenChanged, logger)
if err != nil {
return nil, fmt.Errorf("failed to create token manager syncer: %w", err)
}

View File

@@ -1,4 +1,4 @@
// file: gemini-balancer\internal\settings\settings.go
// Filename: gemini-balancer/internal/settings/settings.go (最终审计修复版)
package settings
import (
@@ -19,7 +19,9 @@ import (
const SettingsUpdateChannel = "system_settings:updated"
const DefaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
// SettingsManager [核心修正] syncer现在缓存正确的“蓝图”类型
var _ models.SettingsManager = (*SettingsManager)(nil)
// SettingsManager 负责管理系统的动态设置,包括从数据库加载、缓存同步和更新。
type SettingsManager struct {
db *gorm.DB
syncer *syncer.CacheSyncer[*models.SystemSettings]
@@ -27,13 +29,14 @@ type SettingsManager struct {
jsonToFieldType map[string]reflect.Type // 用于将JSON字段映射到Go类型
}
// NewSettingsManager 创建一个新的 SettingsManager 实例。
func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (*SettingsManager, error) {
sm := &SettingsManager{
db: db,
logger: logger.WithField("component", "SettingsManager⚙"),
jsonToFieldType: make(map[string]reflect.Type),
}
// settingsLoader 的职责:读取“砖块”,组装并返回“蓝图”
settingsType := reflect.TypeOf(models.SystemSettings{})
for i := 0; i < settingsType.NumField(); i++ {
field := settingsType.Field(i)
@@ -42,102 +45,89 @@ func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (
sm.jsonToFieldType[jsonTag] = field.Type
}
}
// settingsLoader 的职责:读取“砖块”,智能组装成“蓝图”
settingsLoader := func() (*models.SystemSettings, error) {
sm.logger.Info("Loading system settings from database...")
var dbRecords []models.Setting
if err := sm.db.Find(&dbRecords).Error; err != nil {
return nil, fmt.Errorf("failed to load system settings from db: %w", err)
}
settingsMap := make(map[string]string)
for _, record := range dbRecords {
settingsMap[record.Key] = record.Value
}
// 从一个包含了所有“出厂设置”的“蓝图”开始
settings := defaultSystemSettings()
v := reflect.ValueOf(settings).Elem()
t := v.Type()
// [智能卸货]
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
for i := 0; i < v.NumField(); i++ {
field := v.Type().Field(i)
fieldValue := v.Field(i)
jsonTag := field.Tag.Get("json")
if dbValue, ok := settingsMap[jsonTag]; ok {
if dbValue, ok := settingsMap[jsonTag]; ok {
if err := parseAndSetField(fieldValue, dbValue); err != nil {
sm.logger.Warnf("Failed to set config field '%s' from DB value '%s': %v. Using default.", field.Name, dbValue, err)
sm.logger.Warnf("Failed to set field '%s' from DB value '%s': %v. Using default.", field.Name, dbValue, err)
}
}
}
if settings.BaseKeyCheckEndpoint == DefaultGeminiEndpoint || settings.BaseKeyCheckEndpoint == "" {
if settings.DefaultUpstreamURL != "" {
// 如果全局上游URL已设置则基于它构建新的检查端点。
originalEndpoint := settings.BaseKeyCheckEndpoint
derivedEndpoint := strings.TrimSuffix(settings.DefaultUpstreamURL, "/") + "/models"
settings.BaseKeyCheckEndpoint = derivedEndpoint
sm.logger.Infof(
"BaseKeyCheckEndpoint is dynamically derived from DefaultUpstreamURL. Original: '%s', New: '%s'",
originalEndpoint, derivedEndpoint,
)
}
} else {
// [评估确认] 派生逻辑与原始版本在功能和日志行为上完全一致。
if (settings.BaseKeyCheckEndpoint == DefaultGeminiEndpoint || settings.BaseKeyCheckEndpoint == "") && settings.DefaultUpstreamURL != "" {
derivedEndpoint := strings.TrimSuffix(settings.DefaultUpstreamURL, "/") + "/models"
sm.logger.Infof("BaseKeyCheckEndpoint is dynamically derived from DefaultUpstreamURL: %s", derivedEndpoint)
settings.BaseKeyCheckEndpoint = derivedEndpoint
} else if settings.BaseKeyCheckEndpoint != DefaultGeminiEndpoint && settings.BaseKeyCheckEndpoint != "" {
// 恢复 else 日志,以明确告知用户正在使用自定义覆盖。
sm.logger.Infof("BaseKeyCheckEndpoint is using a user-defined override: %s", settings.BaseKeyCheckEndpoint)
}
sm.logger.Info("System settings loaded and cached.")
sm.DisplaySettings(settings)
return settings, nil
}
s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel)
s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel, logger)
if err != nil {
return nil, fmt.Errorf("failed to create system settings syncer: %w", err)
}
sm.syncer = s
go sm.ensureSettingsInitialized()
if err := sm.ensureSettingsInitialized(); err != nil {
return nil, fmt.Errorf("failed to ensure system settings are initialized: %w", err)
}
return sm, nil
}
// GetSettings [核心修正] 现在它正确地返回我们需要的“蓝图”
// GetSettings 返回当前缓存的系统设置。
func (sm *SettingsManager) GetSettings() *models.SystemSettings {
return sm.syncer.Get()
}
// UpdateSettings [核心修正] 它接收更新,并将它们转换为“砖块”存入数据库
// UpdateSettings 更新一个或多个系统设置。
func (sm *SettingsManager) UpdateSettings(settingsMap map[string]interface{}) error {
var settingsToUpdate []models.Setting
for key, value := range settingsMap {
fieldType, ok := sm.jsonToFieldType[key]
if !ok {
sm.logger.Warnf("Received update for unknown setting key '%s', ignoring.", key)
continue
}
var dbValue string
// [智能打包]
// 如果字段是 slice 或 map我们就将传入的 interface{} “打包”成 JSON string
kind := fieldType.Kind()
if kind == reflect.Slice || kind == reflect.Map {
jsonBytes, marshalErr := json.Marshal(value)
if marshalErr != nil {
// [真正的错误处理] 如果打包失败,我们记录日志,并跳过这个“坏掉的集装箱”。
sm.logger.Warnf("Failed to marshal setting '%s' to JSON: %v, skipping update.", key, marshalErr)
continue // 跳过继续处理下一个key
}
dbValue = string(jsonBytes)
} else if kind == reflect.Bool {
if b, ok := value.(bool); ok {
dbValue = strconv.FormatBool(b)
} else {
dbValue = "false"
}
} else {
dbValue = fmt.Sprintf("%v", value)
dbValue, err := sm.convertToDBValue(key, value, fieldType)
if err != nil {
sm.logger.Warnf("Failed to convert value for setting '%s': %v. Skipping update.", key, err)
continue
}
settingsToUpdate = append(settingsToUpdate, models.Setting{
Key: key,
Value: dbValue,
})
}
if len(settingsToUpdate) > 0 {
err := sm.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}},
@@ -147,83 +137,20 @@ func (sm *SettingsManager) UpdateSettings(settingsMap map[string]interface{}) er
return fmt.Errorf("failed to update settings in db: %w", err)
}
}
return sm.syncer.Invalidate()
}
// ensureSettingsInitialized [核心修正] 确保DB中有所有“砖块”的定义
func (sm *SettingsManager) ensureSettingsInitialized() {
defaults := defaultSystemSettings()
v := reflect.ValueOf(defaults).Elem()
t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldValue := v.Field(i)
key := field.Tag.Get("json")
if key == "" || key == "-" {
continue
}
var existing models.Setting
if err := sm.db.Where("key = ?", key).First(&existing).Error; err == gorm.ErrRecordNotFound {
var defaultValue string
kind := fieldValue.Kind()
// [智能初始化]
if kind == reflect.Slice || kind == reflect.Map {
// 为复杂类型生成一个“空的”JSON字符串例如 "[]" 或 "{}"
jsonBytes, _ := json.Marshal(fieldValue.Interface())
defaultValue = string(jsonBytes)
} else {
defaultValue = field.Tag.Get("default")
}
setting := models.Setting{
Key: key,
Value: defaultValue,
Name: field.Tag.Get("name"),
Description: field.Tag.Get("desc"),
Category: field.Tag.Get("category"),
DefaultValue: field.Tag.Get("default"), // 元数据中的default永远来自tag
}
if err := sm.db.Create(&setting).Error; err != nil {
sm.logger.Errorf("Failed to initialize setting '%s': %v", key, err)
}
}
if err := sm.syncer.Invalidate(); err != nil {
sm.logger.Errorf("CRITICAL: Database settings updated, but cache invalidation failed: %v", err)
return fmt.Errorf("settings updated but cache invalidation failed, system may be inconsistent: %w", err)
}
return nil
}
// ResetAndSaveSettings [核心新增] 將所有置重置為其在 'default' 標籤中定義的值。
// ResetAndSaveSettings 所有置重置为其默认值。
func (sm *SettingsManager) ResetAndSaveSettings() (*models.SystemSettings, error) {
defaults := defaultSystemSettings()
v := reflect.ValueOf(defaults).Elem()
t := v.Type()
var settingsToSave []models.Setting
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldValue := v.Field(i)
key := field.Tag.Get("json")
if key == "" || key == "-" {
continue
}
settingsToSave := sm.buildSettingsFromDefaults(defaults)
var defaultValue string
kind := fieldValue.Kind()
// [智能重置]
if kind == reflect.Slice || kind == reflect.Map {
jsonBytes, _ := json.Marshal(fieldValue.Interface())
defaultValue = string(jsonBytes)
} else {
defaultValue = field.Tag.Get("default")
}
setting := models.Setting{
Key: key,
Value: defaultValue,
Name: field.Tag.Get("name"),
Description: field.Tag.Get("desc"),
Category: field.Tag.Get("category"),
DefaultValue: field.Tag.Get("default"),
}
settingsToSave = append(settingsToSave, setting)
}
if len(settingsToSave) > 0 {
err := sm.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}},
@@ -233,8 +160,99 @@ func (sm *SettingsManager) ResetAndSaveSettings() (*models.SystemSettings, error
return nil, fmt.Errorf("failed to reset settings in db: %w", err)
}
}
if err := sm.syncer.Invalidate(); err != nil {
sm.logger.Errorf("Failed to invalidate settings cache after reset: %v", err)
sm.logger.Errorf("CRITICAL: Database settings reset, but cache invalidation failed: %v", err)
return nil, fmt.Errorf("settings reset but cache invalidation failed: %w", err)
}
return defaults, nil
}
// --- 私有辅助函数 ---
func (sm *SettingsManager) ensureSettingsInitialized() error {
defaults := defaultSystemSettings()
settingsToCreate := sm.buildSettingsFromDefaults(defaults)
for _, setting := range settingsToCreate {
var existing models.Setting
err := sm.db.Where("key = ?", setting.Key).First(&existing).Error
if err == gorm.ErrRecordNotFound {
sm.logger.Infof("Initializing new setting '%s'", setting.Key)
if createErr := sm.db.Create(&setting).Error; createErr != nil {
return fmt.Errorf("failed to create initial setting '%s': %w", setting.Key, createErr)
}
} else if err != nil {
return fmt.Errorf("failed to check for existing setting '%s': %w", setting.Key, err)
}
}
return nil
}
func (sm *SettingsManager) buildSettingsFromDefaults(defaults *models.SystemSettings) []models.Setting {
v := reflect.ValueOf(defaults).Elem()
t := v.Type()
var settings []models.Setting
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldValue := v.Field(i)
key := field.Tag.Get("json")
if key == "" || key == "-" {
continue
}
var defaultValue string
kind := fieldValue.Kind()
if kind == reflect.Slice || kind == reflect.Map {
jsonBytes, _ := json.Marshal(fieldValue.Interface())
defaultValue = string(jsonBytes)
} else {
defaultValue = field.Tag.Get("default")
}
settings = append(settings, models.Setting{
Key: key,
Value: defaultValue,
Name: field.Tag.Get("name"),
Description: field.Tag.Get("desc"),
Category: field.Tag.Get("category"),
DefaultValue: field.Tag.Get("default"),
})
}
return settings
}
// [修正] 使用空白标识符 `_` 修复 "unused parameter" 警告。
func (sm *SettingsManager) convertToDBValue(_ string, value interface{}, fieldType reflect.Type) (string, error) {
kind := fieldType.Kind()
switch kind {
case reflect.Slice, reflect.Map:
jsonBytes, err := json.Marshal(value)
if err != nil {
return "", fmt.Errorf("failed to marshal to JSON: %w", err)
}
return string(jsonBytes), nil
case reflect.Bool:
b, ok := value.(bool)
if !ok {
return "", fmt.Errorf("expected bool, but got %T", value)
}
return strconv.FormatBool(b), nil
default:
return fmt.Sprintf("%v", value), nil
}
}
// IsValidKey 检查给定的 JSON key 是否是有效的设置字段
func (sm *SettingsManager) IsValidKey(key string) (reflect.Type, bool) {
fieldType, ok := sm.jsonToFieldType[key]
return fieldType, ok
}

View File

@@ -1,3 +1,4 @@
// Filename: internal/store/factory.go
package store
import (
@@ -11,7 +12,6 @@ import (
// NewStore creates a new store based on the application configuration.
func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) {
// 检查是否有Redis配置
if cfg.Redis.DSN != "" {
opts, err := redis.ParseURL(cfg.Redis.DSN)
if err != nil {
@@ -20,10 +20,10 @@ func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) {
client := redis.NewClient(opts)
if err := client.Ping(context.Background()).Err(); err != nil {
logger.WithError(err).Warnf("WARN: Failed to connect to Redis (%s), falling back to in-memory store. Error: %v", cfg.Redis.DSN, err)
return NewMemoryStore(logger), nil // 连接失败,也回退到内存模式,但不返回错误
return NewMemoryStore(logger), nil
}
logger.Info("Successfully connected to Redis. Using Redis as store.")
return NewRedisStore(client), nil
return NewRedisStore(client, logger), nil
}
logger.Info("INFO: Redis DSN not configured, falling back to in-memory store.")
return NewMemoryStore(logger), nil

View File

@@ -1,17 +1,20 @@
// Filename: internal/store/memory_store.go (经同行审查后最终修复版)
// Filename: internal/store/memory_store.go
package store
import (
"context"
"fmt"
"math/rand"
"sort"
"strconv"
"sync"
"time"
"github.com/sirupsen/logrus"
)
// ensure memoryStore implements Store interface
var _ Store = (*memoryStore)(nil)
type memoryStoreItem struct {
@@ -32,7 +35,6 @@ type memoryStore struct {
items map[string]*memoryStoreItem
pubsub map[string][]chan *Message
mu sync.RWMutex
// [USER SUGGESTION APPLIED] 使用带锁的随机数源以保证并发安全
rng *rand.Rand
rngMu sync.Mutex
logger *logrus.Entry
@@ -42,7 +44,6 @@ func NewMemoryStore(logger *logrus.Logger) Store {
store := &memoryStore{
items: make(map[string]*memoryStoreItem),
pubsub: make(map[string][]chan *Message),
// 使用当前时间作为种子,创建一个新的随机数源
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
logger: logger.WithField("component", "store.memory 🗱"),
}
@@ -50,13 +51,12 @@ func NewMemoryStore(logger *logrus.Logger) Store {
return store
}
// [USER SUGGESTION INCORPORATED] Fix #1: 使用 now := time.Now() 进行原子性检查
func (s *memoryStore) startGCollector() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
s.mu.Lock()
now := time.Now() // 避免在循环中重复调用
now := time.Now()
for key, item := range s.items {
if !item.expireAt.IsZero() && now.After(item.expireAt) {
delete(s.items, key)
@@ -66,92 +66,10 @@ func (s *memoryStore) startGCollector() {
}
}
// [USER SUGGESTION INCORPORATED] Fix #2 & #3: 修复了致命的nil检查和类型断言问题
func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
// --- 所有方法签名都增加了 context.Context 参数以匹配接口 ---
// --- 内存实现可以忽略该参数,用 _ 接收 ---
mainItem, mainOk := s.items[mainKey]
var mainSet map[string]struct{}
if mainOk && !mainItem.isExpired() {
// 安全地进行类型断言
mainSet, mainOk = mainItem.value.(map[string]struct{})
// 确保断言成功且集合不为空
mainOk = mainOk && len(mainSet) > 0
} else {
mainOk = false
}
if !mainOk {
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
return "", ErrNotFound
}
// 安全地进行类型断言
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
if !cooldownSetOk || len(cooldownSet) == 0 {
return "", ErrNotFound
}
s.items[mainKey] = cooldownItem
delete(s.items, cooldownKey)
mainSet = cooldownSet
}
var popped string
for k := range mainSet {
popped = k
break
}
delete(mainSet, popped)
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
s.items[cooldownKey] = cooldownItem
}
// 安全地处理冷却池
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
if !ok {
cooldownSet = make(map[string]struct{})
cooldownItem.value = cooldownSet
}
cooldownSet[popped] = struct{}{}
return popped, nil
}
// SRandMember [并发修复版] 使用带锁的rng
func (s *memoryStore) SRandMember(key string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return "", ErrNotFound
}
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
if len(members) == 0 {
return "", ErrNotFound
}
s.rngMu.Lock()
n := s.rng.Intn(len(members))
s.rngMu.Unlock()
return members[n], nil
}
// --- 以下是其余函数的最终版本,它们都遵循了安全、原子的锁策略 ---
func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
var expireAt time.Time
@@ -162,7 +80,7 @@ func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
return nil
}
func (s *memoryStore) Get(key string) ([]byte, error) {
func (s *memoryStore) Get(_ context.Context, key string) ([]byte, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -175,7 +93,7 @@ func (s *memoryStore) Get(key string) ([]byte, error) {
return nil, ErrNotFound
}
func (s *memoryStore) Del(keys ...string) error {
func (s *memoryStore) Del(_ context.Context, keys ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
for _, key := range keys {
@@ -184,14 +102,25 @@ func (s *memoryStore) Del(keys ...string) error {
return nil
}
func (s *memoryStore) Exists(key string) (bool, error) {
func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
return ok && !item.isExpired(), nil
}
func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
func (s *memoryStore) Expire(_ context.Context, key string, expiration time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
if !ok {
return ErrNotFound
}
item.expireAt = time.Now().Add(expiration)
return nil
}
func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -208,7 +137,7 @@ func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool,
func (s *memoryStore) Close() error { return nil }
func (s *memoryStore) HDel(key string, fields ...string) error {
func (s *memoryStore) HDel(_ context.Context, key string, fields ...string) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -223,7 +152,7 @@ func (s *memoryStore) HDel(key string, fields ...string) error {
return nil
}
func (s *memoryStore) HSet(key string, values map[string]any) error {
func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -242,7 +171,22 @@ func (s *memoryStore) HSet(key string, values map[string]any) error {
return nil
}
func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
func (s *memoryStore) HGet(_ context.Context, key, field string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
if hash, ok := item.value.(map[string]string); ok {
if value, exists := hash[field]; exists {
return value, nil
}
}
return "", ErrNotFound
}
func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -259,7 +203,7 @@ func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
return make(map[string]string), nil
}
func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
func (s *memoryStore) HIncrBy(_ context.Context, key, field string, incr int64) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -281,7 +225,7 @@ func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
return newVal, nil
}
func (s *memoryStore) LPush(key string, values ...any) error {
func (s *memoryStore) LPush(_ context.Context, key string, values ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -301,7 +245,7 @@ func (s *memoryStore) LPush(key string, values ...any) error {
return nil
}
func (s *memoryStore) LRem(key string, count int64, value any) error {
func (s *memoryStore) LRem(_ context.Context, key string, count int64, value any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -326,7 +270,7 @@ func (s *memoryStore) LRem(key string, count int64, value any) error {
return nil
}
func (s *memoryStore) SAdd(key string, members ...any) error {
func (s *memoryStore) SAdd(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -345,7 +289,7 @@ func (s *memoryStore) SAdd(key string, members ...any) error {
return nil
}
func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
func (s *memoryStore) SPopN(_ context.Context, key string, count int64) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -375,7 +319,7 @@ func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
return popped, nil
}
func (s *memoryStore) SMembers(key string) ([]string, error) {
func (s *memoryStore) SMembers(_ context.Context, key string) ([]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -393,7 +337,7 @@ func (s *memoryStore) SMembers(key string) ([]string, error) {
return members, nil
}
func (s *memoryStore) SRem(key string, members ...any) error {
func (s *memoryStore) SRem(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -410,7 +354,51 @@ func (s *memoryStore) SRem(key string, members ...any) error {
return nil
}
func (s *memoryStore) Rotate(key string) (string, error) {
func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
if !ok || item.isExpired() {
return "", ErrNotFound
}
set, ok := item.value.(map[string]struct{})
if !ok || len(set) == 0 {
return "", ErrNotFound
}
members := make([]string, 0, len(set))
for member := range set {
members = append(members, member)
}
if len(members) == 0 {
return "", ErrNotFound
}
s.rngMu.Lock()
n := s.rng.Intn(len(members))
s.rngMu.Unlock()
return members[n], nil
}
func (s *memoryStore) SUnionStore(_ context.Context, destination string, keys ...string) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
unionSet := make(map[string]struct{})
for _, key := range keys {
item, ok := s.items[key]
if !ok || item.isExpired() {
continue
}
if set, ok := item.value.(map[string]struct{}); ok {
for member := range set {
unionSet[member] = struct{}{}
}
}
}
destItem := &memoryStoreItem{value: unionSet}
s.items[destination] = destItem
return int64(len(unionSet)), nil
}
func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -426,7 +414,7 @@ func (s *memoryStore) Rotate(key string) (string, error) {
return val, nil
}
func (s *memoryStore) LIndex(key string, index int64) (string, error) {
func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -447,8 +435,17 @@ func (s *memoryStore) LIndex(key string, index int64) (string, error) {
return list[index], nil
}
// Zset methods... (ZAdd, ZRange, ZRem)
func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
func (s *memoryStore) MSet(ctx context.Context, values map[string]any) error {
s.mu.Lock()
defer s.mu.Unlock()
for key, value := range values {
// 内存存储不支持独立的 TTL因此我们假设永不过期
s.items[key] = &memoryStoreItem{value: value, expireAt: time.Time{}}
}
return nil
}
func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -471,8 +468,6 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
for val, score := range membersMap {
newZSet = append(newZSet, zsetMember{Value: val, Score: score})
}
// NOTE: This ZSet implementation is simple but not performant for large sets.
// A production implementation would use a skip list or a balanced tree.
sort.Slice(newZSet, func(i, j int) bool {
if newZSet[i].Score == newZSet[j].Score {
return newZSet[i].Value < newZSet[j].Value
@@ -482,7 +477,7 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
item.value = newZSet
return nil
}
func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
func (s *memoryStore) ZRange(_ context.Context, key string, start, stop int64) ([]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
item, ok := s.items[key]
@@ -515,7 +510,7 @@ func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
}
return result, nil
}
func (s *memoryStore) ZRem(key string, members ...any) error {
func (s *memoryStore) ZRem(_ context.Context, key string, members ...any) error {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[key]
@@ -540,13 +535,56 @@ func (s *memoryStore) ZRem(key string, members ...any) error {
return nil
}
// Pipeline implementation
func (s *memoryStore) PopAndCycleSetMember(_ context.Context, mainKey, cooldownKey string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
mainItem, mainOk := s.items[mainKey]
var mainSet map[string]struct{}
if mainOk && !mainItem.isExpired() {
mainSet, mainOk = mainItem.value.(map[string]struct{})
mainOk = mainOk && len(mainSet) > 0
} else {
mainOk = false
}
if !mainOk {
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
return "", ErrNotFound
}
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
if !cooldownSetOk || len(cooldownSet) == 0 {
return "", ErrNotFound
}
s.items[mainKey] = cooldownItem
delete(s.items, cooldownKey)
mainSet = cooldownSet
}
var popped string
for k := range mainSet {
popped = k
break
}
delete(mainSet, popped)
cooldownItem, cooldownOk := s.items[cooldownKey]
if !cooldownOk || cooldownItem.isExpired() {
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
s.items[cooldownKey] = cooldownItem
}
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
if !ok {
cooldownSet = make(map[string]struct{})
cooldownItem.value = cooldownSet
}
cooldownSet[popped] = struct{}{}
return popped, nil
}
type memoryPipeliner struct {
store *memoryStore
ops []func()
}
func (s *memoryStore) Pipeline() Pipeliner {
func (s *memoryStore) Pipeline(_ context.Context) Pipeliner {
return &memoryPipeliner{store: s}
}
func (p *memoryPipeliner) Exec() error {
@@ -559,7 +597,6 @@ func (p *memoryPipeliner) Exec() error {
}
func (p *memoryPipeliner) Expire(key string, expiration time.Duration) {
// [USER SUGGESTION APPLIED] Fix #4: Capture value, not reference
capturedKey := key
p.ops = append(p.ops, func() {
if item, ok := p.store.items[capturedKey]; ok {
@@ -576,6 +613,22 @@ func (p *memoryPipeliner) Del(keys ...string) {
}
})
}
func (p *memoryPipeliner) Set(key string, value []byte, expiration time.Duration) {
capturedKey := key
capturedValue := value
p.ops = append(p.ops, func() {
var expireAt time.Time
if expiration > 0 {
expireAt = time.Now().Add(expiration)
}
p.store.items[capturedKey] = &memoryStoreItem{
value: capturedValue,
expireAt: expireAt,
}
})
}
func (p *memoryPipeliner) SAdd(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
@@ -615,7 +668,6 @@ func (p *memoryPipeliner) SRem(key string, members ...any) {
}
})
}
func (p *memoryPipeliner) LPush(key string, values ...any) {
capturedKey := key
capturedValues := make([]any, len(values))
@@ -637,11 +689,150 @@ func (p *memoryPipeliner) LPush(key string, values ...any) {
item.value = append(stringValues, list...)
})
}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
func (p *memoryPipeliner) LRem(key string, count int64, value any) {
capturedKey := key
capturedValue := fmt.Sprintf("%v", value)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
return
}
list, ok := item.value.([]string)
if !ok {
return
}
newList := make([]string, 0, len(list))
removed := int64(0)
for _, v := range list {
shouldRemove := v == capturedValue && (count == 0 || removed < count)
if shouldRemove {
removed++
} else {
newList = append(newList, v)
}
}
item.value = newList
})
}
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
capturedKey := key
capturedValues := make(map[string]any, len(values))
for k, v := range values {
capturedValues[k] = v
}
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
p.store.items[capturedKey] = item
}
hash, ok := item.value.(map[string]string)
if !ok {
hash = make(map[string]string)
item.value = hash
}
for field, value := range capturedValues {
hash[field] = fmt.Sprintf("%v", value)
}
})
}
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {
capturedKey := key
capturedField := field
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make(map[string]string)}
p.store.items[capturedKey] = item
}
hash, ok := item.value.(map[string]string)
if !ok {
hash = make(map[string]string)
item.value = hash
}
current, _ := strconv.ParseInt(hash[capturedField], 10, 64)
hash[capturedField] = strconv.FormatInt(current+incr, 10)
})
}
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {
capturedKey := key
capturedMembers := make(map[string]float64, len(members))
for k, v := range members {
capturedMembers[k] = v
}
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
item = &memoryStoreItem{value: make([]zsetMember, 0)}
p.store.items[capturedKey] = item
}
zset, ok := item.value.([]zsetMember)
if !ok {
zset = make([]zsetMember, 0)
}
membersMap := make(map[string]float64, len(zset))
for _, z := range zset {
membersMap[z.Value] = z.Score
}
for memberVal, score := range capturedMembers {
membersMap[memberVal] = score
}
newZSet := make([]zsetMember, 0, len(membersMap))
for val, score := range membersMap {
newZSet = append(newZSet, zsetMember{Value: val, Score: score})
}
sort.Slice(newZSet, func(i, j int) bool {
if newZSet[i].Score == newZSet[j].Score {
return newZSet[i].Value < newZSet[j].Value
}
return newZSet[i].Score < newZSet[j].Score
})
item.value = newZSet
})
}
func (p *memoryPipeliner) ZRem(key string, members ...any) {
capturedKey := key
capturedMembers := make([]any, len(members))
copy(capturedMembers, members)
p.ops = append(p.ops, func() {
item, ok := p.store.items[capturedKey]
if !ok || item.isExpired() {
return
}
zset, ok := item.value.([]zsetMember)
if !ok {
return
}
membersToRemove := make(map[string]struct{}, len(capturedMembers))
for _, m := range capturedMembers {
membersToRemove[fmt.Sprintf("%v", m)] = struct{}{}
}
newZSet := make([]zsetMember, 0, len(zset))
for _, z := range zset {
if _, exists := membersToRemove[z.Value]; !exists {
newZSet = append(newZSet, z)
}
}
item.value = newZSet
})
}
func (p *memoryPipeliner) MSet(values map[string]any) {
capturedValues := make(map[string]any, len(values))
for k, v := range values {
capturedValues[k] = v
}
p.ops = append(p.ops, func() {
for key, value := range capturedValues {
p.store.items[key] = &memoryStoreItem{
value: value,
expireAt: time.Time{}, // Pipelined MSet 同样假设永不过期
}
}
})
}
// --- Pub/Sub implementation (remains unchanged) ---
type memorySubscription struct {
store *memoryStore
channelName string
@@ -649,10 +840,11 @@ type memorySubscription struct {
}
func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan }
func (ms *memorySubscription) ChannelName() string { return ms.channelName }
func (ms *memorySubscription) Close() error {
return ms.store.removeSubscriber(ms.channelName, ms.msgChan)
}
func (s *memoryStore) Publish(channel string, message []byte) error {
func (s *memoryStore) Publish(_ context.Context, channel string, message []byte) error {
s.mu.RLock()
defer s.mu.RUnlock()
subscribers, ok := s.pubsub[channel]
@@ -669,7 +861,7 @@ func (s *memoryStore) Publish(channel string, message []byte) error {
}
return nil
}
func (s *memoryStore) Subscribe(channel string) (Subscription, error) {
func (s *memoryStore) Subscribe(_ context.Context, channel string) (Subscription, error) {
s.mu.Lock()
defer s.mu.Unlock()
msgChan := make(chan *Message, 10)

View File

@@ -1,3 +1,5 @@
// Filename: internal/store/redis_store.go
package store
import (
@@ -8,22 +10,20 @@ import (
"time"
"github.com/redis/go-redis/v9"
"github.com/sirupsen/logrus"
)
// ensure RedisStore implements Store interface
var _ Store = (*RedisStore)(nil)
// RedisStore is a Redis-backed key-value store.
type RedisStore struct {
client *redis.Client
popAndCycleScript *redis.Script
logger *logrus.Entry
}
// NewRedisStore creates a new RedisStore instance.
func NewRedisStore(client *redis.Client) Store {
// Lua script for atomic pop-and-cycle operation.
// KEYS[1]: main set key
// KEYS[2]: cooldown set key
func NewRedisStore(client *redis.Client, logger *logrus.Logger) Store {
const script = `
if redis.call('SCARD', KEYS[1]) == 0 then
if redis.call('SCARD', KEYS[2]) == 0 then
@@ -36,15 +36,16 @@ func NewRedisStore(client *redis.Client) Store {
return &RedisStore{
client: client,
popAndCycleScript: redis.NewScript(script),
logger: logger.WithField("component", "store.redis 🗄️"),
}
}
func (s *RedisStore) Set(key string, value []byte, ttl time.Duration) error {
return s.client.Set(context.Background(), key, value, ttl).Err()
func (s *RedisStore) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
return s.client.Set(ctx, key, value, ttl).Err()
}
func (s *RedisStore) Get(key string) ([]byte, error) {
val, err := s.client.Get(context.Background(), key).Bytes()
func (s *RedisStore) Get(ctx context.Context, key string) ([]byte, error) {
val, err := s.client.Get(ctx, key).Bytes()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, ErrNotFound
@@ -54,53 +55,67 @@ func (s *RedisStore) Get(key string) ([]byte, error) {
return val, nil
}
func (s *RedisStore) Del(keys ...string) error {
func (s *RedisStore) Del(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}
return s.client.Del(context.Background(), keys...).Err()
return s.client.Del(ctx, keys...).Err()
}
func (s *RedisStore) Exists(key string) (bool, error) {
val, err := s.client.Exists(context.Background(), key).Result()
func (s *RedisStore) Exists(ctx context.Context, key string) (bool, error) {
val, err := s.client.Exists(ctx, key).Result()
return val > 0, err
}
func (s *RedisStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
return s.client.SetNX(context.Background(), key, value, ttl).Result()
func (s *RedisStore) SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
return s.client.SetNX(ctx, key, value, ttl).Result()
}
func (s *RedisStore) Close() error {
return s.client.Close()
}
func (s *RedisStore) HSet(key string, values map[string]any) error {
return s.client.HSet(context.Background(), key, values).Err()
func (s *RedisStore) Expire(ctx context.Context, key string, expiration time.Duration) error {
return s.client.Expire(ctx, key, expiration).Err()
}
func (s *RedisStore) HGetAll(key string) (map[string]string, error) {
return s.client.HGetAll(context.Background(), key).Result()
func (s *RedisStore) HSet(ctx context.Context, key string, values map[string]any) error {
return s.client.HSet(ctx, key, values).Err()
}
func (s *RedisStore) HIncrBy(key, field string, incr int64) (int64, error) {
return s.client.HIncrBy(context.Background(), key, field, incr).Result()
func (s *RedisStore) HGet(ctx context.Context, key, field string) (string, error) {
val, err := s.client.HGet(ctx, key, field).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
}
return "", err
}
return val, nil
}
func (s *RedisStore) HDel(key string, fields ...string) error {
func (s *RedisStore) HGetAll(ctx context.Context, key string) (map[string]string, error) {
return s.client.HGetAll(ctx, key).Result()
}
func (s *RedisStore) HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error) {
return s.client.HIncrBy(ctx, key, field, incr).Result()
}
func (s *RedisStore) HDel(ctx context.Context, key string, fields ...string) error {
if len(fields) == 0 {
return nil
}
return s.client.HDel(context.Background(), key, fields...).Err()
return s.client.HDel(ctx, key, fields...).Err()
}
func (s *RedisStore) LPush(key string, values ...any) error {
return s.client.LPush(context.Background(), key, values...).Err()
func (s *RedisStore) LPush(ctx context.Context, key string, values ...any) error {
return s.client.LPush(ctx, key, values...).Err()
}
func (s *RedisStore) LRem(key string, count int64, value any) error {
return s.client.LRem(context.Background(), key, count, value).Err()
func (s *RedisStore) LRem(ctx context.Context, key string, count int64, value any) error {
return s.client.LRem(ctx, key, count, value).Err()
}
func (s *RedisStore) Rotate(key string) (string, error) {
val, err := s.client.RPopLPush(context.Background(), key, key).Result()
func (s *RedisStore) Rotate(ctx context.Context, key string) (string, error) {
val, err := s.client.RPopLPush(ctx, key, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
@@ -110,29 +125,40 @@ func (s *RedisStore) Rotate(key string) (string, error) {
return val, nil
}
func (s *RedisStore) SAdd(key string, members ...any) error {
return s.client.SAdd(context.Background(), key, members...).Err()
func (s *RedisStore) MSet(ctx context.Context, values map[string]any) error {
if len(values) == 0 {
return nil
}
// Redis MSet 命令需要 [key1, value1, key2, value2, ...] 格式的切片
pairs := make([]interface{}, 0, len(values)*2)
for k, v := range values {
pairs = append(pairs, k, v)
}
return s.client.MSet(ctx, pairs...).Err()
}
func (s *RedisStore) SPopN(key string, count int64) ([]string, error) {
return s.client.SPopN(context.Background(), key, count).Result()
func (s *RedisStore) SAdd(ctx context.Context, key string, members ...any) error {
return s.client.SAdd(ctx, key, members...).Err()
}
func (s *RedisStore) SMembers(key string) ([]string, error) {
return s.client.SMembers(context.Background(), key).Result()
func (s *RedisStore) SPopN(ctx context.Context, key string, count int64) ([]string, error) {
return s.client.SPopN(ctx, key, count).Result()
}
func (s *RedisStore) SRem(key string, members ...any) error {
func (s *RedisStore) SMembers(ctx context.Context, key string) ([]string, error) {
return s.client.SMembers(ctx, key).Result()
}
func (s *RedisStore) SRem(ctx context.Context, key string, members ...any) error {
if len(members) == 0 {
return nil
}
return s.client.SRem(context.Background(), key, members...).Err()
return s.client.SRem(ctx, key, members...).Err()
}
func (s *RedisStore) SRandMember(key string) (string, error) {
member, err := s.client.SRandMember(context.Background(), key).Result()
func (s *RedisStore) SRandMember(ctx context.Context, key string) (string, error) {
member, err := s.client.SRandMember(ctx, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
}
@@ -141,81 +167,50 @@ func (s *RedisStore) SRandMember(key string) (string, error) {
return member, nil
}
// === 新增方法实现 ===
func (s *RedisStore) SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error) {
if len(keys) == 0 {
return 0, nil
}
return s.client.SUnionStore(ctx, destination, keys...).Result()
}
func (s *RedisStore) ZAdd(key string, members map[string]float64) error {
func (s *RedisStore) ZAdd(ctx context.Context, key string, members map[string]float64) error {
if len(members) == 0 {
return nil
}
redisMembers := make([]redis.Z, 0, len(members))
redisMembers := make([]redis.Z, len(members))
i := 0
for member, score := range members {
redisMembers = append(redisMembers, redis.Z{Score: score, Member: member})
redisMembers[i] = redis.Z{Score: score, Member: member}
i++
}
return s.client.ZAdd(context.Background(), key, redisMembers...).Err()
return s.client.ZAdd(ctx, key, redisMembers...).Err()
}
func (s *RedisStore) ZRange(key string, start, stop int64) ([]string, error) {
return s.client.ZRange(context.Background(), key, start, stop).Result()
func (s *RedisStore) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) {
return s.client.ZRange(ctx, key, start, stop).Result()
}
func (s *RedisStore) ZRem(key string, members ...any) error {
func (s *RedisStore) ZRem(ctx context.Context, key string, members ...any) error {
if len(members) == 0 {
return nil
}
return s.client.ZRem(context.Background(), key, members...).Err()
return s.client.ZRem(ctx, key, members...).Err()
}
func (s *RedisStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
val, err := s.popAndCycleScript.Run(context.Background(), s.client, []string{mainKey, cooldownKey}).Result()
func (s *RedisStore) PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error) {
val, err := s.popAndCycleScript.Run(ctx, s.client, []string{mainKey, cooldownKey}).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
}
return "", err
}
// Lua script returns a string, so we need to type assert
if str, ok := val.(string); ok {
return str, nil
}
return "", ErrNotFound // This happens if both sets were empty and the script returned nil
return "", ErrNotFound
}
type redisPipeliner struct{ pipe redis.Pipeliner }
func (p *redisPipeliner) HSet(key string, values map[string]any) {
p.pipe.HSet(context.Background(), key, values)
}
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
p.pipe.HIncrBy(context.Background(), key, field, incr)
}
func (p *redisPipeliner) Exec() error {
_, err := p.pipe.Exec(context.Background())
return err
}
func (p *redisPipeliner) Del(keys ...string) {
if len(keys) > 0 {
p.pipe.Del(context.Background(), keys...)
}
}
func (p *redisPipeliner) SAdd(key string, members ...any) {
p.pipe.SAdd(context.Background(), key, members...)
}
func (p *redisPipeliner) SRem(key string, members ...any) {
if len(members) > 0 {
p.pipe.SRem(context.Background(), key, members...)
}
}
func (p *redisPipeliner) LPush(key string, values ...any) {
p.pipe.LPush(context.Background(), key, values...)
}
func (p *redisPipeliner) LRem(key string, count int64, value any) {
p.pipe.LRem(context.Background(), key, count, value)
}
func (s *RedisStore) LIndex(key string, index int64) (string, error) {
val, err := s.client.LIndex(context.Background(), key, index).Result()
func (s *RedisStore) LIndex(ctx context.Context, key string, index int64) (string, error) {
val, err := s.client.LIndex(ctx, key, index).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", ErrNotFound
@@ -225,47 +220,131 @@ func (s *RedisStore) LIndex(key string, index int64) (string, error) {
return val, nil
}
func (p *redisPipeliner) Expire(key string, expiration time.Duration) {
p.pipe.Expire(context.Background(), key, expiration)
type redisPipeliner struct {
pipe redis.Pipeliner
ctx context.Context
}
func (s *RedisStore) Pipeline() Pipeliner {
return &redisPipeliner{pipe: s.client.Pipeline()}
func (s *RedisStore) Pipeline(ctx context.Context) Pipeliner {
return &redisPipeliner{
pipe: s.client.Pipeline(),
ctx: ctx,
}
}
func (p *redisPipeliner) Exec() error {
_, err := p.pipe.Exec(p.ctx)
return err
}
func (p *redisPipeliner) Del(keys ...string) { p.pipe.Del(p.ctx, keys...) }
func (p *redisPipeliner) Expire(key string, expiration time.Duration) {
p.pipe.Expire(p.ctx, key, expiration)
}
func (p *redisPipeliner) HSet(key string, values map[string]any) { p.pipe.HSet(p.ctx, key, values) }
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
p.pipe.HIncrBy(p.ctx, key, field, incr)
}
func (p *redisPipeliner) LPush(key string, values ...any) { p.pipe.LPush(p.ctx, key, values...) }
func (p *redisPipeliner) LRem(key string, count int64, value any) {
p.pipe.LRem(p.ctx, key, count, value)
}
func (p *redisPipeliner) Set(key string, value []byte, expiration time.Duration) {
p.pipe.Set(p.ctx, key, value, expiration)
}
func (p *redisPipeliner) MSet(values map[string]any) {
if len(values) == 0 {
return
}
p.pipe.MSet(p.ctx, values)
}
func (p *redisPipeliner) SAdd(key string, members ...any) { p.pipe.SAdd(p.ctx, key, members...) }
func (p *redisPipeliner) SRem(key string, members ...any) { p.pipe.SRem(p.ctx, key, members...) }
func (p *redisPipeliner) ZAdd(key string, members map[string]float64) {
if len(members) == 0 {
return
}
redisMembers := make([]redis.Z, len(members))
i := 0
for member, score := range members {
redisMembers[i] = redis.Z{Score: score, Member: member}
i++
}
p.pipe.ZAdd(p.ctx, key, redisMembers...)
}
func (p *redisPipeliner) ZRem(key string, members ...any) { p.pipe.ZRem(p.ctx, key, members...) }
type redisSubscription struct {
pubsub *redis.PubSub
msgChan chan *Message
once sync.Once
pubsub *redis.PubSub
msgChan chan *Message
logger *logrus.Entry
wg sync.WaitGroup
close context.CancelFunc
channelName string
}
func (s *RedisStore) Subscribe(ctx context.Context, channel string) (Subscription, error) {
pubsub := s.client.Subscribe(ctx, channel)
_, err := pubsub.Receive(ctx)
if err != nil {
_ = pubsub.Close()
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err)
}
subCtx, cancel := context.WithCancel(context.Background())
sub := &redisSubscription{
pubsub: pubsub,
msgChan: make(chan *Message, 10),
logger: s.logger,
close: cancel,
channelName: channel,
}
sub.wg.Add(1)
go sub.bridge(subCtx)
return sub, nil
}
func (rs *redisSubscription) bridge(ctx context.Context) {
defer rs.wg.Done()
defer close(rs.msgChan)
redisCh := rs.pubsub.Channel()
for {
select {
case <-ctx.Done():
return
case redisMsg, ok := <-redisCh:
if !ok {
return
}
msg := &Message{
Channel: redisMsg.Channel,
Payload: []byte(redisMsg.Payload),
}
select {
case rs.msgChan <- msg:
default:
rs.logger.Warnf("Message dropped for channel '%s' due to slow consumer.", rs.channelName)
}
}
}
}
func (rs *redisSubscription) Channel() <-chan *Message {
rs.once.Do(func() {
rs.msgChan = make(chan *Message)
go func() {
defer close(rs.msgChan)
for redisMsg := range rs.pubsub.Channel() {
rs.msgChan <- &Message{
Channel: redisMsg.Channel,
Payload: []byte(redisMsg.Payload),
}
}
}()
})
return rs.msgChan
}
func (rs *redisSubscription) Close() error { return rs.pubsub.Close() }
func (s *RedisStore) Publish(channel string, message []byte) error {
return s.client.Publish(context.Background(), channel, message).Err()
func (rs *redisSubscription) ChannelName() string {
return rs.channelName
}
func (s *RedisStore) Subscribe(channel string) (Subscription, error) {
pubsub := s.client.Subscribe(context.Background(), channel)
_, err := pubsub.Receive(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err)
}
return &redisSubscription{pubsub: pubsub}, nil
func (rs *redisSubscription) Close() error {
rs.close()
err := rs.pubsub.Close()
rs.wg.Wait()
return err
}
func (s *RedisStore) Publish(ctx context.Context, channel string, message []byte) error {
return s.client.Publish(ctx, channel, message).Err()
}

View File

@@ -1,6 +1,9 @@
// Filename: internal/store/store.go
package store
import (
"context"
"errors"
"time"
)
@@ -17,6 +20,7 @@ type Message struct {
// Subscription represents an active subscription to a pub/sub channel.
type Subscription interface {
Channel() <-chan *Message
ChannelName() string
Close() error
}
@@ -31,6 +35,8 @@ type Pipeliner interface {
HIncrBy(key, field string, incr int64)
// SET
MSet(values map[string]any)
Set(key string, value []byte, expiration time.Duration)
SAdd(key string, members ...any)
SRem(key string, members ...any)
@@ -38,6 +44,10 @@ type Pipeliner interface {
LPush(key string, values ...any)
LRem(key string, count int64, value any)
// ZSET
ZAdd(key string, members map[string]float64)
ZRem(key string, members ...any)
// Execution
Exec() error
}
@@ -45,44 +55,48 @@ type Pipeliner interface {
// Store is the master interface for our cache service.
type Store interface {
// Basic K/V operations
Set(key string, value []byte, ttl time.Duration) error
Get(key string) ([]byte, error)
Del(keys ...string) error
Exists(key string) (bool, error)
SetNX(key string, value []byte, ttl time.Duration) (bool, error)
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
Get(ctx context.Context, key string) ([]byte, error)
Del(ctx context.Context, keys ...string) error
Exists(ctx context.Context, key string) (bool, error)
SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error)
MSet(ctx context.Context, values map[string]any) error
// HASH operations
HSet(key string, values map[string]any) error
HGetAll(key string) (map[string]string, error)
HIncrBy(key, field string, incr int64) (int64, error)
HDel(key string, fields ...string) error // [新增]
HSet(ctx context.Context, key string, values map[string]any) error
HGet(ctx context.Context, key, field string) (string, error)
HGetAll(ctx context.Context, key string) (map[string]string, error)
HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error)
HDel(ctx context.Context, key string, fields ...string) error
// LIST operations
LPush(key string, values ...any) error
LRem(key string, count int64, value any) error
Rotate(key string) (string, error)
LIndex(key string, index int64) (string, error)
LPush(ctx context.Context, key string, values ...any) error
LRem(ctx context.Context, key string, count int64, value any) error
Rotate(ctx context.Context, key string) (string, error)
LIndex(ctx context.Context, key string, index int64) (string, error)
Expire(ctx context.Context, key string, expiration time.Duration) error
// SET operations
SAdd(key string, members ...any) error
SPopN(key string, count int64) ([]string, error)
SMembers(key string) ([]string, error)
SRem(key string, members ...any) error
SRandMember(key string) (string, error)
SAdd(ctx context.Context, key string, members ...any) error
SPopN(ctx context.Context, key string, count int64) ([]string, error)
SMembers(ctx context.Context, key string) ([]string, error)
SRem(ctx context.Context, key string, members ...any) error
SRandMember(ctx context.Context, key string) (string, error)
SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error)
// Pub/Sub operations
Publish(channel string, message []byte) error
Subscribe(channel string) (Subscription, error)
Publish(ctx context.Context, channel string, message []byte) error
Subscribe(ctx context.Context, channel string) (Subscription, error)
// Pipeline (optional) - 我们在redis实现它内存版暂时不实现
Pipeline() Pipeliner
// Pipeline
Pipeline(ctx context.Context) Pipeliner
// Close closes the store and releases any underlying resources.
Close() error
// === 新增方法,支持轮询策略 ===
ZAdd(key string, members map[string]float64) error
ZRange(key string, start, stop int64) ([]string, error)
ZRem(key string, members ...any) error
PopAndCycleSetMember(mainKey, cooldownKey string) (string, error)
// ZSET operations
ZAdd(ctx context.Context, key string, members map[string]float64) error
ZRange(ctx context.Context, key string, start, stop int64) ([]string, error)
ZRem(ctx context.Context, key string, members ...any) error
PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error)
}

View File

@@ -1,48 +1,57 @@
package syncer
import (
"context"
"fmt"
"gemini-balancer/internal/store"
"log"
"sync"
"time"
"github.com/sirupsen/logrus"
)
const (
ReconnectDelay = 5 * time.Second
ReloadTimeout = 30 * time.Second
)
// LoaderFunc
type LoaderFunc[T any] func() (T, error)
// CacheSyncer
type CacheSyncer[T any] struct {
mu sync.RWMutex
cache T
loader LoaderFunc[T]
store store.Store
channelName string
logger *logrus.Entry
stopChan chan struct{}
wg sync.WaitGroup
}
// NewCacheSyncer
func NewCacheSyncer[T any](
loader LoaderFunc[T],
store store.Store,
channelName string,
logger *logrus.Logger,
) (*CacheSyncer[T], error) {
s := &CacheSyncer[T]{
loader: loader,
store: store,
channelName: channelName,
logger: logger.WithField("component", fmt.Sprintf("CacheSyncer[%s]", channelName)),
stopChan: make(chan struct{}),
}
if err := s.reload(); err != nil {
return nil, fmt.Errorf("initial load for %s failed: %w", channelName, err)
return nil, fmt.Errorf("initial load failed: %w", err)
}
s.wg.Add(1)
go s.listenForUpdates()
return s, nil
}
// Get, Invalidate, Stop, reload 方法 .
func (s *CacheSyncer[T]) Get() T {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -50,33 +59,60 @@ func (s *CacheSyncer[T]) Get() T {
}
func (s *CacheSyncer[T]) Invalidate() error {
log.Printf("INFO: Publishing invalidation notification on channel '%s'", s.channelName)
return s.store.Publish(s.channelName, []byte("reload"))
s.logger.Info("Publishing invalidation notification")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.store.Publish(ctx, s.channelName, []byte("reload")); err != nil {
s.logger.WithError(err).Error("Failed to publish invalidation")
return err
}
return nil
}
func (s *CacheSyncer[T]) Stop() {
close(s.stopChan)
s.wg.Wait()
log.Printf("INFO: CacheSyncer for channel '%s' stopped.", s.channelName)
s.logger.Info("CacheSyncer stopped")
}
func (s *CacheSyncer[T]) reload() error {
log.Printf("INFO: Reloading cache for channel '%s'...", s.channelName)
newData, err := s.loader()
if err != nil {
log.Printf("ERROR: Failed to reload cache for '%s': %v", s.channelName, err)
return err
s.logger.Info("Reloading cache...")
ctx, cancel := context.WithTimeout(context.Background(), ReloadTimeout)
defer cancel()
type result struct {
data T
err error
}
resultChan := make(chan result, 1)
go func() {
data, err := s.loader()
resultChan <- result{data, err}
}()
select {
case res := <-resultChan:
if res.err != nil {
s.logger.WithError(res.err).Error("Failed to reload cache")
return res.err
}
s.mu.Lock()
s.cache = res.data
s.mu.Unlock()
s.logger.Info("Cache reloaded successfully")
return nil
case <-ctx.Done():
s.logger.Error("Cache reload timeout")
return fmt.Errorf("reload timeout after %v", ReloadTimeout)
}
s.mu.Lock()
s.cache = newData
s.mu.Unlock()
log.Printf("INFO: Cache for channel '%s' reloaded successfully.", s.channelName)
return nil
}
// listenForUpdates ...
func (s *CacheSyncer[T]) listenForUpdates() {
defer s.wg.Done()
for {
select {
case <-s.stopChan:
@@ -84,31 +120,39 @@ func (s *CacheSyncer[T]) listenForUpdates() {
default:
}
subscription, err := s.store.Subscribe(s.channelName)
if err != nil {
log.Printf("ERROR: Failed to subscribe to '%s', retrying in 5s: %v", s.channelName, err)
time.Sleep(5 * time.Second)
continue
}
log.Printf("INFO: Subscribed to channel '%s' for cache invalidation.", s.channelName)
subscriberLoop:
for {
if err := s.subscribeAndListen(); err != nil {
s.logger.WithError(err).Warnf("Subscription error, retrying in %v", ReconnectDelay)
select {
case _, ok := <-subscription.Channel():
if !ok {
log.Printf("WARN: Subscription channel '%s' closed, will re-subscribe.", s.channelName)
break subscriberLoop
}
log.Printf("INFO: Received invalidation notification on '%s', reloading cache.", s.channelName)
if err := s.reload(); err != nil {
log.Printf("ERROR: Failed to reload cache for '%s' after notification: %v", s.channelName, err)
}
case <-time.After(ReconnectDelay):
case <-s.stopChan:
subscription.Close()
return
}
}
subscription.Close()
}
}
func (s *CacheSyncer[T]) subscribeAndListen() error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
subscription, err := s.store.Subscribe(ctx, s.channelName)
if err != nil {
return fmt.Errorf("failed to subscribe: %w", err)
}
defer subscription.Close()
s.logger.Info("Subscribed to channel")
for {
select {
case msg, ok := <-subscription.Channel():
if !ok {
return fmt.Errorf("subscription channel closed")
}
s.logger.WithField("message", string(msg.Payload)).Info("Received invalidation notification")
if err := s.reload(); err != nil {
s.logger.WithError(err).Error("Failed to reload after notification")
}
case <-s.stopChan:
return nil
}
}
}

View File

@@ -1,7 +1,7 @@
// Filename: internal/task/task.go (最终校准版)
package task
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -12,18 +12,18 @@ import (
)
const (
ResultTTL = 60 * time.Minute
ResultTTL = 60 * time.Minute
DefaultTimeout = 24 * time.Hour
LockTTL = 30 * time.Minute
)
// Reporter 接口,定义了领域如何与任务服务交互。
type Reporter interface {
StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
EndTaskByID(taskID, resourceID string, result any, taskErr error)
UpdateProgressByID(taskID string, processed int) error
UpdateTotalByID(taskID string, total int) error
StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
EndTaskByID(ctx context.Context, taskID, resourceID string, result any, taskErr error)
UpdateProgressByID(ctx context.Context, taskID string, processed int) error
UpdateTotalByID(ctx context.Context, taskID string, total int) error
}
// Status 代表一个后台任务的完整状态
type Status struct {
ID string `json:"id"`
TaskType string `json:"task_type"`
@@ -38,13 +38,11 @@ type Status struct {
DurationSeconds float64 `json:"duration_seconds,omitempty"`
}
// Task 是任务管理的核心服务
type Task struct {
store store.Store
logger *logrus.Entry
}
// NewTask 是 Task 的构造函数
func NewTask(store store.Store, logger *logrus.Logger) *Task {
return &Task{
store: store,
@@ -62,21 +60,27 @@ func (s *Task) getTaskDataKey(taskID string) string {
return fmt.Sprintf("task:data:%s", taskID)
}
// --- 新增的輔助函數,用於獲取原子標記的鍵 ---
func (s *Task) getIsRunningFlagKey(taskID string) string {
return fmt.Sprintf("task:running:%s", taskID)
}
func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
lockKey := s.getResourceLockKey(resourceID)
taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID)
if existingTaskID, err := s.store.Get(lockKey); err == nil && len(existingTaskID) > 0 {
locked, err := s.store.SetNX(ctx, lockKey, []byte(taskID), LockTTL)
if err != nil {
return nil, fmt.Errorf("failed to acquire task lock: %w", err)
}
if !locked {
existingTaskID, _ := s.store.Get(ctx, lockKey)
return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID))
}
taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID)
taskKey := s.getTaskDataKey(taskID)
runningFlagKey := s.getIsRunningFlagKey(taskID)
if timeout == 0 {
timeout = DefaultTimeout
}
status := &Status{
ID: taskID,
TaskType: taskType,
@@ -85,81 +89,71 @@ func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int
Total: total,
StartedAt: time.Now(),
}
statusBytes, err := json.Marshal(status)
if err != nil {
return nil, fmt.Errorf("failed to serialize new task status: %w", err)
if err := s.saveStatus(ctx, taskID, status, timeout); err != nil {
_ = s.store.Del(ctx, lockKey)
return nil, fmt.Errorf("failed to save task status: %w", err)
}
if timeout == 0 {
timeout = ResultTTL * 24
runningFlagKey := s.getIsRunningFlagKey(taskID)
if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil {
_ = s.store.Del(ctx, lockKey)
_ = s.store.Del(ctx, s.getTaskDataKey(taskID))
return nil, fmt.Errorf("failed to set running flag: %w", err)
}
if err := s.store.Set(lockKey, []byte(taskID), timeout); err != nil {
return nil, fmt.Errorf("failed to acquire task resource lock: %w", err)
}
if err := s.store.Set(taskKey, statusBytes, timeout); err != nil {
_ = s.store.Del(lockKey)
return nil, fmt.Errorf("failed to set new task data in store: %w", err)
}
// 創建一個獨立的“運行中”標記,它的存在與否是原子性的
if err := s.store.Set(runningFlagKey, []byte("1"), timeout); err != nil {
_ = s.store.Del(lockKey)
_ = s.store.Del(taskKey)
return nil, fmt.Errorf("failed to set task running flag: %w", err)
}
return status, nil
}
func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr error) {
func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) {
lockKey := s.getResourceLockKey(resourceID)
defer func() {
if err := s.store.Del(lockKey); err != nil {
s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID)
}
}()
runningFlagKey := s.getIsRunningFlagKey(taskID)
_ = s.store.Del(runningFlagKey)
status, err := s.GetStatus(taskID)
if err != nil {
s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID)
defer func() {
_ = s.store.Del(ctx, lockKey)
_ = s.store.Del(ctx, runningFlagKey)
}()
status, err := s.GetStatus(ctx, taskID)
if err != nil {
s.logger.WithError(err).Errorf("Failed to get task status for %s during EndTask", taskID)
return
}
if !status.IsRunning {
s.logger.Warnf("EndTaskByID called for an already finished task: %s", taskID)
s.logger.Warnf("EndTaskByID called for already finished task: %s", taskID)
return
}
now := time.Now()
status.IsRunning = false
status.FinishedAt = &now
status.DurationSeconds = now.Sub(status.StartedAt).Seconds()
if taskErr != nil {
status.Error = taskErr.Error()
} else {
status.Result = resultData
}
updatedTaskBytes, _ := json.Marshal(status)
taskKey := s.getTaskDataKey(taskID)
if err := s.store.Set(taskKey, updatedTaskBytes, ResultTTL); err != nil {
s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID)
if err := s.saveStatus(ctx, taskID, status, ResultTTL); err != nil {
s.logger.WithError(err).Errorf("Failed to save final status for task %s", taskID)
}
}
// GetStatus 通过ID获取任务状态供外部如API Handler调用
func (s *Task) GetStatus(taskID string) (*Status, error) {
func (s *Task) GetStatus(ctx context.Context, taskID string) (*Status, error) {
taskKey := s.getTaskDataKey(taskID)
statusBytes, err := s.store.Get(taskKey)
statusBytes, err := s.store.Get(ctx, taskKey)
if err != nil {
if errors.Is(err, store.ErrNotFound) {
return nil, errors.New("task not found")
}
return nil, fmt.Errorf("failed to get task status from store: %w", err)
return nil, fmt.Errorf("failed to get task status: %w", err)
}
var status Status
if err := json.Unmarshal(statusBytes, &status); err != nil {
return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID)
return nil, fmt.Errorf("corrupted task data for ID %s: %w", taskID, err)
}
if !status.IsRunning && status.FinishedAt != nil {
@@ -169,46 +163,51 @@ func (s *Task) GetStatus(taskID string) (*Status, error) {
return &status, nil
}
// UpdateProgressByID 通过ID更新任务进度
func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error {
runningFlagKey := s.getIsRunningFlagKey(taskID)
if _, err := s.store.Get(runningFlagKey); err != nil {
// 任务已结束,静默返回是预期行为。
return nil
if _, err := s.store.Get(ctx, runningFlagKey); err != nil {
if errors.Is(err, store.ErrNotFound) {
return errors.New("task is not running")
}
return fmt.Errorf("failed to check running flag: %w", err)
}
status, err := s.GetStatus(taskID)
status, err := s.GetStatus(ctx, taskID)
if err != nil {
s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID)
return nil
return fmt.Errorf("failed to get task status: %w", err)
}
if !status.IsRunning {
return nil
return errors.New("task is not running")
}
// 调用传入的 updater 函数来修改 status
updater(status)
statusBytes, marshalErr := json.Marshal(status)
if marshalErr != nil {
s.logger.WithError(marshalErr).Errorf("Failed to serialize status for update on task %s.", taskID)
return nil
}
taskKey := s.getTaskDataKey(taskID)
// 使用更长的TTL确保运行中的任务不会过早过期
if err := s.store.Set(taskKey, statusBytes, ResultTTL*24); err != nil {
s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID)
}
return nil
return s.saveStatus(ctx, taskID, status, DefaultTimeout)
}
// [REFACTORED] UpdateProgressByID 现在是一个简单的、调用通用更新器的包装器。
func (s *Task) UpdateProgressByID(taskID string, processed int) error {
return s.updateTask(taskID, func(status *Status) {
func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error {
return s.updateTask(ctx, taskID, func(status *Status) {
status.Processed = processed
})
}
// [REFACTORED] UpdateTotalByID 现在也是一个简单的、调用通用更新器的包装器。
func (s *Task) UpdateTotalByID(taskID string, total int) error {
return s.updateTask(taskID, func(status *Status) {
func (s *Task) UpdateTotalByID(ctx context.Context, taskID string, total int) error {
return s.updateTask(ctx, taskID, func(status *Status) {
status.Total = total
})
}
func (s *Task) saveStatus(ctx context.Context, taskID string, status *Status, ttl time.Duration) error {
statusBytes, err := json.Marshal(status)
if err != nil {
return fmt.Errorf("failed to serialize status: %w", err)
}
taskKey := s.getTaskDataKey(taskID)
if err := s.store.Set(ctx, taskKey, statusBytes, ttl); err != nil {
return fmt.Errorf("failed to save status: %w", err)
}
return nil
}

View File

@@ -1,43 +1,118 @@
// Filename: internal/webhandlers/auth_handler.go (最终现代化改造版)
// Filename: internal/webhandlers/auth_handler.go
package webhandlers
import (
"gemini-balancer/internal/middleware"
"gemini-balancer/internal/service" // [核心改造] 依赖service层
"gemini-balancer/internal/service"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
// WebAuthHandler [核心改造] 依赖关系净化注入SecurityService
// WebAuthHandler Web 认证处理器
type WebAuthHandler struct {
securityService *service.SecurityService
logger *logrus.Logger
}
// NewWebAuthHandler [核心改造] 构造函数更新
// NewWebAuthHandler 创建 WebAuthHandler
func NewWebAuthHandler(securityService *service.SecurityService) *WebAuthHandler {
logger := logrus.New()
logger.SetLevel(logrus.InfoLevel)
return &WebAuthHandler{
securityService: securityService,
logger: logger,
}
}
// ShowLoginPage 保持不变
// ShowLoginPage 显示登录页面
func (h *WebAuthHandler) ShowLoginPage(c *gin.Context) {
errMsg := c.Query("error")
from := c.Query("from") // 可以从登录失败的页面返回
// 验证重定向路径(防止开放重定向攻击)
redirectPath := h.validateRedirectPath(c.Query("redirect"))
// 如果已登录,直接重定向
if cookie := middleware.ExtractTokenFromCookie(c); cookie != "" {
if _, err := h.securityService.AuthenticateToken(cookie); err == nil {
c.Redirect(http.StatusFound, redirectPath)
return
}
}
c.HTML(http.StatusOK, "auth.html", gin.H{
"error": errMsg,
"from": from,
"error": errMsg,
"redirect": redirectPath,
})
}
// HandleLogin [核心改造] 认证逻辑完全委托给SecurityService
// HandleLogin 已废弃(项目无用户名系统)
func (h *WebAuthHandler) HandleLogin(c *gin.Context) {
c.Redirect(http.StatusFound, "/login?error=DEPRECATED_LOGIN_METHOD")
}
// HandleLogout 保持不变
// HandleLogout 处理登出请求
func (h *WebAuthHandler) HandleLogout(c *gin.Context) {
cookie := middleware.ExtractTokenFromCookie(c)
if cookie != "" {
// 尝试获取 Token 信息用于日志
authToken, err := h.securityService.AuthenticateToken(cookie)
if err == nil {
h.logger.WithFields(logrus.Fields{
"token_id": authToken.ID,
"client_ip": c.ClientIP(),
}).Info("User logged out")
} else {
h.logger.WithField("client_ip", c.ClientIP()).Warn("Logout with invalid token")
}
// 使缓存失效
middleware.InvalidateTokenCache(cookie)
} else {
h.logger.WithField("client_ip", c.ClientIP()).Debug("Logout without session cookie")
}
// 清除 Cookie
middleware.ClearAdminSessionCookie(c)
// 重定向到登录页
c.Redirect(http.StatusFound, "/login")
}
// validateRedirectPath 验证重定向路径(防止开放重定向攻击)
func (h *WebAuthHandler) validateRedirectPath(path string) string {
defaultPath := "/dashboard"
if path == "" {
return defaultPath
}
// 只允许内部路径
if !strings.HasPrefix(path, "/") || strings.HasPrefix(path, "//") {
h.logger.WithField("path", path).Warn("Invalid redirect path blocked")
return defaultPath
}
// 白名单验证
allowedPaths := []string{
"/dashboard",
"/keys",
"/settings",
"/logs",
"/tasks",
"/chat",
}
for _, allowed := range allowedPaths {
if strings.HasPrefix(path, allowed) {
return path
}
}
return defaultPath
}

View File

@@ -8,6 +8,12 @@ module.exports = {
'./web/templates/**/*.html',
'./web/static/js/**/*.js',
],
safelist: [
'grid-rows-[1]',
{
pattern: /data-\[(expanded|active)\]/,
},
],
theme: {
extend: {
// 定义语义化颜色

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,30 @@
var __create = Object.create;
var __defProp = Object.defineProperty;
var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
var __getOwnPropNames = Object.getOwnPropertyNames;
var __getProtoOf = Object.getPrototypeOf;
var __hasOwnProp = Object.prototype.hasOwnProperty;
var __commonJS = (cb, mod) => function __require() {
return mod || (0, cb[__getOwnPropNames(cb)[0]])((mod = { exports: {} }).exports, mod), mod.exports;
};
var __copyProps = (to, from, except, desc) => {
if (from && typeof from === "object" || typeof from === "function") {
for (let key of __getOwnPropNames(from))
if (!__hasOwnProp.call(to, key) && key !== except)
__defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable });
}
return to;
};
var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__getProtoOf(mod)) : {}, __copyProps(
// If the importer is in node compatibility mode or this is not an ESM
// file that has been converted to a CommonJS file using a Babel-
// compatible transform (i.e. "__esModule" has not been set), then set
// "default" to the CommonJS "module.exports" for node compatibility.
isNodeMode || !mod || !mod.__esModule ? __defProp(target, "default", { value: mod, enumerable: true }) : target,
mod
));
export {
__commonJS,
__toESM
};

View File

@@ -1,83 +0,0 @@
// frontend/js/services/api.js
var APIClientError = class extends Error {
constructor(message, status, code, rawMessageFromServer) {
super(message);
this.name = "APIClientError";
this.status = status;
this.code = code;
this.rawMessageFromServer = rawMessageFromServer;
}
};
var apiPromiseCache = /* @__PURE__ */ new Map();
async function apiFetch(url, options = {}) {
const isGetRequest = !options.method || options.method.toUpperCase() === "GET";
const cacheKey = isGetRequest && !options.noCache ? url : null;
if (cacheKey && apiPromiseCache.has(cacheKey)) {
return apiPromiseCache.get(cacheKey);
}
const token = localStorage.getItem("bearerToken");
const headers = {
"Content-Type": "application/json",
...options.headers
};
if (token) {
headers["Authorization"] = `Bearer ${token}`;
}
const requestPromise = (async () => {
try {
const response = await fetch(url, { ...options, headers });
if (response.status === 401) {
if (cacheKey) apiPromiseCache.delete(cacheKey);
localStorage.removeItem("bearerToken");
if (window.location.pathname !== "/login") {
window.location.href = "/login?error=\u4F1A\u8BDD\u5DF2\u8FC7\u671F\uFF0C\u8BF7\u91CD\u65B0\u767B\u5F55\u3002";
}
throw new APIClientError("Unauthorized", 401, "UNAUTHORIZED", "Session expired or token is invalid.");
}
if (!response.ok) {
let errorData = null;
let rawMessage = "";
try {
rawMessage = await response.text();
if (rawMessage) {
errorData = JSON.parse(rawMessage);
}
} catch (e) {
errorData = { error: { code: "UNKNOWN_FORMAT", message: rawMessage || response.statusText } };
}
const code = errorData?.error?.code || "UNKNOWN_ERROR";
const messageFromServer = errorData?.error?.message || rawMessage || "No message provided by server.";
const error = new APIClientError(
`API request failed: ${response.status}`,
response.status,
code,
messageFromServer
);
throw error;
}
return response;
} catch (error) {
if (cacheKey) apiPromiseCache.delete(cacheKey);
throw error;
}
})();
if (cacheKey) {
apiPromiseCache.set(cacheKey, requestPromise);
}
return requestPromise;
}
async function apiFetchJson(url, options = {}) {
try {
const response = await apiFetch(url, options);
const clonedResponse = response.clone();
const jsonData = await clonedResponse.json();
return jsonData;
} catch (error) {
throw error;
}
}
export {
apiFetch,
apiFetchJson
};

View File

@@ -0,0 +1,606 @@
import {
__commonJS,
__toESM
} from "./chunk-JSBRDJBE.js";
// frontend/js/vendor/popper.esm.min.js
var require_popper_esm_min = __commonJS({
"frontend/js/vendor/popper.esm.min.js"(exports, module) {
!(function(e, t) {
"object" == typeof exports && "undefined" != typeof module ? t(exports) : "function" == typeof define && define.amd ? define(["exports"], t) : t((e = "undefined" != typeof globalThis ? globalThis : e || self).Popper = {});
})(exports, (function(e) {
"use strict";
function t(e2) {
if (null == e2) return window;
if ("[object Window]" !== e2.toString()) {
var t2 = e2.ownerDocument;
return t2 && t2.defaultView || window;
}
return e2;
}
function n(e2) {
return e2 instanceof t(e2).Element || e2 instanceof Element;
}
function r(e2) {
return e2 instanceof t(e2).HTMLElement || e2 instanceof HTMLElement;
}
function o(e2) {
return "undefined" != typeof ShadowRoot && (e2 instanceof t(e2).ShadowRoot || e2 instanceof ShadowRoot);
}
var i = Math.max, a = Math.min, s = Math.round;
function f() {
var e2 = navigator.userAgentData;
return null != e2 && e2.brands && Array.isArray(e2.brands) ? e2.brands.map((function(e3) {
return e3.brand + "/" + e3.version;
})).join(" ") : navigator.userAgent;
}
function c() {
return !/^((?!chrome|android).)*safari/i.test(f());
}
function p(e2, o2, i2) {
void 0 === o2 && (o2 = false), void 0 === i2 && (i2 = false);
var a2 = e2.getBoundingClientRect(), f2 = 1, p2 = 1;
o2 && r(e2) && (f2 = e2.offsetWidth > 0 && s(a2.width) / e2.offsetWidth || 1, p2 = e2.offsetHeight > 0 && s(a2.height) / e2.offsetHeight || 1);
var u2 = (n(e2) ? t(e2) : window).visualViewport, l2 = !c() && i2, d2 = (a2.left + (l2 && u2 ? u2.offsetLeft : 0)) / f2, h2 = (a2.top + (l2 && u2 ? u2.offsetTop : 0)) / p2, m2 = a2.width / f2, v2 = a2.height / p2;
return { width: m2, height: v2, top: h2, right: d2 + m2, bottom: h2 + v2, left: d2, x: d2, y: h2 };
}
function u(e2) {
var n2 = t(e2);
return { scrollLeft: n2.pageXOffset, scrollTop: n2.pageYOffset };
}
function l(e2) {
return e2 ? (e2.nodeName || "").toLowerCase() : null;
}
function d(e2) {
return ((n(e2) ? e2.ownerDocument : e2.document) || window.document).documentElement;
}
function h(e2) {
return p(d(e2)).left + u(e2).scrollLeft;
}
function m(e2) {
return t(e2).getComputedStyle(e2);
}
function v(e2) {
var t2 = m(e2), n2 = t2.overflow, r2 = t2.overflowX, o2 = t2.overflowY;
return /auto|scroll|overlay|hidden/.test(n2 + o2 + r2);
}
function y(e2, n2, o2) {
void 0 === o2 && (o2 = false);
var i2, a2, f2 = r(n2), c2 = r(n2) && (function(e3) {
var t2 = e3.getBoundingClientRect(), n3 = s(t2.width) / e3.offsetWidth || 1, r2 = s(t2.height) / e3.offsetHeight || 1;
return 1 !== n3 || 1 !== r2;
})(n2), m2 = d(n2), y2 = p(e2, c2, o2), g2 = { scrollLeft: 0, scrollTop: 0 }, b2 = { x: 0, y: 0 };
return (f2 || !f2 && !o2) && (("body" !== l(n2) || v(m2)) && (g2 = (i2 = n2) !== t(i2) && r(i2) ? { scrollLeft: (a2 = i2).scrollLeft, scrollTop: a2.scrollTop } : u(i2)), r(n2) ? ((b2 = p(n2, true)).x += n2.clientLeft, b2.y += n2.clientTop) : m2 && (b2.x = h(m2))), { x: y2.left + g2.scrollLeft - b2.x, y: y2.top + g2.scrollTop - b2.y, width: y2.width, height: y2.height };
}
function g(e2) {
var t2 = p(e2), n2 = e2.offsetWidth, r2 = e2.offsetHeight;
return Math.abs(t2.width - n2) <= 1 && (n2 = t2.width), Math.abs(t2.height - r2) <= 1 && (r2 = t2.height), { x: e2.offsetLeft, y: e2.offsetTop, width: n2, height: r2 };
}
function b(e2) {
return "html" === l(e2) ? e2 : e2.assignedSlot || e2.parentNode || (o(e2) ? e2.host : null) || d(e2);
}
function x(e2) {
return ["html", "body", "#document"].indexOf(l(e2)) >= 0 ? e2.ownerDocument.body : r(e2) && v(e2) ? e2 : x(b(e2));
}
function w(e2, n2) {
var r2;
void 0 === n2 && (n2 = []);
var o2 = x(e2), i2 = o2 === (null == (r2 = e2.ownerDocument) ? void 0 : r2.body), a2 = t(o2), s2 = i2 ? [a2].concat(a2.visualViewport || [], v(o2) ? o2 : []) : o2, f2 = n2.concat(s2);
return i2 ? f2 : f2.concat(w(b(s2)));
}
function O(e2) {
return ["table", "td", "th"].indexOf(l(e2)) >= 0;
}
function j(e2) {
return r(e2) && "fixed" !== m(e2).position ? e2.offsetParent : null;
}
function E(e2) {
for (var n2 = t(e2), i2 = j(e2); i2 && O(i2) && "static" === m(i2).position; ) i2 = j(i2);
return i2 && ("html" === l(i2) || "body" === l(i2) && "static" === m(i2).position) ? n2 : i2 || (function(e3) {
var t2 = /firefox/i.test(f());
if (/Trident/i.test(f()) && r(e3) && "fixed" === m(e3).position) return null;
var n3 = b(e3);
for (o(n3) && (n3 = n3.host); r(n3) && ["html", "body"].indexOf(l(n3)) < 0; ) {
var i3 = m(n3);
if ("none" !== i3.transform || "none" !== i3.perspective || "paint" === i3.contain || -1 !== ["transform", "perspective"].indexOf(i3.willChange) || t2 && "filter" === i3.willChange || t2 && i3.filter && "none" !== i3.filter) return n3;
n3 = n3.parentNode;
}
return null;
})(e2) || n2;
}
var D = "top", A = "bottom", L = "right", P = "left", M = "auto", k = [D, A, L, P], W = "start", B = "end", H = "viewport", T = "popper", R = k.reduce((function(e2, t2) {
return e2.concat([t2 + "-" + W, t2 + "-" + B]);
}), []), S = [].concat(k, [M]).reduce((function(e2, t2) {
return e2.concat([t2, t2 + "-" + W, t2 + "-" + B]);
}), []), V = ["beforeRead", "read", "afterRead", "beforeMain", "main", "afterMain", "beforeWrite", "write", "afterWrite"];
function q(e2) {
var t2 = /* @__PURE__ */ new Map(), n2 = /* @__PURE__ */ new Set(), r2 = [];
function o2(e3) {
n2.add(e3.name), [].concat(e3.requires || [], e3.requiresIfExists || []).forEach((function(e4) {
if (!n2.has(e4)) {
var r3 = t2.get(e4);
r3 && o2(r3);
}
})), r2.push(e3);
}
return e2.forEach((function(e3) {
t2.set(e3.name, e3);
})), e2.forEach((function(e3) {
n2.has(e3.name) || o2(e3);
})), r2;
}
function C(e2, t2) {
var n2 = t2.getRootNode && t2.getRootNode();
if (e2.contains(t2)) return true;
if (n2 && o(n2)) {
var r2 = t2;
do {
if (r2 && e2.isSameNode(r2)) return true;
r2 = r2.parentNode || r2.host;
} while (r2);
}
return false;
}
function N(e2) {
return Object.assign({}, e2, { left: e2.x, top: e2.y, right: e2.x + e2.width, bottom: e2.y + e2.height });
}
function I(e2, r2, o2) {
return r2 === H ? N((function(e3, n2) {
var r3 = t(e3), o3 = d(e3), i2 = r3.visualViewport, a2 = o3.clientWidth, s2 = o3.clientHeight, f2 = 0, p2 = 0;
if (i2) {
a2 = i2.width, s2 = i2.height;
var u2 = c();
(u2 || !u2 && "fixed" === n2) && (f2 = i2.offsetLeft, p2 = i2.offsetTop);
}
return { width: a2, height: s2, x: f2 + h(e3), y: p2 };
})(e2, o2)) : n(r2) ? (function(e3, t2) {
var n2 = p(e3, false, "fixed" === t2);
return n2.top = n2.top + e3.clientTop, n2.left = n2.left + e3.clientLeft, n2.bottom = n2.top + e3.clientHeight, n2.right = n2.left + e3.clientWidth, n2.width = e3.clientWidth, n2.height = e3.clientHeight, n2.x = n2.left, n2.y = n2.top, n2;
})(r2, o2) : N((function(e3) {
var t2, n2 = d(e3), r3 = u(e3), o3 = null == (t2 = e3.ownerDocument) ? void 0 : t2.body, a2 = i(n2.scrollWidth, n2.clientWidth, o3 ? o3.scrollWidth : 0, o3 ? o3.clientWidth : 0), s2 = i(n2.scrollHeight, n2.clientHeight, o3 ? o3.scrollHeight : 0, o3 ? o3.clientHeight : 0), f2 = -r3.scrollLeft + h(e3), c2 = -r3.scrollTop;
return "rtl" === m(o3 || n2).direction && (f2 += i(n2.clientWidth, o3 ? o3.clientWidth : 0) - a2), { width: a2, height: s2, x: f2, y: c2 };
})(d(e2)));
}
function _(e2, t2, o2, s2) {
var f2 = "clippingParents" === t2 ? (function(e3) {
var t3 = w(b(e3)), o3 = ["absolute", "fixed"].indexOf(m(e3).position) >= 0 && r(e3) ? E(e3) : e3;
return n(o3) ? t3.filter((function(e4) {
return n(e4) && C(e4, o3) && "body" !== l(e4);
})) : [];
})(e2) : [].concat(t2), c2 = [].concat(f2, [o2]), p2 = c2[0], u2 = c2.reduce((function(t3, n2) {
var r2 = I(e2, n2, s2);
return t3.top = i(r2.top, t3.top), t3.right = a(r2.right, t3.right), t3.bottom = a(r2.bottom, t3.bottom), t3.left = i(r2.left, t3.left), t3;
}), I(e2, p2, s2));
return u2.width = u2.right - u2.left, u2.height = u2.bottom - u2.top, u2.x = u2.left, u2.y = u2.top, u2;
}
function F(e2) {
return e2.split("-")[0];
}
function U(e2) {
return e2.split("-")[1];
}
function z(e2) {
return ["top", "bottom"].indexOf(e2) >= 0 ? "x" : "y";
}
function X(e2) {
var t2, n2 = e2.reference, r2 = e2.element, o2 = e2.placement, i2 = o2 ? F(o2) : null, a2 = o2 ? U(o2) : null, s2 = n2.x + n2.width / 2 - r2.width / 2, f2 = n2.y + n2.height / 2 - r2.height / 2;
switch (i2) {
case D:
t2 = { x: s2, y: n2.y - r2.height };
break;
case A:
t2 = { x: s2, y: n2.y + n2.height };
break;
case L:
t2 = { x: n2.x + n2.width, y: f2 };
break;
case P:
t2 = { x: n2.x - r2.width, y: f2 };
break;
default:
t2 = { x: n2.x, y: n2.y };
}
var c2 = i2 ? z(i2) : null;
if (null != c2) {
var p2 = "y" === c2 ? "height" : "width";
switch (a2) {
case W:
t2[c2] = t2[c2] - (n2[p2] / 2 - r2[p2] / 2);
break;
case B:
t2[c2] = t2[c2] + (n2[p2] / 2 - r2[p2] / 2);
}
}
return t2;
}
function Y(e2) {
return Object.assign({}, { top: 0, right: 0, bottom: 0, left: 0 }, e2);
}
function G(e2, t2) {
return t2.reduce((function(t3, n2) {
return t3[n2] = e2, t3;
}), {});
}
function J(e2, t2) {
void 0 === t2 && (t2 = {});
var r2 = t2, o2 = r2.placement, i2 = void 0 === o2 ? e2.placement : o2, a2 = r2.strategy, s2 = void 0 === a2 ? e2.strategy : a2, f2 = r2.boundary, c2 = void 0 === f2 ? "clippingParents" : f2, u2 = r2.rootBoundary, l2 = void 0 === u2 ? H : u2, h2 = r2.elementContext, m2 = void 0 === h2 ? T : h2, v2 = r2.altBoundary, y2 = void 0 !== v2 && v2, g2 = r2.padding, b2 = void 0 === g2 ? 0 : g2, x2 = Y("number" != typeof b2 ? b2 : G(b2, k)), w2 = m2 === T ? "reference" : T, O2 = e2.rects.popper, j2 = e2.elements[y2 ? w2 : m2], E2 = _(n(j2) ? j2 : j2.contextElement || d(e2.elements.popper), c2, l2, s2), P2 = p(e2.elements.reference), M2 = X({ reference: P2, element: O2, strategy: "absolute", placement: i2 }), W2 = N(Object.assign({}, O2, M2)), B2 = m2 === T ? W2 : P2, R2 = { top: E2.top - B2.top + x2.top, bottom: B2.bottom - E2.bottom + x2.bottom, left: E2.left - B2.left + x2.left, right: B2.right - E2.right + x2.right }, S2 = e2.modifiersData.offset;
if (m2 === T && S2) {
var V2 = S2[i2];
Object.keys(R2).forEach((function(e3) {
var t3 = [L, A].indexOf(e3) >= 0 ? 1 : -1, n2 = [D, A].indexOf(e3) >= 0 ? "y" : "x";
R2[e3] += V2[n2] * t3;
}));
}
return R2;
}
var K = { placement: "bottom", modifiers: [], strategy: "absolute" };
function Q() {
for (var e2 = arguments.length, t2 = new Array(e2), n2 = 0; n2 < e2; n2++) t2[n2] = arguments[n2];
return !t2.some((function(e3) {
return !(e3 && "function" == typeof e3.getBoundingClientRect);
}));
}
function Z(e2) {
void 0 === e2 && (e2 = {});
var t2 = e2, r2 = t2.defaultModifiers, o2 = void 0 === r2 ? [] : r2, i2 = t2.defaultOptions, a2 = void 0 === i2 ? K : i2;
return function(e3, t3, r3) {
void 0 === r3 && (r3 = a2);
var i3, s2, f2 = { placement: "bottom", orderedModifiers: [], options: Object.assign({}, K, a2), modifiersData: {}, elements: { reference: e3, popper: t3 }, attributes: {}, styles: {} }, c2 = [], p2 = false, u2 = { state: f2, setOptions: function(r4) {
var i4 = "function" == typeof r4 ? r4(f2.options) : r4;
l2(), f2.options = Object.assign({}, a2, f2.options, i4), f2.scrollParents = { reference: n(e3) ? w(e3) : e3.contextElement ? w(e3.contextElement) : [], popper: w(t3) };
var s3, p3, d2 = (function(e4) {
var t4 = q(e4);
return V.reduce((function(e5, n2) {
return e5.concat(t4.filter((function(e6) {
return e6.phase === n2;
})));
}), []);
})((s3 = [].concat(o2, f2.options.modifiers), p3 = s3.reduce((function(e4, t4) {
var n2 = e4[t4.name];
return e4[t4.name] = n2 ? Object.assign({}, n2, t4, { options: Object.assign({}, n2.options, t4.options), data: Object.assign({}, n2.data, t4.data) }) : t4, e4;
}), {}), Object.keys(p3).map((function(e4) {
return p3[e4];
}))));
return f2.orderedModifiers = d2.filter((function(e4) {
return e4.enabled;
})), f2.orderedModifiers.forEach((function(e4) {
var t4 = e4.name, n2 = e4.options, r5 = void 0 === n2 ? {} : n2, o3 = e4.effect;
if ("function" == typeof o3) {
var i5 = o3({ state: f2, name: t4, instance: u2, options: r5 }), a3 = function() {
};
c2.push(i5 || a3);
}
})), u2.update();
}, forceUpdate: function() {
if (!p2) {
var e4 = f2.elements, t4 = e4.reference, n2 = e4.popper;
if (Q(t4, n2)) {
f2.rects = { reference: y(t4, E(n2), "fixed" === f2.options.strategy), popper: g(n2) }, f2.reset = false, f2.placement = f2.options.placement, f2.orderedModifiers.forEach((function(e5) {
return f2.modifiersData[e5.name] = Object.assign({}, e5.data);
}));
for (var r4 = 0; r4 < f2.orderedModifiers.length; r4++) if (true !== f2.reset) {
var o3 = f2.orderedModifiers[r4], i4 = o3.fn, a3 = o3.options, s3 = void 0 === a3 ? {} : a3, c3 = o3.name;
"function" == typeof i4 && (f2 = i4({ state: f2, options: s3, name: c3, instance: u2 }) || f2);
} else f2.reset = false, r4 = -1;
}
}
}, update: (i3 = function() {
return new Promise((function(e4) {
u2.forceUpdate(), e4(f2);
}));
}, function() {
return s2 || (s2 = new Promise((function(e4) {
Promise.resolve().then((function() {
s2 = void 0, e4(i3());
}));
}))), s2;
}), destroy: function() {
l2(), p2 = true;
} };
if (!Q(e3, t3)) return u2;
function l2() {
c2.forEach((function(e4) {
return e4();
})), c2 = [];
}
return u2.setOptions(r3).then((function(e4) {
!p2 && r3.onFirstUpdate && r3.onFirstUpdate(e4);
})), u2;
};
}
var $ = { passive: true };
var ee = { name: "eventListeners", enabled: true, phase: "write", fn: function() {
}, effect: function(e2) {
var n2 = e2.state, r2 = e2.instance, o2 = e2.options, i2 = o2.scroll, a2 = void 0 === i2 || i2, s2 = o2.resize, f2 = void 0 === s2 || s2, c2 = t(n2.elements.popper), p2 = [].concat(n2.scrollParents.reference, n2.scrollParents.popper);
return a2 && p2.forEach((function(e3) {
e3.addEventListener("scroll", r2.update, $);
})), f2 && c2.addEventListener("resize", r2.update, $), function() {
a2 && p2.forEach((function(e3) {
e3.removeEventListener("scroll", r2.update, $);
})), f2 && c2.removeEventListener("resize", r2.update, $);
};
}, data: {} };
var te = { name: "popperOffsets", enabled: true, phase: "read", fn: function(e2) {
var t2 = e2.state, n2 = e2.name;
t2.modifiersData[n2] = X({ reference: t2.rects.reference, element: t2.rects.popper, strategy: "absolute", placement: t2.placement });
}, data: {} }, ne = { top: "auto", right: "auto", bottom: "auto", left: "auto" };
function re(e2) {
var n2, r2 = e2.popper, o2 = e2.popperRect, i2 = e2.placement, a2 = e2.variation, f2 = e2.offsets, c2 = e2.position, p2 = e2.gpuAcceleration, u2 = e2.adaptive, l2 = e2.roundOffsets, h2 = e2.isFixed, v2 = f2.x, y2 = void 0 === v2 ? 0 : v2, g2 = f2.y, b2 = void 0 === g2 ? 0 : g2, x2 = "function" == typeof l2 ? l2({ x: y2, y: b2 }) : { x: y2, y: b2 };
y2 = x2.x, b2 = x2.y;
var w2 = f2.hasOwnProperty("x"), O2 = f2.hasOwnProperty("y"), j2 = P, M2 = D, k2 = window;
if (u2) {
var W2 = E(r2), H2 = "clientHeight", T2 = "clientWidth";
if (W2 === t(r2) && "static" !== m(W2 = d(r2)).position && "absolute" === c2 && (H2 = "scrollHeight", T2 = "scrollWidth"), W2 = W2, i2 === D || (i2 === P || i2 === L) && a2 === B) M2 = A, b2 -= (h2 && W2 === k2 && k2.visualViewport ? k2.visualViewport.height : W2[H2]) - o2.height, b2 *= p2 ? 1 : -1;
if (i2 === P || (i2 === D || i2 === A) && a2 === B) j2 = L, y2 -= (h2 && W2 === k2 && k2.visualViewport ? k2.visualViewport.width : W2[T2]) - o2.width, y2 *= p2 ? 1 : -1;
}
var R2, S2 = Object.assign({ position: c2 }, u2 && ne), V2 = true === l2 ? (function(e3, t2) {
var n3 = e3.x, r3 = e3.y, o3 = t2.devicePixelRatio || 1;
return { x: s(n3 * o3) / o3 || 0, y: s(r3 * o3) / o3 || 0 };
})({ x: y2, y: b2 }, t(r2)) : { x: y2, y: b2 };
return y2 = V2.x, b2 = V2.y, p2 ? Object.assign({}, S2, ((R2 = {})[M2] = O2 ? "0" : "", R2[j2] = w2 ? "0" : "", R2.transform = (k2.devicePixelRatio || 1) <= 1 ? "translate(" + y2 + "px, " + b2 + "px)" : "translate3d(" + y2 + "px, " + b2 + "px, 0)", R2)) : Object.assign({}, S2, ((n2 = {})[M2] = O2 ? b2 + "px" : "", n2[j2] = w2 ? y2 + "px" : "", n2.transform = "", n2));
}
var oe = { name: "computeStyles", enabled: true, phase: "beforeWrite", fn: function(e2) {
var t2 = e2.state, n2 = e2.options, r2 = n2.gpuAcceleration, o2 = void 0 === r2 || r2, i2 = n2.adaptive, a2 = void 0 === i2 || i2, s2 = n2.roundOffsets, f2 = void 0 === s2 || s2, c2 = { placement: F(t2.placement), variation: U(t2.placement), popper: t2.elements.popper, popperRect: t2.rects.popper, gpuAcceleration: o2, isFixed: "fixed" === t2.options.strategy };
null != t2.modifiersData.popperOffsets && (t2.styles.popper = Object.assign({}, t2.styles.popper, re(Object.assign({}, c2, { offsets: t2.modifiersData.popperOffsets, position: t2.options.strategy, adaptive: a2, roundOffsets: f2 })))), null != t2.modifiersData.arrow && (t2.styles.arrow = Object.assign({}, t2.styles.arrow, re(Object.assign({}, c2, { offsets: t2.modifiersData.arrow, position: "absolute", adaptive: false, roundOffsets: f2 })))), t2.attributes.popper = Object.assign({}, t2.attributes.popper, { "data-popper-placement": t2.placement });
}, data: {} };
var ie = { name: "applyStyles", enabled: true, phase: "write", fn: function(e2) {
var t2 = e2.state;
Object.keys(t2.elements).forEach((function(e3) {
var n2 = t2.styles[e3] || {}, o2 = t2.attributes[e3] || {}, i2 = t2.elements[e3];
r(i2) && l(i2) && (Object.assign(i2.style, n2), Object.keys(o2).forEach((function(e4) {
var t3 = o2[e4];
false === t3 ? i2.removeAttribute(e4) : i2.setAttribute(e4, true === t3 ? "" : t3);
})));
}));
}, effect: function(e2) {
var t2 = e2.state, n2 = { popper: { position: t2.options.strategy, left: "0", top: "0", margin: "0" }, arrow: { position: "absolute" }, reference: {} };
return Object.assign(t2.elements.popper.style, n2.popper), t2.styles = n2, t2.elements.arrow && Object.assign(t2.elements.arrow.style, n2.arrow), function() {
Object.keys(t2.elements).forEach((function(e3) {
var o2 = t2.elements[e3], i2 = t2.attributes[e3] || {}, a2 = Object.keys(t2.styles.hasOwnProperty(e3) ? t2.styles[e3] : n2[e3]).reduce((function(e4, t3) {
return e4[t3] = "", e4;
}), {});
r(o2) && l(o2) && (Object.assign(o2.style, a2), Object.keys(i2).forEach((function(e4) {
o2.removeAttribute(e4);
})));
}));
};
}, requires: ["computeStyles"] };
var ae = { name: "offset", enabled: true, phase: "main", requires: ["popperOffsets"], fn: function(e2) {
var t2 = e2.state, n2 = e2.options, r2 = e2.name, o2 = n2.offset, i2 = void 0 === o2 ? [0, 0] : o2, a2 = S.reduce((function(e3, n3) {
return e3[n3] = (function(e4, t3, n4) {
var r3 = F(e4), o3 = [P, D].indexOf(r3) >= 0 ? -1 : 1, i3 = "function" == typeof n4 ? n4(Object.assign({}, t3, { placement: e4 })) : n4, a3 = i3[0], s3 = i3[1];
return a3 = a3 || 0, s3 = (s3 || 0) * o3, [P, L].indexOf(r3) >= 0 ? { x: s3, y: a3 } : { x: a3, y: s3 };
})(n3, t2.rects, i2), e3;
}), {}), s2 = a2[t2.placement], f2 = s2.x, c2 = s2.y;
null != t2.modifiersData.popperOffsets && (t2.modifiersData.popperOffsets.x += f2, t2.modifiersData.popperOffsets.y += c2), t2.modifiersData[r2] = a2;
} }, se = { left: "right", right: "left", bottom: "top", top: "bottom" };
function fe(e2) {
return e2.replace(/left|right|bottom|top/g, (function(e3) {
return se[e3];
}));
}
var ce = { start: "end", end: "start" };
function pe(e2) {
return e2.replace(/start|end/g, (function(e3) {
return ce[e3];
}));
}
function ue(e2, t2) {
void 0 === t2 && (t2 = {});
var n2 = t2, r2 = n2.placement, o2 = n2.boundary, i2 = n2.rootBoundary, a2 = n2.padding, s2 = n2.flipVariations, f2 = n2.allowedAutoPlacements, c2 = void 0 === f2 ? S : f2, p2 = U(r2), u2 = p2 ? s2 ? R : R.filter((function(e3) {
return U(e3) === p2;
})) : k, l2 = u2.filter((function(e3) {
return c2.indexOf(e3) >= 0;
}));
0 === l2.length && (l2 = u2);
var d2 = l2.reduce((function(t3, n3) {
return t3[n3] = J(e2, { placement: n3, boundary: o2, rootBoundary: i2, padding: a2 })[F(n3)], t3;
}), {});
return Object.keys(d2).sort((function(e3, t3) {
return d2[e3] - d2[t3];
}));
}
var le = { name: "flip", enabled: true, phase: "main", fn: function(e2) {
var t2 = e2.state, n2 = e2.options, r2 = e2.name;
if (!t2.modifiersData[r2]._skip) {
for (var o2 = n2.mainAxis, i2 = void 0 === o2 || o2, a2 = n2.altAxis, s2 = void 0 === a2 || a2, f2 = n2.fallbackPlacements, c2 = n2.padding, p2 = n2.boundary, u2 = n2.rootBoundary, l2 = n2.altBoundary, d2 = n2.flipVariations, h2 = void 0 === d2 || d2, m2 = n2.allowedAutoPlacements, v2 = t2.options.placement, y2 = F(v2), g2 = f2 || (y2 === v2 || !h2 ? [fe(v2)] : (function(e3) {
if (F(e3) === M) return [];
var t3 = fe(e3);
return [pe(e3), t3, pe(t3)];
})(v2)), b2 = [v2].concat(g2).reduce((function(e3, n3) {
return e3.concat(F(n3) === M ? ue(t2, { placement: n3, boundary: p2, rootBoundary: u2, padding: c2, flipVariations: h2, allowedAutoPlacements: m2 }) : n3);
}), []), x2 = t2.rects.reference, w2 = t2.rects.popper, O2 = /* @__PURE__ */ new Map(), j2 = true, E2 = b2[0], k2 = 0; k2 < b2.length; k2++) {
var B2 = b2[k2], H2 = F(B2), T2 = U(B2) === W, R2 = [D, A].indexOf(H2) >= 0, S2 = R2 ? "width" : "height", V2 = J(t2, { placement: B2, boundary: p2, rootBoundary: u2, altBoundary: l2, padding: c2 }), q2 = R2 ? T2 ? L : P : T2 ? A : D;
x2[S2] > w2[S2] && (q2 = fe(q2));
var C2 = fe(q2), N2 = [];
if (i2 && N2.push(V2[H2] <= 0), s2 && N2.push(V2[q2] <= 0, V2[C2] <= 0), N2.every((function(e3) {
return e3;
}))) {
E2 = B2, j2 = false;
break;
}
O2.set(B2, N2);
}
if (j2) for (var I2 = function(e3) {
var t3 = b2.find((function(t4) {
var n3 = O2.get(t4);
if (n3) return n3.slice(0, e3).every((function(e4) {
return e4;
}));
}));
if (t3) return E2 = t3, "break";
}, _2 = h2 ? 3 : 1; _2 > 0; _2--) {
if ("break" === I2(_2)) break;
}
t2.placement !== E2 && (t2.modifiersData[r2]._skip = true, t2.placement = E2, t2.reset = true);
}
}, requiresIfExists: ["offset"], data: { _skip: false } };
function de(e2, t2, n2) {
return i(e2, a(t2, n2));
}
var he = { name: "preventOverflow", enabled: true, phase: "main", fn: function(e2) {
var t2 = e2.state, n2 = e2.options, r2 = e2.name, o2 = n2.mainAxis, s2 = void 0 === o2 || o2, f2 = n2.altAxis, c2 = void 0 !== f2 && f2, p2 = n2.boundary, u2 = n2.rootBoundary, l2 = n2.altBoundary, d2 = n2.padding, h2 = n2.tether, m2 = void 0 === h2 || h2, v2 = n2.tetherOffset, y2 = void 0 === v2 ? 0 : v2, b2 = J(t2, { boundary: p2, rootBoundary: u2, padding: d2, altBoundary: l2 }), x2 = F(t2.placement), w2 = U(t2.placement), O2 = !w2, j2 = z(x2), M2 = "x" === j2 ? "y" : "x", k2 = t2.modifiersData.popperOffsets, B2 = t2.rects.reference, H2 = t2.rects.popper, T2 = "function" == typeof y2 ? y2(Object.assign({}, t2.rects, { placement: t2.placement })) : y2, R2 = "number" == typeof T2 ? { mainAxis: T2, altAxis: T2 } : Object.assign({ mainAxis: 0, altAxis: 0 }, T2), S2 = t2.modifiersData.offset ? t2.modifiersData.offset[t2.placement] : null, V2 = { x: 0, y: 0 };
if (k2) {
if (s2) {
var q2, C2 = "y" === j2 ? D : P, N2 = "y" === j2 ? A : L, I2 = "y" === j2 ? "height" : "width", _2 = k2[j2], X2 = _2 + b2[C2], Y2 = _2 - b2[N2], G2 = m2 ? -H2[I2] / 2 : 0, K2 = w2 === W ? B2[I2] : H2[I2], Q2 = w2 === W ? -H2[I2] : -B2[I2], Z2 = t2.elements.arrow, $2 = m2 && Z2 ? g(Z2) : { width: 0, height: 0 }, ee2 = t2.modifiersData["arrow#persistent"] ? t2.modifiersData["arrow#persistent"].padding : { top: 0, right: 0, bottom: 0, left: 0 }, te2 = ee2[C2], ne2 = ee2[N2], re2 = de(0, B2[I2], $2[I2]), oe2 = O2 ? B2[I2] / 2 - G2 - re2 - te2 - R2.mainAxis : K2 - re2 - te2 - R2.mainAxis, ie2 = O2 ? -B2[I2] / 2 + G2 + re2 + ne2 + R2.mainAxis : Q2 + re2 + ne2 + R2.mainAxis, ae2 = t2.elements.arrow && E(t2.elements.arrow), se2 = ae2 ? "y" === j2 ? ae2.clientTop || 0 : ae2.clientLeft || 0 : 0, fe2 = null != (q2 = null == S2 ? void 0 : S2[j2]) ? q2 : 0, ce2 = _2 + ie2 - fe2, pe2 = de(m2 ? a(X2, _2 + oe2 - fe2 - se2) : X2, _2, m2 ? i(Y2, ce2) : Y2);
k2[j2] = pe2, V2[j2] = pe2 - _2;
}
if (c2) {
var ue2, le2 = "x" === j2 ? D : P, he2 = "x" === j2 ? A : L, me2 = k2[M2], ve2 = "y" === M2 ? "height" : "width", ye2 = me2 + b2[le2], ge2 = me2 - b2[he2], be2 = -1 !== [D, P].indexOf(x2), xe2 = null != (ue2 = null == S2 ? void 0 : S2[M2]) ? ue2 : 0, we2 = be2 ? ye2 : me2 - B2[ve2] - H2[ve2] - xe2 + R2.altAxis, Oe = be2 ? me2 + B2[ve2] + H2[ve2] - xe2 - R2.altAxis : ge2, je = m2 && be2 ? (function(e3, t3, n3) {
var r3 = de(e3, t3, n3);
return r3 > n3 ? n3 : r3;
})(we2, me2, Oe) : de(m2 ? we2 : ye2, me2, m2 ? Oe : ge2);
k2[M2] = je, V2[M2] = je - me2;
}
t2.modifiersData[r2] = V2;
}
}, requiresIfExists: ["offset"] };
var me = { name: "arrow", enabled: true, phase: "main", fn: function(e2) {
var t2, n2 = e2.state, r2 = e2.name, o2 = e2.options, i2 = n2.elements.arrow, a2 = n2.modifiersData.popperOffsets, s2 = F(n2.placement), f2 = z(s2), c2 = [P, L].indexOf(s2) >= 0 ? "height" : "width";
if (i2 && a2) {
var p2 = (function(e3, t3) {
return Y("number" != typeof (e3 = "function" == typeof e3 ? e3(Object.assign({}, t3.rects, { placement: t3.placement })) : e3) ? e3 : G(e3, k));
})(o2.padding, n2), u2 = g(i2), l2 = "y" === f2 ? D : P, d2 = "y" === f2 ? A : L, h2 = n2.rects.reference[c2] + n2.rects.reference[f2] - a2[f2] - n2.rects.popper[c2], m2 = a2[f2] - n2.rects.reference[f2], v2 = E(i2), y2 = v2 ? "y" === f2 ? v2.clientHeight || 0 : v2.clientWidth || 0 : 0, b2 = h2 / 2 - m2 / 2, x2 = p2[l2], w2 = y2 - u2[c2] - p2[d2], O2 = y2 / 2 - u2[c2] / 2 + b2, j2 = de(x2, O2, w2), M2 = f2;
n2.modifiersData[r2] = ((t2 = {})[M2] = j2, t2.centerOffset = j2 - O2, t2);
}
}, effect: function(e2) {
var t2 = e2.state, n2 = e2.options.element, r2 = void 0 === n2 ? "[data-popper-arrow]" : n2;
null != r2 && ("string" != typeof r2 || (r2 = t2.elements.popper.querySelector(r2))) && C(t2.elements.popper, r2) && (t2.elements.arrow = r2);
}, requires: ["popperOffsets"], requiresIfExists: ["preventOverflow"] };
function ve(e2, t2, n2) {
return void 0 === n2 && (n2 = { x: 0, y: 0 }), { top: e2.top - t2.height - n2.y, right: e2.right - t2.width + n2.x, bottom: e2.bottom - t2.height + n2.y, left: e2.left - t2.width - n2.x };
}
function ye(e2) {
return [D, L, A, P].some((function(t2) {
return e2[t2] >= 0;
}));
}
var ge = { name: "hide", enabled: true, phase: "main", requiresIfExists: ["preventOverflow"], fn: function(e2) {
var t2 = e2.state, n2 = e2.name, r2 = t2.rects.reference, o2 = t2.rects.popper, i2 = t2.modifiersData.preventOverflow, a2 = J(t2, { elementContext: "reference" }), s2 = J(t2, { altBoundary: true }), f2 = ve(a2, r2), c2 = ve(s2, o2, i2), p2 = ye(f2), u2 = ye(c2);
t2.modifiersData[n2] = { referenceClippingOffsets: f2, popperEscapeOffsets: c2, isReferenceHidden: p2, hasPopperEscaped: u2 }, t2.attributes.popper = Object.assign({}, t2.attributes.popper, { "data-popper-reference-hidden": p2, "data-popper-escaped": u2 });
} }, be = Z({ defaultModifiers: [ee, te, oe, ie] }), xe = [ee, te, oe, ie, ae, le, he, me, ge], we = Z({ defaultModifiers: xe });
e.applyStyles = ie, e.arrow = me, e.computeStyles = oe, e.createPopper = we, e.createPopperLite = be, e.defaultModifiers = xe, e.detectOverflow = J, e.eventListeners = ee, e.flip = le, e.hide = ge, e.offset = ae, e.popperGenerator = Z, e.popperOffsets = te, e.preventOverflow = he, Object.defineProperty(e, "__esModule", { value: true });
}));
}
});
// frontend/js/components/customSelectV2.js
var import_popper_esm_min = __toESM(require_popper_esm_min());
var CustomSelectV2 = class _CustomSelectV2 {
constructor(container) {
this.container = container;
this.trigger = this.container.querySelector(".custom-select-trigger");
this.nativeSelect = this.container.querySelector("select");
this.template = this.container.querySelector(".custom-select-panel-template");
if (!this.trigger || !this.nativeSelect || !this.template) {
console.warn("CustomSelectV2 cannot initialize: missing required elements.", this.container);
return;
}
this.panel = null;
this.popperInstance = null;
this.isOpen = false;
this.triggerText = this.trigger.querySelector("span");
if (typeof _CustomSelectV2.openInstance === "undefined") {
_CustomSelectV2.openInstance = null;
_CustomSelectV2.initGlobalListener();
}
this.updateTriggerText();
this.bindEvents();
}
static initGlobalListener() {
document.addEventListener("click", (event) => {
const instance = _CustomSelectV2.openInstance;
if (instance && !instance.container.contains(event.target) && (!instance.panel || !instance.panel.contains(event.target))) {
instance.close();
}
});
}
createPanel() {
const panelFragment = this.template.content.cloneNode(true);
this.panel = panelFragment.querySelector(".custom-select-panel");
document.body.appendChild(this.panel);
this.panel.innerHTML = "";
Array.from(this.nativeSelect.options).forEach((option) => {
const item = document.createElement("a");
item.href = "#";
item.className = "custom-select-option block w-full text-left px-3 py-1.5 text-sm text-zinc-700 hover:bg-zinc-100 dark:text-zinc-200 dark:hover:bg-zinc-700";
item.textContent = option.textContent;
item.dataset.value = option.value;
if (option.selected) {
item.classList.add("is-selected");
}
this.panel.appendChild(item);
});
this.panel.addEventListener("click", (event) => {
event.preventDefault();
const optionEl = event.target.closest(".custom-select-option");
if (optionEl) {
this.selectOption(optionEl);
}
});
}
bindEvents() {
this.trigger.addEventListener("click", (event) => {
event.stopPropagation();
if (_CustomSelectV2.openInstance && _CustomSelectV2.openInstance !== this) {
_CustomSelectV2.openInstance.close();
}
this.toggle();
});
}
selectOption(optionEl) {
const selectedValue = optionEl.dataset.value;
if (this.nativeSelect.value !== selectedValue) {
this.nativeSelect.value = selectedValue;
this.nativeSelect.dispatchEvent(new Event("change", { bubbles: true }));
}
this.updateTriggerText();
this.close();
}
updateTriggerText() {
const selectedOption = this.nativeSelect.options[this.nativeSelect.selectedIndex];
if (selectedOption) {
this.triggerText.textContent = selectedOption.textContent;
}
}
toggle() {
this.isOpen ? this.close() : this.open();
}
open() {
if (this.isOpen) return;
this.isOpen = true;
if (!this.panel) {
this.createPanel();
}
this.panel.style.display = "block";
this.panel.offsetHeight;
this.popperInstance = (0, import_popper_esm_min.createPopper)(this.trigger, this.panel, {
placement: "top-start",
modifiers: [
{ name: "offset", options: { offset: [0, 8] } },
{ name: "flip", options: { fallbackPlacements: ["bottom-start"] } }
]
});
_CustomSelectV2.openInstance = this;
}
close() {
if (!this.isOpen) return;
this.isOpen = false;
if (this.popperInstance) {
this.popperInstance.destroy();
this.popperInstance = null;
}
if (this.panel) {
this.panel.remove();
this.panel = null;
}
if (_CustomSelectV2.openInstance === this) {
_CustomSelectV2.openInstance = null;
}
}
};
export {
require_popper_esm_min,
CustomSelectV2
};

View File

@@ -111,281 +111,6 @@ var CustomSelect = class _CustomSelect {
}
};
// frontend/js/components/ui.js
var ModalManager = class {
/**
* Shows a generic modal by its ID.
* @param {string} modalId The ID of the modal element to show.
*/
show(modalId) {
const modal = document.getElementById(modalId);
if (modal) {
modal.classList.remove("hidden");
} else {
console.error(`Modal with ID "${modalId}" not found.`);
}
}
/**
* Hides a generic modal by its ID.
* @param {string} modalId The ID of the modal element to hide.
*/
hide(modalId) {
const modal = document.getElementById(modalId);
if (modal) {
modal.classList.add("hidden");
} else {
console.error(`Modal with ID "${modalId}" not found.`);
}
}
/**
* Shows a confirmation dialog. This is a versatile method for 'Are you sure?' style prompts.
* It dynamically sets the title, message, and confirm action for a generic confirmation modal.
* @param {object} options - The options for the confirmation modal.
* @param {string} options.modalId - The ID of the confirmation modal element (e.g., 'resetModal', 'deleteConfirmModal').
* @param {string} options.title - The title to display in the modal header.
* @param {string} options.message - The message to display in the modal body. Can contain HTML.
* @param {function} options.onConfirm - The callback function to execute when the confirm button is clicked.
* @param {boolean} [options.disableConfirm=false] - Whether the confirm button should be initially disabled.
*/
showConfirm({ modalId, title, message, onConfirm, disableConfirm = false }) {
const modalElement = document.getElementById(modalId);
if (!modalElement) {
console.error(`Confirmation modal with ID "${modalId}" not found.`);
return;
}
const titleElement = modalElement.querySelector('[id$="ModalTitle"]');
const messageElement = modalElement.querySelector('[id$="ModalMessage"]');
const confirmButton = modalElement.querySelector('[id^="confirm"]');
if (!titleElement || !messageElement || !confirmButton) {
console.error(`Modal "${modalId}" is missing required child elements (title, message, or confirm button).`);
return;
}
titleElement.textContent = title;
messageElement.innerHTML = message;
confirmButton.disabled = disableConfirm;
const newConfirmButton = confirmButton.cloneNode(true);
confirmButton.parentNode.replaceChild(newConfirmButton, confirmButton);
newConfirmButton.onclick = () => onConfirm();
this.show(modalId);
}
/**
* Shows a result modal to indicate the outcome of an operation (success or failure).
* @param {boolean} success - If true, displays a success icon and title; otherwise, shows failure indicators.
* @param {string|Node} message - The message to display. Can be a simple string or a complex DOM Node for rich content.
* @param {boolean} [autoReload=false] - If true, the page will automatically reload when the modal is closed.
*/
showResult(success, message, autoReload = false) {
const modalElement = document.getElementById("resultModal");
if (!modalElement) {
console.error("Result modal with ID 'resultModal' not found.");
return;
}
const titleElement = document.getElementById("resultModalTitle");
const messageElement = document.getElementById("resultModalMessage");
const iconElement = document.getElementById("resultIcon");
const confirmButton = document.getElementById("resultModalConfirmBtn");
if (!titleElement || !messageElement || !iconElement || !confirmButton) {
console.error("Result modal is missing required child elements.");
return;
}
titleElement.textContent = success ? "\u64CD\u4F5C\u6210\u529F" : "\u64CD\u4F5C\u5931\u8D25";
if (success) {
iconElement.innerHTML = '<i class="fas fa-check-circle text-success-500"></i>';
iconElement.className = "text-6xl mb-3 text-success-500";
} else {
iconElement.innerHTML = '<i class="fas fa-times-circle text-danger-500"></i>';
iconElement.className = "text-6xl mb-3 text-danger-500";
}
messageElement.innerHTML = "";
if (typeof message === "string") {
const messageDiv = document.createElement("div");
messageDiv.innerText = message;
messageElement.appendChild(messageDiv);
} else if (message instanceof Node) {
messageElement.appendChild(message);
} else {
const messageDiv = document.createElement("div");
messageDiv.innerText = String(message);
messageElement.appendChild(messageDiv);
}
confirmButton.onclick = () => this.closeResult(autoReload);
this.show("resultModal");
}
/**
* Closes the result modal.
* @param {boolean} [reload=false] - If true, reloads the page after closing the modal.
*/
closeResult(reload = false) {
this.hide("resultModal");
if (reload) {
location.reload();
}
}
/**
* Shows and initializes the progress modal for long-running operations.
* @param {string} title - The title to display for the progress modal.
*/
showProgress(title) {
const modal = document.getElementById("progressModal");
if (!modal) {
console.error("Progress modal with ID 'progressModal' not found.");
return;
}
const titleElement = document.getElementById("progressModalTitle");
const statusText = document.getElementById("progressStatusText");
const progressBar = document.getElementById("progressBar");
const progressPercentage = document.getElementById("progressPercentage");
const progressLog = document.getElementById("progressLog");
const closeButton = document.getElementById("progressModalCloseBtn");
const closeIcon = document.getElementById("closeProgressModalBtn");
if (!titleElement || !statusText || !progressBar || !progressPercentage || !progressLog || !closeButton || !closeIcon) {
console.error("Progress modal is missing required child elements.");
return;
}
titleElement.textContent = title;
statusText.textContent = "\u51C6\u5907\u5F00\u59CB...";
progressBar.style.width = "0%";
progressPercentage.textContent = "0%";
progressLog.innerHTML = "";
closeButton.disabled = true;
closeIcon.disabled = true;
this.show("progressModal");
}
/**
* Updates the progress bar and status text within the progress modal.
* @param {number} processed - The number of items that have been processed.
* @param {number} total - The total number of items to process.
* @param {string} status - The current status message to display.
*/
updateProgress(processed, total, status) {
const modal = document.getElementById("progressModal");
if (!modal || modal.classList.contains("hidden")) return;
const progressBar = document.getElementById("progressBar");
const progressPercentage = document.getElementById("progressPercentage");
const statusText = document.getElementById("progressStatusText");
const closeButton = document.getElementById("progressModalCloseBtn");
const closeIcon = document.getElementById("closeProgressModalBtn");
const percentage = total > 0 ? Math.round(processed / total * 100) : 0;
progressBar.style.width = `${percentage}%`;
progressPercentage.textContent = `${percentage}%`;
statusText.textContent = status;
if (processed === total) {
closeButton.disabled = false;
closeIcon.disabled = false;
}
}
/**
* Adds a log entry to the progress modal's log area.
* @param {string} message - The log message to append.
* @param {boolean} [isError=false] - If true, styles the log entry as an error.
*/
addProgressLog(message, isError = false) {
const progressLog = document.getElementById("progressLog");
if (!progressLog) return;
const logEntry = document.createElement("div");
logEntry.textContent = message;
logEntry.className = isError ? "text-danger-600" : "text-gray-700";
progressLog.appendChild(logEntry);
progressLog.scrollTop = progressLog.scrollHeight;
}
/**
* Closes the progress modal.
* @param {boolean} [reload=false] - If true, reloads the page after closing.
*/
closeProgress(reload = false) {
this.hide("progressModal");
if (reload) {
location.reload();
}
}
};
var UIPatterns = class {
/**
* Animates numerical values in elements from 0 to their target number.
* The target number is read from the element's text content.
* @param {string} selector - The CSS selector for the elements to animate (e.g., '.stat-value').
* @param {number} [duration=1500] - The duration of the animation in milliseconds.
*/
animateCounters(selector = ".stat-value", duration = 1500) {
const statValues = document.querySelectorAll(selector);
statValues.forEach((valueElement) => {
const finalValue = parseInt(valueElement.textContent, 10);
if (isNaN(finalValue)) return;
if (!valueElement.dataset.originalValue) {
valueElement.dataset.originalValue = valueElement.textContent;
}
let startValue = 0;
const startTime = performance.now();
const updateCounter = (currentTime) => {
const elapsedTime = currentTime - startTime;
if (elapsedTime < duration) {
const progress = elapsedTime / duration;
const easeOutValue = 1 - Math.pow(1 - progress, 3);
const currentValue = Math.floor(easeOutValue * finalValue);
valueElement.textContent = currentValue;
requestAnimationFrame(updateCounter);
} else {
valueElement.textContent = valueElement.dataset.originalValue;
}
};
requestAnimationFrame(updateCounter);
});
}
/**
* Toggles the visibility of a content section with a smooth height animation.
* It expects a specific HTML structure where the header and content are within a common parent (e.g., a card).
* The content element should have a `collapsed` class when hidden.
* @param {HTMLElement} header - The header element that was clicked to trigger the toggle.
*/
toggleSection(header) {
const card = header.closest(".stats-card");
if (!card) return;
const content = card.querySelector(".key-content");
const toggleIcon = header.querySelector(".toggle-icon");
if (!content || !toggleIcon) {
console.error("Toggle section failed: Content or icon element not found.", { header });
return;
}
const isCollapsed = content.classList.contains("collapsed");
toggleIcon.classList.toggle("collapsed", !isCollapsed);
if (isCollapsed) {
content.classList.remove("collapsed");
content.style.maxHeight = null;
content.style.opacity = null;
content.style.paddingTop = null;
content.style.paddingBottom = null;
content.style.overflow = "hidden";
requestAnimationFrame(() => {
const targetHeight = content.scrollHeight;
content.style.maxHeight = `${targetHeight}px`;
content.style.opacity = "1";
content.style.paddingTop = "1rem";
content.style.paddingBottom = "1rem";
content.addEventListener("transitionend", function onExpansionEnd() {
content.removeEventListener("transitionend", onExpansionEnd);
if (!content.classList.contains("collapsed")) {
content.style.maxHeight = "";
content.style.overflow = "visible";
}
}, { once: true });
});
} else {
const currentHeight = content.scrollHeight;
content.style.maxHeight = `${currentHeight}px`;
content.style.overflow = "hidden";
requestAnimationFrame(() => {
content.style.maxHeight = "0px";
content.style.opacity = "0";
content.style.paddingTop = "0";
content.style.paddingBottom = "0";
content.classList.add("collapsed");
});
}
}
};
var modalManager = new ModalManager();
var uiPatterns = new UIPatterns();
// frontend/js/components/taskCenter.js
var TaskCenterManager = class {
constructor() {
@@ -810,8 +535,6 @@ var toastManager = new ToastManager();
export {
CustomSelect,
modalManager,
uiPatterns,
taskCenterManager,
toastManager
};

View File

@@ -0,0 +1,393 @@
// frontend/js/components/ui.js
var ModalManager = class {
/**
* Shows a generic modal by its ID.
* @param {string} modalId The ID of the modal element to show.
*/
show(modalId) {
const modal = document.getElementById(modalId);
if (modal) {
modal.classList.remove("hidden");
} else {
console.error(`Modal with ID "${modalId}" not found.`);
}
}
/**
* Hides a generic modal by its ID.
* @param {string} modalId The ID of the modal element to hide.
*/
hide(modalId) {
const modal = document.getElementById(modalId);
if (modal) {
modal.classList.add("hidden");
} else {
console.error(`Modal with ID "${modalId}" not found.`);
}
}
/**
* Shows a confirmation dialog. This is a versatile method for 'Are you sure?' style prompts.
* It dynamically sets the title, message, and confirm action for a generic confirmation modal.
* @param {object} options - The options for the confirmation modal.
* @param {string} options.modalId - The ID of the confirmation modal element (e.g., 'resetModal', 'deleteConfirmModal').
* @param {string} options.title - The title to display in the modal header.
* @param {string} options.message - The message to display in the modal body. Can contain HTML.
* @param {function} options.onConfirm - The callback function to execute when the confirm button is clicked.
* @param {boolean} [options.disableConfirm=false] - Whether the confirm button should be initially disabled.
*/
showConfirm({ modalId, title, message, onConfirm, disableConfirm = false }) {
const modalElement = document.getElementById(modalId);
if (!modalElement) {
console.error(`Confirmation modal with ID "${modalId}" not found.`);
return;
}
const titleElement = modalElement.querySelector('[id$="ModalTitle"]');
const messageElement = modalElement.querySelector('[id$="ModalMessage"]');
const confirmButton = modalElement.querySelector('[id^="confirm"]');
if (!titleElement || !messageElement || !confirmButton) {
console.error(`Modal "${modalId}" is missing required child elements (title, message, or confirm button).`);
return;
}
titleElement.textContent = title;
messageElement.innerHTML = message;
confirmButton.disabled = disableConfirm;
const newConfirmButton = confirmButton.cloneNode(true);
confirmButton.parentNode.replaceChild(newConfirmButton, confirmButton);
newConfirmButton.onclick = () => onConfirm();
this.show(modalId);
}
/**
* Shows a result modal to indicate the outcome of an operation (success or failure).
* @param {boolean} success - If true, displays a success icon and title; otherwise, shows failure indicators.
* @param {string|Node} message - The message to display. Can be a simple string or a complex DOM Node for rich content.
* @param {boolean} [autoReload=false] - If true, the page will automatically reload when the modal is closed.
*/
showResult(success, message, autoReload = false) {
const modalElement = document.getElementById("resultModal");
if (!modalElement) {
console.error("Result modal with ID 'resultModal' not found.");
return;
}
const titleElement = document.getElementById("resultModalTitle");
const messageElement = document.getElementById("resultModalMessage");
const iconElement = document.getElementById("resultIcon");
const confirmButton = document.getElementById("resultModalConfirmBtn");
if (!titleElement || !messageElement || !iconElement || !confirmButton) {
console.error("Result modal is missing required child elements.");
return;
}
titleElement.textContent = success ? "\u64CD\u4F5C\u6210\u529F" : "\u64CD\u4F5C\u5931\u8D25";
if (success) {
iconElement.innerHTML = '<i class="fas fa-check-circle text-success-500"></i>';
iconElement.className = "text-6xl mb-3 text-success-500";
} else {
iconElement.innerHTML = '<i class="fas fa-times-circle text-danger-500"></i>';
iconElement.className = "text-6xl mb-3 text-danger-500";
}
messageElement.innerHTML = "";
if (typeof message === "string") {
const messageDiv = document.createElement("div");
messageDiv.innerText = message;
messageElement.appendChild(messageDiv);
} else if (message instanceof Node) {
messageElement.appendChild(message);
} else {
const messageDiv = document.createElement("div");
messageDiv.innerText = String(message);
messageElement.appendChild(messageDiv);
}
confirmButton.onclick = () => this.closeResult(autoReload);
this.show("resultModal");
}
/**
* Closes the result modal.
* @param {boolean} [reload=false] - If true, reloads the page after closing the modal.
*/
closeResult(reload = false) {
this.hide("resultModal");
if (reload) {
location.reload();
}
}
/**
* Shows and initializes the progress modal for long-running operations.
* @param {string} title - The title to display for the progress modal.
*/
showProgress(title) {
const modal = document.getElementById("progressModal");
if (!modal) {
console.error("Progress modal with ID 'progressModal' not found.");
return;
}
const titleElement = document.getElementById("progressModalTitle");
const statusText = document.getElementById("progressStatusText");
const progressBar = document.getElementById("progressBar");
const progressPercentage = document.getElementById("progressPercentage");
const progressLog = document.getElementById("progressLog");
const closeButton = document.getElementById("progressModalCloseBtn");
const closeIcon = document.getElementById("closeProgressModalBtn");
if (!titleElement || !statusText || !progressBar || !progressPercentage || !progressLog || !closeButton || !closeIcon) {
console.error("Progress modal is missing required child elements.");
return;
}
titleElement.textContent = title;
statusText.textContent = "\u51C6\u5907\u5F00\u59CB...";
progressBar.style.width = "0%";
progressPercentage.textContent = "0%";
progressLog.innerHTML = "";
closeButton.disabled = true;
closeIcon.disabled = true;
this.show("progressModal");
}
/**
* Updates the progress bar and status text within the progress modal.
* @param {number} processed - The number of items that have been processed.
* @param {number} total - The total number of items to process.
* @param {string} status - The current status message to display.
*/
updateProgress(processed, total, status) {
const modal = document.getElementById("progressModal");
if (!modal || modal.classList.contains("hidden")) return;
const progressBar = document.getElementById("progressBar");
const progressPercentage = document.getElementById("progressPercentage");
const statusText = document.getElementById("progressStatusText");
const closeButton = document.getElementById("progressModalCloseBtn");
const closeIcon = document.getElementById("closeProgressModalBtn");
const percentage = total > 0 ? Math.round(processed / total * 100) : 0;
progressBar.style.width = `${percentage}%`;
progressPercentage.textContent = `${percentage}%`;
statusText.textContent = status;
if (processed === total) {
closeButton.disabled = false;
closeIcon.disabled = false;
}
}
/**
* Adds a log entry to the progress modal's log area.
* @param {string} message - The log message to append.
* @param {boolean} [isError=false] - If true, styles the log entry as an error.
*/
addProgressLog(message, isError = false) {
const progressLog = document.getElementById("progressLog");
if (!progressLog) return;
const logEntry = document.createElement("div");
logEntry.textContent = message;
logEntry.className = isError ? "text-danger-600" : "text-gray-700";
progressLog.appendChild(logEntry);
progressLog.scrollTop = progressLog.scrollHeight;
}
/**
* Closes the progress modal.
* @param {boolean} [reload=false] - If true, reloads the page after closing.
*/
closeProgress(reload = false) {
this.hide("progressModal");
if (reload) {
location.reload();
}
}
};
var UIPatterns = class {
/**
* Animates numerical values in elements from 0 to their target number.
* The target number is read from the element's text content.
* @param {string} selector - The CSS selector for the elements to animate (e.g., '.stat-value').
* @param {number} [duration=1500] - The duration of the animation in milliseconds.
*/
animateCounters(selector = ".stat-value", duration = 1500) {
const statValues = document.querySelectorAll(selector);
statValues.forEach((valueElement) => {
const finalValue = parseInt(valueElement.textContent, 10);
if (isNaN(finalValue)) return;
if (!valueElement.dataset.originalValue) {
valueElement.dataset.originalValue = valueElement.textContent;
}
let startValue = 0;
const startTime = performance.now();
const updateCounter = (currentTime) => {
const elapsedTime = currentTime - startTime;
if (elapsedTime < duration) {
const progress = elapsedTime / duration;
const easeOutValue = 1 - Math.pow(1 - progress, 3);
const currentValue = Math.floor(easeOutValue * finalValue);
valueElement.textContent = currentValue;
requestAnimationFrame(updateCounter);
} else {
valueElement.textContent = valueElement.dataset.originalValue;
}
};
requestAnimationFrame(updateCounter);
});
}
/**
* Toggles the visibility of a content section with a smooth height animation.
* It expects a specific HTML structure where the header and content are within a common parent (e.g., a card).
* The content element should have a `collapsed` class when hidden.
* @param {HTMLElement} header - The header element that was clicked to trigger the toggle.
*/
toggleSection(header) {
const card = header.closest(".stats-card");
if (!card) return;
const content = card.querySelector(".key-content");
const toggleIcon = header.querySelector(".toggle-icon");
if (!content || !toggleIcon) {
console.error("Toggle section failed: Content or icon element not found.", { header });
return;
}
const isCollapsed = content.classList.contains("collapsed");
toggleIcon.classList.toggle("collapsed", !isCollapsed);
if (isCollapsed) {
content.classList.remove("collapsed");
content.style.maxHeight = null;
content.style.opacity = null;
content.style.paddingTop = null;
content.style.paddingBottom = null;
content.style.overflow = "hidden";
requestAnimationFrame(() => {
const targetHeight = content.scrollHeight;
content.style.maxHeight = `${targetHeight}px`;
content.style.opacity = "1";
content.style.paddingTop = "1rem";
content.style.paddingBottom = "1rem";
content.addEventListener("transitionend", function onExpansionEnd() {
content.removeEventListener("transitionend", onExpansionEnd);
if (!content.classList.contains("collapsed")) {
content.style.maxHeight = "";
content.style.overflow = "visible";
}
}, { once: true });
});
} else {
const currentHeight = content.scrollHeight;
content.style.maxHeight = `${currentHeight}px`;
content.style.overflow = "hidden";
requestAnimationFrame(() => {
content.style.maxHeight = "0px";
content.style.opacity = "0";
content.style.paddingTop = "0";
content.style.paddingBottom = "0";
content.classList.add("collapsed");
});
}
}
/**
* Sets a button to a loading state by disabling it and showing a spinner.
* It stores the button's original content to be restored later.
* @param {HTMLButtonElement} button - The button element to modify.
*/
setButtonLoading(button) {
if (!button) return;
if (!button.dataset.originalContent) {
button.dataset.originalContent = button.innerHTML;
}
button.disabled = true;
button.innerHTML = '<i class="fas fa-spinner fa-spin"></i>';
}
/**
* Restores a button from its loading state to its original content and enables it.
* @param {HTMLButtonElement} button - The button element to restore.
*/
clearButtonLoading(button) {
if (!button) return;
if (button.dataset.originalContent) {
button.innerHTML = button.dataset.originalContent;
delete button.dataset.originalContent;
}
button.disabled = false;
}
/**
* Returns the HTML for a streaming text cursor animation.
* This is used as a placeholder in the chat UI while waiting for an assistant's response.
* @returns {string} The HTML string for the loader.
*/
renderStreamingLoader() {
return '<span class="streaming-cursor animate-pulse">\u258B</span>';
}
};
var modalManager = new ModalManager();
var uiPatterns = new UIPatterns();
// frontend/js/services/api.js
var APIClientError = class extends Error {
constructor(message, status, code, rawMessageFromServer) {
super(message);
this.name = "APIClientError";
this.status = status;
this.code = code;
this.rawMessageFromServer = rawMessageFromServer;
}
};
var apiPromiseCache = /* @__PURE__ */ new Map();
async function apiFetch(url, options = {}) {
const isGetRequest = !options.method || options.method.toUpperCase() === "GET";
const cacheKey = isGetRequest && !options.noCache ? url : null;
if (cacheKey && apiPromiseCache.has(cacheKey)) {
return apiPromiseCache.get(cacheKey);
}
const token = localStorage.getItem("bearerToken");
const headers = {
"Content-Type": "application/json",
...options.headers
};
if (token) {
headers["Authorization"] = `Bearer ${token}`;
}
const requestPromise = (async () => {
try {
const response = await fetch(url, { ...options, headers });
if (response.status === 401) {
if (cacheKey) apiPromiseCache.delete(cacheKey);
localStorage.removeItem("bearerToken");
if (window.location.pathname !== "/login") {
window.location.href = "/login?error=\u4F1A\u8BDD\u5DF2\u8FC7\u671F\uFF0C\u8BF7\u91CD\u65B0\u767B\u5F55\u3002";
}
throw new APIClientError("Unauthorized", 401, "UNAUTHORIZED", "Session expired or token is invalid.");
}
if (!response.ok) {
let errorData = null;
let rawMessage = "";
try {
rawMessage = await response.text();
if (rawMessage) {
errorData = JSON.parse(rawMessage);
}
} catch (e) {
errorData = { error: { code: "UNKNOWN_FORMAT", message: rawMessage || response.statusText } };
}
const code = errorData?.error?.code || "UNKNOWN_ERROR";
const messageFromServer = errorData?.error?.message || rawMessage || "No message provided by server.";
const error = new APIClientError(
`API request failed: ${response.status}`,
response.status,
code,
messageFromServer
);
throw error;
}
return response;
} catch (error) {
if (cacheKey) apiPromiseCache.delete(cacheKey);
throw error;
}
})();
if (cacheKey) {
apiPromiseCache.set(cacheKey, requestPromise);
}
return requestPromise;
}
async function apiFetchJson(url, options = {}) {
try {
const response = await apiFetch(url, options);
const clonedResponse = response.clone();
const jsonData = await clonedResponse.json();
return jsonData;
} catch (error) {
throw error;
}
}
export {
modalManager,
uiPatterns,
apiFetch,
apiFetchJson
};

View File

@@ -1,3 +1,5 @@
import "./chunk-JSBRDJBE.js";
// frontend/js/pages/dashboard.js
function init() {
console.log("[Modern Frontend] Dashboard module loaded. Future logic will execute here.");

Some files were not shown because too many files have changed in this diff Show More