66 lines
1.3 KiB
Go
66 lines
1.3 KiB
Go
package middleware
|
||
|
||
import (
|
||
"net/http"
|
||
"sync"
|
||
"time"
|
||
|
||
"golang.org/x/time/rate"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// RateLimiter 按 IP 的限流器
|
||
type RateLimiter struct {
|
||
mu sync.Mutex
|
||
clients map[string]*rate.Limiter
|
||
r rate.Limit
|
||
b int
|
||
}
|
||
|
||
// NewRateLimiter 创建限流中间件,r 每秒请求数,b 突发容量
|
||
func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
|
||
return &RateLimiter{
|
||
clients: make(map[string]*rate.Limiter),
|
||
r: r,
|
||
b: b,
|
||
}
|
||
}
|
||
|
||
// getLimiter 获取或创建该 key 的 limiter
|
||
func (rl *RateLimiter) getLimiter(key string) *rate.Limiter {
|
||
rl.mu.Lock()
|
||
defer rl.mu.Unlock()
|
||
if lim, ok := rl.clients[key]; ok {
|
||
return lim
|
||
}
|
||
lim := rate.NewLimiter(rl.r, rl.b)
|
||
rl.clients[key] = lim
|
||
return lim
|
||
}
|
||
|
||
// Middleware 返回 Gin 限流中间件(按客户端 IP)
|
||
func (rl *RateLimiter) Middleware() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
key := c.ClientIP()
|
||
lim := rl.getLimiter(key)
|
||
if !lim.Allow() {
|
||
c.AbortWithStatus(http.StatusTooManyRequests)
|
||
return
|
||
}
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// Cleanup 定期清理过期 limiter(可选,避免 map 无限增长)
|
||
func (rl *RateLimiter) Cleanup(interval time.Duration) {
|
||
ticker := time.NewTicker(interval)
|
||
go func() {
|
||
for range ticker.C {
|
||
rl.mu.Lock()
|
||
rl.clients = make(map[string]*rate.Limiter)
|
||
rl.mu.Unlock()
|
||
}
|
||
}()
|
||
}
|