From 047f0b8422e565e05411dbb395d05198ec16b57d Mon Sep 17 00:00:00 2001 From: zhi Date: Wed, 11 Mar 2026 10:07:32 +0000 Subject: [PATCH] feat: add CORS support via CORS_ORIGINS env var - New CORSMiddleware in server/cors.go - Reads comma-separated origins from CORS_ORIGINS env - Empty or "*" allows all origins - Handles preflight OPTIONS requests - Wraps existing LoggingMiddleware chain --- main.go | 26 ++++++++++++++++++++++---- server/cors.go | 43 +++++++++++++++++++++++++++++++++++++++++++ server/server.go | 5 +++-- 3 files changed, 68 insertions(+), 6 deletions(-) create mode 100644 server/cors.go diff --git a/main.go b/main.go index 068535b..28b066b 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "log" "os" "strconv" + "strings" "AbstractWizard/audit" "AbstractWizard/server" @@ -11,9 +12,10 @@ import ( func main() { cfg := server.AppConfig{ - ConfigDir: envOrDefault("CONFIG_DIR", "/config"), - ListenAddr: envOrDefault("LISTEN_ADDR", "127.0.0.1:8080"), - MaxBackups: envOrDefaultInt("MAX_BACKUPS", 10), + ConfigDir: envOrDefault("CONFIG_DIR", "/config"), + ListenAddr: envOrDefault("LISTEN_ADDR", "127.0.0.1:8080"), + MaxBackups: envOrDefaultInt("MAX_BACKUPS", 10), + CORSOrigins: parseCSV(os.Getenv("CORS_ORIGINS")), } if err := os.MkdirAll(cfg.ConfigDir, 0o755); err != nil { @@ -23,7 +25,8 @@ func main() { auditLog := audit.NewLogger() srv := server.New(cfg, auditLog) - log.Printf("config_dir=%s listen_addr=%s max_backups=%d", cfg.ConfigDir, cfg.ListenAddr, cfg.MaxBackups) + log.Printf("config_dir=%s listen_addr=%s max_backups=%d cors_origins=%v", + cfg.ConfigDir, cfg.ListenAddr, cfg.MaxBackups, cfg.CORSOrigins) if err := srv.ListenAndServe(); err != nil { log.Fatalf("server error: %v", err) @@ -49,3 +52,18 @@ func envOrDefaultInt(key string, defaultVal int) int { } return n } + +func parseCSV(s string) []string { + if s == "" { + return nil + } + parts := strings.Split(s, ",") + result := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + result = append(result, p) + } + } + return result +} diff --git a/server/cors.go b/server/cors.go new file mode 100644 index 0000000..042735b --- /dev/null +++ b/server/cors.go @@ -0,0 +1,43 @@ +package server + +import ( + "net/http" + "strings" +) + +// CORSMiddleware adds CORS headers based on allowed origins. +// If allowedOrigins is empty or contains "*", all origins are allowed. +func CORSMiddleware(allowedOrigins []string, next http.Handler) http.Handler { + allowAll := len(allowedOrigins) == 0 + originSet := make(map[string]bool, len(allowedOrigins)) + for _, o := range allowedOrigins { + o = strings.TrimSpace(o) + if o == "*" { + allowAll = true + } + originSet[o] = true + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if origin == "" { + next.ServeHTTP(w, r) + return + } + + if allowAll || originSet[origin] { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, PATCH, POST, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + w.Header().Set("Access-Control-Max-Age", "3600") + w.Header().Set("Vary", "Origin") + } + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/server/server.go b/server/server.go index 27cd468..a310ca5 100644 --- a/server/server.go +++ b/server/server.go @@ -48,7 +48,8 @@ func ParseMode(s string) (Mode, bool) { type AppConfig struct { ConfigDir string ListenAddr string - MaxBackups int + MaxBackups int + CORSOrigins []string } // Server is the main HTTP server. @@ -73,7 +74,7 @@ func New(cfg AppConfig, auditLog *audit.Logger) *Server { s.srv = &http.Server{ Addr: cfg.ListenAddr, - Handler: LoggingMiddleware(mux), + Handler: CORSMiddleware(cfg.CORSOrigins, LoggingMiddleware(mux)), ReadTimeout: 10 * time.Second, WriteTimeout: 30 * time.Second, IdleTimeout: 60 * time.Second,