From 55a96767bb12b6e4172ce6b741c051556b33b805 Mon Sep 17 00:00:00 2001 From: Konstantin Azizov Date: Mon, 7 May 2018 20:01:39 +0200 Subject: [PATCH] feat(security): add request rate limiter on authentication endpoint (#1866) --- api/http/handler/auth.go | 4 +- api/http/security/rate_limiter.go | 47 ++++++++++++++++++ api/http/security/rate_limiter_test.go | 69 ++++++++++++++++++++++++++ api/http/server.go | 5 +- 4 files changed, 122 insertions(+), 3 deletions(-) create mode 100644 api/http/security/rate_limiter.go create mode 100644 api/http/security/rate_limiter_test.go diff --git a/api/http/handler/auth.go b/api/http/handler/auth.go index eb5e86c00..4b75967e9 100644 --- a/api/http/handler/auth.go +++ b/api/http/handler/auth.go @@ -37,14 +37,14 @@ const ( ) // NewAuthHandler returns a new instance of AuthHandler. -func NewAuthHandler(bouncer *security.RequestBouncer, authDisabled bool) *AuthHandler { +func NewAuthHandler(bouncer *security.RequestBouncer, rateLimiter *security.RateLimiter, authDisabled bool) *AuthHandler { h := &AuthHandler{ Router: mux.NewRouter(), Logger: log.New(os.Stderr, "", log.LstdFlags), authDisabled: authDisabled, } h.Handle("/auth", - bouncer.PublicAccess(http.HandlerFunc(h.handlePostAuth))).Methods(http.MethodPost) + rateLimiter.LimitAccess(bouncer.PublicAccess(http.HandlerFunc(h.handlePostAuth)))).Methods(http.MethodPost) return h } diff --git a/api/http/security/rate_limiter.go b/api/http/security/rate_limiter.go new file mode 100644 index 000000000..0eb89e0c1 --- /dev/null +++ b/api/http/security/rate_limiter.go @@ -0,0 +1,47 @@ +package security + +import ( + "net/http" + "strings" + "time" + + "github.com/g07cha/defender" + "github.com/portainer/portainer" + httperror "github.com/portainer/portainer/http/error" +) + +// RateLimiter represents an entity that manages request rate limiting +type RateLimiter struct { + *defender.Defender +} + +// NewRateLimiter initializes a new RateLimiter +func NewRateLimiter(maxRequests int, duration time.Duration, banDuration time.Duration) *RateLimiter { + messages := make(chan struct{}) + limiter := defender.New(maxRequests, duration, banDuration) + go limiter.CleanupTask(messages) + return &RateLimiter{ + limiter, + } +} + +// LimitAccess wraps current request with check if remote address does not goes above the defined limits +func (limiter *RateLimiter) LimitAccess(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip := StripAddrPort(r.RemoteAddr) + if banned := limiter.Inc(ip); banned == true { + httperror.WriteErrorResponse(w, portainer.ErrResourceAccessDenied, http.StatusForbidden, nil) + return + } + next.ServeHTTP(w, r) + }) +} + +// StripAddrPort removes port from IP address +func StripAddrPort(addr string) string { + portIndex := strings.LastIndex(addr, ":") + if portIndex != -1 { + addr = addr[:portIndex] + } + return addr +} diff --git a/api/http/security/rate_limiter_test.go b/api/http/security/rate_limiter_test.go new file mode 100644 index 000000000..49fc79030 --- /dev/null +++ b/api/http/security/rate_limiter_test.go @@ -0,0 +1,69 @@ +package security + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestLimitAccess(t *testing.T) { + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + t.Run("Request below the limit", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + rr := httptest.NewRecorder() + rateLimiter := NewRateLimiter(10, 1*time.Second, 1*time.Hour) + handler := rateLimiter.LimitAccess(testHandler) + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + }) + + t.Run("Request above the limit", func(t *testing.T) { + rateLimiter := NewRateLimiter(1, 1*time.Second, 1*time.Hour) + handler := rateLimiter.LimitAccess(testHandler) + + ts := httptest.NewServer(handler) + defer ts.Close() + http.Get(ts.URL) + resp, err := http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + if status := resp.StatusCode; status != http.StatusForbidden { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusForbidden) + } + }) +} + +func TestStripAddrPort(t *testing.T) { + t.Run("IP with port", func(t *testing.T) { + result := StripAddrPort("127.0.0.1:1000") + if result != "127.0.0.1" { + t.Errorf("Expected IP with address to be '127.0.0.1', but it was %s instead", result) + } + }) + + t.Run("IP without port", func(t *testing.T) { + result := StripAddrPort("127.0.0.1") + if result != "127.0.0.1" { + t.Errorf("Expected IP with address to be '127.0.0.1', but it was %s instead", result) + } + }) + + t.Run("Local IP", func(t *testing.T) { + result := StripAddrPort("[::1]:1000") + if result != "[::1]" { + t.Errorf("Expected IP with address to be '[::1]', but it was %s instead", result) + } + }) +} diff --git a/api/http/server.go b/api/http/server.go index d893a5eaf..5e85f8efa 100644 --- a/api/http/server.go +++ b/api/http/server.go @@ -1,6 +1,8 @@ package http import ( + "time" + "github.com/portainer/portainer" "github.com/portainer/portainer/http/handler" "github.com/portainer/portainer/http/handler/extensions" @@ -53,9 +55,10 @@ func (server *Server) Start() error { SignatureService: server.SignatureService, } proxyManager := proxy.NewManager(proxyManagerParameters) + rateLimiter := security.NewRateLimiter(10, 1*time.Second, 1*time.Hour) var fileHandler = handler.NewFileHandler(filepath.Join(server.AssetsPath, "public")) - var authHandler = handler.NewAuthHandler(requestBouncer, server.AuthDisabled) + var authHandler = handler.NewAuthHandler(requestBouncer, rateLimiter, server.AuthDisabled) authHandler.UserService = server.UserService authHandler.CryptoService = server.CryptoService authHandler.JWTService = server.JWTService