Files
gotail/backend/internal/middleware/middleware.go

210 lines
5.3 KiB
Go

package middleware
import (
"crypto/subtle"
"net/http"
"strings"
"sync"
"time"
"github.com/rs/zerolog/log"
"golang.org/x/time/rate"
)
// AuthMiddleware implementa autenticação HTTP Basic
type AuthMiddleware struct {
enabled bool
username string
password string
}
// NewAuthMiddleware cria uma nova instância do middleware de autenticação
func NewAuthMiddleware(enabled bool, username, password string) *AuthMiddleware {
return &AuthMiddleware{
enabled: enabled,
username: username,
password: password,
}
}
// Handler retorna o handler HTTP com autenticação
func (m *AuthMiddleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !m.enabled {
next.ServeHTTP(w, r)
return
}
user, pass, ok := r.BasicAuth()
if !ok {
m.requestAuth(w)
return
}
usernameMatch := subtle.ConstantTimeCompare([]byte(user), []byte(m.username)) == 1
passwordMatch := subtle.ConstantTimeCompare([]byte(pass), []byte(m.password)) == 1
if usernameMatch && passwordMatch {
next.ServeHTTP(w, r)
} else {
log.Warn().
Str("username", user).
Str("remote_addr", r.RemoteAddr).
Msg("Authentication failed")
m.requestAuth(w)
}
})
}
func (m *AuthMiddleware) requestAuth(w http.ResponseWriter) {
w.Header().Set("WWW-Authenticate", `Basic realm="Web Tail Pro - Restricted Access"`)
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("Unauthorized\n"))
}
// CORSMiddleware implementa CORS
type CORSMiddleware struct {
allowedOrigins []string
}
// NewCORSMiddleware cria uma nova instância do middleware CORS
func NewCORSMiddleware(origins []string) *CORSMiddleware {
return &CORSMiddleware{allowedOrigins: origins}
}
// Handler retorna o handler HTTP com CORS
func (m *CORSMiddleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if m.isAllowedOrigin(origin) {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
func (m *CORSMiddleware) isAllowedOrigin(origin string) bool {
if len(m.allowedOrigins) == 0 {
return true
}
for _, allowed := range m.allowedOrigins {
if allowed == "*" || allowed == origin {
return true
}
}
return false
}
// RateLimitMiddleware implementa rate limiting por IP
type RateLimitMiddleware struct {
limiters map[string]*rate.Limiter
mu sync.RWMutex
rps int
}
// NewRateLimitMiddleware cria uma nova instância do middleware de rate limiting
func NewRateLimitMiddleware(rps int) *RateLimitMiddleware {
return &RateLimitMiddleware{
limiters: make(map[string]*rate.Limiter),
rps: rps,
}
}
// Handler retorna o handler HTTP com rate limiting
func (m *RateLimitMiddleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := m.getIP(r)
limiter := m.getLimiter(ip)
if !limiter.Allow() {
log.Warn().Str("ip", ip).Str("path", r.URL.Path).Msg("Rate limit exceeded")
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
func (m *RateLimitMiddleware) getLimiter(ip string) *rate.Limiter {
m.mu.RLock()
limiter, exists := m.limiters[ip]
m.mu.RUnlock()
if !exists {
m.mu.Lock()
limiter = rate.NewLimiter(rate.Limit(m.rps), m.rps*2)
m.limiters[ip] = limiter
m.mu.Unlock()
}
return limiter
}
func (m *RateLimitMiddleware) getIP(r *http.Request) string {
// Tentar obter IP real de headers de proxy
ip := r.Header.Get("X-Forwarded-For")
if ip == "" {
ip = r.Header.Get("X-Real-IP")
}
if ip == "" {
ip = strings.Split(r.RemoteAddr, ":")[0]
}
return ip
}
// SecurityHeadersMiddleware adiciona headers de segurança
func SecurityHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
next.ServeHTTP(w, r)
})
}
// LoggingMiddleware registra todas as requisições HTTP
func LoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Wrapper para capturar status code
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(wrapped, r)
log.Info().
Str("method", r.Method).
Str("path", r.URL.Path).
Str("remote_addr", r.RemoteAddr).
Int("status", wrapped.statusCode).
Dur("duration", time.Since(start)).
Msg("HTTP request")
})
}
// responseWriter é um wrapper para capturar o status code
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}