- 基本中间件结构:
// 基本中间件函数签名
func MiddlewareName() gin.HandlerFunc {
return func(c *gin.Context) {
// 处理请求前的逻辑
c.Next() // 处理下一个中间件或处理函数
// 处理请求后的逻辑
}
}
- 常用中间件示例:
package middleware
import (
"github.com/gin-gonic/gin"
"log"
"time"
)
// 日志中间件
func Logger() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
c.Next()
latency := time.Since(start)
statusCode := c.Writer.Status()
log.Printf("Path: %s | Status: %d | Latency: %v", path, statusCode, latency)
}
}
// 认证中间件
func Auth() gin.HandlerFunc {
return func(c *gin.Context) {
token := c.GetHeader("Authorization")
if token == "" {
c.JSON(401, gin.H{"error": "Authorization token required"})
c.Abort()
return
}
// 验证 token
if !validateToken(token) {
c.JSON(401, gin.H{"error": "Invalid token"})
c.Abort()
return
}
c.Next()
}
}
// 错误处理中间件
func ErrorHandler() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
c.JSON(500, gin.H{
"error": "Internal Server Error",
})
}
}()
c.Next()
}
}
// CORS 中间件
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}
- 中间件的使用方式:
package main
import (
"github.com/gin-gonic/gin"
"your-project/middleware"
)
func main() {
r := gin.Default()
// 全局中间件
r.Use(middleware.Logger())
r.Use(middleware.ErrorHandler())
// 路由组中间件
authorized := r.Group("/api")
authorized.Use(middleware.Auth())
{
authorized.GET("/users", GetUsers)
authorized.POST("/users", CreateUser)
}
// 单个路由中间件
r.GET("/public", middleware.CORS(), PublicHandler)
r.Run(":8080")
}
- 带配置的中间件:
// 限流中间件
func RateLimit(limit int, duration time.Duration) gin.HandlerFunc {
limiter := rate.NewLimiter(rate.Every(duration), limit)
return func(c *gin.Context) {
if !limiter.Allow() {
c.JSON(429, gin.H{"error": "Too many requests"})
c.Abort()
return
}
c.Next()
}
}
// 缓存中间件
func Cache(duration time.Duration) gin.HandlerFunc {
cache := make(map[string]cacheEntry)
type cacheEntry struct {
data interface{}
timestamp time.Time
}
return func(c *gin.Context) {
key := c.Request.URL.Path
// 检查缓存
if entry, exists := cache[key]; exists {
if time.Since(entry.timestamp) < duration {
c.JSON(200, entry.data)
c.Abort()
return
}
}
c.Next()
// 存储响应到缓存
cache[key] = cacheEntry{
data: c.Keys["response"],
timestamp: time.Now(),
}
}
}
- 自定义上下文:
// 自定义上下文
type CustomContext struct {
*gin.Context
User *models.User
}
// 用户上下文中间件
func UserContext() gin.HandlerFunc {
return func(c *gin.Context) {
user := getCurrentUser(c)
customContext := &CustomContext{
Context: c,
User: user,
}
c.Set("custom_context", customContext)
c.Next()
}
}
- 链式中间件:
// 请求链路追踪
func RequestTracing() gin.HandlerFunc {
return func(c *gin.Context) {
traceID := generateTraceID()
c.Set("trace_id", traceID)
c.Next()
// 记录请求完成
log.Printf("Request completed: %s", traceID)
}
}
// 性能监控中间件
func Performance() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
c.Next()
duration := time.Since(start)
if duration > time.Second {
log.Printf("Slow request: %s took %v", c.Request.URL.Path, duration)
}
}
}
- 中间件通信:
// 设置和获取中间件数据
func DataMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 设置数据
c.Set("key", "value")
c.Next()
// 获取数据
value, exists := c.Get("key")
if exists {
// 使用数据
}
}
}
// 响应修改中间件
func ResponseModifier() gin.HandlerFunc {
return func(c *gin.Context) {
c.Next()
// 修改响应
status := c.Writer.Status()
if status >= 400 {
c.JSON(status, gin.H{
"error": true,
"message": c.Keys["error"],
})
}
}
}
- 高级中间件示例:
// JWT 认证中间件
func JWTAuth() gin.HandlerFunc {
return func(c *gin.Context) {
token := c.GetHeader("Authorization")
if token == "" {
c.JSON(401, gin.H{"error": "No token provided"})
c.Abort()
return
}
claims, err := parseToken(token)
if err != nil {
c.JSON(401, gin.H{"error": "Invalid token"})
c.Abort()
return
}
c.Set("user_id", claims.UserID)
c.Next()
}
}
// 请求验证中间件
func ValidateRequest(schema interface{}) gin.HandlerFunc {
return func(c *gin.Context) {
if err := c.ShouldBindJSON(schema); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
c.Abort()
return
}
c.Set("validated_data", schema)
c.Next()
}
}
// 请求限制中间件
func RequestLimit(maxRequests int, duration time.Duration) gin.HandlerFunc {
clients := make(map[string][]time.Time)
mu := &sync.Mutex{}
return func(c *gin.Context) {
ip := c.ClientIP()
now := time.Now()
mu.Lock()
defer mu.Unlock()
// 清理过期的请求记录
clients[ip] = filterOldRequests(clients[ip], now.Add(-duration))
if len(clients[ip]) >= maxRequests {
c.JSON(429, gin.H{"error": "Too many requests"})
c.Abort()
return
}
clients[ip] = append(clients[ip], now)
c.Next()
}
}
- 测试中间件:
func TestAuthMiddleware(t *testing.T) {
// 创建测试路由
r := gin.New()
r.Use(Auth())
r.GET("/test", func(c *gin.Context) {
c.JSON(200, gin.H{"message": "success"})
})
// 创建测试请求
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
// 不带 token 的请求
r.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
// 带有效 token 的请求
w = httptest.NewRecorder()
req.Header.Set("Authorization", "valid-token")
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
- 中间件最佳实践:
// 中间件工厂
type MiddlewareConfig struct {
EnableLogging bool
LogLevel string
Timeout time.Duration
}
func NewMiddleware(config MiddlewareConfig) gin.HandlerFunc {
return func(c *gin.Context) {
if config.EnableLogging {
// 启用日志
}
// 设置超时
ctx, cancel := context.WithTimeout(c.Request.Context(), config.Timeout)
defer cancel()
c.Request = c.Request.WithContext(ctx)
done := make(chan bool)
go func() {
c.Next()
done <- true
}()
select {
case <-done:
return
case <-ctx.Done():
c.JSON(504, gin.H{"error": "request timeout"})
c.Abort()
}
}
}
使用这些中间件时,需要注意:
- 中间件的执行顺序很重要
- 合理使用
c.Next()
和c.Abort()
- 避免在中间件中存储太多数据
- 注意性能影响
- 做好错误处理
- 保持中间件的独立性和可复用性