package middleware
import (
"net/http"
"time"
)
const (
errCapacityExceeded = "Server capacity exceeded."
errTimedOut = "Timed out while waiting for a pending request to complete."
errContextCanceled = "Context was canceled."
)
var (
defaultBacklogTimeout = time.Second * 60
)
// Throttle is a middleware that limits number of currently processed requests
// at a time across all users. Note: Throttle is not a rate-limiter per user,
// instead it just puts a ceiling on the number of currentl in-flight requests
// being processed from the point from where the Throttle middleware is mounted.
func Throttle(limit int) func(http.Handler) http.Handler {
return ThrottleBacklog(limit, 0, defaultBacklogTimeout)
}
// ThrottleBacklog is a middleware that limits number of currently processed
// requests at a time and provides a backlog for holding a finite number of
// pending requests.
func ThrottleBacklog(limit int, backlogLimit int, backlogTimeout time.Duration) func(http.Handler) http.Handler {
if limit < 1 {
panic("chi/middleware: Throttle expects limit > 0")
}
if backlogLimit < 0 {
panic("chi/middleware: Throttle expects backlogLimit to be positive")
}
t := throttler{
tokens: make(chan token, limit),
backlogTokens: make(chan token, limit+backlogLimit),
backlogTimeout: backlogTimeout,
}
// Filling tokens.
for i := 0; i < limit+backlogLimit; i++ {
if i < limit {
t.tokens <- token{}
}
t.backlogTokens <- token{}
}
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
select {
case <-ctx.Done():
http.Error(w, errContextCanceled, http.StatusServiceUnavailable)
return
case btok := <-t.backlogTokens:
timer := time.NewTimer(t.backlogTimeout)
defer func() {
t.backlogTokens <- btok
}()
select {
case <-timer.C:
http.Error(w, errTimedOut, http.StatusServiceUnavailable)
return
case <-ctx.Done():
timer.Stop()
http.Error(w, errContextCanceled, http.StatusServiceUnavailable)
return
case tok := <-t.tokens:
defer func() {
timer.Stop()
t.tokens <- tok
}()
next.ServeHTTP(w, r)
}
return
default:
http.Error(w, errCapacityExceeded, http.StatusServiceUnavailable)
return
}
}
return http.HandlerFunc(fn)
}
}
// token represents a request that is being processed.
type token struct{}
// throttler limits number of currently processed requests at a time.
type throttler struct {
tokens chan token
backlogTokens chan token
backlogTimeout time.Duration
}