From d7794a06b3889ca8d7c973a6cb123cd02283470d Mon Sep 17 00:00:00 2001 From: andres-portainer <91705312+andres-portainer@users.noreply.github.com> Date: Wed, 2 Jul 2025 21:00:39 -0300 Subject: [PATCH] feat(csrf): add trusted origins cli flags [BE-11972] (#856) Co-authored-by: oscarzhou Co-authored-by: andres-portainer Co-authored-by: Malcolm Lockyer --- api/cli/cli.go | 1 + api/cmd/portainer/main.go | 14 ++ api/http/csrf/csrf.go | 42 ++++- .../middlewares/plaintext_http_request.go | 42 ++++- .../plaintext_http_request_test.go | 173 ++++++++++++++++++ api/http/server.go | 3 +- api/portainer.go | 3 + pkg/validate/validate.go | 29 +++ pkg/validate/validate_test.go | 61 ++++++ 9 files changed, 359 insertions(+), 9 deletions(-) create mode 100644 api/http/middlewares/plaintext_http_request_test.go diff --git a/api/cli/cli.go b/api/cli/cli.go index f6035f298..7f6d80e68 100644 --- a/api/cli/cli.go +++ b/api/cli/cli.go @@ -61,6 +61,7 @@ func CLIFlags() *portainer.CLIFlags { LogMode: kingpin.Flag("log-mode", "Set the logging output mode").Default("PRETTY").Enum("NOCOLOR", "PRETTY", "JSON"), KubectlShellImage: kingpin.Flag("kubectl-shell-image", "Kubectl shell image").Envar(portainer.KubectlShellImageEnvVar).Default(portainer.DefaultKubectlShellImage).String(), PullLimitCheckDisabled: kingpin.Flag("pull-limit-check-disabled", "Pull limit check").Envar(portainer.PullLimitCheckDisabledEnvVar).Default(defaultPullLimitCheckDisabled).Bool(), + TrustedOrigins: kingpin.Flag("trusted-origins", "List of trusted origins for CSRF protection. Separate multiple origins with a comma.").Envar(portainer.TrustedOriginsEnvVar).String(), } } diff --git a/api/cmd/portainer/main.go b/api/cmd/portainer/main.go index 6261efbd9..90fd34fd2 100644 --- a/api/cmd/portainer/main.go +++ b/api/cmd/portainer/main.go @@ -52,6 +52,7 @@ import ( "github.com/portainer/portainer/pkg/libhelm" libhelmtypes "github.com/portainer/portainer/pkg/libhelm/types" "github.com/portainer/portainer/pkg/libstack/compose" + "github.com/portainer/portainer/pkg/validate" "github.com/gofrs/uuid" "github.com/rs/zerolog/log" @@ -330,6 +331,18 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server { featureflags.Parse(*flags.FeatureFlags, portainer.SupportedFeatureFlags) } + trustedOrigins := []string{} + if *flags.TrustedOrigins != "" { + // validate if the trusted origins are valid urls + for _, origin := range strings.Split(*flags.TrustedOrigins, ",") { + if !validate.IsTrustedOrigin(origin) { + log.Fatal().Str("trusted_origin", origin).Msg("invalid url for trusted origin. Please check the trusted origins flag.") + } + + trustedOrigins = append(trustedOrigins, origin) + } + } + fileService := initFileService(*flags.Data) encryptionKey := loadEncryptionSecretKey(*flags.SecretKeyName) if encryptionKey == nil { @@ -578,6 +591,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server { PendingActionsService: pendingActionsService, PlatformService: platformService, PullLimitCheckDisabled: *flags.PullLimitCheckDisabled, + TrustedOrigins: trustedOrigins, } } diff --git a/api/http/csrf/csrf.go b/api/http/csrf/csrf.go index 857d72c8b..6205c9290 100644 --- a/api/http/csrf/csrf.go +++ b/api/http/csrf/csrf.go @@ -2,6 +2,7 @@ package csrf import ( "crypto/rand" + "errors" "fmt" "net/http" "os" @@ -9,7 +10,8 @@ import ( "github.com/portainer/portainer/api/http/security" httperror "github.com/portainer/portainer/pkg/libhttp/error" - gorillacsrf "github.com/gorilla/csrf" + gcsrf "github.com/gorilla/csrf" + "github.com/rs/zerolog/log" "github.com/urfave/negroni" ) @@ -19,7 +21,7 @@ func SkipCSRFToken(w http.ResponseWriter) { w.Header().Set(csrfSkipHeader, "1") } -func WithProtect(handler http.Handler) (http.Handler, error) { +func WithProtect(handler http.Handler, trustedOrigins []string) (http.Handler, error) { // IsDockerDesktopExtension is used to check if we should skip csrf checks in the request bouncer (ShouldSkipCSRFCheck) // DOCKER_EXTENSION is set to '1' in build/docker-extension/docker-compose.yml isDockerDesktopExtension := false @@ -34,10 +36,12 @@ func WithProtect(handler http.Handler) (http.Handler, error) { return nil, fmt.Errorf("failed to generate CSRF token: %w", err) } - handler = gorillacsrf.Protect( + handler = gcsrf.Protect( token, - gorillacsrf.Path("/"), - gorillacsrf.Secure(false), + gcsrf.Path("/"), + gcsrf.Secure(false), + gcsrf.TrustedOrigins(trustedOrigins), + gcsrf.ErrorHandler(withErrorHandler(trustedOrigins)), )(handler) return withSkipCSRF(handler, isDockerDesktopExtension), nil @@ -55,7 +59,7 @@ func withSendCSRFToken(handler http.Handler) http.Handler { } if statusCode := sw.Status(); statusCode >= 200 && statusCode < 300 { - sw.Header().Set("X-CSRF-Token", gorillacsrf.Token(r)) + sw.Header().Set("X-CSRF-Token", gcsrf.Token(r)) } }) @@ -73,9 +77,33 @@ func withSkipCSRF(handler http.Handler, isDockerDesktopExtension bool) http.Hand } if skip { - r = gorillacsrf.UnsafeSkipCheck(r) + r = gcsrf.UnsafeSkipCheck(r) } handler.ServeHTTP(w, r) }) } + +func withErrorHandler(trustedOrigins []string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := gcsrf.FailureReason(r) + + if errors.Is(err, gcsrf.ErrBadOrigin) || errors.Is(err, gcsrf.ErrBadReferer) || errors.Is(err, gcsrf.ErrNoReferer) { + log.Error().Err(err). + Str("request_url", r.URL.String()). + Str("host", r.Host). + Str("x_forwarded_proto", r.Header.Get("X-Forwarded-Proto")). + Str("forwarded", r.Header.Get("Forwarded")). + Str("origin", r.Header.Get("Origin")). + Str("referer", r.Header.Get("Referer")). + Strs("trusted_origins", trustedOrigins). + Msg("Failed to validate Origin or Referer") + } + + http.Error( + w, + http.StatusText(http.StatusForbidden)+" - "+err.Error(), + http.StatusForbidden, + ) + }) +} diff --git a/api/http/middlewares/plaintext_http_request.go b/api/http/middlewares/plaintext_http_request.go index 668346098..e746fd819 100644 --- a/api/http/middlewares/plaintext_http_request.go +++ b/api/http/middlewares/plaintext_http_request.go @@ -3,6 +3,7 @@ package middlewares import ( "net/http" "slices" + "strings" "github.com/gorilla/csrf" ) @@ -16,6 +17,45 @@ type plainTextHTTPRequestHandler struct { next http.Handler } +// parseForwardedHeaderProto parses the Forwarded header and extracts the protocol. +// The Forwarded header format supports: +// - Single proxy: Forwarded: by=;for=;host=;proto= +// - Multiple proxies: Forwarded: for=192.0.2.43, for=198.51.100.17 +// We take the first (leftmost) entry as it represents the original client +func parseForwardedHeaderProto(forwarded string) string { + if forwarded == "" { + return "" + } + + // Parse the first part (leftmost proxy, closest to original client) + firstPart, _, _ := strings.Cut(forwarded, ",") + firstPart = strings.TrimSpace(firstPart) + + // Split by semicolon to get key-value pairs within this proxy entry + // Format: key=value;key=value;key=value + pairs := strings.Split(firstPart, ";") + for _, pair := range pairs { + // Split by equals sign to separate key and value + key, value, found := strings.Cut(pair, "=") + if !found { + continue + } + + if strings.EqualFold(strings.TrimSpace(key), "proto") { + return strings.Trim(strings.TrimSpace(value), `"'`) + } + } + + return "" +} + +// isHTTPSRequest checks if the original request was made over HTTPS +// by examining both X-Forwarded-Proto and Forwarded headers +func isHTTPSRequest(r *http.Request) bool { + return strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https") || + strings.EqualFold(parseForwardedHeaderProto(r.Header.Get("Forwarded")), "https") +} + func (h *plainTextHTTPRequestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if slices.Contains(safeMethods, r.Method) { h.next.ServeHTTP(w, r) @@ -24,7 +64,7 @@ func (h *plainTextHTTPRequestHandler) ServeHTTP(w http.ResponseWriter, r *http.R req := r // If original request was HTTPS (via proxy), keep CSRF checks. - if xfproto := r.Header.Get("X-Forwarded-Proto"); xfproto != "https" { + if !isHTTPSRequest(r) { req = csrf.PlaintextHTTPRequest(r) } diff --git a/api/http/middlewares/plaintext_http_request_test.go b/api/http/middlewares/plaintext_http_request_test.go new file mode 100644 index 000000000..33912be80 --- /dev/null +++ b/api/http/middlewares/plaintext_http_request_test.go @@ -0,0 +1,173 @@ +package middlewares + +import ( + "testing" +) + +var tests = []struct { + name string + forwarded string + expected string +}{ + { + name: "empty header", + forwarded: "", + expected: "", + }, + { + name: "single proxy with proto=https", + forwarded: "proto=https", + expected: "https", + }, + { + name: "single proxy with proto=http", + forwarded: "proto=http", + expected: "http", + }, + { + name: "single proxy with multiple directives", + forwarded: "for=192.0.2.60;proto=https;by=203.0.113.43", + expected: "https", + }, + { + name: "single proxy with proto in middle", + forwarded: "for=192.0.2.60;proto=https;host=example.com", + expected: "https", + }, + { + name: "single proxy with proto at end", + forwarded: "for=192.0.2.60;host=example.com;proto=https", + expected: "https", + }, + { + name: "multiple proxies - takes first", + forwarded: "proto=https, proto=http", + expected: "https", + }, + { + name: "multiple proxies with complex format", + forwarded: "for=192.0.2.43;proto=https, for=198.51.100.17;proto=http", + expected: "https", + }, + { + name: "multiple proxies with for directive only", + forwarded: "for=192.0.2.43, for=198.51.100.17", + expected: "", + }, + { + name: "multiple proxies with proto only in second", + forwarded: "for=192.0.2.43, proto=https", + expected: "", + }, + { + name: "multiple proxies with proto only in first", + forwarded: "proto=https, for=198.51.100.17", + expected: "https", + }, + { + name: "quoted protocol value", + forwarded: "proto=\"https\"", + expected: "https", + }, + { + name: "single quoted protocol value", + forwarded: "proto='https'", + expected: "https", + }, + { + name: "mixed case protocol", + forwarded: "proto=HTTPS", + expected: "HTTPS", + }, + { + name: "no proto directive", + forwarded: "for=192.0.2.60;by=203.0.113.43", + expected: "", + }, + { + name: "empty proto value", + forwarded: "proto=", + expected: "", + }, + { + name: "whitespace around values", + forwarded: " proto = https ", + expected: "https", + }, + { + name: "whitespace around semicolons", + forwarded: "for=192.0.2.60 ; proto=https ; by=203.0.113.43", + expected: "https", + }, + { + name: "whitespace around commas", + forwarded: "proto=https , proto=http", + expected: "https", + }, + { + name: "IPv6 address in for directive", + forwarded: "for=\"[2001:db8:cafe::17]:4711\";proto=https", + expected: "https", + }, + { + name: "complex multiple proxies with IPv6", + forwarded: "for=192.0.2.43;proto=https, for=\"[2001:db8:cafe::17]\";proto=http", + expected: "https", + }, + { + name: "obfuscated identifiers", + forwarded: "for=_mdn;proto=https", + expected: "https", + }, + { + name: "unknown identifier", + forwarded: "for=unknown;proto=https", + expected: "https", + }, + { + name: "malformed key-value pair", + forwarded: "proto", + expected: "", + }, + { + name: "malformed key-value pair with equals", + forwarded: "proto=", + expected: "", + }, + { + name: "multiple equals signs", + forwarded: "proto=https=extra", + expected: "https=extra", + }, + { + name: "mixed case directive name", + forwarded: "PROTO=https", + expected: "https", + }, + { + name: "mixed case directive name with spaces", + forwarded: " Proto = https ", + expected: "https", + }, +} + +func TestParseForwardedHeaderProto(t *testing.T) { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseForwardedHeaderProto(tt.forwarded) + if result != tt.expected { + t.Errorf("parseForwardedHeader(%q) = %q, want %q", tt.forwarded, result, tt.expected) + } + }) + } +} + +func FuzzParseForwardedHeaderProto(f *testing.F) { + for _, t := range tests { + f.Add(t.forwarded) + } + + f.Fuzz(func(t *testing.T, forwarded string) { + parseForwardedHeaderProto(forwarded) + }) +} diff --git a/api/http/server.go b/api/http/server.go index 88d131650..e876ca5d2 100644 --- a/api/http/server.go +++ b/api/http/server.go @@ -113,6 +113,7 @@ type Server struct { PendingActionsService *pendingactions.PendingActionsService PlatformService platform.Service PullLimitCheckDisabled bool + TrustedOrigins []string } // Start starts the HTTP server @@ -336,7 +337,7 @@ func (server *Server) Start() error { handler = middlewares.WithPanicLogger(middlewares.WithSlowRequestsLogger(handler)) - handler, err := csrf.WithProtect(handler) + handler, err := csrf.WithProtect(handler, server.TrustedOrigins) if err != nil { return errors.Wrap(err, "failed to create CSRF middleware") } diff --git a/api/portainer.go b/api/portainer.go index 2b11fb380..ed308a271 100644 --- a/api/portainer.go +++ b/api/portainer.go @@ -139,6 +139,7 @@ type ( LogMode *string KubectlShellImage *string PullLimitCheckDisabled *bool + TrustedOrigins *string } // CustomTemplateVariableDefinition @@ -1787,6 +1788,8 @@ const ( LicenseServerBaseURL = "https://api.portainer.io" // URL to validate licenses along with system metadata. LicenseCheckInURL = LicenseServerBaseURL + "/licenses/checkin" + // TrustedOriginsEnvVar is the environment variable used to set the trusted origins for CSRF protection + TrustedOriginsEnvVar = "TRUSTED_ORIGINS" ) // List of supported features diff --git a/pkg/validate/validate.go b/pkg/validate/validate.go index f647d1e17..8ad69df72 100644 --- a/pkg/validate/validate.go +++ b/pkg/validate/validate.go @@ -80,3 +80,32 @@ func IsDNSName(s string) bool { return !IsIP(s) && dnsNameRegex.MatchString(s) } + +func IsTrustedOrigin(s string) bool { + // Reject if a scheme is present + if strings.Contains(s, "://") { + return false + } + + // Prepend http:// for parsing + strTemp := "http://" + s + parsedOrigin, err := url.Parse(strTemp) + if err != nil { + return false + } + + // Validate host, and ensure no user, path, query, fragment, port, etc. + if parsedOrigin.Host == "" || + parsedOrigin.User != nil || + parsedOrigin.Path != "" || + parsedOrigin.RawQuery != "" || + parsedOrigin.Fragment != "" || + parsedOrigin.Opaque != "" || + parsedOrigin.RawFragment != "" || + parsedOrigin.RawPath != "" || + parsedOrigin.Port() != "" { + return false + } + + return true +} diff --git a/pkg/validate/validate_test.go b/pkg/validate/validate_test.go index f3cb6a01c..ca054190d 100644 --- a/pkg/validate/validate_test.go +++ b/pkg/validate/validate_test.go @@ -437,3 +437,64 @@ func Test_IsDNSName(t *testing.T) { }) } } + +func Test_IsTrustedOrigin(t *testing.T) { + f := func(s string, expected bool) { + t.Helper() + + result := IsTrustedOrigin(s) + if result != expected { + t.Fatalf("unexpected result for %q; got %t; want %t", s, result, expected) + } + } + + // Valid trusted origins - host only + f("localhost", true) + f("example.com", true) + f("192.168.1.1", true) + f("api.example.com", true) + f("subdomain.example.org", true) + + // Invalid trusted origins - host with port (no longer allowed) + f("localhost:8080", false) + f("example.com:3000", false) + f("192.168.1.1:443", false) + f("api.example.com:9000", false) + + // Invalid trusted origins - empty or malformed + f("", false) + f("invalid url", false) + f("://example.com", false) + + // Invalid trusted origins - with scheme + f("http://example.com", false) + f("https://localhost", false) + f("ftp://192.168.1.1", false) + + // Invalid trusted origins - with user info + f("user@example.com", false) + f("user:pass@localhost", false) + + // Invalid trusted origins - with path + f("example.com/path", false) + f("localhost/api", false) + f("192.168.1.1/static", false) + + // Invalid trusted origins - with query parameters + f("example.com?param=value", false) + f("localhost:8080?query=test", false) + + // Invalid trusted origins - with fragment + f("example.com#fragment", false) + f("localhost:3000#section", false) + + // Invalid trusted origins - with multiple invalid components + f("https://user@example.com/path?query=value#fragment", false) + f("http://localhost:8080/api/v1?param=test", false) + + // Edge cases - ports are no longer allowed + f("example.com:0", false) // port 0 is no longer valid + f("example.com:65535", false) // max port number is no longer valid + f("example.com:99999", false) // invalid port number + f("example.com:-1", false) // negative port +}