From 1e1998e269acc58e28d4dd1cbfc74b0bf36ee525 Mon Sep 17 00:00:00 2001 From: andres-portainer <91705312+andres-portainer@users.noreply.github.com> Date: Tue, 1 Jul 2025 21:38:02 -0300 Subject: [PATCH] feat(csrf): add trusted origins cli flags [BE-11972] (#839) Co-authored-by: oscarzhou Co-authored-by: andres-portainer --- api/cli/cli.go | 1 + api/cmd/portainer/main.go | 14 + api/http/csrf/csrf.go | 42 +- .../middlewares/plaintext_http_request.go | 42 +- .../plaintext_http_request_test.go | 173 ++++++ api/http/server.go | 3 +- api/portainer.go | 8 + pkg/validate/validate.go | 111 ++++ pkg/validate/validate_test.go | 500 ++++++++++++++++++ 9 files changed, 885 insertions(+), 9 deletions(-) create mode 100644 api/http/middlewares/plaintext_http_request_test.go create mode 100644 pkg/validate/validate.go create mode 100644 pkg/validate/validate_test.go diff --git a/api/cli/cli.go b/api/cli/cli.go index f6035f298..7f6d80e68 100644 --- a/api/cli/cli.go +++ b/api/cli/cli.go @@ -61,6 +61,7 @@ func CLIFlags() *portainer.CLIFlags { LogMode: kingpin.Flag("log-mode", "Set the logging output mode").Default("PRETTY").Enum("NOCOLOR", "PRETTY", "JSON"), KubectlShellImage: kingpin.Flag("kubectl-shell-image", "Kubectl shell image").Envar(portainer.KubectlShellImageEnvVar).Default(portainer.DefaultKubectlShellImage).String(), PullLimitCheckDisabled: kingpin.Flag("pull-limit-check-disabled", "Pull limit check").Envar(portainer.PullLimitCheckDisabledEnvVar).Default(defaultPullLimitCheckDisabled).Bool(), + TrustedOrigins: kingpin.Flag("trusted-origins", "List of trusted origins for CSRF protection. Separate multiple origins with a comma.").Envar(portainer.TrustedOriginsEnvVar).String(), } } diff --git a/api/cmd/portainer/main.go b/api/cmd/portainer/main.go index bdf2aabb1..b1d3051e9 100644 --- a/api/cmd/portainer/main.go +++ b/api/cmd/portainer/main.go @@ -50,6 +50,7 @@ import ( "github.com/portainer/portainer/pkg/featureflags" "github.com/portainer/portainer/pkg/libhelm" "github.com/portainer/portainer/pkg/libstack/compose" + "github.com/portainer/portainer/pkg/validate" "github.com/gofrs/uuid" "github.com/rs/zerolog/log" @@ -328,6 +329,18 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server { featureflags.Parse(*flags.FeatureFlags, portainer.SupportedFeatureFlags) } + trustedOrigins := []string{} + if *flags.TrustedOrigins != "" { + // validate if the trusted origins are valid urls + for _, origin := range strings.Split(*flags.TrustedOrigins, ",") { + if !validate.IsTrustedOrigin(origin) { + log.Fatal().Str("trusted_origin", origin).Msg("invalid url for trusted origin. Please check the trusted origins flag.") + } + + trustedOrigins = append(trustedOrigins, origin) + } + } + fileService := initFileService(*flags.Data) encryptionKey := loadEncryptionSecretKey(*flags.SecretKeyName) if encryptionKey == nil { @@ -576,6 +589,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server { PendingActionsService: pendingActionsService, PlatformService: platformService, PullLimitCheckDisabled: *flags.PullLimitCheckDisabled, + TrustedOrigins: trustedOrigins, } } diff --git a/api/http/csrf/csrf.go b/api/http/csrf/csrf.go index 857d72c8b..6205c9290 100644 --- a/api/http/csrf/csrf.go +++ b/api/http/csrf/csrf.go @@ -2,6 +2,7 @@ package csrf import ( "crypto/rand" + "errors" "fmt" "net/http" "os" @@ -9,7 +10,8 @@ import ( "github.com/portainer/portainer/api/http/security" httperror "github.com/portainer/portainer/pkg/libhttp/error" - gorillacsrf "github.com/gorilla/csrf" + gcsrf "github.com/gorilla/csrf" + "github.com/rs/zerolog/log" "github.com/urfave/negroni" ) @@ -19,7 +21,7 @@ func SkipCSRFToken(w http.ResponseWriter) { w.Header().Set(csrfSkipHeader, "1") } -func WithProtect(handler http.Handler) (http.Handler, error) { +func WithProtect(handler http.Handler, trustedOrigins []string) (http.Handler, error) { // IsDockerDesktopExtension is used to check if we should skip csrf checks in the request bouncer (ShouldSkipCSRFCheck) // DOCKER_EXTENSION is set to '1' in build/docker-extension/docker-compose.yml isDockerDesktopExtension := false @@ -34,10 +36,12 @@ func WithProtect(handler http.Handler) (http.Handler, error) { return nil, fmt.Errorf("failed to generate CSRF token: %w", err) } - handler = gorillacsrf.Protect( + handler = gcsrf.Protect( token, - gorillacsrf.Path("/"), - gorillacsrf.Secure(false), + gcsrf.Path("/"), + gcsrf.Secure(false), + gcsrf.TrustedOrigins(trustedOrigins), + gcsrf.ErrorHandler(withErrorHandler(trustedOrigins)), )(handler) return withSkipCSRF(handler, isDockerDesktopExtension), nil @@ -55,7 +59,7 @@ func withSendCSRFToken(handler http.Handler) http.Handler { } if statusCode := sw.Status(); statusCode >= 200 && statusCode < 300 { - sw.Header().Set("X-CSRF-Token", gorillacsrf.Token(r)) + sw.Header().Set("X-CSRF-Token", gcsrf.Token(r)) } }) @@ -73,9 +77,33 @@ func withSkipCSRF(handler http.Handler, isDockerDesktopExtension bool) http.Hand } if skip { - r = gorillacsrf.UnsafeSkipCheck(r) + r = gcsrf.UnsafeSkipCheck(r) } handler.ServeHTTP(w, r) }) } + +func withErrorHandler(trustedOrigins []string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := gcsrf.FailureReason(r) + + if errors.Is(err, gcsrf.ErrBadOrigin) || errors.Is(err, gcsrf.ErrBadReferer) || errors.Is(err, gcsrf.ErrNoReferer) { + log.Error().Err(err). + Str("request_url", r.URL.String()). + Str("host", r.Host). + Str("x_forwarded_proto", r.Header.Get("X-Forwarded-Proto")). + Str("forwarded", r.Header.Get("Forwarded")). + Str("origin", r.Header.Get("Origin")). + Str("referer", r.Header.Get("Referer")). + Strs("trusted_origins", trustedOrigins). + Msg("Failed to validate Origin or Referer") + } + + http.Error( + w, + http.StatusText(http.StatusForbidden)+" - "+err.Error(), + http.StatusForbidden, + ) + }) +} diff --git a/api/http/middlewares/plaintext_http_request.go b/api/http/middlewares/plaintext_http_request.go index 668346098..e746fd819 100644 --- a/api/http/middlewares/plaintext_http_request.go +++ b/api/http/middlewares/plaintext_http_request.go @@ -3,6 +3,7 @@ package middlewares import ( "net/http" "slices" + "strings" "github.com/gorilla/csrf" ) @@ -16,6 +17,45 @@ type plainTextHTTPRequestHandler struct { next http.Handler } +// parseForwardedHeaderProto parses the Forwarded header and extracts the protocol. +// The Forwarded header format supports: +// - Single proxy: Forwarded: by=;for=;host=;proto= +// - Multiple proxies: Forwarded: for=192.0.2.43, for=198.51.100.17 +// We take the first (leftmost) entry as it represents the original client +func parseForwardedHeaderProto(forwarded string) string { + if forwarded == "" { + return "" + } + + // Parse the first part (leftmost proxy, closest to original client) + firstPart, _, _ := strings.Cut(forwarded, ",") + firstPart = strings.TrimSpace(firstPart) + + // Split by semicolon to get key-value pairs within this proxy entry + // Format: key=value;key=value;key=value + pairs := strings.Split(firstPart, ";") + for _, pair := range pairs { + // Split by equals sign to separate key and value + key, value, found := strings.Cut(pair, "=") + if !found { + continue + } + + if strings.EqualFold(strings.TrimSpace(key), "proto") { + return strings.Trim(strings.TrimSpace(value), `"'`) + } + } + + return "" +} + +// isHTTPSRequest checks if the original request was made over HTTPS +// by examining both X-Forwarded-Proto and Forwarded headers +func isHTTPSRequest(r *http.Request) bool { + return strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https") || + strings.EqualFold(parseForwardedHeaderProto(r.Header.Get("Forwarded")), "https") +} + func (h *plainTextHTTPRequestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if slices.Contains(safeMethods, r.Method) { h.next.ServeHTTP(w, r) @@ -24,7 +64,7 @@ func (h *plainTextHTTPRequestHandler) ServeHTTP(w http.ResponseWriter, r *http.R req := r // If original request was HTTPS (via proxy), keep CSRF checks. - if xfproto := r.Header.Get("X-Forwarded-Proto"); xfproto != "https" { + if !isHTTPSRequest(r) { req = csrf.PlaintextHTTPRequest(r) } diff --git a/api/http/middlewares/plaintext_http_request_test.go b/api/http/middlewares/plaintext_http_request_test.go new file mode 100644 index 000000000..33912be80 --- /dev/null +++ b/api/http/middlewares/plaintext_http_request_test.go @@ -0,0 +1,173 @@ +package middlewares + +import ( + "testing" +) + +var tests = []struct { + name string + forwarded string + expected string +}{ + { + name: "empty header", + forwarded: "", + expected: "", + }, + { + name: "single proxy with proto=https", + forwarded: "proto=https", + expected: "https", + }, + { + name: "single proxy with proto=http", + forwarded: "proto=http", + expected: "http", + }, + { + name: "single proxy with multiple directives", + forwarded: "for=192.0.2.60;proto=https;by=203.0.113.43", + expected: "https", + }, + { + name: "single proxy with proto in middle", + forwarded: "for=192.0.2.60;proto=https;host=example.com", + expected: "https", + }, + { + name: "single proxy with proto at end", + forwarded: "for=192.0.2.60;host=example.com;proto=https", + expected: "https", + }, + { + name: "multiple proxies - takes first", + forwarded: "proto=https, proto=http", + expected: "https", + }, + { + name: "multiple proxies with complex format", + forwarded: "for=192.0.2.43;proto=https, for=198.51.100.17;proto=http", + expected: "https", + }, + { + name: "multiple proxies with for directive only", + forwarded: "for=192.0.2.43, for=198.51.100.17", + expected: "", + }, + { + name: "multiple proxies with proto only in second", + forwarded: "for=192.0.2.43, proto=https", + expected: "", + }, + { + name: "multiple proxies with proto only in first", + forwarded: "proto=https, for=198.51.100.17", + expected: "https", + }, + { + name: "quoted protocol value", + forwarded: "proto=\"https\"", + expected: "https", + }, + { + name: "single quoted protocol value", + forwarded: "proto='https'", + expected: "https", + }, + { + name: "mixed case protocol", + forwarded: "proto=HTTPS", + expected: "HTTPS", + }, + { + name: "no proto directive", + forwarded: "for=192.0.2.60;by=203.0.113.43", + expected: "", + }, + { + name: "empty proto value", + forwarded: "proto=", + expected: "", + }, + { + name: "whitespace around values", + forwarded: " proto = https ", + expected: "https", + }, + { + name: "whitespace around semicolons", + forwarded: "for=192.0.2.60 ; proto=https ; by=203.0.113.43", + expected: "https", + }, + { + name: "whitespace around commas", + forwarded: "proto=https , proto=http", + expected: "https", + }, + { + name: "IPv6 address in for directive", + forwarded: "for=\"[2001:db8:cafe::17]:4711\";proto=https", + expected: "https", + }, + { + name: "complex multiple proxies with IPv6", + forwarded: "for=192.0.2.43;proto=https, for=\"[2001:db8:cafe::17]\";proto=http", + expected: "https", + }, + { + name: "obfuscated identifiers", + forwarded: "for=_mdn;proto=https", + expected: "https", + }, + { + name: "unknown identifier", + forwarded: "for=unknown;proto=https", + expected: "https", + }, + { + name: "malformed key-value pair", + forwarded: "proto", + expected: "", + }, + { + name: "malformed key-value pair with equals", + forwarded: "proto=", + expected: "", + }, + { + name: "multiple equals signs", + forwarded: "proto=https=extra", + expected: "https=extra", + }, + { + name: "mixed case directive name", + forwarded: "PROTO=https", + expected: "https", + }, + { + name: "mixed case directive name with spaces", + forwarded: " Proto = https ", + expected: "https", + }, +} + +func TestParseForwardedHeaderProto(t *testing.T) { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseForwardedHeaderProto(tt.forwarded) + if result != tt.expected { + t.Errorf("parseForwardedHeader(%q) = %q, want %q", tt.forwarded, result, tt.expected) + } + }) + } +} + +func FuzzParseForwardedHeaderProto(f *testing.F) { + for _, t := range tests { + f.Add(t.forwarded) + } + + f.Fuzz(func(t *testing.T, forwarded string) { + parseForwardedHeaderProto(forwarded) + }) +} diff --git a/api/http/server.go b/api/http/server.go index 8f2c52d5c..b79ac455b 100644 --- a/api/http/server.go +++ b/api/http/server.go @@ -113,6 +113,7 @@ type Server struct { PendingActionsService *pendingactions.PendingActionsService PlatformService platform.Service PullLimitCheckDisabled bool + TrustedOrigins []string } // Start starts the HTTP server @@ -339,7 +340,7 @@ func (server *Server) Start() error { handler = middlewares.WithSlowRequestsLogger(handler) - handler, err := csrf.WithProtect(handler) + handler, err := csrf.WithProtect(handler, server.TrustedOrigins) if err != nil { return errors.Wrap(err, "failed to create CSRF middleware") } diff --git a/api/portainer.go b/api/portainer.go index 0cb906f28..c41c2988c 100644 --- a/api/portainer.go +++ b/api/portainer.go @@ -135,6 +135,7 @@ type ( LogMode *string KubectlShellImage *string PullLimitCheckDisabled *bool + TrustedOrigins *string } // CustomTemplateVariableDefinition @@ -1692,6 +1693,13 @@ const ( KubectlShellImageEnvVar = "KUBECTL_SHELL_IMAGE" // PullLimitCheckDisabledEnvVar is the environment variable used to disable the pull limit check PullLimitCheckDisabledEnvVar = "PULL_LIMIT_CHECK_DISABLED" + // LicenseServerBaseURL represents the base URL of the API used to validate + // an extension license. + LicenseServerBaseURL = "https://api.portainer.io" + // URL to validate licenses along with system metadata. + LicenseCheckInURL = LicenseServerBaseURL + "/licenses/checkin" + // TrustedOriginsEnvVar is the environment variable used to set the trusted origins for CSRF protection + TrustedOriginsEnvVar = "TRUSTED_ORIGINS" ) // List of supported features diff --git a/pkg/validate/validate.go b/pkg/validate/validate.go new file mode 100644 index 000000000..8ad69df72 --- /dev/null +++ b/pkg/validate/validate.go @@ -0,0 +1,111 @@ +package validate + +import ( + "net" + "net/url" + "regexp" + "strings" + "unicode/utf8" + + "github.com/google/uuid" +) + +var ( + hexadecimalRegex = regexp.MustCompile(`^[0-9a-fA-F]+$`) + dnsNameRegex = regexp.MustCompile(`^([a-zA-Z0-9_]{1}[a-zA-Z0-9_-]{0,62}){1}(\.[a-zA-Z0-9_]{1}[a-zA-Z0-9_-]{0,62})*[\._]?$`) +) + +func IsURL(urlString string) bool { + if len(urlString) == 0 { + return false + } + + strTemp := urlString + if !strings.Contains(urlString, "://") { + // support no indicated urlscheme + // http:// is appended so url.Parse will succeed + strTemp = "http://" + urlString + } + + u, err := url.Parse(strTemp) + return err == nil && u.Host != "" +} + +func IsUUID(uuidString string) bool { + return uuid.Validate(uuidString) == nil +} + +func IsHexadecimal(hexString string) bool { + return hexadecimalRegex.MatchString(hexString) +} + +func HasWhitespaceOnly(s string) bool { + return len(s) > 0 && strings.TrimSpace(s) == "" +} + +func MinStringLength(s string, len int) bool { + return utf8.RuneCountInString(s) >= len +} + +func Matches(s, pattern string) bool { + match, err := regexp.MatchString(pattern, s) + return err == nil && match +} + +func IsNonPositive(f float64) bool { + return f <= 0 +} + +func InRange(val, left, right float64) bool { + if left > right { + left, right = right, left + } + + return val >= left && val <= right +} + +func IsHost(s string) bool { + return IsIP(s) || IsDNSName(s) +} + +func IsIP(s string) bool { + return net.ParseIP(s) != nil +} + +func IsDNSName(s string) bool { + if s == "" || len(strings.ReplaceAll(s, ".", "")) > 255 { + // constraints already violated + return false + } + + return !IsIP(s) && dnsNameRegex.MatchString(s) +} + +func IsTrustedOrigin(s string) bool { + // Reject if a scheme is present + if strings.Contains(s, "://") { + return false + } + + // Prepend http:// for parsing + strTemp := "http://" + s + parsedOrigin, err := url.Parse(strTemp) + if err != nil { + return false + } + + // Validate host, and ensure no user, path, query, fragment, port, etc. + if parsedOrigin.Host == "" || + parsedOrigin.User != nil || + parsedOrigin.Path != "" || + parsedOrigin.RawQuery != "" || + parsedOrigin.Fragment != "" || + parsedOrigin.Opaque != "" || + parsedOrigin.RawFragment != "" || + parsedOrigin.RawPath != "" || + parsedOrigin.Port() != "" { + return false + } + + return true +} diff --git a/pkg/validate/validate_test.go b/pkg/validate/validate_test.go new file mode 100644 index 000000000..ca054190d --- /dev/null +++ b/pkg/validate/validate_test.go @@ -0,0 +1,500 @@ +package validate + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_IsURL(t *testing.T) { + testCases := []struct { + name string + url string + expectedResult bool + }{ + { + name: "simple url", + url: "https://google.com", + expectedResult: true, + }, + { + name: "empty", + url: "", + expectedResult: false, + }, + { + name: "no schema", + url: "google.com", + expectedResult: true, + }, + { + name: "path", + url: "https://google.com/some/thing", + expectedResult: true, + }, + { + name: "query params", + url: "https://google.com/some/thing?a=5&b=6", + expectedResult: true, + }, + { + name: "no top level domain", + url: "google", + expectedResult: true, + }, + { + name: "Unicode URL", + url: "www.xn--exampe-7db.ai", + expectedResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := IsURL(tc.url) + require.Equal(t, tc.expectedResult, result) + }) + } +} + +func Test_IsUUID(t *testing.T) { + testCases := []struct { + name string + uuid string + expectedResult bool + }{ + { + name: "empty", + uuid: "", + expectedResult: false, + }, + { + name: "version 3 UUID", + uuid: "060507eb-3b9a-362e-b850-d5f065eea403", + expectedResult: true, + }, + { + name: "version 4 UUID", + uuid: "63e695ee-48a9-498a-98b3-9472ff75e09f", + expectedResult: true, + }, + { + name: "version 5 UUID", + uuid: "5daabcd8-f17e-568c-aa6f-da9d92c7032c", + expectedResult: true, + }, + { + name: "text", + uuid: "something like this", + expectedResult: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := IsUUID(tc.uuid) + require.Equal(t, tc.expectedResult, result) + }) + } +} + +func Test_IsHexadecimal(t *testing.T) { + testCases := []struct { + name string + hex string + expectedResult bool + }{ + { + name: "empty", + hex: "", + expectedResult: false, + }, + { + name: "hex", + hex: "48656C6C6F20736F6D657468696E67", + expectedResult: true, + }, + { + name: "text", + hex: "something like this", + expectedResult: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := IsHexadecimal(tc.hex) + require.Equal(t, tc.expectedResult, result) + }) + } +} + +func Test_HasWhitespaceOnly(t *testing.T) { + testCases := []struct { + name string + s string + expectedResult bool + }{ + { + name: "empty", + s: "", + expectedResult: false, + }, + { + name: "space", + s: " ", + expectedResult: true, + }, + { + name: "tab", + s: "\t", + expectedResult: true, + }, + { + name: "text", + s: "something like this", + expectedResult: false, + }, + { + name: "all whitespace", + s: "\t\n\v\f\r ", + expectedResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := HasWhitespaceOnly(tc.s) + require.Equal(t, tc.expectedResult, result) + }) + } +} + +func Test_MinStringLength(t *testing.T) { + testCases := []struct { + name string + s string + len int + expectedResult bool + }{ + { + name: "empty + zero len", + s: "", + len: 0, + expectedResult: true, + }, + { + name: "empty + non zero len", + s: "", + len: 10, + expectedResult: false, + }, + { + name: "long text + non zero len", + s: "something else", + len: 10, + expectedResult: true, + }, + { + name: "multibyte characters - enough", + s: "X生", + len: 2, + expectedResult: true, + }, + { + name: "multibyte characters - not enough", + s: "X生", + len: 3, + expectedResult: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := MinStringLength(tc.s, tc.len) + require.Equal(t, tc.expectedResult, result) + }) + } +} + +func Test_Matches(t *testing.T) { + testCases := []struct { + name string + s string + pattern string + expectedResult bool + }{ + { + name: "empty", + s: "", + pattern: "", + expectedResult: true, + }, + { + name: "space", + s: "something else", + pattern: " ", + expectedResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := Matches(tc.s, tc.pattern) + require.Equal(t, tc.expectedResult, result) + }) + } +} + +func Test_IsNonPositive(t *testing.T) { + testCases := []struct { + name string + f float64 + expectedResult bool + }{ + { + name: "zero", + f: 0, + expectedResult: true, + }, + { + name: "positive", + f: 1, + expectedResult: false, + }, + { + name: "negative", + f: -1, + expectedResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := IsNonPositive(tc.f) + require.Equal(t, tc.expectedResult, result) + }) + } +} + +func Test_InRange(t *testing.T) { + testCases := []struct { + name string + f float64 + left float64 + right float64 + expectedResult bool + }{ + { + name: "zero", + f: 0, + left: 0, + right: 0, + expectedResult: true, + }, + { + name: "equal left", + f: 1, + left: 1, + right: 2, + expectedResult: true, + }, + { + name: "equal right", + f: 2, + left: 1, + right: 2, + expectedResult: true, + }, + { + name: "above", + f: 3, + left: 1, + right: 2, + expectedResult: false, + }, + { + name: "below", + f: 0, + left: 1, + right: 2, + expectedResult: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := InRange(tc.f, tc.left, tc.right) + require.Equal(t, tc.expectedResult, result) + }) + } +} + +func Test_IsHost(t *testing.T) { + testCases := []struct { + name string + s string + expectedResult bool + }{ + { + name: "empty", + s: "", + expectedResult: false, + }, + { + name: "ip address", + s: "192.168.1.1", + expectedResult: true, + }, + { + name: "hostname", + s: "google.com", + expectedResult: true, + }, + { + name: "text", + s: "Something like this", + expectedResult: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := IsHost(tc.s) + require.Equal(t, tc.expectedResult, result) + }) + } +} + +func Test_IsIP(t *testing.T) { + testCases := []struct { + name string + s string + expectedResult bool + }{ + { + name: "empty", + s: "", + expectedResult: false, + }, + { + name: "ip address", + s: "192.168.1.1", + expectedResult: true, + }, + { + name: "hostname", + s: "google.com", + expectedResult: false, + }, + { + name: "text", + s: "Something like this", + expectedResult: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := IsIP(tc.s) + require.Equal(t, tc.expectedResult, result) + }) + } +} + +func Test_IsDNSName(t *testing.T) { + testCases := []struct { + name string + s string + expectedResult bool + }{ + { + name: "empty", + s: "", + expectedResult: false, + }, + { + name: "ip address", + s: "192.168.1.1", + expectedResult: false, + }, + { + name: "hostname", + s: "google.com", + expectedResult: true, + }, + { + name: "text", + s: "Something like this", + expectedResult: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := IsDNSName(tc.s) + require.Equal(t, tc.expectedResult, result) + }) + } +} + +func Test_IsTrustedOrigin(t *testing.T) { + f := func(s string, expected bool) { + t.Helper() + + result := IsTrustedOrigin(s) + if result != expected { + t.Fatalf("unexpected result for %q; got %t; want %t", s, result, expected) + } + } + + // Valid trusted origins - host only + f("localhost", true) + f("example.com", true) + f("192.168.1.1", true) + f("api.example.com", true) + f("subdomain.example.org", true) + + // Invalid trusted origins - host with port (no longer allowed) + f("localhost:8080", false) + f("example.com:3000", false) + f("192.168.1.1:443", false) + f("api.example.com:9000", false) + + // Invalid trusted origins - empty or malformed + f("", false) + f("invalid url", false) + f("://example.com", false) + + // Invalid trusted origins - with scheme + f("http://example.com", false) + f("https://localhost", false) + f("ftp://192.168.1.1", false) + + // Invalid trusted origins - with user info + f("user@example.com", false) + f("user:pass@localhost", false) + + // Invalid trusted origins - with path + f("example.com/path", false) + f("localhost/api", false) + f("192.168.1.1/static", false) + + // Invalid trusted origins - with query parameters + f("example.com?param=value", false) + f("localhost:8080?query=test", false) + + // Invalid trusted origins - with fragment + f("example.com#fragment", false) + f("localhost:3000#section", false) + + // Invalid trusted origins - with multiple invalid components + f("https://user@example.com/path?query=value#fragment", false) + f("http://localhost:8080/api/v1?param=test", false) + + // Edge cases - ports are no longer allowed + f("example.com:0", false) // port 0 is no longer valid + f("example.com:65535", false) // max port number is no longer valid + f("example.com:99999", false) // invalid port number + f("example.com:-1", false) // negative port +}