package middleware import ( "fmt" "net/http" "strconv" "strings" "time" "github.com/gin-gonic/gin" "github.com/rs/zerolog" "go.uber.org/fx" "base/config" "base/pkg/jwt" "base/pkg/metrics" ) type Middleware interface { Metrics() gin.HandlerFunc FileSizeLimit(maxSize int64) gin.HandlerFunc AuthShield() gin.HandlerFunc } type middleware struct { metrics *metrics.Metrics logger zerolog.Logger config *config.AppConfig tokenService jwt.TokenService } type Param struct { Metrics *metrics.Metrics Logger zerolog.Logger Config *config.AppConfig fx.In } const ( UserIDKey = "userID" ) func NewMiddleware(lc fx.Lifecycle, param Param) Middleware { lc.Append(fx.Hook{}) return &middleware{ metrics: param.Metrics, logger: param.Logger, config: param.Config, tokenService: jwt.New(param.Config.JWT.Secret, param.Config.JWT.AccessTokenExpiration, param.Config.JWT.RefreshTokenExpiration), } } func (m *middleware) AuthShield() gin.HandlerFunc { return func(c *gin.Context) { var accessToken string // Fallback to Authorization header authorizationHeader := c.GetHeader("Authorization") if authorizationHeader == "" { m.logger.Warn(). Str("path", c.Request.URL.Path). Msg("Authorization header is empty") c.JSON(http.StatusUnauthorized, gin.H{ "message": "unauthorized", "status": http.StatusUnauthorized, }) c.Abort() return } parts := strings.SplitN(authorizationHeader, " ", 2) if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { m.logger.Warn(). Str("header", authorizationHeader). Str("path", c.Request.URL.Path). Msg("Authorization header format is invalid") c.JSON(http.StatusUnauthorized, gin.H{ "message": "unauthorized", "status": http.StatusUnauthorized, }) c.Abort() return } accessToken = parts[1] if accessToken == "" { m.logger.Warn(). Str("path", c.Request.URL.Path). Msg("Authorization token is empty") c.JSON(http.StatusUnauthorized, gin.H{ "message": "unauthorized", "status": http.StatusUnauthorized, }) c.Abort() return } m.logger.Debug(). Str("path", c.Request.URL.Path). Msg("Using access token from Authorization header") // Verify token token, err := m.tokenService.VerifyToken(c.Request.Context(), accessToken) if err != nil { m.logger.Warn(). Err(err). Str("path", c.Request.URL.Path). Msg("Authorization token is invalid") c.JSON(http.StatusUnauthorized, gin.H{ "message": "unauthorized", "status": http.StatusUnauthorized, }) c.Abort() return } m.logger.Debug(). Str("sub", token.Sub). Str("path", c.Request.URL.Path). Msg("Authorization token is valid") c.Set(UserIDKey, token.Sub) c.Next() } } func (m *middleware) Metrics() gin.HandlerFunc { return func(c *gin.Context) { start := time.Now() recorder := &StatusRecorder{ ResponseWriter: c.Writer, statusCode: http.StatusOK, // Default status code } // Replace the original ResponseWriter with the StatusRecorder c.Writer = recorder c.Next() statusCode := recorder.GetStatusCode() path := c.Request.URL.Path if path == "/health" || path == "/metrics" || path == "/health/live" || strings.Contains(path, "/swagger/") { return } // Normalize path to prevent metric cardinality explosion normalizedPath := m.metrics.NormalizePath(path) m.metrics.RecordHTTPRequest(c.Request.Method, normalizedPath, strconv.Itoa(statusCode), time.Since(start)) } } func (m *middleware) FileSizeLimit(maxSize int64) gin.HandlerFunc { return func(c *gin.Context) { // Check if this is a multipart form request if c.Request.MultipartForm == nil { // Parse multipart form to get file size if err := c.Request.ParseMultipartForm(maxSize); err != nil { if err.Error() == "http: request body too large" { m.logger.Warn(). Int64("maxSize", maxSize). Str("path", c.Request.URL.Path). Str("ip", c.ClientIP()). Msg("File size limit exceeded") c.JSON( http.StatusRequestEntityTooLarge, gin.H{ "error": fmt.Sprintf("File size exceeds the maximum allowed size of %d bytes", maxSize), }) c.Abort() return } // Other parsing errors should not block the request m.logger.Error().Err(err).Msg("Failed to parse multipart form") } } // Check individual file sizes if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil { for fieldName, files := range c.Request.MultipartForm.File { for _, file := range files { if file.Size > maxSize { m.logger.Warn(). Int64("fileSize", file.Size). Int64("maxSize", maxSize). Str("filename", file.Filename). Str("fieldName", fieldName). Str("path", c.Request.URL.Path). Str("ip", c.ClientIP()). Msg("File size limit exceeded") c.JSON( http.StatusRequestEntityTooLarge, gin.H{ "error": fmt.Sprintf("File '%s' size (%d bytes) exceeds the maximum allowed size of %d bytes", file.Filename, file.Size, maxSize), }) c.Abort() return } } } } c.Next() } }