Gin 框架中间件详细介绍

时间:2024-11-18 12:08:15
  1. 基本中间件结构:
// 基本中间件函数签名
func MiddlewareName() gin.HandlerFunc {
    return func(c *gin.Context) {
        // 处理请求前的逻辑
        
        c.Next()  // 处理下一个中间件或处理函数
        
        // 处理请求后的逻辑
    }
}
  1. 常用中间件示例:
package middleware

import (
    "github.com/gin-gonic/gin"
    "log"
    "time"
)

// 日志中间件
func Logger() gin.HandlerFunc {
    return func(c *gin.Context) {
        start := time.Now()
        path := c.Request.URL.Path
        
        c.Next()
        
        latency := time.Since(start)
        statusCode := c.Writer.Status()
        log.Printf("Path: %s | Status: %d | Latency: %v", path, statusCode, latency)
    }
}

// 认证中间件
func Auth() gin.HandlerFunc {
    return func(c *gin.Context) {
        token := c.GetHeader("Authorization")
        if token == "" {
            c.JSON(401, gin.H{"error": "Authorization token required"})
            c.Abort()
            return
        }
        
        // 验证 token
        if !validateToken(token) {
            c.JSON(401, gin.H{"error": "Invalid token"})
            c.Abort()
            return
        }
        
        c.Next()
    }
}

// 错误处理中间件
func ErrorHandler() gin.HandlerFunc {
    return func(c *gin.Context) {
        defer func() {
            if err := recover(); err != nil {
                c.JSON(500, gin.H{
                    "error": "Internal Server Error",
                })
            }
        }()
        c.Next()
    }
}

// CORS 中间件
func CORS() gin.HandlerFunc {
    return func(c *gin.Context) {
        c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
        c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
        c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
        
        if c.Request.Method == "OPTIONS" {
            c.AbortWithStatus(204)
            return
        }
        
        c.Next()
    }
}
  1. 中间件的使用方式:
package main

import (
    "github.com/gin-gonic/gin"
    "your-project/middleware"
)

func main() {
    r := gin.Default()
    
    // 全局中间件
    r.Use(middleware.Logger())
    r.Use(middleware.ErrorHandler())
    
    // 路由组中间件
    authorized := r.Group("/api")
    authorized.Use(middleware.Auth())
    {
        authorized.GET("/users", GetUsers)
        authorized.POST("/users", CreateUser)
    }
    
    // 单个路由中间件
    r.GET("/public", middleware.CORS(), PublicHandler)
    
    r.Run(":8080")
}
  1. 带配置的中间件:
// 限流中间件
func RateLimit(limit int, duration time.Duration) gin.HandlerFunc {
    limiter := rate.NewLimiter(rate.Every(duration), limit)
    
    return func(c *gin.Context) {
        if !limiter.Allow() {
            c.JSON(429, gin.H{"error": "Too many requests"})
            c.Abort()
            return
        }
        c.Next()
    }
}

// 缓存中间件
func Cache(duration time.Duration) gin.HandlerFunc {
    cache := make(map[string]cacheEntry)
    
    type cacheEntry struct {
        data      interface{}
        timestamp time.Time
    }
    
    return func(c *gin.Context) {
        key := c.Request.URL.Path
        
        // 检查缓存
        if entry, exists := cache[key]; exists {
            if time.Since(entry.timestamp) < duration {
                c.JSON(200, entry.data)
                c.Abort()
                return
            }
        }
        
        c.Next()
        
        // 存储响应到缓存
        cache[key] = cacheEntry{
            data:      c.Keys["response"],
            timestamp: time.Now(),
        }
    }
}
  1. 自定义上下文:
// 自定义上下文
type CustomContext struct {
    *gin.Context
    User *models.User
}

// 用户上下文中间件
func UserContext() gin.HandlerFunc {
    return func(c *gin.Context) {
        user := getCurrentUser(c)
        customContext := &CustomContext{
            Context: c,
            User:    user,
        }
        
        c.Set("custom_context", customContext)
        c.Next()
    }
}
  1. 链式中间件:
// 请求链路追踪
func RequestTracing() gin.HandlerFunc {
    return func(c *gin.Context) {
        traceID := generateTraceID()
        c.Set("trace_id", traceID)
        
        c.Next()
        
        // 记录请求完成
        log.Printf("Request completed: %s", traceID)
    }
}

// 性能监控中间件
func Performance() gin.HandlerFunc {
    return func(c *gin.Context) {
        start := time.Now()
        
        c.Next()
        
        duration := time.Since(start)
        if duration > time.Second {
            log.Printf("Slow request: %s took %v", c.Request.URL.Path, duration)
        }
    }
}
  1. 中间件通信:
// 设置和获取中间件数据
func DataMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        // 设置数据
        c.Set("key", "value")
        
        c.Next()
        
        // 获取数据
        value, exists := c.Get("key")
        if exists {
            // 使用数据
        }
    }
}

// 响应修改中间件
func ResponseModifier() gin.HandlerFunc {
    return func(c *gin.Context) {
        c.Next()
        
        // 修改响应
        status := c.Writer.Status()
        if status >= 400 {
            c.JSON(status, gin.H{
                "error":   true,
                "message": c.Keys["error"],
            })
        }
    }
}
  1. 高级中间件示例:
// JWT 认证中间件
func JWTAuth() gin.HandlerFunc {
    return func(c *gin.Context) {
        token := c.GetHeader("Authorization")
        if token == "" {
            c.JSON(401, gin.H{"error": "No token provided"})
            c.Abort()
            return
        }

        claims, err := parseToken(token)
        if err != nil {
            c.JSON(401, gin.H{"error": "Invalid token"})
            c.Abort()
            return
        }

        c.Set("user_id", claims.UserID)
        c.Next()
    }
}

// 请求验证中间件
func ValidateRequest(schema interface{}) gin.HandlerFunc {
    return func(c *gin.Context) {
        if err := c.ShouldBindJSON(schema); err != nil {
            c.JSON(400, gin.H{"error": err.Error()})
            c.Abort()
            return
        }
        c.Set("validated_data", schema)
        c.Next()
    }
}

// 请求限制中间件
func RequestLimit(maxRequests int, duration time.Duration) gin.HandlerFunc {
    clients := make(map[string][]time.Time)
    mu := &sync.Mutex{}

    return func(c *gin.Context) {
        ip := c.ClientIP()
        now := time.Now()

        mu.Lock()
        defer mu.Unlock()

        // 清理过期的请求记录
        clients[ip] = filterOldRequests(clients[ip], now.Add(-duration))

        if len(clients[ip]) >= maxRequests {
            c.JSON(429, gin.H{"error": "Too many requests"})
            c.Abort()
            return
        }

        clients[ip] = append(clients[ip], now)
        c.Next()
    }
}
  1. 测试中间件:
func TestAuthMiddleware(t *testing.T) {
    // 创建测试路由
    r := gin.New()
    r.Use(Auth())
    r.GET("/test", func(c *gin.Context) {
        c.JSON(200, gin.H{"message": "success"})
    })

    // 创建测试请求
    w := httptest.NewRecorder()
    req, _ := http.NewRequest("GET", "/test", nil)
    
    // 不带 token 的请求
    r.ServeHTTP(w, req)
    assert.Equal(t, 401, w.Code)
    
    // 带有效 token 的请求
    w = httptest.NewRecorder()
    req.Header.Set("Authorization", "valid-token")
    r.ServeHTTP(w, req)
    assert.Equal(t, 200, w.Code)
}
  1. 中间件最佳实践:
// 中间件工厂
type MiddlewareConfig struct {
    EnableLogging bool
    LogLevel     string
    Timeout      time.Duration
}

func NewMiddleware(config MiddlewareConfig) gin.HandlerFunc {
    return func(c *gin.Context) {
        if config.EnableLogging {
            // 启用日志
        }
        
        // 设置超时
        ctx, cancel := context.WithTimeout(c.Request.Context(), config.Timeout)
        defer cancel()
        c.Request = c.Request.WithContext(ctx)
        
        done := make(chan bool)
        go func() {
            c.Next()
            done <- true
        }()
        
        select {
        case <-done:
            return
        case <-ctx.Done():
            c.JSON(504, gin.H{"error": "request timeout"})
            c.Abort()
        }
    }
}

使用这些中间件时,需要注意:

  • 中间件的执行顺序很重要
  • 合理使用 c.Next()c.Abort()
  • 避免在中间件中存储太多数据
  • 注意性能影响
  • 做好错误处理
  • 保持中间件的独立性和可复用性