feat: add CORS support via CORS_ORIGINS env var
- New CORSMiddleware in server/cors.go - Reads comma-separated origins from CORS_ORIGINS env - Empty or "*" allows all origins - Handles preflight OPTIONS requests - Wraps existing LoggingMiddleware chain
This commit is contained in:
26
main.go
26
main.go
@@ -4,6 +4,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"AbstractWizard/audit"
|
"AbstractWizard/audit"
|
||||||
"AbstractWizard/server"
|
"AbstractWizard/server"
|
||||||
@@ -11,9 +12,10 @@ import (
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
cfg := server.AppConfig{
|
cfg := server.AppConfig{
|
||||||
ConfigDir: envOrDefault("CONFIG_DIR", "/config"),
|
ConfigDir: envOrDefault("CONFIG_DIR", "/config"),
|
||||||
ListenAddr: envOrDefault("LISTEN_ADDR", "127.0.0.1:8080"),
|
ListenAddr: envOrDefault("LISTEN_ADDR", "127.0.0.1:8080"),
|
||||||
MaxBackups: envOrDefaultInt("MAX_BACKUPS", 10),
|
MaxBackups: envOrDefaultInt("MAX_BACKUPS", 10),
|
||||||
|
CORSOrigins: parseCSV(os.Getenv("CORS_ORIGINS")),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.MkdirAll(cfg.ConfigDir, 0o755); err != nil {
|
if err := os.MkdirAll(cfg.ConfigDir, 0o755); err != nil {
|
||||||
@@ -23,7 +25,8 @@ func main() {
|
|||||||
auditLog := audit.NewLogger()
|
auditLog := audit.NewLogger()
|
||||||
srv := server.New(cfg, auditLog)
|
srv := server.New(cfg, auditLog)
|
||||||
|
|
||||||
log.Printf("config_dir=%s listen_addr=%s max_backups=%d", cfg.ConfigDir, cfg.ListenAddr, cfg.MaxBackups)
|
log.Printf("config_dir=%s listen_addr=%s max_backups=%d cors_origins=%v",
|
||||||
|
cfg.ConfigDir, cfg.ListenAddr, cfg.MaxBackups, cfg.CORSOrigins)
|
||||||
|
|
||||||
if err := srv.ListenAndServe(); err != nil {
|
if err := srv.ListenAndServe(); err != nil {
|
||||||
log.Fatalf("server error: %v", err)
|
log.Fatalf("server error: %v", err)
|
||||||
@@ -49,3 +52,18 @@ func envOrDefaultInt(key string, defaultVal int) int {
|
|||||||
}
|
}
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseCSV(s string) []string {
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
parts := strings.Split(s, ",")
|
||||||
|
result := make([]string, 0, len(parts))
|
||||||
|
for _, p := range parts {
|
||||||
|
p = strings.TrimSpace(p)
|
||||||
|
if p != "" {
|
||||||
|
result = append(result, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|||||||
43
server/cors.go
Normal file
43
server/cors.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CORSMiddleware adds CORS headers based on allowed origins.
|
||||||
|
// If allowedOrigins is empty or contains "*", all origins are allowed.
|
||||||
|
func CORSMiddleware(allowedOrigins []string, next http.Handler) http.Handler {
|
||||||
|
allowAll := len(allowedOrigins) == 0
|
||||||
|
originSet := make(map[string]bool, len(allowedOrigins))
|
||||||
|
for _, o := range allowedOrigins {
|
||||||
|
o = strings.TrimSpace(o)
|
||||||
|
if o == "*" {
|
||||||
|
allowAll = true
|
||||||
|
}
|
||||||
|
originSet[o] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
if origin == "" {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if allowAll || originSet[origin] {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||||
|
w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, PATCH, POST, DELETE, OPTIONS")
|
||||||
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||||
|
w.Header().Set("Access-Control-Max-Age", "3600")
|
||||||
|
w.Header().Set("Vary", "Origin")
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Method == http.MethodOptions {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -48,7 +48,8 @@ func ParseMode(s string) (Mode, bool) {
|
|||||||
type AppConfig struct {
|
type AppConfig struct {
|
||||||
ConfigDir string
|
ConfigDir string
|
||||||
ListenAddr string
|
ListenAddr string
|
||||||
MaxBackups int
|
MaxBackups int
|
||||||
|
CORSOrigins []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Server is the main HTTP server.
|
// Server is the main HTTP server.
|
||||||
@@ -73,7 +74,7 @@ func New(cfg AppConfig, auditLog *audit.Logger) *Server {
|
|||||||
|
|
||||||
s.srv = &http.Server{
|
s.srv = &http.Server{
|
||||||
Addr: cfg.ListenAddr,
|
Addr: cfg.ListenAddr,
|
||||||
Handler: LoggingMiddleware(mux),
|
Handler: CORSMiddleware(cfg.CORSOrigins, LoggingMiddleware(mux)),
|
||||||
ReadTimeout: 10 * time.Second,
|
ReadTimeout: 10 * time.Second,
|
||||||
WriteTimeout: 30 * time.Second,
|
WriteTimeout: 30 * time.Second,
|
||||||
IdleTimeout: 60 * time.Second,
|
IdleTimeout: 60 * time.Second,
|
||||||
|
|||||||
Reference in New Issue
Block a user