Compare commits
11 Commits
ac0e0a8275
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 0839ec35a2 | |||
| 166437c0ac | |||
| c86e7a7ba4 | |||
| 01c9b34600 | |||
| 04d36e4d9e | |||
| ad1e6180cf | |||
| e026d8f324 | |||
| f2706d6fc8 | |||
| 3a95a07e8a | |||
| 6c7283d51b | |||
| 2b0b9b67dc |
@@ -3,10 +3,12 @@ package main
|
||||
import (
|
||||
"gemini-balancer/internal/app"
|
||||
"gemini-balancer/internal/container"
|
||||
"gemini-balancer/internal/logging"
|
||||
"log"
|
||||
)
|
||||
|
||||
func main() {
|
||||
defer logging.Close()
|
||||
cont, err := container.BuildContainer()
|
||||
if err != nil {
|
||||
log.Fatalf("FATAL: Failed to build dependency container: %v", err)
|
||||
|
||||
@@ -14,6 +14,12 @@ server:
|
||||
log:
|
||||
level: "debug"
|
||||
|
||||
# 日志轮转配置
|
||||
max_size: 100 # MB
|
||||
max_backups: 7 # 保留文件数
|
||||
max_age: 30 # 保留天数
|
||||
compress: true # 压缩旧日志
|
||||
|
||||
redis:
|
||||
dsn: "redis://localhost:6379/0"
|
||||
|
||||
|
||||
@@ -5,4 +5,5 @@ esbuild ./frontend/js/main.js \
|
||||
--outdir=./web/static/js \
|
||||
--splitting \
|
||||
--format=esm \
|
||||
--loader:.css=css \
|
||||
--watch=forever
|
||||
|
||||
13
frontend/css/flatpickr.min.css
vendored
Normal file
13
frontend/css/flatpickr.min.css
vendored
Normal file
File diff suppressed because one or more lines are too long
@@ -1,6 +1,6 @@
|
||||
/* static/css/input.css */
|
||||
@import "tailwindcss";
|
||||
|
||||
/* @import "./css/flatpickr.min.css"; */
|
||||
/* =================================================================== */
|
||||
/* [核心] 定义 shadcn/ui 的设计系统变量 */
|
||||
/* =================================================================== */
|
||||
@@ -19,7 +19,7 @@
|
||||
--secondary: theme(colors.zinc.200);
|
||||
--secondary-foreground: theme(colors.zinc.900);
|
||||
|
||||
--destructive: theme(colors.red.600);
|
||||
--destructive: theme(colors.red.500);
|
||||
--destructive-foreground: theme(colors.white);
|
||||
--accent: theme(colors.zinc.100);
|
||||
--accent-foreground: theme(colors.zinc.900);
|
||||
@@ -69,10 +69,10 @@
|
||||
@apply bg-primary text-primary-foreground hover:bg-primary/90;
|
||||
}
|
||||
.btn-secondary {
|
||||
@apply bg-secondary text-secondary-foreground hover:bg-secondary/80;
|
||||
@apply bg-secondary text-secondary-foreground border border-zinc-500/30 hover:bg-secondary/80;
|
||||
}
|
||||
.btn-destructive {
|
||||
@apply bg-destructive text-destructive-foreground hover:bg-destructive/90;
|
||||
@apply bg-destructive text-destructive-foreground border border-zinc-500/30 hover:bg-destructive/90;
|
||||
}
|
||||
.btn-outline {
|
||||
@apply border border-input bg-background hover:bg-accent hover:text-accent-foreground;
|
||||
@@ -83,7 +83,9 @@
|
||||
.btn-link {
|
||||
@apply text-primary underline-offset-4 hover:underline;
|
||||
}
|
||||
|
||||
.btn-group-item.active {
|
||||
@apply bg-primary text-primary-foreground;
|
||||
}
|
||||
/* 按钮尺寸变体 */
|
||||
.btn-lg { @apply h-11 rounded-md px-8; }
|
||||
.btn-md { @apply h-10 px-4 py-2; }
|
||||
@@ -97,6 +99,155 @@
|
||||
focus-visible:outline-none focus-visible:outline-2 focus-visible:outline-offset-2 focus-visible:outline-[(var(--ring))]
|
||||
disabled:cursor-not-allowed disabled:opacity-50;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------ */
|
||||
/* Custom Flatpickr Theme using shadcn/ui variables */
|
||||
/* ------------------------------------------------ */
|
||||
|
||||
.flatpickr-calendar {
|
||||
/* --- 主题样式 --- */
|
||||
@apply bg-background text-foreground rounded-lg shadow-lg border border-zinc-500/30 w-auto font-sans;
|
||||
animation: var(--animation-panel-in);
|
||||
width: 200px;
|
||||
/* --- 核心结构样式 --- */
|
||||
display: none;
|
||||
position: absolute;
|
||||
visibility: hidden;
|
||||
opacity: 0;
|
||||
padding: 0;
|
||||
z-index: 999;
|
||||
box-sizing: border-box;
|
||||
transition: opacity 0.15s ease-out, visibility 0.15s ease-out;
|
||||
}
|
||||
.flatpickr-calendar.open {
|
||||
opacity: 1;
|
||||
visibility: visible;
|
||||
display: inline-block;
|
||||
}
|
||||
|
||||
.flatpickr-calendar.not-ready {
|
||||
top: 0;
|
||||
left: 0;
|
||||
visibility: hidden;
|
||||
}
|
||||
.flatpickr-calendar.static {
|
||||
position: relative;
|
||||
top: auto;
|
||||
left: auto;
|
||||
display: block;
|
||||
visibility: visible;
|
||||
opacity: 1;
|
||||
}
|
||||
.flatpickr-calendar.static { position: relative; top: auto; left: auto; display: block; visibility: visible; opacity: 1; }
|
||||
.flatpickr-calendar.not-ready { top: 0; left: 0; visibility: hidden; }
|
||||
/* 月份导航区域 */
|
||||
.flatpickr-months {
|
||||
@apply flex items-center bg-transparent p-0 border-b border-zinc-500/30;
|
||||
}
|
||||
.flatpickr-month { @apply h-auto pt-2 pb-1; }
|
||||
.flatpickr-current-month {
|
||||
@apply flex flex-1 items-center justify-center text-foreground font-semibold text-sm h-auto;
|
||||
}
|
||||
|
||||
.flatpickr-current-month .numInputWrapper {
|
||||
@apply ml-0;
|
||||
}
|
||||
|
||||
.flatpickr-current-month .numInputWrapper input.numInput {
|
||||
|
||||
@apply w-14 text-center font-semibold bg-transparent border-0 p-0 text-sm text-foreground;
|
||||
|
||||
/* 移除默认的浏览器样式和聚焦时的轮廓 */
|
||||
@apply appearance-none focus:outline-none focus:ring-0;
|
||||
-moz-appearance: textfield;
|
||||
}
|
||||
/* 强制移除数字输入框在 Chrome, Safari, Edge 等 Webkit 浏览器中的上下箭头 */
|
||||
.flatpickr-current-month .numInputWrapper input.numInput::-webkit-outer-spin-button,
|
||||
.flatpickr-current-month .numInputWrapper input.numInput::-webkit-inner-spin-button {
|
||||
-webkit-appearance: none;
|
||||
margin: 0;
|
||||
}
|
||||
.flatpickr-current-month .cur-month { @apply font-semibold; }
|
||||
.flatpickr-current-month .flatpickr-monthDropdown-months {
|
||||
|
||||
@apply w-22 font-semibold bg-transparent border-0 p-0 text-sm text-foreground text-right;
|
||||
|
||||
@apply appearance-none focus:outline-none focus:ring-0;
|
||||
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap; /* 确保月份名不换行 */
|
||||
}
|
||||
option.flatpickr-monthDropdown-month {
|
||||
@apply bg-background text-foreground border-0;
|
||||
@apply dark:bg-zinc-800 dark:text-zinc-200;
|
||||
}
|
||||
.flatpickr-current-month .flatpickr-monthDropdown-months,
|
||||
.flatpickr-current-month .numInputWrapper input.numInput {
|
||||
@apply text-sm pl-0 ;
|
||||
}
|
||||
/* 导航箭头 */
|
||||
.flatpickr-prev-month,
|
||||
.flatpickr-next-month {
|
||||
@apply inline-flex items-center justify-center whitespace-nowrap rounded-md pt-2 pb-1 text-sm font-medium transition-colors
|
||||
focus-visible:outline-none focus-visible:outline-2 focus-visible:outline-offset-2 focus-visible:outline-[(var(--ring))]
|
||||
disabled:pointer-events-none disabled:opacity-50
|
||||
hover:text-accent-foreground
|
||||
h-7 w-7 shrink-0;
|
||||
position: relative;
|
||||
}
|
||||
.flatpickr-prev-month svg,
|
||||
.flatpickr-next-month svg {
|
||||
@apply h-3 w-3 hover:h-4 hover:w-4 fill-zinc-500;
|
||||
}
|
||||
|
||||
/* 星期标题 */
|
||||
.flatpickr-weekdaycontainer {
|
||||
@apply flex justify-around p-1;
|
||||
}
|
||||
span.flatpickr-weekday {
|
||||
@apply flex-1 text-center text-muted-foreground font-medium;
|
||||
font-size: 0.7rem;
|
||||
}
|
||||
/* 日期网格 */
|
||||
.dayContainer {
|
||||
@apply flex flex-wrap p-1 pt-0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
.flatpickr-day {
|
||||
@apply w-4 h-6.5 flex items-center justify-center rounded-full border-0 text-foreground transition-colors shrink-0; /* <--- 从 w-9 h-9 缩小 */
|
||||
flex-basis: 14.2857%;
|
||||
line-height: 1;
|
||||
cursor: pointer;
|
||||
font-size: 0.7rem; /* 介于 text-xs 和 text-sm 之间 */
|
||||
}
|
||||
.flatpickr-day:hover,
|
||||
.flatpickr-day:focus { @apply bg-accent text-accent-foreground outline-none; }
|
||||
.flatpickr-day.today { @apply border border-primary; }
|
||||
.flatpickr-day.selected,
|
||||
.flatpickr-day.startRange,
|
||||
.flatpickr-day.endRange,
|
||||
.flatpickr-day.selected:hover,
|
||||
.flatpickr-day.startRange:hover,
|
||||
.flatpickr-day.endRange:hover {
|
||||
@apply bg-primary text-primary-foreground;
|
||||
}
|
||||
.flatpickr-day.inRange { @apply bg-accent rounded-none shadow-none; }
|
||||
.flatpickr-day.startRange { @apply rounded-l-full; }
|
||||
.flatpickr-day.endRange { @apply rounded-r-full; }
|
||||
.flatpickr-day.disabled,
|
||||
.flatpickr-day.disabled:hover { @apply bg-transparent text-muted-foreground/50 cursor-not-allowed; }
|
||||
.flatpickr-day.nextMonthDay, .flatpickr-day.prevMonthDay { @apply text-muted-foreground/50; }
|
||||
.flatpickr-day.nextMonthDay:hover, .flatpickr-day.prevMonthDay:hover { @apply bg-accent; }
|
||||
|
||||
/* 清除按钮 */
|
||||
.flatpickr-calendar .flatpickr-clear-button {
|
||||
@apply h-6 py-3 inline-flex items-center justify-center whitespace-nowrap rounded-md text-xs font-medium transition-colors
|
||||
focus-visible:outline-none focus-visible:outline-2 focus-visible:outline-offset-2 focus-visible:outline-[(var(--ring))]
|
||||
disabled:pointer-events-none disabled:opacity-50
|
||||
text-primary underline-offset-4 hover:text-sm
|
||||
w-full rounded-t-none border-t border-zinc-500/20;
|
||||
}
|
||||
}
|
||||
|
||||
@custom-variant dark (&:where(.dark, .dark *));
|
||||
@@ -214,8 +365,8 @@
|
||||
@apply w-[1.2rem] text-center;
|
||||
@apply transition-all duration-300 ease-in-out;
|
||||
/* 悬停和激活状态 */
|
||||
@apply group-hover:text-[#60a5fa] group-hover:[filter:drop-shadow(0_0_5px_rgba(59,130,246,0.5))];
|
||||
@apply group-data-[active='true']:text-[#60a5fa] group-data-[active='true']:[filter:drop-shadow(0_0_5px_rgba(59,130,246,0.7))];
|
||||
@apply group-hover:text-[#60a5fa] group-hover:filter-[drop-shadow(0_0_5px_rgba(59,130,246,0.5))];
|
||||
@apply group-data-[active='true']:text-[#60a5fa] group-data-[active='true']:filter-[drop-shadow(0_0_5px_rgba(59,130,246,0.7))];
|
||||
}
|
||||
/* 4. 指示器 */
|
||||
.nav-indicator {
|
||||
@@ -243,13 +394,13 @@
|
||||
@apply flex items-start p-3 w-full rounded-lg shadow-lg ring-1 ring-black/5 dark:ring-white/10 bg-white/80 dark:bg-zinc-800/80 backdrop-blur-md pointer-events-auto;
|
||||
}
|
||||
.toast-icon {
|
||||
@apply flex-shrink-0 w-8 h-8 rounded-full flex items-center justify-center text-white mr-3;
|
||||
@apply shrink-0 w-8 h-8 rounded-full flex items-center justify-center text-white mr-3;
|
||||
}
|
||||
.toast-icon-loading {@apply bg-blue-500;}
|
||||
.toast-icon-success {@apply bg-green-500;}
|
||||
.toast-icon-error {@apply bg-red-500;}
|
||||
.toast-content {
|
||||
@apply flex-grow;
|
||||
@apply grow
|
||||
}
|
||||
.toast-title {
|
||||
@apply font-semibold text-sm text-zinc-800 dark:text-zinc-100;
|
||||
@@ -273,7 +424,7 @@
|
||||
|
||||
/* --- 任务项主内容区 (左栏) --- */
|
||||
.task-item-main {
|
||||
@apply flex items-center justify-between flex-grow gap-1; /* flex-grow 使其占据所有可用空间 */
|
||||
@apply flex items-center justify-between grow gap-1; /* flex-grow 使其占据所有可用空间 */
|
||||
}
|
||||
/* 2. 任务项头部: 包含标题和时间戳 */
|
||||
.task-item-header {
|
||||
@@ -285,7 +436,7 @@
|
||||
}
|
||||
.task-item-timestamp {
|
||||
/* 融合了您原有的字体样式 */
|
||||
@apply text-xs self-start pt-1.5 pl-2 text-zinc-400 dark:text-zinc-500 flex-shrink-0;
|
||||
@apply text-xs self-start pt-1.5 pl-2 text-zinc-400 dark:text-zinc-500 shrink-0
|
||||
}
|
||||
|
||||
/* 3. [新增] 阶段动画的核心容器 */
|
||||
@@ -298,13 +449,13 @@
|
||||
@apply flex items-center gap-2 p-1.5 rounded-md transition-all duration-300 ease-in-out relative;
|
||||
}
|
||||
.task-stage-icon {
|
||||
@apply w-4 h-4 relative flex-shrink-0 text-zinc-400;
|
||||
@apply w-4 h-4 relative shrink-0 text-zinc-400;
|
||||
}
|
||||
.task-stage-icon i {
|
||||
@apply absolute top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 opacity-0 transition-opacity duration-200;
|
||||
}
|
||||
.task-stage-content {
|
||||
@apply flex-grow flex justify-between items-baseline text-xs;
|
||||
@apply grow justify-between items-baseline text-xs;
|
||||
}
|
||||
.task-stage-name {
|
||||
@apply text-zinc-600 dark:text-zinc-400;
|
||||
@@ -373,7 +524,7 @@
|
||||
}
|
||||
/* --- 4. 折叠/展开的雪佛兰图标 --- */
|
||||
.task-toggle-icon {
|
||||
@apply transition-transform duration-300 ease-in-out text-zinc-400 flex-shrink-0 ml-2;
|
||||
@apply transition-transform duration-300 ease-in-out text-zinc-400 shrink-0 ml-2;
|
||||
}
|
||||
/* --- 5. 展开状态下的图标旋转 --- */
|
||||
/*
|
||||
@@ -470,7 +621,7 @@
|
||||
* 2. 【新增】移动端首屏的 "当前分组" 选择器样式
|
||||
*/
|
||||
.mobile-group-selector {
|
||||
@apply flex-grow flex items-center justify-between p-3 border border-zinc-200 dark:border-zinc-700 rounded-lg;
|
||||
@apply grow flex items-center justify-between p-3 border border-zinc-200 dark:border-zinc-700 rounded-lg;
|
||||
}
|
||||
/* 移动端群组下拉列表样式 */
|
||||
.mobile-group-menu-active {
|
||||
@@ -621,7 +772,7 @@
|
||||
/* Tag Input Component */
|
||||
.tag-input-container {
|
||||
|
||||
@apply flex flex-wrap items-center gap-2 mt-1 w-full rounded-md bg-white dark:bg-zinc-700 border border-zinc-300 dark:border-zinc-600 p-2 min-h-[40px] focus-within:border-blue-500 focus-within:ring-1 focus-within:ring-blue-500;
|
||||
@apply flex flex-wrap items-center gap-2 mt-1 w-full rounded-md bg-white dark:bg-zinc-700 border border-zinc-300 dark:border-zinc-600 p-2 min-h-10 focus-within:border-blue-500 focus-within:ring-1 focus-within:ring-blue-500;
|
||||
}
|
||||
.tag-item {
|
||||
@apply flex items-center gap-x-1.5 bg-blue-100 dark:bg-blue-500/20 text-blue-800 dark:text-blue-200 text-sm font-medium rounded-full px-2.5 py-0.5;
|
||||
@@ -631,7 +782,7 @@
|
||||
}
|
||||
.tag-input-new {
|
||||
/* 使其在容器内垂直居中,感觉更好 */
|
||||
@apply flex-grow bg-transparent focus:outline-none text-sm self-center;
|
||||
@apply grow bg-transparent focus:outline-none text-sm self-center;
|
||||
}
|
||||
|
||||
/* 为复制按钮提供基础样式 */
|
||||
@@ -702,7 +853,7 @@
|
||||
}
|
||||
/* .tooltip-text is now dynamically generated by JS */
|
||||
.global-tooltip {
|
||||
@apply fixed z-[9999] w-max max-w-xs whitespace-normal rounded-lg bg-zinc-800 px-3 py-2 text-sm font-medium text-white shadow-lg transition-opacity duration-200;
|
||||
@apply fixed z-9999 w-max max-w-xs whitespace-normal rounded-lg bg-zinc-800 px-3 py-2 text-sm font-medium text-white shadow-lg transition-opacity duration-200;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -752,14 +903,18 @@
|
||||
@apply text-4xl;
|
||||
}
|
||||
|
||||
|
||||
/* =================================================================== */
|
||||
/* 自定义 table 样式 (Custom SweetAlert2 Styles) */
|
||||
/* =================================================================== */
|
||||
@layer components {
|
||||
/* --- [新增] 可复用的表格组件样式 --- */
|
||||
/* --- 可复用的表格组件样式 --- */
|
||||
.table {
|
||||
@apply w-full caption-bottom text-sm;
|
||||
}
|
||||
.table-header {
|
||||
/* 使用语义化颜色,自动适应暗色模式 */
|
||||
@apply sticky top-0 z-10 border-b border-border bg-muted/50;
|
||||
@apply sticky top-0 z-10 border-b border-border bg-zinc-200 dark:bg-zinc-900;
|
||||
}
|
||||
.table-header .table-row {
|
||||
/* 表头的 hover 效果通常与数据行不同,或者没有 */
|
||||
@@ -769,7 +924,7 @@
|
||||
@apply [&_tr:last-child]:border-0;
|
||||
}
|
||||
.table-row {
|
||||
@apply border-b border-border transition-colors hover:bg-muted/80;
|
||||
@apply border-b border-border transition-colors hover:bg-muted;
|
||||
}
|
||||
.table-head-cell {
|
||||
@apply h-12 px-4 text-left align-middle font-medium text-muted-foreground;
|
||||
@@ -777,4 +932,5 @@
|
||||
.table-cell {
|
||||
@apply p-4 align-middle;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
129
frontend/js/components/customSelectV2.js
Normal file
129
frontend/js/components/customSelectV2.js
Normal file
@@ -0,0 +1,129 @@
|
||||
// Filename: frontend/js/components/customSelectV2.js
|
||||
import { createPopper } from '../vendor/popper.esm.min.js';
|
||||
|
||||
export default class CustomSelectV2 {
|
||||
constructor(container) {
|
||||
this.container = container;
|
||||
this.trigger = this.container.querySelector('.custom-select-trigger');
|
||||
this.nativeSelect = this.container.querySelector('select');
|
||||
this.template = this.container.querySelector('.custom-select-panel-template');
|
||||
|
||||
if (!this.trigger || !this.nativeSelect || !this.template) {
|
||||
console.warn('CustomSelectV2 cannot initialize: missing required elements.', this.container);
|
||||
return;
|
||||
}
|
||||
|
||||
this.panel = null;
|
||||
this.popperInstance = null;
|
||||
this.isOpen = false;
|
||||
this.triggerText = this.trigger.querySelector('span');
|
||||
|
||||
if (typeof CustomSelectV2.openInstance === 'undefined') {
|
||||
CustomSelectV2.openInstance = null;
|
||||
CustomSelectV2.initGlobalListener();
|
||||
}
|
||||
|
||||
this.updateTriggerText();
|
||||
this.bindEvents();
|
||||
}
|
||||
|
||||
static initGlobalListener() {
|
||||
document.addEventListener('click', (event) => {
|
||||
const instance = CustomSelectV2.openInstance;
|
||||
if (instance && !instance.container.contains(event.target) && (!instance.panel || !instance.panel.contains(event.target))) {
|
||||
instance.close();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
createPanel() {
|
||||
const panelFragment = this.template.content.cloneNode(true);
|
||||
this.panel = panelFragment.querySelector('.custom-select-panel');
|
||||
document.body.appendChild(this.panel);
|
||||
|
||||
this.panel.innerHTML = '';
|
||||
Array.from(this.nativeSelect.options).forEach(option => {
|
||||
const item = document.createElement('a');
|
||||
item.href = '#';
|
||||
item.className = 'custom-select-option block w-full text-left px-3 py-1.5 text-sm text-zinc-700 hover:bg-zinc-100 dark:text-zinc-200 dark:hover:bg-zinc-700';
|
||||
item.textContent = option.textContent;
|
||||
item.dataset.value = option.value;
|
||||
if (option.selected) { item.classList.add('is-selected'); }
|
||||
this.panel.appendChild(item);
|
||||
});
|
||||
|
||||
this.panel.addEventListener('click', (event) => {
|
||||
event.preventDefault();
|
||||
const optionEl = event.target.closest('.custom-select-option');
|
||||
if (optionEl) { this.selectOption(optionEl); }
|
||||
});
|
||||
}
|
||||
|
||||
bindEvents() {
|
||||
this.trigger.addEventListener('click', (event) => {
|
||||
event.stopPropagation();
|
||||
if (CustomSelectV2.openInstance && CustomSelectV2.openInstance !== this) {
|
||||
CustomSelectV2.openInstance.close();
|
||||
}
|
||||
this.toggle();
|
||||
});
|
||||
}
|
||||
|
||||
selectOption(optionEl) {
|
||||
const selectedValue = optionEl.dataset.value;
|
||||
if (this.nativeSelect.value !== selectedValue) {
|
||||
this.nativeSelect.value = selectedValue;
|
||||
this.nativeSelect.dispatchEvent(new Event('change', { bubbles: true }));
|
||||
}
|
||||
this.updateTriggerText();
|
||||
this.close();
|
||||
}
|
||||
|
||||
updateTriggerText() {
|
||||
const selectedOption = this.nativeSelect.options[this.nativeSelect.selectedIndex];
|
||||
if (selectedOption) {
|
||||
this.triggerText.textContent = selectedOption.textContent;
|
||||
}
|
||||
}
|
||||
|
||||
toggle() { this.isOpen ? this.close() : this.open(); }
|
||||
|
||||
open() {
|
||||
if (this.isOpen) return;
|
||||
this.isOpen = true;
|
||||
|
||||
if (!this.panel) { this.createPanel(); }
|
||||
|
||||
this.panel.style.display = 'block';
|
||||
this.panel.offsetHeight;
|
||||
|
||||
this.popperInstance = createPopper(this.trigger, this.panel, {
|
||||
placement: 'top-start',
|
||||
modifiers: [
|
||||
{ name: 'offset', options: { offset: [0, 8] } },
|
||||
{ name: 'flip', options: { fallbackPlacements: ['bottom-start'] } }
|
||||
],
|
||||
});
|
||||
|
||||
CustomSelectV2.openInstance = this;
|
||||
}
|
||||
|
||||
close() {
|
||||
if (!this.isOpen) return;
|
||||
this.isOpen = false;
|
||||
|
||||
if (this.popperInstance) {
|
||||
this.popperInstance.destroy();
|
||||
this.popperInstance = null;
|
||||
}
|
||||
|
||||
if (this.panel) {
|
||||
this.panel.remove();
|
||||
this.panel = null;
|
||||
}
|
||||
|
||||
if (CustomSelectV2.openInstance === this) {
|
||||
CustomSelectV2.openInstance = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
95
frontend/js/components/filterPopover.js
Normal file
95
frontend/js/components/filterPopover.js
Normal file
@@ -0,0 +1,95 @@
|
||||
// Filename: frontend/js/components/filterPopover.js
|
||||
|
||||
import { createPopper } from '../vendor/popper.esm.min.js';
|
||||
|
||||
export default class FilterPopover {
|
||||
constructor(triggerElement, options, title) {
|
||||
if (!triggerElement || typeof createPopper !== 'function') {
|
||||
console.error('FilterPopover: Trigger element or Popper.js not found.');
|
||||
return;
|
||||
}
|
||||
this.triggerElement = triggerElement;
|
||||
this.options = options;
|
||||
this.title = title;
|
||||
this.selectedValues = new Set();
|
||||
|
||||
this._createPopoverHTML();
|
||||
this.popperInstance = createPopper(this.triggerElement, this.popoverElement, {
|
||||
placement: 'bottom-start',
|
||||
modifiers: [{ name: 'offset', options: { offset: [0, 8] } }],
|
||||
});
|
||||
|
||||
this._bindEvents();
|
||||
}
|
||||
|
||||
_createPopoverHTML() {
|
||||
this.popoverElement = document.createElement('div');
|
||||
this.popoverElement.className = 'hidden z-50 min-w-[12rem] rounded-md border-1 border-zinc-500/30 bg-popover bg-white dark:bg-zinc-900 p-2 text-popover-foreground shadow-md';
|
||||
this.popoverElement.innerHTML = `
|
||||
<div class="px-2 py-1.5 text-sm font-semibold">${this.title}</div>
|
||||
<div class="space-y-1 p-1">
|
||||
${this.options.map(option => `
|
||||
<label class="flex items-center space-x-2 px-2 py-1.5 rounded-md hover:bg-accent cursor-pointer">
|
||||
<input type="checkbox" value="${option.value}" class="h-4 w-4 rounded border-zinc-300 text-blue-600 focus:ring-blue-500">
|
||||
<span class="text-sm">${option.label}</span>
|
||||
</label>
|
||||
`).join('')}
|
||||
</div>
|
||||
<div class="border-t border-border mt-2 pt-2 px-2 flex justify-end space-x-2">
|
||||
<button data-action="clear" class="btn btn-ghost h-7 px-2 text-xs">清空</button>
|
||||
<button data-action="apply" class="btn btn-primary h-7 px-2 text-xs">应用</button>
|
||||
</div>
|
||||
`;
|
||||
document.body.appendChild(this.popoverElement);
|
||||
}
|
||||
|
||||
_bindEvents() {
|
||||
this.triggerElement.addEventListener('click', () => this.toggle());
|
||||
|
||||
document.addEventListener('click', (event) => {
|
||||
if (!this.popoverElement.contains(event.target) && !this.triggerElement.contains(event.target)) {
|
||||
this.hide();
|
||||
}
|
||||
});
|
||||
|
||||
this.popoverElement.addEventListener('click', (event) => {
|
||||
const target = event.target.closest('button');
|
||||
if (!target) return;
|
||||
const action = target.dataset.action;
|
||||
if (action === 'clear') this._handleClear();
|
||||
if (action === 'apply') this._handleApply();
|
||||
});
|
||||
}
|
||||
|
||||
_handleClear() {
|
||||
this.popoverElement.querySelectorAll('input[type="checkbox"]').forEach(cb => cb.checked = false);
|
||||
this.selectedValues.clear();
|
||||
this._handleApply();
|
||||
}
|
||||
|
||||
_handleApply() {
|
||||
this.selectedValues.clear();
|
||||
this.popoverElement.querySelectorAll('input:checked').forEach(cb => {
|
||||
this.selectedValues.add(cb.value);
|
||||
});
|
||||
|
||||
const filterChangeEvent = new CustomEvent('filter-change', {
|
||||
detail: {
|
||||
filterKey: this.triggerElement.id,
|
||||
selected: this.selectedValues
|
||||
}
|
||||
});
|
||||
this.triggerElement.dispatchEvent(filterChangeEvent);
|
||||
|
||||
this.hide();
|
||||
}
|
||||
|
||||
toggle() {
|
||||
this.popoverElement.classList.toggle('hidden');
|
||||
this.popperInstance.update();
|
||||
}
|
||||
|
||||
hide() {
|
||||
this.popoverElement.classList.add('hidden');
|
||||
}
|
||||
}
|
||||
@@ -326,6 +326,41 @@ class UIPatterns {
|
||||
});
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Sets a button to a loading state by disabling it and showing a spinner.
|
||||
* It stores the button's original content to be restored later.
|
||||
* @param {HTMLButtonElement} button - The button element to modify.
|
||||
*/
|
||||
setButtonLoading(button) {
|
||||
if (!button) return;
|
||||
// Store original content if it hasn't been stored already
|
||||
if (!button.dataset.originalContent) {
|
||||
button.dataset.originalContent = button.innerHTML;
|
||||
}
|
||||
button.disabled = true;
|
||||
button.innerHTML = '<i class="fas fa-spinner fa-spin"></i>';
|
||||
}
|
||||
/**
|
||||
* Restores a button from its loading state to its original content and enables it.
|
||||
* @param {HTMLButtonElement} button - The button element to restore.
|
||||
*/
|
||||
clearButtonLoading(button) {
|
||||
if (!button) return;
|
||||
if (button.dataset.originalContent) {
|
||||
button.innerHTML = button.dataset.originalContent;
|
||||
// Clean up the data attribute
|
||||
delete button.dataset.originalContent;
|
||||
}
|
||||
button.disabled = false;
|
||||
}
|
||||
/**
|
||||
* Returns the HTML for a streaming text cursor animation.
|
||||
* This is used as a placeholder in the chat UI while waiting for an assistant's response.
|
||||
* @returns {string} The HTML string for the loader.
|
||||
*/
|
||||
renderStreamingLoader() {
|
||||
return '<span class="streaming-cursor animate-pulse">▋</span>';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
// Filename: frontend/js/main.js
|
||||
|
||||
import Swal from './vendor/sweetalert2.esm.js';
|
||||
import './vendor/sweetalert2.min.css';
|
||||
import anime from './vendor/anime.esm.js';
|
||||
// === 1. 导入通用组件 (这些是所有页面都可能用到的,保持静态导入) ===
|
||||
import SlidingTabs from './components/slidingTabs.js';
|
||||
import CustomSelect from './components/customSelect.js';
|
||||
@@ -14,6 +16,7 @@ const pageModules = {
|
||||
'dashboard': () => import('./pages/dashboard.js'),
|
||||
'keys': () => import('./pages/keys/index.js'),
|
||||
'logs': () => import('./pages/logs/index.js'),
|
||||
'chat': () => import('./pages/chat/index.js'),
|
||||
// 'settings': () => import('./pages/settings.js'), // 未来启用 settings 页面
|
||||
// 未来新增的页面,只需在这里添加一行映射,esbuild会自动处理
|
||||
};
|
||||
@@ -48,3 +51,5 @@ window.modalManager = modalManager;
|
||||
window.taskCenterManager = taskCenterManager;
|
||||
window.toastManager = toastManager;
|
||||
window.uiPatterns = uiPatterns;
|
||||
window.Swal = Swal;
|
||||
window.anime = anime;
|
||||
178
frontend/js/pages/chat/SessionManager.js
Normal file
178
frontend/js/pages/chat/SessionManager.js
Normal file
@@ -0,0 +1,178 @@
|
||||
// Filename: frontend/js/pages/chat/SessionManager.js
|
||||
|
||||
import { nanoid } from 'https://cdn.jsdelivr.net/npm/nanoid/nanoid.js';
|
||||
|
||||
const LOCAL_STORAGE_KEY = 'gemini_chat_state';
|
||||
|
||||
/**
|
||||
* Manages the state and persistence of chat sessions.
|
||||
* This class handles loading from/saving to localStorage,
|
||||
* and all operations like creating, switching, and deleting sessions.
|
||||
*/
|
||||
export class SessionManager {
|
||||
constructor() {
|
||||
this.state = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the manager by loading state from localStorage or creating a default state.
|
||||
*/
|
||||
init() {
|
||||
this._loadState();
|
||||
}
|
||||
|
||||
// --- Public API for state access ---
|
||||
|
||||
getSessions() {
|
||||
return this.state.sessions;
|
||||
}
|
||||
|
||||
getCurrentSessionId() {
|
||||
return this.state.currentSessionId;
|
||||
}
|
||||
|
||||
getCurrentSession() {
|
||||
return this.state.sessions.find(s => s.id === this.state.currentSessionId);
|
||||
}
|
||||
|
||||
// --- Public API for state mutation ---
|
||||
|
||||
/**
|
||||
* Creates a new, empty session and sets it as the current one.
|
||||
*/
|
||||
createSession() {
|
||||
const newSessionId = nanoid();
|
||||
const newSession = {
|
||||
id: newSessionId,
|
||||
name: '新会话',
|
||||
systemPrompt: '',
|
||||
messages: [],
|
||||
modelConfig: { model: 'gemini-2.0-flash-lite' },
|
||||
params: { temperature: 0.7 }
|
||||
};
|
||||
this.state.sessions.unshift(newSession);
|
||||
this.state.currentSessionId = newSessionId;
|
||||
this._saveState();
|
||||
}
|
||||
|
||||
/**
|
||||
* Switches the current session to the one with the given ID.
|
||||
* @param {string} sessionId The ID of the session to switch to.
|
||||
*/
|
||||
switchSession(sessionId) {
|
||||
if (this.state.currentSessionId === sessionId) return;
|
||||
this.state.currentSessionId = sessionId;
|
||||
this._saveState();
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a session by its ID.
|
||||
* @param {string} sessionId The ID of the session to delete.
|
||||
*/
|
||||
deleteSession(sessionId) {
|
||||
this.state.sessions = this.state.sessions.filter(s => s.id !== sessionId);
|
||||
|
||||
if (this.state.currentSessionId === sessionId) {
|
||||
this.state.currentSessionId = this.state.sessions[0]?.id || null;
|
||||
if (!this.state.currentSessionId) {
|
||||
this._createInitialState(); // Create a new one if all are deleted
|
||||
}
|
||||
}
|
||||
|
||||
this._saveState();
|
||||
}
|
||||
|
||||
/**
|
||||
* [NEW] Clears all messages from the currently active session.
|
||||
*/
|
||||
clearCurrentSession() {
|
||||
const currentSession = this.getCurrentSession();
|
||||
if (currentSession) {
|
||||
currentSession.messages = [];
|
||||
this._saveState();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a message to the current session and updates the session name if it's the first message.
|
||||
* @param {object} message The message object to add.
|
||||
* @returns {object} The session that was updated.
|
||||
*/
|
||||
addMessage(message) {
|
||||
const currentSession = this.getCurrentSession();
|
||||
if (currentSession) {
|
||||
if (currentSession.messages.length === 0 && message.role === 'user') {
|
||||
currentSession.name = message.content.substring(0, 30);
|
||||
}
|
||||
currentSession.messages.push(message);
|
||||
this._saveState();
|
||||
return currentSession;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
deleteMessage(messageId) {
|
||||
const currentSession = this.getCurrentSession();
|
||||
if (currentSession) {
|
||||
const messageIndex = currentSession.messages.findIndex(m => m.id === messageId);
|
||||
if (messageIndex > -1) {
|
||||
currentSession.messages.splice(messageIndex, 1);
|
||||
this._saveState();
|
||||
console.log(`Message ${messageId} deleted.`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
truncateMessagesAfter(messageId) {
|
||||
const currentSession = this.getCurrentSession();
|
||||
if (currentSession) {
|
||||
const messageIndex = currentSession.messages.findIndex(m => m.id === messageId);
|
||||
// Ensure the message exists and it's not already the last one
|
||||
if (messageIndex > -1 && messageIndex < currentSession.messages.length - 1) {
|
||||
currentSession.messages.splice(messageIndex + 1);
|
||||
this._saveState();
|
||||
console.log(`Truncated messages after ${messageId}.`);
|
||||
}
|
||||
}
|
||||
}
|
||||
// --- Private persistence methods ---
|
||||
|
||||
_saveState() {
|
||||
try {
|
||||
localStorage.setItem(LOCAL_STORAGE_KEY, JSON.stringify(this.state));
|
||||
} catch (error) {
|
||||
console.error("Failed to save session state:", error);
|
||||
}
|
||||
}
|
||||
|
||||
_loadState() {
|
||||
try {
|
||||
const stateString = localStorage.getItem(LOCAL_STORAGE_KEY);
|
||||
if (stateString) {
|
||||
this.state = JSON.parse(stateString);
|
||||
} else {
|
||||
this._createInitialState();
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to load session state, creating initial state:", error);
|
||||
this._createInitialState();
|
||||
}
|
||||
}
|
||||
|
||||
_createInitialState() {
|
||||
const initialSessionId = nanoid();
|
||||
this.state = {
|
||||
sessions: [{
|
||||
id: initialSessionId,
|
||||
name: '新会话',
|
||||
systemPrompt: '',
|
||||
messages: [],
|
||||
modelConfig: { model: 'gemini-2.0-flash-lite' },
|
||||
params: { temperature: 0.7 }
|
||||
}],
|
||||
currentSessionId: initialSessionId,
|
||||
settings: {}
|
||||
};
|
||||
this._saveState();
|
||||
}
|
||||
}
|
||||
113
frontend/js/pages/chat/chatSettings.js
Normal file
113
frontend/js/pages/chat/chatSettings.js
Normal file
@@ -0,0 +1,113 @@
|
||||
// Filename: frontend/js/pages/chat/chatSettings.js
|
||||
|
||||
export class ChatSettings {
|
||||
constructor(elements) {
|
||||
// [MODIFIED] Store the root elements passed from ChatPage
|
||||
this.elements = {};
|
||||
this.elements.root = elements; // Keep a reference to all elements
|
||||
|
||||
// [MODIFIED] Query for specific elements this class controls, relative to their panels
|
||||
this._initScopedDOMElements();
|
||||
|
||||
// Initialize panel states to ensure they are collapsed on load
|
||||
this.elements.quickSettingsPanel.style.gridTemplateRows = '0fr';
|
||||
this.elements.sessionParamsPanel.style.gridTemplateRows = '0fr';
|
||||
}
|
||||
|
||||
// [NEW] A dedicated method to find elements within their specific panels
|
||||
_initScopedDOMElements() {
|
||||
this.elements.quickSettingsPanel = this.elements.root.quickSettingsPanel;
|
||||
this.elements.sessionParamsPanel = this.elements.root.sessionParamsPanel;
|
||||
this.elements.toggleQuickSettingsBtn = this.elements.root.toggleQuickSettingsBtn;
|
||||
this.elements.toggleSessionParamsBtn = this.elements.root.toggleSessionParamsBtn;
|
||||
|
||||
// Query elements within the quick settings panel
|
||||
this.elements.btnGroups = this.elements.quickSettingsPanel.querySelectorAll('.btn-group');
|
||||
this.elements.directRoutingOptions = this.elements.quickSettingsPanel.querySelector('#direct-routing-options');
|
||||
|
||||
// Query elements within the session params panel
|
||||
this.elements.temperatureSlider = this.elements.sessionParamsPanel.querySelector('#temperature-slider');
|
||||
this.elements.temperatureValue = this.elements.sessionParamsPanel.querySelector('#temperature-value');
|
||||
this.elements.contextSlider = this.elements.sessionParamsPanel.querySelector('#context-slider');
|
||||
this.elements.contextValue = this.elements.sessionParamsPanel.querySelector('#context-value');
|
||||
}
|
||||
|
||||
init() {
|
||||
if (!this.elements.toggleQuickSettingsBtn) {
|
||||
console.warn("ChatSettings: Aborting initialization, required elements not found.");
|
||||
return;
|
||||
}
|
||||
this._initPanelToggleListeners();
|
||||
this._initButtonGroupListeners();
|
||||
this._initSliderListeners();
|
||||
}
|
||||
|
||||
_initPanelToggleListeners() {
|
||||
this.elements.toggleQuickSettingsBtn.addEventListener('click', () =>
|
||||
this._togglePanel(this.elements.quickSettingsPanel, this.elements.toggleQuickSettingsBtn)
|
||||
);
|
||||
this.elements.toggleSessionParamsBtn.addEventListener('click', () =>
|
||||
this._togglePanel(this.elements.sessionParamsPanel, this.elements.toggleSessionParamsBtn)
|
||||
);
|
||||
}
|
||||
|
||||
_initButtonGroupListeners() {
|
||||
// [FIXED] This logic is now guaranteed to work with the correctly scoped elements.
|
||||
this.elements.btnGroups.forEach(group => {
|
||||
group.addEventListener('click', (e) => {
|
||||
const button = e.target.closest('.btn-group-item');
|
||||
if (!button) return;
|
||||
|
||||
group.querySelectorAll('.btn-group-item').forEach(btn => btn.removeAttribute('data-active'));
|
||||
button.setAttribute('data-active', 'true');
|
||||
|
||||
if (button.dataset.group === 'routing-mode') {
|
||||
this._handleRoutingModeChange(button.dataset.value);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
_initSliderListeners() {
|
||||
// [FIXED] Add null checks for robustness, now that elements are queried scoped.
|
||||
if (this.elements.temperatureSlider) {
|
||||
this.elements.temperatureSlider.addEventListener('input', () => {
|
||||
this.elements.temperatureValue.textContent = parseFloat(this.elements.temperatureSlider.value).toFixed(1);
|
||||
});
|
||||
}
|
||||
if (this.elements.contextSlider) {
|
||||
this.elements.contextSlider.addEventListener('input', () => {
|
||||
this.elements.contextValue.textContent = `${this.elements.contextSlider.value}k`;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
_handleRoutingModeChange(selectedValue) {
|
||||
// [FIXED] This logic now correctly targets the scoped element.
|
||||
if (this.elements.directRoutingOptions) {
|
||||
if (selectedValue === 'direct') {
|
||||
this.elements.directRoutingOptions.classList.remove('hidden');
|
||||
} else {
|
||||
this.elements.directRoutingOptions.classList.add('hidden');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_togglePanel(panel, button) {
|
||||
const isExpanded = panel.hasAttribute('data-expanded');
|
||||
|
||||
this.elements.quickSettingsPanel.removeAttribute('data-expanded');
|
||||
this.elements.sessionParamsPanel.removeAttribute('data-expanded');
|
||||
this.elements.toggleQuickSettingsBtn.removeAttribute('data-active');
|
||||
this.elements.toggleSessionParamsBtn.removeAttribute('data-active');
|
||||
|
||||
this.elements.quickSettingsPanel.style.gridTemplateRows = '0fr';
|
||||
this.elements.sessionParamsPanel.style.gridTemplateRows = '0fr';
|
||||
|
||||
if (!isExpanded) {
|
||||
panel.setAttribute('data-expanded', 'true');
|
||||
button.setAttribute('data-active', 'true');
|
||||
panel.style.gridTemplateRows = '1fr';
|
||||
}
|
||||
}
|
||||
}
|
||||
545
frontend/js/pages/chat/index.js
Normal file
545
frontend/js/pages/chat/index.js
Normal file
@@ -0,0 +1,545 @@
|
||||
// Filename: frontend/js/pages/chat/index.js
|
||||
|
||||
import { nanoid } from '../../vendor/nanoid.js';
|
||||
import { uiPatterns } from '../../components/ui.js';
|
||||
import { apiFetch } from '../../services/api.js';
|
||||
import { marked } from '../../vendor/marked.min.js';
|
||||
import { SessionManager } from './SessionManager.js';
|
||||
import { ChatSettings } from './chatSettings.js';
|
||||
import CustomSelectV2 from '../../components/customSelectV2.js';
|
||||
|
||||
marked.use({ breaks: true, gfm: true });
|
||||
|
||||
function getCookie(name) {
|
||||
let cookieValue = null;
|
||||
if (document.cookie && document.cookie !== '') {
|
||||
const cookies = document.cookie.split(';');
|
||||
for (let i = 0; i < cookies.length; i++) {
|
||||
const cookie = cookies[i].trim();
|
||||
if (cookie.substring(0, name.length + 1) === (name + '=')) {
|
||||
cookieValue = decodeURIComponent(cookie.substring(name.length + 1));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return cookieValue;
|
||||
}
|
||||
|
||||
class ChatPage {
|
||||
constructor() {
|
||||
this.sessionManager = new SessionManager();
|
||||
this.isStreaming = false;
|
||||
this.elements = {};
|
||||
this.initialized = false;
|
||||
this.searchTerm = '';
|
||||
this.settingsManager = null;
|
||||
}
|
||||
|
||||
init() {
|
||||
if (!document.querySelector('[data-page-id="chat"]')) { return; }
|
||||
|
||||
this.sessionManager.init();
|
||||
|
||||
this.initialized = true;
|
||||
this._initDOMElements();
|
||||
this._initComponents();
|
||||
this._initEventListeners();
|
||||
this._render();
|
||||
console.log("ChatPage initialized. Session management is delegated.", this.sessionManager.state);
|
||||
}
|
||||
|
||||
_initDOMElements() {
|
||||
this.elements.chatScrollContainer = document.getElementById('chat-scroll-container');
|
||||
this.elements.chatMessagesContainer = document.getElementById('chat-messages-container');
|
||||
this.elements.messageForm = document.getElementById('message-form');
|
||||
this.elements.messageInput = document.getElementById('message-input');
|
||||
this.elements.sendBtn = document.getElementById('send-btn');
|
||||
this.elements.newSessionBtn = document.getElementById('new-session-btn');
|
||||
this.elements.sessionListContainer = document.getElementById('session-list-container');
|
||||
this.elements.chatHeaderTitle = document.querySelector('.chat-header-title');
|
||||
this.elements.clearSessionBtn = document.getElementById('clear-session-btn');
|
||||
this.elements.sessionSearchInput = document.getElementById('session-search-input');
|
||||
this.elements.toggleQuickSettingsBtn = document.getElementById('toggle-quick-settings');
|
||||
this.elements.toggleSessionParamsBtn = document.getElementById('toggle-session-params');
|
||||
this.elements.quickSettingsPanel = document.getElementById('quick-settings-panel');
|
||||
this.elements.sessionParamsPanel = document.getElementById('session-params-panel');
|
||||
this.elements.directRoutingOptions = document.getElementById('direct-routing-options');
|
||||
this.elements.btnGroups = document.querySelectorAll('.btn-group');
|
||||
this.elements.temperatureSlider = document.getElementById('temperature-slider');
|
||||
this.elements.temperatureValue = document.getElementById('temperature-value');
|
||||
this.elements.contextSlider = document.getElementById('context-slider');
|
||||
this.elements.contextValue = document.getElementById('context-value');
|
||||
this.elements.groupSelectContainer = document.getElementById('group-select-container');
|
||||
}
|
||||
// [NEW] A dedicated method for initializing complex UI components
|
||||
_initComponents() {
|
||||
if (this.elements.groupSelectContainer) {
|
||||
new CustomSelectV2(this.elements.groupSelectContainer);
|
||||
}
|
||||
// In the future, we will initialize the model select component here as well
|
||||
}
|
||||
|
||||
_initEventListeners() {
|
||||
// --- Initialize Settings Manager First ---
|
||||
this.settingsManager = new ChatSettings(this.elements);
|
||||
this.settingsManager.init();
|
||||
// --- Core Chat Event Listeners ---
|
||||
this.elements.messageForm.addEventListener('submit', (e) => { e.preventDefault(); this._handleSendMessage(); });
|
||||
this.elements.messageInput.addEventListener('keydown', (e) => { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); this._handleSendMessage(); } });
|
||||
this.elements.messageInput.addEventListener('input', () => this._autoResizeTextarea());
|
||||
this.elements.newSessionBtn.addEventListener('click', () => {
|
||||
this.sessionManager.createSession();
|
||||
this._render();
|
||||
this.elements.messageInput.focus();
|
||||
});
|
||||
this.elements.sessionListContainer.addEventListener('click', (e) => {
|
||||
const sessionItem = e.target.closest('[data-session-id]');
|
||||
const deleteBtn = e.target.closest('.delete-session-btn');
|
||||
|
||||
if (deleteBtn) {
|
||||
e.preventDefault();
|
||||
const sessionId = deleteBtn.closest('[data-session-id]').dataset.sessionId;
|
||||
this._handleDeleteSession(sessionId);
|
||||
} else if (sessionItem) {
|
||||
e.preventDefault();
|
||||
const sessionId = sessionItem.dataset.sessionId;
|
||||
this.sessionManager.switchSession(sessionId);
|
||||
this._render();
|
||||
this.elements.messageInput.focus();
|
||||
}
|
||||
});
|
||||
this.elements.clearSessionBtn.addEventListener('click', () => this._handleClearSession());
|
||||
this.elements.sessionSearchInput.addEventListener('input', (e) => {
|
||||
this.searchTerm = e.target.value.trim();
|
||||
this._renderSessionList();
|
||||
});
|
||||
this.elements.chatMessagesContainer.addEventListener('click', (e) => {
|
||||
const messageElement = e.target.closest('[data-message-id]');
|
||||
if (!messageElement) return;
|
||||
const messageId = messageElement.dataset.messageId;
|
||||
const copyBtn = e.target.closest('.action-copy');
|
||||
const deleteBtn = e.target.closest('.action-delete');
|
||||
const retryBtn = e.target.closest('.action-retry');
|
||||
if (copyBtn) {
|
||||
this._handleCopyMessage(messageId);
|
||||
} else if (deleteBtn) {
|
||||
this._handleDeleteMessage(messageId, e.target);
|
||||
} else if (retryBtn) {
|
||||
this._handleRetryMessage(messageId);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
_handleCopyMessage(messageId) {
|
||||
const currentSession = this.sessionManager.getCurrentSession();
|
||||
if (!currentSession) return;
|
||||
const message = currentSession.messages.find(m => m.id === messageId);
|
||||
if (!message || !message.content) {
|
||||
console.error("Message content not found for copying.");
|
||||
return;
|
||||
}
|
||||
// Handle cases where content might be HTML (like error messages)
|
||||
// by stripping tags to get plain text.
|
||||
let textToCopy = message.content;
|
||||
if (textToCopy.includes('<') && textToCopy.includes('>')) {
|
||||
const tempDiv = document.createElement('div');
|
||||
tempDiv.innerHTML = textToCopy;
|
||||
textToCopy = tempDiv.textContent || tempDiv.innerText || '';
|
||||
}
|
||||
navigator.clipboard.writeText(textToCopy)
|
||||
.then(() => {
|
||||
Swal.fire({
|
||||
toast: true,
|
||||
position: 'top-end',
|
||||
icon: 'success',
|
||||
title: '已复制',
|
||||
showConfirmButton: false,
|
||||
timer: 1500,
|
||||
customClass: {
|
||||
popup: `${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}`
|
||||
}
|
||||
});
|
||||
})
|
||||
.catch(err => {
|
||||
console.error('Failed to copy text: ', err);
|
||||
Swal.fire({
|
||||
toast: true,
|
||||
position: 'top-end',
|
||||
icon: 'error',
|
||||
title: '复制失败',
|
||||
showConfirmButton: false,
|
||||
timer: 1500,
|
||||
customClass: {
|
||||
popup: `${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}`
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
_handleDeleteMessage(messageId, targetElement) {
|
||||
// Remove any existing popover first to prevent duplicates
|
||||
const existingPopover = document.getElementById('delete-confirmation-popover');
|
||||
if (existingPopover) {
|
||||
existingPopover.remove();
|
||||
}
|
||||
// Create the popover element with your specified dimensions.
|
||||
const popover = document.createElement('div');
|
||||
popover.id = 'delete-confirmation-popover';
|
||||
// [MODIFIED] - Using your w-36, and adding flexbox classes for centering.
|
||||
popover.className = 'absolute z-50 p-3 w-45 border border-border rounded-md shadow-lg bg-background text-popover-foreground flex flex-col items-center';
|
||||
|
||||
// [MODIFIED] - Added an icon and classes for horizontal centering.
|
||||
popover.innerHTML = `
|
||||
<div class="flex items-center gap-2 mb-2">
|
||||
<i class="fas fa-exclamation-circle text-yellow-500"></i>
|
||||
<p class="text-sm">确认删除此消息吗?</p>
|
||||
</div>
|
||||
<div class="flex translate-x-12 gap-2 w-full">
|
||||
<button class="btn btn-secondary rounded-xs w-12 btn-xs popover-cancel">取消</button>
|
||||
<button class="btn btn-destructive rounded-xs w-12 btn-xs popover-confirm">确认</button>
|
||||
</div>
|
||||
`;
|
||||
document.body.appendChild(popover);
|
||||
// Position the popover above the clicked icon
|
||||
const iconRect = targetElement.closest('button').getBoundingClientRect();
|
||||
const popoverRect = popover.getBoundingClientRect();
|
||||
popover.style.top = `${window.scrollY + iconRect.top - popoverRect.height - 8}px`;
|
||||
popover.style.left = `${window.scrollX + iconRect.left + (iconRect.width / 2) - (popoverRect.width / 2)}px`;
|
||||
// Event listener to close the popover if clicked outside
|
||||
const outsideClickListener = (event) => {
|
||||
if (!popover.contains(event.target) && event.target !== targetElement) {
|
||||
popover.remove();
|
||||
document.removeEventListener('click', outsideClickListener);
|
||||
}
|
||||
};
|
||||
setTimeout(() => document.addEventListener('click', outsideClickListener), 0);
|
||||
// Event listeners for the buttons inside the popover
|
||||
popover.querySelector('.popover-confirm').addEventListener('click', () => {
|
||||
this.sessionManager.deleteMessage(messageId);
|
||||
this._renderChatMessages();
|
||||
this._renderSessionList();
|
||||
popover.remove();
|
||||
document.removeEventListener('click', outsideClickListener);
|
||||
});
|
||||
popover.querySelector('.popover-cancel').addEventListener('click', () => {
|
||||
popover.remove();
|
||||
document.removeEventListener('click', outsideClickListener);
|
||||
});
|
||||
}
|
||||
|
||||
_handleRetryMessage(messageId) {
|
||||
if (this.isStreaming) return; // Prevent retrying while a response is already generating
|
||||
const currentSession = this.sessionManager.getCurrentSession();
|
||||
if (!currentSession) return;
|
||||
|
||||
const message = currentSession.messages.find(m => m.id === messageId);
|
||||
if (!message) return;
|
||||
if (message.role === 'user') {
|
||||
// Logic for retrying from a user's prompt
|
||||
this.sessionManager.truncateMessagesAfter(messageId);
|
||||
} else if (message.role === 'assistant') {
|
||||
// Logic for regenerating an assistant's response (must be the last one)
|
||||
this.sessionManager.deleteMessage(messageId);
|
||||
}
|
||||
// After data manipulation, update the UI and trigger a new response
|
||||
this._renderChatMessages();
|
||||
this._renderSessionList();
|
||||
this._getAssistantResponse();
|
||||
}
|
||||
_autoResizeTextarea() {
|
||||
const el = this.elements.messageInput;
|
||||
el.style.height = 'auto';
|
||||
el.style.height = (el.scrollHeight) + 'px';
|
||||
}
|
||||
|
||||
_handleSendMessage() {
|
||||
if (this.isStreaming) return;
|
||||
const content = this.elements.messageInput.value.trim();
|
||||
if (!content) return;
|
||||
|
||||
const userMessage = { id: nanoid(), role: 'user', content: content };
|
||||
this.sessionManager.addMessage(userMessage);
|
||||
|
||||
this._renderChatMessages();
|
||||
this._renderSessionList();
|
||||
|
||||
this.elements.messageInput.value = '';
|
||||
this._autoResizeTextarea();
|
||||
this.elements.messageInput.focus();
|
||||
this._getAssistantResponse();
|
||||
}
|
||||
|
||||
async _getAssistantResponse() {
|
||||
this.isStreaming = true;
|
||||
this._setLoadingState(true);
|
||||
const currentSession = this.sessionManager.getCurrentSession();
|
||||
const assistantMessageId = nanoid();
|
||||
let finalAssistantMessage = { id: assistantMessageId, role: 'assistant', content: '' };
|
||||
// Step 1: Create and render a temporary UI placeholder for streaming.
|
||||
// [MODIFIED] The placeholder now uses the three-dot animation.
|
||||
const placeholderHtml = `
|
||||
<div class="flex items-start gap-4" data-message-id="${assistantMessageId}">
|
||||
<span class="flex h-8 w-8 shrink-0 items-center justify-center rounded-full bg-primary text-primary-foreground">
|
||||
<i class="fas fa-robot"></i>
|
||||
</span>
|
||||
<div class="flex-1 space-y-2">
|
||||
<div class="relative group rounded-lg p-5 bg-primary/10 border/20">
|
||||
<div class="prose prose-sm max-w-none text-foreground message-content">
|
||||
<div class="flex items-center gap-1">
|
||||
<span class="h-2 w-2 bg-foreground/50 rounded-full animate-bounce" style="animation-delay: 0s;"></span>
|
||||
<span class="h-2 w-2 bg-foreground/50 rounded-full animate-bounce" style="animation-delay: 0.1s;"></span>
|
||||
<span class="h-2 w-2 bg-foreground/50 rounded-full animate-bounce" style="animation-delay: 0.2s;"></span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>`;
|
||||
this.elements.chatMessagesContainer.insertAdjacentHTML('beforeend', placeholderHtml);
|
||||
this._scrollToBottom();
|
||||
const assistantMessageContentEl = this.elements.chatMessagesContainer.querySelector(`[data-message-id="${assistantMessageId}"] .message-content`);
|
||||
try {
|
||||
const token = getCookie('gemini_admin_session');
|
||||
const headers = { 'Authorization': `Bearer ${token}` };
|
||||
const response = await apiFetch('/v1/chat/completions', {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
model: currentSession.modelConfig.model,
|
||||
messages: currentSession.messages.filter(m => m.content).map(({ role, content }) => ({ role, content })),
|
||||
stream: true,
|
||||
})
|
||||
});
|
||||
if (!response.body) throw new Error("Response body is null.");
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
while (true) {
|
||||
const { value, done } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
const lines = chunk.split('\n').filter(line => line.trim().startsWith('data:'));
|
||||
|
||||
for (const line of lines) {
|
||||
const dataStr = line.replace(/^data: /, '').trim();
|
||||
if (dataStr !== '[DONE]') {
|
||||
try {
|
||||
const data = JSON.parse(dataStr);
|
||||
const deltaContent = data.choices[0]?.delta?.content;
|
||||
if (deltaContent) {
|
||||
finalAssistantMessage.content += deltaContent;
|
||||
assistantMessageContentEl.innerHTML = marked.parse(finalAssistantMessage.content);
|
||||
this._scrollToBottom();
|
||||
}
|
||||
} catch (e) { /* ignore malformed JSON */ }
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Fetch stream error:', error);
|
||||
const errorMessage = error.rawMessageFromServer || error.message;
|
||||
finalAssistantMessage.content = `<span class="text-red-500">请求失败: ${errorMessage}</span>`;
|
||||
} finally {
|
||||
this.sessionManager.addMessage(finalAssistantMessage);
|
||||
this._renderChatMessages();
|
||||
this._renderSessionList();
|
||||
this.isStreaming = false;
|
||||
this._setLoadingState(false);
|
||||
this.elements.messageInput.focus();
|
||||
}
|
||||
}
|
||||
|
||||
_renderMessage(message, replace = false, isLastMessage = false) {
|
||||
let contentHtml;
|
||||
if (message.role === 'user') {
|
||||
const escapedContent = message.content.replace(/</g, "<").replace(/>/g, ">");
|
||||
contentHtml = `<p class="text-sm text-foreground message-content">${escapedContent.replace(/\n/g, '<br>')}</p>`;
|
||||
} else {
|
||||
// [FIXED] Simplified logic: if it's an assistant message, it either has real content or error HTML.
|
||||
// The isStreamingPlaceholder case is now handled differently and removed from here.
|
||||
const isErrorHtml = message.content && message.content.includes('<span class="text-red-500">');
|
||||
contentHtml = isErrorHtml ?
|
||||
`<div class="message-content">${message.content}</div>` :
|
||||
`<div class="prose prose-sm max-w-none text-foreground message-content">${marked.parse(message.content || '')}</div>`;
|
||||
}
|
||||
|
||||
// [FIXED] No special handling for streaming placeholders needed anymore.
|
||||
// If a message has content, it gets actions. An error message has content, so it will get actions.
|
||||
let actionsHtml = '';
|
||||
let retryButton = '';
|
||||
if (message.role === 'user') {
|
||||
retryButton = `
|
||||
<button class="btn btn-ghost btn-icon w-6 h-6 hover:text-sky-500 action-retry rounded-full" title="从此消息重新生成">
|
||||
<i class="fas fa-redo text-xs"></i>
|
||||
</button>`;
|
||||
} else if (message.role === 'assistant' && isLastMessage) {
|
||||
// This now correctly applies to final error messages too.
|
||||
retryButton = `
|
||||
<button class="btn btn-ghost btn-icon w-6 h-6 hover:text-sky-500 action-retry rounded-full" title="重新生成回答">
|
||||
<i class="fas fa-redo text-xs"></i>
|
||||
</button>`;
|
||||
}
|
||||
const toolbarBaseClasses = "message-actions flex items-center gap-1 transition-opacity duration-200";
|
||||
const toolbarPositionClass = isLastMessage ? "mt-2" : "absolute bottom-2.5 right-2.5";
|
||||
const visibilityClass = isLastMessage ? "" : "opacity-0 group-hover:opacity-100";
|
||||
|
||||
actionsHtml = `
|
||||
<div class="${toolbarBaseClasses} ${toolbarPositionClass} ${visibilityClass}">
|
||||
${retryButton}
|
||||
<button class="btn btn-ghost btn-icon w-6 h-6 hover:text-sky-500 action-copy rounded-full" title="复制">
|
||||
<i class="far fa-copy text-xs"></i>
|
||||
</button>
|
||||
<button class="btn btn-ghost btn-icon w-6 h-6 hover:text-sky-500 action-delete rounded-full" title="删除">
|
||||
<i class="far fa-trash-alt text-xs"></i>
|
||||
</button>
|
||||
</div>
|
||||
`;
|
||||
|
||||
const messageBubbleClasses = `relative group rounded-lg p-5 ${message.role === 'user' ? 'bg-muted' : 'bg-primary/10 border/20'}`;
|
||||
const messageHtml = `
|
||||
<div class="flex items-start gap-4" data-message-id="${message.id}">
|
||||
<span class="flex h-8 w-8 shrink-0 items-center justify-center rounded-full ${message.role === 'user' ? 'bg-secondary text-secondary-foreground' : 'bg-primary text-primary-foreground'}">
|
||||
<i class="fas ${message.role === 'user' ? 'fa-user' : 'fa-robot'}"></i>
|
||||
</span>
|
||||
<div class="flex-1 space-y-2">
|
||||
<div class="${messageBubbleClasses}">
|
||||
${contentHtml}
|
||||
${actionsHtml}
|
||||
</div>
|
||||
</div>
|
||||
</div>`;
|
||||
|
||||
const existingElement = this.elements.chatMessagesContainer.querySelector(`[data-message-id="${message.id}"]`);
|
||||
if (replace && existingElement) {
|
||||
existingElement.outerHTML = messageHtml;
|
||||
} else if (!existingElement) {
|
||||
this.elements.chatMessagesContainer.insertAdjacentHTML('beforeend', messageHtml);
|
||||
}
|
||||
|
||||
if (!replace) { this._scrollToBottom(); }
|
||||
}
|
||||
|
||||
_scrollToBottom() {
|
||||
this.elements.chatScrollContainer.scrollTop = this.elements.chatScrollContainer.scrollHeight;
|
||||
}
|
||||
|
||||
_render() {
|
||||
this._renderSessionList();
|
||||
this._renderChatMessages();
|
||||
this._renderChatHeader();
|
||||
}
|
||||
|
||||
_renderSessionList() {
|
||||
let sessions = this.sessionManager.getSessions();
|
||||
const currentSessionId = this.sessionManager.getCurrentSessionId();
|
||||
if (this.searchTerm) {
|
||||
const lowerCaseSearchTerm = this.searchTerm.toLowerCase();
|
||||
sessions = sessions.filter(session => {
|
||||
const titleMatch = session.name.toLowerCase().includes(lowerCaseSearchTerm);
|
||||
const messageMatch = session.messages.some(message =>
|
||||
message.content && message.content.toLowerCase().includes(lowerCaseSearchTerm)
|
||||
);
|
||||
return titleMatch || messageMatch;
|
||||
});
|
||||
}
|
||||
|
||||
this.elements.sessionListContainer.innerHTML = sessions.map(session => {
|
||||
const isActive = session.id === currentSessionId;
|
||||
const lastMessage = session.messages.length > 0 ? session.messages[session.messages.length - 1].content : '新会话';
|
||||
|
||||
return `
|
||||
<div class="relative group flex items-center" data-session-id="${session.id}">
|
||||
<a href="#" class="grow flex flex-col items-start gap-2 rounded-lg p-3 text-left text-sm transition-all hover:bg-accent ${isActive ? 'bg-accent' : ''} min-w-0">
|
||||
<div class="w-full font-semibold truncate pr-2">${session.name}</div>
|
||||
<div class="w-full text-xs text-muted-foreground line-clamp-2">${session.messages.length > 0 ? (lastMessage.includes('<span class="text-red-500">') ? '[请求失败]' : lastMessage) : '新会话'}</div>
|
||||
</a>
|
||||
<button class="delete-session-btn absolute top-1/2 -translate-y-1/2 right-2 w-6 h-6 flex items-center justify-center rounded-full bg-muted text-muted-foreground opacity-0 group-hover:opacity-100 hover:bg-destructive/80 hover:text-destructive-foreground transition-all" aria-label="删除会话">
|
||||
<i class="fas fa-times"></i>
|
||||
</button>
|
||||
</div>
|
||||
`;
|
||||
}).join('');
|
||||
}
|
||||
|
||||
_renderChatMessages() {
|
||||
this.elements.chatMessagesContainer.innerHTML = '';
|
||||
const currentSession = this.sessionManager.getCurrentSession();
|
||||
if (currentSession) {
|
||||
const messages = currentSession.messages;
|
||||
const lastMessageIndex = messages.length > 0 ? messages.length - 1 : -1;
|
||||
|
||||
messages.forEach((message, index) => {
|
||||
const isLastMessage = (index === lastMessageIndex);
|
||||
this._renderMessage(message, false, isLastMessage);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
_renderChatHeader() {
|
||||
const currentSession = this.sessionManager.getCurrentSession();
|
||||
if (currentSession && this.elements.chatHeaderTitle) {
|
||||
this.elements.chatHeaderTitle.textContent = currentSession.name;
|
||||
}
|
||||
}
|
||||
|
||||
_setLoadingState(isLoading) {
|
||||
this.elements.messageInput.disabled = isLoading;
|
||||
this.elements.sendBtn.disabled = isLoading;
|
||||
if (isLoading) {
|
||||
uiPatterns.setButtonLoading(this.elements.sendBtn);
|
||||
} else {
|
||||
uiPatterns.clearButtonLoading(this.elements.sendBtn);
|
||||
}
|
||||
}
|
||||
_handleClearSession() {
|
||||
Swal.fire({
|
||||
width: '22rem',
|
||||
backdrop: `rgba(0,0,0,0.5)`,
|
||||
heightAuto: false,
|
||||
customClass: { popup: `swal2-custom-style ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}` },
|
||||
title: '确定要清空会话吗?',
|
||||
text: '当前会话的所有聊天记录将被删除,但会话本身会保留。',
|
||||
showCancelButton: true,
|
||||
confirmButtonText: '确认清空',
|
||||
cancelButtonText: '取消',
|
||||
reverseButtons: false,
|
||||
confirmButtonColor: '#ef4444',
|
||||
cancelButtonColor: '#6b7280',
|
||||
focusConfirm: false,
|
||||
focusCancel: true,
|
||||
}).then((result) => {
|
||||
if (result.isConfirmed) {
|
||||
this.sessionManager.clearCurrentSession(); // This method needs to be added to SessionManager
|
||||
this._render();
|
||||
}
|
||||
});
|
||||
}
|
||||
_handleDeleteSession(sessionId) {
|
||||
Swal.fire({
|
||||
width: '22rem',
|
||||
backdrop: `rgba(0,0,0,0.5)`,
|
||||
heightAuto: false,
|
||||
customClass: { popup: `swal2-custom-style ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}` },
|
||||
title: '确定要删除吗?',
|
||||
text: '此会话的所有记录将被永久删除,无法撤销。',
|
||||
showCancelButton: true,
|
||||
confirmButtonText: '确认删除',
|
||||
cancelButtonText: '取消',
|
||||
reverseButtons: false,
|
||||
confirmButtonColor: '#ef4444',
|
||||
cancelButtonColor: '#6b7280',
|
||||
focusConfirm: false,
|
||||
focusCancel: true,
|
||||
}).then((result) => {
|
||||
if (result.isConfirmed) {
|
||||
this.sessionManager.deleteSession(sessionId);
|
||||
this._render();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export default function() {
|
||||
const page = new ChatPage();
|
||||
page.init();
|
||||
}
|
||||
@@ -427,7 +427,7 @@ class ApiKeyList {
|
||||
contentHtml = `
|
||||
<div class="task-item-main">
|
||||
<div class="task-item-icon-summary text-red-500"><i class="fas fa-exclamation-triangle"></i></div>
|
||||
<div class="task-item-content flex-grow">
|
||||
<div class="task-item-content grow">
|
||||
<p class="task-item-title">验证任务出错: ${maskedKey}</p>
|
||||
<p class="task-item-status text-red-500 truncate" title="${safeError}">${safeError}</p>
|
||||
</div>
|
||||
@@ -443,7 +443,7 @@ class ApiKeyList {
|
||||
contentHtml = `
|
||||
<div class="task-item-main">
|
||||
<div class="task-item-icon-summary"><i class="${iconClass}"></i></div>
|
||||
<div class="task-item-content flex-grow">
|
||||
<div class="task-item-content grow">
|
||||
<p class="task-item-title">${title}: ${maskedKey}</p>
|
||||
<p class="task-item-status truncate" title="${safeMessage}">${safeMessage}</p>
|
||||
</div>
|
||||
@@ -455,7 +455,7 @@ class ApiKeyList {
|
||||
contentHtml = `
|
||||
<div class="task-item-main gap-3">
|
||||
<div class="task-item-icon task-item-icon-running"><i class="fas fa-spinner animate-spin"></i></div>
|
||||
<div class="task-item-content flex-grow">
|
||||
<div class="task-item-content grow">
|
||||
<p class="task-item-title">正在验证: ${maskedKey}</p>
|
||||
<p class="task-item-status">运行中... (${data.processed}/${data.total})</p>
|
||||
</div>
|
||||
@@ -495,7 +495,7 @@ class ApiKeyList {
|
||||
data-mapping-id="${mappingId}">
|
||||
<input type="checkbox" class="api-key-checkbox h-4 w-4 rounded border-zinc-300 text-blue-600 focus:ring-blue-500 shrink-0">
|
||||
<span data-status-indicator class="w-2 h-2 rounded-full shrink-0"></span>
|
||||
<div class="flex-grow min-w-0">
|
||||
<div class="grow min-w-0">
|
||||
<p class="font-mono text-xs font-semibold truncate">${maskedKey}</p>
|
||||
<p class="text-xs text-zinc-400 mt-1">失败: ${errorCount} 次</p>
|
||||
</div>
|
||||
@@ -854,7 +854,7 @@ class ApiKeyList {
|
||||
contentHtml = `
|
||||
<div class="task-item-main gap-3">
|
||||
<div class="task-item-icon task-item-icon-running"><i class="fas fa-spinner animate-spin"></i></div>
|
||||
<div class="task-item-content flex-grow">
|
||||
<div class="task-item-content grow">
|
||||
<p class="task-item-title">批量验证 ${data.total} 个Key</p>
|
||||
<p class="task-item-status">运行中... (${data.processed}/${data.total})</p>
|
||||
</div>
|
||||
@@ -867,7 +867,7 @@ class ApiKeyList {
|
||||
contentHtml = `
|
||||
<div class="task-item-main">
|
||||
<div class="task-item-icon-summary text-red-500"><i class="fas fa-exclamation-triangle"></i></div>
|
||||
<div class="task-item-content flex-grow">
|
||||
<div class="task-item-content grow">
|
||||
<p class="task-item-title">批量验证任务出错</p>
|
||||
<p class="task-item-status text-red-500 truncate" title="${data.error}">${data.error}</p>
|
||||
</div>
|
||||
@@ -893,7 +893,7 @@ class ApiKeyList {
|
||||
return `
|
||||
<div class="flex items-start text-xs">
|
||||
<i class="fas fa-check-circle text-green-500 mt-0.5 mr-2"></i>
|
||||
<div class="flex-grow">
|
||||
<div class="grow">
|
||||
<p class="font-mono">${maskedKey}</p>
|
||||
<p class="text-zinc-400">${safeMessage}</p>
|
||||
</div>
|
||||
@@ -902,7 +902,7 @@ class ApiKeyList {
|
||||
return `
|
||||
<div class="flex items-start text-xs">
|
||||
<i class="fas fa-times-circle text-red-500 mt-0.5 mr-2"></i>
|
||||
<div class="flex-grow">
|
||||
<div class="grow">
|
||||
<p class="font-mono">${maskedKey}</p>
|
||||
<p class="text-zinc-400">${safeMessage}</p>
|
||||
</div>
|
||||
@@ -913,7 +913,7 @@ class ApiKeyList {
|
||||
contentHtml = `
|
||||
<div class="task-item-main">
|
||||
<div class="task-item-icon-summary"><i class="${overallIconClass}"></i></div>
|
||||
<div class="task-item-content flex-grow">
|
||||
<div class="task-item-content grow">
|
||||
<div class="flex justify-between items-center cursor-pointer" data-task-toggle>
|
||||
<p class="task-item-title">${summaryTitle}</p>
|
||||
<i class="fas fa-chevron-down task-toggle-icon"></i>
|
||||
@@ -1135,6 +1135,8 @@ class ApiKeyList {
|
||||
const actionConfig = this._getQuickActionConfig(action);
|
||||
if (actionConfig && actionConfig.requiresConfirm) {
|
||||
const result = await Swal.fire({
|
||||
backdrop: `rgba(0,0,0,0.5)`,
|
||||
heightAuto: false,
|
||||
target: '#main-content-wrapper',
|
||||
title: '请确认操作',
|
||||
html: actionConfig.confirmText,
|
||||
@@ -1246,7 +1248,7 @@ class ApiKeyList {
|
||||
contentHtml = `
|
||||
<div class="task-item-main gap-3">
|
||||
<div class="task-item-icon task-item-icon-running"><i class="fas fa-spinner animate-spin"></i></div>
|
||||
<div class="task-item-content flex-grow">
|
||||
<div class="task-item-content grow">
|
||||
<p class="task-item-title">${title}</p>
|
||||
<p class="task-item-status">运行中... (${data.processed}/${data.total})</p>
|
||||
</div>
|
||||
@@ -1256,7 +1258,7 @@ class ApiKeyList {
|
||||
contentHtml = `
|
||||
<div class="task-item-main">
|
||||
<div class="task-item-icon-summary text-red-500"><i class="fas fa-exclamation-triangle"></i></div>
|
||||
<div class="task-item-content flex-grow">
|
||||
<div class="task-item-content grow">
|
||||
<p class="task-item-title">${title}任务出错</p>
|
||||
<p class="task-item-status text-red-500 truncate" title="${safeError}">${safeError}</p>
|
||||
</div>
|
||||
@@ -1293,7 +1295,7 @@ class ApiKeyList {
|
||||
contentHtml = `
|
||||
<div class="task-item-main">
|
||||
<div class="task-item-icon-summary"><i class="${iconClass}"></i></div>
|
||||
<div class="task-item-content flex-grow">
|
||||
<div class="task-item-content grow">
|
||||
<p class="task-item-title">${title}</p>
|
||||
<p class="task-item-status truncate" title="${safeSummary}">${safeSummary}</p>
|
||||
</div>
|
||||
|
||||
151
frontend/js/pages/logs/batchActions.js
Normal file
151
frontend/js/pages/logs/batchActions.js
Normal file
@@ -0,0 +1,151 @@
|
||||
// Filename: frontend/js/components/BatchActions.js
|
||||
import { apiFetchJson } from '../../services/api.js';
|
||||
|
||||
// 存储对 LogsPage 实例的引用
|
||||
let logsPageInstance = null;
|
||||
|
||||
// 存储对 DOM 元素的引用
|
||||
const elements = {
|
||||
batchActionsBtn: null,
|
||||
batchActionsMenu: null,
|
||||
deleteSelectedLogsBtn: null,
|
||||
clearAllLogsBtn: null,
|
||||
};
|
||||
|
||||
// 用于处理页面点击以关闭菜单的绑定函数
|
||||
const handleDocumentClick = (event) => {
|
||||
if (!elements.batchActionsMenu.contains(event.target) && !elements.batchActionsBtn.contains(event.target)) {
|
||||
closeBatchActionsMenu();
|
||||
}
|
||||
};
|
||||
|
||||
// 关闭菜单的逻辑
|
||||
function closeBatchActionsMenu() {
|
||||
if (elements.batchActionsMenu && !elements.batchActionsMenu.classList.contains('hidden')) {
|
||||
elements.batchActionsMenu.classList.remove('opacity-100', 'scale-100');
|
||||
elements.batchActionsMenu.classList.add('opacity-0', 'scale-95');
|
||||
setTimeout(() => {
|
||||
elements.batchActionsMenu.classList.add('hidden');
|
||||
}, 100);
|
||||
document.removeEventListener('click', handleDocumentClick);
|
||||
}
|
||||
}
|
||||
|
||||
// 切换菜单显示/隐藏
|
||||
function handleBatchActionsToggle(event) {
|
||||
event.stopPropagation();
|
||||
const isHidden = elements.batchActionsMenu.classList.contains('hidden');
|
||||
if (isHidden) {
|
||||
elements.batchActionsMenu.classList.remove('hidden', 'opacity-0', 'scale-95');
|
||||
elements.batchActionsMenu.classList.add('opacity-100', 'scale-100');
|
||||
document.addEventListener('click', handleDocumentClick);
|
||||
} else {
|
||||
closeBatchActionsMenu();
|
||||
}
|
||||
}
|
||||
|
||||
// 处理删除选中日志的逻辑
|
||||
async function handleDeleteSelectedLogs() {
|
||||
closeBatchActionsMenu();
|
||||
const selectedIds = Array.from(logsPageInstance.state.selectedLogIds);
|
||||
if (selectedIds.length === 0) return;
|
||||
|
||||
Swal.fire({
|
||||
width: '20rem',
|
||||
backdrop: `rgba(0,0,0,0.5)`,
|
||||
heightAuto: false,
|
||||
customClass: { popup: `swal2-custom-style ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}` },
|
||||
title: '确认删除',
|
||||
text: `您确定要删除选定的 ${selectedIds.length} 条日志吗?此操作不可撤销。`,
|
||||
icon: 'warning',
|
||||
showCancelButton: true,
|
||||
confirmButtonText: '确认删除',
|
||||
cancelButtonText: '取消',
|
||||
reverseButtons: false,
|
||||
confirmButtonColor: '#ef4444',
|
||||
cancelButtonColor: '#6b7280',
|
||||
focusCancel: true,
|
||||
target: '#main-content-wrapper',
|
||||
}).then(async (result) => {
|
||||
if (result.isConfirmed) {
|
||||
try {
|
||||
const idsQueryString = selectedIds.join(',');
|
||||
const url = `/admin/logs?ids=${idsQueryString}`;
|
||||
const { success, message } = await apiFetchJson(url, { method: 'DELETE' });
|
||||
if (success) {
|
||||
Swal.fire({ toast: true, position: 'top-end', icon: 'success', title: '删除成功', showConfirmButton: false, timer: 2000, timerProgressBar: true });
|
||||
logsPageInstance.loadAndRenderLogs(); // 使用实例刷新列表
|
||||
} else {
|
||||
throw new Error(message || '删除失败,请稍后重试。');
|
||||
}
|
||||
} catch (error) {
|
||||
Swal.fire({ icon: 'error', title: '操作失败', text: error.message, target: '#main-content-wrapper' });
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// [NEW] 处理清空所有日志的逻辑
|
||||
async function handleClearAllLogs() {
|
||||
closeBatchActionsMenu();
|
||||
|
||||
Swal.fire({
|
||||
width: '20rem',
|
||||
backdrop: `rgba(0,0,0,0.5)`,
|
||||
heightAuto: false,
|
||||
customClass: { popup: `swal2-custom-style ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}` },
|
||||
title: '危险操作确认',
|
||||
html: `您确定要<strong class="text-red-500">清空全部</strong>日志吗?<br>此操作不可撤销!`,
|
||||
icon: 'warning',
|
||||
showCancelButton: true,
|
||||
confirmButtonText: '确认清空',
|
||||
cancelButtonText: '取消',
|
||||
reverseButtons: false,
|
||||
confirmButtonColor: '#ef4444',
|
||||
cancelButtonColor: '#6b7280',
|
||||
focusCancel: true,
|
||||
target: '#main-content-wrapper',
|
||||
}).then(async (result) => {
|
||||
if (result.isConfirmed) {
|
||||
try {
|
||||
const url = `/admin/logs/all`; // 后端清空所有日志的接口
|
||||
const { success, message } = await apiFetchJson(url, { method: 'DELETE' });
|
||||
if (success) {
|
||||
Swal.fire({ toast: true, position: 'top-end', icon: 'success', title: '清空成功', showConfirmButton: false, timer: 2000, timerProgressBar: true });
|
||||
logsPageInstance.loadAndRenderLogs(); // 刷新列表
|
||||
} else {
|
||||
throw new Error(message || '清空失败,请稍后重试。');
|
||||
}
|
||||
} catch (error) {
|
||||
Swal.fire({ icon: 'error', title: '操作失败', text: error.message, target: '#main-content-wrapper' });
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 初始化批量操作模块
|
||||
* @param {object} logsPage - LogsPage 类的实例
|
||||
*/
|
||||
export function initBatchActions(logsPage) {
|
||||
logsPageInstance = logsPage;
|
||||
|
||||
// 查询所有需要的 DOM 元素
|
||||
elements.batchActionsBtn = document.getElementById('batch-actions-btn');
|
||||
elements.batchActionsMenu = document.getElementById('batch-actions-menu');
|
||||
elements.deleteSelectedLogsBtn = document.getElementById('delete-selected-logs-btn');
|
||||
elements.clearAllLogsBtn = document.getElementById('clear-all-logs-btn'); // [NEW] 查询新按钮
|
||||
|
||||
if (!elements.batchActionsBtn) return; // 如果找不到主按钮,则不进行任何操作
|
||||
|
||||
// 绑定事件监听器
|
||||
elements.batchActionsBtn.addEventListener('click', handleBatchActionsToggle);
|
||||
if (elements.deleteSelectedLogsBtn) {
|
||||
elements.deleteSelectedLogsBtn.addEventListener('click', handleDeleteSelectedLogs);
|
||||
}
|
||||
if (elements.clearAllLogsBtn) { // [NEW] 绑定新按钮的事件
|
||||
elements.clearAllLogsBtn.addEventListener('click', handleClearAllLogs);
|
||||
}
|
||||
}
|
||||
// [NEW] - END
|
||||
@@ -1,13 +1,19 @@
|
||||
// Filename: frontend/js/pages/logs/index.js
|
||||
|
||||
import { apiFetchJson } from '../../services/api.js';
|
||||
import LogList from './logList.js';
|
||||
import CustomSelectV2 from '../../components/customSelectV2.js';
|
||||
import { debounce } from '../../utils/utils.js';
|
||||
import FilterPopover from '../../components/filterPopover.js';
|
||||
import { STATIC_ERROR_MAP, STATUS_CODE_MAP } from './logList.js';
|
||||
import SystemLogTerminal from './systemLog.js';
|
||||
import { initBatchActions } from './batchActions.js';
|
||||
import flatpickr from '../../vendor/flatpickr.js';
|
||||
import LogSettingsModal from './logSettingsModal.js';
|
||||
|
||||
// [最终版] 创建一个共享的数据仓库,用于缓存 Groups 和 Keys
|
||||
const dataStore = {
|
||||
groups: new Map(),
|
||||
keys: new Map(),
|
||||
};
|
||||
groups: new Map(),
|
||||
keys: new Map(),
|
||||
};
|
||||
|
||||
class LogsPage {
|
||||
constructor() {
|
||||
@@ -15,32 +21,597 @@ class LogsPage {
|
||||
logs: [],
|
||||
pagination: { page: 1, pages: 1, total: 0, page_size: 20 },
|
||||
isLoading: true,
|
||||
filters: { page: 1, page_size: 20 }
|
||||
filters: {
|
||||
page: 1,
|
||||
page_size: 20,
|
||||
q: '',
|
||||
key_ids: new Set(),
|
||||
group_ids: new Set(),
|
||||
error_types: new Set(),
|
||||
status_codes: new Set(),
|
||||
start_date: null,
|
||||
end_date: null,
|
||||
},
|
||||
selectedLogIds: new Set(),
|
||||
currentView: 'error',
|
||||
};
|
||||
|
||||
this.elements = {
|
||||
tableBody: document.getElementById('logs-table-body'),
|
||||
tabsContainer: document.querySelector('[data-sliding-tabs-container]'),
|
||||
contentContainer: document.getElementById('log-content-container'),
|
||||
errorFilters: document.getElementById('error-logs-filters'),
|
||||
systemControls: document.getElementById('system-logs-controls'),
|
||||
errorTemplate: document.getElementById('error-logs-template'),
|
||||
systemTemplate: document.getElementById('system-logs-template'),
|
||||
settingsBtn: document.querySelector('button[aria-label="日志设置"]'),
|
||||
};
|
||||
|
||||
this.initialized = !!this.elements.tableBody;
|
||||
|
||||
this.initialized = !!this.elements.contentContainer;
|
||||
if (this.initialized) {
|
||||
this.logList = new LogList(this.elements.tableBody, dataStore);
|
||||
this.logList = null;
|
||||
this.systemLogTerminal = null;
|
||||
this.debouncedLoadAndRender = debounce(() => this.loadAndRenderLogs(), 300);
|
||||
this.fp = null;
|
||||
this.themeObserver = null;
|
||||
this.settingsModal = null;
|
||||
this.currentSettings = {};
|
||||
}
|
||||
}
|
||||
|
||||
async init() {
|
||||
if (!this.initialized) return;
|
||||
this.initEventListeners();
|
||||
// 页面初始化:先加载群组,再加载日志
|
||||
this._initPermanentEventListeners();
|
||||
await this.loadCurrentSettings();
|
||||
this._initSettingsModal();
|
||||
await this.loadGroupsOnce();
|
||||
await this.loadAndRenderLogs();
|
||||
this.state.currentView = null;
|
||||
this.switchToView('error');
|
||||
}
|
||||
|
||||
initEventListeners() { /* 分页和筛选的事件监听器 */ }
|
||||
_initSettingsModal() {
|
||||
if (!this.elements.settingsBtn) return;
|
||||
this.settingsModal = new LogSettingsModal({
|
||||
onSave: this.handleSaveSettings.bind(this)
|
||||
});
|
||||
this.elements.settingsBtn.addEventListener('click', () => {
|
||||
|
||||
const settingsForModal = {
|
||||
log_level: this.currentSettings.log_level,
|
||||
auto_cleanup: {
|
||||
enabled: this.currentSettings.log_auto_cleanup_enabled,
|
||||
retention_days: this.currentSettings.log_auto_cleanup_retention_days,
|
||||
exec_time: this.currentSettings.log_auto_cleanup_time,
|
||||
interval: 'daily',
|
||||
}
|
||||
};
|
||||
this.settingsModal.open(settingsForModal);
|
||||
});
|
||||
}
|
||||
|
||||
async loadCurrentSettings() {
|
||||
try {
|
||||
const { success, data } = await apiFetchJson('/admin/settings');
|
||||
if (success) {
|
||||
this.currentSettings = data;
|
||||
} else {
|
||||
console.error('Failed to load settings from server.');
|
||||
this.currentSettings = { log_auto_cleanup_time: '04:05' };
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load log settings:', error);
|
||||
this.currentSettings = { log_auto_cleanup_time: '04:05' };
|
||||
}
|
||||
}
|
||||
|
||||
async handleSaveSettings(settingsData) {
|
||||
const partialPayload = {
|
||||
"log_level": settingsData.log_level,
|
||||
"log_auto_cleanup_enabled": settingsData.auto_cleanup.enabled,
|
||||
"log_auto_cleanup_time": settingsData.auto_cleanup.exec_time,
|
||||
};
|
||||
if (settingsData.auto_cleanup.enabled) {
|
||||
let retentionDays = settingsData.auto_cleanup.retention_days;
|
||||
if (retentionDays === null || retentionDays <= 0) {
|
||||
retentionDays = 30;
|
||||
}
|
||||
partialPayload.log_auto_cleanup_retention_days = retentionDays;
|
||||
}
|
||||
|
||||
console.log('Sending PARTIAL settings update to /admin/settings:', partialPayload);
|
||||
try {
|
||||
const { success, message } = await apiFetchJson('/admin/settings', {
|
||||
method: 'PUT',
|
||||
body: JSON.stringify(partialPayload)
|
||||
});
|
||||
if (!success) {
|
||||
throw new Error(message || 'Failed to save settings');
|
||||
}
|
||||
|
||||
Object.assign(this.currentSettings, partialPayload);
|
||||
} catch (error) {
|
||||
console.error('Error saving log settings:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
_initPermanentEventListeners() {
|
||||
this.elements.tabsContainer.addEventListener('click', (event) => {
|
||||
const tabItem = event.target.closest('[data-tab-target]');
|
||||
if (!tabItem) return;
|
||||
event.preventDefault();
|
||||
const viewName = tabItem.dataset.tabTarget;
|
||||
if (viewName) {
|
||||
this.switchToView(viewName);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
switchToView(viewName) {
|
||||
if (this.state.currentView === viewName && this.elements.contentContainer.innerHTML !== '') return;
|
||||
|
||||
if (this.systemLogTerminal) {
|
||||
this.systemLogTerminal.disconnect();
|
||||
this.systemLogTerminal = null;
|
||||
}
|
||||
if (this.fp) {
|
||||
this.fp.destroy();
|
||||
this.fp = null;
|
||||
}
|
||||
if (this.themeObserver) {
|
||||
this.themeObserver.disconnect();
|
||||
this.themeObserver = null;
|
||||
}
|
||||
this.state.currentView = viewName;
|
||||
this.elements.contentContainer.innerHTML = '';
|
||||
const isErrorView = viewName === 'error';
|
||||
this.elements.errorFilters.style.display = isErrorView ? 'flex' : 'none';
|
||||
this.elements.systemControls.style.display = isErrorView ? 'none' : 'flex';
|
||||
if (isErrorView) {
|
||||
const template = this.elements.errorTemplate.content.cloneNode(true);
|
||||
this.elements.contentContainer.appendChild(template);
|
||||
requestAnimationFrame(() => {
|
||||
this._initErrorLogView();
|
||||
});
|
||||
} else if (viewName === 'system') {
|
||||
const template = this.elements.systemTemplate.content.cloneNode(true);
|
||||
this.elements.contentContainer.appendChild(template);
|
||||
requestAnimationFrame(() => {
|
||||
this._initSystemLogView();
|
||||
});
|
||||
}
|
||||
}
|
||||
_initErrorLogView() {
|
||||
this.elements.tableBody = document.getElementById('logs-table-body');
|
||||
this.elements.selectedCount = document.querySelector('.flex-1.text-sm span.font-semibold:nth-child(1)');
|
||||
this.elements.totalCount = document.querySelector('.flex-1.text-sm span:last-child');
|
||||
this.elements.pageSizeSelect = document.querySelector('[data-component="custom-select-v2"] select');
|
||||
this.elements.pageInfo = document.querySelector('.flex.w-\\[100px\\]');
|
||||
this.elements.paginationBtns = document.querySelectorAll('[data-pagination-controls] button');
|
||||
this.elements.selectAllCheckbox = document.querySelector('thead .table-head-cell input[type="checkbox"]');
|
||||
this.elements.searchInput = document.getElementById('log-search-input');
|
||||
this.elements.errorTypeFilterBtn = document.getElementById('filter-error-type-btn');
|
||||
this.elements.errorCodeFilterBtn = document.getElementById('filter-error-code-btn');
|
||||
this.elements.dateRangeFilterBtn = document.getElementById('filter-date-range-btn');
|
||||
|
||||
this.logList = new LogList(this.elements.tableBody, dataStore);
|
||||
const selectContainer = document.querySelector('[data-component="custom-select-v2"]');
|
||||
if (selectContainer) { new CustomSelectV2(selectContainer); }
|
||||
|
||||
this.initFilterPopovers();
|
||||
this.initDateRangePicker();
|
||||
this.initEventListeners();
|
||||
this._observeThemeChanges();
|
||||
initBatchActions(this);
|
||||
this.loadAndRenderLogs();
|
||||
}
|
||||
|
||||
_observeThemeChanges() {
|
||||
const applyTheme = () => {
|
||||
if (!this.fp || !this.fp.calendarContainer) return;
|
||||
if (document.documentElement.classList.contains('dark')) {
|
||||
this.fp.calendarContainer.classList.add('dark');
|
||||
} else {
|
||||
this.fp.calendarContainer.classList.remove('dark');
|
||||
}
|
||||
};
|
||||
this.themeObserver = new MutationObserver((mutationsList) => {
|
||||
for (const mutation of mutationsList) {
|
||||
if (mutation.type === 'attributes' && mutation.attributeName === 'class') {
|
||||
applyTheme();
|
||||
}
|
||||
}
|
||||
});
|
||||
this.themeObserver.observe(document.documentElement, { attributes: true });
|
||||
|
||||
applyTheme();
|
||||
}
|
||||
|
||||
_initSystemLogView() {
|
||||
this.systemLogTerminal = new SystemLogTerminal(
|
||||
this.elements.contentContainer,
|
||||
this.elements.systemControls
|
||||
);
|
||||
Swal.fire({
|
||||
width: '20rem',
|
||||
backdrop: `rgba(0,0,0,0.5)`,
|
||||
heightAuto: false,
|
||||
customClass: { popup: `swal2-custom-style ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}` },
|
||||
title: '系统终端日志',
|
||||
text: '您即将连接到实时系统日志流窗口。',
|
||||
showCancelButton: true,
|
||||
confirmButtonText: '确认',
|
||||
cancelButtonText: '取消',
|
||||
reverseButtons: false,
|
||||
confirmButtonColor: 'rgba(31, 102, 255, 0.8)',
|
||||
cancelButtonColor: '#6b7280',
|
||||
focusConfirm: false,
|
||||
focusCancel: false,
|
||||
target: '#main-content-wrapper',
|
||||
}).then((result) => {
|
||||
if (result.isConfirmed) {
|
||||
this.systemLogTerminal.connect();
|
||||
} else {
|
||||
const errorLogTab = Array.from(this.elements.tabsContainer.querySelectorAll('[data-tab-target="error"]'))[0];
|
||||
if (errorLogTab) errorLogTab.click();
|
||||
}
|
||||
});
|
||||
}
|
||||
initFilterPopovers() {
|
||||
const errorTypeOptions = [
|
||||
...Object.values(STATUS_CODE_MAP).map(v => ({ value: v.type, label: v.type })),
|
||||
...Object.values(STATIC_ERROR_MAP).map(v => ({ value: v.type, label: v.type }))
|
||||
];
|
||||
const uniqueErrorTypeOptions = Array.from(new Map(errorTypeOptions.map(item => [item.value, item])).values());
|
||||
if (this.elements.errorTypeFilterBtn) {
|
||||
new FilterPopover(this.elements.errorTypeFilterBtn, uniqueErrorTypeOptions, '筛选错误类型');
|
||||
}
|
||||
const statusCodeOptions = Object.keys(STATUS_CODE_MAP).map(code => ({ value: code, label: code }));
|
||||
if (this.elements.errorCodeFilterBtn) {
|
||||
new FilterPopover(this.elements.errorCodeFilterBtn, statusCodeOptions, '筛选状态码');
|
||||
}
|
||||
}
|
||||
initDateRangePicker() {
|
||||
if (!this.elements.dateRangeFilterBtn) return;
|
||||
const buttonTextSpan = this.elements.dateRangeFilterBtn.querySelector('span');
|
||||
const originalText = buttonTextSpan.textContent;
|
||||
this.fp = flatpickr(this.elements.dateRangeFilterBtn, {
|
||||
mode: 'range',
|
||||
dateFormat: 'Y-m-d',
|
||||
onClose: (selectedDates) => {
|
||||
if (selectedDates.length === 2) {
|
||||
const [start, end] = selectedDates;
|
||||
end.setHours(23, 59, 59, 999);
|
||||
this.state.filters.start_date = start.toISOString();
|
||||
this.state.filters.end_date = end.toISOString();
|
||||
const startDateStr = start.toISOString().split('T')[0];
|
||||
const endDateStr = end.toISOString().split('T')[0];
|
||||
|
||||
buttonTextSpan.textContent = `${startDateStr} ~ ${endDateStr}`;
|
||||
this.elements.dateRangeFilterBtn.classList.add('!border-primary', '!text-primary');
|
||||
this.state.filters.page = 1;
|
||||
this.loadAndRenderLogs();
|
||||
}
|
||||
},
|
||||
onReady: (selectedDates, dateStr, instance) => {
|
||||
if (document.documentElement.classList.contains('dark')) {
|
||||
instance.calendarContainer.classList.add('dark');
|
||||
}
|
||||
const clearButton = document.createElement('button');
|
||||
clearButton.textContent = '清除';
|
||||
clearButton.className = 'button flatpickr-button flatpickr-clear-button';
|
||||
clearButton.addEventListener('click', (e) => {
|
||||
e.preventDefault();
|
||||
instance.clear();
|
||||
|
||||
this.state.filters.start_date = null;
|
||||
this.state.filters.end_date = null;
|
||||
buttonTextSpan.textContent = originalText;
|
||||
this.elements.dateRangeFilterBtn.classList.remove('!border-primary', '!text-primary');
|
||||
|
||||
this.state.filters.page = 1;
|
||||
this.loadAndRenderLogs();
|
||||
instance.close();
|
||||
});
|
||||
instance.calendarContainer.appendChild(clearButton);
|
||||
const nativeMonthSelect = instance.monthsDropdownContainer;
|
||||
if (!nativeMonthSelect) return;
|
||||
const monthYearContainer = nativeMonthSelect.parentElement;
|
||||
const wrapper = document.createElement('div');
|
||||
wrapper.className = 'custom-select-v2-container relative inline-block text-left';
|
||||
|
||||
wrapper.innerHTML = `
|
||||
<button type="button" class="custom-select-trigger inline-flex justify-center items-center w-22 gap-x-1.5 rounded-md bg-transparent px-1 py-1 text-xs font-semibold text-foreground shadow-sm ring-0 ring-inset ring-input hover:bg-accent focus:outline-none focus:ring-1 focus:ring-offset-1 focus:ring-offset-background focus:ring-primary" aria-haspopup="true">
|
||||
<span class="truncate"></span>
|
||||
</button>
|
||||
`;
|
||||
|
||||
const template = document.createElement('template');
|
||||
template.className = 'custom-select-panel-template';
|
||||
|
||||
template.innerHTML = `
|
||||
<div class="custom-select-panel absolute z-1000 my-2 w-24 origin-top-right rounded-md bg-popover dark:bg-zinc-900 shadow-lg ring-1 ring-zinc-500/30 ring-opacity-5 focus:outline-none" role="menu" aria-orientation="vertical" tabindex="-1">
|
||||
</div>
|
||||
`;
|
||||
nativeMonthSelect.classList.add('hidden');
|
||||
wrapper.appendChild(nativeMonthSelect);
|
||||
wrapper.appendChild(template);
|
||||
monthYearContainer.prepend(wrapper);
|
||||
const customSelect = new CustomSelectV2(wrapper);
|
||||
instance.customMonthSelect = customSelect;
|
||||
},
|
||||
onMonthChange: (selectedDates, dateStr, instance) => {
|
||||
if (instance.customMonthSelect) {
|
||||
instance.customMonthSelect.updateTriggerText();
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
initEventListeners() {
|
||||
if (this.elements.pageSizeSelect) {
|
||||
this.elements.pageSizeSelect.addEventListener('change', (e) => this.changePageSize(parseInt(e.target.value, 10)));
|
||||
}
|
||||
if (this.elements.paginationBtns.length >= 4) {
|
||||
this.elements.paginationBtns[0].addEventListener('click', () => this.goToPage(1));
|
||||
this.elements.paginationBtns[1].addEventListener('click', () => this.goToPage(this.state.pagination.page - 1));
|
||||
this.elements.paginationBtns[2].addEventListener('click', () => this.goToPage(this.state.pagination.page + 1));
|
||||
this.elements.paginationBtns[3].addEventListener('click', () => this.goToPage(this.state.pagination.pages));
|
||||
}
|
||||
if (this.elements.selectAllCheckbox) {
|
||||
this.elements.selectAllCheckbox.addEventListener('change', (event) => this.handleSelectAllChange(event));
|
||||
}
|
||||
if (this.elements.tableBody) {
|
||||
this.elements.tableBody.addEventListener('click', (event) => {
|
||||
const checkbox = event.target.closest('input[type="checkbox"]');
|
||||
const actionButton = event.target.closest('button[data-action]');
|
||||
if (checkbox) {
|
||||
this.handleSelectionChange(checkbox);
|
||||
} else if (actionButton) {
|
||||
this._handleLogRowAction(actionButton);
|
||||
}
|
||||
});
|
||||
}
|
||||
if (this.elements.searchInput) {
|
||||
this.elements.searchInput.addEventListener('input', (event) => this.handleSearchInput(event));
|
||||
}
|
||||
if (this.elements.errorTypeFilterBtn) {
|
||||
this.elements.errorTypeFilterBtn.addEventListener('filter-change', (e) => this.handleFilterChange(e));
|
||||
}
|
||||
if (this.elements.errorCodeFilterBtn) {
|
||||
this.elements.errorCodeFilterBtn.addEventListener('filter-change', (e) => this.handleFilterChange(e));
|
||||
}
|
||||
}
|
||||
handleFilterChange(event) {
|
||||
const { filterKey, selected } = event.detail;
|
||||
if (filterKey === 'filter-error-type-btn') {
|
||||
this.state.filters.error_types = selected;
|
||||
} else if (filterKey === 'filter-error-code-btn') {
|
||||
this.state.filters.status_codes = selected;
|
||||
}
|
||||
this.state.filters.page = 1;
|
||||
this.loadAndRenderLogs();
|
||||
}
|
||||
handleSearchInput(event) {
|
||||
const searchTerm = event.target.value.trim().toLowerCase();
|
||||
this.state.filters.page = 1;
|
||||
this.state.filters.q = '';
|
||||
this.state.filters.key_ids = new Set();
|
||||
this.state.filters.group_ids = new Set();
|
||||
if (searchTerm === '') {
|
||||
this.debouncedLoadAndRender();
|
||||
return;
|
||||
}
|
||||
const matchedGroupIds = new Set();
|
||||
dataStore.groups.forEach(group => {
|
||||
if (group.display_name.toLowerCase().includes(searchTerm)) {
|
||||
matchedGroupIds.add(group.id);
|
||||
}
|
||||
});
|
||||
const matchedKeyIds = new Set();
|
||||
dataStore.keys.forEach(key => {
|
||||
if (key.APIKey && key.APIKey.toLowerCase().includes(searchTerm)) {
|
||||
matchedKeyIds.add(key.ID);
|
||||
}
|
||||
});
|
||||
if (matchedGroupIds.size > 0) this.state.filters.group_ids = matchedGroupIds;
|
||||
if (matchedKeyIds.size > 0) this.state.filters.key_ids = matchedKeyIds;
|
||||
if (matchedGroupIds.size === 0 && matchedKeyIds.size === 0) {
|
||||
this.state.filters.q = searchTerm;
|
||||
}
|
||||
this.debouncedLoadAndRender();
|
||||
}
|
||||
handleSelectionChange(checkbox) {
|
||||
const row = checkbox.closest('.table-row');
|
||||
if (!row) return;
|
||||
const logId = parseInt(row.dataset.logId, 10);
|
||||
if (isNaN(logId)) return;
|
||||
if (checkbox.checked) {
|
||||
this.state.selectedLogIds.add(logId);
|
||||
} else {
|
||||
this.state.selectedLogIds.delete(logId);
|
||||
}
|
||||
this.syncSelectionUI();
|
||||
}
|
||||
handleSelectAllChange(event) {
|
||||
const isChecked = event.target.checked;
|
||||
this.state.logs.forEach(log => {
|
||||
if (isChecked) {
|
||||
this.state.selectedLogIds.add(log.ID);
|
||||
} else {
|
||||
this.state.selectedLogIds.delete(log.ID);
|
||||
}
|
||||
});
|
||||
this.syncRowCheckboxes();
|
||||
this.syncSelectionUI();
|
||||
}
|
||||
syncRowCheckboxes() {
|
||||
const isAllChecked = this.elements.selectAllCheckbox.checked;
|
||||
this.elements.tableBody.querySelectorAll('input[type="checkbox"]').forEach(cb => {
|
||||
cb.checked = isAllChecked;
|
||||
});
|
||||
}
|
||||
syncSelectionUI() {
|
||||
if (!this.elements.selectAllCheckbox || !this.elements.selectedCount) return;
|
||||
const selectedCount = this.state.selectedLogIds.size;
|
||||
const visibleLogsCount = this.state.logs.length;
|
||||
|
||||
if (selectedCount === 0) {
|
||||
this.elements.selectAllCheckbox.checked = false;
|
||||
this.elements.selectAllCheckbox.indeterminate = false;
|
||||
} else if (selectedCount < visibleLogsCount) {
|
||||
this.elements.selectAllCheckbox.checked = false;
|
||||
this.elements.selectAllCheckbox.indeterminate = true;
|
||||
} else if (selectedCount === visibleLogsCount && visibleLogsCount > 0) {
|
||||
this.elements.selectAllCheckbox.checked = true;
|
||||
this.elements.selectAllCheckbox.indeterminate = false;
|
||||
}
|
||||
|
||||
this.elements.selectedCount.textContent = selectedCount;
|
||||
const hasSelection = selectedCount > 0;
|
||||
const deleteSelectedBtn = document.getElementById('delete-selected-logs-btn');
|
||||
if (deleteSelectedBtn) {
|
||||
deleteSelectedBtn.disabled = !hasSelection;
|
||||
}
|
||||
}
|
||||
|
||||
async _handleLogRowAction(button) {
|
||||
const action = button.dataset.action;
|
||||
const row = button.closest('.table-row');
|
||||
const isDarkMode = document.documentElement.classList.contains('dark');
|
||||
if (!row) return;
|
||||
const logId = parseInt(row.dataset.logId, 10);
|
||||
const log = this.state.logs.find(l => l.ID === logId);
|
||||
if (!log) {
|
||||
Swal.fire({ toast: true, position: 'top-end', icon: 'error', title: '找不到日志数据', showConfirmButton: false, timer: 2000 });
|
||||
return;
|
||||
}
|
||||
switch (action) {
|
||||
case 'view-log-details': {
|
||||
const detailsHtml = `
|
||||
<div class="space-y-3 text-left text-sm p-2">
|
||||
<div class="flex"><p class="w-24 font-semibold text-zinc-500 shrink-0">状态码</p><p class="font-mono text-zinc-800 dark:text-zinc-200">${log.StatusCode || 'N/A'}</p></div>
|
||||
<div class="flex"><p class="w-24 font-semibold text-zinc-500 shrink-0">状态</p><p class="font-mono text-zinc-800 dark:text-zinc-200">${log.Status || 'N/A'}</p></div>
|
||||
<div class="flex"><p class="w-24 font-semibold text-zinc-500 shrink-0">模型</p><p class="font-mono text-zinc-800 dark:text-zinc-200">${log.ModelName || 'N/A'}</p></div>
|
||||
<div class="border-t border-zinc-200 dark:border-zinc-700 my-2"></div>
|
||||
<div>
|
||||
<p class="font-semibold text-zinc-500 mb-1">错误消息</p>
|
||||
<div class="max-h-40 overflow-y-auto bg-zinc-100 dark:bg-zinc-800 p-2 rounded-md text-zinc-700 dark:text-zinc-300 wrap-break-word text-xs">
|
||||
${log.ErrorMessage ? log.ErrorMessage.replace(/\n/g, '<br>') : '无错误消息。'}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
Swal.fire({
|
||||
target: '#main-content-wrapper',
|
||||
width: '32rem',
|
||||
backdrop: `rgba(0,0,0,0.5)`,
|
||||
heightAuto: false,
|
||||
customClass: {
|
||||
popup: `swal2-custom-style rounded-xl ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}`,
|
||||
title: 'text-lg font-bold',
|
||||
htmlContainer: 'm-0 text-left',
|
||||
},
|
||||
title: '日志详情',
|
||||
html: detailsHtml,
|
||||
showCloseButton: false,
|
||||
showConfirmButton: false,
|
||||
});
|
||||
break;
|
||||
}
|
||||
case 'copy-api-key': {
|
||||
const key = dataStore.keys.get(log.KeyID);
|
||||
if (key && key.APIKey) {
|
||||
navigator.clipboard.writeText(key.APIKey).then(() => {
|
||||
Swal.fire({ toast: true, position: 'top-end', customClass: { popup: `swal2-custom-style ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}` }, icon: 'success', title: 'API Key 已复制', showConfirmButton: false, timer: 1500 });
|
||||
}).catch(err => {
|
||||
Swal.fire({ toast: true, position: 'top-end', icon: 'error', title: '复制失败', text: err.message, showConfirmButton: false, timer: 2000 });
|
||||
});
|
||||
} else {
|
||||
Swal.fire({ toast: true, position: 'top-end', icon: 'warning', title: '未找到完整的API Key', showConfirmButton: false, timer: 2000 });
|
||||
return;
|
||||
}
|
||||
if (navigator.clipboard && window.isSecureContext) {
|
||||
navigator.clipboard.writeText(key.APIKey).then(() => {
|
||||
Swal.fire({ toast: true, position: 'top-end', icon: 'success', title: 'API Key 已复制', showConfirmButton: false, timer: 1500 });
|
||||
}).catch(err => {
|
||||
Swal.fire({ toast: true, position: 'top-end', icon: 'error', title: '复制失败', text: err.message, showConfirmButton: false, timer: 2000 });
|
||||
});
|
||||
} else {
|
||||
// 如果不可用,则提供明确的错误提示
|
||||
Swal.fire({
|
||||
icon: 'error',
|
||||
title: '复制失败',
|
||||
text: '此功能需要安全连接 (HTTPS) 或在 localhost 环境下使用。',
|
||||
target: '#main-content-wrapper',
|
||||
customClass: { popup: `swal2-custom-style ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}` },
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'delete-log': {
|
||||
Swal.fire({
|
||||
width: '20rem',
|
||||
backdrop: `rgba(0,0,0,0.5)`,
|
||||
heightAuto: false,
|
||||
customClass: { popup: `swal2-custom-style ${document.documentElement.classList.contains('dark') ? 'swal2-dark' : ''}` },
|
||||
title: '确认删除',
|
||||
text: `您确定要删除这条日志吗?此操作不可撤销。`,
|
||||
showCancelButton: true,
|
||||
confirmButtonText: '确认删除',
|
||||
cancelButtonText: '取消',
|
||||
reverseButtons: false,
|
||||
confirmButtonColor: '#ef4444',
|
||||
cancelButtonColor: '#6b7280',
|
||||
focusCancel: true,
|
||||
target: '#main-content-wrapper',
|
||||
}).then(async (result) => {
|
||||
if (result.isConfirmed) {
|
||||
try {
|
||||
const url = `/admin/logs?ids=${logId}`;
|
||||
const { success, message } = await apiFetchJson(url, { method: 'DELETE' });
|
||||
if (success) {
|
||||
Swal.fire({ toast: true, position: 'top-end', icon: 'success', title: '删除成功', showConfirmButton: false, timer: 2000, timerProgressBar: true });
|
||||
this.loadAndRenderLogs();
|
||||
} else {
|
||||
throw new Error(message || '删除失败,请稍后重试。');
|
||||
}
|
||||
} catch (error) {
|
||||
Swal.fire({ icon: 'error', title: '操作失败', text: error.message, target: '#main-content-wrapper' });
|
||||
}
|
||||
}
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
changePageSize(newSize) {
|
||||
this.state.filters.page_size = newSize;
|
||||
this.state.filters.page = 1;
|
||||
this.loadAndRenderLogs();
|
||||
}
|
||||
goToPage(page) {
|
||||
if (page < 1 || page > this.state.pagination.pages || this.state.isLoading) return;
|
||||
this.state.filters.page = page;
|
||||
this.loadAndRenderLogs();
|
||||
}
|
||||
updatePaginationUI() {
|
||||
const { page, pages, total } = this.state.pagination;
|
||||
if (this.elements.pageInfo) {
|
||||
this.elements.pageInfo.textContent = `第 ${page} / ${pages} 页`;
|
||||
}
|
||||
if (this.elements.totalCount) {
|
||||
this.elements.totalCount.textContent = total;
|
||||
}
|
||||
|
||||
if (this.elements.paginationBtns.length >= 4) {
|
||||
const isFirstPage = page === 1;
|
||||
const isLastPage = page === pages || pages === 0;
|
||||
this.elements.paginationBtns[0].disabled = isFirstPage;
|
||||
this.elements.paginationBtns[1].disabled = isFirstPage;
|
||||
this.elements.paginationBtns[2].disabled = isLastPage;
|
||||
this.elements.paginationBtns[3].disabled = isLastPage;
|
||||
}
|
||||
}
|
||||
async loadGroupsOnce() {
|
||||
if (dataStore.groups.size > 0) return; // 防止重复加载
|
||||
if (dataStore.groups.size > 0) return;
|
||||
try {
|
||||
const { success, data } = await apiFetchJson("/admin/keygroups");
|
||||
if (success && Array.isArray(data)) {
|
||||
@@ -53,33 +624,81 @@ class LogsPage {
|
||||
|
||||
async loadAndRenderLogs() {
|
||||
this.state.isLoading = true;
|
||||
this.logList.renderLoading();
|
||||
|
||||
this.state.selectedLogIds.clear();
|
||||
this.logList.renderLoading();
|
||||
this.updatePaginationUI();
|
||||
this.syncSelectionUI();
|
||||
try {
|
||||
const query = new URLSearchParams(this.state.filters);
|
||||
const { success, data } = await apiFetchJson(`/admin/logs?${query.toString()}`);
|
||||
const finalParams = {};
|
||||
const { filters } = this.state;
|
||||
|
||||
Object.keys(filters).forEach(key => {
|
||||
if (!(filters[key] instanceof Set)) {
|
||||
finalParams[key] = filters[key];
|
||||
}
|
||||
});
|
||||
// --- [MODIFIED] START: Combine all error-related filters into a single parameter for OR logic ---
|
||||
const allErrorCodes = new Set();
|
||||
const allStatusCodes = new Set(filters.status_codes);
|
||||
if (filters.error_types.size > 0) {
|
||||
filters.error_types.forEach(type => {
|
||||
// Find matching static error codes (e.g., 'API_KEY_INVALID')
|
||||
for (const [code, obj] of Object.entries(STATIC_ERROR_MAP)) {
|
||||
if (obj.type === type) {
|
||||
allErrorCodes.add(code);
|
||||
}
|
||||
}
|
||||
// Find matching status codes (e.g., 400, 401)
|
||||
for (const [code, obj] of Object.entries(STATUS_CODE_MAP)) {
|
||||
if (obj.type === type) {
|
||||
allStatusCodes.add(code);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (success && typeof data === 'object') {
|
||||
// Pass the combined codes to the backend. The backend will handle the OR logic.
|
||||
if (allErrorCodes.size > 0) finalParams.error_codes = [...allErrorCodes].join(',');
|
||||
if (allStatusCodes.size > 0) finalParams.status_codes = [...allStatusCodes].join(',');
|
||||
// --- [MODIFIED] END ---
|
||||
|
||||
if (filters.key_ids.size > 0) finalParams.key_ids = [...filters.key_ids].join(',');
|
||||
if (filters.group_ids.size > 0) finalParams.group_ids = [...filters.group_ids].join(',');
|
||||
|
||||
Object.keys(finalParams).forEach(key => {
|
||||
if (finalParams[key] === '' || finalParams[key] === null || finalParams[key] === undefined) {
|
||||
delete finalParams[key];
|
||||
}
|
||||
});
|
||||
const query = new URLSearchParams(finalParams);
|
||||
|
||||
const { success, data } = await apiFetchJson(
|
||||
`/admin/logs?${query.toString()}`,
|
||||
{ cache: 'no-cache', noCache: true }
|
||||
);
|
||||
if (success && typeof data === 'object' && data.items) {
|
||||
const { items, total, page, page_size } = data;
|
||||
this.state.logs = items;
|
||||
this.state.pagination = { page, page_size, total, pages: Math.ceil(total / page_size) };
|
||||
|
||||
// [核心] 在渲染前,按需批量加载本页日志所需的、尚未缓存的Key信息
|
||||
const totalPages = Math.ceil(total / page_size);
|
||||
this.state.pagination = { page, page_size, total, pages: totalPages > 0 ? totalPages : 1 };
|
||||
await this.enrichLogsWithKeyNames(items);
|
||||
|
||||
// 调用 render,此时 dataStore 中已包含所有需要的数据
|
||||
this.logList.render(this.state.logs, this.state.pagination);
|
||||
this.logList.render(this.state.logs, this.state.pagination, this.state.selectedLogIds);
|
||||
} else {
|
||||
this.state.logs = [];
|
||||
this.state.pagination = { ...this.state.pagination, total: 0, pages: 1, page: 1 };
|
||||
this.logList.render([], this.state.pagination);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to load logs:", error);
|
||||
this.state.logs = [];
|
||||
this.state.pagination = { ...this.state.pagination, total: 0, pages: 1, page: 1 };
|
||||
this.logList.render([], this.state.pagination);
|
||||
} finally {
|
||||
this.state.isLoading = false;
|
||||
this.updatePaginationUI();
|
||||
this.syncSelectionUI();
|
||||
}
|
||||
}
|
||||
|
||||
async enrichLogsWithKeyNames(logs) {
|
||||
const missingKeyIds = [...new Set(
|
||||
logs.filter(log => log.KeyID && !dataStore.keys.has(log.KeyID)).map(log => log.KeyID)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Filename: frontend/js/pages/logs/logList.js
|
||||
import { escapeHTML } from '../../utils/utils.js';
|
||||
|
||||
const STATIC_ERROR_MAP = {
|
||||
export const STATIC_ERROR_MAP = {
|
||||
'API_KEY_INVALID': { type: '密钥无效', style: 'red' },
|
||||
'INVALID_ARGUMENT': { type: '参数无效', style: 'red' },
|
||||
'PERMISSION_DENIED': { type: '权限不足', style: 'red' },
|
||||
@@ -13,8 +13,8 @@ const STATIC_ERROR_MAP = {
|
||||
'INTERNAL': { type: 'Google内部错误', style: 'yellow' },
|
||||
'UNAVAILABLE': { type: '服务不可用', style: 'yellow' },
|
||||
};
|
||||
// --- [更新] HTTP状态码到类型和样式的动态映射表 ---
|
||||
const STATUS_CODE_MAP = {
|
||||
// --- HTTP状态码到类型和样式的动态映射表 ---
|
||||
export const STATUS_CODE_MAP = {
|
||||
400: { type: '错误请求', style: 'red' },
|
||||
401: { type: '认证失败', style: 'red' },
|
||||
403: { type: '禁止访问', style: 'red' },
|
||||
@@ -55,7 +55,7 @@ class LogList {
|
||||
this.container.innerHTML = `<tr><td colspan="9" class="p-8 text-center text-muted-foreground"><i class="fas fa-spinner fa-spin mr-2"></i> 加载日志中...</td></tr>`;
|
||||
}
|
||||
|
||||
render(logs, pagination) {
|
||||
render(logs, pagination, selectedLogIds) {
|
||||
if (!this.container) return;
|
||||
if (!logs || logs.length === 0) {
|
||||
this.container.innerHTML = `<tr><td colspan="9" class="p-8 text-center text-muted-foreground">没有找到相关的日志记录。</td></tr>`;
|
||||
@@ -63,7 +63,10 @@ class LogList {
|
||||
}
|
||||
const { page, page_size } = pagination;
|
||||
const startIndex = (page - 1) * page_size;
|
||||
const logsHtml = logs.map((log, index) => this.createLogRowHtml(log, startIndex + index + 1)).join('');
|
||||
const logsHtml = logs.map((log, index) => {
|
||||
const isChecked = selectedLogIds.has(log.ID);
|
||||
return this.createLogRowHtml(log, startIndex + index + 1, isChecked);
|
||||
}).join('');
|
||||
this.container.innerHTML = logsHtml;
|
||||
}
|
||||
|
||||
@@ -75,7 +78,7 @@ class LogList {
|
||||
statusCodeHtml: `<span class="inline-flex items-center rounded-md bg-green-500/10 px-2 py-1 text-xs font-medium text-green-600">成功</span>`
|
||||
};
|
||||
}
|
||||
// 2. [新增] 特殊场景优先判断 (结合ErrorCode和ErrorMessage)
|
||||
// 2. 特殊场景优先判断 (结合ErrorCode和ErrorMessage)
|
||||
const codeMatch = log.ErrorCode ? log.ErrorCode.match(errorCodeRegex) : null;
|
||||
if (codeMatch && codeMatch[1] && log.ErrorMessage) {
|
||||
const code = parseInt(codeMatch[1], 10);
|
||||
@@ -125,7 +128,7 @@ class LogList {
|
||||
return `<div class="inline-block rounded bg-zinc-100 dark:bg-zinc-800 px-2 py-0.5"><span class="font-quinquefive text-xs tracking-wider ${styleClass}">${modelName}</span></div>`;
|
||||
}
|
||||
|
||||
createLogRowHtml(log, index) {
|
||||
createLogRowHtml(log, index, isChecked) {
|
||||
const group = this.dataStore.groups.get(log.GroupID);
|
||||
const groupName = group ? group.display_name : (log.GroupID ? `Group #${log.GroupID}` : 'N/A');
|
||||
const key = this.dataStore.keys.get(log.KeyID);
|
||||
@@ -140,9 +143,13 @@ class LogList {
|
||||
const modelNameFormatted = this._formatModelName(log.ModelName);
|
||||
const errorMessageAttr = log.ErrorMessage ? `data-error-message="${escape(log.ErrorMessage)}"` : '';
|
||||
const requestTime = new Date(log.RequestTime).toLocaleString();
|
||||
|
||||
const checkedAttr = isChecked ? 'checked' : '';
|
||||
return `
|
||||
<tr class="table-row" data-log-id="${log.ID}" ${errorMessageAttr}>
|
||||
<td class="table-cell"><input type="checkbox" class="h-4 w-4 rounded border-zinc-300 text-blue-600 focus:ring-blue-500"></td>
|
||||
<tr class="table-row group even:bg-zinc-200/30 dark:even:bg-black/10" data-log-id="${log.ID}" ${errorMessageAttr}>
|
||||
<td class="table-cell">
|
||||
<input type="checkbox" class="h-4 w-4 rounded border-zinc-300 text-blue-600 focus:ring-blue-500" ${checkedAttr}>
|
||||
</td>
|
||||
<td class="table-cell font-mono text-muted-foreground">${index}</td>
|
||||
<td class="table-cell font-medium font-mono">${apiKeyDisplay}</td>
|
||||
<td class="table-cell">${groupName}</td>
|
||||
@@ -150,14 +157,29 @@ class LogList {
|
||||
<td class="table-cell">${errorInfo.statusCodeHtml}</td>
|
||||
<td class="table-cell">${modelNameFormatted}</td>
|
||||
<td class="table-cell text-muted-foreground text-xs">${requestTime}</td>
|
||||
<td class="table-cell">
|
||||
<button class="btn btn-ghost btn-icon btn-sm" aria-label="查看详情">
|
||||
<i class="fas fa-ellipsis-h h-4 w-4"></i>
|
||||
</button>
|
||||
<td class="table-cell relative">
|
||||
<!-- [MODIFIED] - 2. 替换原有按钮为悬浮操作菜单 -->
|
||||
<div class="flex items-center justify-center">
|
||||
<!-- 默认显示的图标 -->
|
||||
<span class="text-zinc-400 group-hover:opacity-0 transition-opacity">
|
||||
<i class="fas fa-ellipsis-h h-4 w-4"></i>
|
||||
</span>
|
||||
<!-- 悬浮时显示的操作按钮 -->
|
||||
<div class="absolute right-2 top-1/2 -translate-y-1/2 flex items-center bg-zinc-100 dark:bg-zinc-700 rounded-full shadow-md opacity-0 group-hover:opacity-100 transition-opacity duration-200 z-10">
|
||||
<button class="px-2 py-1 text-zinc-500 hover:text-blue-500" data-action="view-log-details" title="查看详情">
|
||||
<i class="fas fa-eye"></i>
|
||||
</button>
|
||||
<button class="px-2 py-1 text-zinc-500 hover:text-green-500" data-action="copy-api-key" title="复制APIKey">
|
||||
<i class="fas fa-copy"></i>
|
||||
</button>
|
||||
<button class="px-2 py-1 text-zinc-500 hover:text-red-500" data-action="delete-log" title="删除日志">
|
||||
<i class="fas fa-trash-alt"></i>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
`;
|
||||
}
|
||||
}
|
||||
|
||||
export default LogList;
|
||||
|
||||
148
frontend/js/pages/logs/logSettingsModal.js
Normal file
148
frontend/js/pages/logs/logSettingsModal.js
Normal file
@@ -0,0 +1,148 @@
|
||||
// Filename: frontend/js/pages/logs/logList.js
|
||||
import { modalManager } from '../../components/ui.js';
|
||||
|
||||
export default class LogSettingsModal {
|
||||
constructor({ onSave }) {
|
||||
this.modalId = 'log-settings-modal';
|
||||
this.onSave = onSave;
|
||||
const modal = document.getElementById(this.modalId);
|
||||
if (!modal) {
|
||||
throw new Error(`Modal with id "${this.modalId}" not found.`);
|
||||
}
|
||||
|
||||
this.elements = {
|
||||
modal: modal,
|
||||
title: document.getElementById('log-settings-modal-title'),
|
||||
saveBtn: document.getElementById('log-settings-save-btn'),
|
||||
|
||||
logLevelSelect: document.getElementById('log-level-select'),
|
||||
|
||||
cleanupEnableToggle: document.getElementById('log-cleanup-enable'),
|
||||
cleanupSettingsPanel: document.getElementById('log-cleanup-settings'),
|
||||
cleanupRetentionInput: document.getElementById('log-cleanup-retention-days'),
|
||||
retentionDaysGroup: document.getElementById('retention-days-group'),
|
||||
retentionPresetBtns: document.querySelectorAll('#retention-days-group button[data-days]'),
|
||||
cleanupExecTimeInput: document.getElementById('log-cleanup-exec-time'), // [NEW] 添加时间选择器元素
|
||||
};
|
||||
|
||||
this.activePresetClasses = ['!bg-primary', '!text-primary-foreground', '!border-primary', 'hover:!bg-primary/90'];
|
||||
this.inactivePresetClasses = ['modal-btn-secondary'];
|
||||
this._initEventListeners();
|
||||
}
|
||||
|
||||
open(settingsData = {}) {
|
||||
this._populateForm(settingsData);
|
||||
modalManager.show(this.modalId);
|
||||
}
|
||||
|
||||
close() {
|
||||
modalManager.hide(this.modalId);
|
||||
}
|
||||
|
||||
_initEventListeners() {
|
||||
this.elements.saveBtn.addEventListener('click', this._handleSave.bind(this));
|
||||
|
||||
this.elements.cleanupEnableToggle.addEventListener('change', (e) => {
|
||||
this.elements.cleanupSettingsPanel.classList.toggle('hidden', !e.target.checked);
|
||||
});
|
||||
|
||||
this._initRetentionPresets();
|
||||
|
||||
const closeAction = () => this.close();
|
||||
const closeTriggers = this.elements.modal.querySelectorAll(`[data-modal-close="${this.modalId}"]`);
|
||||
closeTriggers.forEach(trigger => trigger.addEventListener('click', closeAction));
|
||||
this.elements.modal.addEventListener('click', (event) => {
|
||||
if (event.target === this.elements.modal) closeAction();
|
||||
});
|
||||
}
|
||||
|
||||
_initRetentionPresets() {
|
||||
this.elements.retentionPresetBtns.forEach(btn => {
|
||||
btn.addEventListener('click', () => {
|
||||
const days = btn.dataset.days;
|
||||
this.elements.cleanupRetentionInput.value = days;
|
||||
this._updateActivePresetButton(days);
|
||||
});
|
||||
});
|
||||
|
||||
this.elements.cleanupRetentionInput.addEventListener('input', (e) => {
|
||||
this._updateActivePresetButton(e.target.value);
|
||||
});
|
||||
}
|
||||
|
||||
_updateActivePresetButton(currentValue) {
|
||||
this.elements.retentionPresetBtns.forEach(btn => {
|
||||
if (btn.dataset.days === currentValue) {
|
||||
btn.classList.remove(...this.inactivePresetClasses);
|
||||
btn.classList.add(...this.activePresetClasses);
|
||||
} else {
|
||||
btn.classList.remove(...this.activePresetClasses);
|
||||
btn.classList.add(...this.inactivePresetClasses);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async _handleSave() {
|
||||
const data = this._collectFormData();
|
||||
if (data.auto_cleanup.enabled && (!data.auto_cleanup.retention_days || data.auto_cleanup.retention_days <= 0)) {
|
||||
alert('启用自动清理时,保留天数必须是大于0的数字。');
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.onSave) {
|
||||
this.elements.saveBtn.disabled = true;
|
||||
this.elements.saveBtn.textContent = '保存中...';
|
||||
try {
|
||||
await this.onSave(data);
|
||||
this.close();
|
||||
} catch (error) {
|
||||
console.error("Failed to save log settings:", error);
|
||||
// 可以添加一个UI提示,比如 toast
|
||||
} finally {
|
||||
this.elements.saveBtn.disabled = false;
|
||||
this.elements.saveBtn.textContent = '保存设置';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// [MODIFIED] - 更新此方法以填充新的时间选择器
|
||||
_populateForm(data) {
|
||||
this.elements.logLevelSelect.value = data.log_level || 'INFO';
|
||||
|
||||
const cleanup = data.auto_cleanup || {};
|
||||
const isCleanupEnabled = cleanup.enabled || false;
|
||||
this.elements.cleanupEnableToggle.checked = isCleanupEnabled;
|
||||
this.elements.cleanupSettingsPanel.classList.toggle('hidden', !isCleanupEnabled);
|
||||
|
||||
const retentionDays = cleanup.retention_days || '';
|
||||
this.elements.cleanupRetentionInput.value = retentionDays;
|
||||
this._updateActivePresetButton(retentionDays.toString());
|
||||
|
||||
// [NEW] 填充执行时间,提供一个安全的默认值
|
||||
this.elements.cleanupExecTimeInput.value = cleanup.exec_time || '04:05';
|
||||
}
|
||||
|
||||
// [MODIFIED] - 更新此方法以收集新的时间数据
|
||||
_collectFormData() {
|
||||
const parseIntOrNull = (value) => {
|
||||
const trimmed = value.trim();
|
||||
if (trimmed === '') return null;
|
||||
const num = parseInt(trimmed, 10);
|
||||
return isNaN(num) ? null : num;
|
||||
};
|
||||
|
||||
const isCleanupEnabled = this.elements.cleanupEnableToggle.checked;
|
||||
|
||||
const formData = {
|
||||
log_level: this.elements.logLevelSelect.value,
|
||||
auto_cleanup: {
|
||||
enabled: isCleanupEnabled,
|
||||
interval: isCleanupEnabled ? 'daily' : null,
|
||||
retention_days: isCleanupEnabled ? parseIntOrNull(this.elements.cleanupRetentionInput.value) : null,
|
||||
exec_time: isCleanupEnabled ? this.elements.cleanupExecTimeInput.value : '04:05', // [NEW] 收集时间数据
|
||||
},
|
||||
};
|
||||
|
||||
return formData;
|
||||
}
|
||||
}
|
||||
176
frontend/js/pages/logs/systemLog.js
Normal file
176
frontend/js/pages/logs/systemLog.js
Normal file
@@ -0,0 +1,176 @@
|
||||
// Filename: frontend/js/pages/logs/systemLog.js
|
||||
|
||||
export default class SystemLogTerminal {
|
||||
constructor(container, controlsContainer) {
|
||||
this.container = container;
|
||||
this.controlsContainer = controlsContainer;
|
||||
this.ws = null;
|
||||
this.isPaused = false;
|
||||
this.shouldAutoScroll = true;
|
||||
this.reconnectAttempts = 0;
|
||||
this.maxReconnectAttempts = 5;
|
||||
this.isConnected = false;
|
||||
|
||||
this.elements = {
|
||||
output: this.container.querySelector('#log-terminal-output'),
|
||||
statusIndicator: this.controlsContainer.querySelector('#terminal-status-indicator'),
|
||||
clearBtn: this.controlsContainer.querySelector('[data-action="clear-terminal"]'),
|
||||
pauseBtn: this.controlsContainer.querySelector('[data-action="toggle-pause-terminal"]'),
|
||||
scrollBtn: this.controlsContainer.querySelector('[data-action="toggle-scroll-terminal"]'),
|
||||
connectBtn: this.controlsContainer.querySelector('[data-action="toggle-connect-terminal"]'),
|
||||
settingsBtn: this.controlsContainer.querySelector('[data-action="terminal-settings"]'),
|
||||
};
|
||||
|
||||
this._initEventListeners();
|
||||
}
|
||||
|
||||
_initEventListeners() {
|
||||
this.elements.clearBtn.addEventListener('click', () => this.clear());
|
||||
this.elements.pauseBtn.addEventListener('click', () => this.togglePause());
|
||||
this.elements.scrollBtn.addEventListener('click', () => this.toggleAutoScroll());
|
||||
this.elements.connectBtn.addEventListener('click', () => this.toggleConnect());
|
||||
this.elements.settingsBtn.addEventListener('click', () => this.openSettings());
|
||||
}
|
||||
|
||||
toggleConnect() {
|
||||
if (this.isConnected) {
|
||||
this.disconnect();
|
||||
} else {
|
||||
this.connect();
|
||||
}
|
||||
}
|
||||
|
||||
connect() {
|
||||
this.clear();
|
||||
this._appendMessage('info', '正在连接到实时日志流...');
|
||||
this._updateStatus('connecting', '连接中...');
|
||||
|
||||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
const wsUrl = `${protocol}//${window.location.host}/ws/system-logs`;
|
||||
|
||||
this.ws = new WebSocket(wsUrl);
|
||||
|
||||
this.ws.onopen = () => {
|
||||
this._appendMessage('info', '✓ 已连接到系统日志流');
|
||||
this._updateStatus('connected', '已连接');
|
||||
this.reconnectAttempts = 0;
|
||||
this.isConnected = true;
|
||||
this.elements.connectBtn.title = '断开';
|
||||
this.elements.connectBtn.querySelector('i').classList.replace('fa-plug', 'fa-minus-circle');
|
||||
};
|
||||
|
||||
this.ws.onmessage = (event) => {
|
||||
if (this.isPaused) return;
|
||||
|
||||
try {
|
||||
const data = JSON.parse(event.data);
|
||||
const levelColors = {
|
||||
'error': 'text-red-500',
|
||||
'warning': 'text-yellow-400',
|
||||
'info': 'text-green-400',
|
||||
'debug': 'text-zinc-400'
|
||||
};
|
||||
const color = levelColors[data.level] || 'text-zinc-200';
|
||||
const timestamp = new Date(data.timestamp).toLocaleTimeString();
|
||||
const msg = `[${timestamp}] [${data.level.toUpperCase()}] ${data.message}`;
|
||||
this._appendMessage(color, msg);
|
||||
} catch (e) {
|
||||
this._appendMessage('text-zinc-200', event.data);
|
||||
}
|
||||
};
|
||||
|
||||
this.ws.onerror = (error) => {
|
||||
this._appendMessage('error', `✗ WebSocket 错误`);
|
||||
this._updateStatus('error', '连接错误');
|
||||
};
|
||||
|
||||
this.ws.onclose = () => {
|
||||
this._appendMessage('error', '✗ 连接已断开');
|
||||
this._updateStatus('disconnected', '未连接');
|
||||
this.isConnected = false;
|
||||
this.elements.connectBtn.title = '连接';
|
||||
this.elements.connectBtn.querySelector('i').classList.replace('fa-minus-circle', 'fa-plug');
|
||||
|
||||
if (this.reconnectAttempts < this.maxReconnectAttempts) {
|
||||
this.reconnectAttempts++;
|
||||
setTimeout(() => {
|
||||
this._appendMessage('info', `尝试重新连接 (${this.reconnectAttempts}/${this.maxReconnectAttempts})...`);
|
||||
this.connect();
|
||||
}, 3000);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
disconnect() {
|
||||
if (this.ws) {
|
||||
this.ws.close();
|
||||
this.ws = null;
|
||||
}
|
||||
this.reconnectAttempts = this.maxReconnectAttempts;
|
||||
this.isConnected = false;
|
||||
this._updateStatus('disconnected', '未连接');
|
||||
this.elements.connectBtn.title = '连接';
|
||||
this.elements.connectBtn.querySelector('i').classList.replace('fa-minus-circle', 'fa-plug');
|
||||
}
|
||||
|
||||
clear() {
|
||||
if(this.elements.output) {
|
||||
this.elements.output.innerHTML = '';
|
||||
}
|
||||
}
|
||||
|
||||
togglePause() {
|
||||
this.isPaused = !this.isPaused;
|
||||
const icon = this.elements.pauseBtn.querySelector('i');
|
||||
if (this.isPaused) {
|
||||
this.elements.pauseBtn.title = '继续';
|
||||
icon.classList.replace('fa-pause', 'fa-play');
|
||||
} else {
|
||||
this.elements.pauseBtn.title = '暂停';
|
||||
icon.classList.replace('fa-play', 'fa-pause');
|
||||
}
|
||||
}
|
||||
|
||||
toggleAutoScroll() {
|
||||
this.shouldAutoScroll = !this.shouldAutoScroll;
|
||||
this.elements.scrollBtn.title = this.shouldAutoScroll ? '自动滚动' : '手动滚动';
|
||||
}
|
||||
|
||||
openSettings() {
|
||||
// 实现设置功能
|
||||
console.log('打开设置');
|
||||
}
|
||||
|
||||
_appendMessage(colorClass, text) {
|
||||
if (!this.elements.output) return;
|
||||
|
||||
const p = document.createElement('p');
|
||||
p.className = colorClass;
|
||||
p.textContent = text;
|
||||
this.elements.output.appendChild(p);
|
||||
|
||||
if (this.shouldAutoScroll) {
|
||||
this.elements.output.scrollTop = this.elements.output.scrollHeight;
|
||||
}
|
||||
}
|
||||
|
||||
_updateStatus(status, text) {
|
||||
const indicator = this.elements.statusIndicator.querySelector('span.relative');
|
||||
const statusText = this.elements.statusIndicator.childNodes[2];
|
||||
|
||||
const colors = {
|
||||
'connecting': 'bg-yellow-500',
|
||||
'connected': 'bg-green-500',
|
||||
'disconnected': 'bg-zinc-500',
|
||||
'error': 'bg-red-500'
|
||||
};
|
||||
|
||||
indicator.querySelectorAll('span').forEach(span => {
|
||||
span.className = span.className.replace(/bg-\w+-\d+/g, colors[status] || colors.disconnected);
|
||||
});
|
||||
|
||||
if (statusText) {
|
||||
statusText.textContent = ` ${text}`;
|
||||
}
|
||||
}
|
||||
}
|
||||
1311
frontend/js/vendor/anime.esm.js
vendored
Normal file
1311
frontend/js/vendor/anime.esm.js
vendored
Normal file
File diff suppressed because it is too large
Load Diff
2
frontend/js/vendor/flatpickr.js
vendored
Normal file
2
frontend/js/vendor/flatpickr.js
vendored
Normal file
File diff suppressed because one or more lines are too long
72
frontend/js/vendor/marked.min.js
vendored
Normal file
72
frontend/js/vendor/marked.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
frontend/js/vendor/nanoid.js
vendored
Normal file
1
frontend/js/vendor/nanoid.js
vendored
Normal file
@@ -0,0 +1 @@
|
||||
let a="useandom-26T198340PX75pxJACKVERYMINDBUSHWOLF_GQZbfghjklqvwyzrict";export let nanoid=(e=21)=>{let t="",r=crypto.getRandomValues(new Uint8Array(e));for(let n=0;n<e;n++)t+=a[63&r[n]];return t};
|
||||
6
frontend/js/vendor/popper.esm.min.js
vendored
Normal file
6
frontend/js/vendor/popper.esm.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
4611
frontend/js/vendor/sweetalert2.esm.js
vendored
Normal file
4611
frontend/js/vendor/sweetalert2.esm.js
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1
frontend/js/vendor/sweetalert2.min.css
vendored
Normal file
1
frontend/js/vendor/sweetalert2.min.css
vendored
Normal file
File diff suppressed because one or more lines are too long
3
go.mod
3
go.mod
@@ -10,6 +10,7 @@ require (
|
||||
github.com/go-co-op/gocron v1.37.0
|
||||
github.com/go-sql-driver/mysql v1.9.3
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/jackc/pgx/v5 v5.7.5
|
||||
github.com/microcosm-cc/bluemonday v1.0.27
|
||||
github.com/redis/go-redis/v9 v9.3.0
|
||||
@@ -17,6 +18,8 @@ require (
|
||||
github.com/spf13/viper v1.20.1
|
||||
go.uber.org/dig v1.19.0
|
||||
golang.org/x/net v0.42.0
|
||||
golang.org/x/time v0.14.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gorm.io/datatypes v1.0.5
|
||||
gorm.io/driver/mysql v1.6.0
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
|
||||
6
go.sum
6
go.sum
@@ -92,6 +92,8 @@ github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
@@ -311,6 +313,8 @@ golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
|
||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
@@ -325,6 +329,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
|
||||
@@ -28,12 +28,6 @@ type GeminiChannel struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// 用于安全提取信息的本地结构体
|
||||
type requestMetadata struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *GeminiChannel {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
@@ -47,39 +41,58 @@ func NewGeminiChannel(logger *logrus.Logger, cfg *models.SystemSettings) *Gemini
|
||||
logger: logger,
|
||||
httpClient: &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 0,
|
||||
Timeout: 0, // Timeout is handled by the request context
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TransformRequest
|
||||
func (ch *GeminiChannel) TransformRequest(c *gin.Context, requestBody []byte) (newBody []byte, modelName string, err error) {
|
||||
modelName = ch.ExtractModel(c, requestBody)
|
||||
return requestBody, modelName, nil
|
||||
}
|
||||
|
||||
// ExtractModel
|
||||
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
|
||||
return ch.extractModelFromRequest(c, bodyBytes)
|
||||
}
|
||||
|
||||
// 统一的模型提取逻辑:优先从请求体解析,失败则回退到从URL路径解析。
|
||||
func (ch *GeminiChannel) extractModelFromRequest(c *gin.Context, bodyBytes []byte) string {
|
||||
var p struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
_ = json.Unmarshal(requestBody, &p)
|
||||
modelName = strings.TrimPrefix(p.Model, "models/")
|
||||
_ = json.Unmarshal(bodyBytes, &p)
|
||||
modelName := strings.TrimPrefix(p.Model, "models/")
|
||||
|
||||
if modelName == "" {
|
||||
modelName = ch.extractModelFromPath(c.Request.URL.Path)
|
||||
}
|
||||
return requestBody, modelName, nil
|
||||
return modelName
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) extractModelFromPath(path string) string {
|
||||
parts := strings.Split(path, "/")
|
||||
for _, part := range parts {
|
||||
// 覆盖更多模型名称格式
|
||||
if strings.HasPrefix(part, "gemini-") || strings.HasPrefix(part, "text-") || strings.HasPrefix(part, "embedding-") {
|
||||
modelPart := strings.Split(part, ":")[0]
|
||||
return modelPart
|
||||
return strings.Split(part, ":")[0]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsOpenAICompatibleRequest 通过纯粹的字符串操作判断,不依赖 gin.Context。
|
||||
func (ch *GeminiChannel) IsOpenAICompatibleRequest(c *gin.Context) bool {
|
||||
path := c.Request.URL.Path
|
||||
return strings.Contains(path, "/v1/chat/completions") || strings.Contains(path, "/v1/embeddings")
|
||||
return ch.isOpenAIPath(c.Request.URL.Path)
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) isOpenAIPath(path string) bool {
|
||||
return strings.Contains(path, "/v1/chat/completions") ||
|
||||
strings.Contains(path, "/v1/completions") ||
|
||||
strings.Contains(path, "/v1/embeddings") ||
|
||||
strings.Contains(path, "/v1/models") ||
|
||||
strings.Contains(path, "/v1/audio/")
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) ValidateKey(
|
||||
@@ -88,25 +101,28 @@ func (ch *GeminiChannel) ValidateKey(
|
||||
targetURL string,
|
||||
timeout time.Duration,
|
||||
) *CustomErrors.APIError {
|
||||
client := &http.Client{
|
||||
Timeout: timeout,
|
||||
}
|
||||
client := &http.Client{Timeout: timeout}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil)
|
||||
if err != nil {
|
||||
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request: "+err.Error())
|
||||
return CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "failed to create validation request")
|
||||
}
|
||||
|
||||
ch.ModifyRequest(req, apiKey)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "failed to send validation request: "+err.Error())
|
||||
return CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "validation request failed: "+err.Error())
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
return nil
|
||||
}
|
||||
|
||||
errorBody, _ := io.ReadAll(resp.Body)
|
||||
parsedMessage := CustomErrors.ParseUpstreamError(errorBody)
|
||||
parsedMessage, _ := CustomErrors.ParseUpstreamError(errorBody)
|
||||
|
||||
return &CustomErrors.APIError{
|
||||
HTTPStatus: resp.StatusCode,
|
||||
Code: fmt.Sprintf("UPSTREAM_%d", resp.StatusCode),
|
||||
@@ -115,10 +131,6 @@ func (ch *GeminiChannel) ValidateKey(
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) ModifyRequest(req *http.Request, apiKey *models.APIKey) {
|
||||
// TODO: [Future Refactoring] Decouple auth logic from URL path.
|
||||
// The authentication method (e.g., Bearer token vs. API key in query) should ideally be a property
|
||||
// of the UpstreamEndpoint or a new "AuthProfile" entity, rather than being hardcoded based on URL patterns.
|
||||
// This would make the channel more generic and adaptable to new upstream provider types.
|
||||
if strings.Contains(req.URL.Path, "/v1beta/openai/") {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey.APIKey)
|
||||
} else {
|
||||
@@ -133,24 +145,22 @@ func (ch *GeminiChannel) IsStreamRequest(c *gin.Context, bodyBytes []byte) bool
|
||||
if strings.HasSuffix(c.Request.URL.Path, ":streamGenerateContent") {
|
||||
return true
|
||||
}
|
||||
var meta requestMetadata
|
||||
if err := json.Unmarshal(bodyBytes, &meta); err == nil {
|
||||
var meta struct {
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
if json.Unmarshal(bodyBytes, &meta) == nil {
|
||||
return meta.Stream
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
|
||||
_, modelName, _ := ch.TransformRequest(c, bodyBytes)
|
||||
return modelName
|
||||
}
|
||||
|
||||
// RewritePath 使用 url.JoinPath 保证路径拼接的正确性。
|
||||
func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
|
||||
tempCtx := &gin.Context{Request: &http.Request{URL: &url.URL{Path: originalPath}}}
|
||||
var rewrittenSegment string
|
||||
if ch.IsOpenAICompatibleRequest(tempCtx) {
|
||||
var apiEndpoint string
|
||||
|
||||
if ch.isOpenAIPath(originalPath) {
|
||||
v1Index := strings.LastIndex(originalPath, "/v1/")
|
||||
var apiEndpoint string
|
||||
if v1Index != -1 {
|
||||
apiEndpoint = originalPath[v1Index+len("/v1/"):]
|
||||
} else {
|
||||
@@ -158,69 +168,76 @@ func (ch *GeminiChannel) RewritePath(basePath, originalPath string) string {
|
||||
}
|
||||
rewrittenSegment = "v1beta/openai/" + apiEndpoint
|
||||
} else {
|
||||
tempPath := originalPath
|
||||
if strings.HasPrefix(tempPath, "/v1/") {
|
||||
tempPath = "/v1beta/" + strings.TrimPrefix(tempPath, "/v1/")
|
||||
if strings.HasPrefix(originalPath, "/v1/") {
|
||||
rewrittenSegment = "v1beta/" + strings.TrimPrefix(originalPath, "/v1/")
|
||||
} else {
|
||||
rewrittenSegment = strings.TrimPrefix(originalPath, "/")
|
||||
}
|
||||
rewrittenSegment = strings.TrimPrefix(tempPath, "/")
|
||||
}
|
||||
trimmedBasePath := strings.TrimSuffix(basePath, "/")
|
||||
pathToJoin := rewrittenSegment
|
||||
|
||||
trimmedBasePath := strings.TrimSuffix(basePath, "/")
|
||||
|
||||
// 防止版本号重复拼接,例如 basePath 是 /v1beta,而重写段也是 v1beta/..
|
||||
versionPrefixes := []string{"v1beta", "v1"}
|
||||
for _, prefix := range versionPrefixes {
|
||||
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(pathToJoin, prefix+"/") {
|
||||
pathToJoin = strings.TrimPrefix(pathToJoin, prefix+"/")
|
||||
if strings.HasSuffix(trimmedBasePath, "/"+prefix) && strings.HasPrefix(rewrittenSegment, prefix+"/") {
|
||||
rewrittenSegment = strings.TrimPrefix(rewrittenSegment, prefix+"/")
|
||||
break
|
||||
}
|
||||
}
|
||||
finalPath, err := url.JoinPath(trimmedBasePath, pathToJoin)
|
||||
|
||||
finalPath, err := url.JoinPath(trimmedBasePath, rewrittenSegment)
|
||||
if err != nil {
|
||||
return trimmedBasePath + "/" + strings.TrimPrefix(pathToJoin, "/")
|
||||
// 回退到简单的字符串拼接
|
||||
return trimmedBasePath + "/" + strings.TrimPrefix(rewrittenSegment, "/")
|
||||
}
|
||||
return finalPath
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) ModifyResponse(resp *http.Response) error {
|
||||
// 这是一个桩实现,暂时不需要任何逻辑。
|
||||
return nil
|
||||
return nil // 桩实现
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) HandleError(c *gin.Context, err error) {
|
||||
// 这是一个桩实现,暂时不需要任何逻辑。
|
||||
// 桩实现
|
||||
}
|
||||
|
||||
// ==========================================================
|
||||
// ================== “智能路由”的核心引擎 ===================
|
||||
// ==========================================================
|
||||
func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartRequestParams) {
|
||||
log := ch.logger.WithField("correlation_id", params.CorrelationID)
|
||||
|
||||
targetURL, err := url.Parse(params.UpstreamURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to parse upstream URL")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL format"))
|
||||
log.WithError(err).Error("Invalid upstream URL")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Invalid upstream URL"))
|
||||
return
|
||||
}
|
||||
|
||||
targetURL.Path = c.Request.URL.Path
|
||||
targetURL.RawQuery = c.Request.URL.RawQuery
|
||||
|
||||
initialReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", targetURL.String(), bytes.NewReader(params.RequestBody))
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to create initial smart request")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, err.Error()))
|
||||
log.WithError(err).Error("Failed to create initial smart stream request")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to create request"))
|
||||
return
|
||||
}
|
||||
|
||||
ch.ModifyRequest(initialReq, params.APIKey)
|
||||
initialReq.Header.Del("Authorization")
|
||||
|
||||
resp, err := ch.httpClient.Do(initialReq)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Initial smart request failed")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, err.Error()))
|
||||
log.WithError(err).Error("Initial smart stream request failed")
|
||||
errToJSON(c, CustomErrors.NewAPIError(CustomErrors.ErrBadGateway, "Request to upstream failed"))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Warnf("Initial request received non-200 status: %d", resp.StatusCode)
|
||||
standardizedResp := ch.standardizeError(resp, params.LogTruncationLimit, log)
|
||||
defer standardizedResp.Body.Close()
|
||||
|
||||
c.Writer.WriteHeader(standardizedResp.StatusCode)
|
||||
for key, values := range standardizedResp.Header {
|
||||
for _, value := range values {
|
||||
@@ -228,45 +245,71 @@ func (ch *GeminiChannel) ProcessSmartStreamRequest(c *gin.Context, params SmartR
|
||||
}
|
||||
}
|
||||
io.Copy(c.Writer, standardizedResp.Body)
|
||||
|
||||
params.EventLogger.IsSuccess = false
|
||||
params.EventLogger.StatusCode = resp.StatusCode
|
||||
return
|
||||
}
|
||||
|
||||
ch.processStreamAndRetry(c, initialReq.Header, resp.Body, params, log)
|
||||
}
|
||||
|
||||
func (ch *GeminiChannel) processStreamAndRetry(
|
||||
c *gin.Context, initialRequestHeaders http.Header, initialReader io.ReadCloser, params SmartRequestParams, log *logrus.Entry,
|
||||
c *gin.Context,
|
||||
initialRequestHeaders http.Header,
|
||||
initialReader io.ReadCloser,
|
||||
params SmartRequestParams,
|
||||
log *logrus.Entry,
|
||||
) {
|
||||
defer initialReader.Close()
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
|
||||
flusher, _ := c.Writer.(http.Flusher)
|
||||
|
||||
var accumulatedText strings.Builder
|
||||
consecutiveRetryCount := 0
|
||||
currentReader := initialReader
|
||||
maxRetries := params.MaxRetries
|
||||
retryDelay := params.RetryDelay
|
||||
|
||||
log.Infof("Starting smart stream session. Max retries: %d", maxRetries)
|
||||
|
||||
for {
|
||||
if c.Request.Context().Err() != nil {
|
||||
log.Info("Client disconnected, stopping stream processing.")
|
||||
return
|
||||
}
|
||||
|
||||
var interruptionReason string
|
||||
scanner := bufio.NewScanner(currentReader)
|
||||
|
||||
for scanner.Scan() {
|
||||
if c.Request.Context().Err() != nil {
|
||||
log.Info("Client disconnected during scan.")
|
||||
return
|
||||
}
|
||||
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Fprintf(c.Writer, "%s\n\n", line)
|
||||
flusher.Flush()
|
||||
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
var payload models.GeminiSSEPayload
|
||||
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(payload.Candidates) > 0 {
|
||||
candidate := payload.Candidates[0]
|
||||
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
|
||||
@@ -285,52 +328,71 @@ func (ch *GeminiChannel) processStreamAndRetry(
|
||||
}
|
||||
}
|
||||
currentReader.Close()
|
||||
|
||||
if interruptionReason == "" {
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.WithError(err).Warn("Stream scanner encountered an error.")
|
||||
interruptionReason = "SCANNER_ERROR"
|
||||
} else {
|
||||
log.Warn("Stream dropped unexpectedly without a finish reason.")
|
||||
log.Warn("Stream connection dropped without a finish reason.")
|
||||
interruptionReason = "CONNECTION_DROP"
|
||||
}
|
||||
}
|
||||
|
||||
if consecutiveRetryCount >= maxRetries {
|
||||
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error.", interruptionReason)
|
||||
errData, _ := json.Marshal(map[string]interface{}{"error": map[string]interface{}{"code": http.StatusGatewayTimeout, "status": "DEADLINE_EXCEEDED", "message": fmt.Sprintf("Proxy retry limit exceeded. Last interruption: %s.", interruptionReason)}})
|
||||
log.Errorf("Retry limit exceeded. Last interruption: %s. Sending final error to client.", interruptionReason)
|
||||
errData, _ := json.Marshal(map[string]interface{}{
|
||||
"error": map[string]interface{}{
|
||||
"code": http.StatusGatewayTimeout,
|
||||
"status": "DEADLINE_EXCEEDED",
|
||||
"message": fmt.Sprintf("Proxy retry limit exceeded after multiple interruptions. Last reason: %s", interruptionReason),
|
||||
},
|
||||
})
|
||||
fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(errData))
|
||||
flusher.Flush()
|
||||
return
|
||||
}
|
||||
|
||||
consecutiveRetryCount++
|
||||
params.EventLogger.Retries = consecutiveRetryCount
|
||||
log.Infof("Stream interrupted. Attempting retry %d/%d after %v.", consecutiveRetryCount, maxRetries, retryDelay)
|
||||
|
||||
time.Sleep(retryDelay)
|
||||
retryBody, _ := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
|
||||
|
||||
retryBody := buildRetryRequestBody(params.OriginalRequest, accumulatedText.String())
|
||||
retryBodyBytes, _ := json.Marshal(retryBody)
|
||||
|
||||
retryReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
|
||||
retryReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", params.UpstreamURL, bytes.NewReader(retryBodyBytes))
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to create retry request")
|
||||
continue
|
||||
}
|
||||
|
||||
retryReq.Header = initialRequestHeaders
|
||||
ch.ModifyRequest(retryReq, params.APIKey)
|
||||
retryReq.Header.Del("Authorization")
|
||||
|
||||
retryResp, err := ch.httpClient.Do(retryReq)
|
||||
if err != nil || retryResp.StatusCode != http.StatusOK || retryResp.Body == nil {
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("Retry request failed.")
|
||||
} else {
|
||||
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
|
||||
if retryResp.Body != nil {
|
||||
retryResp.Body.Close()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Retry request failed")
|
||||
continue
|
||||
}
|
||||
|
||||
if retryResp.StatusCode != http.StatusOK {
|
||||
log.Errorf("Retry request received non-200 status: %d", retryResp.StatusCode)
|
||||
retryResp.Body.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
currentReader = retryResp.Body
|
||||
}
|
||||
}
|
||||
|
||||
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) (models.GeminiRequest, error) {
|
||||
// buildRetryRequestBody 正确处理多轮对话的上下文插入。
|
||||
func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText string) models.GeminiRequest {
|
||||
retryBody := originalBody
|
||||
|
||||
// 找到最后一个 'user' 角色的消息索引
|
||||
lastUserIndex := -1
|
||||
for i := len(retryBody.Contents) - 1; i >= 0; i-- {
|
||||
if retryBody.Contents[i].Role == "user" {
|
||||
@@ -338,25 +400,26 @@ func buildRetryRequestBody(originalBody models.GeminiRequest, accumulatedText st
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
history := []models.GeminiContent{
|
||||
{Role: "model", Parts: []models.Part{{Text: accumulatedText}}},
|
||||
{Role: "user", Parts: []models.Part{{Text: SmartRetryPrompt}}},
|
||||
}
|
||||
|
||||
if lastUserIndex != -1 {
|
||||
// 如果找到了 'user' 消息,将历史记录插入到其后
|
||||
newContents := make([]models.GeminiContent, 0, len(retryBody.Contents)+2)
|
||||
newContents = append(newContents, retryBody.Contents[:lastUserIndex+1]...)
|
||||
newContents = append(newContents, history...)
|
||||
newContents = append(newContents, retryBody.Contents[lastUserIndex+1:]...)
|
||||
retryBody.Contents = newContents
|
||||
} else {
|
||||
// 如果没有 'user' 消息(理论上不应发生),则直接追加
|
||||
retryBody.Contents = append(retryBody.Contents, history...)
|
||||
}
|
||||
return retryBody, nil
|
||||
}
|
||||
|
||||
// ===============================================
|
||||
// ========= 辅助函数区 (继承并强化) =========
|
||||
// ===============================================
|
||||
return retryBody
|
||||
}
|
||||
|
||||
type googleAPIError struct {
|
||||
Error struct {
|
||||
@@ -397,25 +460,28 @@ func truncate(s string, n int) string {
|
||||
return s
|
||||
}
|
||||
|
||||
// standardizeError
|
||||
func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int, log *logrus.Entry) *http.Response {
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to read upstream error body")
|
||||
bodyBytes = []byte("Failed to read upstream error body: " + err.Error())
|
||||
bodyBytes = []byte("Failed to read upstream error body")
|
||||
}
|
||||
resp.Body.Close()
|
||||
log.Errorf("Upstream error body: %s", truncate(string(bodyBytes), truncateLimit))
|
||||
|
||||
log.Errorf("Upstream error: %s", truncate(string(bodyBytes), truncateLimit))
|
||||
|
||||
var standardizedPayload googleAPIError
|
||||
// 即使解析失败,也要构建一个标准的错误结构体
|
||||
if json.Unmarshal(bodyBytes, &standardizedPayload) != nil || standardizedPayload.Error.Code == 0 {
|
||||
standardizedPayload.Error.Code = resp.StatusCode
|
||||
standardizedPayload.Error.Message = http.StatusText(resp.StatusCode)
|
||||
standardizedPayload.Error.Status = statusToGoogleStatus(resp.StatusCode)
|
||||
standardizedPayload.Error.Details = []interface{}{map[string]string{
|
||||
"@type": "proxy.upstream.error",
|
||||
"@type": "proxy.upstream.unparsed.error",
|
||||
"body": truncate(string(bodyBytes), truncateLimit),
|
||||
}}
|
||||
}
|
||||
|
||||
newBodyBytes, _ := json.Marshal(standardizedPayload)
|
||||
newResp := &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
@@ -425,10 +491,13 @@ func (ch *GeminiChannel) standardizeError(resp *http.Response, truncateLimit int
|
||||
}
|
||||
newResp.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
newResp.Header.Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
return newResp
|
||||
}
|
||||
|
||||
// errToJSON
|
||||
func errToJSON(c *gin.Context, apiErr *CustomErrors.APIError) {
|
||||
if c.IsAborted() {
|
||||
return
|
||||
}
|
||||
c.JSON(apiErr.HTTPStatus, gin.H{"error": apiErr})
|
||||
}
|
||||
|
||||
@@ -13,9 +13,10 @@ type Config struct {
|
||||
Database DatabaseConfig
|
||||
Server ServerConfig
|
||||
Log LogConfig
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
SessionSecret string `mapstructure:"session_secret"`
|
||||
EncryptionKey string `mapstructure:"encryption_key"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
SessionSecret string `mapstructure:"session_secret"`
|
||||
EncryptionKey string `mapstructure:"encryption_key"`
|
||||
Repository RepositoryConfig `mapstructure:"repository"`
|
||||
}
|
||||
|
||||
// DatabaseConfig 存储数据库连接信息
|
||||
@@ -28,7 +29,9 @@ type DatabaseConfig struct {
|
||||
|
||||
// ServerConfig 存储HTTP服务器配置
|
||||
type ServerConfig struct {
|
||||
Port string `mapstructure:"port"`
|
||||
Port string `mapstructure:"port"`
|
||||
Host string `yaml:"host"`
|
||||
CORSOrigins []string `yaml:"cors_origins"`
|
||||
}
|
||||
|
||||
// LogConfig 存储日志配置
|
||||
@@ -37,25 +40,36 @@ type LogConfig struct {
|
||||
Format string `mapstructure:"format" json:"format"`
|
||||
EnableFile bool `mapstructure:"enable_file" json:"enable_file"`
|
||||
FilePath string `mapstructure:"file_path" json:"file_path"`
|
||||
|
||||
// 日志轮转配置(可选)
|
||||
MaxSize int `yaml:"max_size"` // MB,默认 100
|
||||
MaxBackups int `yaml:"max_backups"` // 默认 7
|
||||
MaxAge int `yaml:"max_age"` // 天,默认 30
|
||||
Compress bool `yaml:"compress"` // 默认 true
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
DSN string `mapstructure:"dsn"`
|
||||
}
|
||||
|
||||
type RepositoryConfig struct {
|
||||
BasePoolTTLMinutes int `mapstructure:"base_pool_ttl_minutes"`
|
||||
BasePoolTTIMinutes int `mapstructure:"base_pool_tti_minutes"`
|
||||
}
|
||||
|
||||
// LoadConfig 从文件和环境变量加载配置
|
||||
func LoadConfig() (*Config, error) {
|
||||
// 设置配置文件名和路径
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
viper.AddConfigPath(".")
|
||||
|
||||
viper.AddConfigPath("/etc/gemini-balancer/") // for production
|
||||
// 允许从环境变量读取
|
||||
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
viper.AutomaticEnv()
|
||||
|
||||
// 设置默认值
|
||||
viper.SetDefault("server.port", "8080")
|
||||
viper.SetDefault("server.port", "9000")
|
||||
viper.SetDefault("log.level", "info")
|
||||
viper.SetDefault("log.format", "text")
|
||||
viper.SetDefault("log.enable_file", false)
|
||||
@@ -67,6 +81,9 @@ func LoadConfig() (*Config, error) {
|
||||
viper.SetDefault("database.conn_max_lifetime", "1h")
|
||||
viper.SetDefault("encryption_key", "")
|
||||
|
||||
viper.SetDefault("repository.base_pool_ttl_minutes", 60)
|
||||
viper.SetDefault("repository.base_pool_tti_minutes", 10)
|
||||
|
||||
// 读取配置文件
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
|
||||
@@ -2,14 +2,11 @@
|
||||
package container
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gemini-balancer/internal/app"
|
||||
"gemini-balancer/internal/channel"
|
||||
"gemini-balancer/internal/config"
|
||||
"gemini-balancer/internal/crypto"
|
||||
"gemini-balancer/internal/db"
|
||||
"gemini-balancer/internal/db/dialect"
|
||||
"gemini-balancer/internal/db/migrations"
|
||||
"gemini-balancer/internal/domain/proxy"
|
||||
"gemini-balancer/internal/domain/upstream"
|
||||
"gemini-balancer/internal/handlers"
|
||||
@@ -21,13 +18,10 @@ import (
|
||||
"gemini-balancer/internal/service"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/syncer"
|
||||
"gemini-balancer/internal/task"
|
||||
"gemini-balancer/internal/webhandlers"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"go.uber.org/dig"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func BuildContainer() (*dig.Container, error) {
|
||||
@@ -35,20 +29,9 @@ func BuildContainer() (*dig.Container, error) {
|
||||
|
||||
// =========== 阶段一: 基础设施层 (Infrastructure) ===========
|
||||
container.Provide(config.LoadConfig)
|
||||
|
||||
container.Provide(func(cfg *config.Config, logger *logrus.Logger) (*gorm.DB, dialect.DialectAdapter, error) {
|
||||
gormDB, adapter, err := db.NewDB(cfg, logger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
// 迁移运行逻辑
|
||||
if err := migrations.RunVersionedMigrations(gormDB, cfg, logger); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to run versioned migrations: %w", err)
|
||||
}
|
||||
return gormDB, adapter, nil
|
||||
})
|
||||
container.Provide(logging.NewLoggerWithWebSocket)
|
||||
container.Provide(db.NewDBWithMigrations)
|
||||
container.Provide(store.NewStore)
|
||||
container.Provide(logging.NewLogger)
|
||||
container.Provide(crypto.NewService)
|
||||
container.Provide(repository.NewAuthTokenRepository)
|
||||
container.Provide(repository.NewGroupRepository)
|
||||
@@ -85,14 +68,10 @@ func BuildContainer() (*dig.Container, error) {
|
||||
// --- Syncer & Loader for GroupManager ---
|
||||
container.Provide(service.NewGroupManagerLoader)
|
||||
// 为GroupManager配置Syncer
|
||||
container.Provide(func(loader syncer.LoaderFunc[service.GroupManagerCacheData], store store.Store, logger *logrus.Logger) (*syncer.CacheSyncer[service.GroupManagerCacheData], error) {
|
||||
const groupUpdateChannel = "groups:cache_invalidation"
|
||||
return syncer.NewCacheSyncer(loader, store, groupUpdateChannel)
|
||||
})
|
||||
container.Provide(service.NewGroupManagerSyncer)
|
||||
|
||||
// =========== 阶段三: 适配器与处理器层 (Handlers & Adapters) ===========
|
||||
|
||||
// 为Channel提供依赖 (Logger 和 *models.SystemSettings 数据插座)
|
||||
container.Provide(channel.NewGeminiChannel)
|
||||
container.Provide(func(ch *channel.GeminiChannel) channel.ChannelProxy { return ch })
|
||||
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gemini-balancer/internal/config"
|
||||
"gemini-balancer/internal/db/dialect"
|
||||
"gemini-balancer/internal/db/migrations"
|
||||
stdlog "log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -86,3 +88,16 @@ func NewDB(cfg *config.Config, appLogger *logrus.Logger) (*gorm.DB, dialect.Dial
|
||||
Logger.Info("Database connection established successfully.")
|
||||
return db, adapter, nil
|
||||
}
|
||||
|
||||
func NewDBWithMigrations(cfg *config.Config, logger *logrus.Logger) (*gorm.DB, dialect.DialectAdapter, error) {
|
||||
gormDB, adapter, err := NewDB(cfg, logger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := migrations.RunVersionedMigrations(gormDB, cfg, logger); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to run versioned migrations: %w", err)
|
||||
}
|
||||
|
||||
return gormDB, adapter, nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
@@ -49,7 +50,6 @@ func (h *handler) registerRoutes(rg *gin.RouterGroup) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- 请求 DTO ---
|
||||
type CreateProxyConfigRequest struct {
|
||||
Address string `json:"address" binding:"required"`
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
||||
@@ -64,12 +64,10 @@ type UpdateProxyConfigRequest struct {
|
||||
Description *string `json:"description"`
|
||||
}
|
||||
|
||||
// 单个检测的请求体 (与前端JS对齐)
|
||||
type CheckSingleProxyRequest struct {
|
||||
Proxy string `json:"proxy" binding:"required"`
|
||||
}
|
||||
|
||||
// 批量检测的请求体
|
||||
type CheckAllProxiesRequest struct {
|
||||
Proxies []string `json:"proxies" binding:"required"`
|
||||
}
|
||||
@@ -84,7 +82,7 @@ func (h *handler) CreateProxyConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
if req.Status == "" {
|
||||
req.Status = "active" // 默认状态
|
||||
req.Status = "active"
|
||||
}
|
||||
|
||||
proxyConfig := models.ProxyConfig{
|
||||
@@ -98,7 +96,6 @@ func (h *handler) CreateProxyConfig(c *gin.Context) {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
// 写操作后,发布事件并使缓存失效
|
||||
h.publishAndInvalidate(proxyConfig.ID, "created")
|
||||
response.Created(c, proxyConfig)
|
||||
}
|
||||
@@ -199,17 +196,16 @@ func (h *handler) DeleteProxyConfig(c *gin.Context) {
|
||||
response.NoContent(c)
|
||||
}
|
||||
|
||||
// publishAndInvalidate 统一事件发布和缓存失效逻辑
|
||||
func (h *handler) publishAndInvalidate(proxyID uint, action string) {
|
||||
go h.manager.invalidate()
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
event := models.ProxyStatusChangedEvent{ProxyID: proxyID, Action: action}
|
||||
eventData, _ := json.Marshal(event)
|
||||
_ = h.store.Publish(models.TopicProxyStatusChanged, eventData)
|
||||
_ = h.store.Publish(ctx, models.TopicProxyStatusChanged, eventData)
|
||||
}()
|
||||
}
|
||||
|
||||
// 新的 Handler 方法和 DTO
|
||||
type SyncProxiesRequest struct {
|
||||
Proxies []string `json:"proxies"`
|
||||
}
|
||||
@@ -220,14 +216,12 @@ func (h *handler) SyncProxies(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.manager.SyncProxiesInBackground(req.Proxies)
|
||||
|
||||
taskStatus, err := h.manager.SyncProxiesInBackground(c.Request.Context(), req.Proxies)
|
||||
if err != nil {
|
||||
|
||||
if errors.Is(err, ErrTaskConflict) {
|
||||
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
} else {
|
||||
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
|
||||
}
|
||||
return
|
||||
@@ -262,7 +256,7 @@ func (h *handler) CheckAllProxies(c *gin.Context) {
|
||||
|
||||
concurrency := cfg.ProxyCheckConcurrency
|
||||
if concurrency <= 0 {
|
||||
concurrency = 5 // 如果配置不合法,提供一个安全的默认值
|
||||
concurrency = 5
|
||||
}
|
||||
results := h.manager.CheckMultipleProxies(req.Proxies, timeout, concurrency)
|
||||
response.Success(c, results)
|
||||
|
||||
@@ -2,14 +2,13 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/syncer"
|
||||
"gemini-balancer/internal/task"
|
||||
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -25,7 +24,7 @@ import (
|
||||
|
||||
const (
|
||||
TaskTypeProxySync = "proxy_sync"
|
||||
proxyChunkSize = 200 // 代理同步的批量大小
|
||||
proxyChunkSize = 200
|
||||
)
|
||||
|
||||
type ProxyCheckResult struct {
|
||||
@@ -35,13 +34,11 @@ type ProxyCheckResult struct {
|
||||
ErrorMessage string `json:"error_message"`
|
||||
}
|
||||
|
||||
// managerCacheData
|
||||
type managerCacheData struct {
|
||||
ActiveProxies []*models.ProxyConfig
|
||||
ProxiesByID map[uint]*models.ProxyConfig
|
||||
}
|
||||
|
||||
// manager结构体
|
||||
type manager struct {
|
||||
db *gorm.DB
|
||||
syncer *syncer.CacheSyncer[managerCacheData]
|
||||
@@ -80,21 +77,21 @@ func newManager(db *gorm.DB, syncer *syncer.CacheSyncer[managerCacheData], taskR
|
||||
}
|
||||
}
|
||||
|
||||
func (m *manager) SyncProxiesInBackground(proxyStrings []string) (*task.Status, error) {
|
||||
func (m *manager) SyncProxiesInBackground(ctx context.Context, proxyStrings []string) (*task.Status, error) {
|
||||
resourceID := "global_proxy_sync"
|
||||
taskStatus, err := m.task.StartTask(0, TaskTypeProxySync, resourceID, len(proxyStrings), 0)
|
||||
taskStatus, err := m.task.StartTask(ctx, 0, TaskTypeProxySync, resourceID, len(proxyStrings), 0)
|
||||
if err != nil {
|
||||
return nil, ErrTaskConflict
|
||||
}
|
||||
go m.runProxySyncTask(taskStatus.ID, proxyStrings)
|
||||
go m.runProxySyncTask(context.Background(), taskStatus.ID, proxyStrings)
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (m *manager) runProxySyncTask(taskID string, finalProxyStrings []string) {
|
||||
func (m *manager) runProxySyncTask(ctx context.Context, taskID string, finalProxyStrings []string) {
|
||||
resourceID := "global_proxy_sync"
|
||||
var allProxies []models.ProxyConfig
|
||||
if err := m.db.Find(&allProxies).Error; err != nil {
|
||||
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err))
|
||||
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed to fetch current proxies: %w", err))
|
||||
return
|
||||
}
|
||||
currentProxyMap := make(map[string]uint)
|
||||
@@ -125,19 +122,19 @@ func (m *manager) runProxySyncTask(taskID string, finalProxyStrings []string) {
|
||||
}
|
||||
if len(idsToDelete) > 0 {
|
||||
if err := m.bulkDeleteByIDs(idsToDelete); err != nil {
|
||||
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err))
|
||||
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed during proxy deletion: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
if len(proxiesToAdd) > 0 {
|
||||
if err := m.bulkAdd(proxiesToAdd); err != nil {
|
||||
m.task.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err))
|
||||
m.task.EndTaskByID(ctx, taskID, resourceID, nil, fmt.Errorf("failed during proxy addition: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
result := gin.H{"added": len(proxiesToAdd), "deleted": len(idsToDelete), "final_total": len(finalProxyMap)}
|
||||
m.task.EndTaskByID(taskID, resourceID, result, nil)
|
||||
m.publishChangeEvent("proxies_synced")
|
||||
m.task.EndTaskByID(ctx, taskID, resourceID, result, nil)
|
||||
m.publishChangeEvent(ctx, "proxies_synced")
|
||||
go m.invalidate()
|
||||
}
|
||||
|
||||
@@ -184,14 +181,15 @@ func (m *manager) bulkDeleteByIDs(ids []uint) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *manager) bulkAdd(proxies []models.ProxyConfig) error {
|
||||
return m.db.CreateInBatches(proxies, proxyChunkSize).Error
|
||||
}
|
||||
|
||||
func (m *manager) publishChangeEvent(reason string) {
|
||||
func (m *manager) publishChangeEvent(ctx context.Context, reason string) {
|
||||
event := models.ProxyStatusChangedEvent{Action: reason}
|
||||
eventData, _ := json.Marshal(event)
|
||||
_ = m.store.Publish(models.TopicProxyStatusChanged, eventData)
|
||||
_ = m.store.Publish(ctx, models.TopicProxyStatusChanged, eventData)
|
||||
}
|
||||
|
||||
func (m *manager) assignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error) {
|
||||
@@ -313,3 +311,8 @@ func (m *manager) checkProxyConnectivity(proxyCfg *models.ProxyConfig, timeout t
|
||||
defer resp.Body.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
type Manager interface {
|
||||
AssignProxyIfNeeded(apiKey *models.APIKey) (*models.ProxyConfig, error)
|
||||
// ... 其他需要暴露给外部服务的方法
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ type Module struct {
|
||||
|
||||
func NewModule(gormDB *gorm.DB, store store.Store, settingsManager *settings.SettingsManager, taskReporter task.Reporter, logger *logrus.Logger) (*Module, error) {
|
||||
loader := newManagerLoader(gormDB)
|
||||
cacheSyncer, err := syncer.NewCacheSyncer(loader, store, "proxies:cache_invalidation")
|
||||
cacheSyncer, err := syncer.NewCacheSyncer(loader, store, "proxies:cache_invalidation", logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
type APIError struct {
|
||||
HTTPStatus int
|
||||
Code string
|
||||
Status string `json:"status,omitempty"`
|
||||
Message string
|
||||
}
|
||||
|
||||
@@ -36,6 +37,7 @@ var (
|
||||
ErrForbidden = &APIError{HTTPStatus: http.StatusForbidden, Code: "FORBIDDEN", Message: "You do not have permission to access this resource"}
|
||||
ErrTaskInProgress = &APIError{HTTPStatus: http.StatusConflict, Code: "TASK_IN_PROGRESS", Message: "A task is already in progress"}
|
||||
ErrBadGateway = &APIError{HTTPStatus: http.StatusBadGateway, Code: "BAD_GATEWAY", Message: "Upstream service error"}
|
||||
ErrGatewayTimeout = &APIError{HTTPStatus: http.StatusGatewayTimeout, Code: "BAD_GATEWAY_TIMEOUT", Message: "Bad gateway timeout"}
|
||||
ErrNoActiveKeys = &APIError{HTTPStatus: http.StatusServiceUnavailable, Code: "NO_ACTIVE_KEYS", Message: "No active API keys available for this group"}
|
||||
ErrMaxRetriesExceeded = &APIError{HTTPStatus: http.StatusBadGateway, Code: "MAX_RETRIES_EXCEEDED", Message: "Request failed after maximum retries"}
|
||||
ErrNoKeysAvailable = &APIError{HTTPStatus: http.StatusServiceUnavailable, Code: "NO_KEYS_AVAILABLE", Message: "No API keys available to process the request"}
|
||||
@@ -44,6 +46,7 @@ var (
|
||||
ErrGroupNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "GROUP_NOT_FOUND", Message: "The specified group was not found."}
|
||||
ErrPermissionDenied = &APIError{HTTPStatus: http.StatusForbidden, Code: "PERMISSION_DENIED", Message: "Permission denied for this operation."}
|
||||
ErrConfigurationError = &APIError{HTTPStatus: http.StatusInternalServerError, Code: "CONFIGURATION_ERROR", Message: "A configuration error prevents this request from being processed."}
|
||||
ErrProxyNotAvailable = &APIError{HTTPStatus: http.StatusNotFound, Code: "PROXY_ERROR", Message: "Required proxy is not available for this request."}
|
||||
|
||||
ErrStateConflictMasterRevoked = &APIError{HTTPStatus: http.StatusConflict, Code: "STATE_CONFLICT_MASTER_REVOKED", Message: "Cannot perform this operation on a revoked key."}
|
||||
ErrNotFound = &APIError{HTTPStatus: http.StatusNotFound, Code: "NOT_FOUND", Message: "Resource not found"}
|
||||
@@ -60,11 +63,13 @@ func NewAPIError(base *APIError, message string) *APIError {
|
||||
}
|
||||
|
||||
// NewAPIErrorWithUpstream creates a new APIError specifically for wrapping raw upstream errors.
|
||||
func NewAPIErrorWithUpstream(statusCode int, code string, upstreamMessage string) *APIError {
|
||||
func NewAPIErrorWithUpstream(statusCode int, code string, bodyBytes []byte) *APIError {
|
||||
msg, status := ParseUpstreamError(bodyBytes)
|
||||
return &APIError{
|
||||
HTTPStatus: statusCode,
|
||||
Code: code,
|
||||
Message: upstreamMessage,
|
||||
Message: msg,
|
||||
Status: status,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ var permanentErrorSubstrings = []string{
|
||||
"permission_denied", // Catches the 'status' field in Google's JSON error, e.g., "status": "PERMISSION_DENIED".
|
||||
"service_disabled", // Catches the 'reason' field for disabled APIs, e.g., "reason": "SERVICE_DISABLED".
|
||||
"api has not been used",
|
||||
"reported as leaked", // Leaked
|
||||
}
|
||||
|
||||
// --- 2. Temporary Errors ---
|
||||
@@ -44,8 +45,10 @@ var permanentErrorSubstrings = []string{
|
||||
// Action: Increment consecutive error count, potentially disable the key.
|
||||
var temporaryErrorSubstrings = []string{
|
||||
"quota",
|
||||
"Quota exceeded",
|
||||
"limit reached",
|
||||
"insufficient",
|
||||
"request limit",
|
||||
"billing",
|
||||
"exceeded",
|
||||
"too many requests",
|
||||
@@ -71,6 +74,25 @@ var clientNetworkErrorSubstrings = []string{
|
||||
"broken pipe",
|
||||
"use of closed network connection",
|
||||
"request canceled",
|
||||
"invalid query parameters", // 参数解析错误,归类为客户端错误
|
||||
}
|
||||
|
||||
// --- 5. Retryable Network/Gateway Errors ---
|
||||
// Errors that indicate temporary network or gateway issues, should retry with same or different key.
|
||||
// Action: Retry the request.
|
||||
var retryableNetworkErrorSubstrings = []string{
|
||||
"bad gateway",
|
||||
"service unavailable",
|
||||
"gateway timeout",
|
||||
"connection refused",
|
||||
"connection reset",
|
||||
"stream transmission interrupted", // ✅ 新增:流式传输中断
|
||||
"failed to establish stream", // ✅ 新增:流式连接建立失败
|
||||
"upstream connect error",
|
||||
"no healthy upstream",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
}
|
||||
|
||||
// IsPermanentUpstreamError checks if an upstream error indicates the key is permanently invalid.
|
||||
@@ -96,6 +118,11 @@ func IsClientNetworkError(err error) bool {
|
||||
return containsSubstring(err.Error(), clientNetworkErrorSubstrings)
|
||||
}
|
||||
|
||||
// IsRetryableNetworkError checks if an error is a temporary network/gateway issue.
|
||||
func IsRetryableNetworkError(msg string) bool {
|
||||
return containsSubstring(msg, retryableNetworkErrorSubstrings)
|
||||
}
|
||||
|
||||
// containsSubstring is a helper function to avoid code repetition.
|
||||
func containsSubstring(s string, substrings []string) bool {
|
||||
if s == "" {
|
||||
|
||||
@@ -2,7 +2,6 @@ package errors
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -10,67 +9,37 @@ const (
|
||||
maxErrorBodyLength = 2048
|
||||
)
|
||||
|
||||
// standardErrorResponse matches formats like: {"error": {"message": "..."}}
|
||||
type standardErrorResponse struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// vendorErrorResponse matches formats like: {"error_msg": "..."}
|
||||
type vendorErrorResponse struct {
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
}
|
||||
|
||||
// simpleErrorResponse matches formats like: {"error": "..."}
|
||||
type simpleErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// rootMessageErrorResponse matches formats like: {"message": "..."}
|
||||
type rootMessageErrorResponse struct {
|
||||
type upstreamErrorDetail struct {
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type upstreamErrorPayload struct {
|
||||
Error upstreamErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ParseUpstreamError attempts to parse a structured error message from an upstream response body
|
||||
func ParseUpstreamError(body []byte) string {
|
||||
// 1. Attempt to parse the standard OpenAI/Gemini format.
|
||||
var stdErr standardErrorResponse
|
||||
if err := json.Unmarshal(body, &stdErr); err == nil {
|
||||
if msg := strings.TrimSpace(stdErr.Error.Message); msg != "" {
|
||||
return truncateString(msg, maxErrorBodyLength)
|
||||
}
|
||||
func ParseUpstreamError(body []byte) (message string, status string) {
|
||||
if len(body) == 0 {
|
||||
return "Upstream returned an empty error body", ""
|
||||
}
|
||||
|
||||
// 2. Attempt to parse vendor-specific format (e.g., Baidu).
|
||||
var vendorErr vendorErrorResponse
|
||||
if err := json.Unmarshal(body, &vendorErr); err == nil {
|
||||
if msg := strings.TrimSpace(vendorErr.ErrorMsg); msg != "" {
|
||||
return truncateString(msg, maxErrorBodyLength)
|
||||
}
|
||||
// 优先级 1: 尝试解析 OpenAI 兼容接口返回的 `[{"error": {...}}]` 数组格式
|
||||
var arrayPayload []upstreamErrorPayload
|
||||
if err := json.Unmarshal(body, &arrayPayload); err == nil && len(arrayPayload) > 0 {
|
||||
detail := arrayPayload[0].Error
|
||||
return truncateString(detail.Message, maxErrorBodyLength), detail.Status
|
||||
}
|
||||
|
||||
// 3. Attempt to parse simple error format.
|
||||
var simpleErr simpleErrorResponse
|
||||
if err := json.Unmarshal(body, &simpleErr); err == nil {
|
||||
if msg := strings.TrimSpace(simpleErr.Error); msg != "" {
|
||||
return truncateString(msg, maxErrorBodyLength)
|
||||
}
|
||||
// 优先级 2: 尝试解析 Gemini 原生接口可能返回的 `{"error": {...}}` 单个对象格式
|
||||
var singlePayload upstreamErrorPayload
|
||||
if err := json.Unmarshal(body, &singlePayload); err == nil && singlePayload.Error.Message != "" {
|
||||
detail := singlePayload.Error
|
||||
return truncateString(detail.Message, maxErrorBodyLength), detail.Status
|
||||
}
|
||||
|
||||
// 4. Attempt to parse root-level message format.
|
||||
var rootMsgErr rootMessageErrorResponse
|
||||
if err := json.Unmarshal(body, &rootMsgErr); err == nil {
|
||||
if msg := strings.TrimSpace(rootMsgErr.Message); msg != "" {
|
||||
return truncateString(msg, maxErrorBodyLength)
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Graceful Degradation: If all parsing fails, return the raw (but safe) body.
|
||||
return truncateString(string(body), maxErrorBodyLength)
|
||||
// 最终回退: 对于无法识别的 JSON 或纯文本错误
|
||||
return truncateString(string(body), maxErrorBodyLength), ""
|
||||
}
|
||||
|
||||
// truncateString ensures a string does not exceed a maximum length.
|
||||
// truncateString remains unchanged.
|
||||
func truncateString(s string, maxLength int) string {
|
||||
if len(s) > maxLength {
|
||||
return s[:maxLength]
|
||||
|
||||
@@ -31,11 +31,10 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService, db *gorm.DB, keyImpo
|
||||
}
|
||||
}
|
||||
|
||||
// DTOs for API requests
|
||||
type BulkAddKeysToGroupRequest struct {
|
||||
KeyGroupID uint `json:"key_group_id" binding:"required"`
|
||||
Keys string `json:"keys" binding:"required"`
|
||||
ValidateOnImport bool `json:"validate_on_import"` // OmitEmpty/default is false
|
||||
ValidateOnImport bool `json:"validate_on_import"`
|
||||
}
|
||||
|
||||
type BulkUnlinkKeysFromGroupRequest struct {
|
||||
@@ -72,11 +71,11 @@ type BulkTestKeysForGroupRequest struct {
|
||||
}
|
||||
|
||||
type BulkActionFilter struct {
|
||||
Status []string `json:"status"` // Changed to slice to accept multiple statuses
|
||||
Status []string `json:"status"`
|
||||
}
|
||||
type BulkActionRequest struct {
|
||||
Action string `json:"action" binding:"required,oneof=revalidate set_status delete"`
|
||||
NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"` // For 'set_status' action
|
||||
NewStatus string `json:"new_status" binding:"omitempty,oneof=active disabled cooldown banned"`
|
||||
Filter BulkActionFilter `json:"filter" binding:"required"`
|
||||
}
|
||||
|
||||
@@ -89,7 +88,7 @@ func (h *APIKeyHandler) AddMultipleKeysToGroup(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyImportService.StartAddKeysTask(req.KeyGroupID, req.Keys, req.ValidateOnImport)
|
||||
taskStatus, err := h.keyImportService.StartAddKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys, req.ValidateOnImport)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
@@ -104,7 +103,7 @@ func (h *APIKeyHandler) UnlinkMultipleKeysFromGroup(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyImportService.StartUnlinkKeysTask(req.KeyGroupID, req.Keys)
|
||||
taskStatus, err := h.keyImportService.StartUnlinkKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
@@ -119,7 +118,7 @@ func (h *APIKeyHandler) HardDeleteMultipleKeys(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(req.Keys)
|
||||
taskStatus, err := h.keyImportService.StartHardDeleteKeysTask(c.Request.Context(), req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
@@ -134,7 +133,7 @@ func (h *APIKeyHandler) RestoreMultipleKeys(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyImportService.StartRestoreKeysTask(req.Keys)
|
||||
taskStatus, err := h.keyImportService.StartRestoreKeysTask(c.Request.Context(), req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
@@ -148,7 +147,7 @@ func (h *APIKeyHandler) TestMultipleKeys(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.keyValidationService.StartTestKeysTask(req.KeyGroupID, req.Keys)
|
||||
taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), req.KeyGroupID, req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
@@ -172,7 +171,7 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if len(ids) > 0 {
|
||||
keys, err := h.apiKeyService.GetKeysByIds(ids)
|
||||
keys, err := h.apiKeyService.GetKeysByIds(c.Request.Context(), ids)
|
||||
if err != nil {
|
||||
response.Error(c, &errors.APIError{
|
||||
HTTPStatus: http.StatusInternalServerError,
|
||||
@@ -191,7 +190,7 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
|
||||
if params.PageSize <= 0 {
|
||||
params.PageSize = 20
|
||||
}
|
||||
result, err := h.apiKeyService.ListAPIKeys(¶ms)
|
||||
result, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), ¶ms)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
@@ -201,19 +200,16 @@ func (h *APIKeyHandler) ListAPIKeys(c *gin.Context) {
|
||||
|
||||
// ListKeysForGroup handles the GET /keygroups/:id/keys request.
|
||||
func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
|
||||
// 1. Manually handle the path parameter.
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
|
||||
return
|
||||
}
|
||||
// 2. Bind query parameters using the correctly tagged struct.
|
||||
var params models.APIKeyQueryParams
|
||||
if err := c.ShouldBindQuery(¶ms); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
// 3. Set server-side defaults and the path parameter.
|
||||
if params.Page <= 0 {
|
||||
params.Page = 1
|
||||
}
|
||||
@@ -221,15 +217,11 @@ func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
|
||||
params.PageSize = 20
|
||||
}
|
||||
params.KeyGroupID = uint(groupID)
|
||||
// 4. Call the service layer.
|
||||
paginatedResult, err := h.apiKeyService.ListAPIKeys(¶ms)
|
||||
paginatedResult, err := h.apiKeyService.ListAPIKeys(c.Request.Context(), ¶ms)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
|
||||
// 5. [THE FIX] Return a successful response using the standard `response.Success`
|
||||
// and a gin.H map, as confirmed to exist in your project.
|
||||
response.Success(c, gin.H{
|
||||
"items": paginatedResult.Items,
|
||||
"total": paginatedResult.Total,
|
||||
@@ -239,20 +231,17 @@ func (h *APIKeyHandler) ListKeysForGroup(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *APIKeyHandler) TestKeysForGroup(c *gin.Context) {
|
||||
// Group ID is now correctly sourced from the URL path.
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid group ID format"))
|
||||
return
|
||||
}
|
||||
// The request body is now simpler, only needing the keys.
|
||||
var req BulkTestKeysForGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
// Call the same underlying service, but with unambiguous context.
|
||||
taskStatus, err := h.keyValidationService.StartTestKeysTask(uint(groupID), req.Keys)
|
||||
taskStatus, err := h.keyValidationService.StartTestKeysTask(c.Request.Context(), uint(groupID), req.Keys)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrTaskInProgress, err.Error()))
|
||||
return
|
||||
@@ -267,7 +256,6 @@ func (h *APIKeyHandler) UpdateAPIKey(c *gin.Context) {
|
||||
}
|
||||
|
||||
// UpdateGroupAPIKeyMapping handles updating a key's status within a specific group.
|
||||
// Route: PUT /keygroups/:id/apikeys/:keyId
|
||||
func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) {
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -284,8 +272,7 @@ func (h *APIKeyHandler) UpdateGroupAPIKeyMapping(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
// Directly use the service to handle the logic
|
||||
updatedMapping, err := h.apiKeyService.UpdateMappingStatus(uint(groupID), uint(keyID), req.Status)
|
||||
updatedMapping, err := h.apiKeyService.UpdateMappingStatus(c.Request.Context(), uint(groupID), uint(keyID), req.Status)
|
||||
if err != nil {
|
||||
var apiErr *errors.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
@@ -305,7 +292,7 @@ func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid ID format"))
|
||||
return
|
||||
}
|
||||
if err := h.apiKeyService.HardDeleteAPIKeyByID(uint(id)); err != nil {
|
||||
if err := h.apiKeyService.HardDeleteAPIKeyByID(c.Request.Context(), uint(id)); err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
}
|
||||
@@ -313,7 +300,6 @@ func (h *APIKeyHandler) HardDeleteAPIKey(c *gin.Context) {
|
||||
}
|
||||
|
||||
// RestoreKeysInGroup 恢复指定Key的接口
|
||||
// POST /keygroups/:id/apikeys/restore
|
||||
func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -325,7 +311,7 @@ func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.apiKeyService.StartRestoreKeysTask(uint(groupID), req.KeyIDs)
|
||||
taskStatus, err := h.apiKeyService.StartRestoreKeysTask(c.Request.Context(), uint(groupID), req.KeyIDs)
|
||||
if err != nil {
|
||||
var apiErr *errors.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
@@ -339,14 +325,13 @@ func (h *APIKeyHandler) RestoreKeysInGroup(c *gin.Context) {
|
||||
}
|
||||
|
||||
// RestoreAllBannedInGroup 一键恢复所有Banned Key的接口
|
||||
// POST /keygroups/:id/apikeys/restore-all-banned
|
||||
func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) {
|
||||
groupID, err := strconv.ParseUint(c.Param("groupId"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
||||
return
|
||||
}
|
||||
taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(uint(groupID))
|
||||
taskStatus, err := h.apiKeyService.StartRestoreAllBannedTask(c.Request.Context(), uint(groupID))
|
||||
if err != nil {
|
||||
var apiErr *errors.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
@@ -360,48 +345,39 @@ func (h *APIKeyHandler) RestoreAllBannedInGroup(c *gin.Context) {
|
||||
}
|
||||
|
||||
// HandleBulkAction handles generic bulk actions on a key group based on server-side filters.
|
||||
// Route: POST /keygroups/:id/bulk-actions
|
||||
func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) {
|
||||
// 1. Parse GroupID from URL
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
||||
return
|
||||
}
|
||||
// 2. Bind the JSON payload to our new DTO
|
||||
var req BulkActionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
return
|
||||
}
|
||||
// 3. Central logic: based on the action, call the appropriate service method.
|
||||
var task *task.Status
|
||||
var apiErr *errors.APIError
|
||||
switch req.Action {
|
||||
case "revalidate":
|
||||
// Assume keyValidationService has a method that accepts a filter
|
||||
task, err = h.keyValidationService.StartTestKeysByFilterTask(uint(groupID), req.Filter.Status)
|
||||
task, err = h.keyValidationService.StartTestKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status)
|
||||
case "set_status":
|
||||
if req.NewStatus == "" {
|
||||
apiErr = errors.NewAPIError(errors.ErrBadRequest, "new_status is required for set_status action")
|
||||
break
|
||||
}
|
||||
// Assume apiKeyService has a method to update status by filter
|
||||
targetStatus := models.APIKeyStatus(req.NewStatus) // Convert string to your model's type
|
||||
task, err = h.apiKeyService.StartUpdateStatusByFilterTask(uint(groupID), req.Filter.Status, targetStatus)
|
||||
targetStatus := models.APIKeyStatus(req.NewStatus)
|
||||
task, err = h.apiKeyService.StartUpdateStatusByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status, targetStatus)
|
||||
case "delete":
|
||||
// Assume keyImportService has a method to unlink by filter
|
||||
task, err = h.keyImportService.StartUnlinkKeysByFilterTask(uint(groupID), req.Filter.Status)
|
||||
task, err = h.keyImportService.StartUnlinkKeysByFilterTask(c.Request.Context(), uint(groupID), req.Filter.Status)
|
||||
default:
|
||||
apiErr = errors.NewAPIError(errors.ErrBadRequest, "Unsupported action: "+req.Action)
|
||||
}
|
||||
// 4. Handle errors from the switch block
|
||||
if apiErr != nil {
|
||||
response.Error(c, apiErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
// Attempt to parse it as a known APIError, otherwise, wrap it.
|
||||
var parsedErr *errors.APIError
|
||||
if errors.As(err, &parsedErr) {
|
||||
response.Error(c, parsedErr)
|
||||
@@ -410,21 +386,18 @@ func (h *APIKeyHandler) HandleBulkAction(c *gin.Context) {
|
||||
}
|
||||
return
|
||||
}
|
||||
// 5. Return the task status on success
|
||||
response.Success(c, task)
|
||||
}
|
||||
|
||||
// ExportKeysForGroup handles requests to export all keys for a group based on status filters.
|
||||
// Route: GET /keygroups/:id/apikeys/export
|
||||
func (h *APIKeyHandler) ExportKeysForGroup(c *gin.Context) {
|
||||
groupID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrBadRequest, "Invalid Group ID format"))
|
||||
return
|
||||
}
|
||||
// Use QueryArray to correctly parse `status[]=active&status[]=cooldown`
|
||||
statuses := c.QueryArray("status")
|
||||
keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(uint(groupID), statuses)
|
||||
keyStrings, err := h.apiKeyService.GetAPIKeyStringsForExport(c.Request.Context(), uint(groupID), statuses)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ParseDBError(err))
|
||||
return
|
||||
|
||||
@@ -30,7 +30,7 @@ func (h *DashboardHandler) GetOverview(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// GetChart 获取仪表盘的图表数据
|
||||
// GetChart
|
||||
func (h *DashboardHandler) GetChart(c *gin.Context) {
|
||||
var groupID *uint
|
||||
if groupIDStr := c.Query("groupId"); groupIDStr != "" {
|
||||
@@ -40,7 +40,7 @@ func (h *DashboardHandler) GetChart(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
chartData, err := h.queryService.QueryHistoricalChart(groupID)
|
||||
chartData, err := h.queryService.QueryHistoricalChart(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
apiErr := errors.NewAPIError(errors.ErrDatabase, err.Error())
|
||||
c.JSON(apiErr.HTTPStatus, apiErr)
|
||||
@@ -49,10 +49,10 @@ func (h *DashboardHandler) GetChart(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, chartData)
|
||||
}
|
||||
|
||||
// GetRequestStats 处理对“期间调用概览”的请求
|
||||
// GetRequestStats
|
||||
func (h *DashboardHandler) GetRequestStats(c *gin.Context) {
|
||||
period := c.Param("period") // 从 URL 路径中获取 period
|
||||
stats, err := h.queryService.GetRequestStatsForPeriod(period)
|
||||
period := c.Param("period")
|
||||
stats, err := h.queryService.GetRequestStatsForPeriod(c.Request.Context(), period)
|
||||
if err != nil {
|
||||
apiErr := errors.NewAPIError(errors.ErrBadRequest, err.Error())
|
||||
c.JSON(apiErr.HTTPStatus, apiErr)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/errors"
|
||||
@@ -31,7 +32,6 @@ func NewKeyGroupHandler(gm *service.GroupManager, s store.Store, qs *service.Das
|
||||
}
|
||||
}
|
||||
|
||||
// DTOs & 辅助函数
|
||||
func isValidGroupName(name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
@@ -40,7 +40,6 @@ func isValidGroupName(name string) bool {
|
||||
return match
|
||||
}
|
||||
|
||||
// KeyGroupOperationalSettings defines the shared operational settings for a key group.
|
||||
type KeyGroupOperationalSettings struct {
|
||||
EnableKeyCheck *bool `json:"enable_key_check"`
|
||||
KeyCheckIntervalMinutes *int `json:"key_check_interval_minutes"`
|
||||
@@ -52,7 +51,6 @@ type KeyGroupOperationalSettings struct {
|
||||
MaxRetries *int `json:"max_retries"`
|
||||
EnableSmartGateway *bool `json:"enable_smart_gateway"`
|
||||
}
|
||||
|
||||
type CreateKeyGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
DisplayName string `json:"display_name"`
|
||||
@@ -60,11 +58,8 @@ type CreateKeyGroupRequest struct {
|
||||
PollingStrategy string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
|
||||
EnableProxy bool `json:"enable_proxy"`
|
||||
ChannelType string `json:"channel_type"`
|
||||
|
||||
// Embed shared operational settings
|
||||
KeyGroupOperationalSettings
|
||||
}
|
||||
|
||||
type UpdateKeyGroupRequest struct {
|
||||
Name *string `json:"name"`
|
||||
DisplayName *string `json:"display_name"`
|
||||
@@ -72,15 +67,10 @@ type UpdateKeyGroupRequest struct {
|
||||
PollingStrategy *string `json:"polling_strategy" binding:"omitempty,oneof=sequential random weighted"`
|
||||
EnableProxy *bool `json:"enable_proxy"`
|
||||
ChannelType *string `json:"channel_type"`
|
||||
|
||||
// Embed shared operational settings
|
||||
KeyGroupOperationalSettings
|
||||
|
||||
// M:N associations
|
||||
AllowedUpstreams []string `json:"allowed_upstreams"`
|
||||
AllowedModels []string `json:"allowed_models"`
|
||||
}
|
||||
|
||||
type KeyGroupResponse struct {
|
||||
ID uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
@@ -96,36 +86,30 @@ type KeyGroupResponse struct {
|
||||
AllowedModels []string `json:"allowed_models"`
|
||||
AllowedUpstreams []string `json:"allowed_upstreams"`
|
||||
}
|
||||
|
||||
// [NEW] Define the detailed response structure for a single group.
|
||||
type KeyGroupDetailsResponse struct {
|
||||
KeyGroupResponse
|
||||
Settings *models.GroupSettings `json:"settings,omitempty"`
|
||||
RequestConfig *models.RequestConfig `json:"request_config,omitempty"`
|
||||
}
|
||||
|
||||
// transformModelsToStrings converts a slice of GroupModelMapping pointers to a slice of model names.
|
||||
func transformModelsToStrings(mappings []*models.GroupModelMapping) []string {
|
||||
modelNames := make([]string, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
if mapping != nil { // Safety check
|
||||
if mapping != nil {
|
||||
modelNames = append(modelNames, mapping.ModelName)
|
||||
}
|
||||
}
|
||||
return modelNames
|
||||
}
|
||||
|
||||
// transformUpstreamsToStrings converts a slice of UpstreamEndpoint pointers to a slice of URLs.
|
||||
func transformUpstreamsToStrings(upstreams []*models.UpstreamEndpoint) []string {
|
||||
urls := make([]string, 0, len(upstreams))
|
||||
for _, upstream := range upstreams {
|
||||
if upstream != nil { // Safety check
|
||||
if upstream != nil {
|
||||
urls = append(urls, upstream.URL)
|
||||
}
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount int64) KeyGroupResponse {
|
||||
return KeyGroupResponse{
|
||||
ID: group.ID,
|
||||
@@ -139,13 +123,10 @@ func (h *KeyGroupHandler) newKeyGroupResponse(group *models.KeyGroup, keyCount i
|
||||
CreatedAt: group.CreatedAt,
|
||||
UpdatedAt: group.UpdatedAt,
|
||||
Order: group.Order,
|
||||
AllowedModels: transformModelsToStrings(group.AllowedModels), // Call the new helper
|
||||
AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams), // Call the new helper
|
||||
AllowedModels: transformModelsToStrings(group.AllowedModels),
|
||||
AllowedUpstreams: transformUpstreamsToStrings(group.AllowedUpstreams),
|
||||
}
|
||||
}
|
||||
|
||||
// packGroupSettings is a helper to convert request-level operational settings
|
||||
// into the model-level settings struct.
|
||||
func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSettings {
|
||||
return &models.KeyGroupSettings{
|
||||
EnableKeyCheck: settings.EnableKeyCheck,
|
||||
@@ -159,7 +140,6 @@ func packGroupSettings(settings KeyGroupOperationalSettings) *models.KeyGroupSet
|
||||
EnableSmartGateway: settings.EnableSmartGateway,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup, *errors.APIError) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
@@ -171,7 +151,6 @@ func (h *KeyGroupHandler) getGroupFromContext(c *gin.Context) (*models.KeyGroup,
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGroup) {
|
||||
if req.Name != nil {
|
||||
group.Name = *req.Name
|
||||
@@ -197,9 +176,10 @@ func applyUpdateRequestToGroup(req *UpdateKeyGroupRequest, group *models.KeyGrou
|
||||
// publishGroupChangeEvent encapsulates the logic for marshaling and publishing a group change event.
|
||||
func (h *KeyGroupHandler) publishGroupChangeEvent(groupID uint, reason string) {
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
event := models.KeyStatusChangedEvent{GroupID: groupID, ChangeReason: reason}
|
||||
eventData, _ := json.Marshal(event)
|
||||
h.store.Publish(models.TopicKeyStatusChanged, eventData)
|
||||
_ = h.store.Publish(ctx, models.TopicKeyStatusChanged, eventData)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -216,7 +196,6 @@ func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// The core logic remains, as it's specific to creation.
|
||||
p := bluemonday.StripTagsPolicy()
|
||||
sanitizedDisplayName := p.Sanitize(req.DisplayName)
|
||||
sanitizedDescription := p.Sanitize(req.Description)
|
||||
@@ -244,11 +223,9 @@ func (h *KeyGroupHandler) CreateKeyGroup(c *gin.Context) {
|
||||
response.Created(c, h.newKeyGroupResponse(keyGroup, 0))
|
||||
}
|
||||
|
||||
// 统一的处理器可以处理两种情况:
|
||||
// 1. GET /keygroups - 返回所有组的列表
|
||||
// 2. GET /keygroups/:id - 返回指定ID的单个组
|
||||
func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
|
||||
// Case 1: Get a single group
|
||||
if idStr := c.Param("id"); idStr != "" {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
@@ -265,7 +242,6 @@ func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
|
||||
response.Success(c, detailedResponse)
|
||||
return
|
||||
}
|
||||
// Case 2: Get all groups
|
||||
allGroups := h.groupManager.GetAllGroups()
|
||||
responses := make([]KeyGroupResponse, 0, len(allGroups))
|
||||
for _, group := range allGroups {
|
||||
@@ -275,7 +251,6 @@ func (h *KeyGroupHandler) GetKeyGroups(c *gin.Context) {
|
||||
response.Success(c, responses)
|
||||
}
|
||||
|
||||
// UpdateKeyGroup
|
||||
func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
@@ -304,7 +279,6 @@ func (h *KeyGroupHandler) UpdateKeyGroup(c *gin.Context) {
|
||||
response.Success(c, h.newKeyGroupResponse(freshGroup, keyCount))
|
||||
}
|
||||
|
||||
// DeleteKeyGroup
|
||||
func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
@@ -320,14 +294,14 @@ func (h *KeyGroupHandler) DeleteKeyGroup(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": fmt.Sprintf("Group '%s' and its associated keys deleted successfully", groupName)})
|
||||
}
|
||||
|
||||
// GetKeyGroupStats
|
||||
func (h *KeyGroupHandler) GetKeyGroupStats(c *gin.Context) {
|
||||
group, apiErr := h.getGroupFromContext(c)
|
||||
if apiErr != nil {
|
||||
response.Error(c, apiErr)
|
||||
return
|
||||
}
|
||||
stats, err := h.queryService.GetGroupStats(group.ID)
|
||||
|
||||
stats, err := h.queryService.GetGroupStats(c.Request.Context(), group.ID)
|
||||
if err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrDatabase, err.Error()))
|
||||
return
|
||||
@@ -350,7 +324,6 @@ func (h *KeyGroupHandler) CloneKeyGroup(c *gin.Context) {
|
||||
response.Created(c, h.newKeyGroupResponse(clonedGroup, keyCount))
|
||||
}
|
||||
|
||||
// 更新分组排序
|
||||
func (h *KeyGroupHandler) UpdateKeyGroupOrder(c *gin.Context) {
|
||||
var payload []service.UpdateOrderPayload
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/response"
|
||||
"gemini-balancer/internal/service"
|
||||
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -19,22 +21,81 @@ func NewLogHandler(logService *service.LogService) *LogHandler {
|
||||
}
|
||||
|
||||
func (h *LogHandler) GetLogs(c *gin.Context) {
|
||||
// 调用新的服务函数,接收日志列表和总数
|
||||
logs, total, err := h.logService.GetLogs(c)
|
||||
queryParams := make(map[string]string)
|
||||
for key, values := range c.Request.URL.Query() {
|
||||
if len(values) > 0 {
|
||||
queryParams[key] = values[0]
|
||||
}
|
||||
}
|
||||
params, err := service.ParseLogQueryParams(queryParams)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ErrBadRequest)
|
||||
return
|
||||
}
|
||||
logs, total, err := h.logService.GetLogs(c.Request.Context(), params)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ErrDatabase)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"items": logs,
|
||||
"total": total,
|
||||
"page": params.Page,
|
||||
"page_size": params.PageSize,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteLogs 删除选定日志 DELETE /admin/logs?ids=1,2,3
|
||||
func (h *LogHandler) DeleteLogs(c *gin.Context) {
|
||||
idsStr := c.Query("ids")
|
||||
if idsStr == "" {
|
||||
response.Error(c, errors.ErrBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var ids []uint
|
||||
for _, idStr := range strings.Split(idsStr, ",") {
|
||||
if id, err := strconv.ParseUint(strings.TrimSpace(idStr), 10, 32); err == nil {
|
||||
ids = append(ids, uint(id))
|
||||
}
|
||||
}
|
||||
|
||||
if len(ids) == 0 {
|
||||
response.Error(c, errors.ErrBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.logService.DeleteLogs(c.Request.Context(), ids); err != nil {
|
||||
response.Error(c, errors.ErrDatabase)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"deleted": len(ids)})
|
||||
}
|
||||
|
||||
// DeleteAllLogs 删除全部日志 DELETE /admin/logs/all
|
||||
func (h *LogHandler) DeleteAllLogs(c *gin.Context) {
|
||||
if err := h.logService.DeleteAllLogs(c.Request.Context()); err != nil {
|
||||
response.Error(c, errors.ErrDatabase)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "all logs deleted"})
|
||||
}
|
||||
|
||||
// DeleteOldLogs 删除旧日志 DELETE /admin/logs/old?days=30
|
||||
func (h *LogHandler) DeleteOldLogs(c *gin.Context) {
|
||||
daysStr := c.Query("days")
|
||||
days, err := strconv.Atoi(daysStr)
|
||||
if err != nil || days <= 0 {
|
||||
response.Error(c, errors.ErrBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
deleted, err := h.logService.DeleteOldLogs(c.Request.Context(), days)
|
||||
if err != nil {
|
||||
response.Error(c, errors.ErrDatabase)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析分页参数用于响应体
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
// 使用标准的分页响应结构
|
||||
response.Success(c, gin.H{
|
||||
"items": logs,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
response.Success(c, gin.H{"deleted": deleted, "days": days})
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,6 +7,7 @@ import (
|
||||
"gemini-balancer/internal/settings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type SettingHandler struct {
|
||||
@@ -23,16 +24,35 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
var newSettingsMap map[string]interface{}
|
||||
if err := c.ShouldBindJSON(&newSettingsMap); err != nil {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, err.Error()))
|
||||
logrus.WithError(err).Error("Failed to bind JSON in UpdateSettings")
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, "Invalid JSON: "+err.Error()))
|
||||
return
|
||||
}
|
||||
if err := h.settingsManager.UpdateSettings(newSettingsMap); err != nil {
|
||||
// TODO 可以根据错误类型返回更具体的错误
|
||||
|
||||
logrus.Debugf("Received settings update: %+v", newSettingsMap)
|
||||
|
||||
validKeys := make(map[string]interface{})
|
||||
for key, value := range newSettingsMap {
|
||||
if _, exists := h.settingsManager.IsValidKey(key); exists {
|
||||
validKeys[key] = value
|
||||
} else {
|
||||
logrus.Warnf("Invalid key received: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
if len(validKeys) == 0 {
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInvalidJSON, "No valid settings keys provided"))
|
||||
return
|
||||
}
|
||||
|
||||
logrus.Debugf("Valid keys to update: %+v", validKeys)
|
||||
|
||||
if err := h.settingsManager.UpdateSettings(validKeys); err != nil {
|
||||
logrus.WithError(err).Error("Failed to update settings")
|
||||
response.Error(c, errors.NewAPIError(errors.ErrInternalServer, err.Error()))
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "Settings update request processed successfully."})
|
||||
|
||||
response.Success(c, gin.H{"message": "Settings updated successfully"})
|
||||
}
|
||||
|
||||
// ResetSettingsToDefaults resets all settings to their default values
|
||||
|
||||
@@ -33,7 +33,7 @@ func (h *TaskHandler) GetTaskStatus(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
taskStatus, err := h.taskService.GetStatus(taskID)
|
||||
taskStatus, err := h.taskService.GetStatus(c.Request.Context(), taskID)
|
||||
if err != nil {
|
||||
// TODO 可以根据 service 层返回的具体错误类型进行更精细的处理
|
||||
response.Error(c, errors.NewAPIError(errors.ErrResourceNotFound, err.Error()))
|
||||
|
||||
68
internal/handlers/websocket_handler.go
Normal file
68
internal/handlers/websocket_handler.go
Normal file
@@ -0,0 +1,68 @@
|
||||
// Filename: internal/handlers/websocket_handler.go
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type connWrapper struct {
|
||||
conn *websocket.Conn
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type WebSocketHandler struct {
|
||||
logger *logrus.Logger
|
||||
clients sync.Map
|
||||
upgrader websocket.Upgrader
|
||||
}
|
||||
|
||||
func NewWebSocketHandler(logger *logrus.Logger) *WebSocketHandler {
|
||||
return &WebSocketHandler{
|
||||
logger: logger,
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WebSocketHandler) HandleSystemLogs(c *gin.Context) {
|
||||
conn, err := h.upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("Failed to upgrade websocket")
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
clientID := time.Now().UnixNano()
|
||||
h.clients.Store(clientID, &connWrapper{conn: conn})
|
||||
defer h.clients.Delete(clientID)
|
||||
|
||||
for {
|
||||
if _, _, err := conn.ReadMessage(); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WebSocketHandler) BroadcastLog(entry *logrus.Entry) {
|
||||
msg := map[string]interface{}{
|
||||
"timestamp": entry.Time.Format(time.RFC3339),
|
||||
"level": entry.Level.String(),
|
||||
"message": entry.Message,
|
||||
"fields": entry.Data,
|
||||
}
|
||||
|
||||
h.clients.Range(func(key, value interface{}) bool {
|
||||
wrapper := value.(*connWrapper)
|
||||
wrapper.mu.Lock()
|
||||
wrapper.conn.WriteJSON(msg)
|
||||
wrapper.mu.Unlock()
|
||||
return true
|
||||
})
|
||||
}
|
||||
@@ -9,20 +9,25 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
// 包级变量,用于存储日志轮转器
|
||||
var logRotator *lumberjack.Logger
|
||||
|
||||
// NewLogger 返回标准的 *logrus.Logger(兼容 Fx 依赖注入)
|
||||
func NewLogger(cfg *config.Config) *logrus.Logger {
|
||||
logger := logrus.New()
|
||||
|
||||
// 1. 设置日志级别
|
||||
// 设置日志级别
|
||||
level, err := logrus.ParseLevel(cfg.Log.Level)
|
||||
if err != nil {
|
||||
logger.WithField("configured_level", cfg.Log.Level).Warn("Invalid log level specified, defaulting to 'info'.")
|
||||
logger.WithField("configured_level", cfg.Log.Level).Warn("Invalid log level, defaulting to 'info'")
|
||||
level = logrus.InfoLevel
|
||||
}
|
||||
logger.SetLevel(level)
|
||||
|
||||
// 2. 设置日志格式
|
||||
// 设置日志格式
|
||||
if cfg.Log.Format == "json" {
|
||||
logger.SetFormatter(&logrus.JSONFormatter{
|
||||
TimestampFormat: "2006-01-02T15:04:05.000Z07:00",
|
||||
@@ -39,36 +44,57 @@ func NewLogger(cfg *config.Config) *logrus.Logger {
|
||||
})
|
||||
}
|
||||
|
||||
// 3. 设置日志输出
|
||||
// 添加全局字段
|
||||
hostname, _ := os.Hostname()
|
||||
logger = logger.WithFields(logrus.Fields{
|
||||
"service": "gemini-balancer",
|
||||
"hostname": hostname,
|
||||
}).Logger
|
||||
|
||||
// 设置日志输出
|
||||
if cfg.Log.EnableFile {
|
||||
if cfg.Log.FilePath == "" {
|
||||
logger.Warn("Log file is enabled but no file path is specified. Logging to console only.")
|
||||
logger.Warn("Log file enabled but no path specified. Logging to console only")
|
||||
logger.SetOutput(os.Stdout)
|
||||
return logger
|
||||
}
|
||||
|
||||
logDir := filepath.Dir(cfg.Log.FilePath)
|
||||
if err := os.MkdirAll(logDir, 0755); err != nil {
|
||||
logger.WithError(err).Warn("Failed to create log directory. Logging to console only.")
|
||||
if err := os.MkdirAll(logDir, 0750); err != nil {
|
||||
logger.WithError(err).Warn("Failed to create log directory. Logging to console only")
|
||||
logger.SetOutput(os.Stdout)
|
||||
return logger
|
||||
}
|
||||
|
||||
logFile, err := os.OpenFile(cfg.Log.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Failed to open log file. Logging to console only.")
|
||||
logger.SetOutput(os.Stdout)
|
||||
return logger
|
||||
// 配置日志轮转(保存到包级变量)
|
||||
logRotator = &lumberjack.Logger{
|
||||
Filename: cfg.Log.FilePath,
|
||||
MaxSize: getOrDefault(cfg.Log.MaxSize, 100),
|
||||
MaxBackups: getOrDefault(cfg.Log.MaxBackups, 7),
|
||||
MaxAge: getOrDefault(cfg.Log.MaxAge, 30),
|
||||
Compress: cfg.Log.Compress,
|
||||
}
|
||||
|
||||
// 同时输出到控制台和文件
|
||||
logger.SetOutput(io.MultiWriter(os.Stdout, logFile))
|
||||
logger.WithField("log_file_path", cfg.Log.FilePath).Info("Logging is now configured to output to both console and file.")
|
||||
logger.SetOutput(io.MultiWriter(os.Stdout, logRotator))
|
||||
logger.WithField("log_file", cfg.Log.FilePath).Info("Logging to both console and file")
|
||||
} else {
|
||||
// 仅输出到控制台
|
||||
logger.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
logger.Info("Root logger initialized.")
|
||||
logger.Info("Logger initialized successfully")
|
||||
return logger
|
||||
}
|
||||
|
||||
// Close 关闭日志轮转器(在 main.go 中调用)
|
||||
func Close() {
|
||||
if logRotator != nil {
|
||||
logRotator.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func getOrDefault(value, defaultValue int) int {
|
||||
if value <= 0 {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
36
internal/logging/websocket_hook.go
Normal file
36
internal/logging/websocket_hook.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Filename: internal/logging/websocket_hook.go
|
||||
package logging
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/config"
|
||||
"gemini-balancer/internal/handlers"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type WebSocketHook struct {
|
||||
broadcaster func(*logrus.Entry)
|
||||
}
|
||||
|
||||
func NewWebSocketHook(broadcaster func(*logrus.Entry)) *WebSocketHook {
|
||||
return &WebSocketHook{broadcaster: broadcaster}
|
||||
}
|
||||
|
||||
func (h *WebSocketHook) Levels() []logrus.Level {
|
||||
return logrus.AllLevels
|
||||
}
|
||||
|
||||
func (h *WebSocketHook) Fire(entry *logrus.Entry) error {
|
||||
if h.broadcaster != nil {
|
||||
go h.broadcaster(entry)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewLoggerWithWebSocket(cfg *config.Config) (*logrus.Logger, *handlers.WebSocketHandler) {
|
||||
logger := NewLogger(cfg)
|
||||
wsHandler := handlers.NewWebSocketHandler(logger)
|
||||
hook := NewWebSocketHook(wsHandler.BroadcastLog)
|
||||
logger.AddHook(hook)
|
||||
return logger, wsHandler
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
// Filename: internal/middleware/auth.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
@@ -7,76 +8,115 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// === API Admin 认证管道 (/admin/* API路由) ===
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
func APIAdminAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc {
|
||||
// APIAdminAuthMiddleware 管理后台 API 认证
|
||||
func APIAdminAuthMiddleware(
|
||||
securityService *service.SecurityService,
|
||||
logger *logrus.Logger,
|
||||
) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
tokenValue := extractBearerToken(c)
|
||||
if tokenValue == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization token is missing"})
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
|
||||
Error: "Authentication required",
|
||||
Code: "AUTH_MISSING",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// ✅ 只传 token 参数(移除 context)
|
||||
authToken, err := securityService.AuthenticateToken(tokenValue)
|
||||
if err != nil || !authToken.IsAdmin {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or non-admin token"})
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Authentication failed")
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
|
||||
Error: "Invalid authentication",
|
||||
Code: "AUTH_INVALID",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !authToken.IsAdmin {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, ErrorResponse{
|
||||
Error: "Admin access required",
|
||||
Code: "AUTH_FORBIDDEN",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("adminUser", authToken)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// === /v1 Proxy 认证 ===
|
||||
|
||||
func ProxyAuthMiddleware(securityService *service.SecurityService) gin.HandlerFunc {
|
||||
// ProxyAuthMiddleware 代理请求认证
|
||||
func ProxyAuthMiddleware(
|
||||
securityService *service.SecurityService,
|
||||
logger *logrus.Logger,
|
||||
) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
tokenValue := extractProxyToken(c)
|
||||
if tokenValue == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "API key is missing from request"})
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
|
||||
Error: "API key required",
|
||||
Code: "KEY_MISSING",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// ✅ 只传 token 参数(移除 context)
|
||||
authToken, err := securityService.AuthenticateToken(tokenValue)
|
||||
if err != nil {
|
||||
// 通用信息,避免泄露过多信息
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid or inactive token provided"})
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse{
|
||||
Error: "Invalid API key",
|
||||
Code: "KEY_INVALID",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("authToken", authToken)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// extractProxyToken 按优先级提取 token
|
||||
func extractProxyToken(c *gin.Context) string {
|
||||
if key := c.Query("key"); key != "" {
|
||||
return key
|
||||
}
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
// 优先级 1: Authorization Header
|
||||
if token := extractBearerToken(c); token != "" {
|
||||
return token
|
||||
}
|
||||
|
||||
// 优先级 2: X-Api-Key
|
||||
if key := c.GetHeader("X-Api-Key"); key != "" {
|
||||
return key
|
||||
}
|
||||
|
||||
// 优先级 3: X-Goog-Api-Key
|
||||
if key := c.GetHeader("X-Goog-Api-Key"); key != "" {
|
||||
return key
|
||||
}
|
||||
return ""
|
||||
|
||||
// 优先级 4: Query 参数(不推荐)
|
||||
return c.Query("key")
|
||||
}
|
||||
|
||||
// === 辅助函数 ===
|
||||
|
||||
// extractBearerToken 提取 Bearer Token
|
||||
func extractBearerToken(c *gin.Context) string {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||
return parts[1]
|
||||
|
||||
const prefix = "Bearer "
|
||||
if !strings.HasPrefix(authHeader, prefix) {
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimSpace(authHeader[len(prefix):])
|
||||
}
|
||||
90
internal/middleware/cors.go
Normal file
90
internal/middleware/cors.go
Normal file
@@ -0,0 +1,90 @@
|
||||
// Filename: internal/middleware/cors.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type CORSConfig struct {
|
||||
AllowedOrigins []string
|
||||
AllowedMethods []string
|
||||
AllowedHeaders []string
|
||||
ExposedHeaders []string
|
||||
AllowCredentials bool
|
||||
MaxAge int
|
||||
}
|
||||
|
||||
func CORSMiddleware(config CORSConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
|
||||
// 检查 origin 是否允许
|
||||
if origin != "" && isOriginAllowed(origin, config.AllowedOrigins) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
if config.AllowCredentials {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
|
||||
if len(config.ExposedHeaders) > 0 {
|
||||
c.Writer.Header().Set("Access-Control-Expose-Headers",
|
||||
strings.Join(config.ExposedHeaders, ", "))
|
||||
}
|
||||
|
||||
// 处理预检请求
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
if len(config.AllowedMethods) > 0 {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods",
|
||||
strings.Join(config.AllowedMethods, ", "))
|
||||
}
|
||||
|
||||
if len(config.AllowedHeaders) > 0 {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers",
|
||||
strings.Join(config.AllowedHeaders, ", "))
|
||||
}
|
||||
|
||||
if config.MaxAge > 0 {
|
||||
c.Writer.Header().Set("Access-Control-Max-Age",
|
||||
string(rune(config.MaxAge)))
|
||||
}
|
||||
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func isOriginAllowed(origin string, allowedOrigins []string) bool {
|
||||
for _, allowed := range allowedOrigins {
|
||||
if allowed == "*" || allowed == origin {
|
||||
return true
|
||||
}
|
||||
// 支持通配符子域名
|
||||
if strings.HasPrefix(allowed, "*.") {
|
||||
domain := strings.TrimPrefix(allowed, "*.")
|
||||
if strings.HasSuffix(origin, domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 使用示例
|
||||
func SetupCORS(r *gin.Engine) {
|
||||
r.Use(CORSMiddleware(CORSConfig{
|
||||
AllowedOrigins: []string{"https://yourdomain.com", "*.yourdomain.com"},
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Authorization", "Content-Type", "X-Api-Key"},
|
||||
ExposedHeaders: []string{"X-Request-Id"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 3600,
|
||||
}))
|
||||
}
|
||||
@@ -1,84 +1,213 @@
|
||||
// Filename: internal/middleware/log_redaction.go
|
||||
// Filename: internal/middleware/logging.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const RedactedBodyKey = "redactedBody"
|
||||
const RedactedAuthHeaderKey = "redactedAuthHeader"
|
||||
const RedactedValue = `"[REDACTED]"`
|
||||
const (
|
||||
RedactedBodyKey = "redactedBody"
|
||||
RedactedAuthHeaderKey = "redactedAuthHeader"
|
||||
RedactedValue = `"[REDACTED]"`
|
||||
)
|
||||
|
||||
// 预编译正则表达式(全局变量,提升性能)
|
||||
var (
|
||||
// JSON 敏感字段脱敏
|
||||
jsonSensitiveKeys = regexp.MustCompile(`("(?i:api_key|apikey|token|password|secret|authorization|key|keys|auth)"\s*:\s*)"[^"]*"`)
|
||||
|
||||
// Bearer Token 脱敏
|
||||
bearerTokenPattern = regexp.MustCompile(`^(Bearer\s+)\S+$`)
|
||||
|
||||
// URL 中的 key 参数脱敏
|
||||
queryKeyPattern = regexp.MustCompile(`([?&](?i:key|token|apikey)=)[^&\s]+`)
|
||||
)
|
||||
|
||||
// RedactionMiddleware 请求数据脱敏中间件
|
||||
func RedactionMiddleware() gin.HandlerFunc {
|
||||
// Pre-compile regex for efficiency
|
||||
jsonKeyPattern := regexp.MustCompile(`("api_key"|"keys")\s*:\s*"[^"]*"`)
|
||||
bearerTokenPattern := regexp.MustCompile(`^(Bearer\s+)\S+$`)
|
||||
return func(c *gin.Context) {
|
||||
// --- 1. Redact Request Body ---
|
||||
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "DELETE" {
|
||||
if bodyBytes, err := io.ReadAll(c.Request.Body); err == nil {
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
bodyString := string(bodyBytes)
|
||||
|
||||
redactedBody := jsonKeyPattern.ReplaceAllString(bodyString, `$1:`+RedactedValue)
|
||||
|
||||
c.Set(RedactedBodyKey, redactedBody)
|
||||
}
|
||||
}
|
||||
// --- 2. Redact Authorization Header ---
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
if bearerTokenPattern.MatchString(authHeader) {
|
||||
redactedHeader := bearerTokenPattern.ReplaceAllString(authHeader, `${1}[REDACTED]`)
|
||||
c.Set(RedactedAuthHeaderKey, redactedHeader)
|
||||
}
|
||||
// 1. 脱敏请求体
|
||||
if shouldRedactBody(c) {
|
||||
redactRequestBody(c)
|
||||
}
|
||||
|
||||
// 2. 脱敏认证头
|
||||
redactAuthHeader(c)
|
||||
|
||||
// 3. 脱敏 URL 查询参数
|
||||
redactQueryParams(c)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// LogrusLogger is a Gin middleware that logs requests using a Logrus logger.
|
||||
// It consumes redacted data prepared by the RedactionMiddleware.
|
||||
// shouldRedactBody 判断是否需要脱敏请求体
|
||||
func shouldRedactBody(c *gin.Context) bool {
|
||||
method := c.Request.Method
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
|
||||
// 只处理包含 JSON 的 POST/PUT/PATCH/DELETE 请求
|
||||
return (method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE") &&
|
||||
strings.Contains(contentType, "application/json")
|
||||
}
|
||||
|
||||
// redactRequestBody 脱敏请求体
|
||||
func redactRequestBody(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 恢复请求体供后续使用
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// 脱敏敏感字段
|
||||
bodyString := string(bodyBytes)
|
||||
redactedBody := jsonSensitiveKeys.ReplaceAllString(bodyString, `$1`+RedactedValue)
|
||||
|
||||
c.Set(RedactedBodyKey, redactedBody)
|
||||
}
|
||||
|
||||
// redactAuthHeader 脱敏认证头
|
||||
func redactAuthHeader(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if bearerTokenPattern.MatchString(authHeader) {
|
||||
redacted := bearerTokenPattern.ReplaceAllString(authHeader, `${1}[REDACTED]`)
|
||||
c.Set(RedactedAuthHeaderKey, redacted)
|
||||
} else {
|
||||
// 对于非 Bearer 的 token,全部脱敏
|
||||
c.Set(RedactedAuthHeaderKey, "[REDACTED]")
|
||||
}
|
||||
|
||||
// 同时处理其他敏感 Header
|
||||
sensitiveHeaders := []string{"X-Api-Key", "X-Goog-Api-Key", "Api-Key"}
|
||||
for _, header := range sensitiveHeaders {
|
||||
if value := c.GetHeader(header); value != "" {
|
||||
c.Set("redacted_"+header, "[REDACTED]")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// redactQueryParams 脱敏 URL 查询参数
|
||||
func redactQueryParams(c *gin.Context) {
|
||||
rawQuery := c.Request.URL.RawQuery
|
||||
if rawQuery == "" {
|
||||
return
|
||||
}
|
||||
|
||||
redacted := queryKeyPattern.ReplaceAllString(rawQuery, `${1}[REDACTED]`)
|
||||
if redacted != rawQuery {
|
||||
c.Set("redactedQuery", redacted)
|
||||
}
|
||||
}
|
||||
|
||||
// LogrusLogger Gin 请求日志中间件(使用 Logrus)
|
||||
func LogrusLogger(logger *logrus.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
method := c.Request.Method
|
||||
|
||||
// Process request
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// After request, gather data and log
|
||||
// 计算延迟
|
||||
latency := time.Since(start)
|
||||
statusCode := c.Writer.Status()
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
entry := logger.WithFields(logrus.Fields{
|
||||
"status_code": statusCode,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"client_ip": c.ClientIP(),
|
||||
"method": c.Request.Method,
|
||||
"path": path,
|
||||
})
|
||||
// 构建日志字段
|
||||
fields := logrus.Fields{
|
||||
"status": statusCode,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"ip": clientIP,
|
||||
"method": method,
|
||||
"path": path,
|
||||
}
|
||||
|
||||
// 添加请求 ID(如果存在)
|
||||
if requestID := getRequestID(c); requestID != "" {
|
||||
fields["request_id"] = requestID
|
||||
}
|
||||
|
||||
// 添加脱敏后的数据
|
||||
if redactedBody, exists := c.Get(RedactedBodyKey); exists {
|
||||
entry = entry.WithField("body", redactedBody)
|
||||
fields["body"] = redactedBody
|
||||
}
|
||||
|
||||
if redactedAuth, exists := c.Get(RedactedAuthHeaderKey); exists {
|
||||
entry = entry.WithField("authorization", redactedAuth)
|
||||
fields["authorization"] = redactedAuth
|
||||
}
|
||||
|
||||
if redactedQuery, exists := c.Get("redactedQuery"); exists {
|
||||
fields["query"] = redactedQuery
|
||||
}
|
||||
|
||||
// 添加用户信息(如果已认证)
|
||||
if user := getAuthenticatedUser(c); user != "" {
|
||||
fields["user"] = user
|
||||
}
|
||||
|
||||
// 根据状态码选择日志级别
|
||||
entry := logger.WithFields(fields)
|
||||
|
||||
if len(c.Errors) > 0 {
|
||||
entry.Error(c.Errors.String())
|
||||
fields["errors"] = c.Errors.String()
|
||||
entry.Error("Request failed")
|
||||
} else {
|
||||
entry.Info("request handled")
|
||||
switch {
|
||||
case statusCode >= 500:
|
||||
entry.Error("Server error")
|
||||
case statusCode >= 400:
|
||||
entry.Warn("Client error")
|
||||
case statusCode >= 300:
|
||||
entry.Info("Redirect")
|
||||
default:
|
||||
// 只在 Debug 模式记录成功请求
|
||||
if logger.Level >= logrus.DebugLevel {
|
||||
entry.Debug("Request completed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getRequestID 获取请求 ID
|
||||
func getRequestID(c *gin.Context) string {
|
||||
if id, exists := c.Get("request_id"); exists {
|
||||
if requestID, ok := id.(string); ok {
|
||||
return requestID
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getAuthenticatedUser 获取已认证用户标识
|
||||
func getAuthenticatedUser(c *gin.Context) string {
|
||||
// 尝试从不同来源获取用户信息
|
||||
if user, exists := c.Get("adminUser"); exists {
|
||||
if authToken, ok := user.(interface{ GetID() string }); ok {
|
||||
return authToken.GetID()
|
||||
}
|
||||
}
|
||||
|
||||
if user, exists := c.Get("authToken"); exists {
|
||||
if authToken, ok := user.(interface{ GetID() string }); ok {
|
||||
return authToken.GetID()
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
86
internal/middleware/rate_limit.go
Normal file
86
internal/middleware/rate_limit.go
Normal file
@@ -0,0 +1,86 @@
|
||||
// Filename: internal/middleware/rate_limit.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type RateLimiter struct {
|
||||
limiters map[string]*rate.Limiter
|
||||
mu sync.RWMutex
|
||||
r rate.Limit // 每秒请求数
|
||||
b int // 突发容量
|
||||
}
|
||||
|
||||
func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
r: r,
|
||||
b: b,
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) getLimiter(key string) *rate.Limiter {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
limiter, exists := rl.limiters[key]
|
||||
if !exists {
|
||||
limiter = rate.NewLimiter(rl.r, rl.b)
|
||||
rl.limiters[key] = limiter
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// 定期清理不活跃的限制器
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rl.mu.Lock()
|
||||
// 简单策略:定期清空(生产环境应该用 LRU)
|
||||
rl.limiters = make(map[string]*rate.Limiter)
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func RateLimitMiddleware(limiter *RateLimiter) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 按 IP 限流
|
||||
key := c.ClientIP()
|
||||
|
||||
// 如果有认证 token,按 token 限流(更精确)
|
||||
if authToken, exists := c.Get("authToken"); exists {
|
||||
if token, ok := authToken.(interface{ GetID() string }); ok {
|
||||
key = "token:" + token.GetID()
|
||||
}
|
||||
}
|
||||
|
||||
l := limiter.getLimiter(key)
|
||||
if !l.Allow() {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "Rate limit exceeded",
|
||||
"code": "RATE_LIMIT",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// 使用示例
|
||||
func SetupRateLimit(r *gin.Engine) {
|
||||
limiter := NewRateLimiter(10, 20) // 每秒 10 个请求,突发 20
|
||||
go limiter.cleanup()
|
||||
|
||||
r.Use(RateLimitMiddleware(limiter))
|
||||
}
|
||||
39
internal/middleware/request_id.go
Normal file
39
internal/middleware/request_id.go
Normal file
@@ -0,0 +1,39 @@
|
||||
// Filename: internal/middleware/request_id.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// RequestIDMiddleware 请求 ID 追踪中间件
|
||||
func RequestIDMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 1. 尝试从 Header 获取现有的 Request ID
|
||||
requestID := c.GetHeader("X-Request-Id")
|
||||
|
||||
// 2. 如果没有,生成新的
|
||||
if requestID == "" {
|
||||
requestID = uuid.New().String()
|
||||
}
|
||||
|
||||
// 3. 设置到 Context
|
||||
c.Set("request_id", requestID)
|
||||
|
||||
// 4. 返回给客户端(用于追踪)
|
||||
c.Writer.Header().Set("X-Request-Id", requestID)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetRequestID 获取当前请求的 Request ID
|
||||
func GetRequestID(c *gin.Context) string {
|
||||
if id, exists := c.Get("request_id"); exists {
|
||||
if requestID, ok := id.(string); ok {
|
||||
return requestID
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
// Filename: internal/middleware/security.go
|
||||
// Filename: internal/middleware/security.go (简化版)
|
||||
|
||||
package middleware
|
||||
|
||||
@@ -6,26 +6,136 @@ import (
|
||||
"gemini-balancer/internal/service"
|
||||
"gemini-balancer/internal/settings"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func IPBanMiddleware(securityService *service.SecurityService, settingsManager *settings.SettingsManager) gin.HandlerFunc {
|
||||
// 简单的缓存项
|
||||
type cacheItem struct {
|
||||
value bool
|
||||
expiration int64
|
||||
}
|
||||
|
||||
// 简单的 TTL 缓存实现
|
||||
type IPBanCache struct {
|
||||
items map[string]*cacheItem
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func NewIPBanCache() *IPBanCache {
|
||||
cache := &IPBanCache{
|
||||
items: make(map[string]*cacheItem),
|
||||
ttl: 1 * time.Minute,
|
||||
}
|
||||
|
||||
// 启动清理协程
|
||||
go cache.cleanup()
|
||||
|
||||
return cache
|
||||
}
|
||||
|
||||
func (c *IPBanCache) Get(key string) (bool, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
item, exists := c.items[key]
|
||||
if !exists {
|
||||
return false, false
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if time.Now().UnixNano() > item.expiration {
|
||||
return false, false
|
||||
}
|
||||
|
||||
return item.value, true
|
||||
}
|
||||
|
||||
func (c *IPBanCache) Set(key string, value bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.items[key] = &cacheItem{
|
||||
value: value,
|
||||
expiration: time.Now().Add(c.ttl).UnixNano(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *IPBanCache) Delete(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.items, key)
|
||||
}
|
||||
|
||||
func (c *IPBanCache) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
c.mu.Lock()
|
||||
now := time.Now().UnixNano()
|
||||
for key, item := range c.items {
|
||||
if now > item.expiration {
|
||||
delete(c.items, key)
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func IPBanMiddleware(
|
||||
securityService *service.SecurityService,
|
||||
settingsManager *settings.SettingsManager,
|
||||
banCache *IPBanCache,
|
||||
logger *logrus.Logger,
|
||||
) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !settingsManager.IsIPBanEnabled() {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
ip := c.ClientIP()
|
||||
isBanned, err := securityService.IsIPBanned(c.Request.Context(), ip)
|
||||
if err != nil {
|
||||
|
||||
// 查缓存
|
||||
if isBanned, exists := banCache.Get(ip); exists {
|
||||
if isBanned {
|
||||
logger.WithField("ip", ip).Debug("IP blocked (cached)")
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": "Access denied",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
if isBanned {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "您的IP已被暂时封禁,请稍后再试"})
|
||||
|
||||
// 查数据库
|
||||
ctx := c.Request.Context()
|
||||
isBanned, err := securityService.IsIPBanned(ctx, ip)
|
||||
if err != nil {
|
||||
logger.WithError(err).WithField("ip", ip).Error("Failed to check IP ban status")
|
||||
|
||||
// 降级策略:允许访问
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 更新缓存
|
||||
banCache.Set(ip, isBanned)
|
||||
|
||||
if isBanned {
|
||||
logger.WithField("ip", ip).Info("IP blocked")
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"error": "Access denied",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
52
internal/middleware/timeout.go
Normal file
52
internal/middleware/timeout.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Filename: internal/middleware/timeout.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TimeoutMiddleware(timeout time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 创建带超时的 context
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// 替换 request context
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
// 使用 channel 等待请求完成
|
||||
finished := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
c.Next()
|
||||
close(finished)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-finished:
|
||||
// 请求正常完成
|
||||
return
|
||||
case <-ctx.Done():
|
||||
// 超时
|
||||
c.AbortWithStatusJSON(http.StatusGatewayTimeout, gin.H{
|
||||
"error": "Request timeout",
|
||||
"code": "TIMEOUT",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 使用示例
|
||||
func SetupTimeout(r *gin.Engine) {
|
||||
// 对 API 路由设置 30 秒超时
|
||||
api := r.Group("/api")
|
||||
api.Use(TimeoutMiddleware(30 * time.Second))
|
||||
{
|
||||
// ... API routes
|
||||
}
|
||||
}
|
||||
@@ -1,23 +1,151 @@
|
||||
// Filename: internal/middleware/web.go
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"gemini-balancer/internal/service"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
AdminSessionCookie = "gemini_admin_session"
|
||||
SessionMaxAge = 3600 * 24 * 7 // 7天
|
||||
CacheTTL = 5 * time.Minute
|
||||
CleanupInterval = 10 * time.Minute // 降低清理频率
|
||||
SessionRefreshTime = 30 * time.Minute
|
||||
)
|
||||
|
||||
// ==================== 缓存层 ====================
|
||||
|
||||
type authCacheEntry struct {
|
||||
Token interface{}
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type authCache struct {
|
||||
mu sync.RWMutex
|
||||
cache map[string]*authCacheEntry
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
var webAuthCache = newAuthCache(CacheTTL)
|
||||
|
||||
func newAuthCache(ttl time.Duration) *authCache {
|
||||
c := &authCache{
|
||||
cache: make(map[string]*authCacheEntry),
|
||||
ttl: ttl,
|
||||
}
|
||||
go c.cleanupLoop()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *authCache) get(key string) (interface{}, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
entry, exists := c.cache[key]
|
||||
if !exists || time.Now().After(entry.ExpiresAt) {
|
||||
return nil, false
|
||||
}
|
||||
return entry.Token, true
|
||||
}
|
||||
|
||||
func (c *authCache) set(key string, token interface{}) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.cache[key] = &authCacheEntry{
|
||||
Token: token,
|
||||
ExpiresAt: time.Now().Add(c.ttl),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *authCache) delete(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.cache, key)
|
||||
}
|
||||
|
||||
func (c *authCache) cleanupLoop() {
|
||||
ticker := time.NewTicker(CleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
c.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *authCache) cleanup() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
count := 0
|
||||
for key, entry := range c.cache {
|
||||
if now.After(entry.ExpiresAt) {
|
||||
delete(c.cache, key)
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
logrus.Debugf("[AuthCache] Cleaned up %d expired entries", count)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 会话刷新缓存 ====================
|
||||
|
||||
var sessionRefreshCache = struct {
|
||||
sync.RWMutex
|
||||
timestamps map[string]time.Time
|
||||
}{
|
||||
timestamps: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
// 定期清理刷新时间戳
|
||||
func init() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
sessionRefreshCache.Lock()
|
||||
now := time.Now()
|
||||
for key, ts := range sessionRefreshCache.timestamps {
|
||||
if now.Sub(ts) > 2*time.Hour {
|
||||
delete(sessionRefreshCache.timestamps, key)
|
||||
}
|
||||
}
|
||||
sessionRefreshCache.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// ==================== Cookie 操作 ====================
|
||||
|
||||
func SetAdminSessionCookie(c *gin.Context, adminToken string) {
|
||||
c.SetCookie(AdminSessionCookie, adminToken, 3600*24*7, "/", "", false, true)
|
||||
secure := c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https"
|
||||
c.SetSameSite(http.SameSiteStrictMode)
|
||||
c.SetCookie(AdminSessionCookie, adminToken, SessionMaxAge, "/", "", secure, true)
|
||||
}
|
||||
|
||||
func SetAdminSessionCookieWithAge(c *gin.Context, adminToken string, maxAge int) {
|
||||
secure := c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https"
|
||||
c.SetSameSite(http.SameSiteStrictMode)
|
||||
c.SetCookie(AdminSessionCookie, adminToken, maxAge, "/", "", secure, true)
|
||||
}
|
||||
|
||||
func ClearAdminSessionCookie(c *gin.Context) {
|
||||
c.SetSameSite(http.SameSiteStrictMode)
|
||||
c.SetCookie(AdminSessionCookie, "", -1, "/", "", false, true)
|
||||
}
|
||||
|
||||
@@ -29,26 +157,258 @@ func ExtractTokenFromCookie(c *gin.Context) string {
|
||||
return cookie
|
||||
}
|
||||
|
||||
// ==================== 认证中间件 ====================
|
||||
|
||||
func WebAdminAuthMiddleware(authService *service.SecurityService) gin.HandlerFunc {
|
||||
logger := logrus.New()
|
||||
logger.SetLevel(getLogLevel())
|
||||
|
||||
return func(c *gin.Context) {
|
||||
cookie := ExtractTokenFromCookie(c)
|
||||
log.Printf("[WebAuth_Guard] Intercepting request for: %s", c.Request.URL.Path)
|
||||
log.Printf("[WebAuth_Guard] Found session cookie value: '%s'", cookie)
|
||||
authToken, err := authService.AuthenticateToken(cookie)
|
||||
if err != nil {
|
||||
log.Printf("[WebAuth_Guard] FATAL: AuthenticateToken FAILED. Error: %v. Redirecting to /login.", err)
|
||||
} else if !authToken.IsAdmin {
|
||||
log.Printf("[WebAuth_Guard] FATAL: Token validated, but IsAdmin is FALSE. Redirecting to /login.")
|
||||
} else {
|
||||
log.Printf("[WebAuth_Guard] SUCCESS: Token validated and IsAdmin is TRUE. Allowing access.")
|
||||
}
|
||||
if err != nil || !authToken.IsAdmin {
|
||||
if cookie == "" {
|
||||
logger.Debug("[WebAuth] No session cookie found")
|
||||
ClearAdminSessionCookie(c)
|
||||
c.Redirect(http.StatusFound, "/login")
|
||||
c.Abort()
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
cacheKey := hashToken(cookie)
|
||||
|
||||
if cachedToken, found := webAuthCache.get(cacheKey); found {
|
||||
logger.Debug("[WebAuth] Using cached token")
|
||||
c.Set("adminUser", cachedToken)
|
||||
refreshSessionIfNeeded(c, cookie)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug("[WebAuth] Cache miss, authenticating...")
|
||||
authToken, err := authService.AuthenticateToken(cookie)
|
||||
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("[WebAuth] Authentication failed")
|
||||
ClearAdminSessionCookie(c)
|
||||
webAuthCache.delete(cacheKey)
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !authToken.IsAdmin {
|
||||
logger.Warn("[WebAuth] User is not admin")
|
||||
ClearAdminSessionCookie(c)
|
||||
webAuthCache.delete(cacheKey)
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
webAuthCache.set(cacheKey, authToken)
|
||||
logger.Debug("[WebAuth] Authentication success, token cached")
|
||||
|
||||
c.Set("adminUser", authToken)
|
||||
refreshSessionIfNeeded(c, cookie)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func WebAdminAuthMiddlewareWithLogger(authService *service.SecurityService, logger *logrus.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
cookie := ExtractTokenFromCookie(c)
|
||||
|
||||
if cookie == "" {
|
||||
logger.Debug("No session cookie found")
|
||||
ClearAdminSessionCookie(c)
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
cacheKey := hashToken(cookie)
|
||||
if cachedToken, found := webAuthCache.get(cacheKey); found {
|
||||
c.Set("adminUser", cachedToken)
|
||||
refreshSessionIfNeeded(c, cookie)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
authToken, err := authService.AuthenticateToken(cookie)
|
||||
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Token authentication failed")
|
||||
ClearAdminSessionCookie(c)
|
||||
webAuthCache.delete(cacheKey)
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !authToken.IsAdmin {
|
||||
logger.Warn("Token valid but user is not admin")
|
||||
ClearAdminSessionCookie(c)
|
||||
webAuthCache.delete(cacheKey)
|
||||
redirectToLogin(c)
|
||||
return
|
||||
}
|
||||
|
||||
webAuthCache.set(cacheKey, authToken)
|
||||
c.Set("adminUser", authToken)
|
||||
refreshSessionIfNeeded(c, cookie)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 辅助函数 ====================
|
||||
|
||||
func hashToken(token string) string {
|
||||
h := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
func redirectToLogin(c *gin.Context) {
|
||||
if isAjaxRequest(c) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Session expired",
|
||||
"code": "AUTH_REQUIRED",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
originalPath := c.Request.URL.Path
|
||||
if originalPath != "/" && originalPath != "/login" {
|
||||
c.Redirect(http.StatusFound, "/login?redirect="+originalPath)
|
||||
} else {
|
||||
c.Redirect(http.StatusFound, "/login")
|
||||
}
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
func isAjaxRequest(c *gin.Context) bool {
|
||||
// 检查 Content-Type
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查 Accept(优先检查 JSON)
|
||||
accept := c.GetHeader("Accept")
|
||||
if strings.Contains(accept, "application/json") &&
|
||||
!strings.Contains(accept, "text/html") {
|
||||
return true
|
||||
}
|
||||
|
||||
// 兼容旧版 XMLHttpRequest
|
||||
return c.GetHeader("X-Requested-With") == "XMLHttpRequest"
|
||||
}
|
||||
|
||||
func refreshSessionIfNeeded(c *gin.Context, token string) {
|
||||
tokenHash := hashToken(token)
|
||||
|
||||
sessionRefreshCache.RLock()
|
||||
lastRefresh, exists := sessionRefreshCache.timestamps[tokenHash]
|
||||
sessionRefreshCache.RUnlock()
|
||||
|
||||
if !exists || time.Since(lastRefresh) > SessionRefreshTime {
|
||||
SetAdminSessionCookie(c, token)
|
||||
|
||||
sessionRefreshCache.Lock()
|
||||
sessionRefreshCache.timestamps[tokenHash] = time.Now()
|
||||
sessionRefreshCache.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func getLogLevel() logrus.Level {
|
||||
level := os.Getenv("LOG_LEVEL")
|
||||
switch strings.ToLower(level) {
|
||||
case "debug":
|
||||
return logrus.DebugLevel
|
||||
case "warn":
|
||||
return logrus.WarnLevel
|
||||
case "error":
|
||||
return logrus.ErrorLevel
|
||||
default:
|
||||
return logrus.InfoLevel
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 工具函数 ====================
|
||||
|
||||
func GetAdminUserFromContext(c *gin.Context) (interface{}, bool) {
|
||||
return c.Get("adminUser")
|
||||
}
|
||||
|
||||
func InvalidateTokenCache(token string) {
|
||||
tokenHash := hashToken(token)
|
||||
webAuthCache.delete(tokenHash)
|
||||
|
||||
// 同时清理刷新时间戳
|
||||
sessionRefreshCache.Lock()
|
||||
delete(sessionRefreshCache.timestamps, tokenHash)
|
||||
sessionRefreshCache.Unlock()
|
||||
}
|
||||
|
||||
func ClearAllAuthCache() {
|
||||
webAuthCache.mu.Lock()
|
||||
webAuthCache.cache = make(map[string]*authCacheEntry)
|
||||
webAuthCache.mu.Unlock()
|
||||
|
||||
sessionRefreshCache.Lock()
|
||||
sessionRefreshCache.timestamps = make(map[string]time.Time)
|
||||
sessionRefreshCache.Unlock()
|
||||
}
|
||||
|
||||
// ==================== 调试工具 ====================
|
||||
|
||||
type SessionInfo struct {
|
||||
HasCookie bool `json:"has_cookie"`
|
||||
IsValid bool `json:"is_valid"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
IsCached bool `json:"is_cached"`
|
||||
LastActivity string `json:"last_activity"`
|
||||
}
|
||||
|
||||
func GetSessionInfo(c *gin.Context, authService *service.SecurityService) SessionInfo {
|
||||
info := SessionInfo{
|
||||
HasCookie: false,
|
||||
IsValid: false,
|
||||
IsAdmin: false,
|
||||
IsCached: false,
|
||||
LastActivity: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
cookie := ExtractTokenFromCookie(c)
|
||||
if cookie == "" {
|
||||
return info
|
||||
}
|
||||
|
||||
info.HasCookie = true
|
||||
|
||||
cacheKey := hashToken(cookie)
|
||||
if _, found := webAuthCache.get(cacheKey); found {
|
||||
info.IsCached = true
|
||||
}
|
||||
|
||||
authToken, err := authService.AuthenticateToken(cookie)
|
||||
if err != nil {
|
||||
return info
|
||||
}
|
||||
|
||||
info.IsValid = true
|
||||
info.IsAdmin = authToken.IsAdmin
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func GetCacheStats() map[string]interface{} {
|
||||
webAuthCache.mu.RLock()
|
||||
cacheSize := len(webAuthCache.cache)
|
||||
webAuthCache.mu.RUnlock()
|
||||
|
||||
sessionRefreshCache.RLock()
|
||||
refreshSize := len(sessionRefreshCache.timestamps)
|
||||
sessionRefreshCache.RUnlock()
|
||||
return map[string]interface{}{
|
||||
"auth_cache_entries": cacheSize,
|
||||
"refresh_cache_entries": refreshSize,
|
||||
"ttl_seconds": int(webAuthCache.ttl.Seconds()),
|
||||
"cleanup_interval": int(CleanupInterval.Seconds()),
|
||||
"session_refresh_time": int(SessionRefreshTime.Seconds()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,3 +77,9 @@ type APIKeyDetails struct {
|
||||
CooldownUntil *time.Time `json:"cooldown_until"`
|
||||
EncryptedKey string
|
||||
}
|
||||
|
||||
// SettingsManager 定义了系统设置管理器的抽象接口。
|
||||
|
||||
type SettingsManager interface {
|
||||
GetSettings() *SystemSettings
|
||||
}
|
||||
|
||||
@@ -101,6 +101,7 @@ type RequestLog struct {
|
||||
LatencyMs int
|
||||
IsSuccess bool
|
||||
StatusCode int
|
||||
Status string `gorm:"type:varchar(100);index"`
|
||||
ModelName string `gorm:"type:varchar(100);index"`
|
||||
GroupID *uint `gorm:"index"`
|
||||
KeyID *uint `gorm:"index"`
|
||||
|
||||
@@ -11,6 +11,7 @@ type SystemSettings struct {
|
||||
BlacklistThreshold int `json:"blacklist_threshold" default:"3" name:"拉黑阈值" category:"密钥设置" desc:"一个Key连续失败多少次后进入冷却状态。"`
|
||||
KeyCooldownMinutes int `json:"key_cooldown_minutes" default:"10" name:"密钥冷却时长(分钟)" category:"密钥设置" desc:"一个Key进入冷却状态后需要等待的时间,单位为分钟。"`
|
||||
LogFlushIntervalSeconds int `json:"log_flush_interval_seconds" default:"10" name:"日志刷新间隔(秒)" category:"日志设置" desc:"异步日志写入数据库的间隔时间(秒)。"`
|
||||
MaxRequestBodySizeMB int `json:"max_request_body_size_mb" default:"10" name:"最大请求体大小 (MB)" category:"请求设置" desc:"允许代理接收的最大请求体大小,单位为MB。超过此大小的请求将被拒绝。"`
|
||||
|
||||
PollingStrategy PollingStrategy `json:"polling_strategy" default:"random" name:"全局轮询策略" category:"调度设置" desc:"智能聚合模式下,从所有可用密钥中选择一个的默认策略。可选值: sequential(顺序), random(随机), weighted(加权)。"`
|
||||
|
||||
@@ -26,6 +27,8 @@ type SystemSettings struct {
|
||||
BaseKeyCheckEndpoint string `json:"base_key_check_endpoint" default:"https://generativelanguage.googleapis.com/v1beta/models" name:"全局Key检查端点" category:"健康检查" desc:"用于全局Key身份检查的目标URL。"`
|
||||
BaseKeyCheckModel string `json:"base_key_check_model" default:"gemini-2.0-flash-lite" name:"默认Key检查模型" category:"健康检查" desc:"用于分组健康检查和手动密钥测试时的默认回退模型。"`
|
||||
|
||||
KeyCheckSchedulerIntervalSeconds int `json:"key_check_scheduler_interval_seconds" default:"60" name:"Key检查调度器间隔(秒)" category:"健康检查" desc:"动态调度器检查各组是否需要执行健康检查的周期。"`
|
||||
|
||||
EnableUpstreamCheck bool `json:"enable_upstream_check" default:"true" name:"启用上游检查" category:"健康检查" desc:"是否启用对上游服务(Upstream)的健康检查。"`
|
||||
UpstreamCheckTimeoutSeconds int `json:"upstream_check_timeout_seconds" default:"20" name:"上游检查超时(秒)" category:"健康检查" desc:"对单个上游服务进行健康检查时的网络超时时间。"`
|
||||
|
||||
@@ -41,6 +44,10 @@ type SystemSettings struct {
|
||||
MaxLoginAttempts int `json:"max_login_attempts" default:"5" name:"最大登录失败次数" category:"安全设置" desc:"在一个IP被封禁前,允许的连续登录失败次数。"`
|
||||
IPBanDurationMinutes int `json:"ip_ban_duration_minutes" default:"15" name:"IP封禁时长(分钟)" category:"安全设置" desc:"IP被封禁的时长,单位为分钟。"`
|
||||
|
||||
// BasePool 相关配置
|
||||
// BasePoolTTLMinutes int `json:"base_pool_ttl_minutes" default:"30" name:"基础资源池最大生存时间(分钟)" category:"基础资源池" desc:"一个动态构建的基础资源池(BasePool)在Redis中的最大生存时间。到期后即使仍在活跃使用也会被强制重建。"`
|
||||
// BasePoolTTIMinutes int `json:"base_pool_tti_minutes" default:"10" name:"基础资源池空闲超时(分钟)" category:"基础资源池" desc:"一个基础资源池(BasePool)在连续无请求后,自动销毁的空闲等待时间。"`
|
||||
|
||||
//智能网关
|
||||
LogTruncationLimit int `json:"log_truncation_limit" default:"8000" name:"日志截断长度" category:"日志设置" desc:"在日志中记录上游响应或错误时,保留的最大字符数。0表示不截断。"`
|
||||
EnableSmartGateway bool `json:"enable_smart_gateway" default:"false" name:"启用智能网关" category:"代理设置" desc:"开启后,系统将对流式请求进行智能中断续传、错误标准化等优化。关闭后,系统将作为一个纯净、无干扰的透明代理。"`
|
||||
@@ -64,6 +71,10 @@ type SystemSettings struct {
|
||||
LogBufferCapacity int `json:"log_buffer_capacity" default:"1000" name:"日志缓冲区容量" category:"日志设置" desc:"内存中日志缓冲区的最大容量,超过则可能丢弃日志。"`
|
||||
LogFlushBatchSize int `json:"log_flush_batch_size" default:"100" name:"日志刷新批次大小" category:"日志设置" desc:"每次向数据库批量写入日志的最大数量。"`
|
||||
|
||||
LogAutoCleanupEnabled bool `json:"log_auto_cleanup_enabled" default:"false" name:"开启请求日志自动清理" category:"日志配置" desc:"启用后,系统将每日定时删除旧的请求日志。"`
|
||||
LogAutoCleanupRetentionDays int `json:"log_auto_cleanup_retention_days" default:"30" name:"日志保留天数" category:"日志配置" desc:"自动清理任务将保留最近 N 天的日志。"`
|
||||
LogAutoCleanupTime string `json:"log_auto_cleanup_time" default:"04:05" name:"每日清理执行时间" category:"日志配置" desc:"自动清理任务执行的时间点(24小时制,例如 04:05)。"`
|
||||
|
||||
// --- API配置 ---
|
||||
CustomHeaders map[string]string `json:"custom_headers" name:"自定义Headers" category:"API配置" ` // 默认为nil
|
||||
|
||||
|
||||
@@ -1,63 +1,96 @@
|
||||
// Filename: internal/pongo/renderer.go
|
||||
|
||||
package pongo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/flosch/pongo2/v6"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin/render"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Renderer struct {
|
||||
Context pongo2.Context
|
||||
tplSet *pongo2.TemplateSet
|
||||
mu sync.RWMutex
|
||||
globalContext pongo2.Context
|
||||
tplSet *pongo2.TemplateSet
|
||||
logger *logrus.Logger
|
||||
}
|
||||
|
||||
func New(directory string, isDebug bool) *Renderer {
|
||||
func New(directory string, isDebug bool, logger *logrus.Logger) *Renderer {
|
||||
loader := pongo2.MustNewLocalFileSystemLoader(directory)
|
||||
tplSet := pongo2.NewSet("gin-pongo-templates", loader)
|
||||
tplSet.Debug = isDebug
|
||||
return &Renderer{Context: make(pongo2.Context), tplSet: tplSet}
|
||||
return &Renderer{
|
||||
globalContext: make(pongo2.Context),
|
||||
tplSet: tplSet,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Instance returns a new render.HTML instance for a single request.
|
||||
func (p *Renderer) Instance(name string, data interface{}) render.Render {
|
||||
var glob pongo2.Context
|
||||
if p.Context != nil {
|
||||
glob = p.Context
|
||||
}
|
||||
// SetGlobalContext 线程安全地设置全局上下文
|
||||
func (p *Renderer) SetGlobalContext(key string, value interface{}) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.globalContext[key] = value
|
||||
}
|
||||
|
||||
// Warmup 预加载模板
|
||||
func (p *Renderer) Warmup(templateNames ...string) error {
|
||||
for _, name := range templateNames {
|
||||
if _, err := p.tplSet.FromCache(name); err != nil {
|
||||
return fmt.Errorf("failed to warmup template '%s': %w", name, err)
|
||||
}
|
||||
}
|
||||
p.logger.WithField("count", len(templateNames)).Info("Templates warmed up")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Renderer) Instance(name string, data interface{}) render.Render {
|
||||
// 安全读取全局上下文
|
||||
p.mu.RLock()
|
||||
glob := make(pongo2.Context, len(p.globalContext))
|
||||
for k, v := range p.globalContext {
|
||||
glob[k] = v
|
||||
}
|
||||
p.mu.RUnlock()
|
||||
|
||||
// 解析请求数据
|
||||
var context pongo2.Context
|
||||
if data != nil {
|
||||
if ginContext, ok := data.(gin.H); ok {
|
||||
context = pongo2.Context(ginContext)
|
||||
} else if pongoContext, ok := data.(pongo2.Context); ok {
|
||||
context = pongoContext
|
||||
} else if m, ok := data.(map[string]interface{}); ok {
|
||||
context = m
|
||||
} else {
|
||||
switch v := data.(type) {
|
||||
case gin.H:
|
||||
context = pongo2.Context(v)
|
||||
case pongo2.Context:
|
||||
context = v
|
||||
case map[string]interface{}:
|
||||
context = v
|
||||
default:
|
||||
context = make(pongo2.Context)
|
||||
}
|
||||
} else {
|
||||
context = make(pongo2.Context)
|
||||
}
|
||||
|
||||
// 合并上下文(请求数据优先)
|
||||
for k, v := range glob {
|
||||
if _, ok := context[k]; !ok {
|
||||
if _, exists := context[k]; !exists {
|
||||
context[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// 加载模板
|
||||
tpl, err := p.tplSet.FromCache(name)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to load template '%s': %v", name, err))
|
||||
p.logger.WithError(err).WithField("template", name).Error("Failed to load template")
|
||||
return &ErrorHTML{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: fmt.Errorf("template load error: %s", name),
|
||||
}
|
||||
}
|
||||
|
||||
return &HTML{
|
||||
p: p,
|
||||
Template: tpl,
|
||||
Name: name,
|
||||
Data: context,
|
||||
@@ -65,7 +98,6 @@ func (p *Renderer) Instance(name string, data interface{}) render.Render {
|
||||
}
|
||||
|
||||
type HTML struct {
|
||||
p *Renderer
|
||||
Template *pongo2.Template
|
||||
Name string
|
||||
Data pongo2.Context
|
||||
@@ -82,15 +114,31 @@ func (h *HTML) Render(w http.ResponseWriter) error {
|
||||
}
|
||||
|
||||
func (h *HTML) WriteContentType(w http.ResponseWriter) {
|
||||
header := w.Header()
|
||||
if val := header["Content-Type"]; len(val) == 0 {
|
||||
header["Content-Type"] = []string{"text/html; charset=utf-8"}
|
||||
if w.Header().Get("Content-Type") == "" {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorHTML 错误渲染器
|
||||
type ErrorHTML struct {
|
||||
StatusCode int
|
||||
Error error
|
||||
}
|
||||
|
||||
func (e *ErrorHTML) Render(w http.ResponseWriter) error {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(e.StatusCode)
|
||||
_, err := w.Write([]byte(e.Error.Error()))
|
||||
return err
|
||||
}
|
||||
|
||||
func (e *ErrorHTML) WriteContentType(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
}
|
||||
|
||||
// C 获取或创建 pongo2 上下文
|
||||
func C(ctx *gin.Context) pongo2.Context {
|
||||
p, exists := ctx.Get("pongo2")
|
||||
if exists {
|
||||
if p, exists := ctx.Get("pongo2"); exists {
|
||||
if pCtx, ok := p.(pongo2.Context); ok {
|
||||
return pCtx
|
||||
}
|
||||
|
||||
@@ -17,8 +17,8 @@ import (
|
||||
type AuthTokenRepository interface {
|
||||
GetAllTokensWithGroups() ([]*models.AuthToken, error)
|
||||
BatchUpdateTokens(updates []*models.TokenUpdateRequest) error
|
||||
GetTokenByHashedValue(tokenHash string) (*models.AuthToken, error) // <-- Add this line
|
||||
SeedAdminToken(encryptedToken, tokenHash string) error // <-- And this line for the seeder
|
||||
GetTokenByHashedValue(tokenHash string) (*models.AuthToken, error)
|
||||
SeedAdminToken(encryptedToken, tokenHash string) error
|
||||
}
|
||||
|
||||
type gormAuthTokenRepository struct {
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
// Filename: internal/repository/key_cache.go
|
||||
// Filename: internal/repository/key_cache.go (最终定稿)
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// --- Redis Key 常量定义 ---
|
||||
const (
|
||||
KeyGroup = "group:%d:keys:active"
|
||||
KeyDetails = "key:%d:details"
|
||||
@@ -22,13 +24,16 @@ const (
|
||||
BasePoolRandomCooldown = "basepool:%s:keys:random:cooldown"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) LoadAllKeysToStore() error {
|
||||
r.logger.Info("Starting to load all keys and associations into cache, including polling structures...")
|
||||
// LoadAllKeysToStore 从数据库加载所有密钥和映射关系,并完整重建Redis缓存。
|
||||
func (r *gormKeyRepository) LoadAllKeysToStore(ctx context.Context) error {
|
||||
r.logger.Info("Starting full cache rebuild for all keys and polling structures.")
|
||||
|
||||
var allMappings []*models.GroupAPIKeyMapping
|
||||
if err := r.db.Preload("APIKey").Find(&allMappings).Error; err != nil {
|
||||
return fmt.Errorf("failed to load all mappings with APIKeys from DB: %w", err)
|
||||
return fmt.Errorf("failed to load mappings with preloaded APIKeys: %w", err)
|
||||
}
|
||||
|
||||
// 1. 批量解密所有涉及的密钥
|
||||
keyMap := make(map[uint]*models.APIKey)
|
||||
for _, m := range allMappings {
|
||||
if m.APIKey != nil {
|
||||
@@ -40,16 +45,16 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
|
||||
keysToDecrypt = append(keysToDecrypt, *k)
|
||||
}
|
||||
if err := r.decryptKeys(keysToDecrypt); err != nil {
|
||||
r.logger.WithError(err).Error("Critical error during cache preload: batch decryption failed.")
|
||||
r.logger.WithError(err).Error("Batch decryption failed during cache rebuild.")
|
||||
// 即使解密失败,也继续尝试加载未加密或已解密的部分
|
||||
}
|
||||
decryptedKeyMap := make(map[uint]models.APIKey)
|
||||
for _, k := range keysToDecrypt {
|
||||
decryptedKeyMap[k.ID] = k
|
||||
}
|
||||
|
||||
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
|
||||
pipe := r.store.Pipeline()
|
||||
detailsToSet := make(map[string][]byte)
|
||||
// 2. 清理所有分组的旧轮询结构
|
||||
pipe := r.store.Pipeline(ctx)
|
||||
var allGroups []*models.KeyGroup
|
||||
if err := r.db.Find(&allGroups).Error; err == nil {
|
||||
for _, group := range allGroups {
|
||||
@@ -62,26 +67,41 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
|
||||
)
|
||||
}
|
||||
} else {
|
||||
r.logger.WithError(err).Error("Failed to get all groups for cache cleanup")
|
||||
r.logger.WithError(err).Error("Failed to get groups for cache cleanup; proceeding with rebuild.")
|
||||
}
|
||||
|
||||
// 3. 准备批量更新数据
|
||||
activeKeysByGroup := make(map[uint][]*models.GroupAPIKeyMapping)
|
||||
detailsToSet := make(map[string]any)
|
||||
|
||||
for _, mapping := range allMappings {
|
||||
if mapping.APIKey == nil {
|
||||
continue
|
||||
}
|
||||
decryptedKey, ok := decryptedKeyMap[mapping.APIKeyID]
|
||||
if !ok {
|
||||
continue
|
||||
continue // 跳过解密失败的密钥
|
||||
}
|
||||
|
||||
// 准备 KeyDetails 和 KeyMapping 的 MSet 数据
|
||||
keyJSON, _ := json.Marshal(decryptedKey)
|
||||
detailsToSet[fmt.Sprintf(KeyDetails, decryptedKey.ID)] = keyJSON
|
||||
mappingJSON, _ := json.Marshal(mapping)
|
||||
detailsToSet[fmt.Sprintf(KeyMapping, mapping.KeyGroupID, decryptedKey.ID)] = mappingJSON
|
||||
|
||||
if mapping.Status == models.StatusActive {
|
||||
activeKeysByGroup[mapping.KeyGroupID] = append(activeKeysByGroup[mapping.KeyGroupID], mapping)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 使用 MSet 批量写入详情和映射缓存
|
||||
if len(detailsToSet) > 0 {
|
||||
if err := r.store.MSet(ctx, detailsToSet); err != nil {
|
||||
r.logger.WithError(err).Error("Failed to MSet key details and mappings during cache rebuild.")
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 在Pipeline中重建所有分组的轮询结构
|
||||
for groupID, activeMappings := range activeKeysByGroup {
|
||||
if len(activeMappings) == 0 {
|
||||
continue
|
||||
@@ -100,22 +120,19 @@ func (r *gormKeyRepository) LoadAllKeysToStore() error {
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), activeKeyIDs...)
|
||||
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), activeKeyIDs...)
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), activeKeyIDs...)
|
||||
go r.store.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
|
||||
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), lruMembers)
|
||||
}
|
||||
|
||||
// 6. 执行Pipeline
|
||||
if err := pipe.Exec(); err != nil {
|
||||
return fmt.Errorf("failed to execute pipeline for cache rebuild: %w", err)
|
||||
}
|
||||
for key, value := range detailsToSet {
|
||||
if err := r.store.Set(key, value, 0); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to set key detail in cache: %s", key)
|
||||
}
|
||||
return fmt.Errorf("pipeline execution for polling structures failed: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Cache rebuild complete, including all polling structures.")
|
||||
r.logger.Info("Full cache rebuild completed successfully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateStoreCacheForKey 更新单个APIKey的详情缓存 (K-V)。
|
||||
func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
|
||||
if err := r.decryptKey(key); err != nil {
|
||||
return fmt.Errorf("failed to decrypt key %d for cache update: %w", key.ID, err)
|
||||
@@ -124,81 +141,104 @@ func (r *gormKeyRepository) updateStoreCacheForKey(key *models.APIKey) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal key %d for cache update: %w", key.ID, err)
|
||||
}
|
||||
return r.store.Set(fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
|
||||
return r.store.Set(context.Background(), fmt.Sprintf(KeyDetails, key.ID), keyJSON, 0)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) removeStoreCacheForKey(key *models.APIKey) error {
|
||||
groupIDs, err := r.GetGroupsForKey(key.ID)
|
||||
// removeStoreCacheForKey 从所有缓存结构中彻底移除一个APIKey。
|
||||
func (r *gormKeyRepository) removeStoreCacheForKey(ctx context.Context, key *models.APIKey) error {
|
||||
groupIDs, err := r.GetGroupsForKey(ctx, key.ID)
|
||||
if err != nil {
|
||||
r.logger.Warnf("failed to get groups for key %d to clean up cache lists: %v", key.ID, err)
|
||||
r.logger.WithError(err).Warnf("Failed to get groups for key %d during cache removal, cleanup may be partial.", key.ID)
|
||||
}
|
||||
|
||||
pipe := r.store.Pipeline()
|
||||
pipe := r.store.Pipeline(ctx)
|
||||
pipe.Del(fmt.Sprintf(KeyDetails, key.ID))
|
||||
|
||||
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
|
||||
for _, groupID := range groupIDs {
|
||||
pipe.Del(fmt.Sprintf(KeyMapping, groupID, key.ID))
|
||||
|
||||
keyIDStr := strconv.FormatUint(uint64(key.ID), 10)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
||||
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
|
||||
go r.store.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
|
||||
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
|
||||
}
|
||||
|
||||
return pipe.Exec()
|
||||
}
|
||||
|
||||
// updateStoreCacheForMapping 根据单个映射关系的状态,原子性地更新所有相关的缓存结构。
|
||||
func (r *gormKeyRepository) updateStoreCacheForMapping(mapping *models.GroupAPIKeyMapping) error {
|
||||
pipe := r.store.Pipeline()
|
||||
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", mapping.KeyGroupID)
|
||||
pipe.LRem(activeKeyListKey, 0, mapping.APIKeyID)
|
||||
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
|
||||
groupID := mapping.KeyGroupID
|
||||
ctx := context.Background()
|
||||
|
||||
pipe := r.store.Pipeline(ctx)
|
||||
|
||||
// 统一、无条件地从所有轮询结构中移除,确保状态清洁
|
||||
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
||||
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
|
||||
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
|
||||
|
||||
// 如果新状态是 Active,则重新添加到所有轮询结构中
|
||||
if mapping.Status == models.StatusActive {
|
||||
pipe.LPush(activeKeyListKey, mapping.APIKeyID)
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
||||
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr)
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||||
|
||||
var score float64
|
||||
if mapping.LastUsedAt != nil {
|
||||
score = float64(mapping.LastUsedAt.UnixMilli())
|
||||
}
|
||||
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score})
|
||||
}
|
||||
|
||||
// 无论状态如何,都更新映射详情的 K-V 缓存
|
||||
mappingJSON, err := json.Marshal(mapping)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal mapping: %w", err)
|
||||
}
|
||||
pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0)
|
||||
|
||||
return pipe.Exec()
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error {
|
||||
// HandleCacheUpdateEventBatch 批量、原子性地更新多个映射关系的缓存。
|
||||
func (r *gormKeyRepository) HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error {
|
||||
if len(mappings) == 0 {
|
||||
return nil
|
||||
}
|
||||
groupUpdates := make(map[uint]struct {
|
||||
ToAdd []interface{}
|
||||
ToRemove []interface{}
|
||||
})
|
||||
|
||||
pipe := r.store.Pipeline(ctx)
|
||||
|
||||
for _, mapping := range mappings {
|
||||
keyIDStr := strconv.FormatUint(uint64(mapping.APIKeyID), 10)
|
||||
update, ok := groupUpdates[mapping.KeyGroupID]
|
||||
if !ok {
|
||||
update = struct {
|
||||
ToAdd []interface{}
|
||||
ToRemove []interface{}
|
||||
}{}
|
||||
}
|
||||
groupID := mapping.KeyGroupID
|
||||
|
||||
// 对于批处理中的每一个mapping,都执行完整的、正确的“先删后增”逻辑
|
||||
pipe.SRem(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
||||
pipe.LRem(fmt.Sprintf(KeyGroupSequential, groupID), 0, keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||||
pipe.SRem(fmt.Sprintf(KeyGroupRandomCooldown, groupID), keyIDStr)
|
||||
pipe.ZRem(fmt.Sprintf(KeyGroupLRU, groupID), keyIDStr)
|
||||
|
||||
if mapping.Status == models.StatusActive {
|
||||
update.ToRemove = append(update.ToRemove, keyIDStr)
|
||||
update.ToAdd = append(update.ToAdd, keyIDStr)
|
||||
} else {
|
||||
update.ToRemove = append(update.ToRemove, keyIDStr)
|
||||
}
|
||||
groupUpdates[mapping.KeyGroupID] = update
|
||||
}
|
||||
pipe := r.store.Pipeline()
|
||||
var pipelineError error
|
||||
for groupID, updates := range groupUpdates {
|
||||
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
|
||||
if len(updates.ToRemove) > 0 {
|
||||
for _, keyID := range updates.ToRemove {
|
||||
pipe.LRem(activeKeyListKey, 0, keyID)
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroup, groupID), keyIDStr)
|
||||
pipe.LPush(fmt.Sprintf(KeyGroupSequential, groupID), keyIDStr)
|
||||
pipe.SAdd(fmt.Sprintf(KeyGroupRandomMain, groupID), keyIDStr)
|
||||
|
||||
var score float64
|
||||
if mapping.LastUsedAt != nil {
|
||||
score = float64(mapping.LastUsedAt.UnixMilli())
|
||||
}
|
||||
pipe.ZAdd(fmt.Sprintf(KeyGroupLRU, groupID), map[string]float64{keyIDStr: score})
|
||||
}
|
||||
if len(updates.ToAdd) > 0 {
|
||||
pipe.LPush(activeKeyListKey, updates.ToAdd...)
|
||||
}
|
||||
|
||||
mappingJSON, _ := json.Marshal(mapping) // 在批处理中忽略单个marshal错误,以保证大部分更新成功
|
||||
pipe.Set(fmt.Sprintf(KeyMapping, groupID, mapping.APIKeyID), mappingJSON, 0)
|
||||
}
|
||||
if err := pipe.Exec(); err != nil {
|
||||
pipelineError = fmt.Errorf("redis pipeline execution failed: %w", err)
|
||||
}
|
||||
return pipelineError
|
||||
|
||||
return pipe.Exec()
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
|
||||
"context"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -22,7 +23,6 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
|
||||
keyHashes := make([]string, len(keys))
|
||||
keyValueToHashMap := make(map[string]string)
|
||||
for i, k := range keys {
|
||||
// All incoming keys must have plaintext APIKey
|
||||
if k.APIKey == "" {
|
||||
return nil, fmt.Errorf("cannot add key at index %d: plaintext APIKey is empty", i)
|
||||
}
|
||||
@@ -34,7 +34,6 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
|
||||
var finalKeys []models.APIKey
|
||||
err := r.db.Transaction(func(tx *gorm.DB) error {
|
||||
var existingKeys []models.APIKey
|
||||
// [MODIFIED] Query by hash to find existing keys.
|
||||
if err := tx.Unscoped().Where("api_key_hash IN ?", keyHashes).Find(&existingKeys).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -68,24 +67,20 @@ func (r *gormKeyRepository) AddKeys(keys []models.APIKey) ([]models.APIKey, erro
|
||||
}
|
||||
}
|
||||
if len(keysToCreate) > 0 {
|
||||
// [MODIFIED] Create now only provides encrypted data and hash.
|
||||
if err := tx.Clauses(clause.OnConflict{DoNothing: true}, clause.Returning{}).Create(&keysToCreate).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// [MODIFIED] Final select uses hashes to retrieve all relevant keys.
|
||||
if err := tx.Where("api_key_hash IN ?", keyHashes).Find(&finalKeys).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// [CRITICAL] Decrypt all keys before returning them to the service layer.
|
||||
|
||||
return r.decryptKeys(finalKeys)
|
||||
})
|
||||
return finalKeys, err
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) Update(key *models.APIKey) error {
|
||||
// [CRITICAL] Before saving, check if the plaintext APIKey field was populated.
|
||||
// This indicates a potential change that needs to be re-encrypted.
|
||||
if key.APIKey != "" {
|
||||
encryptedKey, err := r.crypto.Encrypt(key.APIKey)
|
||||
if err != nil {
|
||||
@@ -97,16 +92,16 @@ func (r *gormKeyRepository) Update(key *models.APIKey) error {
|
||||
key.APIKeyHash = hex.EncodeToString(hash[:])
|
||||
}
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
// GORM automatically ignores `key.APIKey` because of the `gorm:"-"` tag.
|
||||
|
||||
return tx.Save(key).Error
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// For the cache update, we need the plaintext. Decrypt if it's not already populated.
|
||||
|
||||
if err := r.decryptKey(key); err != nil {
|
||||
r.logger.Warnf("DB updated key ID %d, but decryption for cache failed: %v", key.ID, err)
|
||||
return nil // Continue without cache update if decryption fails.
|
||||
return nil
|
||||
}
|
||||
if err := r.updateStoreCacheForKey(key); err != nil {
|
||||
r.logger.Warnf("DB updated key ID %d, but cache update failed: %v", key.ID, err)
|
||||
@@ -115,7 +110,7 @@ func (r *gormKeyRepository) Update(key *models.APIKey) error {
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HardDeleteByID(id uint) error {
|
||||
key, err := r.GetKeyByID(id) // This now returns a decrypted key
|
||||
key, err := r.GetKeyByID(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -125,7 +120,7 @@ func (r *gormKeyRepository) HardDeleteByID(id uint) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.removeStoreCacheForKey(key); err != nil {
|
||||
if err := r.removeStoreCacheForKey(context.Background(), key); err != nil {
|
||||
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", id, err)
|
||||
}
|
||||
return nil
|
||||
@@ -140,16 +135,13 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error
|
||||
hash := sha256.Sum256([]byte(v))
|
||||
hashes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
// Find the full key objects first to update the cache later.
|
||||
var keysToDelete []models.APIKey
|
||||
// [MODIFIED] Find by hash.
|
||||
if err := r.db.Where("api_key_hash IN ?", hashes).Find(&keysToDelete).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(keysToDelete) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
// Decrypt them to ensure cache has plaintext if needed.
|
||||
if err := r.decryptKeys(keysToDelete); err != nil {
|
||||
r.logger.Warnf("Decryption failed for keys before hard delete, cache removal may be impacted: %v", err)
|
||||
}
|
||||
@@ -167,7 +159,7 @@ func (r *gormKeyRepository) HardDeleteByValues(keyValues []string) (int64, error
|
||||
return 0, err
|
||||
}
|
||||
for i := range keysToDelete {
|
||||
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
|
||||
if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
|
||||
r.logger.Warnf("DB deleted key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
|
||||
}
|
||||
}
|
||||
@@ -194,7 +186,6 @@ func (r *gormKeyRepository) GetKeysByIDs(ids []uint) ([]models.APIKey, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// [CRITICAL] Decrypt before returning.
|
||||
return keys, r.decryptKeys(keys)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"gemini-balancer/internal/models"
|
||||
@@ -110,13 +111,13 @@ func (r *gormKeyRepository) deleteOrphanKeysLogic(db *gorm.DB) (int64, error) {
|
||||
}
|
||||
|
||||
result := db.Delete(&models.APIKey{}, orphanKeyIDs)
|
||||
//result := db.Unscoped().Delete(&models.APIKey{}, orphanKeyIDs)
|
||||
if result.Error != nil {
|
||||
return 0, result.Error
|
||||
}
|
||||
|
||||
for i := range keysToDelete {
|
||||
if err := r.removeStoreCacheForKey(&keysToDelete[i]); err != nil {
|
||||
// [修正] 使用 context.Background() 调用已更新的缓存清理函数
|
||||
if err := r.removeStoreCacheForKey(context.Background(), &keysToDelete[i]); err != nil {
|
||||
r.logger.Warnf("DB deleted orphan key ID %d, but cache removal failed: %v", keysToDelete[i].ID, err)
|
||||
}
|
||||
}
|
||||
@@ -144,7 +145,7 @@ func (r *gormKeyRepository) GetActiveMasterKeys() ([]*models.APIKey, error) {
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error {
|
||||
func (r *gormKeyRepository) UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error {
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
result := tx.Model(&models.APIKey{}).
|
||||
Where("id = ?", keyID).
|
||||
@@ -160,7 +161,7 @@ func (r *gormKeyRepository) UpdateAPIKeyStatus(keyID uint, status models.MasterA
|
||||
if err == nil {
|
||||
r.logger.Infof("MasterStatus for key ID %d changed, triggering a full cache reload.", keyID)
|
||||
go func() {
|
||||
if err := r.LoadAllKeysToStore(); err != nil {
|
||||
if err := r.LoadAllKeysToStore(context.Background()); err != nil {
|
||||
r.logger.Errorf("Failed to reload cache after MasterStatus change for key ID %d: %v", keyID, err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
|
||||
func (r *gormKeyRepository) LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error {
|
||||
if len(keyIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -34,12 +35,12 @@ func (r *gormKeyRepository) LinkKeysToGroup(groupID uint, keyIDs []uint) error {
|
||||
}
|
||||
|
||||
for _, keyID := range keyIDs {
|
||||
r.store.SAdd(fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
r.store.SAdd(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (int64, error) {
|
||||
func (r *gormKeyRepository) UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (int64, error) {
|
||||
if len(keyIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
@@ -63,16 +64,16 @@ func (r *gormKeyRepository) UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (in
|
||||
|
||||
activeKeyListKey := fmt.Sprintf("group:%d:keys:active", groupID)
|
||||
for _, keyID := range keyIDs {
|
||||
r.store.SRem(fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
r.store.LRem(activeKeyListKey, 0, strconv.Itoa(int(keyID)))
|
||||
r.store.SRem(context.Background(), fmt.Sprintf("key:%d:groups", keyID), groupID)
|
||||
r.store.LRem(context.Background(), activeKeyListKey, 0, strconv.Itoa(int(keyID)))
|
||||
}
|
||||
|
||||
return unlinkedCount, nil
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
|
||||
func (r *gormKeyRepository) GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error) {
|
||||
cacheKey := fmt.Sprintf("key:%d:groups", keyID)
|
||||
strGroupIDs, err := r.store.SMembers(cacheKey)
|
||||
strGroupIDs, err := r.store.SMembers(context.Background(), cacheKey)
|
||||
if err != nil || len(strGroupIDs) == 0 {
|
||||
var groupIDs []uint
|
||||
dbErr := r.db.Table("group_api_key_mappings").Where("api_key_id = ?", keyID).Pluck("key_group_id", &groupIDs).Error
|
||||
@@ -84,7 +85,7 @@ func (r *gormKeyRepository) GetGroupsForKey(keyID uint) ([]uint, error) {
|
||||
for _, id := range groupIDs {
|
||||
interfaceSlice = append(interfaceSlice, id)
|
||||
}
|
||||
r.store.SAdd(cacheKey, interfaceSlice...)
|
||||
r.store.SAdd(context.Background(), cacheKey, interfaceSlice...)
|
||||
}
|
||||
return groupIDs, nil
|
||||
}
|
||||
@@ -103,7 +104,7 @@ func (r *gormKeyRepository) GetMapping(groupID, keyID uint) (*models.GroupAPIKey
|
||||
return &mapping, err
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) UpdateMapping(mapping *models.GroupAPIKeyMapping) error {
|
||||
func (r *gormKeyRepository) UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error {
|
||||
err := r.executeTransactionWithRetry(func(tx *gorm.DB) error {
|
||||
return tx.Save(mapping).Error
|
||||
})
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -17,40 +18,40 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
CacheTTL = 5 * time.Minute
|
||||
EmptyPoolPlaceholder = "EMPTY_POOL"
|
||||
EmptyCacheTTL = 1 * time.Minute
|
||||
CacheTTL = 5 * time.Minute
|
||||
EmptyCacheTTL = 1 * time.Minute
|
||||
)
|
||||
|
||||
// SelectOneActiveKey 根据指定的轮询策略,从缓存中高效地选取一个可用的API密钥。
|
||||
|
||||
func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||||
// SelectOneActiveKey 根据指定的轮询策略,从单个密钥组缓存中选取一个可用的API密钥。
|
||||
func (r *gormKeyRepository) SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||||
if group == nil {
|
||||
return nil, nil, fmt.Errorf("group cannot be nil")
|
||||
}
|
||||
var keyIDStr string
|
||||
var err error
|
||||
|
||||
switch group.PollingStrategy {
|
||||
case models.StrategySequential:
|
||||
sequentialKey := fmt.Sprintf(KeyGroupSequential, group.ID)
|
||||
keyIDStr, err = r.store.Rotate(sequentialKey)
|
||||
|
||||
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
|
||||
case models.StrategyWeighted:
|
||||
lruKey := fmt.Sprintf(KeyGroupLRU, group.ID)
|
||||
results, zerr := r.store.ZRange(lruKey, 0, 0)
|
||||
if zerr == nil && len(results) > 0 {
|
||||
keyIDStr = results[0]
|
||||
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
|
||||
if zerr == nil {
|
||||
if len(results) > 0 {
|
||||
keyIDStr = results[0]
|
||||
} else {
|
||||
zerr = gorm.ErrRecordNotFound
|
||||
}
|
||||
}
|
||||
err = zerr
|
||||
|
||||
case models.StrategyRandom:
|
||||
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, group.ID)
|
||||
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, group.ID)
|
||||
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
|
||||
|
||||
default: // 默认或未指定策略时,使用基础的随机策略
|
||||
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
|
||||
default:
|
||||
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
|
||||
keyIDStr, err = r.store.SRandMember(activeKeySetKey)
|
||||
keyIDStr, err = r.store.SRandMember(ctx, activeKeySetKey)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil, gorm.ErrRecordNotFound
|
||||
@@ -58,65 +59,70 @@ func (r *gormKeyRepository) SelectOneActiveKey(group *models.KeyGroup) (*models.
|
||||
r.logger.WithError(err).Errorf("Failed to select key for group %d with strategy %s", group.ID, group.PollingStrategy)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if keyIDStr == "" {
|
||||
return nil, nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
|
||||
apiKey, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||||
keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
if parseErr != nil {
|
||||
r.logger.WithError(parseErr).Errorf("Invalid key ID format in group %d cache: %s", group.ID, keyIDStr)
|
||||
return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr)
|
||||
}
|
||||
apiKey, mapping, err := r.getKeyDetailsFromCache(ctx, uint(keyID), group.ID)
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Warnf("Cache inconsistency: Failed to get details for selected key ID %d", keyID)
|
||||
// TODO 可以在此加入重试逻辑,再次调用 SelectOneActiveKey(group)
|
||||
r.logger.WithError(err).Warnf("Cache inconsistency for key ID %d in group %d", keyID, group.ID)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if group.PollingStrategy == models.StrategyWeighted {
|
||||
go r.UpdateKeyUsageTimestamp(group.ID, uint(keyID))
|
||||
go func() {
|
||||
updateCtx, cancel := r.withTimeout(5 * time.Second)
|
||||
defer cancel()
|
||||
r.UpdateKeyUsageTimestamp(updateCtx, group.ID, uint(keyID))
|
||||
}()
|
||||
}
|
||||
|
||||
return apiKey, mapping, nil
|
||||
}
|
||||
|
||||
// SelectOneActiveKeyFromBasePool 为智能聚合模式设计的全新轮询器。
|
||||
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
|
||||
// 生成唯一的池ID,确保不同请求组合的轮询状态相互隔离
|
||||
poolID := generatePoolID(pool.CandidateGroups)
|
||||
// SelectOneActiveKeyFromBasePool 从智能聚合池中选取一个可用Key。
|
||||
func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error) {
|
||||
if pool == nil || len(pool.CandidateGroups) == 0 {
|
||||
return nil, nil, fmt.Errorf("invalid or empty base pool configuration")
|
||||
}
|
||||
poolID := r.generatePoolID(pool.CandidateGroups)
|
||||
log := r.logger.WithField("pool_id", poolID)
|
||||
|
||||
if err := r.ensureBasePoolCacheExists(pool, poolID); err != nil {
|
||||
log.WithError(err).Error("Failed to ensure BasePool cache exists.")
|
||||
if err := r.ensureBasePoolCacheExists(ctx, pool, poolID); err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.WithError(err).Error("Failed to ensure BasePool cache exists")
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var keyIDStr string
|
||||
var err error
|
||||
|
||||
switch pool.PollingStrategy {
|
||||
case models.StrategySequential:
|
||||
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
keyIDStr, err = r.store.Rotate(sequentialKey)
|
||||
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
|
||||
case models.StrategyWeighted:
|
||||
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
|
||||
results, zerr := r.store.ZRange(lruKey, 0, 0)
|
||||
if zerr == nil && len(results) > 0 {
|
||||
keyIDStr = results[0]
|
||||
results, zerr := r.store.ZRange(ctx, lruKey, 0, 0)
|
||||
if zerr == nil {
|
||||
if len(results) > 0 {
|
||||
keyIDStr = results[0]
|
||||
} else {
|
||||
zerr = gorm.ErrRecordNotFound
|
||||
}
|
||||
}
|
||||
err = zerr
|
||||
case models.StrategyRandom:
|
||||
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
|
||||
cooldownPoolKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
|
||||
keyIDStr, err = r.store.PopAndCycleSetMember(mainPoolKey, cooldownPoolKey)
|
||||
default: // 默认策略,应该在 ensureCache 中处理,但作为降级方案
|
||||
log.Warnf("Default polling strategy triggered inside selection. This should be rare.")
|
||||
|
||||
keyIDStr, err = r.store.PopAndCycleSetMember(ctx, mainPoolKey, cooldownPoolKey)
|
||||
default:
|
||||
log.Warnf("Unknown polling strategy '%s'. Using sequential as fallback.", pool.PollingStrategy)
|
||||
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
keyIDStr, err = r.store.LIndex(sequentialKey, 0)
|
||||
keyIDStr, err = r.store.Rotate(ctx, sequentialKey)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) {
|
||||
if errors.Is(err, store.ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
log.WithError(err).Errorf("Failed to select key from BasePool with strategy %s", pool.PollingStrategy)
|
||||
@@ -125,153 +131,266 @@ func (r *gormKeyRepository) SelectOneActiveKeyFromBasePool(pool *BasePool) (*mod
|
||||
if keyIDStr == "" {
|
||||
return nil, nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
|
||||
for _, group := range pool.CandidateGroups {
|
||||
apiKey, mapping, cacheErr := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||||
if cacheErr == nil && apiKey != nil && mapping != nil {
|
||||
|
||||
if pool.PollingStrategy == models.StrategyWeighted {
|
||||
|
||||
go r.updateKeyUsageTimestampForPool(poolID, uint(keyID))
|
||||
}
|
||||
return apiKey, group, nil
|
||||
go func() {
|
||||
bgCtx, cancel := r.withTimeout(5 * time.Second)
|
||||
defer cancel()
|
||||
r.refreshBasePoolHeartbeat(bgCtx, poolID)
|
||||
}()
|
||||
keyID, parseErr := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
if parseErr != nil {
|
||||
log.WithError(parseErr).Errorf("Invalid key ID format in BasePool cache: %s", keyIDStr)
|
||||
return nil, nil, fmt.Errorf("invalid key ID in cache: %w", parseErr)
|
||||
}
|
||||
keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID)
|
||||
groupIDStr, err := r.store.HGet(ctx, keyToGroupMapKey, keyIDStr)
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("Cache inconsistency: KeyID %d found in pool but not in key-to-group map", keyID)
|
||||
return nil, nil, errors.New("cache inconsistency: key has no origin group mapping")
|
||||
}
|
||||
groupID, parseErr := strconv.ParseUint(groupIDStr, 10, 64)
|
||||
if parseErr != nil {
|
||||
log.WithError(parseErr).Errorf("Invalid group ID format in key-to-group map for key %d: %s", keyID, groupIDStr)
|
||||
return nil, nil, errors.New("cache inconsistency: invalid group id in mapping")
|
||||
}
|
||||
apiKey, _, err := r.getKeyDetailsFromCache(ctx, uint(keyID), uint(groupID))
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("Cache inconsistency: Failed to get details for key %d in mapped group %d", keyID, groupID)
|
||||
return nil, nil, err
|
||||
}
|
||||
var originGroup *models.KeyGroup
|
||||
for _, g := range pool.CandidateGroups {
|
||||
if g.ID == uint(groupID) {
|
||||
originGroup = g
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
log.Errorf("Cache inconsistency: Selected KeyID %d from BasePool but could not find its origin group.", keyID)
|
||||
return nil, nil, errors.New("cache inconsistency: selected key has no origin group")
|
||||
if originGroup == nil {
|
||||
log.Errorf("Logic error: Mapped GroupID %d not found in pool's candidate groups list", groupID)
|
||||
return nil, nil, errors.New("cache inconsistency: mapped group not in candidate list")
|
||||
}
|
||||
if pool.PollingStrategy == models.StrategyWeighted {
|
||||
go func() {
|
||||
bgCtx, cancel := r.withTimeout(5 * time.Second)
|
||||
defer cancel()
|
||||
r.updateKeyUsageTimestampForPool(bgCtx, poolID, uint(keyID))
|
||||
}()
|
||||
}
|
||||
return apiKey, originGroup, nil
|
||||
}
|
||||
|
||||
// ensureBasePoolCacheExists 动态创建 BasePool 的 Redis 结构
|
||||
func (r *gormKeyRepository) ensureBasePoolCacheExists(pool *BasePool, poolID string) error {
|
||||
listKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
|
||||
// --- [逻辑优化] 提前处理“毒丸”,让逻辑更清晰 ---
|
||||
exists, err := r.store.Exists(listKey)
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Errorf("Failed to check existence for pool_id '%s'", poolID)
|
||||
return err // 直接返回读取错误
|
||||
// ensureBasePoolCacheExists 动态创建或验证 BasePool 的 Redis 缓存结构。
|
||||
func (r *gormKeyRepository) ensureBasePoolCacheExists(ctx context.Context, pool *BasePool, poolID string) error {
|
||||
heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID)
|
||||
emptyMarkerKey := fmt.Sprintf("basepool:empty:%s", poolID)
|
||||
// 预检查,快速失败
|
||||
if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
if exists {
|
||||
val, err := r.store.LIndex(listKey, 0)
|
||||
if err != nil {
|
||||
// 如果连 LIndex 都失败,说明缓存可能已损坏,允许重建
|
||||
r.logger.WithError(err).Warnf("Cache for pool_id '%s' exists but is unreadable. Forcing rebuild.", poolID)
|
||||
} else {
|
||||
if val == EmptyPoolPlaceholder {
|
||||
return gorm.ErrRecordNotFound // 已知为空,直接返回
|
||||
}
|
||||
return nil // 缓存有效,直接返回
|
||||
}
|
||||
}
|
||||
// --- [锁机制优化] 增加分布式锁,防止并发构建时的惊群效应 ---
|
||||
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
|
||||
acquired, err := r.store.SetNX(lockKey, []byte("1"), 10*time.Second) // 10秒锁超时
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock for basepool build.")
|
||||
return err
|
||||
}
|
||||
if !acquired {
|
||||
// 未获取到锁,等待一小段时间后重试,让持有锁的协程完成构建
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return r.ensureBasePoolCacheExists(pool, poolID)
|
||||
}
|
||||
defer r.store.Del(lockKey) // 确保在函数退出时释放锁
|
||||
// 双重检查,防止在获取锁的间隙,已有其他协程完成了构建
|
||||
if exists, _ := r.store.Exists(listKey); exists {
|
||||
if exists, _ := r.store.Exists(ctx, heartbeatKey); exists {
|
||||
return nil
|
||||
}
|
||||
r.logger.Infof("BasePool cache for pool_id '%s' not found or is unreadable. Building now...", poolID)
|
||||
var allActiveKeyIDs []string
|
||||
lruMembers := make(map[string]float64)
|
||||
// 获取分布式锁
|
||||
lockKey := fmt.Sprintf("lock:basepool:%s", poolID)
|
||||
if err := r.acquireLock(ctx, lockKey); err != nil {
|
||||
return err // acquireLock 内部已记录日志并返回明确错误
|
||||
}
|
||||
defer r.releaseLock(context.Background(), lockKey)
|
||||
// 双重检查锁定
|
||||
if exists, _ := r.store.Exists(ctx, emptyMarkerKey); exists {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
if exists, _ := r.store.Exists(ctx, heartbeatKey); exists {
|
||||
return nil
|
||||
}
|
||||
// 在执行重度操作前,最后检查一次上下文是否已取消
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
r.logger.Infof("Building BasePool cache for pool_id '%s'", poolID)
|
||||
// 手动聚合所有 Keys 并同时构建 key-to-group 映射
|
||||
keyToGroupMap := make(map[string]any)
|
||||
allKeyIDsSet := make(map[string]struct{})
|
||||
for _, group := range pool.CandidateGroups {
|
||||
activeKeySetKey := fmt.Sprintf(KeyGroup, group.ID)
|
||||
groupKeyIDs, err := r.store.SMembers(activeKeySetKey)
|
||||
|
||||
// --- [核心修正] ---
|
||||
// 这是整个问题的根源。我们绝不能在读取失败时,默默地`continue`。
|
||||
// 任何读取源数据的失败,都必须被视为一次构建过程的彻底失败,并立即中止。
|
||||
groupKeySet := fmt.Sprintf(KeyGroup, group.ID)
|
||||
groupKeyIDs, err := r.store.SMembers(ctx, groupKeySet)
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Errorf("FATAL: Failed to read active keys for group %d during BasePool build. Aborting build process for pool_id '%s'.", group.ID, poolID)
|
||||
// 返回这个瞬时错误。这会导致本次请求失败,但绝不会写入“毒丸”,
|
||||
// 从而给了下一次请求一个全新的、成功的机会。
|
||||
return err
|
||||
r.logger.WithError(err).Warnf("Failed to get members for group %d during pool build", group.ID)
|
||||
continue
|
||||
}
|
||||
// 只有在 SMembers 成功时,才继续处理
|
||||
allActiveKeyIDs = append(allActiveKeyIDs, groupKeyIDs...)
|
||||
for _, keyIDStr := range groupKeyIDs {
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
_, mapping, err := r.getKeyDetailsFromCache(uint(keyID), group.ID)
|
||||
if err == nil && mapping != nil {
|
||||
var score float64
|
||||
if mapping.LastUsedAt != nil {
|
||||
score = float64(mapping.LastUsedAt.UnixMilli())
|
||||
}
|
||||
lruMembers[keyIDStr] = score
|
||||
groupIDStr := strconv.FormatUint(uint64(group.ID), 10)
|
||||
for _, keyID := range groupKeyIDs {
|
||||
if _, exists := allKeyIDsSet[keyID]; !exists {
|
||||
allKeyIDsSet[keyID] = struct{}{}
|
||||
keyToGroupMap[keyID] = groupIDStr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- [逻辑修正] ---
|
||||
// 只有在“我们成功读取了所有数据,但发现数据本身是空的”这种情况下,
|
||||
// 才允许写入“毒丸”。
|
||||
if len(allActiveKeyIDs) == 0 {
|
||||
r.logger.Warnf("No active keys found for any candidate groups for pool_id '%s'. Setting empty pool placeholder.", poolID)
|
||||
pipe := r.store.Pipeline()
|
||||
pipe.LPush(listKey, EmptyPoolPlaceholder)
|
||||
pipe.Expire(listKey, EmptyCacheTTL)
|
||||
if err := pipe.Exec(); err != nil {
|
||||
r.logger.WithError(err).Errorf("Failed to set empty pool placeholder for pool_id '%s'", poolID)
|
||||
// 处理空池情况
|
||||
if len(allKeyIDsSet) == 0 {
|
||||
emptyCacheTTL := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute / 2
|
||||
if emptyCacheTTL < time.Minute {
|
||||
emptyCacheTTL = time.Minute
|
||||
}
|
||||
r.logger.Warnf("No active keys found for pool_id '%s', setting empty marker.", poolID)
|
||||
if err := r.store.Set(ctx, emptyMarkerKey, []byte("1"), emptyCacheTTL); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to set empty marker for pool_id '%s'", poolID)
|
||||
}
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
// 使用管道填充所有轮询结构
|
||||
pipe := r.store.Pipeline()
|
||||
// 1. 顺序
|
||||
pipe.LPush(fmt.Sprintf(BasePoolSequential, poolID), toInterfaceSlice(allActiveKeyIDs)...)
|
||||
// 2. 随机
|
||||
pipe.SAdd(fmt.Sprintf(BasePoolRandomMain, poolID), toInterfaceSlice(allActiveKeyIDs)...)
|
||||
|
||||
// 设置合理的过期时间,例如5分钟,以防止孤儿数据
|
||||
pipe.Expire(fmt.Sprintf(BasePoolSequential, poolID), CacheTTL)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolRandomMain, poolID), CacheTTL)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolRandomCooldown, poolID), CacheTTL)
|
||||
pipe.Expire(fmt.Sprintf(BasePoolLRU, poolID), CacheTTL)
|
||||
|
||||
allActiveKeyIDs := make([]string, 0, len(allKeyIDsSet))
|
||||
for keyID := range allKeyIDsSet {
|
||||
allActiveKeyIDs = append(allActiveKeyIDs, keyID)
|
||||
}
|
||||
// 使用 Pipeline 原子化构建所有缓存结构
|
||||
basePoolTTL := time.Duration(r.config.Repository.BasePoolTTLMinutes) * time.Minute
|
||||
basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute
|
||||
mainPoolKey := fmt.Sprintf(BasePoolRandomMain, poolID)
|
||||
sequentialKey := fmt.Sprintf(BasePoolSequential, poolID)
|
||||
cooldownKey := fmt.Sprintf(BasePoolRandomCooldown, poolID)
|
||||
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
|
||||
keyToGroupMapKey := fmt.Sprintf("basepool:%s:key_to_group", poolID)
|
||||
pipe := r.store.Pipeline(ctx)
|
||||
pipe.Del(mainPoolKey, sequentialKey, cooldownKey, lruKey, emptyMarkerKey, keyToGroupMapKey)
|
||||
pipe.SAdd(mainPoolKey, r.toInterfaceSlice(allActiveKeyIDs)...)
|
||||
pipe.LPush(sequentialKey, r.toInterfaceSlice(allActiveKeyIDs)...)
|
||||
if len(keyToGroupMap) > 0 {
|
||||
pipe.HSet(keyToGroupMapKey, keyToGroupMap)
|
||||
pipe.Expire(keyToGroupMapKey, basePoolTTL)
|
||||
}
|
||||
pipe.Expire(mainPoolKey, basePoolTTL)
|
||||
pipe.Expire(sequentialKey, basePoolTTL)
|
||||
pipe.Expire(cooldownKey, basePoolTTL)
|
||||
pipe.Expire(lruKey, basePoolTTL)
|
||||
pipe.Set(heartbeatKey, []byte("1"), basePoolTTI)
|
||||
if err := pipe.Exec(); err != nil {
|
||||
r.logger.WithError(err).Errorf("Failed to populate polling structures for pool_id '%s'", poolID)
|
||||
cleanupCtx, cancel := r.withTimeout(5 * time.Second)
|
||||
defer cancel()
|
||||
r.store.Del(cleanupCtx, mainPoolKey, sequentialKey, cooldownKey, lruKey, heartbeatKey, emptyMarkerKey, keyToGroupMapKey)
|
||||
return err
|
||||
}
|
||||
|
||||
if len(lruMembers) > 0 {
|
||||
r.store.ZAdd(fmt.Sprintf(BasePoolLRU, poolID), lruMembers)
|
||||
}
|
||||
// 异步填充 LRU 缓存,并传入已构建好的映射
|
||||
go r.populateBasePoolLRUCache(context.Background(), poolID, allActiveKeyIDs, keyToGroupMap)
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- 辅助方法 ---
|
||||
// acquireLock 封装了带重试和指数退避的分布式锁获取逻辑。
|
||||
func (r *gormKeyRepository) acquireLock(ctx context.Context, lockKey string) error {
|
||||
const (
|
||||
lockTTL = 30 * time.Second
|
||||
lockMaxRetries = 5
|
||||
lockBaseBackoff = 50 * time.Millisecond
|
||||
)
|
||||
for i := 0; i < lockMaxRetries; i++ {
|
||||
acquired, err := r.store.SetNX(ctx, lockKey, []byte("1"), lockTTL)
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Error("Failed to attempt acquiring distributed lock")
|
||||
return err
|
||||
}
|
||||
if acquired {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(lockBaseBackoff * (1 << i))
|
||||
}
|
||||
return fmt.Errorf("failed to acquire lock for key '%s' after %d retries", lockKey, lockMaxRetries)
|
||||
}
|
||||
|
||||
// releaseLock 封装了分布式锁的释放逻辑。
|
||||
func (r *gormKeyRepository) releaseLock(ctx context.Context, lockKey string) {
|
||||
if err := r.store.Del(ctx, lockKey); err != nil {
|
||||
r.logger.WithError(err).Errorf("Failed to release distributed lock for key '%s'", lockKey)
|
||||
}
|
||||
}
|
||||
|
||||
// withTimeout 是 context.WithTimeout 的一个简单包装,便于测试和模拟。
|
||||
func (r *gormKeyRepository) withTimeout(duration time.Duration) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), duration)
|
||||
}
|
||||
|
||||
// refreshBasePoolHeartbeat 异步刷新心跳Key的TTI
|
||||
func (r *gormKeyRepository) refreshBasePoolHeartbeat(ctx context.Context, poolID string) {
|
||||
basePoolTTI := time.Duration(r.config.Repository.BasePoolTTIMinutes) * time.Minute
|
||||
heartbeatKey := fmt.Sprintf("basepool:%s:heartbeat", poolID)
|
||||
// 使用 EXPIRE 命令来刷新,如果Key不存在,它什么也不做,是安全的
|
||||
if err := r.store.Expire(ctx, heartbeatKey, basePoolTTI); err != nil {
|
||||
if ctx.Err() == nil { // 避免在context取消后打印不必要的错误
|
||||
r.logger.WithError(err).Warnf("Failed to refresh heartbeat for pool_id '%s'", poolID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// populateBasePoolLRUCache 异步填充 BasePool 的 LRU 缓存结构
|
||||
func (r *gormKeyRepository) populateBasePoolLRUCache(
|
||||
parentCtx context.Context,
|
||||
currentPoolID string,
|
||||
keys []string,
|
||||
keyToGroupMap map[string]any,
|
||||
) {
|
||||
lruMembers := make(map[string]float64, len(keys))
|
||||
for _, keyIDStr := range keys {
|
||||
select {
|
||||
case <-parentCtx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
groupIDStr, ok := keyToGroupMap[keyIDStr].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
keyID, _ := strconv.ParseUint(keyIDStr, 10, 64)
|
||||
groupID, _ := strconv.ParseUint(groupIDStr, 10, 64)
|
||||
mappingKey := fmt.Sprintf(KeyMapping, groupID, keyID)
|
||||
data, err := r.store.Get(parentCtx, mappingKey)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var mapping models.GroupAPIKeyMapping
|
||||
if json.Unmarshal(data, &mapping) == nil {
|
||||
var score float64
|
||||
if mapping.LastUsedAt != nil {
|
||||
score = float64(mapping.LastUsedAt.UnixMilli())
|
||||
}
|
||||
lruMembers[keyIDStr] = score
|
||||
}
|
||||
}
|
||||
if len(lruMembers) > 0 {
|
||||
lruKey := fmt.Sprintf(BasePoolLRU, currentPoolID)
|
||||
if err := r.store.ZAdd(parentCtx, lruKey, lruMembers); err != nil {
|
||||
if parentCtx.Err() == nil {
|
||||
r.logger.WithError(err).Warnf("Failed to populate LRU cache for pool '%s'", currentPoolID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateKeyUsageTimestampForPool 更新 BasePool 的 LUR ZSET
|
||||
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(poolID string, keyID uint) {
|
||||
func (r *gormKeyRepository) updateKeyUsageTimestampForPool(ctx context.Context, poolID string, keyID uint) {
|
||||
lruKey := fmt.Sprintf(BasePoolLRU, poolID)
|
||||
r.store.ZAdd(lruKey, map[string]float64{
|
||||
strconv.FormatUint(uint64(keyID), 10): nowMilli(),
|
||||
err := r.store.ZAdd(ctx, lruKey, map[string]float64{
|
||||
strconv.FormatUint(uint64(keyID), 10): r.nowMilli(),
|
||||
})
|
||||
if err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to update key usage for pool %s", poolID)
|
||||
}
|
||||
}
|
||||
|
||||
// generatePoolID 根据候选组ID列表生成一个稳定的、唯一的字符串ID
|
||||
func generatePoolID(groups []*models.KeyGroup) string {
|
||||
func (r *gormKeyRepository) generatePoolID(groups []*models.KeyGroup) string {
|
||||
ids := make([]int, len(groups))
|
||||
for i, g := range groups {
|
||||
ids[i] = int(g.ID)
|
||||
}
|
||||
sort.Ints(ids)
|
||||
|
||||
h := sha1.New()
|
||||
io.WriteString(h, fmt.Sprintf("%v", ids))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// toInterfaceSlice 类型转换辅助函数
|
||||
func toInterfaceSlice(slice []string) []interface{} {
|
||||
func (r *gormKeyRepository) toInterfaceSlice(slice []string) []interface{} {
|
||||
result := make([]interface{}, len(slice))
|
||||
for i, v := range slice {
|
||||
result[i] = v
|
||||
@@ -280,13 +399,13 @@ func toInterfaceSlice(slice []string) []interface{} {
|
||||
}
|
||||
|
||||
// nowMilli 返回当前的Unix毫秒时间戳,用于LRU/Weighted策略
|
||||
func nowMilli() float64 {
|
||||
func (r *gormKeyRepository) nowMilli() float64 {
|
||||
return float64(time.Now().UnixMilli())
|
||||
}
|
||||
|
||||
// getKeyDetailsFromCache 从缓存中获取Key和Mapping的JSON数据。
|
||||
func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||||
apiKeyJSON, err := r.store.Get(fmt.Sprintf(KeyDetails, keyID))
|
||||
func (r *gormKeyRepository) getKeyDetailsFromCache(ctx context.Context, keyID, groupID uint) (*models.APIKey, *models.GroupAPIKeyMapping, error) {
|
||||
apiKeyJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyDetails, keyID))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get key details for key %d: %w", keyID, err)
|
||||
}
|
||||
@@ -295,7 +414,7 @@ func (r *gormKeyRepository) getKeyDetailsFromCache(keyID, groupID uint) (*models
|
||||
return nil, nil, fmt.Errorf("failed to unmarshal api key %d: %w", keyID, err)
|
||||
}
|
||||
|
||||
mappingJSON, err := r.store.Get(fmt.Sprintf(KeyMapping, groupID, keyID))
|
||||
mappingJSON, err := r.store.Get(ctx, fmt.Sprintf(KeyMapping, groupID, keyID))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get mapping details for key %d in group %d: %w", keyID, groupID, err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
// Filename: internal/repository/key_writer.go
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
@@ -9,7 +11,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
|
||||
func (r *gormKeyRepository) UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint) {
|
||||
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
|
||||
timestamp := float64(time.Now().UnixMilli())
|
||||
|
||||
@@ -17,52 +19,51 @@ func (r *gormKeyRepository) UpdateKeyUsageTimestamp(groupID, keyID uint) {
|
||||
strconv.FormatUint(uint64(keyID), 10): timestamp,
|
||||
}
|
||||
|
||||
if err := r.store.ZAdd(lruKey, members); err != nil {
|
||||
if err := r.store.ZAdd(ctx, lruKey, members); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to update usage timestamp for key %d in group %d", keyID, groupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
func (r *gormKeyRepository) SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
r.logger.Infof("SYNC: Directly updating polling caches for G:%d K:%d -> %s", groupID, keyID, newStatus)
|
||||
r.updatePollingCachesLogic(groupID, keyID, newStatus)
|
||||
r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
func (r *gormKeyRepository) HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
r.logger.Infof("EVENT: Updating polling caches for G:%d K:%d -> %s from an event", groupID, keyID, newStatus)
|
||||
r.updatePollingCachesLogic(groupID, keyID, newStatus)
|
||||
r.updatePollingCachesLogic(ctx, groupID, keyID, newStatus)
|
||||
}
|
||||
|
||||
func (r *gormKeyRepository) updatePollingCachesLogic(groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
func (r *gormKeyRepository) updatePollingCachesLogic(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus) {
|
||||
keyIDStr := strconv.FormatUint(uint64(keyID), 10)
|
||||
sequentialKey := fmt.Sprintf(KeyGroupSequential, groupID)
|
||||
lruKey := fmt.Sprintf(KeyGroupLRU, groupID)
|
||||
mainPoolKey := fmt.Sprintf(KeyGroupRandomMain, groupID)
|
||||
cooldownPoolKey := fmt.Sprintf(KeyGroupRandomCooldown, groupID)
|
||||
|
||||
_ = r.store.LRem(sequentialKey, 0, keyIDStr)
|
||||
_ = r.store.ZRem(lruKey, keyIDStr)
|
||||
_ = r.store.SRem(mainPoolKey, keyIDStr)
|
||||
_ = r.store.SRem(cooldownPoolKey, keyIDStr)
|
||||
_ = r.store.LRem(ctx, sequentialKey, 0, keyIDStr)
|
||||
_ = r.store.ZRem(ctx, lruKey, keyIDStr)
|
||||
_ = r.store.SRem(ctx, mainPoolKey, keyIDStr)
|
||||
_ = r.store.SRem(ctx, cooldownPoolKey, keyIDStr)
|
||||
|
||||
if newStatus == models.StatusActive {
|
||||
if err := r.store.LPush(sequentialKey, keyIDStr); err != nil {
|
||||
if err := r.store.LPush(ctx, sequentialKey, keyIDStr); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to add key %d to sequential list for group %d", keyID, groupID)
|
||||
}
|
||||
members := map[string]float64{keyIDStr: 0}
|
||||
if err := r.store.ZAdd(lruKey, members); err != nil {
|
||||
if err := r.store.ZAdd(ctx, lruKey, members); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to add key %d to LRU zset for group %d", keyID, groupID)
|
||||
}
|
||||
if err := r.store.SAdd(mainPoolKey, keyIDStr); err != nil {
|
||||
if err := r.store.SAdd(ctx, mainPoolKey, keyIDStr); err != nil {
|
||||
r.logger.WithError(err).Warnf("Failed to add key %d to random main pool for group %d", keyID, groupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateKeyStatusAfterRequest is the new central hub for handling feedback.
|
||||
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
|
||||
func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError) {
|
||||
if success {
|
||||
if group.PollingStrategy == models.StrategyWeighted {
|
||||
go r.UpdateKeyUsageTimestamp(group.ID, key.ID)
|
||||
go r.UpdateKeyUsageTimestamp(context.Background(), group.ID, key.ID)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -72,6 +73,5 @@ func (r *gormKeyRepository) UpdateKeyStatusAfterRequest(group *models.KeyGroup,
|
||||
}
|
||||
r.logger.Warnf("Request failed for KeyID %d in GroupID %d with error: %s. Temporarily removing from active polling caches.", key.ID, group.ID, apiErr.Message)
|
||||
|
||||
// This call is correct. It uses the synchronous, direct method.
|
||||
r.SyncKeyStatusInPollingCaches(group.ID, key.ID, models.StatusCooldown)
|
||||
r.SyncKeyStatusInPollingCaches(ctx, group.ID, key.ID, models.StatusCooldown)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gemini-balancer/internal/config"
|
||||
"gemini-balancer/internal/crypto"
|
||||
"gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
@@ -22,8 +24,8 @@ type BasePool struct {
|
||||
|
||||
type KeyRepository interface {
|
||||
// --- 核心选取与调度 --- key_selector
|
||||
SelectOneActiveKey(group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
|
||||
SelectOneActiveKeyFromBasePool(pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
|
||||
SelectOneActiveKey(ctx context.Context, group *models.KeyGroup) (*models.APIKey, *models.GroupAPIKeyMapping, error)
|
||||
SelectOneActiveKeyFromBasePool(ctx context.Context, pool *BasePool) (*models.APIKey, *models.KeyGroup, error)
|
||||
|
||||
// --- 加密与解密 --- key_crud
|
||||
Decrypt(key *models.APIKey) error
|
||||
@@ -37,16 +39,16 @@ type KeyRepository interface {
|
||||
GetKeyByID(id uint) (*models.APIKey, error)
|
||||
GetKeyByValue(keyValue string) (*models.APIKey, error)
|
||||
GetKeysByValues(keyValues []string) ([]models.APIKey, error)
|
||||
GetKeysByIDs(ids []uint) ([]models.APIKey, error) // [新增] 根据一组主键ID批量获取Key
|
||||
GetKeysByIDs(ids []uint) ([]models.APIKey, error)
|
||||
GetKeysByGroup(groupID uint) ([]models.APIKey, error)
|
||||
CountByGroup(groupID uint) (int64, error)
|
||||
|
||||
// --- 多对多关系管理 --- key_mapping
|
||||
LinkKeysToGroup(groupID uint, keyIDs []uint) error
|
||||
UnlinkKeysFromGroup(groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
|
||||
GetGroupsForKey(keyID uint) ([]uint, error)
|
||||
LinkKeysToGroup(ctx context.Context, groupID uint, keyIDs []uint) error
|
||||
UnlinkKeysFromGroup(ctx context.Context, groupID uint, keyIDs []uint) (unlinkedCount int64, err error)
|
||||
GetGroupsForKey(ctx context.Context, keyID uint) ([]uint, error)
|
||||
GetMapping(groupID, keyID uint) (*models.GroupAPIKeyMapping, error)
|
||||
UpdateMapping(mapping *models.GroupAPIKeyMapping) error
|
||||
UpdateMapping(ctx context.Context, mapping *models.GroupAPIKeyMapping) error
|
||||
GetPaginatedKeysAndMappingsByGroup(params *models.APIKeyQueryParams) ([]*models.APIKeyDetails, int64, error)
|
||||
GetKeysByValuesAndGroupID(values []string, groupID uint) ([]models.APIKey, error)
|
||||
FindKeyValuesByStatus(groupID uint, statuses []string) ([]string, error)
|
||||
@@ -55,8 +57,8 @@ type KeyRepository interface {
|
||||
UpdateMappingWithoutCache(mapping *models.GroupAPIKeyMapping) error
|
||||
|
||||
// --- 缓存管理 --- key_cache
|
||||
LoadAllKeysToStore() error
|
||||
HandleCacheUpdateEventBatch(mappings []*models.GroupAPIKeyMapping) error
|
||||
LoadAllKeysToStore(ctx context.Context) error
|
||||
HandleCacheUpdateEventBatch(ctx context.Context, mappings []*models.GroupAPIKeyMapping) error
|
||||
|
||||
// --- 维护与后台任务 --- key_maintenance
|
||||
StreamKeysToWriter(groupID uint, statusFilter string, writer io.Writer) error
|
||||
@@ -65,16 +67,14 @@ type KeyRepository interface {
|
||||
DeleteOrphanKeys() (int64, error)
|
||||
DeleteOrphanKeysTx(tx *gorm.DB) (int64, error)
|
||||
GetActiveMasterKeys() ([]*models.APIKey, error)
|
||||
UpdateAPIKeyStatus(keyID uint, status models.MasterAPIKeyStatus) error
|
||||
UpdateAPIKeyStatus(ctx context.Context, keyID uint, status models.MasterAPIKeyStatus) error
|
||||
HardDeleteSoftDeletedBefore(date time.Time) (int64, error)
|
||||
|
||||
// --- 轮询策略的"写"操作 --- key_writer
|
||||
UpdateKeyUsageTimestamp(groupID, keyID uint)
|
||||
// 同步更新缓存,供核心业务使用
|
||||
SyncKeyStatusInPollingCaches(groupID, keyID uint, newStatus models.APIKeyStatus)
|
||||
// 异步更新缓存,供事件订阅者使用
|
||||
HandleCacheUpdateEvent(groupID, keyID uint, newStatus models.APIKeyStatus)
|
||||
UpdateKeyStatusAfterRequest(group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
|
||||
UpdateKeyUsageTimestamp(ctx context.Context, groupID, keyID uint)
|
||||
SyncKeyStatusInPollingCaches(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus)
|
||||
HandleCacheUpdateEvent(ctx context.Context, groupID, keyID uint, newStatus models.APIKeyStatus)
|
||||
UpdateKeyStatusAfterRequest(ctx context.Context, group *models.KeyGroup, key *models.APIKey, success bool, apiErr *errors.APIError)
|
||||
}
|
||||
|
||||
type GroupRepository interface {
|
||||
@@ -88,18 +88,20 @@ type gormKeyRepository struct {
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
crypto *crypto.Service
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
type gormGroupRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service) KeyRepository {
|
||||
func NewKeyRepository(db *gorm.DB, s store.Store, logger *logrus.Logger, crypto *crypto.Service, cfg *config.Config) KeyRepository {
|
||||
return &gormKeyRepository{
|
||||
db: db,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "repository.key🔗"),
|
||||
crypto: crypto,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func NewRouter(
|
||||
@@ -35,6 +36,7 @@ func NewRouter(
|
||||
settingHandler *handlers.SettingHandler,
|
||||
dashboardHandler *handlers.DashboardHandler,
|
||||
taskHandler *handlers.TaskHandler,
|
||||
wsHandler *handlers.WebSocketHandler,
|
||||
// Web Page Handlers
|
||||
webAuthHandler *webhandlers.WebAuthHandler,
|
||||
pageHandler *webhandlers.PageHandler,
|
||||
@@ -42,70 +44,215 @@ func NewRouter(
|
||||
upstreamModule *upstream.Module,
|
||||
proxyModule *proxy.Module,
|
||||
) *gin.Engine {
|
||||
// === 1. 创建全局 Logger(统一管理)===
|
||||
logger := createLogger(cfg)
|
||||
|
||||
// === 2. 设置 Gin 运行模式 ===
|
||||
if cfg.Log.Level != "debug" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
router := gin.Default()
|
||||
|
||||
router.Static("/static", "./web/static")
|
||||
// CORS 配置
|
||||
config := cors.Config{
|
||||
// 允许前端的来源。在生产环境中,需改为实际域名
|
||||
AllowOrigins: []string{"http://localhost:9000"},
|
||||
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
|
||||
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization"},
|
||||
ExposeHeaders: []string{"Content-Length"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 12 * time.Hour,
|
||||
}
|
||||
router.Use(cors.New(config))
|
||||
isDebug := gin.Mode() != gin.ReleaseMode
|
||||
router.HTMLRender = pongo.New("web/templates", isDebug)
|
||||
// === 3. 创建 Router(使用 gin.New() 以便完全控制中间件)===
|
||||
router := gin.New()
|
||||
|
||||
// --- 基础设施 ---
|
||||
router.GET("/", func(c *gin.Context) { c.Redirect(http.StatusMovedPermanently, "/dashboard") })
|
||||
router.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) })
|
||||
// --- 统一的认证管道 ---
|
||||
apiAdminAuth := middleware.APIAdminAuthMiddleware(securityService)
|
||||
// === 4. 注册全局中间件(按执行顺序)===
|
||||
setupGlobalMiddleware(router, logger)
|
||||
|
||||
// === 5. 配置静态文件和模板 ===
|
||||
setupStaticAndTemplates(router, logger)
|
||||
|
||||
// === 6. 配置 CORS ===
|
||||
setupCORS(router, cfg)
|
||||
|
||||
// === 7. 注册基础路由 ===
|
||||
setupBasicRoutes(router)
|
||||
|
||||
// === 8. 创建认证中间件 ===
|
||||
apiAdminAuth := middleware.APIAdminAuthMiddleware(securityService, logger)
|
||||
webAdminAuth := middleware.WebAdminAuthMiddleware(securityService)
|
||||
|
||||
router.Use(gin.RecoveryWithWriter(os.Stdout))
|
||||
// --- 将正确的依赖和中间件管道传递下去 ---
|
||||
registerProxyRoutes(router, proxyHandler, securityService)
|
||||
registerAdminRoutes(router, apiAdminAuth, keyGroupHandler, tokensHandler, apiKeyHandler, logHandler, settingHandler, dashboardHandler, taskHandler, upstreamModule, proxyModule)
|
||||
registerPublicAPIRoutes(router, apiAuthHandler, securityService, settingsManager)
|
||||
// === 9. 注册业务路由(按功能分组)===
|
||||
registerPublicAPIRoutes(router, apiAuthHandler, securityService, settingsManager, logger)
|
||||
registerWebRoutes(router, webAdminAuth, webAuthHandler, pageHandler)
|
||||
registerAdminRoutes(router, apiAdminAuth, keyGroupHandler, tokensHandler, apiKeyHandler,
|
||||
logHandler, settingHandler, dashboardHandler, taskHandler, upstreamModule, proxyModule)
|
||||
registerWebSocketRoutes(router, wsHandler)
|
||||
registerProxyRoutes(router, proxyHandler, securityService, logger)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
// ==================== 辅助函数 ====================
|
||||
|
||||
// createLogger 创建并配置全局 Logger
|
||||
func createLogger(cfg *config.Config) *logrus.Logger {
|
||||
logger := logrus.New()
|
||||
|
||||
// 设置日志格式
|
||||
if cfg.Log.Format == "json" {
|
||||
logger.SetFormatter(&logrus.JSONFormatter{
|
||||
TimestampFormat: time.RFC3339,
|
||||
})
|
||||
} else {
|
||||
logger.SetFormatter(&logrus.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
TimestampFormat: "2006-01-02 15:04:05",
|
||||
})
|
||||
}
|
||||
|
||||
// 设置日志级别
|
||||
switch cfg.Log.Level {
|
||||
case "debug":
|
||||
logger.SetLevel(logrus.DebugLevel)
|
||||
case "info":
|
||||
logger.SetLevel(logrus.InfoLevel)
|
||||
case "warn":
|
||||
logger.SetLevel(logrus.WarnLevel)
|
||||
case "error":
|
||||
logger.SetLevel(logrus.ErrorLevel)
|
||||
default:
|
||||
logger.SetLevel(logrus.InfoLevel)
|
||||
}
|
||||
|
||||
// 设置输出(可选:输出到文件)
|
||||
logger.SetOutput(os.Stdout)
|
||||
|
||||
return logger
|
||||
}
|
||||
|
||||
// setupGlobalMiddleware 设置全局中间件
|
||||
func setupGlobalMiddleware(router *gin.Engine, logger *logrus.Logger) {
|
||||
// 1. 请求 ID 中间件(用于链路追踪)
|
||||
router.Use(middleware.RequestIDMiddleware())
|
||||
|
||||
// 2. 数据脱敏中间件(在日志前执行)
|
||||
router.Use(middleware.RedactionMiddleware())
|
||||
|
||||
// 3. 日志中间件
|
||||
router.Use(middleware.LogrusLogger(logger))
|
||||
|
||||
// 4. 错误恢复中间件
|
||||
router.Use(gin.RecoveryWithWriter(os.Stdout))
|
||||
}
|
||||
|
||||
// setupStaticAndTemplates 配置静态文件和模板
|
||||
func setupStaticAndTemplates(router *gin.Engine, logger *logrus.Logger) {
|
||||
router.Static("/static", "./web/static")
|
||||
|
||||
isDebug := gin.Mode() != gin.ReleaseMode
|
||||
router.HTMLRender = pongo.New("web/templates", isDebug, logger)
|
||||
}
|
||||
|
||||
// setupCORS 配置 CORS
|
||||
func setupCORS(router *gin.Engine, cfg *config.Config) {
|
||||
corsConfig := cors.Config{
|
||||
AllowOrigins: getCORSOrigins(cfg),
|
||||
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
|
||||
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Request-Id"},
|
||||
ExposeHeaders: []string{"Content-Length", "X-Request-Id"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 12 * time.Hour,
|
||||
}
|
||||
router.Use(cors.New(corsConfig))
|
||||
}
|
||||
|
||||
// getCORSOrigins 获取 CORS 允许的来源
|
||||
func getCORSOrigins(cfg *config.Config) []string {
|
||||
// 默认值
|
||||
origins := []string{"http://localhost:9000"}
|
||||
|
||||
// 从配置读取(修复:移除 nil 检查)
|
||||
if len(cfg.Server.CORSOrigins) > 0 {
|
||||
origins = cfg.Server.CORSOrigins
|
||||
}
|
||||
|
||||
return origins
|
||||
}
|
||||
|
||||
// setupBasicRoutes 设置基础路由
|
||||
func setupBasicRoutes(router *gin.Engine) {
|
||||
// 根路径重定向
|
||||
router.GET("/", func(c *gin.Context) {
|
||||
c.Redirect(http.StatusMovedPermanently, "/dashboard")
|
||||
})
|
||||
|
||||
// 健康检查
|
||||
router.GET("/health", handleHealthCheck)
|
||||
|
||||
// 版本信息(可选)
|
||||
router.GET("/version", handleVersion)
|
||||
}
|
||||
|
||||
// handleHealthCheck 健康检查处理器
|
||||
func handleHealthCheck(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "ok",
|
||||
"time": time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
// handleVersion 版本信息处理器
|
||||
func handleVersion(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"version": "1.0.0", // 可以从配置或编译时变量读取
|
||||
"build": "latest",
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 路由注册函数 ====================
|
||||
|
||||
// registerProxyRoutes 注册代理路由
|
||||
func registerProxyRoutes(
|
||||
router *gin.Engine, proxyHandler *handlers.ProxyHandler, securityService *service.SecurityService,
|
||||
router *gin.Engine,
|
||||
proxyHandler *handlers.ProxyHandler,
|
||||
securityService *service.SecurityService,
|
||||
logger *logrus.Logger,
|
||||
) {
|
||||
// 通用的代理认证中间件
|
||||
proxyAuthMiddleware := middleware.ProxyAuthMiddleware(securityService)
|
||||
// --- 模式一: 智能聚合模式 (根路径) ---
|
||||
// /v1 和 /v1beta 路径作为默认入口,服务于 BasePool 聚合逻辑
|
||||
// 创建代理认证中间件
|
||||
proxyAuthMiddleware := middleware.ProxyAuthMiddleware(securityService, logger)
|
||||
|
||||
// 模式一: 智能聚合模式(默认入口)
|
||||
registerAggregateProxyRoutes(router, proxyHandler, proxyAuthMiddleware)
|
||||
|
||||
// 模式二: 精确路由模式(按组名路由)
|
||||
registerGroupProxyRoutes(router, proxyHandler, proxyAuthMiddleware)
|
||||
}
|
||||
|
||||
// registerAggregateProxyRoutes 注册聚合代理路由
|
||||
func registerAggregateProxyRoutes(
|
||||
router *gin.Engine,
|
||||
proxyHandler *handlers.ProxyHandler,
|
||||
authMiddleware gin.HandlerFunc,
|
||||
) {
|
||||
// /v1 路径组
|
||||
v1 := router.Group("/v1")
|
||||
v1.Use(proxyAuthMiddleware)
|
||||
v1.Use(authMiddleware)
|
||||
{
|
||||
v1.Any("/*path", proxyHandler.HandleProxy)
|
||||
}
|
||||
|
||||
// /v1beta 路径组
|
||||
v1beta := router.Group("/v1beta")
|
||||
v1beta.Use(proxyAuthMiddleware)
|
||||
v1beta.Use(authMiddleware)
|
||||
{
|
||||
v1beta.Any("/*path", proxyHandler.HandleProxy)
|
||||
}
|
||||
// --- 模式二: 精确路由模式 (/proxy/:group_name) ---
|
||||
// 创建一个新的、物理隔离的路由组,用于按组名精确路由
|
||||
}
|
||||
|
||||
// registerGroupProxyRoutes 注册分组代理路由
|
||||
func registerGroupProxyRoutes(
|
||||
router *gin.Engine,
|
||||
proxyHandler *handlers.ProxyHandler,
|
||||
authMiddleware gin.HandlerFunc,
|
||||
) {
|
||||
proxyGroup := router.Group("/proxy/:group_name")
|
||||
proxyGroup.Use(proxyAuthMiddleware)
|
||||
proxyGroup.Use(authMiddleware)
|
||||
{
|
||||
// 捕获所有子路径 (例如 /v1/chat/completions),并全部交给同一个 ProxyHandler。
|
||||
proxyGroup.Any("/*path", proxyHandler.HandleProxy)
|
||||
}
|
||||
}
|
||||
|
||||
// registerAdminRoutes
|
||||
// registerAdminRoutes 注册管理后台 API 路由
|
||||
func registerAdminRoutes(
|
||||
router *gin.Engine,
|
||||
authMiddleware gin.HandlerFunc,
|
||||
@@ -121,74 +268,115 @@ func registerAdminRoutes(
|
||||
) {
|
||||
admin := router.Group("/admin", authMiddleware)
|
||||
{
|
||||
// --- KeyGroup Base Routes ---
|
||||
admin.POST("/keygroups", keyGroupHandler.CreateKeyGroup)
|
||||
admin.GET("/keygroups", keyGroupHandler.GetKeyGroups)
|
||||
admin.PUT("/keygroups/order", keyGroupHandler.UpdateKeyGroupOrder)
|
||||
// --- KeyGroup Specific Routes (by :id) ---
|
||||
admin.GET("/keygroups/:id", keyGroupHandler.GetKeyGroups)
|
||||
admin.PUT("/keygroups/:id", keyGroupHandler.UpdateKeyGroup)
|
||||
admin.DELETE("/keygroups/:id", keyGroupHandler.DeleteKeyGroup)
|
||||
admin.POST("/keygroups/:id/clone", keyGroupHandler.CloneKeyGroup)
|
||||
admin.GET("/keygroups/:id/stats", keyGroupHandler.GetKeyGroupStats)
|
||||
admin.POST("/keygroups/:id/bulk-actions", apiKeyHandler.HandleBulkAction)
|
||||
// --- APIKey Sub-resource Routes under a KeyGroup ---
|
||||
keyGroupAPIKeys := admin.Group("/keygroups/:id/apikeys")
|
||||
{
|
||||
keyGroupAPIKeys.GET("", apiKeyHandler.ListKeysForGroup)
|
||||
keyGroupAPIKeys.GET("/export", apiKeyHandler.ExportKeysForGroup)
|
||||
keyGroupAPIKeys.POST("/bulk", apiKeyHandler.AddMultipleKeysToGroup)
|
||||
keyGroupAPIKeys.DELETE("/bulk", apiKeyHandler.UnlinkMultipleKeysFromGroup)
|
||||
keyGroupAPIKeys.POST("/test", apiKeyHandler.TestKeysForGroup)
|
||||
keyGroupAPIKeys.PUT("/:keyId", apiKeyHandler.UpdateGroupAPIKeyMapping)
|
||||
}
|
||||
// KeyGroup 路由
|
||||
registerKeyGroupRoutes(admin, keyGroupHandler, apiKeyHandler)
|
||||
|
||||
// Global key operations
|
||||
admin.GET("/apikeys", apiKeyHandler.ListAPIKeys)
|
||||
// admin.PUT("/apikeys/:id", apiKeyHandler.UpdateAPIKey) // DEPRECATED: Status is now contextual
|
||||
admin.POST("/apikeys/test", apiKeyHandler.TestMultipleKeys) // Test keys globally
|
||||
admin.DELETE("/apikeys/:id", apiKeyHandler.HardDeleteAPIKey) // Hard delete a single key
|
||||
admin.DELETE("/apikeys/bulk", apiKeyHandler.HardDeleteMultipleKeys) // Hard delete multiple keys
|
||||
admin.PUT("/apikeys/bulk/restore", apiKeyHandler.RestoreMultipleKeys) // Restore multiple keys globally
|
||||
// APIKey 全局路由
|
||||
registerAPIKeyRoutes(admin, apiKeyHandler)
|
||||
|
||||
// --- Global Routes ---
|
||||
admin.GET("/tokens", tokensHandler.GetAllTokens)
|
||||
admin.PUT("/tokens", tokensHandler.UpdateTokens)
|
||||
admin.GET("/logs", logHandler.GetLogs)
|
||||
admin.GET("/settings", settingHandler.GetSettings)
|
||||
admin.PUT("/settings", settingHandler.UpdateSettings)
|
||||
admin.PUT("/settings/reset", settingHandler.ResetSettingsToDefaults)
|
||||
// 系统管理路由
|
||||
registerSystemRoutes(admin, tokensHandler, logHandler, settingHandler, taskHandler)
|
||||
|
||||
// 用于查询异步任务的状态
|
||||
admin.GET("/tasks/:id", taskHandler.GetTaskStatus)
|
||||
// 仪表盘路由
|
||||
registerDashboardRoutes(admin, dashboardHandler)
|
||||
|
||||
// 领域模块
|
||||
// 领域模块路由
|
||||
upstreamModule.RegisterRoutes(admin)
|
||||
proxyModule.RegisterRoutes(admin)
|
||||
// --- 全局仪表盘路由 ---
|
||||
dashboard := admin.Group("/dashboard")
|
||||
{
|
||||
dashboard.GET("/overview", dashboardHandler.GetOverview)
|
||||
dashboard.GET("/chart", dashboardHandler.GetChart)
|
||||
dashboard.GET("/stats/:period", dashboardHandler.GetRequestStats) // 点击详情
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// registerWebRoutes
|
||||
// registerKeyGroupRoutes 注册 KeyGroup 相关路由
|
||||
func registerKeyGroupRoutes(
|
||||
admin *gin.RouterGroup,
|
||||
keyGroupHandler *handlers.KeyGroupHandler,
|
||||
apiKeyHandler *handlers.APIKeyHandler,
|
||||
) {
|
||||
// 基础路由
|
||||
admin.POST("/keygroups", keyGroupHandler.CreateKeyGroup)
|
||||
admin.GET("/keygroups", keyGroupHandler.GetKeyGroups)
|
||||
admin.PUT("/keygroups/order", keyGroupHandler.UpdateKeyGroupOrder)
|
||||
|
||||
// 特定 KeyGroup 路由
|
||||
admin.GET("/keygroups/:id", keyGroupHandler.GetKeyGroups)
|
||||
admin.PUT("/keygroups/:id", keyGroupHandler.UpdateKeyGroup)
|
||||
admin.DELETE("/keygroups/:id", keyGroupHandler.DeleteKeyGroup)
|
||||
admin.POST("/keygroups/:id/clone", keyGroupHandler.CloneKeyGroup)
|
||||
admin.GET("/keygroups/:id/stats", keyGroupHandler.GetKeyGroupStats)
|
||||
admin.POST("/keygroups/:id/bulk-actions", apiKeyHandler.HandleBulkAction)
|
||||
|
||||
// KeyGroup 的 APIKey 子资源
|
||||
keyGroupAPIKeys := admin.Group("/keygroups/:id/apikeys")
|
||||
{
|
||||
keyGroupAPIKeys.GET("", apiKeyHandler.ListKeysForGroup)
|
||||
keyGroupAPIKeys.GET("/export", apiKeyHandler.ExportKeysForGroup)
|
||||
keyGroupAPIKeys.POST("/bulk", apiKeyHandler.AddMultipleKeysToGroup)
|
||||
keyGroupAPIKeys.DELETE("/bulk", apiKeyHandler.UnlinkMultipleKeysFromGroup)
|
||||
keyGroupAPIKeys.POST("/test", apiKeyHandler.TestKeysForGroup)
|
||||
keyGroupAPIKeys.PUT("/:keyId", apiKeyHandler.UpdateGroupAPIKeyMapping)
|
||||
}
|
||||
}
|
||||
|
||||
// registerAPIKeyRoutes 注册 APIKey 全局路由
|
||||
func registerAPIKeyRoutes(admin *gin.RouterGroup, apiKeyHandler *handlers.APIKeyHandler) {
|
||||
admin.GET("/apikeys", apiKeyHandler.ListAPIKeys)
|
||||
admin.POST("/apikeys/test", apiKeyHandler.TestMultipleKeys)
|
||||
admin.DELETE("/apikeys/:id", apiKeyHandler.HardDeleteAPIKey)
|
||||
admin.DELETE("/apikeys/bulk", apiKeyHandler.HardDeleteMultipleKeys)
|
||||
admin.PUT("/apikeys/bulk/restore", apiKeyHandler.RestoreMultipleKeys)
|
||||
}
|
||||
|
||||
// registerSystemRoutes 注册系统管理路由
|
||||
func registerSystemRoutes(
|
||||
admin *gin.RouterGroup,
|
||||
tokensHandler *handlers.TokensHandler,
|
||||
logHandler *handlers.LogHandler,
|
||||
settingHandler *handlers.SettingHandler,
|
||||
taskHandler *handlers.TaskHandler,
|
||||
) {
|
||||
// Token 管理
|
||||
admin.GET("/tokens", tokensHandler.GetAllTokens)
|
||||
admin.PUT("/tokens", tokensHandler.UpdateTokens)
|
||||
|
||||
// 日志管理
|
||||
admin.GET("/logs", logHandler.GetLogs)
|
||||
admin.DELETE("/logs", logHandler.DeleteLogs) // 删除选定
|
||||
admin.DELETE("/logs/all", logHandler.DeleteAllLogs) // 删除全部
|
||||
admin.DELETE("/logs/old", logHandler.DeleteOldLogs) // 删除旧日志
|
||||
|
||||
// 设置管理
|
||||
admin.GET("/settings", settingHandler.GetSettings)
|
||||
admin.PUT("/settings", settingHandler.UpdateSettings)
|
||||
admin.PUT("/settings/reset", settingHandler.ResetSettingsToDefaults)
|
||||
|
||||
// 任务管理
|
||||
admin.GET("/tasks/:id", taskHandler.GetTaskStatus)
|
||||
}
|
||||
|
||||
// registerDashboardRoutes 注册仪表盘路由
|
||||
func registerDashboardRoutes(admin *gin.RouterGroup, dashboardHandler *handlers.DashboardHandler) {
|
||||
dashboard := admin.Group("/dashboard")
|
||||
{
|
||||
dashboard.GET("/overview", dashboardHandler.GetOverview)
|
||||
dashboard.GET("/chart", dashboardHandler.GetChart)
|
||||
dashboard.GET("/stats/:period", dashboardHandler.GetRequestStats)
|
||||
}
|
||||
}
|
||||
|
||||
// registerWebRoutes 注册 Web 页面路由
|
||||
func registerWebRoutes(
|
||||
router *gin.Engine,
|
||||
authMiddleware gin.HandlerFunc,
|
||||
webAuthHandler *webhandlers.WebAuthHandler,
|
||||
pageHandler *webhandlers.PageHandler,
|
||||
) {
|
||||
// 公开的认证路由
|
||||
router.GET("/login", webAuthHandler.ShowLoginPage)
|
||||
router.POST("/login", webAuthHandler.HandleLogin)
|
||||
router.GET("/logout", webAuthHandler.HandleLogout)
|
||||
// For Test only router.Run("127.0.0.1:9000")
|
||||
// 受保护的Admin Web界面
|
||||
|
||||
// 受保护的管理界面
|
||||
webGroup := router.Group("/", authMiddleware)
|
||||
webGroup.Use(authMiddleware)
|
||||
{
|
||||
webGroup.GET("/keys", pageHandler.ShowKeysPage)
|
||||
webGroup.GET("/settings", pageHandler.ShowConfigEditorPage)
|
||||
@@ -197,14 +385,35 @@ func registerWebRoutes(
|
||||
webGroup.GET("/tasks", pageHandler.ShowTasksPage)
|
||||
webGroup.GET("/chat", pageHandler.ShowChatPage)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// registerPublicAPIRoutes 无需后台登录的公共API路由
|
||||
func registerPublicAPIRoutes(router *gin.Engine, apiAuthHandler *handlers.APIAuthHandler, securityService *service.SecurityService, settingsManager *settings.SettingsManager) {
|
||||
ipBanMiddleware := middleware.IPBanMiddleware(securityService, settingsManager)
|
||||
publicAPIGroup := router.Group("/api")
|
||||
// registerPublicAPIRoutes 注册公共 API 路由
|
||||
func registerPublicAPIRoutes(
|
||||
router *gin.Engine,
|
||||
apiAuthHandler *handlers.APIAuthHandler,
|
||||
securityService *service.SecurityService,
|
||||
settingsManager *settings.SettingsManager,
|
||||
logger *logrus.Logger,
|
||||
) {
|
||||
// 创建 IP 封禁中间件
|
||||
ipBanCache := middleware.NewIPBanCache()
|
||||
ipBanMiddleware := middleware.IPBanMiddleware(
|
||||
securityService,
|
||||
settingsManager,
|
||||
ipBanCache,
|
||||
logger,
|
||||
)
|
||||
|
||||
// 公共 API 路由组
|
||||
publicAPI := router.Group("/api")
|
||||
{
|
||||
publicAPIGroup.POST("/login", ipBanMiddleware, apiAuthHandler.HandleLogin)
|
||||
publicAPI.POST("/login", ipBanMiddleware, apiAuthHandler.HandleLogin)
|
||||
// 可以在这里添加其他公共 API 路由
|
||||
// publicAPI.POST("/register", ipBanMiddleware, apiAuthHandler.HandleRegister)
|
||||
// publicAPI.POST("/forgot-password", ipBanMiddleware, apiAuthHandler.HandleForgotPassword)
|
||||
}
|
||||
}
|
||||
|
||||
func registerWebSocketRoutes(router *gin.Engine, wsHandler *handlers.WebSocketHandler) {
|
||||
router.GET("/ws/system-logs", wsHandler.HandleSystemLogs)
|
||||
}
|
||||
|
||||
@@ -1,42 +1,68 @@
|
||||
// Filename: internal/scheduler/scheduler.go
|
||||
// [REVISED] - 用这个更智能的版本完整替换
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt" // [NEW] 导入 fmt
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/service"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"strconv" // [NEW] 导入 strconv
|
||||
"strings" // [NEW] 导入 strings
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-co-op/gocron"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ... (Scheduler struct 和 NewScheduler 保持不变) ...
|
||||
const LogCleanupTaskTag = "log-cleanup-task"
|
||||
|
||||
type Scheduler struct {
|
||||
gocronScheduler *gocron.Scheduler
|
||||
logger *logrus.Entry
|
||||
statsService *service.StatsService
|
||||
logService *service.LogService
|
||||
settingsManager *settings.SettingsManager
|
||||
keyRepo repository.KeyRepository
|
||||
// healthCheckService *service.HealthCheckService // 健康检查任务预留
|
||||
store store.Store
|
||||
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewScheduler(statsSvc *service.StatsService, keyRepo repository.KeyRepository, logger *logrus.Logger) *Scheduler {
|
||||
func NewScheduler(
|
||||
statsSvc *service.StatsService,
|
||||
logSvc *service.LogService,
|
||||
keyRepo repository.KeyRepository,
|
||||
settingsMgr *settings.SettingsManager,
|
||||
store store.Store,
|
||||
logger *logrus.Logger,
|
||||
) *Scheduler {
|
||||
s := gocron.NewScheduler(time.UTC)
|
||||
s.TagsUnique()
|
||||
return &Scheduler{
|
||||
gocronScheduler: s,
|
||||
logger: logger.WithField("component", "Scheduler📆"),
|
||||
statsService: statsSvc,
|
||||
logService: logSvc,
|
||||
settingsManager: settingsMgr,
|
||||
keyRepo: keyRepo,
|
||||
store: store,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// ... (Start 和 listenForSettingsUpdates 保持不变) ...
|
||||
func (s *Scheduler) Start() {
|
||||
s.logger.Info("Starting scheduler and registering jobs...")
|
||||
|
||||
// --- 任务注册 ---
|
||||
// 使用CRON表达式,精确定义“每小时的第5分钟”执行
|
||||
// --- 注册静态定时任务 ---
|
||||
_, err := s.gocronScheduler.Cron("5 * * * *").Tag("stats-aggregation").Do(func() {
|
||||
s.logger.Info("Executing hourly request stats aggregation...")
|
||||
if err := s.statsService.AggregateHourlyStats(); err != nil {
|
||||
ctx := context.Background()
|
||||
if err := s.statsService.AggregateHourlyStats(ctx); err != nil {
|
||||
s.logger.WithError(err).Error("Hourly stats aggregation failed.")
|
||||
} else {
|
||||
s.logger.Info("Hourly stats aggregation completed successfully.")
|
||||
@@ -45,26 +71,9 @@ func (s *Scheduler) Start() {
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to schedule [stats-aggregation]: %v", err)
|
||||
}
|
||||
|
||||
// 任务二:(预留) 自动健康检查 (例如:每10分钟一次)
|
||||
/*
|
||||
_, err = s.gocronScheduler.Every(10).Minutes().Tag("auto-health-check").Do(func() {
|
||||
s.logger.Info("Executing periodic health check for all groups...")
|
||||
// s.healthCheckService.StartGlobalCheckTask() // 伪代码
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to schedule [auto-health-check]: %v", err)
|
||||
}
|
||||
*/
|
||||
// [NEW] --- 任务三: 清理软删除的API Keys ---
|
||||
// Executes once daily at 3:15 AM UTC.
|
||||
_, err = s.gocronScheduler.Cron("15 3 * * *").Tag("cleanup-soft-deleted-keys").Do(func() {
|
||||
s.logger.Info("Executing daily cleanup of soft-deleted API keys...")
|
||||
|
||||
// Let's assume a retention period of 7 days for now.
|
||||
// In a real scenario, this should come from settings.
|
||||
const retentionDays = 7
|
||||
|
||||
count, err := s.keyRepo.HardDeleteSoftDeletedBefore(time.Now().AddDate(0, 0, -retentionDays))
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Daily cleanup of soft-deleted keys failed.")
|
||||
@@ -77,14 +86,125 @@ func (s *Scheduler) Start() {
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to schedule [cleanup-soft-deleted-keys]: %v", err)
|
||||
}
|
||||
// --- 任务注册结束 ---
|
||||
|
||||
s.gocronScheduler.StartAsync() // 异步启动,不阻塞应用主线程
|
||||
// --- 动态任务初始化 ---
|
||||
if err := s.UpdateLogCleanupTask(); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to initialize log cleanup task on startup.")
|
||||
}
|
||||
// --- 启动后台监听器和调度器 ---
|
||||
s.wg.Add(1)
|
||||
go s.listenForSettingsUpdates()
|
||||
s.gocronScheduler.StartAsync()
|
||||
s.logger.Info("Scheduler started.")
|
||||
}
|
||||
func (s *Scheduler) listenForSettingsUpdates() {
|
||||
defer s.wg.Done()
|
||||
s.logger.Info("Starting listener for system settings updates...")
|
||||
for {
|
||||
select {
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Stopping settings update listener.")
|
||||
return
|
||||
default:
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
subscription, err := s.store.Subscribe(ctx, settings.SettingsUpdateChannel)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to subscribe to settings channel, retrying in 5s...")
|
||||
cancel()
|
||||
time.Sleep(5 * time.Second)
|
||||
continue
|
||||
}
|
||||
s.logger.Infof("Successfully subscribed to channel '%s'.", settings.SettingsUpdateChannel)
|
||||
listenLoop:
|
||||
for {
|
||||
select {
|
||||
case msg, ok := <-subscription.Channel():
|
||||
if !ok {
|
||||
s.logger.Warn("Subscription channel closed by publisher. Re-subscribing...")
|
||||
break listenLoop
|
||||
}
|
||||
s.logger.Infof("Received settings update notification: %s", string(msg.Payload))
|
||||
if err := s.UpdateLogCleanupTask(); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to update log cleanup task after notification.")
|
||||
}
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Stopping settings update listener.")
|
||||
subscription.Close()
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
subscription.Close()
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// [MODIFIED] - UpdateLogCleanupTask 现在会动态生成 cron 表达式
|
||||
func (s *Scheduler) UpdateLogCleanupTask() error {
|
||||
if err := s.gocronScheduler.RemoveByTag(LogCleanupTaskTag); err != nil {
|
||||
// This is not an error, just means the job didn't exist
|
||||
}
|
||||
|
||||
settings := s.settingsManager.GetSettings()
|
||||
if !settings.LogAutoCleanupEnabled || settings.LogAutoCleanupRetentionDays <= 0 {
|
||||
s.logger.Info("Log auto-cleanup is disabled. Task removed or not scheduled.")
|
||||
return nil
|
||||
}
|
||||
|
||||
days := settings.LogAutoCleanupRetentionDays
|
||||
|
||||
// [NEW] 解析时间并生成 cron 表达式
|
||||
cronSpec, err := parseTimeToCron(settings.LogAutoCleanupTime)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Warnf("Invalid cleanup time format '%s'. Falling back to default '04:05'.", settings.LogAutoCleanupTime)
|
||||
cronSpec = "5 4 * * *" // 安全回退
|
||||
}
|
||||
|
||||
s.logger.Infof("Scheduling/updating daily log cleanup task to retain last %d days of logs, using cron spec: '%s'", days, cronSpec)
|
||||
|
||||
_, err = s.gocronScheduler.Cron(cronSpec).Tag(LogCleanupTaskTag).Do(func() {
|
||||
s.logger.Infof("Executing daily log cleanup, deleting logs older than %d days...", days)
|
||||
ctx := context.Background()
|
||||
deletedCount, err := s.logService.DeleteOldLogs(ctx, days)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Daily log cleanup task failed.")
|
||||
} else {
|
||||
s.logger.Infof("Daily log cleanup task completed. Deleted %d old logs.", deletedCount)
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Failed to schedule new log cleanup task.")
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Info("Log cleanup task updated successfully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// [NEW] - 用于解析 "HH:mm" 格式时间为 cron 表达式的辅助函数
|
||||
func parseTimeToCron(timeStr string) (string, error) {
|
||||
parts := strings.Split(timeStr, ":")
|
||||
if len(parts) != 2 {
|
||||
return "", fmt.Errorf("invalid time format, expected HH:mm")
|
||||
}
|
||||
|
||||
hour, err := strconv.Atoi(parts[0])
|
||||
if err != nil || hour < 0 || hour > 23 {
|
||||
return "", fmt.Errorf("invalid hour value: %s", parts[0])
|
||||
}
|
||||
|
||||
minute, err := strconv.Atoi(parts[1])
|
||||
if err != nil || minute < 0 || minute > 59 {
|
||||
return "", fmt.Errorf("invalid minute value: %s", parts[1])
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%d %d * * *", minute, hour), nil
|
||||
}
|
||||
func (s *Scheduler) Stop() {
|
||||
s.logger.Info("Stopping scheduler...")
|
||||
close(s.stopChan)
|
||||
s.gocronScheduler.Stop()
|
||||
s.logger.Info("Scheduler stopped.")
|
||||
s.wg.Wait()
|
||||
s.logger.Info("Scheduler stopped gracefully.")
|
||||
}
|
||||
|
||||
@@ -1,96 +1,183 @@
|
||||
// Filename: internal/service/analytics_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/db/dialect"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gemini-balancer/internal/db/dialect"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
flushLoopInterval = 1 * time.Minute
|
||||
defaultFlushInterval = 1 * time.Minute
|
||||
maxRetryAttempts = 3
|
||||
retryDelay = 5 * time.Second
|
||||
)
|
||||
|
||||
type AnalyticsServiceLogger struct{ *logrus.Entry }
|
||||
|
||||
type AnalyticsService struct {
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
dialect dialect.DialectAdapter
|
||||
settingsManager *settings.SettingsManager
|
||||
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
dialect dialect.DialectAdapter
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// 统计指标
|
||||
eventsReceived atomic.Uint64
|
||||
eventsProcessed atomic.Uint64
|
||||
eventsFailed atomic.Uint64
|
||||
flushCount atomic.Uint64
|
||||
recordsFlushed atomic.Uint64
|
||||
flushErrors atomic.Uint64
|
||||
lastFlushTime time.Time
|
||||
lastFlushMutex sync.RWMutex
|
||||
|
||||
// 运行时配置
|
||||
flushInterval time.Duration
|
||||
configMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewAnalyticsService(db *gorm.DB, s store.Store, logger *logrus.Logger, d dialect.DialectAdapter) *AnalyticsService {
|
||||
func NewAnalyticsService(
|
||||
db *gorm.DB,
|
||||
s store.Store,
|
||||
logger *logrus.Logger,
|
||||
d dialect.DialectAdapter,
|
||||
settingsManager *settings.SettingsManager,
|
||||
) *AnalyticsService {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &AnalyticsService{
|
||||
db: db,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "Analytics📊"),
|
||||
stopChan: make(chan struct{}),
|
||||
dialect: d,
|
||||
db: db,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "Analytics📊"),
|
||||
dialect: d,
|
||||
settingsManager: settingsManager,
|
||||
stopChan: make(chan struct{}),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
flushInterval: defaultFlushInterval,
|
||||
lastFlushTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) Start() {
|
||||
s.wg.Add(2) // 2 (flushLoop, eventListener)
|
||||
go s.flushLoop()
|
||||
s.wg.Add(3)
|
||||
go s.eventListener()
|
||||
s.logger.Info("AnalyticsService (Command Side) started.")
|
||||
go s.flushLoop()
|
||||
go s.metricsReporter()
|
||||
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"flush_interval": s.flushInterval,
|
||||
}).Info("AnalyticsService started")
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) Stop() {
|
||||
s.logger.Info("AnalyticsService stopping...")
|
||||
close(s.stopChan)
|
||||
s.cancel()
|
||||
s.wg.Wait()
|
||||
s.logger.Info("AnalyticsService stopped. Performing final data flush...")
|
||||
s.flushToDB() // 停止前刷盘
|
||||
s.logger.Info("AnalyticsService final data flush completed.")
|
||||
|
||||
s.logger.Info("Performing final data flush...")
|
||||
s.flushToDB()
|
||||
|
||||
// 输出最终统计
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"events_received": s.eventsReceived.Load(),
|
||||
"events_processed": s.eventsProcessed.Load(),
|
||||
"events_failed": s.eventsFailed.Load(),
|
||||
"flush_count": s.flushCount.Load(),
|
||||
"records_flushed": s.recordsFlushed.Load(),
|
||||
"flush_errors": s.flushErrors.Load(),
|
||||
}).Info("AnalyticsService stopped")
|
||||
}
|
||||
|
||||
// 事件监听循环
|
||||
func (s *AnalyticsService) eventListener() {
|
||||
defer s.wg.Done()
|
||||
sub, err := s.store.Subscribe(models.TopicRequestFinished)
|
||||
|
||||
sub, err := s.store.Subscribe(s.ctx, models.TopicRequestFinished)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
|
||||
s.logger.WithError(err).Error("Failed to subscribe to request events, analytics disabled")
|
||||
return
|
||||
}
|
||||
defer sub.Close()
|
||||
s.logger.Info("AnalyticsService subscribed to request events.")
|
||||
defer func() {
|
||||
if err := sub.Close(); err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to close subscription")
|
||||
}
|
||||
}()
|
||||
|
||||
s.logger.Info("Subscribed to request events for analytics")
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-sub.Channel():
|
||||
var event models.RequestFinishedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal analytics event: %v", err)
|
||||
continue
|
||||
}
|
||||
s.handleAnalyticsEvent(&event)
|
||||
s.handleMessage(msg)
|
||||
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("AnalyticsService stopping event listener.")
|
||||
s.logger.Info("Event listener stopping")
|
||||
return
|
||||
|
||||
case <-s.ctx.Done():
|
||||
s.logger.Info("Event listener context cancelled")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) {
|
||||
if event.RequestLog.GroupID == nil {
|
||||
// 处理单条消息
|
||||
func (s *AnalyticsService) handleMessage(msg *store.Message) {
|
||||
var event models.RequestFinishedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to unmarshal analytics event")
|
||||
s.eventsFailed.Add(1)
|
||||
return
|
||||
}
|
||||
key := fmt.Sprintf("analytics:hourly:%s", time.Now().UTC().Format("2006-01-02T15"))
|
||||
|
||||
s.eventsReceived.Add(1)
|
||||
|
||||
if err := s.handleAnalyticsEvent(&event); err != nil {
|
||||
s.eventsFailed.Add(1)
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"correlation_id": event.CorrelationID,
|
||||
"group_id": event.RequestLog.GroupID,
|
||||
}).WithError(err).Warn("Failed to process analytics event")
|
||||
} else {
|
||||
s.eventsProcessed.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理分析事件
|
||||
func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEvent) error {
|
||||
if event.RequestLog.GroupID == nil {
|
||||
return nil // 跳过无 GroupID 的事件
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
now := time.Now().UTC()
|
||||
key := fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15"))
|
||||
fieldPrefix := fmt.Sprintf("%d:%s", *event.RequestLog.GroupID, event.RequestLog.ModelName)
|
||||
pipe := s.store.Pipeline()
|
||||
|
||||
pipe := s.store.Pipeline(ctx)
|
||||
pipe.HIncrBy(key, fieldPrefix+":requests", 1)
|
||||
|
||||
if event.RequestLog.IsSuccess {
|
||||
pipe.HIncrBy(key, fieldPrefix+":success", 1)
|
||||
}
|
||||
@@ -100,79 +187,213 @@ func (s *AnalyticsService) handleAnalyticsEvent(event *models.RequestFinishedEve
|
||||
if event.RequestLog.CompletionTokens > 0 {
|
||||
pipe.HIncrBy(key, fieldPrefix+":completion", int64(event.RequestLog.CompletionTokens))
|
||||
}
|
||||
|
||||
// 设置过期时间(保留48小时)
|
||||
pipe.Expire(key, 48*time.Hour)
|
||||
|
||||
if err := pipe.Exec(); err != nil {
|
||||
s.logger.Warnf("[%s] Failed to record analytics event to store for group %d: %v", event.CorrelationID, *event.RequestLog.GroupID, err)
|
||||
return fmt.Errorf("redis pipeline failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 刷新循环
|
||||
func (s *AnalyticsService) flushLoop() {
|
||||
defer s.wg.Done()
|
||||
ticker := time.NewTicker(flushLoopInterval)
|
||||
|
||||
s.configMutex.RLock()
|
||||
interval := s.flushInterval
|
||||
s.configMutex.RUnlock()
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.logger.WithField("interval", interval).Info("Flush loop started")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.flushToDB()
|
||||
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Flush loop stopping")
|
||||
return
|
||||
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 刷写到数据库
|
||||
func (s *AnalyticsService) flushToDB() {
|
||||
start := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
now := time.Now().UTC()
|
||||
keysToFlush := []string{
|
||||
fmt.Sprintf("analytics:hourly:%s", now.Add(-1*time.Hour).Format("2006-01-02T15")),
|
||||
fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")),
|
||||
}
|
||||
keysToFlush := s.generateFlushKeys(now)
|
||||
|
||||
totalRecords := 0
|
||||
totalErrors := 0
|
||||
|
||||
for _, key := range keysToFlush {
|
||||
data, err := s.store.HGetAll(key)
|
||||
if err != nil || len(data) == 0 {
|
||||
continue
|
||||
records, err := s.flushSingleKey(ctx, key, now)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).WithField("key", key).Error("Failed to flush key")
|
||||
totalErrors++
|
||||
s.flushErrors.Add(1)
|
||||
} else {
|
||||
totalRecords += records
|
||||
}
|
||||
}
|
||||
|
||||
statsToFlush, parsedFields := s.parseStatsFromHash(now.Truncate(time.Hour), data)
|
||||
s.recordsFlushed.Add(uint64(totalRecords))
|
||||
s.flushCount.Add(1)
|
||||
|
||||
if len(statsToFlush) > 0 {
|
||||
upsertClause := s.dialect.OnConflictUpdateAll(
|
||||
[]string{"time", "group_id", "model_name"}, // conflict columns
|
||||
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}, // update columns
|
||||
)
|
||||
err := s.db.Clauses(upsertClause).Create(&statsToFlush).Error
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to flush analytics data for key %s: %v", key, err)
|
||||
} else {
|
||||
s.logger.Infof("Successfully flushed %d records from key %s.", len(statsToFlush), key)
|
||||
_ = s.store.HDel(key, parsedFields...)
|
||||
}
|
||||
}
|
||||
s.lastFlushMutex.Lock()
|
||||
s.lastFlushTime = time.Now()
|
||||
s.lastFlushMutex.Unlock()
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
if totalRecords > 0 || totalErrors > 0 {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"records_flushed": totalRecords,
|
||||
"keys_processed": len(keysToFlush),
|
||||
"errors": totalErrors,
|
||||
"duration": duration,
|
||||
}).Info("Analytics data flush completed")
|
||||
} else {
|
||||
s.logger.WithField("duration", duration).Debug("Analytics flush completed (no data)")
|
||||
}
|
||||
}
|
||||
|
||||
// 生成需要刷新的 Redis 键
|
||||
func (s *AnalyticsService) generateFlushKeys(now time.Time) []string {
|
||||
keys := make([]string, 0, 4)
|
||||
|
||||
// 当前小时
|
||||
keys = append(keys, fmt.Sprintf("analytics:hourly:%s", now.Format("2006-01-02T15")))
|
||||
|
||||
// 前3个小时(处理延迟和时区问题)
|
||||
for i := 1; i <= 3; i++ {
|
||||
pastHour := now.Add(-time.Duration(i) * time.Hour)
|
||||
keys = append(keys, fmt.Sprintf("analytics:hourly:%s", pastHour.Format("2006-01-02T15")))
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// 刷写单个 Redis 键
|
||||
func (s *AnalyticsService) flushSingleKey(ctx context.Context, key string, baseTime time.Time) (int, error) {
|
||||
data, err := s.store.HGetAll(ctx, key)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get hash data: %w", err)
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
return 0, nil // 无数据,跳过
|
||||
}
|
||||
|
||||
// 解析时间戳
|
||||
hourStr := strings.TrimPrefix(key, "analytics:hourly:")
|
||||
recordTime, err := time.Parse("2006-01-02T15", hourStr)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).WithField("key", key).Warn("Failed to parse time from key")
|
||||
recordTime = baseTime.Truncate(time.Hour)
|
||||
}
|
||||
|
||||
statsToFlush, parsedFields := s.parseStatsFromHash(recordTime, data)
|
||||
|
||||
if len(statsToFlush) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// 使用事务 + 重试机制
|
||||
var dbErr error
|
||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||
dbErr = s.upsertStatsWithTransaction(ctx, statsToFlush)
|
||||
if dbErr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if attempt < maxRetryAttempts {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"attempt": attempt,
|
||||
"key": key,
|
||||
}).WithError(dbErr).Warn("Database upsert failed, retrying...")
|
||||
time.Sleep(retryDelay)
|
||||
}
|
||||
}
|
||||
|
||||
if dbErr != nil {
|
||||
return 0, fmt.Errorf("failed to upsert after %d attempts: %w", maxRetryAttempts, dbErr)
|
||||
}
|
||||
|
||||
// 删除已处理的字段
|
||||
if len(parsedFields) > 0 {
|
||||
if err := s.store.HDel(ctx, key, parsedFields...); err != nil {
|
||||
s.logger.WithError(err).WithField("key", key).Warn("Failed to delete flushed fields from Redis")
|
||||
}
|
||||
}
|
||||
|
||||
return len(statsToFlush), nil
|
||||
}
|
||||
|
||||
// 使用事务批量 upsert
|
||||
func (s *AnalyticsService) upsertStatsWithTransaction(ctx context.Context, stats []models.StatsHourly) error {
|
||||
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
upsertClause := s.dialect.OnConflictUpdateAll(
|
||||
[]string{"time", "group_id", "model_name"},
|
||||
[]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"},
|
||||
)
|
||||
return tx.Clauses(upsertClause).Create(&stats).Error
|
||||
})
|
||||
}
|
||||
|
||||
// 解析 Redis Hash 数据
|
||||
func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]string) ([]models.StatsHourly, []string) {
|
||||
tempAggregator := make(map[string]*models.StatsHourly)
|
||||
var parsedFields []string
|
||||
parsedFields := make([]string, 0, len(data))
|
||||
|
||||
for field, valueStr := range data {
|
||||
parts := strings.Split(field, ":")
|
||||
if len(parts) != 3 {
|
||||
s.logger.WithField("field", field).Warn("Invalid field format")
|
||||
continue
|
||||
}
|
||||
groupIDStr, modelName, counterType := parts[0], parts[1], parts[2]
|
||||
|
||||
groupIDStr, modelName, counterType := parts[0], parts[1], parts[2]
|
||||
aggKey := groupIDStr + ":" + modelName
|
||||
|
||||
if _, ok := tempAggregator[aggKey]; !ok {
|
||||
gid, err := strconv.Atoi(groupIDStr)
|
||||
if err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"field": field,
|
||||
"group_id": groupIDStr,
|
||||
}).Warn("Invalid group ID")
|
||||
continue
|
||||
}
|
||||
|
||||
tempAggregator[aggKey] = &models.StatsHourly{
|
||||
Time: t,
|
||||
GroupID: uint(gid),
|
||||
ModelName: modelName,
|
||||
}
|
||||
}
|
||||
val, _ := strconv.ParseInt(valueStr, 10, 64)
|
||||
|
||||
val, err := strconv.ParseInt(valueStr, 10, 64)
|
||||
if err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"field": field,
|
||||
"value": valueStr,
|
||||
}).Warn("Invalid counter value")
|
||||
continue
|
||||
}
|
||||
|
||||
switch counterType {
|
||||
case "requests":
|
||||
tempAggregator[aggKey].RequestCount = val
|
||||
@@ -182,14 +403,92 @@ func (s *AnalyticsService) parseStatsFromHash(t time.Time, data map[string]strin
|
||||
tempAggregator[aggKey].PromptTokens = val
|
||||
case "completion":
|
||||
tempAggregator[aggKey].CompletionTokens = val
|
||||
default:
|
||||
s.logger.WithField("counter_type", counterType).Warn("Unknown counter type")
|
||||
continue
|
||||
}
|
||||
|
||||
parsedFields = append(parsedFields, field)
|
||||
}
|
||||
var result []models.StatsHourly
|
||||
|
||||
result := make([]models.StatsHourly, 0, len(tempAggregator))
|
||||
for _, stats := range tempAggregator {
|
||||
if stats.RequestCount > 0 {
|
||||
result = append(result, *stats)
|
||||
}
|
||||
}
|
||||
|
||||
return result, parsedFields
|
||||
}
|
||||
|
||||
// 定期输出统计信息
|
||||
func (s *AnalyticsService) metricsReporter() {
|
||||
defer s.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.reportMetrics()
|
||||
case <-s.stopChan:
|
||||
return
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsService) reportMetrics() {
|
||||
s.lastFlushMutex.RLock()
|
||||
lastFlush := s.lastFlushTime
|
||||
s.lastFlushMutex.RUnlock()
|
||||
|
||||
received := s.eventsReceived.Load()
|
||||
processed := s.eventsProcessed.Load()
|
||||
failed := s.eventsFailed.Load()
|
||||
|
||||
var successRate float64
|
||||
if received > 0 {
|
||||
successRate = float64(processed) / float64(received) * 100
|
||||
}
|
||||
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"events_received": received,
|
||||
"events_processed": processed,
|
||||
"events_failed": failed,
|
||||
"success_rate": fmt.Sprintf("%.2f%%", successRate),
|
||||
"flush_count": s.flushCount.Load(),
|
||||
"records_flushed": s.recordsFlushed.Load(),
|
||||
"flush_errors": s.flushErrors.Load(),
|
||||
"last_flush_ago": time.Since(lastFlush).Round(time.Second),
|
||||
}).Info("Analytics metrics")
|
||||
}
|
||||
|
||||
// GetMetrics 返回当前统计指标(供监控使用)
|
||||
func (s *AnalyticsService) GetMetrics() map[string]interface{} {
|
||||
s.lastFlushMutex.RLock()
|
||||
lastFlush := s.lastFlushTime
|
||||
s.lastFlushMutex.RUnlock()
|
||||
|
||||
received := s.eventsReceived.Load()
|
||||
processed := s.eventsProcessed.Load()
|
||||
|
||||
var successRate float64
|
||||
if received > 0 {
|
||||
successRate = float64(processed) / float64(received) * 100
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"events_received": received,
|
||||
"events_processed": processed,
|
||||
"events_failed": s.eventsFailed.Load(),
|
||||
"success_rate": successRate,
|
||||
"flush_count": s.flushCount.Load(),
|
||||
"records_flushed": s.recordsFlushed.Load(),
|
||||
"flush_errors": s.flushErrors.Load(),
|
||||
"last_flush_ago": time.Since(lastFlush).Seconds(),
|
||||
"flush_interval": s.flushInterval.Seconds(),
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,164 +1,300 @@
|
||||
// Filename: internal/service/dashboard_query_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/syncer"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const overviewCacheChannel = "syncer:cache:dashboard_overview"
|
||||
const (
|
||||
overviewCacheChannel = "syncer:cache:dashboard_overview"
|
||||
defaultChartDays = 7
|
||||
cacheLoadTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// DashboardQueryService 负责所有面向前端的仪表盘数据查询。
|
||||
var (
|
||||
// 图表颜色调色板
|
||||
chartColorPalette = []string{
|
||||
"#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0",
|
||||
"#9966FF", "#FF9F40", "#C9CBCF", "#4D5360",
|
||||
}
|
||||
)
|
||||
|
||||
type DashboardQueryService struct {
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
overviewSyncer *syncer.CacheSyncer[*models.DashboardStatsResponse]
|
||||
logger *logrus.Entry
|
||||
stopChan chan struct{}
|
||||
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// 统计指标
|
||||
queryCount atomic.Uint64
|
||||
cacheHits atomic.Uint64
|
||||
cacheMisses atomic.Uint64
|
||||
overviewLoadCount atomic.Uint64
|
||||
lastQueryTime time.Time
|
||||
lastQueryMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewDashboardQueryService(db *gorm.DB, s store.Store, logger *logrus.Logger) (*DashboardQueryService, error) {
|
||||
qs := &DashboardQueryService{
|
||||
db: db,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "DashboardQueryService"),
|
||||
stopChan: make(chan struct{}),
|
||||
func NewDashboardQueryService(
|
||||
db *gorm.DB,
|
||||
s store.Store,
|
||||
logger *logrus.Logger,
|
||||
) (*DashboardQueryService, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
service := &DashboardQueryService{
|
||||
db: db,
|
||||
store: s,
|
||||
logger: logger.WithField("component", "DashboardQuery📈"),
|
||||
stopChan: make(chan struct{}),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
lastQueryTime: time.Now(),
|
||||
}
|
||||
|
||||
loader := qs.loadOverviewData
|
||||
overviewSyncer, err := syncer.NewCacheSyncer(loader, s, overviewCacheChannel)
|
||||
// 创建 CacheSyncer
|
||||
overviewSyncer, err := syncer.NewCacheSyncer(
|
||||
service.loadOverviewData,
|
||||
s,
|
||||
overviewCacheChannel,
|
||||
logger,
|
||||
)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to create overview cache syncer: %w", err)
|
||||
}
|
||||
qs.overviewSyncer = overviewSyncer
|
||||
return qs, nil
|
||||
service.overviewSyncer = overviewSyncer
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) Start() {
|
||||
s.wg.Add(2)
|
||||
go s.eventListener()
|
||||
s.logger.Info("DashboardQueryService started and listening for invalidation events.")
|
||||
go s.metricsReporter()
|
||||
|
||||
s.logger.Info("DashboardQueryService started")
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) Stop() {
|
||||
s.logger.Info("DashboardQueryService stopping...")
|
||||
close(s.stopChan)
|
||||
s.logger.Info("DashboardQueryService and its CacheSyncer have been stopped.")
|
||||
s.cancel()
|
||||
s.wg.Wait()
|
||||
|
||||
// 输出最终统计
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"total_queries": s.queryCount.Load(),
|
||||
"cache_hits": s.cacheHits.Load(),
|
||||
"cache_misses": s.cacheMisses.Load(),
|
||||
"overview_loads": s.overviewLoadCount.Load(),
|
||||
}).Info("DashboardQueryService stopped")
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) GetGroupStats(groupID uint) (map[string]any, error) {
|
||||
// ==================== 核心查询方法 ====================
|
||||
|
||||
// GetDashboardOverviewData 获取仪表盘概览数据(带缓存)
|
||||
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
|
||||
s.queryCount.Add(1)
|
||||
|
||||
cachedDataPtr := s.overviewSyncer.Get()
|
||||
if cachedDataPtr == nil {
|
||||
s.cacheMisses.Add(1)
|
||||
s.logger.Warn("Overview cache is empty, attempting to load...")
|
||||
|
||||
// 触发立即加载
|
||||
if err := s.overviewSyncer.Invalidate(); err != nil {
|
||||
return nil, fmt.Errorf("failed to trigger cache reload: %w", err)
|
||||
}
|
||||
|
||||
// 等待加载完成(最多30秒)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheLoadTimeout)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if data := s.overviewSyncer.Get(); data != nil {
|
||||
s.cacheHits.Add(1)
|
||||
return data, nil
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("timeout waiting for overview cache to load")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.cacheHits.Add(1)
|
||||
return cachedDataPtr, nil
|
||||
}
|
||||
|
||||
// InvalidateOverviewCache 手动失效概览缓存
|
||||
func (s *DashboardQueryService) InvalidateOverviewCache() error {
|
||||
s.logger.Info("Manually invalidating overview cache")
|
||||
return s.overviewSyncer.Invalidate()
|
||||
}
|
||||
|
||||
// GetGroupStats 获取指定分组的统计数据
|
||||
func (s *DashboardQueryService) GetGroupStats(ctx context.Context, groupID uint) (map[string]any, error) {
|
||||
s.queryCount.Add(1)
|
||||
s.updateLastQueryTime()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
// 1. 从 Redis 获取 Key 统计
|
||||
statsKey := fmt.Sprintf("stats:group:%d", groupID)
|
||||
keyStatsMap, err := s.store.HGetAll(statsKey)
|
||||
keyStatsMap, err := s.store.HGetAll(ctx, statsKey)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to get key stats from cache for group %d", groupID)
|
||||
s.logger.WithError(err).Errorf("Failed to get key stats for group %d", groupID)
|
||||
return nil, fmt.Errorf("failed to get key stats from cache: %w", err)
|
||||
}
|
||||
|
||||
keyStats := make(map[string]int64)
|
||||
for k, v := range keyStatsMap {
|
||||
val, _ := strconv.ParseInt(v, 10, 64)
|
||||
keyStats[k] = val
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
// 2. 查询请求统计(使用 UTC 时间)
|
||||
now := time.Now().UTC()
|
||||
oneHourAgo := now.Add(-1 * time.Hour)
|
||||
twentyFourHoursAgo := now.Add(-24 * time.Hour)
|
||||
|
||||
type requestStatsResult struct {
|
||||
TotalRequests int64
|
||||
SuccessRequests int64
|
||||
}
|
||||
|
||||
var last1Hour, last24Hours requestStatsResult
|
||||
s.db.Model(&models.StatsHourly{}).
|
||||
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
|
||||
Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
|
||||
Scan(&last1Hour)
|
||||
s.db.Model(&models.StatsHourly{}).
|
||||
Select("SUM(request_count) as total_requests, SUM(success_count) as success_requests").
|
||||
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
|
||||
Scan(&last24Hours)
|
||||
failureRate1h := 0.0
|
||||
if last1Hour.TotalRequests > 0 {
|
||||
failureRate1h = float64(last1Hour.TotalRequests-last1Hour.SuccessRequests) / float64(last1Hour.TotalRequests) * 100
|
||||
}
|
||||
failureRate24h := 0.0
|
||||
if last24Hours.TotalRequests > 0 {
|
||||
failureRate24h = float64(last24Hours.TotalRequests-last24Hours.SuccessRequests) / float64(last24Hours.TotalRequests) * 100
|
||||
}
|
||||
last1HourStats := map[string]any{
|
||||
"total_requests": last1Hour.TotalRequests,
|
||||
"success_requests": last1Hour.SuccessRequests,
|
||||
"failure_rate": failureRate1h,
|
||||
}
|
||||
last24HoursStats := map[string]any{
|
||||
"total_requests": last24Hours.TotalRequests,
|
||||
"success_requests": last24Hours.SuccessRequests,
|
||||
"failure_rate": failureRate24h,
|
||||
|
||||
// 并发查询优化
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
wg.Add(2)
|
||||
|
||||
// 查询最近1小时
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}).
|
||||
Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests").
|
||||
Where("group_id = ? AND time >= ?", groupID, oneHourAgo).
|
||||
Scan(&last1Hour).Error; err != nil {
|
||||
errChan <- fmt.Errorf("failed to query 1h stats: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 查询最近24小时
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.db.WithContext(ctx).Model(&models.StatsHourly{}).
|
||||
Select("COALESCE(SUM(request_count), 0) as total_requests, COALESCE(SUM(success_count), 0) as success_requests").
|
||||
Where("group_id = ? AND time >= ?", groupID, twentyFourHoursAgo).
|
||||
Scan(&last24Hours).Error; err != nil {
|
||||
errChan <- fmt.Errorf("failed to query 24h stats: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
// 检查错误
|
||||
for err := range errChan {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 计算失败率
|
||||
failureRate1h := s.calculateFailureRate(last1Hour.TotalRequests, last1Hour.SuccessRequests)
|
||||
failureRate24h := s.calculateFailureRate(last24Hours.TotalRequests, last24Hours.SuccessRequests)
|
||||
|
||||
result := map[string]any{
|
||||
"key_stats": keyStats,
|
||||
"last_1_hour": last1HourStats,
|
||||
"last_24_hours": last24HoursStats,
|
||||
"key_stats": keyStats,
|
||||
"last_1_hour": map[string]any{
|
||||
"total_requests": last1Hour.TotalRequests,
|
||||
"success_requests": last1Hour.SuccessRequests,
|
||||
"failed_requests": last1Hour.TotalRequests - last1Hour.SuccessRequests,
|
||||
"failure_rate": failureRate1h,
|
||||
},
|
||||
"last_24_hours": map[string]any{
|
||||
"total_requests": last24Hours.TotalRequests,
|
||||
"success_requests": last24Hours.SuccessRequests,
|
||||
"failed_requests": last24Hours.TotalRequests - last24Hours.SuccessRequests,
|
||||
"failure_rate": failureRate24h,
|
||||
},
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"duration": duration,
|
||||
}).Debug("Group stats query completed")
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) eventListener() {
|
||||
keyStatusSub, _ := s.store.Subscribe(models.TopicKeyStatusChanged)
|
||||
upstreamStatusSub, _ := s.store.Subscribe(models.TopicUpstreamHealthChanged)
|
||||
defer keyStatusSub.Close()
|
||||
defer upstreamStatusSub.Close()
|
||||
for {
|
||||
select {
|
||||
case <-keyStatusSub.Channel():
|
||||
s.logger.Info("Received key status changed event, invalidating overview cache...")
|
||||
_ = s.InvalidateOverviewCache()
|
||||
case <-upstreamStatusSub.Channel():
|
||||
s.logger.Info("Received upstream status changed event, invalidating overview cache...")
|
||||
_ = s.InvalidateOverviewCache()
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Stopping dashboard event listener.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// QueryHistoricalChart 查询历史图表数据
|
||||
func (s *DashboardQueryService) QueryHistoricalChart(ctx context.Context, groupID *uint) (*models.ChartData, error) {
|
||||
s.queryCount.Add(1)
|
||||
s.updateLastQueryTime()
|
||||
|
||||
// GetDashboardOverviewData 从 Syncer 缓存中高速获取仪表盘概览数据。
|
||||
func (s *DashboardQueryService) GetDashboardOverviewData() (*models.DashboardStatsResponse, error) {
|
||||
cachedDataPtr := s.overviewSyncer.Get()
|
||||
if cachedDataPtr == nil {
|
||||
return &models.DashboardStatsResponse{}, fmt.Errorf("overview cache is not available or still syncing")
|
||||
}
|
||||
return cachedDataPtr, nil
|
||||
}
|
||||
start := time.Now()
|
||||
|
||||
func (s *DashboardQueryService) InvalidateOverviewCache() error {
|
||||
return s.overviewSyncer.Invalidate()
|
||||
}
|
||||
|
||||
// QueryHistoricalChart 查询历史图表数据。
|
||||
func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.ChartData, error) {
|
||||
type ChartPoint struct {
|
||||
TimeLabel string `gorm:"column:time_label"`
|
||||
ModelName string `gorm:"column:model_name"`
|
||||
TotalRequests int64 `gorm:"column:total_requests"`
|
||||
}
|
||||
sevenDaysAgo := time.Now().Add(-24 * 7 * time.Hour).Truncate(time.Hour)
|
||||
|
||||
// 查询最近7天数据(使用 UTC)
|
||||
sevenDaysAgo := time.Now().UTC().AddDate(0, 0, -defaultChartDays).Truncate(time.Hour)
|
||||
|
||||
// 根据数据库类型构建时间格式化子句
|
||||
sqlFormat, goFormat := s.buildTimeFormatSelectClause()
|
||||
selectClause := fmt.Sprintf("%s as time_label, model_name, SUM(request_count) as total_requests", sqlFormat)
|
||||
query := s.db.Model(&models.StatsHourly{}).Select(selectClause).Where("time >= ?", sevenDaysAgo).Group("time_label, model_name").Order("time_label ASC")
|
||||
selectClause := fmt.Sprintf(
|
||||
"%s as time_label, model_name, COALESCE(SUM(request_count), 0) as total_requests",
|
||||
sqlFormat,
|
||||
)
|
||||
|
||||
// 构建查询
|
||||
query := s.db.WithContext(ctx).
|
||||
Model(&models.StatsHourly{}).
|
||||
Select(selectClause).
|
||||
Where("time >= ?", sevenDaysAgo).
|
||||
Group("time_label, model_name").
|
||||
Order("time_label ASC")
|
||||
|
||||
if groupID != nil && *groupID > 0 {
|
||||
query = query.Where("group_id = ?", *groupID)
|
||||
}
|
||||
|
||||
var points []ChartPoint
|
||||
if err := query.Find(&points).Error; err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to query chart data: %w", err)
|
||||
}
|
||||
|
||||
// 构建数据集
|
||||
datasets := make(map[string]map[string]int64)
|
||||
for _, p := range points {
|
||||
if _, ok := datasets[p.ModelName]; !ok {
|
||||
@@ -166,150 +302,495 @@ func (s *DashboardQueryService) QueryHistoricalChart(groupID *uint) (*models.Cha
|
||||
}
|
||||
datasets[p.ModelName][p.TimeLabel] = p.TotalRequests
|
||||
}
|
||||
|
||||
// 生成时间标签(按小时)
|
||||
var labels []string
|
||||
for t := sevenDaysAgo; t.Before(time.Now()); t = t.Add(time.Hour) {
|
||||
for t := sevenDaysAgo; t.Before(time.Now().UTC()); t = t.Add(time.Hour) {
|
||||
labels = append(labels, t.Format(goFormat))
|
||||
}
|
||||
chartData := &models.ChartData{Labels: labels, Datasets: make([]models.ChartDataset, 0)}
|
||||
colorPalette := []string{"#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0", "#9966FF", "#FF9F40"}
|
||||
|
||||
// 构建图表数据
|
||||
chartData := &models.ChartData{
|
||||
Labels: labels,
|
||||
Datasets: make([]models.ChartDataset, 0, len(datasets)),
|
||||
}
|
||||
|
||||
colorIndex := 0
|
||||
for modelName, dataPoints := range datasets {
|
||||
dataArray := make([]int64, len(labels))
|
||||
for i, label := range labels {
|
||||
dataArray[i] = dataPoints[label]
|
||||
}
|
||||
|
||||
chartData.Datasets = append(chartData.Datasets, models.ChartDataset{
|
||||
Label: modelName,
|
||||
Data: dataArray,
|
||||
Color: colorPalette[colorIndex%len(colorPalette)],
|
||||
Color: chartColorPalette[colorIndex%len(chartColorPalette)],
|
||||
})
|
||||
colorIndex++
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"points": len(points),
|
||||
"datasets": len(chartData.Datasets),
|
||||
"duration": duration,
|
||||
}).Debug("Historical chart query completed")
|
||||
|
||||
return chartData, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
|
||||
s.logger.Info("[CacheSyncer] Starting to load overview data from database...")
|
||||
startTime := time.Now()
|
||||
resp := &models.DashboardStatsResponse{
|
||||
KeyStatusCount: make(map[models.APIKeyStatus]int64),
|
||||
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
|
||||
KeyCount: models.StatCard{}, // 确保KeyCount是一个空的结构体,而不是nil
|
||||
RequestCount24h: models.StatCard{}, // 同上
|
||||
TokenCount: make(map[string]any),
|
||||
UpstreamHealthStatus: make(map[string]string),
|
||||
RPM: models.StatCard{},
|
||||
RequestCounts: make(map[string]int64),
|
||||
}
|
||||
// --- 1. Aggregate Operational Status from Mappings ---
|
||||
type MappingStatusResult struct {
|
||||
Status models.APIKeyStatus
|
||||
Count int64
|
||||
}
|
||||
var mappingStatusResults []MappingStatusResult
|
||||
if err := s.db.Model(&models.GroupAPIKeyMapping{}).Select("status, count(*) as count").Group("status").Find(&mappingStatusResults).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to query mapping status stats: %w", err)
|
||||
}
|
||||
for _, res := range mappingStatusResults {
|
||||
resp.KeyStatusCount[res.Status] = res.Count
|
||||
}
|
||||
// GetRequestStatsForPeriod 获取指定时间段的请求统计
|
||||
func (s *DashboardQueryService) GetRequestStatsForPeriod(ctx context.Context, period string) (gin.H, error) {
|
||||
s.queryCount.Add(1)
|
||||
s.updateLastQueryTime()
|
||||
|
||||
// --- 2. Aggregate Master Status from APIKeys ---
|
||||
type MasterStatusResult struct {
|
||||
Status models.MasterAPIKeyStatus
|
||||
Count int64
|
||||
}
|
||||
var masterStatusResults []MasterStatusResult
|
||||
if err := s.db.Model(&models.APIKey{}).Select("master_status as status, count(*) as count").Group("master_status").Find(&masterStatusResults).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to query master status stats: %w", err)
|
||||
}
|
||||
var totalKeys, invalidKeys int64
|
||||
for _, res := range masterStatusResults {
|
||||
resp.MasterStatusCount[res.Status] = res.Count
|
||||
totalKeys += res.Count
|
||||
if res.Status != models.MasterStatusActive {
|
||||
invalidKeys += res.Count
|
||||
}
|
||||
}
|
||||
resp.KeyCount = models.StatCard{Value: float64(totalKeys), SubValue: invalidKeys, SubValueTip: "非活跃身份密钥数"}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 1. RPM (1分钟), RPH (1小时), RPD (今日): 从“瞬时记忆”(request_logs)中精确查询
|
||||
var count1m, count1h, count1d int64
|
||||
// RPM: 从此刻倒推1分钟
|
||||
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Minute)).Count(&count1m)
|
||||
// RPH: 从此刻倒推1小时
|
||||
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", now.Add(-1*time.Hour)).Count(&count1h)
|
||||
|
||||
// RPD: 从今天零点 (UTC) 到此刻
|
||||
year, month, day := now.UTC().Date()
|
||||
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
|
||||
s.db.Model(&models.RequestLog{}).Where("request_time >= ?", startOfDay).Count(&count1d)
|
||||
// 2. RP30D (30天): 从“长期记忆”(stats_hourly)中高效查询,以保证性能
|
||||
var count30d int64
|
||||
s.db.Model(&models.StatsHourly{}).Where("time >= ?", now.AddDate(0, 0, -30)).Select("COALESCE(SUM(request_count), 0)").Scan(&count30d)
|
||||
|
||||
resp.RequestCounts["1m"] = count1m
|
||||
resp.RequestCounts["1h"] = count1h
|
||||
resp.RequestCounts["1d"] = count1d
|
||||
resp.RequestCounts["30d"] = count30d
|
||||
|
||||
var upstreams []*models.UpstreamEndpoint
|
||||
if err := s.db.Find(&upstreams).Error; err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to load upstream statuses for dashboard.")
|
||||
} else {
|
||||
for _, u := range upstreams {
|
||||
resp.UpstreamHealthStatus[u.URL] = u.Status
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
s.logger.Infof("[CacheSyncer] Successfully finished loading overview data in %s.", duration)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) GetRequestStatsForPeriod(period string) (gin.H, error) {
|
||||
var startTime time.Time
|
||||
now := time.Now()
|
||||
now := time.Now().UTC()
|
||||
|
||||
switch period {
|
||||
case "1m":
|
||||
startTime = now.Add(-1 * time.Minute)
|
||||
case "1h":
|
||||
startTime = now.Add(-1 * time.Hour)
|
||||
case "1d":
|
||||
year, month, day := now.UTC().Date()
|
||||
year, month, day := now.Date()
|
||||
startTime = time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid period specified: %s", period)
|
||||
return nil, fmt.Errorf("invalid period specified: %s (must be 1m, 1h, or 1d)", period)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Total int64
|
||||
Success int64
|
||||
}
|
||||
|
||||
err := s.db.Model(&models.RequestLog{}).
|
||||
Select("count(*) as total, sum(case when is_success = true then 1 else 0 end) as success").
|
||||
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
|
||||
Select("COUNT(*) as total, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success").
|
||||
Where("request_time >= ?", startTime).
|
||||
Scan(&result).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to query request stats: %w", err)
|
||||
}
|
||||
|
||||
return gin.H{
|
||||
"period": period,
|
||||
"total": result.Total,
|
||||
"success": result.Success,
|
||||
"failure": result.Total - result.Success,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ==================== 内部方法 ====================
|
||||
|
||||
// loadOverviewData 加载仪表盘概览数据(供 CacheSyncer 调用)
|
||||
func (s *DashboardQueryService) loadOverviewData() (*models.DashboardStatsResponse, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
s.overviewLoadCount.Add(1)
|
||||
startTime := time.Now()
|
||||
|
||||
s.logger.Info("Starting to load dashboard overview data...")
|
||||
|
||||
resp := &models.DashboardStatsResponse{
|
||||
KeyStatusCount: make(map[models.APIKeyStatus]int64),
|
||||
MasterStatusCount: make(map[models.MasterAPIKeyStatus]int64),
|
||||
KeyCount: models.StatCard{},
|
||||
RequestCount24h: models.StatCard{},
|
||||
TokenCount: make(map[string]any),
|
||||
UpstreamHealthStatus: make(map[string]string),
|
||||
RPM: models.StatCard{},
|
||||
RequestCounts: make(map[string]int64),
|
||||
}
|
||||
|
||||
var loadErr error
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 10)
|
||||
|
||||
// 1. 并发加载 Key 映射状态统计
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.loadMappingStatusStats(ctx, resp); err != nil {
|
||||
errChan <- fmt.Errorf("mapping stats: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 2. 并发加载 Master Key 状态统计
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.loadMasterStatusStats(ctx, resp); err != nil {
|
||||
errChan <- fmt.Errorf("master stats: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 3. 并发加载请求统计
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.loadRequestCounts(ctx, resp); err != nil {
|
||||
errChan <- fmt.Errorf("request counts: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 4. 并发加载上游健康状态
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.loadUpstreamHealth(ctx, resp); err != nil {
|
||||
// 上游健康状态失败不阻塞整体加载
|
||||
s.logger.WithError(err).Warn("Failed to load upstream health status")
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待所有加载任务完成
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
// 收集错误
|
||||
for err := range errChan {
|
||||
if err != nil {
|
||||
loadErr = err
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if loadErr != nil {
|
||||
s.logger.WithError(loadErr).Error("Failed to load overview data")
|
||||
return nil, loadErr
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"duration": duration,
|
||||
"total_keys": resp.KeyCount.Value,
|
||||
"requests_1d": resp.RequestCounts["1d"],
|
||||
"upstreams": len(resp.UpstreamHealthStatus),
|
||||
}).Info("Successfully loaded dashboard overview data")
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// loadMappingStatusStats 加载 Key 映射状态统计
|
||||
func (s *DashboardQueryService) loadMappingStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error {
|
||||
type MappingStatusResult struct {
|
||||
Status models.APIKeyStatus
|
||||
Count int64
|
||||
}
|
||||
|
||||
var results []MappingStatusResult
|
||||
if err := s.db.WithContext(ctx).
|
||||
Model(&models.GroupAPIKeyMapping{}).
|
||||
Select("status, COUNT(*) as count").
|
||||
Group("status").
|
||||
Find(&results).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, res := range results {
|
||||
resp.KeyStatusCount[res.Status] = res.Count
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadMasterStatusStats 加载 Master Key 状态统计
|
||||
func (s *DashboardQueryService) loadMasterStatusStats(ctx context.Context, resp *models.DashboardStatsResponse) error {
|
||||
type MasterStatusResult struct {
|
||||
Status models.MasterAPIKeyStatus
|
||||
Count int64
|
||||
}
|
||||
|
||||
var results []MasterStatusResult
|
||||
if err := s.db.WithContext(ctx).
|
||||
Model(&models.APIKey{}).
|
||||
Select("master_status as status, COUNT(*) as count").
|
||||
Group("master_status").
|
||||
Find(&results).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var totalKeys, invalidKeys int64
|
||||
for _, res := range results {
|
||||
resp.MasterStatusCount[res.Status] = res.Count
|
||||
totalKeys += res.Count
|
||||
if res.Status != models.MasterStatusActive {
|
||||
invalidKeys += res.Count
|
||||
}
|
||||
}
|
||||
|
||||
resp.KeyCount = models.StatCard{
|
||||
Value: float64(totalKeys),
|
||||
SubValue: invalidKeys,
|
||||
SubValueTip: "非活跃身份密钥数",
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadRequestCounts 加载请求计数统计
|
||||
func (s *DashboardQueryService) loadRequestCounts(ctx context.Context, resp *models.DashboardStatsResponse) error {
|
||||
now := time.Now().UTC()
|
||||
|
||||
// 使用 RequestLog 表查询短期数据
|
||||
var count1m, count1h int64
|
||||
|
||||
// 最近1分钟
|
||||
if err := s.db.WithContext(ctx).
|
||||
Model(&models.RequestLog{}).
|
||||
Where("request_time >= ?", now.Add(-1*time.Minute)).
|
||||
Count(&count1m).Error; err != nil {
|
||||
return fmt.Errorf("1m count: %w", err)
|
||||
}
|
||||
|
||||
// 最近1小时
|
||||
if err := s.db.WithContext(ctx).
|
||||
Model(&models.RequestLog{}).
|
||||
Where("request_time >= ?", now.Add(-1*time.Hour)).
|
||||
Count(&count1h).Error; err != nil {
|
||||
return fmt.Errorf("1h count: %w", err)
|
||||
}
|
||||
|
||||
// 今天(UTC)
|
||||
year, month, day := now.Date()
|
||||
startOfDay := time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
var count1d int64
|
||||
if err := s.db.WithContext(ctx).
|
||||
Model(&models.RequestLog{}).
|
||||
Where("request_time >= ?", startOfDay).
|
||||
Count(&count1d).Error; err != nil {
|
||||
return fmt.Errorf("1d count: %w", err)
|
||||
}
|
||||
|
||||
// 最近30天使用聚合表
|
||||
var count30d int64
|
||||
if err := s.db.WithContext(ctx).
|
||||
Model(&models.StatsHourly{}).
|
||||
Where("time >= ?", now.AddDate(0, 0, -30)).
|
||||
Select("COALESCE(SUM(request_count), 0)").
|
||||
Scan(&count30d).Error; err != nil {
|
||||
return fmt.Errorf("30d count: %w", err)
|
||||
}
|
||||
|
||||
resp.RequestCounts["1m"] = count1m
|
||||
resp.RequestCounts["1h"] = count1h
|
||||
resp.RequestCounts["1d"] = count1d
|
||||
resp.RequestCounts["30d"] = count30d
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadUpstreamHealth 加载上游健康状态
|
||||
func (s *DashboardQueryService) loadUpstreamHealth(ctx context.Context, resp *models.DashboardStatsResponse) error {
|
||||
var upstreams []*models.UpstreamEndpoint
|
||||
if err := s.db.WithContext(ctx).Find(&upstreams).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, u := range upstreams {
|
||||
resp.UpstreamHealthStatus[u.URL] = u.Status
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== 事件监听 ====================
|
||||
|
||||
// eventListener 监听缓存失效事件
|
||||
func (s *DashboardQueryService) eventListener() {
|
||||
defer s.wg.Done()
|
||||
|
||||
// 订阅事件
|
||||
keyStatusSub, err1 := s.store.Subscribe(s.ctx, models.TopicKeyStatusChanged)
|
||||
upstreamStatusSub, err2 := s.store.Subscribe(s.ctx, models.TopicUpstreamHealthChanged)
|
||||
|
||||
// 错误处理
|
||||
if err1 != nil {
|
||||
s.logger.WithError(err1).Error("Failed to subscribe to key status events")
|
||||
keyStatusSub = nil
|
||||
}
|
||||
if err2 != nil {
|
||||
s.logger.WithError(err2).Error("Failed to subscribe to upstream status events")
|
||||
upstreamStatusSub = nil
|
||||
}
|
||||
|
||||
// 如果全部失败,直接返回
|
||||
if keyStatusSub == nil && upstreamStatusSub == nil {
|
||||
s.logger.Error("All event subscriptions failed, listener disabled")
|
||||
return
|
||||
}
|
||||
|
||||
// 安全关闭订阅
|
||||
defer func() {
|
||||
if keyStatusSub != nil {
|
||||
if err := keyStatusSub.Close(); err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to close key status subscription")
|
||||
}
|
||||
}
|
||||
if upstreamStatusSub != nil {
|
||||
if err := upstreamStatusSub.Close(); err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to close upstream status subscription")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"key_status_sub": keyStatusSub != nil,
|
||||
"upstream_status_sub": upstreamStatusSub != nil,
|
||||
}).Info("Event listener started")
|
||||
|
||||
neverReady := make(chan *store.Message)
|
||||
close(neverReady) // 立即关闭,确保永远不会阻塞
|
||||
|
||||
for {
|
||||
// 动态选择有效的 channel
|
||||
var keyStatusChan <-chan *store.Message = neverReady
|
||||
if keyStatusSub != nil {
|
||||
keyStatusChan = keyStatusSub.Channel()
|
||||
}
|
||||
|
||||
var upstreamStatusChan <-chan *store.Message = neverReady
|
||||
if upstreamStatusSub != nil {
|
||||
upstreamStatusChan = upstreamStatusSub.Channel()
|
||||
}
|
||||
|
||||
select {
|
||||
case _, ok := <-keyStatusChan:
|
||||
if !ok {
|
||||
s.logger.Warn("Key status channel closed")
|
||||
keyStatusSub = nil
|
||||
continue
|
||||
}
|
||||
s.logger.Debug("Received key status changed event")
|
||||
if err := s.InvalidateOverviewCache(); err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to invalidate cache on key status change")
|
||||
}
|
||||
|
||||
case _, ok := <-upstreamStatusChan:
|
||||
if !ok {
|
||||
s.logger.Warn("Upstream status channel closed")
|
||||
upstreamStatusSub = nil
|
||||
continue
|
||||
}
|
||||
s.logger.Debug("Received upstream status changed event")
|
||||
if err := s.InvalidateOverviewCache(); err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to invalidate cache on upstream status change")
|
||||
}
|
||||
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Event listener stopping (stopChan)")
|
||||
return
|
||||
|
||||
case <-s.ctx.Done():
|
||||
s.logger.Info("Event listener stopping (context cancelled)")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 监控指标 ====================
|
||||
|
||||
// metricsReporter 定期输出统计信息
|
||||
func (s *DashboardQueryService) metricsReporter() {
|
||||
defer s.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.reportMetrics()
|
||||
case <-s.stopChan:
|
||||
return
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardQueryService) reportMetrics() {
|
||||
s.lastQueryMutex.RLock()
|
||||
lastQuery := s.lastQueryTime
|
||||
s.lastQueryMutex.RUnlock()
|
||||
|
||||
totalQueries := s.queryCount.Load()
|
||||
hits := s.cacheHits.Load()
|
||||
misses := s.cacheMisses.Load()
|
||||
|
||||
var cacheHitRate float64
|
||||
if hits+misses > 0 {
|
||||
cacheHitRate = float64(hits) / float64(hits+misses) * 100
|
||||
}
|
||||
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"total_queries": totalQueries,
|
||||
"cache_hits": hits,
|
||||
"cache_misses": misses,
|
||||
"cache_hit_rate": fmt.Sprintf("%.2f%%", cacheHitRate),
|
||||
"overview_loads": s.overviewLoadCount.Load(),
|
||||
"last_query_ago": time.Since(lastQuery).Round(time.Second),
|
||||
}).Info("DashboardQuery metrics")
|
||||
}
|
||||
|
||||
// GetMetrics 返回当前统计指标(供监控使用)
|
||||
func (s *DashboardQueryService) GetMetrics() map[string]interface{} {
|
||||
s.lastQueryMutex.RLock()
|
||||
lastQuery := s.lastQueryTime
|
||||
s.lastQueryMutex.RUnlock()
|
||||
|
||||
hits := s.cacheHits.Load()
|
||||
misses := s.cacheMisses.Load()
|
||||
|
||||
var cacheHitRate float64
|
||||
if hits+misses > 0 {
|
||||
cacheHitRate = float64(hits) / float64(hits+misses) * 100
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_queries": s.queryCount.Load(),
|
||||
"cache_hits": hits,
|
||||
"cache_misses": misses,
|
||||
"cache_hit_rate": cacheHitRate,
|
||||
"overview_loads": s.overviewLoadCount.Load(),
|
||||
"last_query_ago": time.Since(lastQuery).Seconds(),
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 辅助方法 ====================
|
||||
|
||||
// calculateFailureRate 计算失败率
|
||||
func (s *DashboardQueryService) calculateFailureRate(total, success int64) float64 {
|
||||
if total == 0 {
|
||||
return 0.0
|
||||
}
|
||||
return float64(total-success) / float64(total) * 100
|
||||
}
|
||||
|
||||
// updateLastQueryTime 更新最后查询时间
|
||||
func (s *DashboardQueryService) updateLastQueryTime() {
|
||||
s.lastQueryMutex.Lock()
|
||||
s.lastQueryTime = time.Now()
|
||||
s.lastQueryMutex.Unlock()
|
||||
}
|
||||
|
||||
// buildTimeFormatSelectClause 根据数据库类型构建时间格式化子句
|
||||
func (s *DashboardQueryService) buildTimeFormatSelectClause() (string, string) {
|
||||
dialect := s.db.Dialector.Name()
|
||||
switch dialect {
|
||||
case "mysql":
|
||||
return "DATE_FORMAT(time, '%Y-%m-%d %H:00:00')", "2006-01-02 15:00:00"
|
||||
case "postgres":
|
||||
return "TO_CHAR(time, 'YYYY-MM-DD HH24:00:00')", "2006-01-02 15:00:00"
|
||||
case "sqlite":
|
||||
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
|
||||
default:
|
||||
s.logger.WithField("dialect", dialect).Warn("Unknown database dialect, using SQLite format")
|
||||
return "strftime('%Y-%m-%d %H:00:00', time)", "2006-01-02 15:00:00"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
// Filename: internal/service/db_log_writer_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
@@ -18,132 +20,324 @@ type DBLogWriterService struct {
|
||||
db *gorm.DB
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
logBuffer chan *models.RequestLog
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
SettingsManager *settings.SettingsManager
|
||||
settingsManager *settings.SettingsManager
|
||||
|
||||
logBuffer chan *models.RequestLog
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// 统计指标
|
||||
totalReceived atomic.Uint64
|
||||
totalFlushed atomic.Uint64
|
||||
totalDropped atomic.Uint64
|
||||
flushCount atomic.Uint64
|
||||
lastFlushTime time.Time
|
||||
lastFlushMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewDBLogWriterService(db *gorm.DB, s store.Store, settings *settings.SettingsManager, logger *logrus.Logger) *DBLogWriterService {
|
||||
cfg := settings.GetSettings()
|
||||
func NewDBLogWriterService(
|
||||
db *gorm.DB,
|
||||
s store.Store,
|
||||
settingsManager *settings.SettingsManager,
|
||||
logger *logrus.Logger,
|
||||
) *DBLogWriterService {
|
||||
cfg := settingsManager.GetSettings()
|
||||
bufferCapacity := cfg.LogBufferCapacity
|
||||
if bufferCapacity <= 0 {
|
||||
bufferCapacity = 1000
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &DBLogWriterService{
|
||||
db: db,
|
||||
store: s,
|
||||
SettingsManager: settings,
|
||||
settingsManager: settingsManager,
|
||||
logger: logger.WithField("component", "DBLogWriter📝"),
|
||||
// 使用配置值来创建缓冲区
|
||||
logBuffer: make(chan *models.RequestLog, bufferCapacity),
|
||||
stopChan: make(chan struct{}),
|
||||
logBuffer: make(chan *models.RequestLog, bufferCapacity),
|
||||
stopChan: make(chan struct{}),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
lastFlushTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DBLogWriterService) Start() {
|
||||
s.wg.Add(2) // 一个用于事件监听,一个用于数据库写入
|
||||
|
||||
// 启动事件监听器
|
||||
s.wg.Add(2)
|
||||
go s.eventListenerLoop()
|
||||
// 启动数据库写入器
|
||||
go s.dbWriterLoop()
|
||||
|
||||
s.logger.Info("DBLogWriterService started.")
|
||||
// 定期输出统计信息
|
||||
s.wg.Add(1)
|
||||
go s.metricsReporter()
|
||||
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"buffer_capacity": cap(s.logBuffer),
|
||||
}).Info("DBLogWriterService started")
|
||||
}
|
||||
|
||||
func (s *DBLogWriterService) Stop() {
|
||||
s.logger.Info("DBLogWriterService stopping...")
|
||||
close(s.stopChan) // 通知所有goroutine停止
|
||||
s.wg.Wait() // 等待所有goroutine完成
|
||||
s.logger.Info("DBLogWriterService stopped.")
|
||||
close(s.stopChan)
|
||||
s.cancel() // 取消上下文
|
||||
s.wg.Wait()
|
||||
|
||||
// 输出最终统计
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"total_received": s.totalReceived.Load(),
|
||||
"total_flushed": s.totalFlushed.Load(),
|
||||
"total_dropped": s.totalDropped.Load(),
|
||||
"flush_count": s.flushCount.Load(),
|
||||
}).Info("DBLogWriterService stopped")
|
||||
}
|
||||
|
||||
// eventListenerLoop 负责从store接收事件并放入内存缓冲区
|
||||
// 事件监听循环
|
||||
func (s *DBLogWriterService) eventListenerLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
sub, err := s.store.Subscribe(models.TopicRequestFinished)
|
||||
sub, err := s.store.Subscribe(s.ctx, models.TopicRequestFinished)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicRequestFinished, err)
|
||||
s.logger.WithError(err).Error("Failed to subscribe to request events, log writing disabled")
|
||||
return
|
||||
}
|
||||
defer sub.Close()
|
||||
defer func() {
|
||||
if err := sub.Close(); err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to close subscription")
|
||||
}
|
||||
}()
|
||||
|
||||
s.logger.Info("Subscribed to request events for database logging.")
|
||||
s.logger.Info("Subscribed to request events for database logging")
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-sub.Channel():
|
||||
var event models.RequestFinishedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal event for logging: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 将事件中的日志部分放入缓冲区
|
||||
select {
|
||||
case s.logBuffer <- &event.RequestLog:
|
||||
default:
|
||||
s.logger.Warn("Log buffer is full. A log message might be dropped.")
|
||||
}
|
||||
s.handleMessage(msg)
|
||||
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Event listener loop stopping.")
|
||||
// 关闭缓冲区,以通知dbWriterLoop处理完剩余日志后退出
|
||||
s.logger.Info("Event listener loop stopping")
|
||||
close(s.logBuffer)
|
||||
return
|
||||
|
||||
case <-s.ctx.Done():
|
||||
s.logger.Info("Event listener context cancelled")
|
||||
close(s.logBuffer)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dbWriterLoop 负责从内存缓冲区批量读取日志并写入数据库
|
||||
// 处理单条消息
|
||||
func (s *DBLogWriterService) handleMessage(msg *store.Message) {
|
||||
var event models.RequestFinishedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to unmarshal request event")
|
||||
return
|
||||
}
|
||||
|
||||
s.totalReceived.Add(1)
|
||||
|
||||
select {
|
||||
case s.logBuffer <- &event.RequestLog:
|
||||
// 成功入队
|
||||
default:
|
||||
// 缓冲区满,丢弃日志
|
||||
dropped := s.totalDropped.Add(1)
|
||||
if dropped%100 == 1 { // 每100条丢失输出一次警告
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"total_dropped": dropped,
|
||||
"buffer_capacity": cap(s.logBuffer),
|
||||
"buffer_len": len(s.logBuffer),
|
||||
}).Warn("Log buffer full, messages being dropped")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 数据库写入循环
|
||||
func (s *DBLogWriterService) dbWriterLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
// 在启动时获取一次配置
|
||||
cfg := s.SettingsManager.GetSettings()
|
||||
cfg := s.settingsManager.GetSettings()
|
||||
batchSize := cfg.LogFlushBatchSize
|
||||
if batchSize <= 0 {
|
||||
batchSize = 100
|
||||
}
|
||||
|
||||
flushTimeout := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
|
||||
if flushTimeout <= 0 {
|
||||
flushTimeout = 5 * time.Second
|
||||
flushInterval := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
|
||||
if flushInterval <= 0 {
|
||||
flushInterval = 5 * time.Second
|
||||
}
|
||||
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"batch_size": batchSize,
|
||||
"flush_interval": flushInterval,
|
||||
}).Info("DB writer loop started")
|
||||
|
||||
batch := make([]*models.RequestLog, 0, batchSize)
|
||||
ticker := time.NewTicker(flushTimeout)
|
||||
ticker := time.NewTicker(flushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// 配置热更新检查(每分钟)
|
||||
configTicker := time.NewTicker(1 * time.Minute)
|
||||
defer configTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case logEntry, ok := <-s.logBuffer:
|
||||
if !ok {
|
||||
// 通道关闭,刷新剩余日志
|
||||
if len(batch) > 0 {
|
||||
s.flushBatch(batch)
|
||||
}
|
||||
s.logger.Info("DB writer loop finished.")
|
||||
s.logger.Info("DB writer loop finished")
|
||||
return
|
||||
}
|
||||
|
||||
batch = append(batch, logEntry)
|
||||
if len(batch) >= batchSize { // 使用配置的批次大小
|
||||
if len(batch) >= batchSize {
|
||||
s.flushBatch(batch)
|
||||
batch = make([]*models.RequestLog, 0, batchSize)
|
||||
}
|
||||
|
||||
case <-ticker.C:
|
||||
if len(batch) > 0 {
|
||||
s.flushBatch(batch)
|
||||
batch = make([]*models.RequestLog, 0, batchSize)
|
||||
}
|
||||
|
||||
case <-configTicker.C:
|
||||
// 热更新配置
|
||||
cfg := s.settingsManager.GetSettings()
|
||||
newBatchSize := cfg.LogFlushBatchSize
|
||||
if newBatchSize <= 0 {
|
||||
newBatchSize = 100
|
||||
}
|
||||
newFlushInterval := time.Duration(cfg.LogFlushIntervalSeconds) * time.Second
|
||||
if newFlushInterval <= 0 {
|
||||
newFlushInterval = 5 * time.Second
|
||||
}
|
||||
|
||||
if newBatchSize != batchSize {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"old": batchSize,
|
||||
"new": newBatchSize,
|
||||
}).Info("Batch size updated")
|
||||
batchSize = newBatchSize
|
||||
if len(batch) >= batchSize {
|
||||
s.flushBatch(batch)
|
||||
batch = make([]*models.RequestLog, 0, batchSize)
|
||||
}
|
||||
}
|
||||
|
||||
if newFlushInterval != flushInterval {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"old": flushInterval,
|
||||
"new": newFlushInterval,
|
||||
}).Info("Flush interval updated")
|
||||
flushInterval = newFlushInterval
|
||||
ticker.Reset(flushInterval)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// flushBatch 将一个批次的日志写入数据库
|
||||
// 批量刷写到数据库
|
||||
func (s *DBLogWriterService) flushBatch(batch []*models.RequestLog) {
|
||||
if err := s.db.CreateInBatches(batch, len(batch)).Error; err != nil {
|
||||
s.logger.WithField("batch_size", len(batch)).WithError(err).Error("Failed to flush log batch to database.")
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.db.WithContext(ctx).CreateInBatches(batch, len(batch)).Error
|
||||
duration := time.Since(start)
|
||||
|
||||
s.lastFlushMutex.Lock()
|
||||
s.lastFlushTime = time.Now()
|
||||
s.lastFlushMutex.Unlock()
|
||||
|
||||
if err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"batch_size": len(batch),
|
||||
"duration": duration,
|
||||
}).WithError(err).Error("Failed to flush log batch to database")
|
||||
} else {
|
||||
s.logger.Infof("Successfully flushed %d logs to database.", len(batch))
|
||||
flushed := s.totalFlushed.Add(uint64(len(batch)))
|
||||
flushCount := s.flushCount.Add(1)
|
||||
|
||||
// 只在慢写入或大批量时输出日志
|
||||
if duration > 1*time.Second || len(batch) > 500 {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"batch_size": len(batch),
|
||||
"duration": duration,
|
||||
"total_flushed": flushed,
|
||||
"flush_count": flushCount,
|
||||
}).Info("Log batch flushed to database")
|
||||
} else {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"batch_size": len(batch),
|
||||
"duration": duration,
|
||||
}).Debug("Log batch flushed to database")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 定期输出统计信息
|
||||
func (s *DBLogWriterService) metricsReporter() {
|
||||
defer s.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.reportMetrics()
|
||||
case <-s.stopChan:
|
||||
return
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DBLogWriterService) reportMetrics() {
|
||||
s.lastFlushMutex.RLock()
|
||||
lastFlush := s.lastFlushTime
|
||||
s.lastFlushMutex.RUnlock()
|
||||
|
||||
received := s.totalReceived.Load()
|
||||
flushed := s.totalFlushed.Load()
|
||||
dropped := s.totalDropped.Load()
|
||||
pending := uint64(len(s.logBuffer))
|
||||
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"received": received,
|
||||
"flushed": flushed,
|
||||
"dropped": dropped,
|
||||
"pending": pending,
|
||||
"flush_count": s.flushCount.Load(),
|
||||
"last_flush": time.Since(lastFlush).Round(time.Second),
|
||||
"buffer_usage": float64(pending) / float64(cap(s.logBuffer)) * 100,
|
||||
"success_rate": float64(flushed) / float64(received) * 100,
|
||||
}).Info("DBLogWriter metrics")
|
||||
}
|
||||
|
||||
// GetMetrics 返回当前统计指标(供监控使用)
|
||||
func (s *DBLogWriterService) GetMetrics() map[string]interface{} {
|
||||
s.lastFlushMutex.RLock()
|
||||
lastFlush := s.lastFlushTime
|
||||
s.lastFlushMutex.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_received": s.totalReceived.Load(),
|
||||
"total_flushed": s.totalFlushed.Load(),
|
||||
"total_dropped": s.totalDropped.Load(),
|
||||
"flush_count": s.flushCount.Load(),
|
||||
"buffer_pending": len(s.logBuffer),
|
||||
"buffer_capacity": cap(s.logBuffer),
|
||||
"last_flush_ago": time.Since(lastFlush).Seconds(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
// Filename: internal/service/group_manager.go (Syncer升级版)
|
||||
|
||||
// Filename: internal/service/group_manager.go
|
||||
package service
|
||||
|
||||
import (
|
||||
@@ -10,6 +9,7 @@ import (
|
||||
"gemini-balancer/internal/pkg/reflectutil"
|
||||
"gemini-balancer/internal/repository"
|
||||
"gemini-balancer/internal/settings"
|
||||
"gemini-balancer/internal/store"
|
||||
"gemini-balancer/internal/syncer"
|
||||
"gemini-balancer/internal/utils"
|
||||
"net/url"
|
||||
@@ -29,10 +29,9 @@ type GroupManagerCacheData struct {
|
||||
Groups []*models.KeyGroup
|
||||
GroupsByName map[string]*models.KeyGroup
|
||||
GroupsByID map[uint]*models.KeyGroup
|
||||
KeyCounts map[uint]int64 // GroupID -> Total Key Count
|
||||
KeyStatusCounts map[uint]map[models.APIKeyStatus]int64 // GroupID -> Status -> Count
|
||||
KeyCounts map[uint]int64
|
||||
KeyStatusCounts map[uint]map[models.APIKeyStatus]int64
|
||||
}
|
||||
|
||||
type GroupManager struct {
|
||||
db *gorm.DB
|
||||
keyRepo repository.KeyRepository
|
||||
@@ -41,7 +40,6 @@ type GroupManager struct {
|
||||
syncer *syncer.CacheSyncer[GroupManagerCacheData]
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
type UpdateOrderPayload struct {
|
||||
ID uint `json:"id" binding:"required"`
|
||||
Order int `json:"order"`
|
||||
@@ -49,43 +47,19 @@ type UpdateOrderPayload struct {
|
||||
|
||||
func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc[GroupManagerCacheData] {
|
||||
return func() (GroupManagerCacheData, error) {
|
||||
logger.Debugf("[GML-LOG 1/5] ---> Entering NewGroupManagerLoader...")
|
||||
var groups []*models.KeyGroup
|
||||
logger.Debugf("[GML-LOG 2/5] About to execute DB query with Preloads...")
|
||||
|
||||
if err := db.Preload("AllowedUpstreams").
|
||||
Preload("AllowedModels").
|
||||
Preload("Settings").
|
||||
Preload("RequestConfig").
|
||||
Preload("Mappings").
|
||||
Find(&groups).Error; err != nil {
|
||||
logger.Errorf("[GML-LOG] CRITICAL: DB query for groups failed: %v", err)
|
||||
return GroupManagerCacheData{}, fmt.Errorf("failed to load key groups for cache: %w", err)
|
||||
return GroupManagerCacheData{}, fmt.Errorf("failed to load groups: %w", err)
|
||||
}
|
||||
logger.Debugf("[GML-LOG 2.1/5] DB query for groups finished. Found %d group records.", len(groups))
|
||||
|
||||
var allMappings []*models.GroupAPIKeyMapping
|
||||
if err := db.Find(&allMappings).Error; err != nil {
|
||||
logger.Errorf("[GML-LOG] CRITICAL: DB query for mappings failed: %v", err)
|
||||
return GroupManagerCacheData{}, fmt.Errorf("failed to load key mappings for cache: %w", err)
|
||||
}
|
||||
logger.Debugf("[GML-LOG 2.2/5] DB query for mappings finished. Found %d total mapping records.", len(allMappings))
|
||||
|
||||
mappingsByGroupID := make(map[uint][]*models.GroupAPIKeyMapping)
|
||||
for i := range allMappings {
|
||||
mapping := allMappings[i] // Avoid pointer issues with range
|
||||
mappingsByGroupID[mapping.KeyGroupID] = append(mappingsByGroupID[mapping.KeyGroupID], mapping)
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
if mappings, ok := mappingsByGroupID[group.ID]; ok {
|
||||
group.Mappings = mappings
|
||||
}
|
||||
}
|
||||
logger.Debugf("[GML-LOG 3/5] Finished manually associating mappings to groups.")
|
||||
|
||||
keyCounts := make(map[uint]int64, len(groups))
|
||||
keyStatusCounts := make(map[uint]map[models.APIKeyStatus]int64, len(groups))
|
||||
|
||||
groupsByName := make(map[string]*models.KeyGroup, len(groups))
|
||||
groupsByID := make(map[uint]*models.KeyGroup, len(groups))
|
||||
for _, group := range groups {
|
||||
keyCounts[group.ID] = int64(len(group.Mappings))
|
||||
statusCounts := make(map[models.APIKeyStatus]int64)
|
||||
@@ -93,20 +67,9 @@ func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc
|
||||
statusCounts[mapping.Status]++
|
||||
}
|
||||
keyStatusCounts[group.ID] = statusCounts
|
||||
groupsByName[group.Name] = group
|
||||
groupsByID[group.ID] = group
|
||||
}
|
||||
groupsByName := make(map[string]*models.KeyGroup, len(groups))
|
||||
groupsByID := make(map[uint]*models.KeyGroup, len(groups))
|
||||
|
||||
logger.Debugf("[GML-LOG 4/5] Starting to process group records into maps...")
|
||||
for i, group := range groups {
|
||||
if group == nil {
|
||||
logger.Debugf("[GML] CRITICAL: Found a 'nil' group pointer at index %d! This is the most likely cause of the panic.", i)
|
||||
} else {
|
||||
groupsByName[group.Name] = group
|
||||
groupsByID[group.ID] = group
|
||||
}
|
||||
}
|
||||
logger.Debugf("[GML-LOG 5/5] Finished processing records. Building final cache data...")
|
||||
return GroupManagerCacheData{
|
||||
Groups: groups,
|
||||
GroupsByName: groupsByName,
|
||||
@@ -116,7 +79,6 @@ func NewGroupManagerLoader(db *gorm.DB, logger *logrus.Logger) syncer.LoaderFunc
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func NewGroupManager(
|
||||
db *gorm.DB,
|
||||
keyRepo repository.KeyRepository,
|
||||
@@ -134,138 +96,67 @@ func NewGroupManager(
|
||||
logger: logger.WithField("component", "GroupManager"),
|
||||
}
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetAllGroups() []*models.KeyGroup {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.Groups) == 0 {
|
||||
return []*models.KeyGroup{}
|
||||
}
|
||||
groupsToOrder := cache.Groups
|
||||
sort.Slice(groupsToOrder, func(i, j int) bool {
|
||||
if groupsToOrder[i].Order != groupsToOrder[j].Order {
|
||||
return groupsToOrder[i].Order < groupsToOrder[j].Order
|
||||
groups := gm.syncer.Get().Groups
|
||||
sort.Slice(groups, func(i, j int) bool {
|
||||
if groups[i].Order != groups[j].Order {
|
||||
return groups[i].Order < groups[j].Order
|
||||
}
|
||||
return groupsToOrder[i].ID < groupsToOrder[j].ID
|
||||
return groups[i].ID < groups[j].ID
|
||||
})
|
||||
return groupsToOrder
|
||||
return groups
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetKeyCount(groupID uint) int64 {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.KeyCounts) == 0 {
|
||||
return 0
|
||||
}
|
||||
count := cache.KeyCounts[groupID]
|
||||
return count
|
||||
return gm.syncer.Get().KeyCounts[groupID]
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetKeyStatusCount(groupID uint) map[models.APIKeyStatus]int64 {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.KeyStatusCounts) == 0 {
|
||||
return make(map[models.APIKeyStatus]int64)
|
||||
}
|
||||
if counts, ok := cache.KeyStatusCounts[groupID]; ok {
|
||||
if counts, ok := gm.syncer.Get().KeyStatusCounts[groupID]; ok {
|
||||
return counts
|
||||
}
|
||||
return make(map[models.APIKeyStatus]int64)
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetGroupByName(name string) (*models.KeyGroup, bool) {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.GroupsByName) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
group, ok := cache.GroupsByName[name]
|
||||
group, ok := gm.syncer.Get().GroupsByName[name]
|
||||
return group, ok
|
||||
}
|
||||
|
||||
func (gm *GroupManager) GetGroupByID(id uint) (*models.KeyGroup, bool) {
|
||||
cache := gm.syncer.Get()
|
||||
if len(cache.GroupsByID) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
group, ok := cache.GroupsByID[id]
|
||||
group, ok := gm.syncer.Get().GroupsByID[id]
|
||||
return group, ok
|
||||
}
|
||||
|
||||
func (gm *GroupManager) Stop() {
|
||||
gm.syncer.Stop()
|
||||
}
|
||||
|
||||
func (gm *GroupManager) Invalidate() error {
|
||||
return gm.syncer.Invalidate()
|
||||
}
|
||||
|
||||
// --- Write Operations ---
|
||||
|
||||
// CreateKeyGroup creates a new key group, including its operational settings, and invalidates the cache.
|
||||
func (gm *GroupManager) CreateKeyGroup(group *models.KeyGroup, settings *models.KeyGroupSettings) error {
|
||||
if !utils.IsValidGroupName(group.Name) {
|
||||
return errors.New("invalid group name: must contain only lowercase letters, numbers, and hyphens")
|
||||
}
|
||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 1. Create the group itself to get an ID
|
||||
if err := tx.Create(group).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. If settings are provided, create the associated GroupSettings record
|
||||
if settings != nil {
|
||||
// Only marshal non-nil fields to keep the JSON clean
|
||||
settingsToMarshal := make(map[string]interface{})
|
||||
if settings.EnableKeyCheck != nil {
|
||||
settingsToMarshal["enable_key_check"] = settings.EnableKeyCheck
|
||||
settingsJSON, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal settings: %w", err)
|
||||
}
|
||||
if settings.KeyCheckIntervalMinutes != nil {
|
||||
settingsToMarshal["key_check_interval_minutes"] = settings.KeyCheckIntervalMinutes
|
||||
groupSettings := models.GroupSettings{
|
||||
GroupID: group.ID,
|
||||
SettingsJSON: datatypes.JSON(settingsJSON),
|
||||
}
|
||||
if settings.KeyBlacklistThreshold != nil {
|
||||
settingsToMarshal["key_blacklist_threshold"] = settings.KeyBlacklistThreshold
|
||||
}
|
||||
if settings.KeyCooldownMinutes != nil {
|
||||
settingsToMarshal["key_cooldown_minutes"] = settings.KeyCooldownMinutes
|
||||
}
|
||||
if settings.KeyCheckConcurrency != nil {
|
||||
settingsToMarshal["key_check_concurrency"] = settings.KeyCheckConcurrency
|
||||
}
|
||||
if settings.KeyCheckEndpoint != nil {
|
||||
settingsToMarshal["key_check_endpoint"] = settings.KeyCheckEndpoint
|
||||
}
|
||||
if settings.KeyCheckModel != nil {
|
||||
settingsToMarshal["key_check_model"] = settings.KeyCheckModel
|
||||
}
|
||||
if settings.MaxRetries != nil {
|
||||
settingsToMarshal["max_retries"] = settings.MaxRetries
|
||||
}
|
||||
if settings.EnableSmartGateway != nil {
|
||||
settingsToMarshal["enable_smart_gateway"] = settings.EnableSmartGateway
|
||||
}
|
||||
if len(settingsToMarshal) > 0 {
|
||||
settingsJSON, err := json.Marshal(settingsToMarshal)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal group settings: %w", err)
|
||||
}
|
||||
groupSettings := models.GroupSettings{
|
||||
GroupID: group.ID,
|
||||
SettingsJSON: datatypes.JSON(settingsJSON),
|
||||
}
|
||||
if err := tx.Create(&groupSettings).Error; err != nil {
|
||||
return fmt.Errorf("failed to save group settings: %w", err)
|
||||
}
|
||||
if err := tx.Create(&groupSettings).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
if err == nil {
|
||||
go gm.Invalidate()
|
||||
}
|
||||
|
||||
go gm.Invalidate()
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateKeyGroup updates an existing key group, its settings, and associations, then invalidates the cache.
|
||||
func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *models.KeyGroupSettings, upstreamURLs []string, modelNames []string) error {
|
||||
if !utils.IsValidGroupName(group.Name) {
|
||||
return fmt.Errorf("invalid group name: must contain only lowercase letters, numbers, and hyphens")
|
||||
@@ -273,7 +164,6 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
|
||||
uniqueUpstreamURLs := uniqueStrings(upstreamURLs)
|
||||
uniqueModelNames := uniqueStrings(modelNames)
|
||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||
// --- 1. Update AllowedUpstreams (M:N relationship) ---
|
||||
var upstreams []*models.UpstreamEndpoint
|
||||
if len(uniqueUpstreamURLs) > 0 {
|
||||
if err := tx.Where("url IN ?", uniqueUpstreamURLs).Find(&upstreams).Error; err != nil {
|
||||
@@ -283,7 +173,6 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
|
||||
if err := tx.Model(group).Association("AllowedUpstreams").Replace(upstreams); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Model(group).Association("AllowedModels").Clear(); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -296,11 +185,9 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Model(group).Updates(group).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var existingSettings models.GroupSettings
|
||||
if err := tx.Where("group_id = ?", group.ID).First(&existingSettings).Error; err != nil && err != gorm.ErrRecordNotFound {
|
||||
return err
|
||||
@@ -308,15 +195,15 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
|
||||
var currentSettingsData models.KeyGroupSettings
|
||||
if len(existingSettings.SettingsJSON) > 0 {
|
||||
if err := json.Unmarshal(existingSettings.SettingsJSON, ¤tSettingsData); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal existing group settings: %w", err)
|
||||
return fmt.Errorf("failed to unmarshal existing settings: %w", err)
|
||||
}
|
||||
}
|
||||
if err := reflectutil.MergeNilFields(¤tSettingsData, newSettings); err != nil {
|
||||
return fmt.Errorf("failed to merge group settings: %w", err)
|
||||
return fmt.Errorf("failed to merge settings: %w", err)
|
||||
}
|
||||
updatedJSON, err := json.Marshal(currentSettingsData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal updated group settings: %w", err)
|
||||
return fmt.Errorf("failed to marshal updated settings: %w", err)
|
||||
}
|
||||
existingSettings.GroupID = group.ID
|
||||
existingSettings.SettingsJSON = datatypes.JSON(updatedJSON)
|
||||
@@ -327,55 +214,25 @@ func (gm *GroupManager) UpdateKeyGroup(group *models.KeyGroup, newSettings *mode
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteKeyGroup deletes a key group and subsequently cleans up any keys that have become orphans.
|
||||
func (gm *GroupManager) DeleteKeyGroup(id uint) error {
|
||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||
gm.logger.Infof("Starting transaction to delete KeyGroup ID: %d", id)
|
||||
// Step 1: First, retrieve the group object we are about to delete.
|
||||
var group models.KeyGroup
|
||||
if err := tx.First(&group, id).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
gm.logger.Warnf("Attempted to delete a non-existent KeyGroup with ID: %d", id)
|
||||
return nil // Don't treat as an error, the group is already gone.
|
||||
return nil
|
||||
}
|
||||
gm.logger.WithError(err).Errorf("Failed to find KeyGroup with ID: %d for deletion", id)
|
||||
return err
|
||||
}
|
||||
// Step 2: Clear all many-to-many and one-to-many associations using GORM's safe methods.
|
||||
if err := tx.Model(&group).Association("AllowedUpstreams").Clear(); err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to clear 'AllowedUpstreams' association for KeyGroup ID: %d", id)
|
||||
if err := tx.Select("AllowedUpstreams", "AllowedModels", "Mappings", "Settings").Delete(&group).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(&group).Association("AllowedModels").Clear(); err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to clear 'AllowedModels' association for KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(&group).Association("Mappings").Clear(); err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to clear 'Mappings' (API Key associations) for KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
// Also clear settings if they exist to maintain data integrity
|
||||
if err := tx.Model(&group).Association("Settings").Delete(group.Settings); err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to delete 'Settings' association for KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
// Step 3: Delete the KeyGroup itself.
|
||||
if err := tx.Delete(&group).Error; err != nil {
|
||||
gm.logger.WithError(err).Errorf("Failed to delete KeyGroup ID: %d", id)
|
||||
return err
|
||||
}
|
||||
gm.logger.Infof("KeyGroup ID %d associations cleared and entity deleted. Triggering orphan key cleanup.", id)
|
||||
// Step 4: Trigger the orphan key cleanup (this logic remains the same and is correct).
|
||||
deletedCount, err := gm.keyRepo.DeleteOrphanKeysTx(tx)
|
||||
if err != nil {
|
||||
gm.logger.WithError(err).Error("Failed to clean up orphan keys after deleting group.")
|
||||
return err
|
||||
}
|
||||
if deletedCount > 0 {
|
||||
gm.logger.Infof("Successfully cleaned up %d orphan keys.", deletedCount)
|
||||
gm.logger.Infof("Cleaned up %d orphan keys after deleting group %d", deletedCount, id)
|
||||
}
|
||||
gm.logger.Infof("Transaction for deleting KeyGroup ID: %d completed successfully.", id)
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
@@ -383,7 +240,6 @@ func (gm *GroupManager) DeleteKeyGroup(id uint) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||
var originalGroup models.KeyGroup
|
||||
if err := gm.db.
|
||||
@@ -392,7 +248,7 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||
Preload("AllowedUpstreams").
|
||||
Preload("AllowedModels").
|
||||
First(&originalGroup, id).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to find original group with id %d: %w", id, err)
|
||||
return nil, fmt.Errorf("failed to find original group %d: %w", id, err)
|
||||
}
|
||||
newGroup := originalGroup
|
||||
timestamp := time.Now().Unix()
|
||||
@@ -401,31 +257,25 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||
newGroup.DisplayName = fmt.Sprintf("%s-clone-%d", originalGroup.DisplayName, timestamp)
|
||||
newGroup.CreatedAt = time.Time{}
|
||||
newGroup.UpdatedAt = time.Time{}
|
||||
|
||||
newGroup.RequestConfigID = nil
|
||||
newGroup.RequestConfig = nil
|
||||
newGroup.Mappings = nil
|
||||
newGroup.AllowedUpstreams = nil
|
||||
newGroup.AllowedModels = nil
|
||||
err := gm.db.Transaction(func(tx *gorm.DB) error {
|
||||
|
||||
if err := tx.Create(&newGroup).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if originalGroup.RequestConfig != nil {
|
||||
newRequestConfig := *originalGroup.RequestConfig
|
||||
newRequestConfig.ID = 0 // Mark as new record
|
||||
|
||||
newRequestConfig.ID = 0
|
||||
if err := tx.Create(&newRequestConfig).Error; err != nil {
|
||||
return fmt.Errorf("failed to clone request config: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Model(&newGroup).Update("request_config_id", newRequestConfig.ID).Error; err != nil {
|
||||
return fmt.Errorf("failed to link new group to cloned request config: %w", err)
|
||||
return fmt.Errorf("failed to link cloned request config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var originalSettings models.GroupSettings
|
||||
err := tx.Where("group_id = ?", originalGroup.ID).First(&originalSettings).Error
|
||||
if err == nil && len(originalSettings.SettingsJSON) > 0 {
|
||||
@@ -434,12 +284,11 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||
SettingsJSON: originalSettings.SettingsJSON,
|
||||
}
|
||||
if err := tx.Create(&newSettings).Error; err != nil {
|
||||
return fmt.Errorf("failed to clone group settings: %w", err)
|
||||
return fmt.Errorf("failed to clone settings: %w", err)
|
||||
}
|
||||
} else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("failed to query original group settings: %w", err)
|
||||
return fmt.Errorf("failed to query original settings: %w", err)
|
||||
}
|
||||
|
||||
if len(originalGroup.Mappings) > 0 {
|
||||
newMappings := make([]models.GroupAPIKeyMapping, len(originalGroup.Mappings))
|
||||
for i, oldMapping := range originalGroup.Mappings {
|
||||
@@ -454,7 +303,7 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||
}
|
||||
}
|
||||
if err := tx.Create(&newMappings).Error; err != nil {
|
||||
return fmt.Errorf("failed to clone key group mappings: %w", err)
|
||||
return fmt.Errorf("failed to clone mappings: %w", err)
|
||||
}
|
||||
}
|
||||
if len(originalGroup.AllowedUpstreams) > 0 {
|
||||
@@ -469,13 +318,10 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go gm.Invalidate()
|
||||
|
||||
var finalClonedGroup models.KeyGroup
|
||||
if err := gm.db.
|
||||
Preload("RequestConfig").
|
||||
@@ -487,10 +333,8 @@ func (gm *GroupManager) CloneKeyGroup(id uint) (*models.KeyGroup, error) {
|
||||
}
|
||||
return &finalClonedGroup, nil
|
||||
}
|
||||
|
||||
func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.KeyGroupSettings, error) {
|
||||
globalSettings := gm.settingsManager.GetSettings()
|
||||
s := "gemini-1.5-flash" // Per user feedback for default model
|
||||
opConfig := &models.KeyGroupSettings{
|
||||
EnableKeyCheck: &globalSettings.EnableBaseKeyCheck,
|
||||
KeyCheckConcurrency: &globalSettings.BaseKeyCheckConcurrency,
|
||||
@@ -498,52 +342,43 @@ func (gm *GroupManager) BuildOperationalConfig(group *models.KeyGroup) (*models.
|
||||
KeyCheckEndpoint: &globalSettings.DefaultUpstreamURL,
|
||||
KeyBlacklistThreshold: &globalSettings.BlacklistThreshold,
|
||||
KeyCooldownMinutes: &globalSettings.KeyCooldownMinutes,
|
||||
KeyCheckModel: &s,
|
||||
KeyCheckModel: &globalSettings.BaseKeyCheckModel,
|
||||
MaxRetries: &globalSettings.MaxRetries,
|
||||
EnableSmartGateway: &globalSettings.EnableSmartGateway,
|
||||
}
|
||||
|
||||
if group == nil {
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
var groupSettingsRecord models.GroupSettings
|
||||
err := gm.db.Where("group_id = ?", group.ID).First(&groupSettingsRecord).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
gm.logger.WithError(err).Errorf("Failed to query group settings for group ID %d", group.ID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(groupSettingsRecord.SettingsJSON) == 0 {
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
var groupSpecificSettings models.KeyGroupSettings
|
||||
if err := json.Unmarshal(groupSettingsRecord.SettingsJSON, &groupSpecificSettings); err != nil {
|
||||
gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings JSON.")
|
||||
return opConfig, err
|
||||
}
|
||||
|
||||
if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil {
|
||||
gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group-specific settings over defaults.")
|
||||
gm.logger.WithError(err).WithField("group_id", group.ID).Warn("Failed to unmarshal group settings")
|
||||
return opConfig, nil
|
||||
}
|
||||
if err := reflectutil.MergeNilFields(opConfig, &groupSpecificSettings); err != nil {
|
||||
gm.logger.WithError(err).WithField("group_id", group.ID).Error("Failed to merge group settings")
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
|
||||
group, ok := gm.GetGroupByID(groupID)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("group with id %d not found", groupID)
|
||||
return "", fmt.Errorf("group %d not found", groupID)
|
||||
}
|
||||
opConfig, err := gm.BuildOperationalConfig(group)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build operational config for group %d: %w", groupID, err)
|
||||
return "", err
|
||||
}
|
||||
globalSettings := gm.settingsManager.GetSettings()
|
||||
baseURL := globalSettings.DefaultUpstreamURL
|
||||
@@ -551,7 +386,7 @@ func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
|
||||
baseURL = *opConfig.KeyCheckEndpoint
|
||||
}
|
||||
if baseURL == "" {
|
||||
return "", fmt.Errorf("no key check endpoint or default upstream URL is configured for group %d", groupID)
|
||||
return "", fmt.Errorf("no endpoint configured for group %d", groupID)
|
||||
}
|
||||
modelName := globalSettings.BaseKeyCheckModel
|
||||
if opConfig.KeyCheckModel != nil && *opConfig.KeyCheckModel != "" {
|
||||
@@ -559,38 +394,41 @@ func (gm *GroupManager) BuildKeyCheckEndpoint(groupID uint) (string, error) {
|
||||
}
|
||||
parsedURL, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse base URL '%s': %w", baseURL, err)
|
||||
return "", fmt.Errorf("invalid URL '%s': %w", baseURL, err)
|
||||
}
|
||||
cleanedPath := parsedURL.Path
|
||||
cleanedPath = strings.TrimSuffix(cleanedPath, "/")
|
||||
cleanedPath = strings.TrimSuffix(cleanedPath, "/v1beta")
|
||||
parsedURL.Path = path.Join(cleanedPath, "v1beta", "models", modelName)
|
||||
finalEndpoint := parsedURL.String()
|
||||
return finalEndpoint, nil
|
||||
cleanedPath := strings.TrimSuffix(strings.TrimSuffix(parsedURL.Path, "/"), "/v1beta")
|
||||
parsedURL.Path = path.Join(cleanedPath, "v1beta/models", modelName)
|
||||
return parsedURL.String(), nil
|
||||
}
|
||||
|
||||
func (gm *GroupManager) UpdateOrder(payload []UpdateOrderPayload) error {
|
||||
ordersMap := make(map[uint]int, len(payload))
|
||||
for _, item := range payload {
|
||||
ordersMap[item.ID] = item.Order
|
||||
}
|
||||
if err := gm.groupRepo.UpdateOrderInTransaction(ordersMap); err != nil {
|
||||
gm.logger.WithError(err).Error("Failed to update group order in transaction")
|
||||
return fmt.Errorf("database transaction failed: %w", err)
|
||||
return fmt.Errorf("failed to update order: %w", err)
|
||||
}
|
||||
gm.logger.Info("Group order updated successfully, invalidating cache...")
|
||||
go gm.Invalidate()
|
||||
return nil
|
||||
}
|
||||
|
||||
func uniqueStrings(slice []string) []string {
|
||||
keys := make(map[string]struct{})
|
||||
list := []string{}
|
||||
for _, entry := range slice {
|
||||
if _, value := keys[entry]; !value {
|
||||
keys[entry] = struct{}{}
|
||||
list = append(list, entry)
|
||||
seen := make(map[string]struct{}, len(slice))
|
||||
result := make([]string, 0, len(slice))
|
||||
for _, s := range slice {
|
||||
if _, exists := seen[s]; !exists {
|
||||
seen[s] = struct{}{}
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return list
|
||||
return result
|
||||
}
|
||||
|
||||
// GroupManager配置Syncer
|
||||
func NewGroupManagerSyncer(
|
||||
loader syncer.LoaderFunc[GroupManagerCacheData],
|
||||
store store.Store,
|
||||
logger *logrus.Logger,
|
||||
) (*syncer.CacheSyncer[GroupManagerCacheData], error) {
|
||||
const groupUpdateChannel = "groups:cache_invalidation"
|
||||
return syncer.NewCacheSyncer(loader, store, groupUpdateChannel, logger)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
@@ -22,6 +23,10 @@ const (
|
||||
TaskTypeHardDeleteKeys = "hard_delete_keys"
|
||||
TaskTypeRestoreKeys = "restore_keys"
|
||||
chunkSize = 500
|
||||
|
||||
// 任务超时时间常量化
|
||||
defaultTaskTimeout = 15 * time.Minute
|
||||
longTaskTimeout = time.Hour
|
||||
)
|
||||
|
||||
type KeyImportService struct {
|
||||
@@ -42,295 +47,425 @@ func NewKeyImportService(ts task.Reporter, kr repository.KeyRepository, s store.
|
||||
}
|
||||
}
|
||||
|
||||
// --- 通用的 Panic-Safe 任務執行器 ---
|
||||
func (s *KeyImportService) runTaskWithRecovery(taskID string, resourceID string, taskFunc func()) {
|
||||
// runTaskWithRecovery 统一的任务恢复包装器
|
||||
func (s *KeyImportService) runTaskWithRecovery(ctx context.Context, taskID string, resourceID string, taskFunc func()) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err := fmt.Errorf("panic recovered in task %s: %v", taskID, r)
|
||||
s.logger.Error(err)
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, err)
|
||||
s.logger.WithField("task_id", taskID).Error(err)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
||||
}
|
||||
}()
|
||||
taskFunc()
|
||||
}
|
||||
|
||||
// --- Public Task Starters ---
|
||||
|
||||
func (s *KeyImportService) StartAddKeysTask(groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
|
||||
// StartAddKeysTask 启动批量添加密钥任务
|
||||
func (s *KeyImportService) StartAddKeysTask(ctx context.Context, groupID uint, keysText string, validateOnImport bool) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found in input text")
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), 15*time.Minute)
|
||||
|
||||
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeAddKeysToGroup, resourceID, len(keys), defaultTaskTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||||
s.runAddKeysTask(taskStatus.ID, resourceID, groupID, keys, validateOnImport)
|
||||
|
||||
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
||||
s.runAddKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys, validateOnImport)
|
||||
})
|
||||
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyImportService) StartUnlinkKeysTask(groupID uint, keysText string) (*task.Status, error) {
|
||||
// StartUnlinkKeysTask 启动批量解绑密钥任务
|
||||
func (s *KeyImportService) StartUnlinkKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), time.Hour)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, uint(groupID), TaskTypeUnlinkKeysFromGroup, resourceID, len(keys), longTaskTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||||
s.runUnlinkKeysTask(taskStatus.ID, resourceID, groupID, keys)
|
||||
|
||||
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
||||
s.runUnlinkKeysTask(context.Background(), taskStatus.ID, resourceID, groupID, keys)
|
||||
})
|
||||
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyImportService) StartHardDeleteKeysTask(keysText string) (*task.Status, error) {
|
||||
// StartHardDeleteKeysTask 启动硬删除密钥任务
|
||||
func (s *KeyImportService) StartHardDeleteKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
resourceID := "global_hard_delete" // Global lock
|
||||
taskStatus, err := s.taskService.StartTask(0, TaskTypeHardDeleteKeys, resourceID, len(keys), time.Hour)
|
||||
|
||||
resourceID := "global_hard_delete"
|
||||
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeHardDeleteKeys, resourceID, len(keys), longTaskTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||||
s.runHardDeleteKeysTask(taskStatus.ID, resourceID, keys)
|
||||
|
||||
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
||||
s.runHardDeleteKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
|
||||
})
|
||||
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyImportService) StartRestoreKeysTask(keysText string) (*task.Status, error) {
|
||||
// StartRestoreKeysTask 启动恢复密钥任务
|
||||
func (s *KeyImportService) StartRestoreKeysTask(ctx context.Context, keysText string) (*task.Status, error) {
|
||||
keys := utils.ParseKeysFromText(keysText)
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found")
|
||||
}
|
||||
resourceID := "global_restore_keys" // Global lock
|
||||
taskStatus, err := s.taskService.StartTask(0, TaskTypeRestoreKeys, resourceID, len(keys), time.Hour)
|
||||
|
||||
resourceID := "global_restore_keys"
|
||||
taskStatus, err := s.taskService.StartTask(ctx, 0, TaskTypeRestoreKeys, resourceID, len(keys), longTaskTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go s.runTaskWithRecovery(taskStatus.ID, resourceID, func() {
|
||||
s.runRestoreKeysTask(taskStatus.ID, resourceID, keys)
|
||||
|
||||
go s.runTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, func() {
|
||||
s.runRestoreKeysTask(context.Background(), taskStatus.ID, resourceID, keys)
|
||||
})
|
||||
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
// --- Private Task Runners ---
|
||||
// StartUnlinkKeysByFilterTask 根据状态过滤条件批量解绑
|
||||
func (s *KeyImportService) StartUnlinkKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
|
||||
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
|
||||
|
||||
func (s *KeyImportService) runAddKeysTask(taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
|
||||
// 步骤 1: 对输入的原始 key 列表进行去重。
|
||||
uniqueKeysMap := make(map[string]struct{})
|
||||
var uniqueKeyStrings []string
|
||||
for _, kStr := range keys {
|
||||
if _, exists := uniqueKeysMap[kStr]; !exists {
|
||||
uniqueKeysMap[kStr] = struct{}{}
|
||||
uniqueKeyStrings = append(uniqueKeyStrings, kStr)
|
||||
}
|
||||
}
|
||||
if len(uniqueKeyStrings) == 0 {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"newly_linked_count": 0, "already_linked_count": 0}, nil)
|
||||
return
|
||||
}
|
||||
// 步骤 2: 确保所有 Key 在主表中存在(创建或恢复),并获取它们完整的实体。
|
||||
keysToEnsure := make([]models.APIKey, len(uniqueKeyStrings))
|
||||
for i, keyStr := range uniqueKeyStrings {
|
||||
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
|
||||
}
|
||||
allKeyModels, err := s.keyRepo.AddKeys(keysToEnsure)
|
||||
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
|
||||
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
|
||||
}
|
||||
if len(keyValues) == 0 {
|
||||
return nil, fmt.Errorf("no keys found matching the provided filter")
|
||||
}
|
||||
|
||||
keysAsText := strings.Join(keyValues, "\n")
|
||||
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
|
||||
|
||||
return s.StartUnlinkKeysTask(ctx, groupID, keysAsText)
|
||||
}
|
||||
|
||||
// ==================== 核心任务执行逻辑 ====================
|
||||
|
||||
// runAddKeysTask 执行批量添加密钥
|
||||
func (s *KeyImportService) runAddKeysTask(ctx context.Context, taskID string, resourceID string, groupID uint, keys []string, validateOnImport bool) {
|
||||
// 1. 去重
|
||||
uniqueKeys := s.deduplicateKeys(keys)
|
||||
if len(uniqueKeys) == 0 {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, gin.H{
|
||||
"newly_linked_count": 0,
|
||||
"already_linked_count": 0,
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
// 步骤 3: 找出在这些 Key 中,哪些【已经】被链接到了当前分组。
|
||||
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeyStrings, groupID)
|
||||
|
||||
// 2. 确保所有密钥在数据库中存在(幂等操作)
|
||||
allKeyModels, err := s.ensureKeysExist(uniqueKeys)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to check for already linked keys: %w", err))
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to ensure keys exist: %w", err))
|
||||
return
|
||||
}
|
||||
alreadyLinkedIDSet := make(map[uint]struct{})
|
||||
for _, key := range alreadyLinkedModels {
|
||||
alreadyLinkedIDSet[key.ID] = struct{}{}
|
||||
|
||||
// 3. 过滤已关联的密钥
|
||||
keysToLink, alreadyLinkedCount, err := s.filterNewKeys(allKeyModels, groupID, uniqueKeys)
|
||||
if err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to check linked keys: %w", err))
|
||||
return
|
||||
}
|
||||
// 步骤 4: 确定【真正需要】被链接到当前分组的 key 列表 (我们的"工作集")。
|
||||
var keysToLink []models.APIKey
|
||||
for _, key := range allKeyModels {
|
||||
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
|
||||
keysToLink = append(keysToLink, key)
|
||||
}
|
||||
}
|
||||
// 步骤 5: 更新任务的 Total 总量为精确的 "工作集" 大小。
|
||||
if err := s.taskService.UpdateTotalByID(taskID, len(keysToLink)); err != nil {
|
||||
|
||||
// 4. 更新任务的实际处理总数
|
||||
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(keysToLink)); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
||||
}
|
||||
// 步骤 6: 分块处理【链接Key到组】的操作,并实时更新进度。
|
||||
|
||||
// 5. 批量关联密钥到组
|
||||
if len(keysToLink) > 0 {
|
||||
idsToLink := make([]uint, len(keysToLink))
|
||||
for i, key := range keysToLink {
|
||||
idsToLink[i] = key.ID
|
||||
}
|
||||
for i := 0; i < len(idsToLink); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(idsToLink) {
|
||||
end = len(idsToLink)
|
||||
}
|
||||
chunk := idsToLink[i:end]
|
||||
if err := s.keyRepo.LinkKeysToGroup(groupID, chunk); err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed to link keys: %w", err))
|
||||
return
|
||||
}
|
||||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||||
if err := s.linkKeysInChunks(ctx, taskID, groupID, keysToLink); err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 步骤 7: 准备最终结果并结束任务。
|
||||
// 6. 根据验证标志处理密钥状态
|
||||
if len(keysToLink) > 0 {
|
||||
s.processNewlyLinkedKeys(ctx, groupID, keysToLink, validateOnImport)
|
||||
}
|
||||
|
||||
// 7. 返回结果
|
||||
result := gin.H{
|
||||
"newly_linked_count": len(keysToLink),
|
||||
"already_linked_count": len(alreadyLinkedIDSet),
|
||||
"already_linked_count": alreadyLinkedCount,
|
||||
"total_linked_count": len(allKeyModels),
|
||||
}
|
||||
// 步骤 8: 根据 `validateOnImport` 标志, 发布事件或直接激活 (只对新链接的keys操作)。
|
||||
if len(keysToLink) > 0 {
|
||||
idsToLink := make([]uint, len(keysToLink))
|
||||
for i, key := range keysToLink {
|
||||
idsToLink[i] = key.ID
|
||||
}
|
||||
if validateOnImport {
|
||||
s.publishImportGroupCompletedEvent(groupID, idsToLink)
|
||||
for _, keyID := range idsToLink {
|
||||
s.publishSingleKeyChangeEvent(groupID, keyID, "", models.StatusPendingValidation, "key_linked")
|
||||
}
|
||||
} else {
|
||||
for _, keyID := range idsToLink {
|
||||
if _, err := s.apiKeyService.UpdateMappingStatus(groupID, keyID, models.StatusActive); err != nil {
|
||||
s.logger.Errorf("Failed to directly activate key ID %d in group %d: %v", keyID, groupID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
// runUnlinkKeysTask
|
||||
func (s *KeyImportService) runUnlinkKeysTask(taskID, resourceID string, groupID uint, keys []string) {
|
||||
uniqueKeysMap := make(map[string]struct{})
|
||||
var uniqueKeys []string
|
||||
for _, kStr := range keys {
|
||||
if _, exists := uniqueKeysMap[kStr]; !exists {
|
||||
uniqueKeysMap[kStr] = struct{}{}
|
||||
uniqueKeys = append(uniqueKeys, kStr)
|
||||
}
|
||||
}
|
||||
// 步骤 1: 一次性找出所有输入 Key 中,实际存在于本组的 Key 实体。这是我们的"工作集"。
|
||||
// runUnlinkKeysTask 执行批量解绑密钥
|
||||
func (s *KeyImportService) runUnlinkKeysTask(ctx context.Context, taskID, resourceID string, groupID uint, keys []string) {
|
||||
// 1. 去重
|
||||
uniqueKeys := s.deduplicateKeys(keys)
|
||||
|
||||
// 2. 查找需要解绑的密钥
|
||||
keysToUnlink, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, fmt.Errorf("failed to find keys to unlink: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(keysToUnlink) == 0 {
|
||||
result := gin.H{"unlinked_count": 0, "hard_deleted_count": 0, "not_found_count": len(uniqueKeys)}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
result := gin.H{
|
||||
"unlinked_count": 0,
|
||||
"hard_deleted_count": 0,
|
||||
"not_found_count": len(uniqueKeys),
|
||||
}
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
|
||||
return
|
||||
}
|
||||
idsToUnlink := make([]uint, len(keysToUnlink))
|
||||
for i, key := range keysToUnlink {
|
||||
idsToUnlink[i] = key.ID
|
||||
}
|
||||
// 步骤 2: 更新任务的 Total 总量为精确的 "工作集" 大小。
|
||||
if err := s.taskService.UpdateTotalByID(taskID, len(idsToUnlink)); err != nil {
|
||||
|
||||
// 3. 提取密钥 ID
|
||||
idsToUnlink := s.extractKeyIDs(keysToUnlink)
|
||||
|
||||
// 4. 更新任务总数
|
||||
if err := s.taskService.UpdateTotalByID(ctx, taskID, len(idsToUnlink)); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to update total for task %s", taskID)
|
||||
}
|
||||
var totalUnlinked int64
|
||||
// 步骤 3: 分块处理【解绑Key】的操作,并上报进度。
|
||||
for i := 0; i < len(idsToUnlink); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(idsToUnlink) {
|
||||
end = len(idsToUnlink)
|
||||
}
|
||||
chunk := idsToUnlink[i:end]
|
||||
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(groupID, chunk)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("chunk failed: could not unlink keys: %w", err))
|
||||
return
|
||||
}
|
||||
totalUnlinked += unlinked
|
||||
|
||||
for _, keyID := range chunk {
|
||||
s.publishSingleKeyChangeEvent(groupID, keyID, models.StatusActive, "", "key_unlinked")
|
||||
}
|
||||
|
||||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||||
// 5. 批量解绑
|
||||
totalUnlinked, err := s.unlinkKeysInChunks(ctx, taskID, groupID, idsToUnlink)
|
||||
if err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 6. 清理孤立密钥
|
||||
totalDeleted, err := s.keyRepo.DeleteOrphanKeys()
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Warn("Orphan key cleanup failed after unlink task.")
|
||||
}
|
||||
|
||||
// 7. 返回结果
|
||||
result := gin.H{
|
||||
"unlinked_count": totalUnlinked,
|
||||
"hard_deleted_count": totalDeleted,
|
||||
"not_found_count": len(uniqueKeys) - int(totalUnlinked),
|
||||
}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
|
||||
}
|
||||
|
||||
func (s *KeyImportService) runHardDeleteKeysTask(taskID, resourceID string, keys []string) {
|
||||
var totalDeleted int64
|
||||
for i := 0; i < len(keys); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(keys) {
|
||||
end = len(keys)
|
||||
}
|
||||
chunk := keys[i:end]
|
||||
// runHardDeleteKeysTask 执行硬删除密钥
|
||||
func (s *KeyImportService) runHardDeleteKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
|
||||
totalDeleted, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) {
|
||||
return s.keyRepo.HardDeleteByValues(chunk)
|
||||
})
|
||||
|
||||
deleted, err := s.keyRepo.HardDeleteByValues(chunk)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to hard delete chunk: %w", err))
|
||||
return
|
||||
}
|
||||
totalDeleted += deleted
|
||||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||||
if err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
result := gin.H{
|
||||
"hard_deleted_count": totalDeleted,
|
||||
"not_found_count": int64(len(keys)) - totalDeleted,
|
||||
}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(0, "keys_hard_deleted") // Global event
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(ctx, 0, "keys_hard_deleted")
|
||||
}
|
||||
|
||||
func (s *KeyImportService) runRestoreKeysTask(taskID, resourceID string, keys []string) {
|
||||
var restoredCount int64
|
||||
for i := 0; i < len(keys); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(keys) {
|
||||
end = len(keys)
|
||||
}
|
||||
chunk := keys[i:end]
|
||||
// runRestoreKeysTask 执行恢复密钥
|
||||
func (s *KeyImportService) runRestoreKeysTask(ctx context.Context, taskID, resourceID string, keys []string) {
|
||||
restoredCount, err := s.processKeysInChunks(ctx, taskID, keys, func(chunk []string) (int64, error) {
|
||||
return s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
|
||||
})
|
||||
|
||||
count, err := s.keyRepo.UpdateMasterStatusByValues(chunk, models.MasterStatusActive)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskID, resourceID, nil, fmt.Errorf("failed to restore chunk: %w", err))
|
||||
return
|
||||
}
|
||||
restoredCount += count
|
||||
_ = s.taskService.UpdateProgressByID(taskID, i+len(chunk))
|
||||
if err != nil {
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
result := gin.H{
|
||||
"restored_count": restoredCount,
|
||||
"not_found_count": int64(len(keys)) - restoredCount,
|
||||
}
|
||||
s.taskService.EndTaskByID(taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(0, "keys_bulk_restored") // Global event
|
||||
s.endTaskWithResult(ctx, taskID, resourceID, result, nil)
|
||||
s.publishChangeEvent(ctx, 0, "keys_bulk_restored")
|
||||
}
|
||||
|
||||
func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
||||
// ==================== 辅助方法 ====================
|
||||
|
||||
// deduplicateKeys 去重密钥列表
|
||||
func (s *KeyImportService) deduplicateKeys(keys []string) []string {
|
||||
uniqueKeysMap := make(map[string]struct{}, len(keys))
|
||||
uniqueKeys := make([]string, 0, len(keys))
|
||||
|
||||
for _, kStr := range keys {
|
||||
if _, exists := uniqueKeysMap[kStr]; !exists {
|
||||
uniqueKeysMap[kStr] = struct{}{}
|
||||
uniqueKeys = append(uniqueKeys, kStr)
|
||||
}
|
||||
}
|
||||
return uniqueKeys
|
||||
}
|
||||
|
||||
// ensureKeysExist 确保所有密钥在数据库中存在
|
||||
func (s *KeyImportService) ensureKeysExist(keys []string) ([]models.APIKey, error) {
|
||||
keysToEnsure := make([]models.APIKey, len(keys))
|
||||
for i, keyStr := range keys {
|
||||
keysToEnsure[i] = models.APIKey{APIKey: keyStr}
|
||||
}
|
||||
return s.keyRepo.AddKeys(keysToEnsure)
|
||||
}
|
||||
|
||||
// filterNewKeys 过滤已关联的密钥,返回需要新增的密钥
|
||||
func (s *KeyImportService) filterNewKeys(allKeyModels []models.APIKey, groupID uint, uniqueKeys []string) ([]models.APIKey, int, error) {
|
||||
alreadyLinkedModels, err := s.keyRepo.GetKeysByValuesAndGroupID(uniqueKeys, groupID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
alreadyLinkedIDSet := make(map[uint]struct{}, len(alreadyLinkedModels))
|
||||
for _, key := range alreadyLinkedModels {
|
||||
alreadyLinkedIDSet[key.ID] = struct{}{}
|
||||
}
|
||||
|
||||
keysToLink := make([]models.APIKey, 0, len(allKeyModels)-len(alreadyLinkedIDSet))
|
||||
for _, key := range allKeyModels {
|
||||
if _, exists := alreadyLinkedIDSet[key.ID]; !exists {
|
||||
keysToLink = append(keysToLink, key)
|
||||
}
|
||||
}
|
||||
|
||||
return keysToLink, len(alreadyLinkedIDSet), nil
|
||||
}
|
||||
|
||||
// extractKeyIDs 提取密钥 ID 列表
|
||||
func (s *KeyImportService) extractKeyIDs(keys []models.APIKey) []uint {
|
||||
ids := make([]uint, len(keys))
|
||||
for i, key := range keys {
|
||||
ids[i] = key.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// linkKeysInChunks 分块关联密钥到组
|
||||
func (s *KeyImportService) linkKeysInChunks(ctx context.Context, taskID string, groupID uint, keysToLink []models.APIKey) error {
|
||||
idsToLink := s.extractKeyIDs(keysToLink)
|
||||
|
||||
for i := 0; i < len(idsToLink); i += chunkSize {
|
||||
end := min(i+chunkSize, len(idsToLink))
|
||||
chunk := idsToLink[i:end]
|
||||
|
||||
if err := s.keyRepo.LinkKeysToGroup(ctx, groupID, chunk); err != nil {
|
||||
return fmt.Errorf("chunk failed to link keys: %w", err)
|
||||
}
|
||||
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// unlinkKeysInChunks 分块解绑密钥
|
||||
func (s *KeyImportService) unlinkKeysInChunks(ctx context.Context, taskID string, groupID uint, idsToUnlink []uint) (int64, error) {
|
||||
var totalUnlinked int64
|
||||
|
||||
for i := 0; i < len(idsToUnlink); i += chunkSize {
|
||||
end := min(i+chunkSize, len(idsToUnlink))
|
||||
chunk := idsToUnlink[i:end]
|
||||
|
||||
unlinked, err := s.keyRepo.UnlinkKeysFromGroup(ctx, groupID, chunk)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("chunk failed: could not unlink keys: %w", err)
|
||||
}
|
||||
|
||||
totalUnlinked += unlinked
|
||||
|
||||
// 发布解绑事件
|
||||
for _, keyID := range chunk {
|
||||
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, models.StatusActive, "", "key_unlinked")
|
||||
}
|
||||
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
|
||||
}
|
||||
|
||||
return totalUnlinked, nil
|
||||
}
|
||||
|
||||
// processKeysInChunks 通用的分块处理密钥逻辑
|
||||
func (s *KeyImportService) processKeysInChunks(
|
||||
ctx context.Context,
|
||||
taskID string,
|
||||
keys []string,
|
||||
processFunc func(chunk []string) (int64, error),
|
||||
) (int64, error) {
|
||||
var totalProcessed int64
|
||||
|
||||
for i := 0; i < len(keys); i += chunkSize {
|
||||
end := min(i+chunkSize, len(keys))
|
||||
chunk := keys[i:end]
|
||||
|
||||
count, err := processFunc(chunk)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to process chunk: %w", err)
|
||||
}
|
||||
|
||||
totalProcessed += count
|
||||
_ = s.taskService.UpdateProgressByID(ctx, taskID, end)
|
||||
}
|
||||
|
||||
return totalProcessed, nil
|
||||
}
|
||||
|
||||
// processNewlyLinkedKeys 处理新关联的密钥(验证或直接激活)
|
||||
func (s *KeyImportService) processNewlyLinkedKeys(ctx context.Context, groupID uint, keysToLink []models.APIKey, validateOnImport bool) {
|
||||
idsToLink := s.extractKeyIDs(keysToLink)
|
||||
|
||||
if validateOnImport {
|
||||
// 发布批量导入完成事件,触发验证
|
||||
s.publishImportGroupCompletedEvent(ctx, groupID, idsToLink)
|
||||
|
||||
// 发布单个密钥状态变更事件
|
||||
for _, keyID := range idsToLink {
|
||||
s.publishSingleKeyChangeEvent(ctx, groupID, keyID, "", models.StatusPendingValidation, "key_linked")
|
||||
}
|
||||
} else {
|
||||
// 直接激活密钥,不进行验证
|
||||
for _, keyID := range idsToLink {
|
||||
if _, err := s.apiKeyService.UpdateMappingStatus(ctx, groupID, keyID, models.StatusActive); err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_id": keyID,
|
||||
}).Errorf("Failed to directly activate key: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// endTaskWithResult 统一的任务结束处理
|
||||
func (s *KeyImportService) endTaskWithResult(ctx context.Context, taskID, resourceID string, result gin.H, err error) {
|
||||
if err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"task_id": taskID,
|
||||
"resource_id": resourceID,
|
||||
}).WithError(err).Error("Task failed")
|
||||
}
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, result, err)
|
||||
}
|
||||
|
||||
// ==================== 事件发布方法 ====================
|
||||
|
||||
// publishSingleKeyChangeEvent 发布单个密钥状态变更事件
|
||||
func (s *KeyImportService) publishSingleKeyChangeEvent(ctx context.Context, groupID, keyID uint, oldStatus, newStatus models.APIKeyStatus, reason string) {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
GroupID: groupID,
|
||||
KeyID: keyID,
|
||||
@@ -339,59 +474,88 @@ func (s *KeyImportService) publishSingleKeyChangeEvent(groupID, keyID uint, oldS
|
||||
ChangeReason: reason,
|
||||
ChangedAt: time.Now(),
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
if err := s.store.Publish(models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
s.logger.WithError(err).WithFields(logrus.Fields{
|
||||
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_id": keyID,
|
||||
"reason": reason,
|
||||
}).Error("Failed to publish single key change event.")
|
||||
}).WithError(err).Error("Failed to marshal key change event")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_id": keyID,
|
||||
"reason": reason,
|
||||
}).WithError(err).Error("Failed to publish single key change event")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KeyImportService) publishChangeEvent(groupID uint, reason string) {
|
||||
// publishChangeEvent 发布通用变更事件
|
||||
func (s *KeyImportService) publishChangeEvent(ctx context.Context, groupID uint, reason string) {
|
||||
event := models.KeyStatusChangedEvent{
|
||||
GroupID: groupID,
|
||||
ChangeReason: reason,
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
_ = s.store.Publish(models.TopicKeyStatusChanged, eventData)
|
||||
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"reason": reason,
|
||||
}).WithError(err).Error("Failed to marshal change event")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.store.Publish(ctx, models.TopicKeyStatusChanged, eventData); err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"reason": reason,
|
||||
}).WithError(err).Error("Failed to publish change event")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *KeyImportService) publishImportGroupCompletedEvent(groupID uint, keyIDs []uint) {
|
||||
// publishImportGroupCompletedEvent 发布批量导入完成事件
|
||||
func (s *KeyImportService) publishImportGroupCompletedEvent(ctx context.Context, groupID uint, keyIDs []uint) {
|
||||
if len(keyIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
event := models.ImportGroupCompletedEvent{
|
||||
GroupID: groupID,
|
||||
KeyIDs: keyIDs,
|
||||
CompletedAt: time.Now(),
|
||||
}
|
||||
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Failed to marshal ImportGroupCompletedEvent.")
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_count": len(keyIDs),
|
||||
}).WithError(err).Error("Failed to marshal ImportGroupCompletedEvent")
|
||||
return
|
||||
}
|
||||
if err := s.store.Publish(models.TopicImportGroupCompleted, eventData); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to publish ImportGroupCompletedEvent.")
|
||||
|
||||
if err := s.store.Publish(ctx, models.TopicImportGroupCompleted, eventData); err != nil {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_count": len(keyIDs),
|
||||
}).WithError(err).Error("Failed to publish ImportGroupCompletedEvent")
|
||||
} else {
|
||||
s.logger.Infof("Published ImportGroupCompletedEvent for group %d with %d keys.", groupID, len(keyIDs))
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"key_count": len(keyIDs),
|
||||
}).Info("Published ImportGroupCompletedEvent")
|
||||
}
|
||||
}
|
||||
|
||||
// [NEW] StartUnlinkKeysByFilterTask starts a task to unlink keys matching a status filter.
|
||||
func (s *KeyImportService) StartUnlinkKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
|
||||
s.logger.Infof("Starting unlink task for group %d with status filter: %v", groupID, statuses)
|
||||
// 1. [New] Find the keys to operate on.
|
||||
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find keys by filter: %w", err)
|
||||
// min 返回两个整数中的较小值
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
if len(keyValues) == 0 {
|
||||
return nil, fmt.Errorf("no keys found matching the provided filter")
|
||||
}
|
||||
// 2. [REUSE] Convert to text and call the existing, robust unlink task logic.
|
||||
keysAsText := strings.Join(keyValues, "\n")
|
||||
s.logger.Infof("Found %d keys to unlink for group %d.", len(keyValues), groupID)
|
||||
return s.StartUnlinkKeysTask(groupID, keysAsText)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/channel"
|
||||
@@ -24,26 +25,38 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
TaskTypeTestKeys = "test_keys"
|
||||
TaskTypeTestKeys = "test_keys"
|
||||
defaultConcurrency = 10
|
||||
maxValidationConcurrency = 100
|
||||
validationTaskTimeout = time.Hour
|
||||
)
|
||||
|
||||
type KeyValidationService struct {
|
||||
taskService task.Reporter
|
||||
channel channel.ChannelProxy
|
||||
db *gorm.DB
|
||||
SettingsManager *settings.SettingsManager
|
||||
settingsManager *settings.SettingsManager
|
||||
groupManager *GroupManager
|
||||
store store.Store
|
||||
keyRepo repository.KeyRepository
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm.DB, ss *settings.SettingsManager, gm *GroupManager, st store.Store, kr repository.KeyRepository, logger *logrus.Logger) *KeyValidationService {
|
||||
func NewKeyValidationService(
|
||||
ts task.Reporter,
|
||||
ch channel.ChannelProxy,
|
||||
db *gorm.DB,
|
||||
ss *settings.SettingsManager,
|
||||
gm *GroupManager,
|
||||
st store.Store,
|
||||
kr repository.KeyRepository,
|
||||
logger *logrus.Logger,
|
||||
) *KeyValidationService {
|
||||
return &KeyValidationService{
|
||||
taskService: ts,
|
||||
channel: ch,
|
||||
db: db,
|
||||
SettingsManager: ss,
|
||||
settingsManager: ss,
|
||||
groupManager: gm,
|
||||
store: st,
|
||||
keyRepo: kr,
|
||||
@@ -51,53 +64,54 @@ func NewKeyValidationService(ts task.Reporter, ch channel.ChannelProxy, db *gorm
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 公开接口 ====================
|
||||
|
||||
// ValidateSingleKey 验证单个密钥
|
||||
func (s *KeyValidationService) ValidateSingleKey(key *models.APIKey, timeout time.Duration, endpoint string) error {
|
||||
// 1. 解密密钥
|
||||
if err := s.keyRepo.Decrypt(key); err != nil {
|
||||
return fmt.Errorf("failed to decrypt key %d for validation: %w", key.ID, err)
|
||||
}
|
||||
|
||||
// 2. 创建 HTTP 客户端和请求
|
||||
client := &http.Client{Timeout: timeout}
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to create request for key validation (ID: %d): %v", key.ID, err)
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"key_id": key.ID,
|
||||
"endpoint": endpoint,
|
||||
}).Error("Failed to create validation request")
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
s.channel.ModifyRequest(req, key) // Use the injected channel to modify the request
|
||||
// 3. 修改请求(添加密钥认证头)
|
||||
s.channel.ModifyRequest(req, key)
|
||||
|
||||
// 4. 执行请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
// This is a network-level error (e.g., timeout, DNS issue)
|
||||
return fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 5. 检查响应状态
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil // Success
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read the body for more error details
|
||||
bodyBytes, readErr := io.ReadAll(resp.Body)
|
||||
var errorMsg string
|
||||
if readErr != nil {
|
||||
errorMsg = "Failed to read error response body"
|
||||
} else {
|
||||
errorMsg = string(bodyBytes)
|
||||
}
|
||||
|
||||
// This is a validation failure with a specific HTTP status code
|
||||
return &CustomErrors.APIError{
|
||||
HTTPStatus: resp.StatusCode,
|
||||
Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg),
|
||||
Code: "VALIDATION_FAILED",
|
||||
}
|
||||
// 6. 处理错误响应
|
||||
return s.buildValidationError(resp)
|
||||
}
|
||||
|
||||
// --- 异步任务方法 (全面适配新task包) ---
|
||||
func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string) (*task.Status, error) {
|
||||
// StartTestKeysTask 启动批量密钥测试任务
|
||||
func (s *KeyValidationService) StartTestKeysTask(ctx context.Context, groupID uint, keysText string) (*task.Status, error) {
|
||||
// 1. 解析和验证输入
|
||||
keyStrings := utils.ParseKeysFromText(keysText)
|
||||
if len(keyStrings) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrBadRequest, "no valid keys found in input text")
|
||||
}
|
||||
|
||||
// 2. 查询密钥模型
|
||||
apiKeyModels, err := s.keyRepo.GetKeysByValues(keyStrings)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
@@ -105,116 +119,345 @@ func (s *KeyValidationService) StartTestKeysTask(groupID uint, keysText string)
|
||||
if len(apiKeyModels) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrResourceNotFound, "none of the provided keys were found in the system")
|
||||
}
|
||||
|
||||
// 3. 批量解密密钥
|
||||
if err := s.keyRepo.DecryptBatch(apiKeyModels); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task.")
|
||||
s.logger.WithError(err).Error("Failed to batch decrypt keys for validation task")
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, "Failed to decrypt keys for validation")
|
||||
}
|
||||
|
||||
// 4. 获取组配置
|
||||
group, ok := s.groupManager.GetGroupByID(groupID)
|
||||
if !ok {
|
||||
// [FIX] Correctly use the NewAPIError constructor for a missing group.
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrGroupNotFound, fmt.Sprintf("group with id %d not found", groupID))
|
||||
}
|
||||
|
||||
opConfig, err := s.groupManager.BuildOperationalConfig(group)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build operational config: %v", err))
|
||||
}
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), time.Hour)
|
||||
if err != nil {
|
||||
return nil, err // Pass up the error from task service (e.g., "task already running")
|
||||
}
|
||||
settings := s.SettingsManager.GetSettings()
|
||||
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
||||
|
||||
// 5. 构建验证端点
|
||||
endpoint, err := s.groupManager.BuildKeyCheckEndpoint(groupID)
|
||||
if err != nil {
|
||||
s.taskService.EndTaskByID(taskStatus.ID, resourceID, nil, err) // End task with error if endpoint fails
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrInternalServer, fmt.Sprintf("failed to build endpoint: %v", err))
|
||||
}
|
||||
|
||||
// 6. 创建任务
|
||||
resourceID := fmt.Sprintf("group-%d", groupID)
|
||||
taskStatus, err := s.taskService.StartTask(ctx, groupID, TaskTypeTestKeys, resourceID, len(apiKeyModels), validationTaskTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var concurrency int
|
||||
if opConfig.KeyCheckConcurrency != nil {
|
||||
concurrency = *opConfig.KeyCheckConcurrency
|
||||
} else {
|
||||
concurrency = settings.BaseKeyCheckConcurrency
|
||||
}
|
||||
go s.runTestKeysTask(taskStatus.ID, resourceID, groupID, apiKeyModels, timeout, endpoint, concurrency)
|
||||
|
||||
// 7. 准备任务参数
|
||||
params := s.buildValidationParams(opConfig)
|
||||
|
||||
// 8. 启动异步验证任务
|
||||
go s.runTestKeysTaskWithRecovery(context.Background(), taskStatus.ID, resourceID, groupID, apiKeyModels, params, endpoint)
|
||||
|
||||
return taskStatus, nil
|
||||
}
|
||||
|
||||
func (s *KeyValidationService) runTestKeysTask(taskID string, resourceID string, groupID uint, keys []models.APIKey, timeout time.Duration, endpoint string, concurrency int) {
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
finalResults := make([]models.KeyTestResult, len(keys))
|
||||
processedCount := 0
|
||||
if concurrency <= 0 {
|
||||
concurrency = 10
|
||||
}
|
||||
type job struct {
|
||||
Index int
|
||||
Value models.APIKey
|
||||
}
|
||||
jobs := make(chan job, len(keys))
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := range jobs {
|
||||
apiKeyModel := j.Value
|
||||
keyToValidate := apiKeyModel
|
||||
validationErr := s.ValidateSingleKey(&keyToValidate, timeout, endpoint)
|
||||
// StartTestKeysByFilterTask 根据状态过滤启动批量测试任务
|
||||
func (s *KeyValidationService) StartTestKeysByFilterTask(ctx context.Context, groupID uint, statuses []string) (*task.Status, error) {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"group_id": groupID,
|
||||
"statuses": statuses,
|
||||
}).Info("Starting test task with status filter")
|
||||
|
||||
var currentResult models.KeyTestResult
|
||||
event := models.RequestFinishedEvent{
|
||||
RequestLog: models.RequestLog{
|
||||
// GroupID 和 KeyID 在 RequestLog 模型中是指针,需要取地址
|
||||
GroupID: &groupID,
|
||||
KeyID: &apiKeyModel.ID,
|
||||
},
|
||||
}
|
||||
if validationErr == nil {
|
||||
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "valid", Message: "Validation successful."}
|
||||
event.RequestLog.IsSuccess = true
|
||||
} else {
|
||||
var apiErr *CustomErrors.APIError
|
||||
if CustomErrors.As(validationErr, &apiErr) {
|
||||
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "invalid", Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message)}
|
||||
event.Error = apiErr
|
||||
} else {
|
||||
currentResult = models.KeyTestResult{Key: apiKeyModel.APIKey, Status: "error", Message: "Validation check failed: " + validationErr.Error()}
|
||||
event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
|
||||
}
|
||||
event.RequestLog.IsSuccess = false
|
||||
}
|
||||
eventData, _ := json.Marshal(event)
|
||||
if err := s.store.Publish(models.TopicRequestFinished, eventData); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to publish RequestFinishedEvent for validation of key ID %d", apiKeyModel.ID)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
finalResults[j.Index] = currentResult
|
||||
processedCount++
|
||||
_ = s.taskService.UpdateProgressByID(taskID, processedCount)
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
for i, k := range keys {
|
||||
jobs <- job{Index: i, Value: k}
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
s.taskService.EndTaskByID(taskID, resourceID, gin.H{"results": finalResults}, nil)
|
||||
}
|
||||
|
||||
func (s *KeyValidationService) StartTestKeysByFilterTask(groupID uint, statuses []string) (*task.Status, error) {
|
||||
s.logger.Infof("Starting test task for group %d with status filter: %v", groupID, statuses)
|
||||
// 1. 根据过滤条件查询密钥
|
||||
keyValues, err := s.keyRepo.FindKeyValuesByStatus(groupID, statuses)
|
||||
if err != nil {
|
||||
return nil, CustomErrors.ParseDBError(err)
|
||||
}
|
||||
|
||||
if len(keyValues) == 0 {
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria.")
|
||||
return nil, CustomErrors.NewAPIError(CustomErrors.ErrNotFound, "No keys found matching the filter criteria")
|
||||
}
|
||||
|
||||
// 2. 转换为文本格式并启动任务
|
||||
keysAsText := strings.Join(keyValues, "\n")
|
||||
s.logger.Infof("Found %d keys to validate for group %d.", len(keyValues), groupID)
|
||||
return s.StartTestKeysTask(groupID, keysAsText)
|
||||
s.logger.Infof("Found %d keys to validate for group %d", len(keyValues), groupID)
|
||||
|
||||
return s.StartTestKeysTask(ctx, groupID, keysAsText)
|
||||
}
|
||||
|
||||
// ==================== 核心任务执行逻辑 ====================
|
||||
|
||||
// validationParams 验证参数封装
|
||||
type validationParams struct {
|
||||
timeout time.Duration
|
||||
concurrency int
|
||||
}
|
||||
|
||||
// buildValidationParams 构建验证参数
|
||||
func (s *KeyValidationService) buildValidationParams(opConfig *models.KeyGroupSettings) validationParams {
|
||||
settings := s.settingsManager.GetSettings()
|
||||
// 从配置读取超时时间(而非硬编码)
|
||||
timeout := time.Duration(settings.KeyCheckTimeoutSeconds) * time.Second
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second // 仅在配置无效时使用默认值
|
||||
}
|
||||
// 从配置读取并发数(优先级:组配置 > 全局配置 > 兜底默认值)
|
||||
var concurrency int
|
||||
if opConfig.KeyCheckConcurrency != nil && *opConfig.KeyCheckConcurrency > 0 {
|
||||
concurrency = *opConfig.KeyCheckConcurrency
|
||||
} else if settings.BaseKeyCheckConcurrency > 0 {
|
||||
concurrency = settings.BaseKeyCheckConcurrency
|
||||
} else {
|
||||
concurrency = defaultConcurrency // 兜底默认值
|
||||
}
|
||||
// 限制最大并发数(防护措施)
|
||||
if concurrency > maxValidationConcurrency {
|
||||
concurrency = maxValidationConcurrency
|
||||
}
|
||||
return validationParams{
|
||||
timeout: timeout,
|
||||
concurrency: concurrency,
|
||||
}
|
||||
}
|
||||
|
||||
// runTestKeysTaskWithRecovery 带恢复机制的任务执行包装器
|
||||
func (s *KeyValidationService) runTestKeysTaskWithRecovery(
|
||||
ctx context.Context,
|
||||
taskID string,
|
||||
resourceID string,
|
||||
groupID uint,
|
||||
keys []models.APIKey,
|
||||
params validationParams,
|
||||
endpoint string,
|
||||
) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err := fmt.Errorf("panic recovered in validation task %s: %v", taskID, r)
|
||||
s.logger.WithField("task_id", taskID).Error(err)
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, nil, err)
|
||||
}
|
||||
}()
|
||||
|
||||
s.runTestKeysTask(ctx, taskID, resourceID, groupID, keys, params, endpoint)
|
||||
}
|
||||
|
||||
// runTestKeysTask 执行批量密钥验证任务
|
||||
func (s *KeyValidationService) runTestKeysTask(
|
||||
ctx context.Context,
|
||||
taskID string,
|
||||
resourceID string,
|
||||
groupID uint,
|
||||
keys []models.APIKey,
|
||||
params validationParams,
|
||||
endpoint string,
|
||||
) {
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"task_id": taskID,
|
||||
"group_id": groupID,
|
||||
"key_count": len(keys),
|
||||
"concurrency": params.concurrency,
|
||||
"timeout": params.timeout,
|
||||
}).Info("Starting validation task")
|
||||
|
||||
// 1. 初始化结果收集
|
||||
results := make([]models.KeyTestResult, len(keys))
|
||||
|
||||
// 2. 创建任务分发器
|
||||
dispatcher := newValidationDispatcher(
|
||||
keys,
|
||||
params.concurrency,
|
||||
s,
|
||||
ctx,
|
||||
taskID,
|
||||
groupID,
|
||||
endpoint,
|
||||
params.timeout,
|
||||
)
|
||||
|
||||
// 3. 执行并发验证
|
||||
dispatcher.run(results)
|
||||
|
||||
// 4. 完成任务
|
||||
s.taskService.EndTaskByID(ctx, taskID, resourceID, gin.H{"results": results}, nil)
|
||||
|
||||
s.logger.WithFields(logrus.Fields{
|
||||
"task_id": taskID,
|
||||
"group_id": groupID,
|
||||
"processed": len(results),
|
||||
}).Info("Validation task completed")
|
||||
}
|
||||
|
||||
// ==================== 验证调度器 ====================
|
||||
|
||||
// validationJob 验证作业
|
||||
type validationJob struct {
|
||||
index int
|
||||
key models.APIKey
|
||||
}
|
||||
|
||||
// validationDispatcher 验证任务分发器
|
||||
type validationDispatcher struct {
|
||||
keys []models.APIKey
|
||||
concurrency int
|
||||
service *KeyValidationService
|
||||
ctx context.Context
|
||||
taskID string
|
||||
groupID uint
|
||||
endpoint string
|
||||
timeout time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
processedCount int
|
||||
}
|
||||
|
||||
// newValidationDispatcher 创建验证分发器
|
||||
func newValidationDispatcher(
|
||||
keys []models.APIKey,
|
||||
concurrency int,
|
||||
service *KeyValidationService,
|
||||
ctx context.Context,
|
||||
taskID string,
|
||||
groupID uint,
|
||||
endpoint string,
|
||||
timeout time.Duration,
|
||||
) *validationDispatcher {
|
||||
return &validationDispatcher{
|
||||
keys: keys,
|
||||
concurrency: concurrency,
|
||||
service: service,
|
||||
ctx: ctx,
|
||||
taskID: taskID,
|
||||
groupID: groupID,
|
||||
endpoint: endpoint,
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// run 执行并发验证
|
||||
func (d *validationDispatcher) run(results []models.KeyTestResult) {
|
||||
var wg sync.WaitGroup
|
||||
jobs := make(chan validationJob, len(d.keys))
|
||||
|
||||
// 启动 worker pool
|
||||
for i := 0; i < d.concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go d.worker(&wg, jobs, results)
|
||||
}
|
||||
|
||||
// 分发任务
|
||||
for i, key := range d.keys {
|
||||
jobs <- validationJob{index: i, key: key}
|
||||
}
|
||||
close(jobs)
|
||||
|
||||
// 等待所有 worker 完成
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// worker 验证工作协程
|
||||
func (d *validationDispatcher) worker(wg *sync.WaitGroup, jobs <-chan validationJob, results []models.KeyTestResult) {
|
||||
defer wg.Done()
|
||||
|
||||
for job := range jobs {
|
||||
result := d.validateKey(job.key)
|
||||
|
||||
d.mu.Lock()
|
||||
results[job.index] = result
|
||||
d.processedCount++
|
||||
_ = d.service.taskService.UpdateProgressByID(d.ctx, d.taskID, d.processedCount)
|
||||
d.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// validateKey 验证单个密钥并返回结果
|
||||
func (d *validationDispatcher) validateKey(key models.APIKey) models.KeyTestResult {
|
||||
// 1. 执行验证
|
||||
validationErr := d.service.ValidateSingleKey(&key, d.timeout, d.endpoint)
|
||||
|
||||
// 2. 构建结果和事件
|
||||
result, event := d.buildResultAndEvent(key, validationErr)
|
||||
|
||||
// 3. 发布验证事件
|
||||
d.publishValidationEvent(key.ID, event)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// buildResultAndEvent 构建验证结果和事件
|
||||
func (d *validationDispatcher) buildResultAndEvent(key models.APIKey, validationErr error) (models.KeyTestResult, models.RequestFinishedEvent) {
|
||||
event := models.RequestFinishedEvent{
|
||||
RequestLog: models.RequestLog{
|
||||
GroupID: &d.groupID,
|
||||
KeyID: &key.ID,
|
||||
},
|
||||
}
|
||||
|
||||
if validationErr == nil {
|
||||
// 验证成功
|
||||
event.RequestLog.IsSuccess = true
|
||||
return models.KeyTestResult{
|
||||
Key: key.APIKey,
|
||||
Status: "valid",
|
||||
Message: "Validation successful",
|
||||
}, event
|
||||
}
|
||||
|
||||
// 验证失败
|
||||
event.RequestLog.IsSuccess = false
|
||||
|
||||
var apiErr *CustomErrors.APIError
|
||||
if CustomErrors.As(validationErr, &apiErr) {
|
||||
event.Error = apiErr
|
||||
return models.KeyTestResult{
|
||||
Key: key.APIKey,
|
||||
Status: "invalid",
|
||||
Message: fmt.Sprintf("Invalid key (HTTP %d): %s", apiErr.HTTPStatus, apiErr.Message),
|
||||
}, event
|
||||
}
|
||||
|
||||
// 其他错误
|
||||
event.Error = &CustomErrors.APIError{Message: validationErr.Error()}
|
||||
return models.KeyTestResult{
|
||||
Key: key.APIKey,
|
||||
Status: "error",
|
||||
Message: "Validation check failed: " + validationErr.Error(),
|
||||
}, event
|
||||
}
|
||||
|
||||
// publishValidationEvent 发布验证事件
|
||||
func (d *validationDispatcher) publishValidationEvent(keyID uint, event models.RequestFinishedEvent) {
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
d.service.logger.WithFields(logrus.Fields{
|
||||
"key_id": keyID,
|
||||
"group_id": d.groupID,
|
||||
}).WithError(err).Error("Failed to marshal validation event")
|
||||
return
|
||||
}
|
||||
|
||||
if err := d.service.store.Publish(d.ctx, models.TopicRequestFinished, eventData); err != nil {
|
||||
d.service.logger.WithFields(logrus.Fields{
|
||||
"key_id": keyID,
|
||||
"group_id": d.groupID,
|
||||
}).WithError(err).Error("Failed to publish validation event")
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 辅助方法 ====================
|
||||
|
||||
// buildValidationError 构建验证错误
|
||||
func (s *KeyValidationService) buildValidationError(resp *http.Response) error {
|
||||
bodyBytes, readErr := io.ReadAll(resp.Body)
|
||||
|
||||
var errorMsg string
|
||||
if readErr != nil {
|
||||
errorMsg = "Failed to read error response body"
|
||||
s.logger.WithError(readErr).Warn("Failed to read validation error response")
|
||||
} else {
|
||||
errorMsg = string(bodyBytes)
|
||||
}
|
||||
|
||||
return &CustomErrors.APIError{
|
||||
HTTPStatus: resp.StatusCode,
|
||||
Message: fmt.Sprintf("Validation failed with status %d: %s", resp.StatusCode, errorMsg),
|
||||
Code: "VALIDATION_FAILED",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,77 +2,196 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type LogService struct {
|
||||
db *gorm.DB
|
||||
db *gorm.DB
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
func NewLogService(db *gorm.DB) *LogService {
|
||||
return &LogService{db: db}
|
||||
func NewLogService(db *gorm.DB, logger *logrus.Logger) *LogService {
|
||||
return &LogService{
|
||||
db: db,
|
||||
logger: logger.WithField("component", "LogService"),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LogService) Record(log *models.RequestLog) error {
|
||||
return s.db.Create(log).Error
|
||||
func (s *LogService) Record(ctx context.Context, log *models.RequestLog) error {
|
||||
return s.db.WithContext(ctx).Create(log).Error
|
||||
}
|
||||
|
||||
func (s *LogService) GetLogs(c *gin.Context) ([]models.RequestLog, int64, error) {
|
||||
// LogQueryParams 解耦 Gin,使用结构体传参
|
||||
type LogQueryParams struct {
|
||||
Page int
|
||||
PageSize int
|
||||
ModelName string
|
||||
IsSuccess *bool // 使用指针区分"未设置"和"false"
|
||||
StatusCode *int
|
||||
KeyIDs []string
|
||||
GroupIDs []string
|
||||
Q string
|
||||
ErrorCodes []string
|
||||
StatusCodes []string
|
||||
}
|
||||
|
||||
func (s *LogService) GetLogs(ctx context.Context, params LogQueryParams) ([]models.RequestLog, int64, error) {
|
||||
// 参数校验
|
||||
if params.Page < 1 {
|
||||
params.Page = 1
|
||||
}
|
||||
if params.PageSize < 1 || params.PageSize > 100 {
|
||||
params.PageSize = 20
|
||||
}
|
||||
|
||||
var logs []models.RequestLog
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&models.RequestLog{}).Scopes(s.filtersScope(c))
|
||||
|
||||
// 先计算总数
|
||||
// 构建基础查询
|
||||
query := s.db.WithContext(ctx).Model(&models.RequestLog{})
|
||||
query = s.applyFilters(query, params)
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
return nil, 0, fmt.Errorf("failed to count logs: %w", err)
|
||||
}
|
||||
if total == 0 {
|
||||
return []models.RequestLog{}, 0, nil
|
||||
}
|
||||
|
||||
// 再执行分页查询
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
err := query.Order("request_time desc").Limit(pageSize).Offset(offset).Find(&logs).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
offset := (params.Page - 1) * params.PageSize
|
||||
if err := query.Order("request_time DESC").
|
||||
Limit(params.PageSize).
|
||||
Offset(offset).
|
||||
Find(&logs).Error; err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query logs: %w", err)
|
||||
}
|
||||
|
||||
return logs, total, nil
|
||||
}
|
||||
|
||||
func (s *LogService) filtersScope(c *gin.Context) func(db *gorm.DB) *gorm.DB {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
if modelName := c.Query("model_name"); modelName != "" {
|
||||
db = db.Where("model_name = ?", modelName)
|
||||
}
|
||||
if isSuccessStr := c.Query("is_success"); isSuccessStr != "" {
|
||||
if isSuccess, err := strconv.ParseBool(isSuccessStr); err == nil {
|
||||
db = db.Where("is_success = ?", isSuccess)
|
||||
}
|
||||
}
|
||||
if statusCodeStr := c.Query("status_code"); statusCodeStr != "" {
|
||||
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil {
|
||||
db = db.Where("status_code = ?", statusCode)
|
||||
}
|
||||
}
|
||||
if keyIDStr := c.Query("key_id"); keyIDStr != "" {
|
||||
if keyID, err := strconv.ParseUint(keyIDStr, 10, 64); err == nil {
|
||||
db = db.Where("key_id = ?", keyID)
|
||||
}
|
||||
}
|
||||
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||
if groupID, err := strconv.ParseUint(groupIDStr, 10, 64); err == nil {
|
||||
db = db.Where("group_id = ?", groupID)
|
||||
}
|
||||
}
|
||||
return db
|
||||
func (s *LogService) applyFilters(query *gorm.DB, params LogQueryParams) *gorm.DB {
|
||||
if params.IsSuccess != nil {
|
||||
query = query.Where("is_success = ?", *params.IsSuccess)
|
||||
} else {
|
||||
query = query.Where("is_success = ?", false)
|
||||
}
|
||||
if params.ModelName != "" {
|
||||
query = query.Where("model_name = ?", params.ModelName)
|
||||
}
|
||||
if params.StatusCode != nil {
|
||||
query = query.Where("status_code = ?", *params.StatusCode)
|
||||
}
|
||||
if len(params.KeyIDs) > 0 {
|
||||
query = query.Where("key_id IN (?)", params.KeyIDs)
|
||||
}
|
||||
if len(params.GroupIDs) > 0 {
|
||||
query = query.Where("group_id IN (?)", params.GroupIDs)
|
||||
}
|
||||
hasErrorCodes := len(params.ErrorCodes) > 0
|
||||
hasStatusCodes := len(params.StatusCodes) > 0
|
||||
if hasErrorCodes && hasStatusCodes {
|
||||
query = query.Where(
|
||||
s.db.Where("error_code IN (?)", params.ErrorCodes).
|
||||
Or("status_code IN (?)", params.StatusCodes),
|
||||
)
|
||||
} else if hasErrorCodes {
|
||||
query = query.Where("error_code IN (?)", params.ErrorCodes)
|
||||
} else if hasStatusCodes {
|
||||
query = query.Where("status_code IN (?)", params.StatusCodes)
|
||||
}
|
||||
if params.Q != "" {
|
||||
searchQuery := "%" + params.Q + "%"
|
||||
query = query.Where(
|
||||
"model_name LIKE ? OR error_code LIKE ? OR error_message LIKE ? OR CAST(status_code AS CHAR) LIKE ?",
|
||||
searchQuery, searchQuery, searchQuery, searchQuery,
|
||||
)
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// ParseLogQueryParams 在 Handler 层调用,解析 Gin 参数
|
||||
func ParseLogQueryParams(queryParams map[string]string) (LogQueryParams, error) {
|
||||
params := LogQueryParams{
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
}
|
||||
|
||||
if pageStr, ok := queryParams["page"]; ok {
|
||||
if page, err := strconv.Atoi(pageStr); err == nil && page > 0 {
|
||||
params.Page = page
|
||||
}
|
||||
}
|
||||
|
||||
if pageSizeStr, ok := queryParams["page_size"]; ok {
|
||||
if pageSize, err := strconv.Atoi(pageSizeStr); err == nil && pageSize > 0 {
|
||||
params.PageSize = pageSize
|
||||
}
|
||||
}
|
||||
|
||||
if modelName, ok := queryParams["model_name"]; ok {
|
||||
params.ModelName = modelName
|
||||
}
|
||||
|
||||
if isSuccessStr, ok := queryParams["is_success"]; ok {
|
||||
if isSuccess, err := strconv.ParseBool(isSuccessStr); err == nil {
|
||||
params.IsSuccess = &isSuccess
|
||||
} else {
|
||||
return params, fmt.Errorf("invalid is_success parameter: %s", isSuccessStr)
|
||||
}
|
||||
}
|
||||
|
||||
if statusCodeStr, ok := queryParams["status_code"]; ok {
|
||||
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil {
|
||||
params.StatusCode = &statusCode
|
||||
} else {
|
||||
return params, fmt.Errorf("invalid status_code parameter: %s", statusCodeStr)
|
||||
}
|
||||
}
|
||||
|
||||
if keyIDsStr, ok := queryParams["key_ids"]; ok {
|
||||
params.KeyIDs = strings.Split(keyIDsStr, ",")
|
||||
}
|
||||
|
||||
if groupIDsStr, ok := queryParams["group_ids"]; ok {
|
||||
params.GroupIDs = strings.Split(groupIDsStr, ",")
|
||||
}
|
||||
|
||||
if errorCodesStr, ok := queryParams["error_codes"]; ok {
|
||||
params.ErrorCodes = strings.Split(errorCodesStr, ",")
|
||||
}
|
||||
if statusCodesStr, ok := queryParams["status_codes"]; ok {
|
||||
params.StatusCodes = strings.Split(statusCodesStr, ",")
|
||||
}
|
||||
if q, ok := queryParams["q"]; ok {
|
||||
params.Q = q
|
||||
}
|
||||
return params, nil
|
||||
}
|
||||
|
||||
// DeleteLogs 删除指定ID的日志
|
||||
func (s *LogService) DeleteLogs(ctx context.Context, ids []uint) error {
|
||||
if len(ids) == 0 {
|
||||
return fmt.Errorf("no log IDs provided")
|
||||
}
|
||||
return s.db.WithContext(ctx).Delete(&models.RequestLog{}, ids).Error
|
||||
}
|
||||
|
||||
// DeleteAllLogs 删除所有日志
|
||||
func (s *LogService) DeleteAllLogs(ctx context.Context) error {
|
||||
return s.db.WithContext(ctx).Where("1 = 1").Delete(&models.RequestLog{}).Error
|
||||
}
|
||||
|
||||
// DeleteOldLogs 删除指定天数之前的日志
|
||||
func (s *LogService) DeleteOldLogs(ctx context.Context, days int) (int64, error) {
|
||||
if days <= 0 {
|
||||
return 0, fmt.Errorf("days must be positive")
|
||||
}
|
||||
result := s.db.WithContext(ctx).
|
||||
Where("request_time < DATE_SUB(NOW(), INTERVAL ? DAY)", days).
|
||||
Delete(&models.RequestLog{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
// Filename: internal/service/resource_service.go
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"context"
|
||||
"gemini-balancer/internal/domain/proxy"
|
||||
apperrors "gemini-balancer/internal/errors"
|
||||
"gemini-balancer/internal/models"
|
||||
"gemini-balancer/internal/repository"
|
||||
@@ -15,10 +15,7 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoResourceAvailable = errors.New("no available resource found for the request")
|
||||
)
|
||||
|
||||
// RequestResources 封装了一次成功请求所需的所有资源。
|
||||
type RequestResources struct {
|
||||
KeyGroup *models.KeyGroup
|
||||
APIKey *models.APIKey
|
||||
@@ -27,86 +24,92 @@ type RequestResources struct {
|
||||
RequestConfig *models.RequestConfig
|
||||
}
|
||||
|
||||
// ResourceService 负责根据请求参数和业务规则,动态地选择和分配API密钥及相关资源。
|
||||
type ResourceService struct {
|
||||
settingsManager *settings.SettingsManager
|
||||
groupManager *GroupManager
|
||||
keyRepo repository.KeyRepository
|
||||
authTokenRepo repository.AuthTokenRepository
|
||||
apiKeyService *APIKeyService
|
||||
proxyManager *proxy.Module
|
||||
logger *logrus.Entry
|
||||
initOnce sync.Once
|
||||
}
|
||||
|
||||
// NewResourceService 创建并初始化一个新的 ResourceService 实例。
|
||||
func NewResourceService(
|
||||
sm *settings.SettingsManager,
|
||||
gm *GroupManager,
|
||||
kr repository.KeyRepository,
|
||||
atr repository.AuthTokenRepository,
|
||||
aks *APIKeyService,
|
||||
pm *proxy.Module,
|
||||
logger *logrus.Logger,
|
||||
) *ResourceService {
|
||||
logger.Debugf("[FORENSIC PROBE | INJECTION | ResourceService] Received 'keyRepo' param. Fingerprint: %p", kr)
|
||||
rs := &ResourceService{
|
||||
settingsManager: sm,
|
||||
groupManager: gm,
|
||||
keyRepo: kr,
|
||||
authTokenRepo: atr,
|
||||
apiKeyService: aks,
|
||||
proxyManager: pm,
|
||||
logger: logger.WithField("component", "ResourceService📦️"),
|
||||
}
|
||||
|
||||
// 使用 sync.Once 确保预热任务在服务生命周期内仅执行一次
|
||||
rs.initOnce.Do(func() {
|
||||
go rs.preWarmCache(logger)
|
||||
go rs.preWarmCache()
|
||||
})
|
||||
return rs
|
||||
|
||||
}
|
||||
|
||||
// --- [模式一:智能聚合模式] ---
|
||||
func (s *ResourceService) GetResourceFromBasePool(authToken *models.AuthToken, modelName string) (*RequestResources, error) {
|
||||
// GetResourceFromBasePool 使用智能聚合池模式获取资源。
|
||||
func (s *ResourceService) GetResourceFromBasePool(ctx context.Context, authToken *models.AuthToken, modelName string) (*RequestResources, error) {
|
||||
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "model_name": modelName, "mode": "BasePool"})
|
||||
log.Debug("Entering BasePool resource acquisition.")
|
||||
// 1.筛选出所有符合条件的候选组,并按优先级排序
|
||||
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken.AllowedGroups)
|
||||
|
||||
candidateGroups := s.filterAndSortCandidateGroups(modelName, authToken)
|
||||
if len(candidateGroups) == 0 {
|
||||
log.Warn("No candidate groups found for BasePool construction.")
|
||||
return nil, apperrors.ErrNoKeysAvailable
|
||||
}
|
||||
// 2.从 BasePool中,根据系统全局策略选择一个Key
|
||||
|
||||
basePool := &repository.BasePool{
|
||||
CandidateGroups: candidateGroups,
|
||||
PollingStrategy: s.settingsManager.GetSettings().PollingStrategy,
|
||||
}
|
||||
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(basePool)
|
||||
|
||||
apiKey, selectedGroup, err := s.keyRepo.SelectOneActiveKeyFromBasePool(ctx, basePool)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Failed to select a key from the BasePool.")
|
||||
return nil, apperrors.ErrNoKeysAvailable
|
||||
}
|
||||
// 3. 组装最终资源
|
||||
// [关键] 在此模式下,RequestConfig 永远是空的,以保证透明性。
|
||||
|
||||
resources, err := s.assembleRequestResources(selectedGroup, apiKey)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to assemble resources after selecting key from BasePool.")
|
||||
return nil, err
|
||||
}
|
||||
resources.RequestConfig = &models.RequestConfig{} // 强制为空
|
||||
resources.RequestConfig = &models.RequestConfig{} // BasePool 模式使用默认请求配置
|
||||
|
||||
log.Infof("Successfully selected KeyID %d from GroupID %d for the BasePool.", apiKey.ID, selectedGroup.ID)
|
||||
return resources, nil
|
||||
}
|
||||
|
||||
// --- [模式二:精确路由模式] ---
|
||||
func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, groupName string) (*RequestResources, error) {
|
||||
// GetResourceFromGroup 使用精确路由模式(指定密钥组)获取资源。
|
||||
func (s *ResourceService) GetResourceFromGroup(ctx context.Context, authToken *models.AuthToken, groupName string) (*RequestResources, error) {
|
||||
log := s.logger.WithFields(logrus.Fields{"token_id": authToken.ID, "group_name": groupName, "mode": "PreciseRoute"})
|
||||
log.Debug("Entering PreciseRoute resource acquisition.")
|
||||
|
||||
targetGroup, ok := s.groupManager.GetGroupByName(groupName)
|
||||
|
||||
if !ok {
|
||||
return nil, apperrors.NewAPIError(apperrors.ErrGroupNotFound, "The specified group does not exist.")
|
||||
}
|
||||
|
||||
if !s.isTokenAllowedForGroup(authToken, targetGroup.ID) {
|
||||
return nil, apperrors.NewAPIError(apperrors.ErrPermissionDenied, "Token does not have permission to access this group.")
|
||||
}
|
||||
|
||||
apiKey, _, err := s.keyRepo.SelectOneActiveKey(targetGroup)
|
||||
apiKey, _, err := s.keyRepo.SelectOneActiveKey(ctx, targetGroup)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Failed to select a key from the precisely targeted group.")
|
||||
return nil, apperrors.ErrNoKeysAvailable
|
||||
@@ -117,39 +120,39 @@ func (s *ResourceService) GetResourceFromGroup(authToken *models.AuthToken, grou
|
||||
log.WithError(err).Error("Failed to assemble resources for precise route.")
|
||||
return nil, err
|
||||
}
|
||||
resources.RequestConfig = targetGroup.RequestConfig
|
||||
resources.RequestConfig = targetGroup.RequestConfig // 精确路由使用该组的特定请求配置
|
||||
|
||||
log.Infof("Successfully selected KeyID %d by precise routing to GroupID %d.", apiKey.ID, targetGroup.ID)
|
||||
return resources, nil
|
||||
}
|
||||
|
||||
// GetAllowedModelsForToken 获取指定认证令牌有权访问的所有模型名称列表。
|
||||
func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken) []string {
|
||||
allGroups := s.groupManager.GetAllGroups()
|
||||
if len(allGroups) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
allowedModelsSet := make(map[string]struct{})
|
||||
|
||||
allowedGroupIDs := make(map[uint]bool)
|
||||
if authToken.IsAdmin {
|
||||
for _, group := range allGroups {
|
||||
for _, modelMapping := range group.AllowedModels {
|
||||
|
||||
allowedModelsSet[modelMapping.ModelName] = struct{}{}
|
||||
}
|
||||
allowedGroupIDs[group.ID] = true
|
||||
}
|
||||
} else {
|
||||
allowedGroupIDs := make(map[uint]bool)
|
||||
for _, ag := range authToken.AllowedGroups {
|
||||
allowedGroupIDs[ag.ID] = true
|
||||
}
|
||||
for _, group := range allGroups {
|
||||
if _, ok := allowedGroupIDs[group.ID]; ok {
|
||||
for _, modelMapping := range group.AllowedModels {
|
||||
}
|
||||
|
||||
allowedModelsSet[modelMapping.ModelName] = struct{}{}
|
||||
}
|
||||
allowedModelsSet := make(map[string]struct{})
|
||||
for _, group := range allGroups {
|
||||
if allowedGroupIDs[group.ID] {
|
||||
for _, modelMapping := range group.AllowedModels {
|
||||
allowedModelsSet[modelMapping.ModelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]string, 0, len(allowedModelsSet))
|
||||
for modelName := range allowedModelsSet {
|
||||
result = append(result, modelName)
|
||||
@@ -158,20 +161,52 @@ func (s *ResourceService) GetAllowedModelsForToken(authToken *models.AuthToken)
|
||||
return result
|
||||
}
|
||||
|
||||
// ReportRequestResult 向 APIKeyService 报告请求的最终结果,以便更新密钥状态。
|
||||
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
|
||||
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
|
||||
return
|
||||
}
|
||||
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
|
||||
}
|
||||
|
||||
// --- 私有辅助方法 ---
|
||||
|
||||
// preWarmCache 在后台执行一次性的缓存预热任务。
|
||||
func (s *ResourceService) preWarmCache() {
|
||||
time.Sleep(2 * time.Second) // 等待其他服务组件可能完成初始化
|
||||
s.logger.Info("Performing initial key cache pre-warming...")
|
||||
|
||||
// 强制加载 GroupManager 缓存
|
||||
s.logger.Info("Pre-warming GroupManager cache...")
|
||||
_ = s.groupManager.GetAllGroups()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // 给予更长的超时
|
||||
defer cancel()
|
||||
|
||||
if err := s.keyRepo.LoadAllKeysToStore(ctx); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
|
||||
} else {
|
||||
s.logger.Info("Initial key cache pre-warming completed successfully.")
|
||||
}
|
||||
}
|
||||
|
||||
// assembleRequestResources 根据密钥组和API密钥组装最终的资源对象。
|
||||
func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKey *models.APIKey) (*RequestResources, error) {
|
||||
selectedUpstream := s.selectUpstreamForGroup(group)
|
||||
if selectedUpstream == nil {
|
||||
return nil, apperrors.NewAPIError(apperrors.ErrConfigurationError, "Selected group has no valid upstream and no global default is set.")
|
||||
}
|
||||
var proxyConfig *models.ProxyConfig
|
||||
// [注意] 代理逻辑需要一个 proxyModule 实例,我们暂时置空。后续需要重新注入依赖。
|
||||
// if group.EnableProxy && s.proxyModule != nil {
|
||||
// var err error
|
||||
// proxyConfig, err = s.proxyModule.AssignProxyIfNeeded(apiKey)
|
||||
// if err != nil {
|
||||
// s.logger.WithError(err).Warnf("Failed to assign proxy for API key %d.", apiKey.ID)
|
||||
// }
|
||||
// }
|
||||
var err error
|
||||
// 只有在组明确启用代理时,才为其分配代理
|
||||
if group.EnableProxy {
|
||||
proxyConfig, err = s.proxyManager.AssignProxyIfNeeded(apiKey)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Group '%s' (ID: %d) requires a proxy, but failed to assign one for KeyID %d", group.Name, group.ID, apiKey.ID)
|
||||
// 根据业务需求,这里必须返回错误,因为代理是该组的强制要求
|
||||
return nil, apperrors.NewAPIError(apperrors.ErrProxyNotAvailable, "Required proxy is not available for this request.")
|
||||
}
|
||||
}
|
||||
return &RequestResources{
|
||||
KeyGroup: group,
|
||||
APIKey: apiKey,
|
||||
@@ -180,8 +215,10 @@ func (s *ResourceService) assembleRequestResources(group *models.KeyGroup, apiKe
|
||||
}, nil
|
||||
}
|
||||
|
||||
// selectUpstreamForGroup 为指定的密钥组选择一个上游端点。
|
||||
func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models.UpstreamEndpoint {
|
||||
if len(group.AllowedUpstreams) > 0 {
|
||||
// (未来可扩展负载均衡逻辑)
|
||||
return group.AllowedUpstreams[0]
|
||||
}
|
||||
globalSettings := s.settingsManager.GetSettings()
|
||||
@@ -191,62 +228,39 @@ func (s *ResourceService) selectUpstreamForGroup(group *models.KeyGroup) *models
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ResourceService) preWarmCache(logger *logrus.Logger) error {
|
||||
time.Sleep(2 * time.Second)
|
||||
s.logger.Info("Performing initial key cache pre-warming...")
|
||||
if err := s.keyRepo.LoadAllKeysToStore(); err != nil {
|
||||
logger.WithError(err).Error("Failed to perform initial key cache pre-warming.")
|
||||
return err
|
||||
}
|
||||
s.logger.Info("Initial key cache pre-warming completed successfully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ResourceService) GetResourcesForRequest(modelName string, allowedGroups []*models.KeyGroup) (*RequestResources, error) {
|
||||
return nil, errors.New("GetResourcesForRequest is deprecated; use GetResourceFromBasePool or GetResourceFromGroup")
|
||||
}
|
||||
|
||||
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, allowedGroupsFromToken []*models.KeyGroup) []*models.KeyGroup {
|
||||
// filterAndSortCandidateGroups 根据模型名称和令牌权限,筛选并排序出合格的候选密钥组。
|
||||
func (s *ResourceService) filterAndSortCandidateGroups(modelName string, authToken *models.AuthToken) []*models.KeyGroup {
|
||||
allGroupsFromCache := s.groupManager.GetAllGroups()
|
||||
var candidateGroups []*models.KeyGroup
|
||||
// 1. 确定权限范围
|
||||
allowedGroupIDs := make(map[uint]bool)
|
||||
isTokenRestricted := len(allowedGroupsFromToken) > 0
|
||||
if isTokenRestricted {
|
||||
for _, ag := range allowedGroupsFromToken {
|
||||
allowedGroupIDs[ag.ID] = true
|
||||
}
|
||||
}
|
||||
// 2. 筛选
|
||||
|
||||
for _, group := range allGroupsFromCache {
|
||||
// 检查Token权限
|
||||
if isTokenRestricted && !allowedGroupIDs[group.ID] {
|
||||
// 检查令牌权限
|
||||
if !s.isTokenAllowedForGroup(authToken, group.ID) {
|
||||
continue
|
||||
}
|
||||
// 检查模型是否被允许
|
||||
isModelAllowed := false
|
||||
if len(group.AllowedModels) == 0 { // 如果组不限制模型,则允许
|
||||
isModelAllowed = true
|
||||
} else {
|
||||
for _, m := range group.AllowedModels {
|
||||
if m.ModelName == modelName {
|
||||
isModelAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if isModelAllowed {
|
||||
// 检查模型支持情况 (如果组内未限制模型,则默认支持所有模型)
|
||||
if len(group.AllowedModels) == 0 || s.groupSupportsModel(group, modelName) {
|
||||
candidateGroups = append(candidateGroups, group)
|
||||
}
|
||||
}
|
||||
|
||||
// 3.按 Order 字段升序排序
|
||||
sort.SliceStable(candidateGroups, func(i, j int) bool {
|
||||
return candidateGroups[i].Order < candidateGroups[j].Order
|
||||
})
|
||||
return candidateGroups
|
||||
}
|
||||
|
||||
// groupSupportsModel 检查指定的密钥组是否支持给定的模型名称。
|
||||
func (s *ResourceService) groupSupportsModel(group *models.KeyGroup, modelName string) bool {
|
||||
for _, m := range group.AllowedModels {
|
||||
if m.ModelName == modelName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isTokenAllowedForGroup 检查指定的认证令牌是否有权访问给定的密钥组。
|
||||
func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, groupID uint) bool {
|
||||
if authToken.IsAdmin {
|
||||
return true
|
||||
@@ -258,10 +272,3 @@ func (s *ResourceService) isTokenAllowedForGroup(authToken *models.AuthToken, gr
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *ResourceService) ReportRequestResult(resources *RequestResources, success bool, apiErr *apperrors.APIError) {
|
||||
if resources == nil || resources.KeyGroup == nil || resources.APIKey == nil {
|
||||
return
|
||||
}
|
||||
s.apiKeyService.HandleRequestResult(resources.KeyGroup, resources.APIKey, success, apiErr)
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func (s *SecurityService) AuthenticateToken(tokenValue string) (*models.AuthToke
|
||||
// IsIPBanned
|
||||
func (s *SecurityService) IsIPBanned(ctx context.Context, ip string) (bool, error) {
|
||||
banKey := fmt.Sprintf("banned_ip:%s", ip)
|
||||
return s.store.Exists(banKey)
|
||||
return s.store.Exists(ctx, banKey)
|
||||
}
|
||||
|
||||
// RecordFailedLoginAttempt
|
||||
@@ -61,7 +61,7 @@ func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip strin
|
||||
return nil
|
||||
}
|
||||
|
||||
count, err := s.store.HIncrBy(loginAttemptsKey, ip, 1)
|
||||
count, err := s.store.HIncrBy(ctx, loginAttemptsKey, ip, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -71,12 +71,12 @@ func (s *SecurityService) RecordFailedLoginAttempt(ctx context.Context, ip strin
|
||||
banDuration := s.SettingsManager.GetIPBanDuration()
|
||||
banKey := fmt.Sprintf("banned_ip:%s", ip)
|
||||
|
||||
if err := s.store.Set(banKey, []byte("1"), banDuration); err != nil {
|
||||
if err := s.store.Set(ctx, banKey, []byte("1"), banDuration); err != nil {
|
||||
return err
|
||||
}
|
||||
s.logger.Warnf("IP BANNED: IP [%s] has been banned for %v due to excessive failed login attempts.", ip, banDuration)
|
||||
|
||||
s.store.HDel(loginAttemptsKey, ip)
|
||||
s.store.HDel(ctx, loginAttemptsKey, ip)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/models"
|
||||
@@ -34,75 +35,121 @@ func NewStatsService(db *gorm.DB, s store.Store, repo repository.KeyRepository,
|
||||
|
||||
func (s *StatsService) Start() {
|
||||
s.logger.Info("Starting event listener for stats maintenance.")
|
||||
sub, err := s.store.Subscribe(models.TopicKeyStatusChanged)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("Failed to subscribe to topic %s: %v", models.TopicKeyStatusChanged, err)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer sub.Close()
|
||||
for {
|
||||
select {
|
||||
case msg := <-sub.Channel():
|
||||
var event models.KeyStatusChangedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal KeyStatusChangedEvent: %v", err)
|
||||
continue
|
||||
}
|
||||
s.handleKeyStatusChange(&event)
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Stopping stats event listener.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
go s.listenForEvents()
|
||||
}
|
||||
|
||||
func (s *StatsService) Stop() {
|
||||
close(s.stopChan)
|
||||
}
|
||||
|
||||
func (s *StatsService) listenForEvents() {
|
||||
for {
|
||||
select {
|
||||
case <-s.stopChan:
|
||||
s.logger.Info("Stopping stats event listener.")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
sub, err := s.store.Subscribe(ctx, models.TopicKeyStatusChanged)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to subscribe: %v, retrying in 5s", err)
|
||||
cancel()
|
||||
time.Sleep(5 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Info("Subscribed to key status changes")
|
||||
s.handleSubscription(sub, cancel)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StatsService) handleSubscription(sub store.Subscription, cancel context.CancelFunc) {
|
||||
defer sub.Close()
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-sub.Channel():
|
||||
var event models.KeyStatusChangedEvent
|
||||
if err := json.Unmarshal(msg.Payload, &event); err != nil {
|
||||
s.logger.Errorf("Failed to unmarshal event: %v", err)
|
||||
continue
|
||||
}
|
||||
s.handleKeyStatusChange(&event)
|
||||
case <-s.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StatsService) handleKeyStatusChange(event *models.KeyStatusChangedEvent) {
|
||||
if event.GroupID == 0 {
|
||||
s.logger.Warnf("Received KeyStatusChangedEvent with no GroupID. Reason: %s, KeyID: %d. Skipping.", event.ChangeReason, event.KeyID)
|
||||
return
|
||||
}
|
||||
ctx := context.Background()
|
||||
statsKey := fmt.Sprintf("stats:group:%d", event.GroupID)
|
||||
s.logger.Infof("Handling key status change for Group %d, KeyID: %d, Reason: %s", event.GroupID, event.KeyID, event.ChangeReason)
|
||||
|
||||
switch event.ChangeReason {
|
||||
case "key_unlinked", "key_hard_deleted":
|
||||
if event.OldStatus != "" {
|
||||
s.store.HIncrBy(statsKey, "total_keys", -1)
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||||
if _, err := s.store.HIncrBy(ctx, statsKey, "total_keys", -1); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to decrement total_keys for group %d", event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
return
|
||||
}
|
||||
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to decrement %s_keys for group %d", event.OldStatus, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
s.logger.Warnf("Received '%s' event for group %d without OldStatus, forcing recalculation.", event.ChangeReason, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
}
|
||||
case "key_linked":
|
||||
if event.NewStatus != "" {
|
||||
s.store.HIncrBy(statsKey, "total_keys", 1)
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||||
if _, err := s.store.HIncrBy(ctx, statsKey, "total_keys", 1); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to increment total_keys for group %d", event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
return
|
||||
}
|
||||
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to increment %s_keys for group %d", event.NewStatus, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
s.logger.Warnf("Received 'key_linked' event for group %d without NewStatus, forcing recalculation.", event.GroupID)
|
||||
s.RecalculateGroupKeyStats(event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
}
|
||||
case "manual_update", "error_threshold_reached", "key_recovered", "invalid_api_key":
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1)
|
||||
s.store.HIncrBy(statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1)
|
||||
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.OldStatus), -1); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to decrement %s_keys for group %d", event.OldStatus, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
return
|
||||
}
|
||||
if _, err := s.store.HIncrBy(ctx, statsKey, fmt.Sprintf("%s_keys", event.NewStatus), 1); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to increment %s_keys for group %d", event.NewStatus, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
return
|
||||
}
|
||||
default:
|
||||
s.logger.Warnf("Unhandled event reason '%s' for group %d, forcing recalculation.", event.ChangeReason, event.GroupID)
|
||||
s.RecalculateGroupKeyStats(event.GroupID)
|
||||
s.RecalculateGroupKeyStats(ctx, event.GroupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
|
||||
func (s *StatsService) RecalculateGroupKeyStats(ctx context.Context, groupID uint) error {
|
||||
s.logger.Warnf("Performing full recalculation for group %d key stats.", groupID)
|
||||
var results []struct {
|
||||
Status models.APIKeyStatus
|
||||
Count int64
|
||||
}
|
||||
if err := s.db.Model(&models.GroupAPIKeyMapping{}).
|
||||
if err := s.db.WithContext(ctx).Model(&models.GroupAPIKeyMapping{}).
|
||||
Where("key_group_id = ?", groupID).
|
||||
Select("status, COUNT(*) as count").
|
||||
Group("status").
|
||||
@@ -111,45 +158,36 @@ func (s *StatsService) RecalculateGroupKeyStats(groupID uint) error {
|
||||
}
|
||||
statsKey := fmt.Sprintf("stats:group:%d", groupID)
|
||||
|
||||
updates := make(map[string]interface{})
|
||||
totalKeys := int64(0)
|
||||
updates := map[string]interface{}{
|
||||
"active_keys": int64(0),
|
||||
"disabled_keys": int64(0),
|
||||
"error_keys": int64(0),
|
||||
"total_keys": int64(0),
|
||||
}
|
||||
for _, res := range results {
|
||||
updates[fmt.Sprintf("%s_keys", res.Status)] = res.Count
|
||||
totalKeys += res.Count
|
||||
updates["total_keys"] = updates["total_keys"].(int64) + res.Count
|
||||
}
|
||||
updates["total_keys"] = totalKeys
|
||||
|
||||
if err := s.store.Del(statsKey); err != nil {
|
||||
if err := s.store.Del(ctx, statsKey); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to delete stale stats key for group %d before recalculation.", groupID)
|
||||
}
|
||||
if err := s.store.HSet(statsKey, updates); err != nil {
|
||||
if err := s.store.HSet(ctx, statsKey, updates); err != nil {
|
||||
return fmt.Errorf("failed to HSet recalculated stats for group %d: %w", groupID, err)
|
||||
}
|
||||
s.logger.Infof("Successfully recalculated stats for group %d using HSet.", groupID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StatsService) GetDashboardStats() (*models.DashboardStatsResponse, error) {
|
||||
// TODO 逻辑:
|
||||
// 1. 从Redis中获取所有分组的Key统计 (HGetAll)
|
||||
// 2. 从 stats_hourly 表中获取过去24小时的请求数和错误率
|
||||
// 3. 组合成 DashboardStatsResponse
|
||||
// ... 这个方法的具体实现,我们可以在DashboardQueryService中完成,
|
||||
// 这里我们先确保StatsService的核心职责(维护缓存)已经完成。
|
||||
// 为了编译通过,我们先返回一个空对象。
|
||||
|
||||
// 伪代码:
|
||||
// keyCounts, _ := s.store.HGetAll("stats:global:keys")
|
||||
// ...
|
||||
|
||||
func (s *StatsService) GetDashboardStats(ctx context.Context) (*models.DashboardStatsResponse, error) {
|
||||
return &models.DashboardStatsResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *StatsService) AggregateHourlyStats() error {
|
||||
func (s *StatsService) AggregateHourlyStats(ctx context.Context) error {
|
||||
s.logger.Info("Starting aggregation of the last hour's request data...")
|
||||
now := time.Now()
|
||||
endTime := now.Truncate(time.Hour) // 例如:15:23 -> 15:00
|
||||
startTime := endTime.Add(-1 * time.Hour) // 15:00 -> 14:00
|
||||
endTime := now.Truncate(time.Hour)
|
||||
startTime := endTime.Add(-1 * time.Hour)
|
||||
|
||||
s.logger.Infof("Aggregating data for time window: [%s, %s)", startTime.Format(time.RFC3339), endTime.Format(time.RFC3339))
|
||||
type aggregationResult struct {
|
||||
@@ -161,7 +199,8 @@ func (s *StatsService) AggregateHourlyStats() error {
|
||||
CompletionTokens int64
|
||||
}
|
||||
var results []aggregationResult
|
||||
err := s.db.Model(&models.RequestLog{}).
|
||||
|
||||
err := s.db.WithContext(ctx).Model(&models.RequestLog{}).
|
||||
Select("group_id, model_name, COUNT(*) as request_count, SUM(CASE WHEN is_success = true THEN 1 ELSE 0 END) as success_count, SUM(prompt_tokens) as prompt_tokens, SUM(completion_tokens) as completion_tokens").
|
||||
Where("request_time >= ? AND request_time < ?", startTime, endTime).
|
||||
Group("group_id, model_name").
|
||||
@@ -179,7 +218,7 @@ func (s *StatsService) AggregateHourlyStats() error {
|
||||
var hourlyStats []models.StatsHourly
|
||||
for _, res := range results {
|
||||
hourlyStats = append(hourlyStats, models.StatsHourly{
|
||||
Time: startTime, // 所有记录的时间戳都是该小时的起点
|
||||
Time: startTime,
|
||||
GroupID: res.GroupID,
|
||||
ModelName: res.ModelName,
|
||||
RequestCount: res.RequestCount,
|
||||
@@ -189,8 +228,18 @@ func (s *StatsService) AggregateHourlyStats() error {
|
||||
})
|
||||
}
|
||||
|
||||
return s.db.Clauses(clause.OnConflict{
|
||||
if err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "time"}, {Name: "group_id"}, {Name: "model_name"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"request_count", "success_count", "prompt_tokens", "completion_tokens"}),
|
||||
}).Create(&hourlyStats).Error
|
||||
}).Create(&hourlyStats).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.db.WithContext(ctx).
|
||||
Where("request_time >= ? AND request_time < ?", startTime, endTime).
|
||||
Delete(&models.RequestLog{}).Error; err != nil {
|
||||
s.logger.WithError(err).Warn("Failed to delete aggregated request logs")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ func NewTokenManager(repo repository.AuthTokenRepository, store store.Store, log
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
s, err := syncer.NewCacheSyncer(tokenLoader, store, TopicTokenChanged)
|
||||
s, err := syncer.NewCacheSyncer(tokenLoader, store, TopicTokenChanged, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token manager syncer: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// file: gemini-balancer\internal\settings\settings.go
|
||||
// Filename: gemini-balancer/internal/settings/settings.go (最终审计修复版)
|
||||
package settings
|
||||
|
||||
import (
|
||||
@@ -19,7 +19,9 @@ import (
|
||||
const SettingsUpdateChannel = "system_settings:updated"
|
||||
const DefaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
|
||||
// SettingsManager [核心修正] syncer现在缓存正确的“蓝图”类型
|
||||
var _ models.SettingsManager = (*SettingsManager)(nil)
|
||||
|
||||
// SettingsManager 负责管理系统的动态设置,包括从数据库加载、缓存同步和更新。
|
||||
type SettingsManager struct {
|
||||
db *gorm.DB
|
||||
syncer *syncer.CacheSyncer[*models.SystemSettings]
|
||||
@@ -27,13 +29,14 @@ type SettingsManager struct {
|
||||
jsonToFieldType map[string]reflect.Type // 用于将JSON字段映射到Go类型
|
||||
}
|
||||
|
||||
// NewSettingsManager 创建一个新的 SettingsManager 实例。
|
||||
func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (*SettingsManager, error) {
|
||||
sm := &SettingsManager{
|
||||
db: db,
|
||||
logger: logger.WithField("component", "SettingsManager⚙️"),
|
||||
jsonToFieldType: make(map[string]reflect.Type),
|
||||
}
|
||||
// settingsLoader 的职责:读取“砖块”,组装并返回“蓝图”
|
||||
|
||||
settingsType := reflect.TypeOf(models.SystemSettings{})
|
||||
for i := 0; i < settingsType.NumField(); i++ {
|
||||
field := settingsType.Field(i)
|
||||
@@ -42,102 +45,89 @@ func NewSettingsManager(db *gorm.DB, store store.Store, logger *logrus.Logger) (
|
||||
sm.jsonToFieldType[jsonTag] = field.Type
|
||||
}
|
||||
}
|
||||
// settingsLoader 的职责:读取“砖块”,智能组装成“蓝图”
|
||||
|
||||
settingsLoader := func() (*models.SystemSettings, error) {
|
||||
sm.logger.Info("Loading system settings from database...")
|
||||
var dbRecords []models.Setting
|
||||
if err := sm.db.Find(&dbRecords).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to load system settings from db: %w", err)
|
||||
}
|
||||
|
||||
settingsMap := make(map[string]string)
|
||||
for _, record := range dbRecords {
|
||||
settingsMap[record.Key] = record.Value
|
||||
}
|
||||
// 从一个包含了所有“出厂设置”的“蓝图”开始
|
||||
|
||||
settings := defaultSystemSettings()
|
||||
v := reflect.ValueOf(settings).Elem()
|
||||
t := v.Type()
|
||||
// [智能卸货]
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Type().Field(i)
|
||||
fieldValue := v.Field(i)
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if dbValue, ok := settingsMap[jsonTag]; ok {
|
||||
|
||||
if dbValue, ok := settingsMap[jsonTag]; ok {
|
||||
if err := parseAndSetField(fieldValue, dbValue); err != nil {
|
||||
sm.logger.Warnf("Failed to set config field '%s' from DB value '%s': %v. Using default.", field.Name, dbValue, err)
|
||||
sm.logger.Warnf("Failed to set field '%s' from DB value '%s': %v. Using default.", field.Name, dbValue, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if settings.BaseKeyCheckEndpoint == DefaultGeminiEndpoint || settings.BaseKeyCheckEndpoint == "" {
|
||||
if settings.DefaultUpstreamURL != "" {
|
||||
// 如果全局上游URL已设置,则基于它构建新的检查端点。
|
||||
originalEndpoint := settings.BaseKeyCheckEndpoint
|
||||
derivedEndpoint := strings.TrimSuffix(settings.DefaultUpstreamURL, "/") + "/models"
|
||||
settings.BaseKeyCheckEndpoint = derivedEndpoint
|
||||
sm.logger.Infof(
|
||||
"BaseKeyCheckEndpoint is dynamically derived from DefaultUpstreamURL. Original: '%s', New: '%s'",
|
||||
originalEndpoint, derivedEndpoint,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// [评估确认] 派生逻辑与原始版本在功能和日志行为上完全一致。
|
||||
if (settings.BaseKeyCheckEndpoint == DefaultGeminiEndpoint || settings.BaseKeyCheckEndpoint == "") && settings.DefaultUpstreamURL != "" {
|
||||
derivedEndpoint := strings.TrimSuffix(settings.DefaultUpstreamURL, "/") + "/models"
|
||||
sm.logger.Infof("BaseKeyCheckEndpoint is dynamically derived from DefaultUpstreamURL: %s", derivedEndpoint)
|
||||
settings.BaseKeyCheckEndpoint = derivedEndpoint
|
||||
} else if settings.BaseKeyCheckEndpoint != DefaultGeminiEndpoint && settings.BaseKeyCheckEndpoint != "" {
|
||||
// 恢复 else 日志,以明确告知用户正在使用自定义覆盖。
|
||||
sm.logger.Infof("BaseKeyCheckEndpoint is using a user-defined override: %s", settings.BaseKeyCheckEndpoint)
|
||||
}
|
||||
|
||||
sm.logger.Info("System settings loaded and cached.")
|
||||
sm.DisplaySettings(settings)
|
||||
return settings, nil
|
||||
}
|
||||
s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel)
|
||||
|
||||
s, err := syncer.NewCacheSyncer(settingsLoader, store, SettingsUpdateChannel, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create system settings syncer: %w", err)
|
||||
}
|
||||
sm.syncer = s
|
||||
go sm.ensureSettingsInitialized()
|
||||
|
||||
if err := sm.ensureSettingsInitialized(); err != nil {
|
||||
return nil, fmt.Errorf("failed to ensure system settings are initialized: %w", err)
|
||||
}
|
||||
|
||||
return sm, nil
|
||||
}
|
||||
|
||||
// GetSettings [核心修正] 现在它正确地返回我们需要的“蓝图”
|
||||
// GetSettings 返回当前缓存的系统设置。
|
||||
func (sm *SettingsManager) GetSettings() *models.SystemSettings {
|
||||
return sm.syncer.Get()
|
||||
}
|
||||
|
||||
// UpdateSettings [核心修正] 它接收更新,并将它们转换为“砖块”存入数据库
|
||||
// UpdateSettings 更新一个或多个系统设置。
|
||||
func (sm *SettingsManager) UpdateSettings(settingsMap map[string]interface{}) error {
|
||||
var settingsToUpdate []models.Setting
|
||||
|
||||
for key, value := range settingsMap {
|
||||
fieldType, ok := sm.jsonToFieldType[key]
|
||||
if !ok {
|
||||
sm.logger.Warnf("Received update for unknown setting key '%s', ignoring.", key)
|
||||
continue
|
||||
}
|
||||
var dbValue string
|
||||
// [智能打包]
|
||||
// 如果字段是 slice 或 map,我们就将传入的 interface{} “打包”成 JSON string
|
||||
kind := fieldType.Kind()
|
||||
if kind == reflect.Slice || kind == reflect.Map {
|
||||
jsonBytes, marshalErr := json.Marshal(value)
|
||||
if marshalErr != nil {
|
||||
// [真正的错误处理] 如果打包失败,我们记录日志,并跳过这个“坏掉的集装箱”。
|
||||
sm.logger.Warnf("Failed to marshal setting '%s' to JSON: %v, skipping update.", key, marshalErr)
|
||||
continue // 跳过,继续处理下一个key
|
||||
}
|
||||
dbValue = string(jsonBytes)
|
||||
} else if kind == reflect.Bool {
|
||||
if b, ok := value.(bool); ok {
|
||||
dbValue = strconv.FormatBool(b)
|
||||
} else {
|
||||
dbValue = "false"
|
||||
}
|
||||
} else {
|
||||
dbValue = fmt.Sprintf("%v", value)
|
||||
|
||||
dbValue, err := sm.convertToDBValue(key, value, fieldType)
|
||||
if err != nil {
|
||||
sm.logger.Warnf("Failed to convert value for setting '%s': %v. Skipping update.", key, err)
|
||||
continue
|
||||
}
|
||||
|
||||
settingsToUpdate = append(settingsToUpdate, models.Setting{
|
||||
Key: key,
|
||||
Value: dbValue,
|
||||
})
|
||||
}
|
||||
|
||||
if len(settingsToUpdate) > 0 {
|
||||
err := sm.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "key"}},
|
||||
@@ -147,83 +137,20 @@ func (sm *SettingsManager) UpdateSettings(settingsMap map[string]interface{}) er
|
||||
return fmt.Errorf("failed to update settings in db: %w", err)
|
||||
}
|
||||
}
|
||||
return sm.syncer.Invalidate()
|
||||
}
|
||||
|
||||
// ensureSettingsInitialized [核心修正] 确保DB中有所有“砖块”的定义
|
||||
func (sm *SettingsManager) ensureSettingsInitialized() {
|
||||
defaults := defaultSystemSettings()
|
||||
v := reflect.ValueOf(defaults).Elem()
|
||||
t := v.Type()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fieldValue := v.Field(i)
|
||||
key := field.Tag.Get("json")
|
||||
if key == "" || key == "-" {
|
||||
continue
|
||||
}
|
||||
var existing models.Setting
|
||||
if err := sm.db.Where("key = ?", key).First(&existing).Error; err == gorm.ErrRecordNotFound {
|
||||
|
||||
var defaultValue string
|
||||
kind := fieldValue.Kind()
|
||||
// [智能初始化]
|
||||
if kind == reflect.Slice || kind == reflect.Map {
|
||||
// 为复杂类型,生成一个“空的”JSON字符串,例如 "[]" 或 "{}"
|
||||
jsonBytes, _ := json.Marshal(fieldValue.Interface())
|
||||
defaultValue = string(jsonBytes)
|
||||
} else {
|
||||
defaultValue = field.Tag.Get("default")
|
||||
}
|
||||
setting := models.Setting{
|
||||
Key: key,
|
||||
Value: defaultValue,
|
||||
Name: field.Tag.Get("name"),
|
||||
Description: field.Tag.Get("desc"),
|
||||
Category: field.Tag.Get("category"),
|
||||
DefaultValue: field.Tag.Get("default"), // 元数据中的default,永远来自tag
|
||||
}
|
||||
if err := sm.db.Create(&setting).Error; err != nil {
|
||||
sm.logger.Errorf("Failed to initialize setting '%s': %v", key, err)
|
||||
}
|
||||
}
|
||||
if err := sm.syncer.Invalidate(); err != nil {
|
||||
sm.logger.Errorf("CRITICAL: Database settings updated, but cache invalidation failed: %v", err)
|
||||
return fmt.Errorf("settings updated but cache invalidation failed, system may be inconsistent: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetAndSaveSettings [核心新增] 將所有配置重置為其在 'default' 標籤中定義的值。
|
||||
|
||||
// ResetAndSaveSettings 将所有设置重置为其默认值。
|
||||
func (sm *SettingsManager) ResetAndSaveSettings() (*models.SystemSettings, error) {
|
||||
defaults := defaultSystemSettings()
|
||||
v := reflect.ValueOf(defaults).Elem()
|
||||
t := v.Type()
|
||||
var settingsToSave []models.Setting
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fieldValue := v.Field(i)
|
||||
key := field.Tag.Get("json")
|
||||
if key == "" || key == "-" {
|
||||
continue
|
||||
}
|
||||
settingsToSave := sm.buildSettingsFromDefaults(defaults)
|
||||
|
||||
var defaultValue string
|
||||
kind := fieldValue.Kind()
|
||||
// [智能重置]
|
||||
if kind == reflect.Slice || kind == reflect.Map {
|
||||
jsonBytes, _ := json.Marshal(fieldValue.Interface())
|
||||
defaultValue = string(jsonBytes)
|
||||
} else {
|
||||
defaultValue = field.Tag.Get("default")
|
||||
}
|
||||
setting := models.Setting{
|
||||
Key: key,
|
||||
Value: defaultValue,
|
||||
Name: field.Tag.Get("name"),
|
||||
Description: field.Tag.Get("desc"),
|
||||
Category: field.Tag.Get("category"),
|
||||
DefaultValue: field.Tag.Get("default"),
|
||||
}
|
||||
settingsToSave = append(settingsToSave, setting)
|
||||
}
|
||||
if len(settingsToSave) > 0 {
|
||||
err := sm.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "key"}},
|
||||
@@ -233,8 +160,99 @@ func (sm *SettingsManager) ResetAndSaveSettings() (*models.SystemSettings, error
|
||||
return nil, fmt.Errorf("failed to reset settings in db: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := sm.syncer.Invalidate(); err != nil {
|
||||
sm.logger.Errorf("Failed to invalidate settings cache after reset: %v", err)
|
||||
sm.logger.Errorf("CRITICAL: Database settings reset, but cache invalidation failed: %v", err)
|
||||
return nil, fmt.Errorf("settings reset but cache invalidation failed: %w", err)
|
||||
}
|
||||
|
||||
return defaults, nil
|
||||
}
|
||||
|
||||
// --- 私有辅助函数 ---
|
||||
|
||||
func (sm *SettingsManager) ensureSettingsInitialized() error {
|
||||
defaults := defaultSystemSettings()
|
||||
settingsToCreate := sm.buildSettingsFromDefaults(defaults)
|
||||
|
||||
for _, setting := range settingsToCreate {
|
||||
var existing models.Setting
|
||||
err := sm.db.Where("key = ?", setting.Key).First(&existing).Error
|
||||
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
sm.logger.Infof("Initializing new setting '%s'", setting.Key)
|
||||
if createErr := sm.db.Create(&setting).Error; createErr != nil {
|
||||
return fmt.Errorf("failed to create initial setting '%s': %w", setting.Key, createErr)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to check for existing setting '%s': %w", setting.Key, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *SettingsManager) buildSettingsFromDefaults(defaults *models.SystemSettings) []models.Setting {
|
||||
v := reflect.ValueOf(defaults).Elem()
|
||||
t := v.Type()
|
||||
var settings []models.Setting
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fieldValue := v.Field(i)
|
||||
key := field.Tag.Get("json")
|
||||
|
||||
if key == "" || key == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
var defaultValue string
|
||||
kind := fieldValue.Kind()
|
||||
|
||||
if kind == reflect.Slice || kind == reflect.Map {
|
||||
jsonBytes, _ := json.Marshal(fieldValue.Interface())
|
||||
defaultValue = string(jsonBytes)
|
||||
} else {
|
||||
defaultValue = field.Tag.Get("default")
|
||||
}
|
||||
|
||||
settings = append(settings, models.Setting{
|
||||
Key: key,
|
||||
Value: defaultValue,
|
||||
Name: field.Tag.Get("name"),
|
||||
Description: field.Tag.Get("desc"),
|
||||
Category: field.Tag.Get("category"),
|
||||
DefaultValue: field.Tag.Get("default"),
|
||||
})
|
||||
}
|
||||
return settings
|
||||
}
|
||||
|
||||
// [修正] 使用空白标识符 `_` 修复 "unused parameter" 警告。
|
||||
func (sm *SettingsManager) convertToDBValue(_ string, value interface{}, fieldType reflect.Type) (string, error) {
|
||||
kind := fieldType.Kind()
|
||||
|
||||
switch kind {
|
||||
case reflect.Slice, reflect.Map:
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal to JSON: %w", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
|
||||
case reflect.Bool:
|
||||
b, ok := value.(bool)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("expected bool, but got %T", value)
|
||||
}
|
||||
return strconv.FormatBool(b), nil
|
||||
|
||||
default:
|
||||
return fmt.Sprintf("%v", value), nil
|
||||
}
|
||||
}
|
||||
|
||||
// IsValidKey 检查给定的 JSON key 是否是有效的设置字段
|
||||
func (sm *SettingsManager) IsValidKey(key string) (reflect.Type, bool) {
|
||||
fieldType, ok := sm.jsonToFieldType[key]
|
||||
return fieldType, ok
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Filename: internal/store/factory.go
|
||||
package store
|
||||
|
||||
import (
|
||||
@@ -11,7 +12,6 @@ import (
|
||||
|
||||
// NewStore creates a new store based on the application configuration.
|
||||
func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) {
|
||||
// 检查是否有Redis配置
|
||||
if cfg.Redis.DSN != "" {
|
||||
opts, err := redis.ParseURL(cfg.Redis.DSN)
|
||||
if err != nil {
|
||||
@@ -20,10 +20,10 @@ func NewStore(cfg *config.Config, logger *logrus.Logger) (Store, error) {
|
||||
client := redis.NewClient(opts)
|
||||
if err := client.Ping(context.Background()).Err(); err != nil {
|
||||
logger.WithError(err).Warnf("WARN: Failed to connect to Redis (%s), falling back to in-memory store. Error: %v", cfg.Redis.DSN, err)
|
||||
return NewMemoryStore(logger), nil // 连接失败,也回退到内存模式,但不返回错误
|
||||
return NewMemoryStore(logger), nil
|
||||
}
|
||||
logger.Info("Successfully connected to Redis. Using Redis as store.")
|
||||
return NewRedisStore(client), nil
|
||||
return NewRedisStore(client, logger), nil
|
||||
}
|
||||
logger.Info("INFO: Redis DSN not configured, falling back to in-memory store.")
|
||||
return NewMemoryStore(logger), nil
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
// Filename: internal/store/memory_store.go (经同行审查后最终修复版)
|
||||
// Filename: internal/store/memory_store.go
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ensure memoryStore implements Store interface
|
||||
var _ Store = (*memoryStore)(nil)
|
||||
|
||||
type memoryStoreItem struct {
|
||||
@@ -32,7 +35,6 @@ type memoryStore struct {
|
||||
items map[string]*memoryStoreItem
|
||||
pubsub map[string][]chan *Message
|
||||
mu sync.RWMutex
|
||||
// [USER SUGGESTION APPLIED] 使用带锁的随机数源以保证并发安全
|
||||
rng *rand.Rand
|
||||
rngMu sync.Mutex
|
||||
logger *logrus.Entry
|
||||
@@ -42,7 +44,6 @@ func NewMemoryStore(logger *logrus.Logger) Store {
|
||||
store := &memoryStore{
|
||||
items: make(map[string]*memoryStoreItem),
|
||||
pubsub: make(map[string][]chan *Message),
|
||||
// 使用当前时间作为种子,创建一个新的随机数源
|
||||
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
logger: logger.WithField("component", "store.memory 🗱"),
|
||||
}
|
||||
@@ -50,13 +51,12 @@ func NewMemoryStore(logger *logrus.Logger) Store {
|
||||
return store
|
||||
}
|
||||
|
||||
// [USER SUGGESTION INCORPORATED] Fix #1: 使用 now := time.Now() 进行原子性检查
|
||||
func (s *memoryStore) startGCollector() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
s.mu.Lock()
|
||||
now := time.Now() // 避免在循环中重复调用
|
||||
now := time.Now()
|
||||
for key, item := range s.items {
|
||||
if !item.expireAt.IsZero() && now.After(item.expireAt) {
|
||||
delete(s.items, key)
|
||||
@@ -66,92 +66,10 @@ func (s *memoryStore) startGCollector() {
|
||||
}
|
||||
}
|
||||
|
||||
// [USER SUGGESTION INCORPORATED] Fix #2 & #3: 修复了致命的nil检查和类型断言问题
|
||||
func (s *memoryStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
// --- 所有方法签名都增加了 context.Context 参数以匹配接口 ---
|
||||
// --- 内存实现可以忽略该参数,用 _ 接收 ---
|
||||
|
||||
mainItem, mainOk := s.items[mainKey]
|
||||
var mainSet map[string]struct{}
|
||||
|
||||
if mainOk && !mainItem.isExpired() {
|
||||
// 安全地进行类型断言
|
||||
mainSet, mainOk = mainItem.value.(map[string]struct{})
|
||||
// 确保断言成功且集合不为空
|
||||
mainOk = mainOk && len(mainSet) > 0
|
||||
} else {
|
||||
mainOk = false
|
||||
}
|
||||
|
||||
if !mainOk {
|
||||
cooldownItem, cooldownOk := s.items[cooldownKey]
|
||||
if !cooldownOk || cooldownItem.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
// 安全地进行类型断言
|
||||
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
|
||||
if !cooldownSetOk || len(cooldownSet) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
s.items[mainKey] = cooldownItem
|
||||
delete(s.items, cooldownKey)
|
||||
mainSet = cooldownSet
|
||||
}
|
||||
|
||||
var popped string
|
||||
for k := range mainSet {
|
||||
popped = k
|
||||
break
|
||||
}
|
||||
delete(mainSet, popped)
|
||||
|
||||
cooldownItem, cooldownOk := s.items[cooldownKey]
|
||||
if !cooldownOk || cooldownItem.isExpired() {
|
||||
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
|
||||
s.items[cooldownKey] = cooldownItem
|
||||
}
|
||||
// 安全地处理冷却池
|
||||
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
|
||||
if !ok {
|
||||
cooldownSet = make(map[string]struct{})
|
||||
cooldownItem.value = cooldownSet
|
||||
}
|
||||
cooldownSet[popped] = struct{}{}
|
||||
|
||||
return popped, nil
|
||||
}
|
||||
|
||||
// SRandMember [并发修复版] 使用带锁的rng
|
||||
func (s *memoryStore) SRandMember(key string) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok || len(set) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
members := make([]string, 0, len(set))
|
||||
for member := range set {
|
||||
members = append(members, member)
|
||||
}
|
||||
if len(members) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
s.rngMu.Lock()
|
||||
n := s.rng.Intn(len(members))
|
||||
s.rngMu.Unlock()
|
||||
|
||||
return members[n], nil
|
||||
}
|
||||
|
||||
// --- 以下是其余函数的最终版本,它们都遵循了安全、原子的锁策略 ---
|
||||
|
||||
func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
|
||||
func (s *memoryStore) Set(_ context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var expireAt time.Time
|
||||
@@ -162,7 +80,7 @@ func (s *memoryStore) Set(key string, value []byte, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Get(key string) ([]byte, error) {
|
||||
func (s *memoryStore) Get(_ context.Context, key string) ([]byte, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -175,7 +93,7 @@ func (s *memoryStore) Get(key string) ([]byte, error) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
func (s *memoryStore) Del(keys ...string) error {
|
||||
func (s *memoryStore) Del(_ context.Context, keys ...string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, key := range keys {
|
||||
@@ -184,14 +102,25 @@ func (s *memoryStore) Del(keys ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Exists(key string) (bool, error) {
|
||||
func (s *memoryStore) Exists(_ context.Context, key string) (bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
return ok && !item.isExpired(), nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
func (s *memoryStore) Expire(_ context.Context, key string, expiration time.Duration) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok {
|
||||
return ErrNotFound
|
||||
}
|
||||
item.expireAt = time.Now().Add(expiration)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SetNX(_ context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -208,7 +137,7 @@ func (s *memoryStore) SetNX(key string, value []byte, ttl time.Duration) (bool,
|
||||
|
||||
func (s *memoryStore) Close() error { return nil }
|
||||
|
||||
func (s *memoryStore) HDel(key string, fields ...string) error {
|
||||
func (s *memoryStore) HDel(_ context.Context, key string, fields ...string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -223,7 +152,7 @@ func (s *memoryStore) HDel(key string, fields ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) HSet(key string, values map[string]any) error {
|
||||
func (s *memoryStore) HSet(_ context.Context, key string, values map[string]any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -242,7 +171,22 @@ func (s *memoryStore) HSet(key string, values map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
|
||||
func (s *memoryStore) HGet(_ context.Context, key, field string) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
if hash, ok := item.value.(map[string]string); ok {
|
||||
if value, exists := hash[field]; exists {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
func (s *memoryStore) HGetAll(_ context.Context, key string) (map[string]string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -259,7 +203,7 @@ func (s *memoryStore) HGetAll(key string) (map[string]string, error) {
|
||||
return make(map[string]string), nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
|
||||
func (s *memoryStore) HIncrBy(_ context.Context, key, field string, incr int64) (int64, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -281,7 +225,7 @@ func (s *memoryStore) HIncrBy(key, field string, incr int64) (int64, error) {
|
||||
return newVal, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) LPush(key string, values ...any) error {
|
||||
func (s *memoryStore) LPush(_ context.Context, key string, values ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -301,7 +245,7 @@ func (s *memoryStore) LPush(key string, values ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) LRem(key string, count int64, value any) error {
|
||||
func (s *memoryStore) LRem(_ context.Context, key string, count int64, value any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -326,7 +270,7 @@ func (s *memoryStore) LRem(key string, count int64, value any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SAdd(key string, members ...any) error {
|
||||
func (s *memoryStore) SAdd(_ context.Context, key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -345,7 +289,7 @@ func (s *memoryStore) SAdd(key string, members ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
|
||||
func (s *memoryStore) SPopN(_ context.Context, key string, count int64) ([]string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -375,7 +319,7 @@ func (s *memoryStore) SPopN(key string, count int64) ([]string, error) {
|
||||
return popped, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SMembers(key string) ([]string, error) {
|
||||
func (s *memoryStore) SMembers(_ context.Context, key string) ([]string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -393,7 +337,7 @@ func (s *memoryStore) SMembers(key string) ([]string, error) {
|
||||
return members, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SRem(key string, members ...any) error {
|
||||
func (s *memoryStore) SRem(_ context.Context, key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -410,7 +354,51 @@ func (s *memoryStore) SRem(key string, members ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Rotate(key string) (string, error) {
|
||||
func (s *memoryStore) SRandMember(_ context.Context, key string) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
set, ok := item.value.(map[string]struct{})
|
||||
if !ok || len(set) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
members := make([]string, 0, len(set))
|
||||
for member := range set {
|
||||
members = append(members, member)
|
||||
}
|
||||
if len(members) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
s.rngMu.Lock()
|
||||
n := s.rng.Intn(len(members))
|
||||
s.rngMu.Unlock()
|
||||
return members[n], nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) SUnionStore(_ context.Context, destination string, keys ...string) (int64, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
unionSet := make(map[string]struct{})
|
||||
for _, key := range keys {
|
||||
item, ok := s.items[key]
|
||||
if !ok || item.isExpired() {
|
||||
continue
|
||||
}
|
||||
if set, ok := item.value.(map[string]struct{}); ok {
|
||||
for member := range set {
|
||||
unionSet[member] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
destItem := &memoryStoreItem{value: unionSet}
|
||||
s.items[destination] = destItem
|
||||
return int64(len(unionSet)), nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) Rotate(_ context.Context, key string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -426,7 +414,7 @@ func (s *memoryStore) Rotate(key string) (string, error) {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) LIndex(key string, index int64) (string, error) {
|
||||
func (s *memoryStore) LIndex(_ context.Context, key string, index int64) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -447,8 +435,17 @@ func (s *memoryStore) LIndex(key string, index int64) (string, error) {
|
||||
return list[index], nil
|
||||
}
|
||||
|
||||
// Zset methods... (ZAdd, ZRange, ZRem)
|
||||
func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
|
||||
func (s *memoryStore) MSet(ctx context.Context, values map[string]any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for key, value := range values {
|
||||
// 内存存储不支持独立的 TTL,因此我们假设永不过期
|
||||
s.items[key] = &memoryStoreItem{value: value, expireAt: time.Time{}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) ZAdd(_ context.Context, key string, members map[string]float64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -471,8 +468,6 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
|
||||
for val, score := range membersMap {
|
||||
newZSet = append(newZSet, zsetMember{Value: val, Score: score})
|
||||
}
|
||||
// NOTE: This ZSet implementation is simple but not performant for large sets.
|
||||
// A production implementation would use a skip list or a balanced tree.
|
||||
sort.Slice(newZSet, func(i, j int) bool {
|
||||
if newZSet[i].Score == newZSet[j].Score {
|
||||
return newZSet[i].Value < newZSet[j].Value
|
||||
@@ -482,7 +477,7 @@ func (s *memoryStore) ZAdd(key string, members map[string]float64) error {
|
||||
item.value = newZSet
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
|
||||
func (s *memoryStore) ZRange(_ context.Context, key string, start, stop int64) ([]string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -515,7 +510,7 @@ func (s *memoryStore) ZRange(key string, start, stop int64) ([]string, error) {
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (s *memoryStore) ZRem(key string, members ...any) error {
|
||||
func (s *memoryStore) ZRem(_ context.Context, key string, members ...any) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
item, ok := s.items[key]
|
||||
@@ -540,13 +535,56 @@ func (s *memoryStore) ZRem(key string, members ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pipeline implementation
|
||||
func (s *memoryStore) PopAndCycleSetMember(_ context.Context, mainKey, cooldownKey string) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
mainItem, mainOk := s.items[mainKey]
|
||||
var mainSet map[string]struct{}
|
||||
if mainOk && !mainItem.isExpired() {
|
||||
mainSet, mainOk = mainItem.value.(map[string]struct{})
|
||||
mainOk = mainOk && len(mainSet) > 0
|
||||
} else {
|
||||
mainOk = false
|
||||
}
|
||||
if !mainOk {
|
||||
cooldownItem, cooldownOk := s.items[cooldownKey]
|
||||
if !cooldownOk || cooldownItem.isExpired() {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
cooldownSet, cooldownSetOk := cooldownItem.value.(map[string]struct{})
|
||||
if !cooldownSetOk || len(cooldownSet) == 0 {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
s.items[mainKey] = cooldownItem
|
||||
delete(s.items, cooldownKey)
|
||||
mainSet = cooldownSet
|
||||
}
|
||||
var popped string
|
||||
for k := range mainSet {
|
||||
popped = k
|
||||
break
|
||||
}
|
||||
delete(mainSet, popped)
|
||||
cooldownItem, cooldownOk := s.items[cooldownKey]
|
||||
if !cooldownOk || cooldownItem.isExpired() {
|
||||
cooldownItem = &memoryStoreItem{value: make(map[string]struct{})}
|
||||
s.items[cooldownKey] = cooldownItem
|
||||
}
|
||||
cooldownSet, ok := cooldownItem.value.(map[string]struct{})
|
||||
if !ok {
|
||||
cooldownSet = make(map[string]struct{})
|
||||
cooldownItem.value = cooldownSet
|
||||
}
|
||||
cooldownSet[popped] = struct{}{}
|
||||
return popped, nil
|
||||
}
|
||||
|
||||
type memoryPipeliner struct {
|
||||
store *memoryStore
|
||||
ops []func()
|
||||
}
|
||||
|
||||
func (s *memoryStore) Pipeline() Pipeliner {
|
||||
func (s *memoryStore) Pipeline(_ context.Context) Pipeliner {
|
||||
return &memoryPipeliner{store: s}
|
||||
}
|
||||
func (p *memoryPipeliner) Exec() error {
|
||||
@@ -559,7 +597,6 @@ func (p *memoryPipeliner) Exec() error {
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) Expire(key string, expiration time.Duration) {
|
||||
// [USER SUGGESTION APPLIED] Fix #4: Capture value, not reference
|
||||
capturedKey := key
|
||||
p.ops = append(p.ops, func() {
|
||||
if item, ok := p.store.items[capturedKey]; ok {
|
||||
@@ -576,6 +613,22 @@ func (p *memoryPipeliner) Del(keys ...string) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) Set(key string, value []byte, expiration time.Duration) {
|
||||
capturedKey := key
|
||||
capturedValue := value
|
||||
p.ops = append(p.ops, func() {
|
||||
var expireAt time.Time
|
||||
if expiration > 0 {
|
||||
expireAt = time.Now().Add(expiration)
|
||||
}
|
||||
p.store.items[capturedKey] = &memoryStoreItem{
|
||||
value: capturedValue,
|
||||
expireAt: expireAt,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) SAdd(key string, members ...any) {
|
||||
capturedKey := key
|
||||
capturedMembers := make([]any, len(members))
|
||||
@@ -615,7 +668,6 @@ func (p *memoryPipeliner) SRem(key string, members ...any) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) LPush(key string, values ...any) {
|
||||
capturedKey := key
|
||||
capturedValues := make([]any, len(values))
|
||||
@@ -637,11 +689,150 @@ func (p *memoryPipeliner) LPush(key string, values ...any) {
|
||||
item.value = append(stringValues, list...)
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) LRem(key string, count int64, value any) {}
|
||||
func (p *memoryPipeliner) HSet(key string, values map[string]any) {}
|
||||
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {}
|
||||
func (p *memoryPipeliner) LRem(key string, count int64, value any) {
|
||||
capturedKey := key
|
||||
capturedValue := fmt.Sprintf("%v", value)
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
return
|
||||
}
|
||||
list, ok := item.value.([]string)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
newList := make([]string, 0, len(list))
|
||||
removed := int64(0)
|
||||
for _, v := range list {
|
||||
shouldRemove := v == capturedValue && (count == 0 || removed < count)
|
||||
if shouldRemove {
|
||||
removed++
|
||||
} else {
|
||||
newList = append(newList, v)
|
||||
}
|
||||
}
|
||||
item.value = newList
|
||||
})
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) HSet(key string, values map[string]any) {
|
||||
capturedKey := key
|
||||
capturedValues := make(map[string]any, len(values))
|
||||
for k, v := range values {
|
||||
capturedValues[k] = v
|
||||
}
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
item = &memoryStoreItem{value: make(map[string]string)}
|
||||
p.store.items[capturedKey] = item
|
||||
}
|
||||
hash, ok := item.value.(map[string]string)
|
||||
if !ok {
|
||||
hash = make(map[string]string)
|
||||
item.value = hash
|
||||
}
|
||||
for field, value := range capturedValues {
|
||||
hash[field] = fmt.Sprintf("%v", value)
|
||||
}
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) HIncrBy(key, field string, incr int64) {
|
||||
capturedKey := key
|
||||
capturedField := field
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
item = &memoryStoreItem{value: make(map[string]string)}
|
||||
p.store.items[capturedKey] = item
|
||||
}
|
||||
hash, ok := item.value.(map[string]string)
|
||||
if !ok {
|
||||
hash = make(map[string]string)
|
||||
item.value = hash
|
||||
}
|
||||
current, _ := strconv.ParseInt(hash[capturedField], 10, 64)
|
||||
hash[capturedField] = strconv.FormatInt(current+incr, 10)
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) ZAdd(key string, members map[string]float64) {
|
||||
capturedKey := key
|
||||
capturedMembers := make(map[string]float64, len(members))
|
||||
for k, v := range members {
|
||||
capturedMembers[k] = v
|
||||
}
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
item = &memoryStoreItem{value: make([]zsetMember, 0)}
|
||||
p.store.items[capturedKey] = item
|
||||
}
|
||||
zset, ok := item.value.([]zsetMember)
|
||||
if !ok {
|
||||
zset = make([]zsetMember, 0)
|
||||
}
|
||||
membersMap := make(map[string]float64, len(zset))
|
||||
for _, z := range zset {
|
||||
membersMap[z.Value] = z.Score
|
||||
}
|
||||
for memberVal, score := range capturedMembers {
|
||||
membersMap[memberVal] = score
|
||||
}
|
||||
newZSet := make([]zsetMember, 0, len(membersMap))
|
||||
for val, score := range membersMap {
|
||||
newZSet = append(newZSet, zsetMember{Value: val, Score: score})
|
||||
}
|
||||
sort.Slice(newZSet, func(i, j int) bool {
|
||||
if newZSet[i].Score == newZSet[j].Score {
|
||||
return newZSet[i].Value < newZSet[j].Value
|
||||
}
|
||||
return newZSet[i].Score < newZSet[j].Score
|
||||
})
|
||||
item.value = newZSet
|
||||
})
|
||||
}
|
||||
func (p *memoryPipeliner) ZRem(key string, members ...any) {
|
||||
capturedKey := key
|
||||
capturedMembers := make([]any, len(members))
|
||||
copy(capturedMembers, members)
|
||||
p.ops = append(p.ops, func() {
|
||||
item, ok := p.store.items[capturedKey]
|
||||
if !ok || item.isExpired() {
|
||||
return
|
||||
}
|
||||
zset, ok := item.value.([]zsetMember)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
membersToRemove := make(map[string]struct{}, len(capturedMembers))
|
||||
for _, m := range capturedMembers {
|
||||
membersToRemove[fmt.Sprintf("%v", m)] = struct{}{}
|
||||
}
|
||||
newZSet := make([]zsetMember, 0, len(zset))
|
||||
for _, z := range zset {
|
||||
if _, exists := membersToRemove[z.Value]; !exists {
|
||||
newZSet = append(newZSet, z)
|
||||
}
|
||||
}
|
||||
item.value = newZSet
|
||||
})
|
||||
}
|
||||
|
||||
func (p *memoryPipeliner) MSet(values map[string]any) {
|
||||
capturedValues := make(map[string]any, len(values))
|
||||
for k, v := range values {
|
||||
capturedValues[k] = v
|
||||
}
|
||||
p.ops = append(p.ops, func() {
|
||||
for key, value := range capturedValues {
|
||||
p.store.items[key] = &memoryStoreItem{
|
||||
value: value,
|
||||
expireAt: time.Time{}, // Pipelined MSet 同样假设永不过期
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// --- Pub/Sub implementation (remains unchanged) ---
|
||||
type memorySubscription struct {
|
||||
store *memoryStore
|
||||
channelName string
|
||||
@@ -649,10 +840,11 @@ type memorySubscription struct {
|
||||
}
|
||||
|
||||
func (ms *memorySubscription) Channel() <-chan *Message { return ms.msgChan }
|
||||
func (ms *memorySubscription) ChannelName() string { return ms.channelName }
|
||||
func (ms *memorySubscription) Close() error {
|
||||
return ms.store.removeSubscriber(ms.channelName, ms.msgChan)
|
||||
}
|
||||
func (s *memoryStore) Publish(channel string, message []byte) error {
|
||||
func (s *memoryStore) Publish(_ context.Context, channel string, message []byte) error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
subscribers, ok := s.pubsub[channel]
|
||||
@@ -669,7 +861,7 @@ func (s *memoryStore) Publish(channel string, message []byte) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *memoryStore) Subscribe(channel string) (Subscription, error) {
|
||||
func (s *memoryStore) Subscribe(_ context.Context, channel string) (Subscription, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
msgChan := make(chan *Message, 10)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// Filename: internal/store/redis_store.go
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
@@ -8,22 +10,20 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ensure RedisStore implements Store interface
|
||||
var _ Store = (*RedisStore)(nil)
|
||||
|
||||
// RedisStore is a Redis-backed key-value store.
|
||||
type RedisStore struct {
|
||||
client *redis.Client
|
||||
popAndCycleScript *redis.Script
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
// NewRedisStore creates a new RedisStore instance.
|
||||
func NewRedisStore(client *redis.Client) Store {
|
||||
// Lua script for atomic pop-and-cycle operation.
|
||||
// KEYS[1]: main set key
|
||||
// KEYS[2]: cooldown set key
|
||||
func NewRedisStore(client *redis.Client, logger *logrus.Logger) Store {
|
||||
const script = `
|
||||
if redis.call('SCARD', KEYS[1]) == 0 then
|
||||
if redis.call('SCARD', KEYS[2]) == 0 then
|
||||
@@ -36,15 +36,16 @@ func NewRedisStore(client *redis.Client) Store {
|
||||
return &RedisStore{
|
||||
client: client,
|
||||
popAndCycleScript: redis.NewScript(script),
|
||||
logger: logger.WithField("component", "store.redis 🗄️"),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RedisStore) Set(key string, value []byte, ttl time.Duration) error {
|
||||
return s.client.Set(context.Background(), key, value, ttl).Err()
|
||||
func (s *RedisStore) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
return s.client.Set(ctx, key, value, ttl).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Get(key string) ([]byte, error) {
|
||||
val, err := s.client.Get(context.Background(), key).Bytes()
|
||||
func (s *RedisStore) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
val, err := s.client.Get(ctx, key).Bytes()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrNotFound
|
||||
@@ -54,53 +55,67 @@ func (s *RedisStore) Get(key string) ([]byte, error) {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) Del(keys ...string) error {
|
||||
func (s *RedisStore) Del(ctx context.Context, keys ...string) error {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.Del(context.Background(), keys...).Err()
|
||||
return s.client.Del(ctx, keys...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Exists(key string) (bool, error) {
|
||||
val, err := s.client.Exists(context.Background(), key).Result()
|
||||
func (s *RedisStore) Exists(ctx context.Context, key string) (bool, error) {
|
||||
val, err := s.client.Exists(ctx, key).Result()
|
||||
return val > 0, err
|
||||
}
|
||||
|
||||
func (s *RedisStore) SetNX(key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
return s.client.SetNX(context.Background(), key, value, ttl).Result()
|
||||
func (s *RedisStore) SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error) {
|
||||
return s.client.SetNX(ctx, key, value, ttl).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Close() error {
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HSet(key string, values map[string]any) error {
|
||||
return s.client.HSet(context.Background(), key, values).Err()
|
||||
func (s *RedisStore) Expire(ctx context.Context, key string, expiration time.Duration) error {
|
||||
return s.client.Expire(ctx, key, expiration).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HGetAll(key string) (map[string]string, error) {
|
||||
return s.client.HGetAll(context.Background(), key).Result()
|
||||
func (s *RedisStore) HSet(ctx context.Context, key string, values map[string]any) error {
|
||||
return s.client.HSet(ctx, key, values).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HIncrBy(key, field string, incr int64) (int64, error) {
|
||||
return s.client.HIncrBy(context.Background(), key, field, incr).Result()
|
||||
func (s *RedisStore) HGet(ctx context.Context, key, field string) (string, error) {
|
||||
val, err := s.client.HGet(ctx, key, field).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
func (s *RedisStore) HDel(key string, fields ...string) error {
|
||||
func (s *RedisStore) HGetAll(ctx context.Context, key string) (map[string]string, error) {
|
||||
return s.client.HGetAll(ctx, key).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error) {
|
||||
return s.client.HIncrBy(ctx, key, field, incr).Result()
|
||||
}
|
||||
func (s *RedisStore) HDel(ctx context.Context, key string, fields ...string) error {
|
||||
if len(fields) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.HDel(context.Background(), key, fields...).Err()
|
||||
return s.client.HDel(ctx, key, fields...).Err()
|
||||
}
|
||||
func (s *RedisStore) LPush(key string, values ...any) error {
|
||||
return s.client.LPush(context.Background(), key, values...).Err()
|
||||
func (s *RedisStore) LPush(ctx context.Context, key string, values ...any) error {
|
||||
return s.client.LPush(ctx, key, values...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) LRem(key string, count int64, value any) error {
|
||||
return s.client.LRem(context.Background(), key, count, value).Err()
|
||||
func (s *RedisStore) LRem(ctx context.Context, key string, count int64, value any) error {
|
||||
return s.client.LRem(ctx, key, count, value).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) Rotate(key string) (string, error) {
|
||||
val, err := s.client.RPopLPush(context.Background(), key, key).Result()
|
||||
func (s *RedisStore) Rotate(ctx context.Context, key string) (string, error) {
|
||||
val, err := s.client.RPopLPush(ctx, key, key).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
@@ -110,29 +125,40 @@ func (s *RedisStore) Rotate(key string) (string, error) {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) SAdd(key string, members ...any) error {
|
||||
return s.client.SAdd(context.Background(), key, members...).Err()
|
||||
func (s *RedisStore) MSet(ctx context.Context, values map[string]any) error {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
// Redis MSet 命令需要 [key1, value1, key2, value2, ...] 格式的切片
|
||||
pairs := make([]interface{}, 0, len(values)*2)
|
||||
for k, v := range values {
|
||||
pairs = append(pairs, k, v)
|
||||
}
|
||||
return s.client.MSet(ctx, pairs...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SPopN(key string, count int64) ([]string, error) {
|
||||
return s.client.SPopN(context.Background(), key, count).Result()
|
||||
func (s *RedisStore) SAdd(ctx context.Context, key string, members ...any) error {
|
||||
return s.client.SAdd(ctx, key, members...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SMembers(key string) ([]string, error) {
|
||||
return s.client.SMembers(context.Background(), key).Result()
|
||||
func (s *RedisStore) SPopN(ctx context.Context, key string, count int64) ([]string, error) {
|
||||
return s.client.SPopN(ctx, key, count).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SRem(key string, members ...any) error {
|
||||
func (s *RedisStore) SMembers(ctx context.Context, key string) ([]string, error) {
|
||||
return s.client.SMembers(ctx, key).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SRem(ctx context.Context, key string, members ...any) error {
|
||||
if len(members) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.SRem(context.Background(), key, members...).Err()
|
||||
return s.client.SRem(ctx, key, members...).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) SRandMember(key string) (string, error) {
|
||||
member, err := s.client.SRandMember(context.Background(), key).Result()
|
||||
func (s *RedisStore) SRandMember(ctx context.Context, key string) (string, error) {
|
||||
member, err := s.client.SRandMember(ctx, key).Result()
|
||||
if err != nil {
|
||||
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
@@ -141,81 +167,50 @@ func (s *RedisStore) SRandMember(key string) (string, error) {
|
||||
return member, nil
|
||||
}
|
||||
|
||||
// === 新增方法实现 ===
|
||||
func (s *RedisStore) SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error) {
|
||||
if len(keys) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return s.client.SUnionStore(ctx, destination, keys...).Result()
|
||||
}
|
||||
|
||||
func (s *RedisStore) ZAdd(key string, members map[string]float64) error {
|
||||
func (s *RedisStore) ZAdd(ctx context.Context, key string, members map[string]float64) error {
|
||||
if len(members) == 0 {
|
||||
return nil
|
||||
}
|
||||
redisMembers := make([]redis.Z, 0, len(members))
|
||||
redisMembers := make([]redis.Z, len(members))
|
||||
i := 0
|
||||
for member, score := range members {
|
||||
redisMembers = append(redisMembers, redis.Z{Score: score, Member: member})
|
||||
redisMembers[i] = redis.Z{Score: score, Member: member}
|
||||
i++
|
||||
}
|
||||
return s.client.ZAdd(context.Background(), key, redisMembers...).Err()
|
||||
return s.client.ZAdd(ctx, key, redisMembers...).Err()
|
||||
}
|
||||
func (s *RedisStore) ZRange(key string, start, stop int64) ([]string, error) {
|
||||
return s.client.ZRange(context.Background(), key, start, stop).Result()
|
||||
func (s *RedisStore) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) {
|
||||
return s.client.ZRange(ctx, key, start, stop).Result()
|
||||
}
|
||||
func (s *RedisStore) ZRem(key string, members ...any) error {
|
||||
func (s *RedisStore) ZRem(ctx context.Context, key string, members ...any) error {
|
||||
if len(members) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.client.ZRem(context.Background(), key, members...).Err()
|
||||
return s.client.ZRem(ctx, key, members...).Err()
|
||||
}
|
||||
func (s *RedisStore) PopAndCycleSetMember(mainKey, cooldownKey string) (string, error) {
|
||||
val, err := s.popAndCycleScript.Run(context.Background(), s.client, []string{mainKey, cooldownKey}).Result()
|
||||
func (s *RedisStore) PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error) {
|
||||
val, err := s.popAndCycleScript.Run(ctx, s.client, []string{mainKey, cooldownKey}).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
// Lua script returns a string, so we need to type assert
|
||||
if str, ok := val.(string); ok {
|
||||
return str, nil
|
||||
}
|
||||
return "", ErrNotFound // This happens if both sets were empty and the script returned nil
|
||||
return "", ErrNotFound
|
||||
}
|
||||
|
||||
type redisPipeliner struct{ pipe redis.Pipeliner }
|
||||
|
||||
func (p *redisPipeliner) HSet(key string, values map[string]any) {
|
||||
p.pipe.HSet(context.Background(), key, values)
|
||||
}
|
||||
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
|
||||
p.pipe.HIncrBy(context.Background(), key, field, incr)
|
||||
}
|
||||
func (p *redisPipeliner) Exec() error {
|
||||
_, err := p.pipe.Exec(context.Background())
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Del(keys ...string) {
|
||||
if len(keys) > 0 {
|
||||
p.pipe.Del(context.Background(), keys...)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) SAdd(key string, members ...any) {
|
||||
p.pipe.SAdd(context.Background(), key, members...)
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) SRem(key string, members ...any) {
|
||||
if len(members) > 0 {
|
||||
p.pipe.SRem(context.Background(), key, members...)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) LPush(key string, values ...any) {
|
||||
p.pipe.LPush(context.Background(), key, values...)
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) LRem(key string, count int64, value any) {
|
||||
p.pipe.LRem(context.Background(), key, count, value)
|
||||
}
|
||||
|
||||
func (s *RedisStore) LIndex(key string, index int64) (string, error) {
|
||||
val, err := s.client.LIndex(context.Background(), key, index).Result()
|
||||
func (s *RedisStore) LIndex(ctx context.Context, key string, index int64) (string, error) {
|
||||
val, err := s.client.LIndex(ctx, key, index).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return "", ErrNotFound
|
||||
@@ -225,47 +220,131 @@ func (s *RedisStore) LIndex(key string, index int64) (string, error) {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Expire(key string, expiration time.Duration) {
|
||||
p.pipe.Expire(context.Background(), key, expiration)
|
||||
type redisPipeliner struct {
|
||||
pipe redis.Pipeliner
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s *RedisStore) Pipeline() Pipeliner {
|
||||
return &redisPipeliner{pipe: s.client.Pipeline()}
|
||||
func (s *RedisStore) Pipeline(ctx context.Context) Pipeliner {
|
||||
return &redisPipeliner{
|
||||
pipe: s.client.Pipeline(),
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Exec() error {
|
||||
_, err := p.pipe.Exec(p.ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Del(keys ...string) { p.pipe.Del(p.ctx, keys...) }
|
||||
func (p *redisPipeliner) Expire(key string, expiration time.Duration) {
|
||||
p.pipe.Expire(p.ctx, key, expiration)
|
||||
}
|
||||
func (p *redisPipeliner) HSet(key string, values map[string]any) { p.pipe.HSet(p.ctx, key, values) }
|
||||
func (p *redisPipeliner) HIncrBy(key, field string, incr int64) {
|
||||
p.pipe.HIncrBy(p.ctx, key, field, incr)
|
||||
}
|
||||
func (p *redisPipeliner) LPush(key string, values ...any) { p.pipe.LPush(p.ctx, key, values...) }
|
||||
func (p *redisPipeliner) LRem(key string, count int64, value any) {
|
||||
p.pipe.LRem(p.ctx, key, count, value)
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) Set(key string, value []byte, expiration time.Duration) {
|
||||
p.pipe.Set(p.ctx, key, value, expiration)
|
||||
}
|
||||
|
||||
func (p *redisPipeliner) MSet(values map[string]any) {
|
||||
if len(values) == 0 {
|
||||
return
|
||||
}
|
||||
p.pipe.MSet(p.ctx, values)
|
||||
}
|
||||
func (p *redisPipeliner) SAdd(key string, members ...any) { p.pipe.SAdd(p.ctx, key, members...) }
|
||||
func (p *redisPipeliner) SRem(key string, members ...any) { p.pipe.SRem(p.ctx, key, members...) }
|
||||
func (p *redisPipeliner) ZAdd(key string, members map[string]float64) {
|
||||
if len(members) == 0 {
|
||||
return
|
||||
}
|
||||
redisMembers := make([]redis.Z, len(members))
|
||||
i := 0
|
||||
for member, score := range members {
|
||||
redisMembers[i] = redis.Z{Score: score, Member: member}
|
||||
i++
|
||||
}
|
||||
p.pipe.ZAdd(p.ctx, key, redisMembers...)
|
||||
}
|
||||
func (p *redisPipeliner) ZRem(key string, members ...any) { p.pipe.ZRem(p.ctx, key, members...) }
|
||||
|
||||
type redisSubscription struct {
|
||||
pubsub *redis.PubSub
|
||||
msgChan chan *Message
|
||||
once sync.Once
|
||||
pubsub *redis.PubSub
|
||||
msgChan chan *Message
|
||||
logger *logrus.Entry
|
||||
wg sync.WaitGroup
|
||||
close context.CancelFunc
|
||||
channelName string
|
||||
}
|
||||
|
||||
func (s *RedisStore) Subscribe(ctx context.Context, channel string) (Subscription, error) {
|
||||
pubsub := s.client.Subscribe(ctx, channel)
|
||||
_, err := pubsub.Receive(ctx)
|
||||
if err != nil {
|
||||
_ = pubsub.Close()
|
||||
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err)
|
||||
}
|
||||
subCtx, cancel := context.WithCancel(context.Background())
|
||||
sub := &redisSubscription{
|
||||
pubsub: pubsub,
|
||||
msgChan: make(chan *Message, 10),
|
||||
logger: s.logger,
|
||||
close: cancel,
|
||||
channelName: channel,
|
||||
}
|
||||
sub.wg.Add(1)
|
||||
go sub.bridge(subCtx)
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
func (rs *redisSubscription) bridge(ctx context.Context) {
|
||||
defer rs.wg.Done()
|
||||
defer close(rs.msgChan)
|
||||
redisCh := rs.pubsub.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case redisMsg, ok := <-redisCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
msg := &Message{
|
||||
Channel: redisMsg.Channel,
|
||||
Payload: []byte(redisMsg.Payload),
|
||||
}
|
||||
select {
|
||||
case rs.msgChan <- msg:
|
||||
default:
|
||||
rs.logger.Warnf("Message dropped for channel '%s' due to slow consumer.", rs.channelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *redisSubscription) Channel() <-chan *Message {
|
||||
rs.once.Do(func() {
|
||||
rs.msgChan = make(chan *Message)
|
||||
go func() {
|
||||
defer close(rs.msgChan)
|
||||
for redisMsg := range rs.pubsub.Channel() {
|
||||
rs.msgChan <- &Message{
|
||||
Channel: redisMsg.Channel,
|
||||
Payload: []byte(redisMsg.Payload),
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
return rs.msgChan
|
||||
}
|
||||
|
||||
func (rs *redisSubscription) Close() error { return rs.pubsub.Close() }
|
||||
|
||||
func (s *RedisStore) Publish(channel string, message []byte) error {
|
||||
return s.client.Publish(context.Background(), channel, message).Err()
|
||||
func (rs *redisSubscription) ChannelName() string {
|
||||
return rs.channelName
|
||||
}
|
||||
|
||||
func (s *RedisStore) Subscribe(channel string) (Subscription, error) {
|
||||
pubsub := s.client.Subscribe(context.Background(), channel)
|
||||
_, err := pubsub.Receive(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to subscribe to channel %s: %w", channel, err)
|
||||
}
|
||||
return &redisSubscription{pubsub: pubsub}, nil
|
||||
func (rs *redisSubscription) Close() error {
|
||||
rs.close()
|
||||
err := rs.pubsub.Close()
|
||||
rs.wg.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *RedisStore) Publish(ctx context.Context, channel string, message []byte) error {
|
||||
return s.client.Publish(ctx, channel, message).Err()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
// Filename: internal/store/store.go
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
@@ -17,6 +20,7 @@ type Message struct {
|
||||
// Subscription represents an active subscription to a pub/sub channel.
|
||||
type Subscription interface {
|
||||
Channel() <-chan *Message
|
||||
ChannelName() string
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -31,6 +35,8 @@ type Pipeliner interface {
|
||||
HIncrBy(key, field string, incr int64)
|
||||
|
||||
// SET
|
||||
MSet(values map[string]any)
|
||||
Set(key string, value []byte, expiration time.Duration)
|
||||
SAdd(key string, members ...any)
|
||||
SRem(key string, members ...any)
|
||||
|
||||
@@ -38,6 +44,10 @@ type Pipeliner interface {
|
||||
LPush(key string, values ...any)
|
||||
LRem(key string, count int64, value any)
|
||||
|
||||
// ZSET
|
||||
ZAdd(key string, members map[string]float64)
|
||||
ZRem(key string, members ...any)
|
||||
|
||||
// Execution
|
||||
Exec() error
|
||||
}
|
||||
@@ -45,44 +55,48 @@ type Pipeliner interface {
|
||||
// Store is the master interface for our cache service.
|
||||
type Store interface {
|
||||
// Basic K/V operations
|
||||
Set(key string, value []byte, ttl time.Duration) error
|
||||
Get(key string) ([]byte, error)
|
||||
Del(keys ...string) error
|
||||
Exists(key string) (bool, error)
|
||||
SetNX(key string, value []byte, ttl time.Duration) (bool, error)
|
||||
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
||||
Get(ctx context.Context, key string) ([]byte, error)
|
||||
Del(ctx context.Context, keys ...string) error
|
||||
Exists(ctx context.Context, key string) (bool, error)
|
||||
SetNX(ctx context.Context, key string, value []byte, ttl time.Duration) (bool, error)
|
||||
MSet(ctx context.Context, values map[string]any) error
|
||||
|
||||
// HASH operations
|
||||
HSet(key string, values map[string]any) error
|
||||
HGetAll(key string) (map[string]string, error)
|
||||
HIncrBy(key, field string, incr int64) (int64, error)
|
||||
HDel(key string, fields ...string) error // [新增]
|
||||
HSet(ctx context.Context, key string, values map[string]any) error
|
||||
HGet(ctx context.Context, key, field string) (string, error)
|
||||
HGetAll(ctx context.Context, key string) (map[string]string, error)
|
||||
HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error)
|
||||
HDel(ctx context.Context, key string, fields ...string) error
|
||||
|
||||
// LIST operations
|
||||
LPush(key string, values ...any) error
|
||||
LRem(key string, count int64, value any) error
|
||||
Rotate(key string) (string, error)
|
||||
LIndex(key string, index int64) (string, error)
|
||||
LPush(ctx context.Context, key string, values ...any) error
|
||||
LRem(ctx context.Context, key string, count int64, value any) error
|
||||
Rotate(ctx context.Context, key string) (string, error)
|
||||
LIndex(ctx context.Context, key string, index int64) (string, error)
|
||||
Expire(ctx context.Context, key string, expiration time.Duration) error
|
||||
|
||||
// SET operations
|
||||
SAdd(key string, members ...any) error
|
||||
SPopN(key string, count int64) ([]string, error)
|
||||
SMembers(key string) ([]string, error)
|
||||
SRem(key string, members ...any) error
|
||||
SRandMember(key string) (string, error)
|
||||
SAdd(ctx context.Context, key string, members ...any) error
|
||||
SPopN(ctx context.Context, key string, count int64) ([]string, error)
|
||||
SMembers(ctx context.Context, key string) ([]string, error)
|
||||
SRem(ctx context.Context, key string, members ...any) error
|
||||
SRandMember(ctx context.Context, key string) (string, error)
|
||||
SUnionStore(ctx context.Context, destination string, keys ...string) (int64, error)
|
||||
|
||||
// Pub/Sub operations
|
||||
Publish(channel string, message []byte) error
|
||||
Subscribe(channel string) (Subscription, error)
|
||||
Publish(ctx context.Context, channel string, message []byte) error
|
||||
Subscribe(ctx context.Context, channel string) (Subscription, error)
|
||||
|
||||
// Pipeline (optional) - 我们在redis实现它,内存版暂时不实现
|
||||
Pipeline() Pipeliner
|
||||
// Pipeline
|
||||
Pipeline(ctx context.Context) Pipeliner
|
||||
|
||||
// Close closes the store and releases any underlying resources.
|
||||
Close() error
|
||||
|
||||
// === 新增方法,支持轮询策略 ===
|
||||
ZAdd(key string, members map[string]float64) error
|
||||
ZRange(key string, start, stop int64) ([]string, error)
|
||||
ZRem(key string, members ...any) error
|
||||
PopAndCycleSetMember(mainKey, cooldownKey string) (string, error)
|
||||
// ZSET operations
|
||||
ZAdd(ctx context.Context, key string, members map[string]float64) error
|
||||
ZRange(ctx context.Context, key string, start, stop int64) ([]string, error)
|
||||
ZRem(ctx context.Context, key string, members ...any) error
|
||||
PopAndCycleSetMember(ctx context.Context, mainKey, cooldownKey string) (string, error)
|
||||
}
|
||||
|
||||
@@ -1,48 +1,57 @@
|
||||
package syncer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gemini-balancer/internal/store"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
ReconnectDelay = 5 * time.Second
|
||||
ReloadTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// LoaderFunc
|
||||
type LoaderFunc[T any] func() (T, error)
|
||||
|
||||
// CacheSyncer
|
||||
type CacheSyncer[T any] struct {
|
||||
mu sync.RWMutex
|
||||
cache T
|
||||
loader LoaderFunc[T]
|
||||
store store.Store
|
||||
channelName string
|
||||
logger *logrus.Entry
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewCacheSyncer
|
||||
func NewCacheSyncer[T any](
|
||||
loader LoaderFunc[T],
|
||||
store store.Store,
|
||||
channelName string,
|
||||
logger *logrus.Logger,
|
||||
) (*CacheSyncer[T], error) {
|
||||
s := &CacheSyncer[T]{
|
||||
loader: loader,
|
||||
store: store,
|
||||
channelName: channelName,
|
||||
logger: logger.WithField("component", fmt.Sprintf("CacheSyncer[%s]", channelName)),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
if err := s.reload(); err != nil {
|
||||
return nil, fmt.Errorf("initial load for %s failed: %w", channelName, err)
|
||||
return nil, fmt.Errorf("initial load failed: %w", err)
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.listenForUpdates()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Get, Invalidate, Stop, reload 方法 .
|
||||
func (s *CacheSyncer[T]) Get() T {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
@@ -50,33 +59,60 @@ func (s *CacheSyncer[T]) Get() T {
|
||||
}
|
||||
|
||||
func (s *CacheSyncer[T]) Invalidate() error {
|
||||
log.Printf("INFO: Publishing invalidation notification on channel '%s'", s.channelName)
|
||||
return s.store.Publish(s.channelName, []byte("reload"))
|
||||
s.logger.Info("Publishing invalidation notification")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.store.Publish(ctx, s.channelName, []byte("reload")); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to publish invalidation")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *CacheSyncer[T]) Stop() {
|
||||
close(s.stopChan)
|
||||
s.wg.Wait()
|
||||
log.Printf("INFO: CacheSyncer for channel '%s' stopped.", s.channelName)
|
||||
s.logger.Info("CacheSyncer stopped")
|
||||
}
|
||||
|
||||
func (s *CacheSyncer[T]) reload() error {
|
||||
log.Printf("INFO: Reloading cache for channel '%s'...", s.channelName)
|
||||
newData, err := s.loader()
|
||||
if err != nil {
|
||||
log.Printf("ERROR: Failed to reload cache for '%s': %v", s.channelName, err)
|
||||
return err
|
||||
s.logger.Info("Reloading cache...")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), ReloadTimeout)
|
||||
defer cancel()
|
||||
|
||||
type result struct {
|
||||
data T
|
||||
err error
|
||||
}
|
||||
resultChan := make(chan result, 1)
|
||||
|
||||
go func() {
|
||||
data, err := s.loader()
|
||||
resultChan <- result{data, err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-resultChan:
|
||||
if res.err != nil {
|
||||
s.logger.WithError(res.err).Error("Failed to reload cache")
|
||||
return res.err
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.cache = res.data
|
||||
s.mu.Unlock()
|
||||
s.logger.Info("Cache reloaded successfully")
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
s.logger.Error("Cache reload timeout")
|
||||
return fmt.Errorf("reload timeout after %v", ReloadTimeout)
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.cache = newData
|
||||
s.mu.Unlock()
|
||||
log.Printf("INFO: Cache for channel '%s' reloaded successfully.", s.channelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// listenForUpdates ...
|
||||
func (s *CacheSyncer[T]) listenForUpdates() {
|
||||
defer s.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.stopChan:
|
||||
@@ -84,31 +120,39 @@ func (s *CacheSyncer[T]) listenForUpdates() {
|
||||
default:
|
||||
}
|
||||
|
||||
subscription, err := s.store.Subscribe(s.channelName)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: Failed to subscribe to '%s', retrying in 5s: %v", s.channelName, err)
|
||||
time.Sleep(5 * time.Second)
|
||||
continue
|
||||
}
|
||||
log.Printf("INFO: Subscribed to channel '%s' for cache invalidation.", s.channelName)
|
||||
|
||||
subscriberLoop:
|
||||
for {
|
||||
if err := s.subscribeAndListen(); err != nil {
|
||||
s.logger.WithError(err).Warnf("Subscription error, retrying in %v", ReconnectDelay)
|
||||
select {
|
||||
case _, ok := <-subscription.Channel():
|
||||
if !ok {
|
||||
log.Printf("WARN: Subscription channel '%s' closed, will re-subscribe.", s.channelName)
|
||||
break subscriberLoop
|
||||
}
|
||||
log.Printf("INFO: Received invalidation notification on '%s', reloading cache.", s.channelName)
|
||||
if err := s.reload(); err != nil {
|
||||
log.Printf("ERROR: Failed to reload cache for '%s' after notification: %v", s.channelName, err)
|
||||
}
|
||||
case <-time.After(ReconnectDelay):
|
||||
case <-s.stopChan:
|
||||
subscription.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
subscription.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CacheSyncer[T]) subscribeAndListen() error {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
subscription, err := s.store.Subscribe(ctx, s.channelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe: %w", err)
|
||||
}
|
||||
defer subscription.Close()
|
||||
s.logger.Info("Subscribed to channel")
|
||||
for {
|
||||
select {
|
||||
case msg, ok := <-subscription.Channel():
|
||||
if !ok {
|
||||
return fmt.Errorf("subscription channel closed")
|
||||
}
|
||||
s.logger.WithField("message", string(msg.Payload)).Info("Received invalidation notification")
|
||||
if err := s.reload(); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to reload after notification")
|
||||
}
|
||||
case <-s.stopChan:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Filename: internal/task/task.go (最终校准版)
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -12,18 +12,18 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
ResultTTL = 60 * time.Minute
|
||||
ResultTTL = 60 * time.Minute
|
||||
DefaultTimeout = 24 * time.Hour
|
||||
LockTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
// Reporter 接口,定义了领域如何与任务服务交互。
|
||||
type Reporter interface {
|
||||
StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
|
||||
EndTaskByID(taskID, resourceID string, result any, taskErr error)
|
||||
UpdateProgressByID(taskID string, processed int) error
|
||||
UpdateTotalByID(taskID string, total int) error
|
||||
StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error)
|
||||
EndTaskByID(ctx context.Context, taskID, resourceID string, result any, taskErr error)
|
||||
UpdateProgressByID(ctx context.Context, taskID string, processed int) error
|
||||
UpdateTotalByID(ctx context.Context, taskID string, total int) error
|
||||
}
|
||||
|
||||
// Status 代表一个后台任务的完整状态
|
||||
type Status struct {
|
||||
ID string `json:"id"`
|
||||
TaskType string `json:"task_type"`
|
||||
@@ -38,13 +38,11 @@ type Status struct {
|
||||
DurationSeconds float64 `json:"duration_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// Task 是任务管理的核心服务
|
||||
type Task struct {
|
||||
store store.Store
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
// NewTask 是 Task 的构造函数
|
||||
func NewTask(store store.Store, logger *logrus.Logger) *Task {
|
||||
return &Task{
|
||||
store: store,
|
||||
@@ -62,21 +60,27 @@ func (s *Task) getTaskDataKey(taskID string) string {
|
||||
return fmt.Sprintf("task:data:%s", taskID)
|
||||
}
|
||||
|
||||
// --- 新增的輔助函數,用於獲取原子標記的鍵 ---
|
||||
func (s *Task) getIsRunningFlagKey(taskID string) string {
|
||||
return fmt.Sprintf("task:running:%s", taskID)
|
||||
}
|
||||
|
||||
func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
|
||||
func (s *Task) StartTask(ctx context.Context, keyGroupID uint, taskType, resourceID string, total int, timeout time.Duration) (*Status, error) {
|
||||
lockKey := s.getResourceLockKey(resourceID)
|
||||
taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID)
|
||||
|
||||
if existingTaskID, err := s.store.Get(lockKey); err == nil && len(existingTaskID) > 0 {
|
||||
locked, err := s.store.SetNX(ctx, lockKey, []byte(taskID), LockTTL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to acquire task lock: %w", err)
|
||||
}
|
||||
if !locked {
|
||||
existingTaskID, _ := s.store.Get(ctx, lockKey)
|
||||
return nil, fmt.Errorf("a task is already running for this resource (ID: %s)", string(existingTaskID))
|
||||
}
|
||||
|
||||
taskID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), keyGroupID)
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||||
if timeout == 0 {
|
||||
timeout = DefaultTimeout
|
||||
}
|
||||
|
||||
status := &Status{
|
||||
ID: taskID,
|
||||
TaskType: taskType,
|
||||
@@ -85,81 +89,71 @@ func (s *Task) StartTask(keyGroupID uint, taskType, resourceID string, total int
|
||||
Total: total,
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
statusBytes, err := json.Marshal(status)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to serialize new task status: %w", err)
|
||||
|
||||
if err := s.saveStatus(ctx, taskID, status, timeout); err != nil {
|
||||
_ = s.store.Del(ctx, lockKey)
|
||||
return nil, fmt.Errorf("failed to save task status: %w", err)
|
||||
}
|
||||
|
||||
if timeout == 0 {
|
||||
timeout = ResultTTL * 24
|
||||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||||
if err := s.store.Set(ctx, runningFlagKey, []byte("1"), timeout); err != nil {
|
||||
_ = s.store.Del(ctx, lockKey)
|
||||
_ = s.store.Del(ctx, s.getTaskDataKey(taskID))
|
||||
return nil, fmt.Errorf("failed to set running flag: %w", err)
|
||||
}
|
||||
|
||||
if err := s.store.Set(lockKey, []byte(taskID), timeout); err != nil {
|
||||
return nil, fmt.Errorf("failed to acquire task resource lock: %w", err)
|
||||
}
|
||||
if err := s.store.Set(taskKey, statusBytes, timeout); err != nil {
|
||||
_ = s.store.Del(lockKey)
|
||||
return nil, fmt.Errorf("failed to set new task data in store: %w", err)
|
||||
}
|
||||
|
||||
// 創建一個獨立的“運行中”標記,它的存在與否是原子性的
|
||||
if err := s.store.Set(runningFlagKey, []byte("1"), timeout); err != nil {
|
||||
_ = s.store.Del(lockKey)
|
||||
_ = s.store.Del(taskKey)
|
||||
return nil, fmt.Errorf("failed to set task running flag: %w", err)
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *Task) EndTaskByID(taskID, resourceID string, resultData any, taskErr error) {
|
||||
func (s *Task) EndTaskByID(ctx context.Context, taskID, resourceID string, resultData any, taskErr error) {
|
||||
lockKey := s.getResourceLockKey(resourceID)
|
||||
defer func() {
|
||||
if err := s.store.Del(lockKey); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to release resource lock '%s' for task %s.", lockKey, taskID)
|
||||
}
|
||||
}()
|
||||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||||
_ = s.store.Del(runningFlagKey)
|
||||
status, err := s.GetStatus(taskID)
|
||||
if err != nil {
|
||||
|
||||
s.logger.WithError(err).Errorf("Could not get task status for task ID %s during EndTask. Lock has been released, but task data may be stale.", taskID)
|
||||
defer func() {
|
||||
_ = s.store.Del(ctx, lockKey)
|
||||
_ = s.store.Del(ctx, runningFlagKey)
|
||||
}()
|
||||
|
||||
status, err := s.GetStatus(ctx, taskID)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to get task status for %s during EndTask", taskID)
|
||||
return
|
||||
}
|
||||
|
||||
if !status.IsRunning {
|
||||
s.logger.Warnf("EndTaskByID called for an already finished task: %s", taskID)
|
||||
s.logger.Warnf("EndTaskByID called for already finished task: %s", taskID)
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
status.IsRunning = false
|
||||
status.FinishedAt = &now
|
||||
status.DurationSeconds = now.Sub(status.StartedAt).Seconds()
|
||||
|
||||
if taskErr != nil {
|
||||
status.Error = taskErr.Error()
|
||||
} else {
|
||||
status.Result = resultData
|
||||
}
|
||||
updatedTaskBytes, _ := json.Marshal(status)
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
if err := s.store.Set(taskKey, updatedTaskBytes, ResultTTL); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to save final status for task %s.", taskID)
|
||||
|
||||
if err := s.saveStatus(ctx, taskID, status, ResultTTL); err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to save final status for task %s", taskID)
|
||||
}
|
||||
}
|
||||
|
||||
// GetStatus 通过ID获取任务状态,供外部(如API Handler)调用
|
||||
func (s *Task) GetStatus(taskID string) (*Status, error) {
|
||||
func (s *Task) GetStatus(ctx context.Context, taskID string) (*Status, error) {
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
statusBytes, err := s.store.Get(taskKey)
|
||||
statusBytes, err := s.store.Get(ctx, taskKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) {
|
||||
return nil, errors.New("task not found")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get task status from store: %w", err)
|
||||
return nil, fmt.Errorf("failed to get task status: %w", err)
|
||||
}
|
||||
|
||||
var status Status
|
||||
if err := json.Unmarshal(statusBytes, &status); err != nil {
|
||||
return nil, fmt.Errorf("corrupted task data in store for ID %s", taskID)
|
||||
return nil, fmt.Errorf("corrupted task data for ID %s: %w", taskID, err)
|
||||
}
|
||||
|
||||
if !status.IsRunning && status.FinishedAt != nil {
|
||||
@@ -169,46 +163,51 @@ func (s *Task) GetStatus(taskID string) (*Status, error) {
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
// UpdateProgressByID 通过ID更新任务进度
|
||||
func (s *Task) updateTask(taskID string, updater func(status *Status)) error {
|
||||
func (s *Task) updateTask(ctx context.Context, taskID string, updater func(status *Status)) error {
|
||||
runningFlagKey := s.getIsRunningFlagKey(taskID)
|
||||
if _, err := s.store.Get(runningFlagKey); err != nil {
|
||||
// 任务已结束,静默返回是预期行为。
|
||||
return nil
|
||||
if _, err := s.store.Get(ctx, runningFlagKey); err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) {
|
||||
return errors.New("task is not running")
|
||||
}
|
||||
return fmt.Errorf("failed to check running flag: %w", err)
|
||||
}
|
||||
status, err := s.GetStatus(taskID)
|
||||
|
||||
status, err := s.GetStatus(ctx, taskID)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to get task status for update on task %s. Update will not be saved.", taskID)
|
||||
return nil
|
||||
return fmt.Errorf("failed to get task status: %w", err)
|
||||
}
|
||||
|
||||
if !status.IsRunning {
|
||||
return nil
|
||||
return errors.New("task is not running")
|
||||
}
|
||||
// 调用传入的 updater 函数来修改 status
|
||||
|
||||
updater(status)
|
||||
statusBytes, marshalErr := json.Marshal(status)
|
||||
if marshalErr != nil {
|
||||
s.logger.WithError(marshalErr).Errorf("Failed to serialize status for update on task %s.", taskID)
|
||||
return nil
|
||||
}
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
// 使用更长的TTL,确保运行中的任务不会过早过期
|
||||
if err := s.store.Set(taskKey, statusBytes, ResultTTL*24); err != nil {
|
||||
s.logger.WithError(err).Warnf("Failed to save update for task %s.", taskID)
|
||||
}
|
||||
return nil
|
||||
|
||||
return s.saveStatus(ctx, taskID, status, DefaultTimeout)
|
||||
}
|
||||
|
||||
// [REFACTORED] UpdateProgressByID 现在是一个简单的、调用通用更新器的包装器。
|
||||
func (s *Task) UpdateProgressByID(taskID string, processed int) error {
|
||||
return s.updateTask(taskID, func(status *Status) {
|
||||
func (s *Task) UpdateProgressByID(ctx context.Context, taskID string, processed int) error {
|
||||
return s.updateTask(ctx, taskID, func(status *Status) {
|
||||
status.Processed = processed
|
||||
})
|
||||
}
|
||||
|
||||
// [REFACTORED] UpdateTotalByID 现在也是一个简单的、调用通用更新器的包装器。
|
||||
func (s *Task) UpdateTotalByID(taskID string, total int) error {
|
||||
return s.updateTask(taskID, func(status *Status) {
|
||||
func (s *Task) UpdateTotalByID(ctx context.Context, taskID string, total int) error {
|
||||
return s.updateTask(ctx, taskID, func(status *Status) {
|
||||
status.Total = total
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Task) saveStatus(ctx context.Context, taskID string, status *Status, ttl time.Duration) error {
|
||||
statusBytes, err := json.Marshal(status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to serialize status: %w", err)
|
||||
}
|
||||
|
||||
taskKey := s.getTaskDataKey(taskID)
|
||||
if err := s.store.Set(ctx, taskKey, statusBytes, ttl); err != nil {
|
||||
return fmt.Errorf("failed to save status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,43 +1,118 @@
|
||||
// Filename: internal/webhandlers/auth_handler.go (最终现代化改造版)
|
||||
// Filename: internal/webhandlers/auth_handler.go
|
||||
|
||||
package webhandlers
|
||||
|
||||
import (
|
||||
"gemini-balancer/internal/middleware"
|
||||
"gemini-balancer/internal/service" // [核心改造] 依赖service层
|
||||
"gemini-balancer/internal/service"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// WebAuthHandler [核心改造] 依赖关系净化,注入SecurityService
|
||||
// WebAuthHandler Web 认证处理器
|
||||
type WebAuthHandler struct {
|
||||
securityService *service.SecurityService
|
||||
logger *logrus.Logger
|
||||
}
|
||||
|
||||
// NewWebAuthHandler [核心改造] 构造函数更新
|
||||
// NewWebAuthHandler 创建 WebAuthHandler
|
||||
func NewWebAuthHandler(securityService *service.SecurityService) *WebAuthHandler {
|
||||
logger := logrus.New()
|
||||
logger.SetLevel(logrus.InfoLevel)
|
||||
|
||||
return &WebAuthHandler{
|
||||
securityService: securityService,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ShowLoginPage 保持不变
|
||||
// ShowLoginPage 显示登录页面
|
||||
func (h *WebAuthHandler) ShowLoginPage(c *gin.Context) {
|
||||
errMsg := c.Query("error")
|
||||
from := c.Query("from") // 可以从登录失败的页面返回
|
||||
|
||||
// 验证重定向路径(防止开放重定向攻击)
|
||||
redirectPath := h.validateRedirectPath(c.Query("redirect"))
|
||||
|
||||
// 如果已登录,直接重定向
|
||||
if cookie := middleware.ExtractTokenFromCookie(c); cookie != "" {
|
||||
if _, err := h.securityService.AuthenticateToken(cookie); err == nil {
|
||||
c.Redirect(http.StatusFound, redirectPath)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.HTML(http.StatusOK, "auth.html", gin.H{
|
||||
"error": errMsg,
|
||||
"from": from,
|
||||
"error": errMsg,
|
||||
"redirect": redirectPath,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleLogin [核心改造] 认证逻辑完全委托给SecurityService
|
||||
// HandleLogin 已废弃(项目无用户名系统)
|
||||
func (h *WebAuthHandler) HandleLogin(c *gin.Context) {
|
||||
c.Redirect(http.StatusFound, "/login?error=DEPRECATED_LOGIN_METHOD")
|
||||
}
|
||||
|
||||
// HandleLogout 保持不变
|
||||
// HandleLogout 处理登出请求
|
||||
func (h *WebAuthHandler) HandleLogout(c *gin.Context) {
|
||||
cookie := middleware.ExtractTokenFromCookie(c)
|
||||
|
||||
if cookie != "" {
|
||||
// 尝试获取 Token 信息用于日志
|
||||
authToken, err := h.securityService.AuthenticateToken(cookie)
|
||||
if err == nil {
|
||||
h.logger.WithFields(logrus.Fields{
|
||||
"token_id": authToken.ID,
|
||||
"client_ip": c.ClientIP(),
|
||||
}).Info("User logged out")
|
||||
} else {
|
||||
h.logger.WithField("client_ip", c.ClientIP()).Warn("Logout with invalid token")
|
||||
}
|
||||
|
||||
// 使缓存失效
|
||||
middleware.InvalidateTokenCache(cookie)
|
||||
} else {
|
||||
h.logger.WithField("client_ip", c.ClientIP()).Debug("Logout without session cookie")
|
||||
}
|
||||
|
||||
// 清除 Cookie
|
||||
middleware.ClearAdminSessionCookie(c)
|
||||
|
||||
// 重定向到登录页
|
||||
c.Redirect(http.StatusFound, "/login")
|
||||
}
|
||||
|
||||
// validateRedirectPath 验证重定向路径(防止开放重定向攻击)
|
||||
func (h *WebAuthHandler) validateRedirectPath(path string) string {
|
||||
defaultPath := "/dashboard"
|
||||
|
||||
if path == "" {
|
||||
return defaultPath
|
||||
}
|
||||
|
||||
// 只允许内部路径
|
||||
if !strings.HasPrefix(path, "/") || strings.HasPrefix(path, "//") {
|
||||
h.logger.WithField("path", path).Warn("Invalid redirect path blocked")
|
||||
return defaultPath
|
||||
}
|
||||
|
||||
// 白名单验证
|
||||
allowedPaths := []string{
|
||||
"/dashboard",
|
||||
"/keys",
|
||||
"/settings",
|
||||
"/logs",
|
||||
"/tasks",
|
||||
"/chat",
|
||||
}
|
||||
|
||||
for _, allowed := range allowedPaths {
|
||||
if strings.HasPrefix(path, allowed) {
|
||||
return path
|
||||
}
|
||||
}
|
||||
|
||||
return defaultPath
|
||||
}
|
||||
|
||||
@@ -8,6 +8,12 @@ module.exports = {
|
||||
'./web/templates/**/*.html',
|
||||
'./web/static/js/**/*.js',
|
||||
],
|
||||
safelist: [
|
||||
'grid-rows-[1]',
|
||||
{
|
||||
pattern: /data-\[(expanded|active)\]/,
|
||||
},
|
||||
],
|
||||
theme: {
|
||||
extend: {
|
||||
// 定义语义化颜色
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
1926
web/static/js/chat-2W4NJWMO.js
Normal file
1926
web/static/js/chat-2W4NJWMO.js
Normal file
File diff suppressed because it is too large
Load Diff
30
web/static/js/chunk-JSBRDJBE.js
Normal file
30
web/static/js/chunk-JSBRDJBE.js
Normal file
@@ -0,0 +1,30 @@
|
||||
var __create = Object.create;
|
||||
var __defProp = Object.defineProperty;
|
||||
var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
|
||||
var __getOwnPropNames = Object.getOwnPropertyNames;
|
||||
var __getProtoOf = Object.getPrototypeOf;
|
||||
var __hasOwnProp = Object.prototype.hasOwnProperty;
|
||||
var __commonJS = (cb, mod) => function __require() {
|
||||
return mod || (0, cb[__getOwnPropNames(cb)[0]])((mod = { exports: {} }).exports, mod), mod.exports;
|
||||
};
|
||||
var __copyProps = (to, from, except, desc) => {
|
||||
if (from && typeof from === "object" || typeof from === "function") {
|
||||
for (let key of __getOwnPropNames(from))
|
||||
if (!__hasOwnProp.call(to, key) && key !== except)
|
||||
__defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable });
|
||||
}
|
||||
return to;
|
||||
};
|
||||
var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__getProtoOf(mod)) : {}, __copyProps(
|
||||
// If the importer is in node compatibility mode or this is not an ESM
|
||||
// file that has been converted to a CommonJS file using a Babel-
|
||||
// compatible transform (i.e. "__esModule" has not been set), then set
|
||||
// "default" to the CommonJS "module.exports" for node compatibility.
|
||||
isNodeMode || !mod || !mod.__esModule ? __defProp(target, "default", { value: mod, enumerable: true }) : target,
|
||||
mod
|
||||
));
|
||||
|
||||
export {
|
||||
__commonJS,
|
||||
__toESM
|
||||
};
|
||||
@@ -1,83 +0,0 @@
|
||||
// frontend/js/services/api.js
|
||||
var APIClientError = class extends Error {
|
||||
constructor(message, status, code, rawMessageFromServer) {
|
||||
super(message);
|
||||
this.name = "APIClientError";
|
||||
this.status = status;
|
||||
this.code = code;
|
||||
this.rawMessageFromServer = rawMessageFromServer;
|
||||
}
|
||||
};
|
||||
var apiPromiseCache = /* @__PURE__ */ new Map();
|
||||
async function apiFetch(url, options = {}) {
|
||||
const isGetRequest = !options.method || options.method.toUpperCase() === "GET";
|
||||
const cacheKey = isGetRequest && !options.noCache ? url : null;
|
||||
if (cacheKey && apiPromiseCache.has(cacheKey)) {
|
||||
return apiPromiseCache.get(cacheKey);
|
||||
}
|
||||
const token = localStorage.getItem("bearerToken");
|
||||
const headers = {
|
||||
"Content-Type": "application/json",
|
||||
...options.headers
|
||||
};
|
||||
if (token) {
|
||||
headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
const requestPromise = (async () => {
|
||||
try {
|
||||
const response = await fetch(url, { ...options, headers });
|
||||
if (response.status === 401) {
|
||||
if (cacheKey) apiPromiseCache.delete(cacheKey);
|
||||
localStorage.removeItem("bearerToken");
|
||||
if (window.location.pathname !== "/login") {
|
||||
window.location.href = "/login?error=\u4F1A\u8BDD\u5DF2\u8FC7\u671F\uFF0C\u8BF7\u91CD\u65B0\u767B\u5F55\u3002";
|
||||
}
|
||||
throw new APIClientError("Unauthorized", 401, "UNAUTHORIZED", "Session expired or token is invalid.");
|
||||
}
|
||||
if (!response.ok) {
|
||||
let errorData = null;
|
||||
let rawMessage = "";
|
||||
try {
|
||||
rawMessage = await response.text();
|
||||
if (rawMessage) {
|
||||
errorData = JSON.parse(rawMessage);
|
||||
}
|
||||
} catch (e) {
|
||||
errorData = { error: { code: "UNKNOWN_FORMAT", message: rawMessage || response.statusText } };
|
||||
}
|
||||
const code = errorData?.error?.code || "UNKNOWN_ERROR";
|
||||
const messageFromServer = errorData?.error?.message || rawMessage || "No message provided by server.";
|
||||
const error = new APIClientError(
|
||||
`API request failed: ${response.status}`,
|
||||
response.status,
|
||||
code,
|
||||
messageFromServer
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
return response;
|
||||
} catch (error) {
|
||||
if (cacheKey) apiPromiseCache.delete(cacheKey);
|
||||
throw error;
|
||||
}
|
||||
})();
|
||||
if (cacheKey) {
|
||||
apiPromiseCache.set(cacheKey, requestPromise);
|
||||
}
|
||||
return requestPromise;
|
||||
}
|
||||
async function apiFetchJson(url, options = {}) {
|
||||
try {
|
||||
const response = await apiFetch(url, options);
|
||||
const clonedResponse = response.clone();
|
||||
const jsonData = await clonedResponse.json();
|
||||
return jsonData;
|
||||
} catch (error) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
export {
|
||||
apiFetch,
|
||||
apiFetchJson
|
||||
};
|
||||
606
web/static/js/chunk-T5V6LQ42.js
Normal file
606
web/static/js/chunk-T5V6LQ42.js
Normal file
@@ -0,0 +1,606 @@
|
||||
import {
|
||||
__commonJS,
|
||||
__toESM
|
||||
} from "./chunk-JSBRDJBE.js";
|
||||
|
||||
// frontend/js/vendor/popper.esm.min.js
|
||||
var require_popper_esm_min = __commonJS({
|
||||
"frontend/js/vendor/popper.esm.min.js"(exports, module) {
|
||||
!(function(e, t) {
|
||||
"object" == typeof exports && "undefined" != typeof module ? t(exports) : "function" == typeof define && define.amd ? define(["exports"], t) : t((e = "undefined" != typeof globalThis ? globalThis : e || self).Popper = {});
|
||||
})(exports, (function(e) {
|
||||
"use strict";
|
||||
function t(e2) {
|
||||
if (null == e2) return window;
|
||||
if ("[object Window]" !== e2.toString()) {
|
||||
var t2 = e2.ownerDocument;
|
||||
return t2 && t2.defaultView || window;
|
||||
}
|
||||
return e2;
|
||||
}
|
||||
function n(e2) {
|
||||
return e2 instanceof t(e2).Element || e2 instanceof Element;
|
||||
}
|
||||
function r(e2) {
|
||||
return e2 instanceof t(e2).HTMLElement || e2 instanceof HTMLElement;
|
||||
}
|
||||
function o(e2) {
|
||||
return "undefined" != typeof ShadowRoot && (e2 instanceof t(e2).ShadowRoot || e2 instanceof ShadowRoot);
|
||||
}
|
||||
var i = Math.max, a = Math.min, s = Math.round;
|
||||
function f() {
|
||||
var e2 = navigator.userAgentData;
|
||||
return null != e2 && e2.brands && Array.isArray(e2.brands) ? e2.brands.map((function(e3) {
|
||||
return e3.brand + "/" + e3.version;
|
||||
})).join(" ") : navigator.userAgent;
|
||||
}
|
||||
function c() {
|
||||
return !/^((?!chrome|android).)*safari/i.test(f());
|
||||
}
|
||||
function p(e2, o2, i2) {
|
||||
void 0 === o2 && (o2 = false), void 0 === i2 && (i2 = false);
|
||||
var a2 = e2.getBoundingClientRect(), f2 = 1, p2 = 1;
|
||||
o2 && r(e2) && (f2 = e2.offsetWidth > 0 && s(a2.width) / e2.offsetWidth || 1, p2 = e2.offsetHeight > 0 && s(a2.height) / e2.offsetHeight || 1);
|
||||
var u2 = (n(e2) ? t(e2) : window).visualViewport, l2 = !c() && i2, d2 = (a2.left + (l2 && u2 ? u2.offsetLeft : 0)) / f2, h2 = (a2.top + (l2 && u2 ? u2.offsetTop : 0)) / p2, m2 = a2.width / f2, v2 = a2.height / p2;
|
||||
return { width: m2, height: v2, top: h2, right: d2 + m2, bottom: h2 + v2, left: d2, x: d2, y: h2 };
|
||||
}
|
||||
function u(e2) {
|
||||
var n2 = t(e2);
|
||||
return { scrollLeft: n2.pageXOffset, scrollTop: n2.pageYOffset };
|
||||
}
|
||||
function l(e2) {
|
||||
return e2 ? (e2.nodeName || "").toLowerCase() : null;
|
||||
}
|
||||
function d(e2) {
|
||||
return ((n(e2) ? e2.ownerDocument : e2.document) || window.document).documentElement;
|
||||
}
|
||||
function h(e2) {
|
||||
return p(d(e2)).left + u(e2).scrollLeft;
|
||||
}
|
||||
function m(e2) {
|
||||
return t(e2).getComputedStyle(e2);
|
||||
}
|
||||
function v(e2) {
|
||||
var t2 = m(e2), n2 = t2.overflow, r2 = t2.overflowX, o2 = t2.overflowY;
|
||||
return /auto|scroll|overlay|hidden/.test(n2 + o2 + r2);
|
||||
}
|
||||
function y(e2, n2, o2) {
|
||||
void 0 === o2 && (o2 = false);
|
||||
var i2, a2, f2 = r(n2), c2 = r(n2) && (function(e3) {
|
||||
var t2 = e3.getBoundingClientRect(), n3 = s(t2.width) / e3.offsetWidth || 1, r2 = s(t2.height) / e3.offsetHeight || 1;
|
||||
return 1 !== n3 || 1 !== r2;
|
||||
})(n2), m2 = d(n2), y2 = p(e2, c2, o2), g2 = { scrollLeft: 0, scrollTop: 0 }, b2 = { x: 0, y: 0 };
|
||||
return (f2 || !f2 && !o2) && (("body" !== l(n2) || v(m2)) && (g2 = (i2 = n2) !== t(i2) && r(i2) ? { scrollLeft: (a2 = i2).scrollLeft, scrollTop: a2.scrollTop } : u(i2)), r(n2) ? ((b2 = p(n2, true)).x += n2.clientLeft, b2.y += n2.clientTop) : m2 && (b2.x = h(m2))), { x: y2.left + g2.scrollLeft - b2.x, y: y2.top + g2.scrollTop - b2.y, width: y2.width, height: y2.height };
|
||||
}
|
||||
function g(e2) {
|
||||
var t2 = p(e2), n2 = e2.offsetWidth, r2 = e2.offsetHeight;
|
||||
return Math.abs(t2.width - n2) <= 1 && (n2 = t2.width), Math.abs(t2.height - r2) <= 1 && (r2 = t2.height), { x: e2.offsetLeft, y: e2.offsetTop, width: n2, height: r2 };
|
||||
}
|
||||
function b(e2) {
|
||||
return "html" === l(e2) ? e2 : e2.assignedSlot || e2.parentNode || (o(e2) ? e2.host : null) || d(e2);
|
||||
}
|
||||
function x(e2) {
|
||||
return ["html", "body", "#document"].indexOf(l(e2)) >= 0 ? e2.ownerDocument.body : r(e2) && v(e2) ? e2 : x(b(e2));
|
||||
}
|
||||
function w(e2, n2) {
|
||||
var r2;
|
||||
void 0 === n2 && (n2 = []);
|
||||
var o2 = x(e2), i2 = o2 === (null == (r2 = e2.ownerDocument) ? void 0 : r2.body), a2 = t(o2), s2 = i2 ? [a2].concat(a2.visualViewport || [], v(o2) ? o2 : []) : o2, f2 = n2.concat(s2);
|
||||
return i2 ? f2 : f2.concat(w(b(s2)));
|
||||
}
|
||||
function O(e2) {
|
||||
return ["table", "td", "th"].indexOf(l(e2)) >= 0;
|
||||
}
|
||||
function j(e2) {
|
||||
return r(e2) && "fixed" !== m(e2).position ? e2.offsetParent : null;
|
||||
}
|
||||
function E(e2) {
|
||||
for (var n2 = t(e2), i2 = j(e2); i2 && O(i2) && "static" === m(i2).position; ) i2 = j(i2);
|
||||
return i2 && ("html" === l(i2) || "body" === l(i2) && "static" === m(i2).position) ? n2 : i2 || (function(e3) {
|
||||
var t2 = /firefox/i.test(f());
|
||||
if (/Trident/i.test(f()) && r(e3) && "fixed" === m(e3).position) return null;
|
||||
var n3 = b(e3);
|
||||
for (o(n3) && (n3 = n3.host); r(n3) && ["html", "body"].indexOf(l(n3)) < 0; ) {
|
||||
var i3 = m(n3);
|
||||
if ("none" !== i3.transform || "none" !== i3.perspective || "paint" === i3.contain || -1 !== ["transform", "perspective"].indexOf(i3.willChange) || t2 && "filter" === i3.willChange || t2 && i3.filter && "none" !== i3.filter) return n3;
|
||||
n3 = n3.parentNode;
|
||||
}
|
||||
return null;
|
||||
})(e2) || n2;
|
||||
}
|
||||
var D = "top", A = "bottom", L = "right", P = "left", M = "auto", k = [D, A, L, P], W = "start", B = "end", H = "viewport", T = "popper", R = k.reduce((function(e2, t2) {
|
||||
return e2.concat([t2 + "-" + W, t2 + "-" + B]);
|
||||
}), []), S = [].concat(k, [M]).reduce((function(e2, t2) {
|
||||
return e2.concat([t2, t2 + "-" + W, t2 + "-" + B]);
|
||||
}), []), V = ["beforeRead", "read", "afterRead", "beforeMain", "main", "afterMain", "beforeWrite", "write", "afterWrite"];
|
||||
function q(e2) {
|
||||
var t2 = /* @__PURE__ */ new Map(), n2 = /* @__PURE__ */ new Set(), r2 = [];
|
||||
function o2(e3) {
|
||||
n2.add(e3.name), [].concat(e3.requires || [], e3.requiresIfExists || []).forEach((function(e4) {
|
||||
if (!n2.has(e4)) {
|
||||
var r3 = t2.get(e4);
|
||||
r3 && o2(r3);
|
||||
}
|
||||
})), r2.push(e3);
|
||||
}
|
||||
return e2.forEach((function(e3) {
|
||||
t2.set(e3.name, e3);
|
||||
})), e2.forEach((function(e3) {
|
||||
n2.has(e3.name) || o2(e3);
|
||||
})), r2;
|
||||
}
|
||||
function C(e2, t2) {
|
||||
var n2 = t2.getRootNode && t2.getRootNode();
|
||||
if (e2.contains(t2)) return true;
|
||||
if (n2 && o(n2)) {
|
||||
var r2 = t2;
|
||||
do {
|
||||
if (r2 && e2.isSameNode(r2)) return true;
|
||||
r2 = r2.parentNode || r2.host;
|
||||
} while (r2);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
function N(e2) {
|
||||
return Object.assign({}, e2, { left: e2.x, top: e2.y, right: e2.x + e2.width, bottom: e2.y + e2.height });
|
||||
}
|
||||
function I(e2, r2, o2) {
|
||||
return r2 === H ? N((function(e3, n2) {
|
||||
var r3 = t(e3), o3 = d(e3), i2 = r3.visualViewport, a2 = o3.clientWidth, s2 = o3.clientHeight, f2 = 0, p2 = 0;
|
||||
if (i2) {
|
||||
a2 = i2.width, s2 = i2.height;
|
||||
var u2 = c();
|
||||
(u2 || !u2 && "fixed" === n2) && (f2 = i2.offsetLeft, p2 = i2.offsetTop);
|
||||
}
|
||||
return { width: a2, height: s2, x: f2 + h(e3), y: p2 };
|
||||
})(e2, o2)) : n(r2) ? (function(e3, t2) {
|
||||
var n2 = p(e3, false, "fixed" === t2);
|
||||
return n2.top = n2.top + e3.clientTop, n2.left = n2.left + e3.clientLeft, n2.bottom = n2.top + e3.clientHeight, n2.right = n2.left + e3.clientWidth, n2.width = e3.clientWidth, n2.height = e3.clientHeight, n2.x = n2.left, n2.y = n2.top, n2;
|
||||
})(r2, o2) : N((function(e3) {
|
||||
var t2, n2 = d(e3), r3 = u(e3), o3 = null == (t2 = e3.ownerDocument) ? void 0 : t2.body, a2 = i(n2.scrollWidth, n2.clientWidth, o3 ? o3.scrollWidth : 0, o3 ? o3.clientWidth : 0), s2 = i(n2.scrollHeight, n2.clientHeight, o3 ? o3.scrollHeight : 0, o3 ? o3.clientHeight : 0), f2 = -r3.scrollLeft + h(e3), c2 = -r3.scrollTop;
|
||||
return "rtl" === m(o3 || n2).direction && (f2 += i(n2.clientWidth, o3 ? o3.clientWidth : 0) - a2), { width: a2, height: s2, x: f2, y: c2 };
|
||||
})(d(e2)));
|
||||
}
|
||||
function _(e2, t2, o2, s2) {
|
||||
var f2 = "clippingParents" === t2 ? (function(e3) {
|
||||
var t3 = w(b(e3)), o3 = ["absolute", "fixed"].indexOf(m(e3).position) >= 0 && r(e3) ? E(e3) : e3;
|
||||
return n(o3) ? t3.filter((function(e4) {
|
||||
return n(e4) && C(e4, o3) && "body" !== l(e4);
|
||||
})) : [];
|
||||
})(e2) : [].concat(t2), c2 = [].concat(f2, [o2]), p2 = c2[0], u2 = c2.reduce((function(t3, n2) {
|
||||
var r2 = I(e2, n2, s2);
|
||||
return t3.top = i(r2.top, t3.top), t3.right = a(r2.right, t3.right), t3.bottom = a(r2.bottom, t3.bottom), t3.left = i(r2.left, t3.left), t3;
|
||||
}), I(e2, p2, s2));
|
||||
return u2.width = u2.right - u2.left, u2.height = u2.bottom - u2.top, u2.x = u2.left, u2.y = u2.top, u2;
|
||||
}
|
||||
function F(e2) {
|
||||
return e2.split("-")[0];
|
||||
}
|
||||
function U(e2) {
|
||||
return e2.split("-")[1];
|
||||
}
|
||||
function z(e2) {
|
||||
return ["top", "bottom"].indexOf(e2) >= 0 ? "x" : "y";
|
||||
}
|
||||
function X(e2) {
|
||||
var t2, n2 = e2.reference, r2 = e2.element, o2 = e2.placement, i2 = o2 ? F(o2) : null, a2 = o2 ? U(o2) : null, s2 = n2.x + n2.width / 2 - r2.width / 2, f2 = n2.y + n2.height / 2 - r2.height / 2;
|
||||
switch (i2) {
|
||||
case D:
|
||||
t2 = { x: s2, y: n2.y - r2.height };
|
||||
break;
|
||||
case A:
|
||||
t2 = { x: s2, y: n2.y + n2.height };
|
||||
break;
|
||||
case L:
|
||||
t2 = { x: n2.x + n2.width, y: f2 };
|
||||
break;
|
||||
case P:
|
||||
t2 = { x: n2.x - r2.width, y: f2 };
|
||||
break;
|
||||
default:
|
||||
t2 = { x: n2.x, y: n2.y };
|
||||
}
|
||||
var c2 = i2 ? z(i2) : null;
|
||||
if (null != c2) {
|
||||
var p2 = "y" === c2 ? "height" : "width";
|
||||
switch (a2) {
|
||||
case W:
|
||||
t2[c2] = t2[c2] - (n2[p2] / 2 - r2[p2] / 2);
|
||||
break;
|
||||
case B:
|
||||
t2[c2] = t2[c2] + (n2[p2] / 2 - r2[p2] / 2);
|
||||
}
|
||||
}
|
||||
return t2;
|
||||
}
|
||||
function Y(e2) {
|
||||
return Object.assign({}, { top: 0, right: 0, bottom: 0, left: 0 }, e2);
|
||||
}
|
||||
function G(e2, t2) {
|
||||
return t2.reduce((function(t3, n2) {
|
||||
return t3[n2] = e2, t3;
|
||||
}), {});
|
||||
}
|
||||
function J(e2, t2) {
|
||||
void 0 === t2 && (t2 = {});
|
||||
var r2 = t2, o2 = r2.placement, i2 = void 0 === o2 ? e2.placement : o2, a2 = r2.strategy, s2 = void 0 === a2 ? e2.strategy : a2, f2 = r2.boundary, c2 = void 0 === f2 ? "clippingParents" : f2, u2 = r2.rootBoundary, l2 = void 0 === u2 ? H : u2, h2 = r2.elementContext, m2 = void 0 === h2 ? T : h2, v2 = r2.altBoundary, y2 = void 0 !== v2 && v2, g2 = r2.padding, b2 = void 0 === g2 ? 0 : g2, x2 = Y("number" != typeof b2 ? b2 : G(b2, k)), w2 = m2 === T ? "reference" : T, O2 = e2.rects.popper, j2 = e2.elements[y2 ? w2 : m2], E2 = _(n(j2) ? j2 : j2.contextElement || d(e2.elements.popper), c2, l2, s2), P2 = p(e2.elements.reference), M2 = X({ reference: P2, element: O2, strategy: "absolute", placement: i2 }), W2 = N(Object.assign({}, O2, M2)), B2 = m2 === T ? W2 : P2, R2 = { top: E2.top - B2.top + x2.top, bottom: B2.bottom - E2.bottom + x2.bottom, left: E2.left - B2.left + x2.left, right: B2.right - E2.right + x2.right }, S2 = e2.modifiersData.offset;
|
||||
if (m2 === T && S2) {
|
||||
var V2 = S2[i2];
|
||||
Object.keys(R2).forEach((function(e3) {
|
||||
var t3 = [L, A].indexOf(e3) >= 0 ? 1 : -1, n2 = [D, A].indexOf(e3) >= 0 ? "y" : "x";
|
||||
R2[e3] += V2[n2] * t3;
|
||||
}));
|
||||
}
|
||||
return R2;
|
||||
}
|
||||
var K = { placement: "bottom", modifiers: [], strategy: "absolute" };
|
||||
function Q() {
|
||||
for (var e2 = arguments.length, t2 = new Array(e2), n2 = 0; n2 < e2; n2++) t2[n2] = arguments[n2];
|
||||
return !t2.some((function(e3) {
|
||||
return !(e3 && "function" == typeof e3.getBoundingClientRect);
|
||||
}));
|
||||
}
|
||||
function Z(e2) {
|
||||
void 0 === e2 && (e2 = {});
|
||||
var t2 = e2, r2 = t2.defaultModifiers, o2 = void 0 === r2 ? [] : r2, i2 = t2.defaultOptions, a2 = void 0 === i2 ? K : i2;
|
||||
return function(e3, t3, r3) {
|
||||
void 0 === r3 && (r3 = a2);
|
||||
var i3, s2, f2 = { placement: "bottom", orderedModifiers: [], options: Object.assign({}, K, a2), modifiersData: {}, elements: { reference: e3, popper: t3 }, attributes: {}, styles: {} }, c2 = [], p2 = false, u2 = { state: f2, setOptions: function(r4) {
|
||||
var i4 = "function" == typeof r4 ? r4(f2.options) : r4;
|
||||
l2(), f2.options = Object.assign({}, a2, f2.options, i4), f2.scrollParents = { reference: n(e3) ? w(e3) : e3.contextElement ? w(e3.contextElement) : [], popper: w(t3) };
|
||||
var s3, p3, d2 = (function(e4) {
|
||||
var t4 = q(e4);
|
||||
return V.reduce((function(e5, n2) {
|
||||
return e5.concat(t4.filter((function(e6) {
|
||||
return e6.phase === n2;
|
||||
})));
|
||||
}), []);
|
||||
})((s3 = [].concat(o2, f2.options.modifiers), p3 = s3.reduce((function(e4, t4) {
|
||||
var n2 = e4[t4.name];
|
||||
return e4[t4.name] = n2 ? Object.assign({}, n2, t4, { options: Object.assign({}, n2.options, t4.options), data: Object.assign({}, n2.data, t4.data) }) : t4, e4;
|
||||
}), {}), Object.keys(p3).map((function(e4) {
|
||||
return p3[e4];
|
||||
}))));
|
||||
return f2.orderedModifiers = d2.filter((function(e4) {
|
||||
return e4.enabled;
|
||||
})), f2.orderedModifiers.forEach((function(e4) {
|
||||
var t4 = e4.name, n2 = e4.options, r5 = void 0 === n2 ? {} : n2, o3 = e4.effect;
|
||||
if ("function" == typeof o3) {
|
||||
var i5 = o3({ state: f2, name: t4, instance: u2, options: r5 }), a3 = function() {
|
||||
};
|
||||
c2.push(i5 || a3);
|
||||
}
|
||||
})), u2.update();
|
||||
}, forceUpdate: function() {
|
||||
if (!p2) {
|
||||
var e4 = f2.elements, t4 = e4.reference, n2 = e4.popper;
|
||||
if (Q(t4, n2)) {
|
||||
f2.rects = { reference: y(t4, E(n2), "fixed" === f2.options.strategy), popper: g(n2) }, f2.reset = false, f2.placement = f2.options.placement, f2.orderedModifiers.forEach((function(e5) {
|
||||
return f2.modifiersData[e5.name] = Object.assign({}, e5.data);
|
||||
}));
|
||||
for (var r4 = 0; r4 < f2.orderedModifiers.length; r4++) if (true !== f2.reset) {
|
||||
var o3 = f2.orderedModifiers[r4], i4 = o3.fn, a3 = o3.options, s3 = void 0 === a3 ? {} : a3, c3 = o3.name;
|
||||
"function" == typeof i4 && (f2 = i4({ state: f2, options: s3, name: c3, instance: u2 }) || f2);
|
||||
} else f2.reset = false, r4 = -1;
|
||||
}
|
||||
}
|
||||
}, update: (i3 = function() {
|
||||
return new Promise((function(e4) {
|
||||
u2.forceUpdate(), e4(f2);
|
||||
}));
|
||||
}, function() {
|
||||
return s2 || (s2 = new Promise((function(e4) {
|
||||
Promise.resolve().then((function() {
|
||||
s2 = void 0, e4(i3());
|
||||
}));
|
||||
}))), s2;
|
||||
}), destroy: function() {
|
||||
l2(), p2 = true;
|
||||
} };
|
||||
if (!Q(e3, t3)) return u2;
|
||||
function l2() {
|
||||
c2.forEach((function(e4) {
|
||||
return e4();
|
||||
})), c2 = [];
|
||||
}
|
||||
return u2.setOptions(r3).then((function(e4) {
|
||||
!p2 && r3.onFirstUpdate && r3.onFirstUpdate(e4);
|
||||
})), u2;
|
||||
};
|
||||
}
|
||||
var $ = { passive: true };
|
||||
var ee = { name: "eventListeners", enabled: true, phase: "write", fn: function() {
|
||||
}, effect: function(e2) {
|
||||
var n2 = e2.state, r2 = e2.instance, o2 = e2.options, i2 = o2.scroll, a2 = void 0 === i2 || i2, s2 = o2.resize, f2 = void 0 === s2 || s2, c2 = t(n2.elements.popper), p2 = [].concat(n2.scrollParents.reference, n2.scrollParents.popper);
|
||||
return a2 && p2.forEach((function(e3) {
|
||||
e3.addEventListener("scroll", r2.update, $);
|
||||
})), f2 && c2.addEventListener("resize", r2.update, $), function() {
|
||||
a2 && p2.forEach((function(e3) {
|
||||
e3.removeEventListener("scroll", r2.update, $);
|
||||
})), f2 && c2.removeEventListener("resize", r2.update, $);
|
||||
};
|
||||
}, data: {} };
|
||||
var te = { name: "popperOffsets", enabled: true, phase: "read", fn: function(e2) {
|
||||
var t2 = e2.state, n2 = e2.name;
|
||||
t2.modifiersData[n2] = X({ reference: t2.rects.reference, element: t2.rects.popper, strategy: "absolute", placement: t2.placement });
|
||||
}, data: {} }, ne = { top: "auto", right: "auto", bottom: "auto", left: "auto" };
|
||||
function re(e2) {
|
||||
var n2, r2 = e2.popper, o2 = e2.popperRect, i2 = e2.placement, a2 = e2.variation, f2 = e2.offsets, c2 = e2.position, p2 = e2.gpuAcceleration, u2 = e2.adaptive, l2 = e2.roundOffsets, h2 = e2.isFixed, v2 = f2.x, y2 = void 0 === v2 ? 0 : v2, g2 = f2.y, b2 = void 0 === g2 ? 0 : g2, x2 = "function" == typeof l2 ? l2({ x: y2, y: b2 }) : { x: y2, y: b2 };
|
||||
y2 = x2.x, b2 = x2.y;
|
||||
var w2 = f2.hasOwnProperty("x"), O2 = f2.hasOwnProperty("y"), j2 = P, M2 = D, k2 = window;
|
||||
if (u2) {
|
||||
var W2 = E(r2), H2 = "clientHeight", T2 = "clientWidth";
|
||||
if (W2 === t(r2) && "static" !== m(W2 = d(r2)).position && "absolute" === c2 && (H2 = "scrollHeight", T2 = "scrollWidth"), W2 = W2, i2 === D || (i2 === P || i2 === L) && a2 === B) M2 = A, b2 -= (h2 && W2 === k2 && k2.visualViewport ? k2.visualViewport.height : W2[H2]) - o2.height, b2 *= p2 ? 1 : -1;
|
||||
if (i2 === P || (i2 === D || i2 === A) && a2 === B) j2 = L, y2 -= (h2 && W2 === k2 && k2.visualViewport ? k2.visualViewport.width : W2[T2]) - o2.width, y2 *= p2 ? 1 : -1;
|
||||
}
|
||||
var R2, S2 = Object.assign({ position: c2 }, u2 && ne), V2 = true === l2 ? (function(e3, t2) {
|
||||
var n3 = e3.x, r3 = e3.y, o3 = t2.devicePixelRatio || 1;
|
||||
return { x: s(n3 * o3) / o3 || 0, y: s(r3 * o3) / o3 || 0 };
|
||||
})({ x: y2, y: b2 }, t(r2)) : { x: y2, y: b2 };
|
||||
return y2 = V2.x, b2 = V2.y, p2 ? Object.assign({}, S2, ((R2 = {})[M2] = O2 ? "0" : "", R2[j2] = w2 ? "0" : "", R2.transform = (k2.devicePixelRatio || 1) <= 1 ? "translate(" + y2 + "px, " + b2 + "px)" : "translate3d(" + y2 + "px, " + b2 + "px, 0)", R2)) : Object.assign({}, S2, ((n2 = {})[M2] = O2 ? b2 + "px" : "", n2[j2] = w2 ? y2 + "px" : "", n2.transform = "", n2));
|
||||
}
|
||||
var oe = { name: "computeStyles", enabled: true, phase: "beforeWrite", fn: function(e2) {
|
||||
var t2 = e2.state, n2 = e2.options, r2 = n2.gpuAcceleration, o2 = void 0 === r2 || r2, i2 = n2.adaptive, a2 = void 0 === i2 || i2, s2 = n2.roundOffsets, f2 = void 0 === s2 || s2, c2 = { placement: F(t2.placement), variation: U(t2.placement), popper: t2.elements.popper, popperRect: t2.rects.popper, gpuAcceleration: o2, isFixed: "fixed" === t2.options.strategy };
|
||||
null != t2.modifiersData.popperOffsets && (t2.styles.popper = Object.assign({}, t2.styles.popper, re(Object.assign({}, c2, { offsets: t2.modifiersData.popperOffsets, position: t2.options.strategy, adaptive: a2, roundOffsets: f2 })))), null != t2.modifiersData.arrow && (t2.styles.arrow = Object.assign({}, t2.styles.arrow, re(Object.assign({}, c2, { offsets: t2.modifiersData.arrow, position: "absolute", adaptive: false, roundOffsets: f2 })))), t2.attributes.popper = Object.assign({}, t2.attributes.popper, { "data-popper-placement": t2.placement });
|
||||
}, data: {} };
|
||||
var ie = { name: "applyStyles", enabled: true, phase: "write", fn: function(e2) {
|
||||
var t2 = e2.state;
|
||||
Object.keys(t2.elements).forEach((function(e3) {
|
||||
var n2 = t2.styles[e3] || {}, o2 = t2.attributes[e3] || {}, i2 = t2.elements[e3];
|
||||
r(i2) && l(i2) && (Object.assign(i2.style, n2), Object.keys(o2).forEach((function(e4) {
|
||||
var t3 = o2[e4];
|
||||
false === t3 ? i2.removeAttribute(e4) : i2.setAttribute(e4, true === t3 ? "" : t3);
|
||||
})));
|
||||
}));
|
||||
}, effect: function(e2) {
|
||||
var t2 = e2.state, n2 = { popper: { position: t2.options.strategy, left: "0", top: "0", margin: "0" }, arrow: { position: "absolute" }, reference: {} };
|
||||
return Object.assign(t2.elements.popper.style, n2.popper), t2.styles = n2, t2.elements.arrow && Object.assign(t2.elements.arrow.style, n2.arrow), function() {
|
||||
Object.keys(t2.elements).forEach((function(e3) {
|
||||
var o2 = t2.elements[e3], i2 = t2.attributes[e3] || {}, a2 = Object.keys(t2.styles.hasOwnProperty(e3) ? t2.styles[e3] : n2[e3]).reduce((function(e4, t3) {
|
||||
return e4[t3] = "", e4;
|
||||
}), {});
|
||||
r(o2) && l(o2) && (Object.assign(o2.style, a2), Object.keys(i2).forEach((function(e4) {
|
||||
o2.removeAttribute(e4);
|
||||
})));
|
||||
}));
|
||||
};
|
||||
}, requires: ["computeStyles"] };
|
||||
var ae = { name: "offset", enabled: true, phase: "main", requires: ["popperOffsets"], fn: function(e2) {
|
||||
var t2 = e2.state, n2 = e2.options, r2 = e2.name, o2 = n2.offset, i2 = void 0 === o2 ? [0, 0] : o2, a2 = S.reduce((function(e3, n3) {
|
||||
return e3[n3] = (function(e4, t3, n4) {
|
||||
var r3 = F(e4), o3 = [P, D].indexOf(r3) >= 0 ? -1 : 1, i3 = "function" == typeof n4 ? n4(Object.assign({}, t3, { placement: e4 })) : n4, a3 = i3[0], s3 = i3[1];
|
||||
return a3 = a3 || 0, s3 = (s3 || 0) * o3, [P, L].indexOf(r3) >= 0 ? { x: s3, y: a3 } : { x: a3, y: s3 };
|
||||
})(n3, t2.rects, i2), e3;
|
||||
}), {}), s2 = a2[t2.placement], f2 = s2.x, c2 = s2.y;
|
||||
null != t2.modifiersData.popperOffsets && (t2.modifiersData.popperOffsets.x += f2, t2.modifiersData.popperOffsets.y += c2), t2.modifiersData[r2] = a2;
|
||||
} }, se = { left: "right", right: "left", bottom: "top", top: "bottom" };
|
||||
function fe(e2) {
|
||||
return e2.replace(/left|right|bottom|top/g, (function(e3) {
|
||||
return se[e3];
|
||||
}));
|
||||
}
|
||||
var ce = { start: "end", end: "start" };
|
||||
function pe(e2) {
|
||||
return e2.replace(/start|end/g, (function(e3) {
|
||||
return ce[e3];
|
||||
}));
|
||||
}
|
||||
function ue(e2, t2) {
|
||||
void 0 === t2 && (t2 = {});
|
||||
var n2 = t2, r2 = n2.placement, o2 = n2.boundary, i2 = n2.rootBoundary, a2 = n2.padding, s2 = n2.flipVariations, f2 = n2.allowedAutoPlacements, c2 = void 0 === f2 ? S : f2, p2 = U(r2), u2 = p2 ? s2 ? R : R.filter((function(e3) {
|
||||
return U(e3) === p2;
|
||||
})) : k, l2 = u2.filter((function(e3) {
|
||||
return c2.indexOf(e3) >= 0;
|
||||
}));
|
||||
0 === l2.length && (l2 = u2);
|
||||
var d2 = l2.reduce((function(t3, n3) {
|
||||
return t3[n3] = J(e2, { placement: n3, boundary: o2, rootBoundary: i2, padding: a2 })[F(n3)], t3;
|
||||
}), {});
|
||||
return Object.keys(d2).sort((function(e3, t3) {
|
||||
return d2[e3] - d2[t3];
|
||||
}));
|
||||
}
|
||||
var le = { name: "flip", enabled: true, phase: "main", fn: function(e2) {
|
||||
var t2 = e2.state, n2 = e2.options, r2 = e2.name;
|
||||
if (!t2.modifiersData[r2]._skip) {
|
||||
for (var o2 = n2.mainAxis, i2 = void 0 === o2 || o2, a2 = n2.altAxis, s2 = void 0 === a2 || a2, f2 = n2.fallbackPlacements, c2 = n2.padding, p2 = n2.boundary, u2 = n2.rootBoundary, l2 = n2.altBoundary, d2 = n2.flipVariations, h2 = void 0 === d2 || d2, m2 = n2.allowedAutoPlacements, v2 = t2.options.placement, y2 = F(v2), g2 = f2 || (y2 === v2 || !h2 ? [fe(v2)] : (function(e3) {
|
||||
if (F(e3) === M) return [];
|
||||
var t3 = fe(e3);
|
||||
return [pe(e3), t3, pe(t3)];
|
||||
})(v2)), b2 = [v2].concat(g2).reduce((function(e3, n3) {
|
||||
return e3.concat(F(n3) === M ? ue(t2, { placement: n3, boundary: p2, rootBoundary: u2, padding: c2, flipVariations: h2, allowedAutoPlacements: m2 }) : n3);
|
||||
}), []), x2 = t2.rects.reference, w2 = t2.rects.popper, O2 = /* @__PURE__ */ new Map(), j2 = true, E2 = b2[0], k2 = 0; k2 < b2.length; k2++) {
|
||||
var B2 = b2[k2], H2 = F(B2), T2 = U(B2) === W, R2 = [D, A].indexOf(H2) >= 0, S2 = R2 ? "width" : "height", V2 = J(t2, { placement: B2, boundary: p2, rootBoundary: u2, altBoundary: l2, padding: c2 }), q2 = R2 ? T2 ? L : P : T2 ? A : D;
|
||||
x2[S2] > w2[S2] && (q2 = fe(q2));
|
||||
var C2 = fe(q2), N2 = [];
|
||||
if (i2 && N2.push(V2[H2] <= 0), s2 && N2.push(V2[q2] <= 0, V2[C2] <= 0), N2.every((function(e3) {
|
||||
return e3;
|
||||
}))) {
|
||||
E2 = B2, j2 = false;
|
||||
break;
|
||||
}
|
||||
O2.set(B2, N2);
|
||||
}
|
||||
if (j2) for (var I2 = function(e3) {
|
||||
var t3 = b2.find((function(t4) {
|
||||
var n3 = O2.get(t4);
|
||||
if (n3) return n3.slice(0, e3).every((function(e4) {
|
||||
return e4;
|
||||
}));
|
||||
}));
|
||||
if (t3) return E2 = t3, "break";
|
||||
}, _2 = h2 ? 3 : 1; _2 > 0; _2--) {
|
||||
if ("break" === I2(_2)) break;
|
||||
}
|
||||
t2.placement !== E2 && (t2.modifiersData[r2]._skip = true, t2.placement = E2, t2.reset = true);
|
||||
}
|
||||
}, requiresIfExists: ["offset"], data: { _skip: false } };
|
||||
function de(e2, t2, n2) {
|
||||
return i(e2, a(t2, n2));
|
||||
}
|
||||
var he = { name: "preventOverflow", enabled: true, phase: "main", fn: function(e2) {
|
||||
var t2 = e2.state, n2 = e2.options, r2 = e2.name, o2 = n2.mainAxis, s2 = void 0 === o2 || o2, f2 = n2.altAxis, c2 = void 0 !== f2 && f2, p2 = n2.boundary, u2 = n2.rootBoundary, l2 = n2.altBoundary, d2 = n2.padding, h2 = n2.tether, m2 = void 0 === h2 || h2, v2 = n2.tetherOffset, y2 = void 0 === v2 ? 0 : v2, b2 = J(t2, { boundary: p2, rootBoundary: u2, padding: d2, altBoundary: l2 }), x2 = F(t2.placement), w2 = U(t2.placement), O2 = !w2, j2 = z(x2), M2 = "x" === j2 ? "y" : "x", k2 = t2.modifiersData.popperOffsets, B2 = t2.rects.reference, H2 = t2.rects.popper, T2 = "function" == typeof y2 ? y2(Object.assign({}, t2.rects, { placement: t2.placement })) : y2, R2 = "number" == typeof T2 ? { mainAxis: T2, altAxis: T2 } : Object.assign({ mainAxis: 0, altAxis: 0 }, T2), S2 = t2.modifiersData.offset ? t2.modifiersData.offset[t2.placement] : null, V2 = { x: 0, y: 0 };
|
||||
if (k2) {
|
||||
if (s2) {
|
||||
var q2, C2 = "y" === j2 ? D : P, N2 = "y" === j2 ? A : L, I2 = "y" === j2 ? "height" : "width", _2 = k2[j2], X2 = _2 + b2[C2], Y2 = _2 - b2[N2], G2 = m2 ? -H2[I2] / 2 : 0, K2 = w2 === W ? B2[I2] : H2[I2], Q2 = w2 === W ? -H2[I2] : -B2[I2], Z2 = t2.elements.arrow, $2 = m2 && Z2 ? g(Z2) : { width: 0, height: 0 }, ee2 = t2.modifiersData["arrow#persistent"] ? t2.modifiersData["arrow#persistent"].padding : { top: 0, right: 0, bottom: 0, left: 0 }, te2 = ee2[C2], ne2 = ee2[N2], re2 = de(0, B2[I2], $2[I2]), oe2 = O2 ? B2[I2] / 2 - G2 - re2 - te2 - R2.mainAxis : K2 - re2 - te2 - R2.mainAxis, ie2 = O2 ? -B2[I2] / 2 + G2 + re2 + ne2 + R2.mainAxis : Q2 + re2 + ne2 + R2.mainAxis, ae2 = t2.elements.arrow && E(t2.elements.arrow), se2 = ae2 ? "y" === j2 ? ae2.clientTop || 0 : ae2.clientLeft || 0 : 0, fe2 = null != (q2 = null == S2 ? void 0 : S2[j2]) ? q2 : 0, ce2 = _2 + ie2 - fe2, pe2 = de(m2 ? a(X2, _2 + oe2 - fe2 - se2) : X2, _2, m2 ? i(Y2, ce2) : Y2);
|
||||
k2[j2] = pe2, V2[j2] = pe2 - _2;
|
||||
}
|
||||
if (c2) {
|
||||
var ue2, le2 = "x" === j2 ? D : P, he2 = "x" === j2 ? A : L, me2 = k2[M2], ve2 = "y" === M2 ? "height" : "width", ye2 = me2 + b2[le2], ge2 = me2 - b2[he2], be2 = -1 !== [D, P].indexOf(x2), xe2 = null != (ue2 = null == S2 ? void 0 : S2[M2]) ? ue2 : 0, we2 = be2 ? ye2 : me2 - B2[ve2] - H2[ve2] - xe2 + R2.altAxis, Oe = be2 ? me2 + B2[ve2] + H2[ve2] - xe2 - R2.altAxis : ge2, je = m2 && be2 ? (function(e3, t3, n3) {
|
||||
var r3 = de(e3, t3, n3);
|
||||
return r3 > n3 ? n3 : r3;
|
||||
})(we2, me2, Oe) : de(m2 ? we2 : ye2, me2, m2 ? Oe : ge2);
|
||||
k2[M2] = je, V2[M2] = je - me2;
|
||||
}
|
||||
t2.modifiersData[r2] = V2;
|
||||
}
|
||||
}, requiresIfExists: ["offset"] };
|
||||
var me = { name: "arrow", enabled: true, phase: "main", fn: function(e2) {
|
||||
var t2, n2 = e2.state, r2 = e2.name, o2 = e2.options, i2 = n2.elements.arrow, a2 = n2.modifiersData.popperOffsets, s2 = F(n2.placement), f2 = z(s2), c2 = [P, L].indexOf(s2) >= 0 ? "height" : "width";
|
||||
if (i2 && a2) {
|
||||
var p2 = (function(e3, t3) {
|
||||
return Y("number" != typeof (e3 = "function" == typeof e3 ? e3(Object.assign({}, t3.rects, { placement: t3.placement })) : e3) ? e3 : G(e3, k));
|
||||
})(o2.padding, n2), u2 = g(i2), l2 = "y" === f2 ? D : P, d2 = "y" === f2 ? A : L, h2 = n2.rects.reference[c2] + n2.rects.reference[f2] - a2[f2] - n2.rects.popper[c2], m2 = a2[f2] - n2.rects.reference[f2], v2 = E(i2), y2 = v2 ? "y" === f2 ? v2.clientHeight || 0 : v2.clientWidth || 0 : 0, b2 = h2 / 2 - m2 / 2, x2 = p2[l2], w2 = y2 - u2[c2] - p2[d2], O2 = y2 / 2 - u2[c2] / 2 + b2, j2 = de(x2, O2, w2), M2 = f2;
|
||||
n2.modifiersData[r2] = ((t2 = {})[M2] = j2, t2.centerOffset = j2 - O2, t2);
|
||||
}
|
||||
}, effect: function(e2) {
|
||||
var t2 = e2.state, n2 = e2.options.element, r2 = void 0 === n2 ? "[data-popper-arrow]" : n2;
|
||||
null != r2 && ("string" != typeof r2 || (r2 = t2.elements.popper.querySelector(r2))) && C(t2.elements.popper, r2) && (t2.elements.arrow = r2);
|
||||
}, requires: ["popperOffsets"], requiresIfExists: ["preventOverflow"] };
|
||||
function ve(e2, t2, n2) {
|
||||
return void 0 === n2 && (n2 = { x: 0, y: 0 }), { top: e2.top - t2.height - n2.y, right: e2.right - t2.width + n2.x, bottom: e2.bottom - t2.height + n2.y, left: e2.left - t2.width - n2.x };
|
||||
}
|
||||
function ye(e2) {
|
||||
return [D, L, A, P].some((function(t2) {
|
||||
return e2[t2] >= 0;
|
||||
}));
|
||||
}
|
||||
var ge = { name: "hide", enabled: true, phase: "main", requiresIfExists: ["preventOverflow"], fn: function(e2) {
|
||||
var t2 = e2.state, n2 = e2.name, r2 = t2.rects.reference, o2 = t2.rects.popper, i2 = t2.modifiersData.preventOverflow, a2 = J(t2, { elementContext: "reference" }), s2 = J(t2, { altBoundary: true }), f2 = ve(a2, r2), c2 = ve(s2, o2, i2), p2 = ye(f2), u2 = ye(c2);
|
||||
t2.modifiersData[n2] = { referenceClippingOffsets: f2, popperEscapeOffsets: c2, isReferenceHidden: p2, hasPopperEscaped: u2 }, t2.attributes.popper = Object.assign({}, t2.attributes.popper, { "data-popper-reference-hidden": p2, "data-popper-escaped": u2 });
|
||||
} }, be = Z({ defaultModifiers: [ee, te, oe, ie] }), xe = [ee, te, oe, ie, ae, le, he, me, ge], we = Z({ defaultModifiers: xe });
|
||||
e.applyStyles = ie, e.arrow = me, e.computeStyles = oe, e.createPopper = we, e.createPopperLite = be, e.defaultModifiers = xe, e.detectOverflow = J, e.eventListeners = ee, e.flip = le, e.hide = ge, e.offset = ae, e.popperGenerator = Z, e.popperOffsets = te, e.preventOverflow = he, Object.defineProperty(e, "__esModule", { value: true });
|
||||
}));
|
||||
}
|
||||
});
|
||||
|
||||
// frontend/js/components/customSelectV2.js
|
||||
var import_popper_esm_min = __toESM(require_popper_esm_min());
|
||||
var CustomSelectV2 = class _CustomSelectV2 {
|
||||
constructor(container) {
|
||||
this.container = container;
|
||||
this.trigger = this.container.querySelector(".custom-select-trigger");
|
||||
this.nativeSelect = this.container.querySelector("select");
|
||||
this.template = this.container.querySelector(".custom-select-panel-template");
|
||||
if (!this.trigger || !this.nativeSelect || !this.template) {
|
||||
console.warn("CustomSelectV2 cannot initialize: missing required elements.", this.container);
|
||||
return;
|
||||
}
|
||||
this.panel = null;
|
||||
this.popperInstance = null;
|
||||
this.isOpen = false;
|
||||
this.triggerText = this.trigger.querySelector("span");
|
||||
if (typeof _CustomSelectV2.openInstance === "undefined") {
|
||||
_CustomSelectV2.openInstance = null;
|
||||
_CustomSelectV2.initGlobalListener();
|
||||
}
|
||||
this.updateTriggerText();
|
||||
this.bindEvents();
|
||||
}
|
||||
static initGlobalListener() {
|
||||
document.addEventListener("click", (event) => {
|
||||
const instance = _CustomSelectV2.openInstance;
|
||||
if (instance && !instance.container.contains(event.target) && (!instance.panel || !instance.panel.contains(event.target))) {
|
||||
instance.close();
|
||||
}
|
||||
});
|
||||
}
|
||||
createPanel() {
|
||||
const panelFragment = this.template.content.cloneNode(true);
|
||||
this.panel = panelFragment.querySelector(".custom-select-panel");
|
||||
document.body.appendChild(this.panel);
|
||||
this.panel.innerHTML = "";
|
||||
Array.from(this.nativeSelect.options).forEach((option) => {
|
||||
const item = document.createElement("a");
|
||||
item.href = "#";
|
||||
item.className = "custom-select-option block w-full text-left px-3 py-1.5 text-sm text-zinc-700 hover:bg-zinc-100 dark:text-zinc-200 dark:hover:bg-zinc-700";
|
||||
item.textContent = option.textContent;
|
||||
item.dataset.value = option.value;
|
||||
if (option.selected) {
|
||||
item.classList.add("is-selected");
|
||||
}
|
||||
this.panel.appendChild(item);
|
||||
});
|
||||
this.panel.addEventListener("click", (event) => {
|
||||
event.preventDefault();
|
||||
const optionEl = event.target.closest(".custom-select-option");
|
||||
if (optionEl) {
|
||||
this.selectOption(optionEl);
|
||||
}
|
||||
});
|
||||
}
|
||||
bindEvents() {
|
||||
this.trigger.addEventListener("click", (event) => {
|
||||
event.stopPropagation();
|
||||
if (_CustomSelectV2.openInstance && _CustomSelectV2.openInstance !== this) {
|
||||
_CustomSelectV2.openInstance.close();
|
||||
}
|
||||
this.toggle();
|
||||
});
|
||||
}
|
||||
selectOption(optionEl) {
|
||||
const selectedValue = optionEl.dataset.value;
|
||||
if (this.nativeSelect.value !== selectedValue) {
|
||||
this.nativeSelect.value = selectedValue;
|
||||
this.nativeSelect.dispatchEvent(new Event("change", { bubbles: true }));
|
||||
}
|
||||
this.updateTriggerText();
|
||||
this.close();
|
||||
}
|
||||
updateTriggerText() {
|
||||
const selectedOption = this.nativeSelect.options[this.nativeSelect.selectedIndex];
|
||||
if (selectedOption) {
|
||||
this.triggerText.textContent = selectedOption.textContent;
|
||||
}
|
||||
}
|
||||
toggle() {
|
||||
this.isOpen ? this.close() : this.open();
|
||||
}
|
||||
open() {
|
||||
if (this.isOpen) return;
|
||||
this.isOpen = true;
|
||||
if (!this.panel) {
|
||||
this.createPanel();
|
||||
}
|
||||
this.panel.style.display = "block";
|
||||
this.panel.offsetHeight;
|
||||
this.popperInstance = (0, import_popper_esm_min.createPopper)(this.trigger, this.panel, {
|
||||
placement: "top-start",
|
||||
modifiers: [
|
||||
{ name: "offset", options: { offset: [0, 8] } },
|
||||
{ name: "flip", options: { fallbackPlacements: ["bottom-start"] } }
|
||||
]
|
||||
});
|
||||
_CustomSelectV2.openInstance = this;
|
||||
}
|
||||
close() {
|
||||
if (!this.isOpen) return;
|
||||
this.isOpen = false;
|
||||
if (this.popperInstance) {
|
||||
this.popperInstance.destroy();
|
||||
this.popperInstance = null;
|
||||
}
|
||||
if (this.panel) {
|
||||
this.panel.remove();
|
||||
this.panel = null;
|
||||
}
|
||||
if (_CustomSelectV2.openInstance === this) {
|
||||
_CustomSelectV2.openInstance = null;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export {
|
||||
require_popper_esm_min,
|
||||
CustomSelectV2
|
||||
};
|
||||
@@ -111,281 +111,6 @@ var CustomSelect = class _CustomSelect {
|
||||
}
|
||||
};
|
||||
|
||||
// frontend/js/components/ui.js
|
||||
var ModalManager = class {
|
||||
/**
|
||||
* Shows a generic modal by its ID.
|
||||
* @param {string} modalId The ID of the modal element to show.
|
||||
*/
|
||||
show(modalId) {
|
||||
const modal = document.getElementById(modalId);
|
||||
if (modal) {
|
||||
modal.classList.remove("hidden");
|
||||
} else {
|
||||
console.error(`Modal with ID "${modalId}" not found.`);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Hides a generic modal by its ID.
|
||||
* @param {string} modalId The ID of the modal element to hide.
|
||||
*/
|
||||
hide(modalId) {
|
||||
const modal = document.getElementById(modalId);
|
||||
if (modal) {
|
||||
modal.classList.add("hidden");
|
||||
} else {
|
||||
console.error(`Modal with ID "${modalId}" not found.`);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Shows a confirmation dialog. This is a versatile method for 'Are you sure?' style prompts.
|
||||
* It dynamically sets the title, message, and confirm action for a generic confirmation modal.
|
||||
* @param {object} options - The options for the confirmation modal.
|
||||
* @param {string} options.modalId - The ID of the confirmation modal element (e.g., 'resetModal', 'deleteConfirmModal').
|
||||
* @param {string} options.title - The title to display in the modal header.
|
||||
* @param {string} options.message - The message to display in the modal body. Can contain HTML.
|
||||
* @param {function} options.onConfirm - The callback function to execute when the confirm button is clicked.
|
||||
* @param {boolean} [options.disableConfirm=false] - Whether the confirm button should be initially disabled.
|
||||
*/
|
||||
showConfirm({ modalId, title, message, onConfirm, disableConfirm = false }) {
|
||||
const modalElement = document.getElementById(modalId);
|
||||
if (!modalElement) {
|
||||
console.error(`Confirmation modal with ID "${modalId}" not found.`);
|
||||
return;
|
||||
}
|
||||
const titleElement = modalElement.querySelector('[id$="ModalTitle"]');
|
||||
const messageElement = modalElement.querySelector('[id$="ModalMessage"]');
|
||||
const confirmButton = modalElement.querySelector('[id^="confirm"]');
|
||||
if (!titleElement || !messageElement || !confirmButton) {
|
||||
console.error(`Modal "${modalId}" is missing required child elements (title, message, or confirm button).`);
|
||||
return;
|
||||
}
|
||||
titleElement.textContent = title;
|
||||
messageElement.innerHTML = message;
|
||||
confirmButton.disabled = disableConfirm;
|
||||
const newConfirmButton = confirmButton.cloneNode(true);
|
||||
confirmButton.parentNode.replaceChild(newConfirmButton, confirmButton);
|
||||
newConfirmButton.onclick = () => onConfirm();
|
||||
this.show(modalId);
|
||||
}
|
||||
/**
|
||||
* Shows a result modal to indicate the outcome of an operation (success or failure).
|
||||
* @param {boolean} success - If true, displays a success icon and title; otherwise, shows failure indicators.
|
||||
* @param {string|Node} message - The message to display. Can be a simple string or a complex DOM Node for rich content.
|
||||
* @param {boolean} [autoReload=false] - If true, the page will automatically reload when the modal is closed.
|
||||
*/
|
||||
showResult(success, message, autoReload = false) {
|
||||
const modalElement = document.getElementById("resultModal");
|
||||
if (!modalElement) {
|
||||
console.error("Result modal with ID 'resultModal' not found.");
|
||||
return;
|
||||
}
|
||||
const titleElement = document.getElementById("resultModalTitle");
|
||||
const messageElement = document.getElementById("resultModalMessage");
|
||||
const iconElement = document.getElementById("resultIcon");
|
||||
const confirmButton = document.getElementById("resultModalConfirmBtn");
|
||||
if (!titleElement || !messageElement || !iconElement || !confirmButton) {
|
||||
console.error("Result modal is missing required child elements.");
|
||||
return;
|
||||
}
|
||||
titleElement.textContent = success ? "\u64CD\u4F5C\u6210\u529F" : "\u64CD\u4F5C\u5931\u8D25";
|
||||
if (success) {
|
||||
iconElement.innerHTML = '<i class="fas fa-check-circle text-success-500"></i>';
|
||||
iconElement.className = "text-6xl mb-3 text-success-500";
|
||||
} else {
|
||||
iconElement.innerHTML = '<i class="fas fa-times-circle text-danger-500"></i>';
|
||||
iconElement.className = "text-6xl mb-3 text-danger-500";
|
||||
}
|
||||
messageElement.innerHTML = "";
|
||||
if (typeof message === "string") {
|
||||
const messageDiv = document.createElement("div");
|
||||
messageDiv.innerText = message;
|
||||
messageElement.appendChild(messageDiv);
|
||||
} else if (message instanceof Node) {
|
||||
messageElement.appendChild(message);
|
||||
} else {
|
||||
const messageDiv = document.createElement("div");
|
||||
messageDiv.innerText = String(message);
|
||||
messageElement.appendChild(messageDiv);
|
||||
}
|
||||
confirmButton.onclick = () => this.closeResult(autoReload);
|
||||
this.show("resultModal");
|
||||
}
|
||||
/**
|
||||
* Closes the result modal.
|
||||
* @param {boolean} [reload=false] - If true, reloads the page after closing the modal.
|
||||
*/
|
||||
closeResult(reload = false) {
|
||||
this.hide("resultModal");
|
||||
if (reload) {
|
||||
location.reload();
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Shows and initializes the progress modal for long-running operations.
|
||||
* @param {string} title - The title to display for the progress modal.
|
||||
*/
|
||||
showProgress(title) {
|
||||
const modal = document.getElementById("progressModal");
|
||||
if (!modal) {
|
||||
console.error("Progress modal with ID 'progressModal' not found.");
|
||||
return;
|
||||
}
|
||||
const titleElement = document.getElementById("progressModalTitle");
|
||||
const statusText = document.getElementById("progressStatusText");
|
||||
const progressBar = document.getElementById("progressBar");
|
||||
const progressPercentage = document.getElementById("progressPercentage");
|
||||
const progressLog = document.getElementById("progressLog");
|
||||
const closeButton = document.getElementById("progressModalCloseBtn");
|
||||
const closeIcon = document.getElementById("closeProgressModalBtn");
|
||||
if (!titleElement || !statusText || !progressBar || !progressPercentage || !progressLog || !closeButton || !closeIcon) {
|
||||
console.error("Progress modal is missing required child elements.");
|
||||
return;
|
||||
}
|
||||
titleElement.textContent = title;
|
||||
statusText.textContent = "\u51C6\u5907\u5F00\u59CB...";
|
||||
progressBar.style.width = "0%";
|
||||
progressPercentage.textContent = "0%";
|
||||
progressLog.innerHTML = "";
|
||||
closeButton.disabled = true;
|
||||
closeIcon.disabled = true;
|
||||
this.show("progressModal");
|
||||
}
|
||||
/**
|
||||
* Updates the progress bar and status text within the progress modal.
|
||||
* @param {number} processed - The number of items that have been processed.
|
||||
* @param {number} total - The total number of items to process.
|
||||
* @param {string} status - The current status message to display.
|
||||
*/
|
||||
updateProgress(processed, total, status) {
|
||||
const modal = document.getElementById("progressModal");
|
||||
if (!modal || modal.classList.contains("hidden")) return;
|
||||
const progressBar = document.getElementById("progressBar");
|
||||
const progressPercentage = document.getElementById("progressPercentage");
|
||||
const statusText = document.getElementById("progressStatusText");
|
||||
const closeButton = document.getElementById("progressModalCloseBtn");
|
||||
const closeIcon = document.getElementById("closeProgressModalBtn");
|
||||
const percentage = total > 0 ? Math.round(processed / total * 100) : 0;
|
||||
progressBar.style.width = `${percentage}%`;
|
||||
progressPercentage.textContent = `${percentage}%`;
|
||||
statusText.textContent = status;
|
||||
if (processed === total) {
|
||||
closeButton.disabled = false;
|
||||
closeIcon.disabled = false;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Adds a log entry to the progress modal's log area.
|
||||
* @param {string} message - The log message to append.
|
||||
* @param {boolean} [isError=false] - If true, styles the log entry as an error.
|
||||
*/
|
||||
addProgressLog(message, isError = false) {
|
||||
const progressLog = document.getElementById("progressLog");
|
||||
if (!progressLog) return;
|
||||
const logEntry = document.createElement("div");
|
||||
logEntry.textContent = message;
|
||||
logEntry.className = isError ? "text-danger-600" : "text-gray-700";
|
||||
progressLog.appendChild(logEntry);
|
||||
progressLog.scrollTop = progressLog.scrollHeight;
|
||||
}
|
||||
/**
|
||||
* Closes the progress modal.
|
||||
* @param {boolean} [reload=false] - If true, reloads the page after closing.
|
||||
*/
|
||||
closeProgress(reload = false) {
|
||||
this.hide("progressModal");
|
||||
if (reload) {
|
||||
location.reload();
|
||||
}
|
||||
}
|
||||
};
|
||||
var UIPatterns = class {
|
||||
/**
|
||||
* Animates numerical values in elements from 0 to their target number.
|
||||
* The target number is read from the element's text content.
|
||||
* @param {string} selector - The CSS selector for the elements to animate (e.g., '.stat-value').
|
||||
* @param {number} [duration=1500] - The duration of the animation in milliseconds.
|
||||
*/
|
||||
animateCounters(selector = ".stat-value", duration = 1500) {
|
||||
const statValues = document.querySelectorAll(selector);
|
||||
statValues.forEach((valueElement) => {
|
||||
const finalValue = parseInt(valueElement.textContent, 10);
|
||||
if (isNaN(finalValue)) return;
|
||||
if (!valueElement.dataset.originalValue) {
|
||||
valueElement.dataset.originalValue = valueElement.textContent;
|
||||
}
|
||||
let startValue = 0;
|
||||
const startTime = performance.now();
|
||||
const updateCounter = (currentTime) => {
|
||||
const elapsedTime = currentTime - startTime;
|
||||
if (elapsedTime < duration) {
|
||||
const progress = elapsedTime / duration;
|
||||
const easeOutValue = 1 - Math.pow(1 - progress, 3);
|
||||
const currentValue = Math.floor(easeOutValue * finalValue);
|
||||
valueElement.textContent = currentValue;
|
||||
requestAnimationFrame(updateCounter);
|
||||
} else {
|
||||
valueElement.textContent = valueElement.dataset.originalValue;
|
||||
}
|
||||
};
|
||||
requestAnimationFrame(updateCounter);
|
||||
});
|
||||
}
|
||||
/**
|
||||
* Toggles the visibility of a content section with a smooth height animation.
|
||||
* It expects a specific HTML structure where the header and content are within a common parent (e.g., a card).
|
||||
* The content element should have a `collapsed` class when hidden.
|
||||
* @param {HTMLElement} header - The header element that was clicked to trigger the toggle.
|
||||
*/
|
||||
toggleSection(header) {
|
||||
const card = header.closest(".stats-card");
|
||||
if (!card) return;
|
||||
const content = card.querySelector(".key-content");
|
||||
const toggleIcon = header.querySelector(".toggle-icon");
|
||||
if (!content || !toggleIcon) {
|
||||
console.error("Toggle section failed: Content or icon element not found.", { header });
|
||||
return;
|
||||
}
|
||||
const isCollapsed = content.classList.contains("collapsed");
|
||||
toggleIcon.classList.toggle("collapsed", !isCollapsed);
|
||||
if (isCollapsed) {
|
||||
content.classList.remove("collapsed");
|
||||
content.style.maxHeight = null;
|
||||
content.style.opacity = null;
|
||||
content.style.paddingTop = null;
|
||||
content.style.paddingBottom = null;
|
||||
content.style.overflow = "hidden";
|
||||
requestAnimationFrame(() => {
|
||||
const targetHeight = content.scrollHeight;
|
||||
content.style.maxHeight = `${targetHeight}px`;
|
||||
content.style.opacity = "1";
|
||||
content.style.paddingTop = "1rem";
|
||||
content.style.paddingBottom = "1rem";
|
||||
content.addEventListener("transitionend", function onExpansionEnd() {
|
||||
content.removeEventListener("transitionend", onExpansionEnd);
|
||||
if (!content.classList.contains("collapsed")) {
|
||||
content.style.maxHeight = "";
|
||||
content.style.overflow = "visible";
|
||||
}
|
||||
}, { once: true });
|
||||
});
|
||||
} else {
|
||||
const currentHeight = content.scrollHeight;
|
||||
content.style.maxHeight = `${currentHeight}px`;
|
||||
content.style.overflow = "hidden";
|
||||
requestAnimationFrame(() => {
|
||||
content.style.maxHeight = "0px";
|
||||
content.style.opacity = "0";
|
||||
content.style.paddingTop = "0";
|
||||
content.style.paddingBottom = "0";
|
||||
content.classList.add("collapsed");
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
var modalManager = new ModalManager();
|
||||
var uiPatterns = new UIPatterns();
|
||||
|
||||
// frontend/js/components/taskCenter.js
|
||||
var TaskCenterManager = class {
|
||||
constructor() {
|
||||
@@ -810,8 +535,6 @@ var toastManager = new ToastManager();
|
||||
|
||||
export {
|
||||
CustomSelect,
|
||||
modalManager,
|
||||
uiPatterns,
|
||||
taskCenterManager,
|
||||
toastManager
|
||||
};
|
||||
393
web/static/js/chunk-VOGCL6QZ.js
Normal file
393
web/static/js/chunk-VOGCL6QZ.js
Normal file
@@ -0,0 +1,393 @@
|
||||
// frontend/js/components/ui.js
|
||||
var ModalManager = class {
|
||||
/**
|
||||
* Shows a generic modal by its ID.
|
||||
* @param {string} modalId The ID of the modal element to show.
|
||||
*/
|
||||
show(modalId) {
|
||||
const modal = document.getElementById(modalId);
|
||||
if (modal) {
|
||||
modal.classList.remove("hidden");
|
||||
} else {
|
||||
console.error(`Modal with ID "${modalId}" not found.`);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Hides a generic modal by its ID.
|
||||
* @param {string} modalId The ID of the modal element to hide.
|
||||
*/
|
||||
hide(modalId) {
|
||||
const modal = document.getElementById(modalId);
|
||||
if (modal) {
|
||||
modal.classList.add("hidden");
|
||||
} else {
|
||||
console.error(`Modal with ID "${modalId}" not found.`);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Shows a confirmation dialog. This is a versatile method for 'Are you sure?' style prompts.
|
||||
* It dynamically sets the title, message, and confirm action for a generic confirmation modal.
|
||||
* @param {object} options - The options for the confirmation modal.
|
||||
* @param {string} options.modalId - The ID of the confirmation modal element (e.g., 'resetModal', 'deleteConfirmModal').
|
||||
* @param {string} options.title - The title to display in the modal header.
|
||||
* @param {string} options.message - The message to display in the modal body. Can contain HTML.
|
||||
* @param {function} options.onConfirm - The callback function to execute when the confirm button is clicked.
|
||||
* @param {boolean} [options.disableConfirm=false] - Whether the confirm button should be initially disabled.
|
||||
*/
|
||||
showConfirm({ modalId, title, message, onConfirm, disableConfirm = false }) {
|
||||
const modalElement = document.getElementById(modalId);
|
||||
if (!modalElement) {
|
||||
console.error(`Confirmation modal with ID "${modalId}" not found.`);
|
||||
return;
|
||||
}
|
||||
const titleElement = modalElement.querySelector('[id$="ModalTitle"]');
|
||||
const messageElement = modalElement.querySelector('[id$="ModalMessage"]');
|
||||
const confirmButton = modalElement.querySelector('[id^="confirm"]');
|
||||
if (!titleElement || !messageElement || !confirmButton) {
|
||||
console.error(`Modal "${modalId}" is missing required child elements (title, message, or confirm button).`);
|
||||
return;
|
||||
}
|
||||
titleElement.textContent = title;
|
||||
messageElement.innerHTML = message;
|
||||
confirmButton.disabled = disableConfirm;
|
||||
const newConfirmButton = confirmButton.cloneNode(true);
|
||||
confirmButton.parentNode.replaceChild(newConfirmButton, confirmButton);
|
||||
newConfirmButton.onclick = () => onConfirm();
|
||||
this.show(modalId);
|
||||
}
|
||||
/**
|
||||
* Shows a result modal to indicate the outcome of an operation (success or failure).
|
||||
* @param {boolean} success - If true, displays a success icon and title; otherwise, shows failure indicators.
|
||||
* @param {string|Node} message - The message to display. Can be a simple string or a complex DOM Node for rich content.
|
||||
* @param {boolean} [autoReload=false] - If true, the page will automatically reload when the modal is closed.
|
||||
*/
|
||||
showResult(success, message, autoReload = false) {
|
||||
const modalElement = document.getElementById("resultModal");
|
||||
if (!modalElement) {
|
||||
console.error("Result modal with ID 'resultModal' not found.");
|
||||
return;
|
||||
}
|
||||
const titleElement = document.getElementById("resultModalTitle");
|
||||
const messageElement = document.getElementById("resultModalMessage");
|
||||
const iconElement = document.getElementById("resultIcon");
|
||||
const confirmButton = document.getElementById("resultModalConfirmBtn");
|
||||
if (!titleElement || !messageElement || !iconElement || !confirmButton) {
|
||||
console.error("Result modal is missing required child elements.");
|
||||
return;
|
||||
}
|
||||
titleElement.textContent = success ? "\u64CD\u4F5C\u6210\u529F" : "\u64CD\u4F5C\u5931\u8D25";
|
||||
if (success) {
|
||||
iconElement.innerHTML = '<i class="fas fa-check-circle text-success-500"></i>';
|
||||
iconElement.className = "text-6xl mb-3 text-success-500";
|
||||
} else {
|
||||
iconElement.innerHTML = '<i class="fas fa-times-circle text-danger-500"></i>';
|
||||
iconElement.className = "text-6xl mb-3 text-danger-500";
|
||||
}
|
||||
messageElement.innerHTML = "";
|
||||
if (typeof message === "string") {
|
||||
const messageDiv = document.createElement("div");
|
||||
messageDiv.innerText = message;
|
||||
messageElement.appendChild(messageDiv);
|
||||
} else if (message instanceof Node) {
|
||||
messageElement.appendChild(message);
|
||||
} else {
|
||||
const messageDiv = document.createElement("div");
|
||||
messageDiv.innerText = String(message);
|
||||
messageElement.appendChild(messageDiv);
|
||||
}
|
||||
confirmButton.onclick = () => this.closeResult(autoReload);
|
||||
this.show("resultModal");
|
||||
}
|
||||
/**
|
||||
* Closes the result modal.
|
||||
* @param {boolean} [reload=false] - If true, reloads the page after closing the modal.
|
||||
*/
|
||||
closeResult(reload = false) {
|
||||
this.hide("resultModal");
|
||||
if (reload) {
|
||||
location.reload();
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Shows and initializes the progress modal for long-running operations.
|
||||
* @param {string} title - The title to display for the progress modal.
|
||||
*/
|
||||
showProgress(title) {
|
||||
const modal = document.getElementById("progressModal");
|
||||
if (!modal) {
|
||||
console.error("Progress modal with ID 'progressModal' not found.");
|
||||
return;
|
||||
}
|
||||
const titleElement = document.getElementById("progressModalTitle");
|
||||
const statusText = document.getElementById("progressStatusText");
|
||||
const progressBar = document.getElementById("progressBar");
|
||||
const progressPercentage = document.getElementById("progressPercentage");
|
||||
const progressLog = document.getElementById("progressLog");
|
||||
const closeButton = document.getElementById("progressModalCloseBtn");
|
||||
const closeIcon = document.getElementById("closeProgressModalBtn");
|
||||
if (!titleElement || !statusText || !progressBar || !progressPercentage || !progressLog || !closeButton || !closeIcon) {
|
||||
console.error("Progress modal is missing required child elements.");
|
||||
return;
|
||||
}
|
||||
titleElement.textContent = title;
|
||||
statusText.textContent = "\u51C6\u5907\u5F00\u59CB...";
|
||||
progressBar.style.width = "0%";
|
||||
progressPercentage.textContent = "0%";
|
||||
progressLog.innerHTML = "";
|
||||
closeButton.disabled = true;
|
||||
closeIcon.disabled = true;
|
||||
this.show("progressModal");
|
||||
}
|
||||
/**
|
||||
* Updates the progress bar and status text within the progress modal.
|
||||
* @param {number} processed - The number of items that have been processed.
|
||||
* @param {number} total - The total number of items to process.
|
||||
* @param {string} status - The current status message to display.
|
||||
*/
|
||||
updateProgress(processed, total, status) {
|
||||
const modal = document.getElementById("progressModal");
|
||||
if (!modal || modal.classList.contains("hidden")) return;
|
||||
const progressBar = document.getElementById("progressBar");
|
||||
const progressPercentage = document.getElementById("progressPercentage");
|
||||
const statusText = document.getElementById("progressStatusText");
|
||||
const closeButton = document.getElementById("progressModalCloseBtn");
|
||||
const closeIcon = document.getElementById("closeProgressModalBtn");
|
||||
const percentage = total > 0 ? Math.round(processed / total * 100) : 0;
|
||||
progressBar.style.width = `${percentage}%`;
|
||||
progressPercentage.textContent = `${percentage}%`;
|
||||
statusText.textContent = status;
|
||||
if (processed === total) {
|
||||
closeButton.disabled = false;
|
||||
closeIcon.disabled = false;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Adds a log entry to the progress modal's log area.
|
||||
* @param {string} message - The log message to append.
|
||||
* @param {boolean} [isError=false] - If true, styles the log entry as an error.
|
||||
*/
|
||||
addProgressLog(message, isError = false) {
|
||||
const progressLog = document.getElementById("progressLog");
|
||||
if (!progressLog) return;
|
||||
const logEntry = document.createElement("div");
|
||||
logEntry.textContent = message;
|
||||
logEntry.className = isError ? "text-danger-600" : "text-gray-700";
|
||||
progressLog.appendChild(logEntry);
|
||||
progressLog.scrollTop = progressLog.scrollHeight;
|
||||
}
|
||||
/**
|
||||
* Closes the progress modal.
|
||||
* @param {boolean} [reload=false] - If true, reloads the page after closing.
|
||||
*/
|
||||
closeProgress(reload = false) {
|
||||
this.hide("progressModal");
|
||||
if (reload) {
|
||||
location.reload();
|
||||
}
|
||||
}
|
||||
};
|
||||
var UIPatterns = class {
|
||||
/**
|
||||
* Animates numerical values in elements from 0 to their target number.
|
||||
* The target number is read from the element's text content.
|
||||
* @param {string} selector - The CSS selector for the elements to animate (e.g., '.stat-value').
|
||||
* @param {number} [duration=1500] - The duration of the animation in milliseconds.
|
||||
*/
|
||||
animateCounters(selector = ".stat-value", duration = 1500) {
|
||||
const statValues = document.querySelectorAll(selector);
|
||||
statValues.forEach((valueElement) => {
|
||||
const finalValue = parseInt(valueElement.textContent, 10);
|
||||
if (isNaN(finalValue)) return;
|
||||
if (!valueElement.dataset.originalValue) {
|
||||
valueElement.dataset.originalValue = valueElement.textContent;
|
||||
}
|
||||
let startValue = 0;
|
||||
const startTime = performance.now();
|
||||
const updateCounter = (currentTime) => {
|
||||
const elapsedTime = currentTime - startTime;
|
||||
if (elapsedTime < duration) {
|
||||
const progress = elapsedTime / duration;
|
||||
const easeOutValue = 1 - Math.pow(1 - progress, 3);
|
||||
const currentValue = Math.floor(easeOutValue * finalValue);
|
||||
valueElement.textContent = currentValue;
|
||||
requestAnimationFrame(updateCounter);
|
||||
} else {
|
||||
valueElement.textContent = valueElement.dataset.originalValue;
|
||||
}
|
||||
};
|
||||
requestAnimationFrame(updateCounter);
|
||||
});
|
||||
}
|
||||
/**
|
||||
* Toggles the visibility of a content section with a smooth height animation.
|
||||
* It expects a specific HTML structure where the header and content are within a common parent (e.g., a card).
|
||||
* The content element should have a `collapsed` class when hidden.
|
||||
* @param {HTMLElement} header - The header element that was clicked to trigger the toggle.
|
||||
*/
|
||||
toggleSection(header) {
|
||||
const card = header.closest(".stats-card");
|
||||
if (!card) return;
|
||||
const content = card.querySelector(".key-content");
|
||||
const toggleIcon = header.querySelector(".toggle-icon");
|
||||
if (!content || !toggleIcon) {
|
||||
console.error("Toggle section failed: Content or icon element not found.", { header });
|
||||
return;
|
||||
}
|
||||
const isCollapsed = content.classList.contains("collapsed");
|
||||
toggleIcon.classList.toggle("collapsed", !isCollapsed);
|
||||
if (isCollapsed) {
|
||||
content.classList.remove("collapsed");
|
||||
content.style.maxHeight = null;
|
||||
content.style.opacity = null;
|
||||
content.style.paddingTop = null;
|
||||
content.style.paddingBottom = null;
|
||||
content.style.overflow = "hidden";
|
||||
requestAnimationFrame(() => {
|
||||
const targetHeight = content.scrollHeight;
|
||||
content.style.maxHeight = `${targetHeight}px`;
|
||||
content.style.opacity = "1";
|
||||
content.style.paddingTop = "1rem";
|
||||
content.style.paddingBottom = "1rem";
|
||||
content.addEventListener("transitionend", function onExpansionEnd() {
|
||||
content.removeEventListener("transitionend", onExpansionEnd);
|
||||
if (!content.classList.contains("collapsed")) {
|
||||
content.style.maxHeight = "";
|
||||
content.style.overflow = "visible";
|
||||
}
|
||||
}, { once: true });
|
||||
});
|
||||
} else {
|
||||
const currentHeight = content.scrollHeight;
|
||||
content.style.maxHeight = `${currentHeight}px`;
|
||||
content.style.overflow = "hidden";
|
||||
requestAnimationFrame(() => {
|
||||
content.style.maxHeight = "0px";
|
||||
content.style.opacity = "0";
|
||||
content.style.paddingTop = "0";
|
||||
content.style.paddingBottom = "0";
|
||||
content.classList.add("collapsed");
|
||||
});
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Sets a button to a loading state by disabling it and showing a spinner.
|
||||
* It stores the button's original content to be restored later.
|
||||
* @param {HTMLButtonElement} button - The button element to modify.
|
||||
*/
|
||||
setButtonLoading(button) {
|
||||
if (!button) return;
|
||||
if (!button.dataset.originalContent) {
|
||||
button.dataset.originalContent = button.innerHTML;
|
||||
}
|
||||
button.disabled = true;
|
||||
button.innerHTML = '<i class="fas fa-spinner fa-spin"></i>';
|
||||
}
|
||||
/**
|
||||
* Restores a button from its loading state to its original content and enables it.
|
||||
* @param {HTMLButtonElement} button - The button element to restore.
|
||||
*/
|
||||
clearButtonLoading(button) {
|
||||
if (!button) return;
|
||||
if (button.dataset.originalContent) {
|
||||
button.innerHTML = button.dataset.originalContent;
|
||||
delete button.dataset.originalContent;
|
||||
}
|
||||
button.disabled = false;
|
||||
}
|
||||
/**
|
||||
* Returns the HTML for a streaming text cursor animation.
|
||||
* This is used as a placeholder in the chat UI while waiting for an assistant's response.
|
||||
* @returns {string} The HTML string for the loader.
|
||||
*/
|
||||
renderStreamingLoader() {
|
||||
return '<span class="streaming-cursor animate-pulse">\u258B</span>';
|
||||
}
|
||||
};
|
||||
var modalManager = new ModalManager();
|
||||
var uiPatterns = new UIPatterns();
|
||||
|
||||
// frontend/js/services/api.js
|
||||
var APIClientError = class extends Error {
|
||||
constructor(message, status, code, rawMessageFromServer) {
|
||||
super(message);
|
||||
this.name = "APIClientError";
|
||||
this.status = status;
|
||||
this.code = code;
|
||||
this.rawMessageFromServer = rawMessageFromServer;
|
||||
}
|
||||
};
|
||||
var apiPromiseCache = /* @__PURE__ */ new Map();
|
||||
async function apiFetch(url, options = {}) {
|
||||
const isGetRequest = !options.method || options.method.toUpperCase() === "GET";
|
||||
const cacheKey = isGetRequest && !options.noCache ? url : null;
|
||||
if (cacheKey && apiPromiseCache.has(cacheKey)) {
|
||||
return apiPromiseCache.get(cacheKey);
|
||||
}
|
||||
const token = localStorage.getItem("bearerToken");
|
||||
const headers = {
|
||||
"Content-Type": "application/json",
|
||||
...options.headers
|
||||
};
|
||||
if (token) {
|
||||
headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
const requestPromise = (async () => {
|
||||
try {
|
||||
const response = await fetch(url, { ...options, headers });
|
||||
if (response.status === 401) {
|
||||
if (cacheKey) apiPromiseCache.delete(cacheKey);
|
||||
localStorage.removeItem("bearerToken");
|
||||
if (window.location.pathname !== "/login") {
|
||||
window.location.href = "/login?error=\u4F1A\u8BDD\u5DF2\u8FC7\u671F\uFF0C\u8BF7\u91CD\u65B0\u767B\u5F55\u3002";
|
||||
}
|
||||
throw new APIClientError("Unauthorized", 401, "UNAUTHORIZED", "Session expired or token is invalid.");
|
||||
}
|
||||
if (!response.ok) {
|
||||
let errorData = null;
|
||||
let rawMessage = "";
|
||||
try {
|
||||
rawMessage = await response.text();
|
||||
if (rawMessage) {
|
||||
errorData = JSON.parse(rawMessage);
|
||||
}
|
||||
} catch (e) {
|
||||
errorData = { error: { code: "UNKNOWN_FORMAT", message: rawMessage || response.statusText } };
|
||||
}
|
||||
const code = errorData?.error?.code || "UNKNOWN_ERROR";
|
||||
const messageFromServer = errorData?.error?.message || rawMessage || "No message provided by server.";
|
||||
const error = new APIClientError(
|
||||
`API request failed: ${response.status}`,
|
||||
response.status,
|
||||
code,
|
||||
messageFromServer
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
return response;
|
||||
} catch (error) {
|
||||
if (cacheKey) apiPromiseCache.delete(cacheKey);
|
||||
throw error;
|
||||
}
|
||||
})();
|
||||
if (cacheKey) {
|
||||
apiPromiseCache.set(cacheKey, requestPromise);
|
||||
}
|
||||
return requestPromise;
|
||||
}
|
||||
async function apiFetchJson(url, options = {}) {
|
||||
try {
|
||||
const response = await apiFetch(url, options);
|
||||
const clonedResponse = response.clone();
|
||||
const jsonData = await clonedResponse.json();
|
||||
return jsonData;
|
||||
} catch (error) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
export {
|
||||
modalManager,
|
||||
uiPatterns,
|
||||
apiFetch,
|
||||
apiFetchJson
|
||||
};
|
||||
@@ -1,3 +1,5 @@
|
||||
import "./chunk-JSBRDJBE.js";
|
||||
|
||||
// frontend/js/pages/dashboard.js
|
||||
function init() {
|
||||
console.log("[Modern Frontend] Dashboard module loaded. Future logic will execute here.");
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user