203 lines
4.9 KiB
Go
203 lines
4.9 KiB
Go
// config/config.go
|
|
package config
|
|
|
|
import (
|
|
"log"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type Config struct {
|
|
// 认证配置
|
|
Username string
|
|
Password string
|
|
SessionSecret string
|
|
SessionTimeout time.Duration
|
|
|
|
// 安全配置
|
|
RateLimit int
|
|
RateLimitWindow time.Duration
|
|
MaxResponseSize int64
|
|
|
|
// 代理配置
|
|
AllowedSchemes []string
|
|
UserAgent string
|
|
|
|
// 黑名单配置
|
|
BlockedDomains []string
|
|
BlockedCIDRs []string
|
|
|
|
// 缓存配置
|
|
CacheEnabled bool
|
|
CacheMaxSize int64
|
|
CacheTTL time.Duration
|
|
|
|
// 服务器配置
|
|
Port string
|
|
}
|
|
|
|
func LoadFromEnv() *Config {
|
|
cfg := &Config{
|
|
Username: getEnv("AUTH_USERNAME", "admin"),
|
|
Password: getEnvRequired("AUTH_PASSWORD"),
|
|
SessionSecret: getEnvOrGenerate("SESSION_SECRET"),
|
|
SessionTimeout: parseDuration(getEnv("SESSION_TIMEOUT", "30m")),
|
|
RateLimit: parseInt(getEnv("RATE_LIMIT_REQUESTS", "100")),
|
|
RateLimitWindow: parseDuration(getEnv("RATE_LIMIT_WINDOW", "1m")),
|
|
MaxResponseSize: parseInt64(getEnv("MAX_RESPONSE_SIZE", "52428800")),
|
|
AllowedSchemes: parseList(getEnv("ALLOWED_SCHEMES", "http,https")),
|
|
UserAgent: getEnv("USER_AGENT", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"),
|
|
BlockedDomains: parseList(getEnv("BLOCKED_DOMAINS", getDefaultBlockedDomains())),
|
|
BlockedCIDRs: parseList(getEnv("BLOCKED_CIDRS", getDefaultBlockedCIDRs())),
|
|
CacheEnabled: parseBool(getEnv("CACHE_ENABLED", "true")),
|
|
CacheMaxSize: parseInt64(getEnv("CACHE_MAX_SIZE", "104857600")),
|
|
CacheTTL: parseDuration(getEnv("CACHE_TTL", "1h")),
|
|
Port: getEnv("PORT", "8080"),
|
|
}
|
|
|
|
// 验证配置
|
|
cfg.validate()
|
|
|
|
return cfg
|
|
}
|
|
|
|
func (c *Config) validate() {
|
|
if c.Password == "" {
|
|
log.Fatal("AUTH_PASSWORD is required")
|
|
}
|
|
|
|
if c.Password == "your_secure_password_here" ||
|
|
c.Password == "change_this" {
|
|
log.Fatal("Please change the default AUTH_PASSWORD")
|
|
}
|
|
|
|
if len(c.Password) < 8 {
|
|
log.Fatal("AUTH_PASSWORD must be at least 8 characters")
|
|
}
|
|
|
|
if c.SessionSecret == "" {
|
|
log.Fatal("SESSION_SECRET is required")
|
|
}
|
|
|
|
if len(c.SessionSecret) < 32 {
|
|
log.Fatal("SESSION_SECRET must be at least 32 characters")
|
|
}
|
|
|
|
if c.SessionTimeout < time.Minute {
|
|
log.Fatal("SESSION_TIMEOUT must be at least 1 minute")
|
|
}
|
|
|
|
if c.RateLimit < 1 {
|
|
log.Fatal("RATE_LIMIT_REQUESTS must be at least 1")
|
|
}
|
|
|
|
if c.MaxResponseSize < 1024 {
|
|
log.Fatal("MAX_RESPONSE_SIZE must be at least 1KB")
|
|
}
|
|
}
|
|
|
|
func getEnv(key, defaultValue string) string {
|
|
if value := os.Getenv(key); value != "" {
|
|
return value
|
|
}
|
|
return defaultValue
|
|
}
|
|
|
|
func getEnvRequired(key string) string {
|
|
value := os.Getenv(key)
|
|
if value == "" {
|
|
log.Fatalf("%s is required", key)
|
|
}
|
|
return value
|
|
}
|
|
|
|
func getEnvOrGenerate(key string) string {
|
|
if value := os.Getenv(key); value != "" {
|
|
return value
|
|
}
|
|
|
|
// 生成随机密钥
|
|
secret := GenerateSessionSecret()
|
|
log.Printf("WARNING: %s not set, generated random value", key)
|
|
log.Printf("Add this to your .env file: %s=%s", key, secret)
|
|
return secret
|
|
}
|
|
|
|
func parseDuration(s string) time.Duration {
|
|
d, err := time.ParseDuration(s)
|
|
if err != nil {
|
|
log.Fatalf("Invalid duration: %s", s)
|
|
}
|
|
return d
|
|
}
|
|
|
|
func parseInt(s string) int {
|
|
i, err := strconv.Atoi(s)
|
|
if err != nil {
|
|
log.Fatalf("Invalid integer: %s", s)
|
|
}
|
|
return i
|
|
}
|
|
|
|
func parseInt64(s string) int64 {
|
|
i, err := strconv.ParseInt(s, 10, 64)
|
|
if err != nil {
|
|
log.Fatalf("Invalid integer: %s", s)
|
|
}
|
|
return i
|
|
}
|
|
|
|
func parseBool(s string) bool {
|
|
b, err := strconv.ParseBool(s)
|
|
if err != nil {
|
|
log.Fatalf("Invalid boolean: %s", s)
|
|
}
|
|
return b
|
|
}
|
|
|
|
func parseList(s string) []string {
|
|
if s == "" {
|
|
return []string{}
|
|
}
|
|
|
|
parts := strings.Split(s, ",")
|
|
result := make([]string, 0, len(parts))
|
|
|
|
for _, part := range parts {
|
|
if trimmed := strings.TrimSpace(part); trimmed != "" {
|
|
result = append(result, trimmed)
|
|
}
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
func getDefaultBlockedDomains() string {
|
|
return strings.Join([]string{
|
|
"localhost",
|
|
"127.0.0.1",
|
|
"0.0.0.0",
|
|
"*.local",
|
|
"internal",
|
|
"metadata.google.internal",
|
|
"169.254.169.254",
|
|
"metadata.azure.com",
|
|
"metadata.packet.net",
|
|
}, ",")
|
|
}
|
|
|
|
func getDefaultBlockedCIDRs() string {
|
|
return strings.Join([]string{
|
|
"10.0.0.0/8",
|
|
"172.16.0.0/12",
|
|
"192.168.0.0/16",
|
|
"169.254.0.0/16",
|
|
"::1/128",
|
|
"fc00::/7",
|
|
"fe80::/10",
|
|
"100.64.0.0/10",
|
|
}, ",")
|
|
}
|