212 lines
5.1 KiB
Go
212 lines
5.1 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/rs/zerolog"
|
|
"go.uber.org/fx"
|
|
|
|
"base/config"
|
|
"base/pkg/jwt"
|
|
"base/pkg/metrics"
|
|
)
|
|
|
|
type Middleware interface {
|
|
Metrics() gin.HandlerFunc
|
|
FileSizeLimit(maxSize int64) gin.HandlerFunc
|
|
AuthShield() gin.HandlerFunc
|
|
}
|
|
|
|
type middleware struct {
|
|
metrics *metrics.Metrics
|
|
logger zerolog.Logger
|
|
config *config.AppConfig
|
|
tokenService jwt.TokenService
|
|
}
|
|
|
|
type Param struct {
|
|
Metrics *metrics.Metrics
|
|
Logger zerolog.Logger
|
|
Config *config.AppConfig
|
|
|
|
fx.In
|
|
}
|
|
|
|
const (
|
|
UserIDKey = "userID"
|
|
)
|
|
|
|
func NewMiddleware(lc fx.Lifecycle, param Param) Middleware {
|
|
lc.Append(fx.Hook{})
|
|
|
|
return &middleware{
|
|
metrics: param.Metrics,
|
|
logger: param.Logger,
|
|
config: param.Config,
|
|
tokenService: jwt.New(param.Config.JWT.Secret, param.Config.JWT.AccessTokenExpiration, param.Config.JWT.RefreshTokenExpiration),
|
|
}
|
|
}
|
|
|
|
func (m *middleware) AuthShield() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
var accessToken string
|
|
|
|
// Fallback to Authorization header
|
|
authorizationHeader := c.GetHeader("Authorization")
|
|
if authorizationHeader == "" {
|
|
m.logger.Warn().
|
|
Str("path", c.Request.URL.Path).
|
|
Msg("Authorization header is empty")
|
|
c.JSON(http.StatusUnauthorized, gin.H{
|
|
"message": "unauthorized",
|
|
"status": http.StatusUnauthorized,
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
parts := strings.SplitN(authorizationHeader, " ", 2)
|
|
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
|
m.logger.Warn().
|
|
Str("header", authorizationHeader).
|
|
Str("path", c.Request.URL.Path).
|
|
Msg("Authorization header format is invalid")
|
|
c.JSON(http.StatusUnauthorized, gin.H{
|
|
"message": "unauthorized",
|
|
"status": http.StatusUnauthorized,
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
accessToken = parts[1]
|
|
if accessToken == "" {
|
|
m.logger.Warn().
|
|
Str("path", c.Request.URL.Path).
|
|
Msg("Authorization token is empty")
|
|
c.JSON(http.StatusUnauthorized, gin.H{
|
|
"message": "unauthorized",
|
|
"status": http.StatusUnauthorized,
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
m.logger.Debug().
|
|
Str("path", c.Request.URL.Path).
|
|
Msg("Using access token from Authorization header")
|
|
|
|
// Verify token
|
|
token, err := m.tokenService.VerifyToken(c.Request.Context(), accessToken)
|
|
if err != nil {
|
|
m.logger.Warn().
|
|
Err(err).
|
|
Str("path", c.Request.URL.Path).
|
|
Msg("Authorization token is invalid")
|
|
c.JSON(http.StatusUnauthorized, gin.H{
|
|
"message": "unauthorized",
|
|
"status": http.StatusUnauthorized,
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
m.logger.Debug().
|
|
Str("sub", token.Sub).
|
|
Str("path", c.Request.URL.Path).
|
|
Msg("Authorization token is valid")
|
|
|
|
c.Set(UserIDKey, token.Sub)
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func (m *middleware) Metrics() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
start := time.Now()
|
|
|
|
recorder := &StatusRecorder{
|
|
ResponseWriter: c.Writer,
|
|
statusCode: http.StatusOK, // Default status code
|
|
}
|
|
|
|
// Replace the original ResponseWriter with the StatusRecorder
|
|
c.Writer = recorder
|
|
|
|
c.Next()
|
|
|
|
statusCode := recorder.GetStatusCode()
|
|
|
|
path := c.Request.URL.Path
|
|
if path == "/health" || path == "/metrics" || path == "/health/live" || strings.Contains(path, "/swagger/") {
|
|
return
|
|
}
|
|
|
|
// Normalize path to prevent metric cardinality explosion
|
|
normalizedPath := m.metrics.NormalizePath(path)
|
|
m.metrics.RecordHTTPRequest(c.Request.Method, normalizedPath, strconv.Itoa(statusCode), time.Since(start))
|
|
}
|
|
}
|
|
|
|
func (m *middleware) FileSizeLimit(maxSize int64) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Check if this is a multipart form request
|
|
if c.Request.MultipartForm == nil {
|
|
// Parse multipart form to get file size
|
|
if err := c.Request.ParseMultipartForm(maxSize); err != nil {
|
|
if err.Error() == "http: request body too large" {
|
|
m.logger.Warn().
|
|
Int64("maxSize", maxSize).
|
|
Str("path", c.Request.URL.Path).
|
|
Str("ip", c.ClientIP()).
|
|
Msg("File size limit exceeded")
|
|
|
|
c.JSON(
|
|
http.StatusRequestEntityTooLarge,
|
|
gin.H{
|
|
"error": fmt.Sprintf("File size exceeds the maximum allowed size of %d bytes", maxSize),
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
// Other parsing errors should not block the request
|
|
m.logger.Error().Err(err).Msg("Failed to parse multipart form")
|
|
}
|
|
}
|
|
|
|
// Check individual file sizes
|
|
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
|
|
for fieldName, files := range c.Request.MultipartForm.File {
|
|
for _, file := range files {
|
|
if file.Size > maxSize {
|
|
m.logger.Warn().
|
|
Int64("fileSize", file.Size).
|
|
Int64("maxSize", maxSize).
|
|
Str("filename", file.Filename).
|
|
Str("fieldName", fieldName).
|
|
Str("path", c.Request.URL.Path).
|
|
Str("ip", c.ClientIP()).
|
|
Msg("File size limit exceeded")
|
|
|
|
c.JSON(
|
|
http.StatusRequestEntityTooLarge,
|
|
gin.H{
|
|
"error": fmt.Sprintf("File '%s' size (%d bytes) exceeds the maximum allowed size of %d bytes",
|
|
file.Filename, file.Size, maxSize),
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|