mirror of
https://github.com/portainer/portainer.git
synced 2025-07-20 13:59:40 +02:00
feat(csrf): add trusted origins cli flags [BE-11972] (#839)
Co-authored-by: oscarzhou <oscar.zhou@portainer.io> Co-authored-by: andres-portainer <andres-portainer@users.noreply.github.com>
This commit is contained in:
parent
973c99dcf4
commit
1e1998e269
9 changed files with 885 additions and 9 deletions
|
@ -61,6 +61,7 @@ func CLIFlags() *portainer.CLIFlags {
|
||||||
LogMode: kingpin.Flag("log-mode", "Set the logging output mode").Default("PRETTY").Enum("NOCOLOR", "PRETTY", "JSON"),
|
LogMode: kingpin.Flag("log-mode", "Set the logging output mode").Default("PRETTY").Enum("NOCOLOR", "PRETTY", "JSON"),
|
||||||
KubectlShellImage: kingpin.Flag("kubectl-shell-image", "Kubectl shell image").Envar(portainer.KubectlShellImageEnvVar).Default(portainer.DefaultKubectlShellImage).String(),
|
KubectlShellImage: kingpin.Flag("kubectl-shell-image", "Kubectl shell image").Envar(portainer.KubectlShellImageEnvVar).Default(portainer.DefaultKubectlShellImage).String(),
|
||||||
PullLimitCheckDisabled: kingpin.Flag("pull-limit-check-disabled", "Pull limit check").Envar(portainer.PullLimitCheckDisabledEnvVar).Default(defaultPullLimitCheckDisabled).Bool(),
|
PullLimitCheckDisabled: kingpin.Flag("pull-limit-check-disabled", "Pull limit check").Envar(portainer.PullLimitCheckDisabledEnvVar).Default(defaultPullLimitCheckDisabled).Bool(),
|
||||||
|
TrustedOrigins: kingpin.Flag("trusted-origins", "List of trusted origins for CSRF protection. Separate multiple origins with a comma.").Envar(portainer.TrustedOriginsEnvVar).String(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -50,6 +50,7 @@ import (
|
||||||
"github.com/portainer/portainer/pkg/featureflags"
|
"github.com/portainer/portainer/pkg/featureflags"
|
||||||
"github.com/portainer/portainer/pkg/libhelm"
|
"github.com/portainer/portainer/pkg/libhelm"
|
||||||
"github.com/portainer/portainer/pkg/libstack/compose"
|
"github.com/portainer/portainer/pkg/libstack/compose"
|
||||||
|
"github.com/portainer/portainer/pkg/validate"
|
||||||
|
|
||||||
"github.com/gofrs/uuid"
|
"github.com/gofrs/uuid"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
@ -328,6 +329,18 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||||
featureflags.Parse(*flags.FeatureFlags, portainer.SupportedFeatureFlags)
|
featureflags.Parse(*flags.FeatureFlags, portainer.SupportedFeatureFlags)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trustedOrigins := []string{}
|
||||||
|
if *flags.TrustedOrigins != "" {
|
||||||
|
// validate if the trusted origins are valid urls
|
||||||
|
for _, origin := range strings.Split(*flags.TrustedOrigins, ",") {
|
||||||
|
if !validate.IsTrustedOrigin(origin) {
|
||||||
|
log.Fatal().Str("trusted_origin", origin).Msg("invalid url for trusted origin. Please check the trusted origins flag.")
|
||||||
|
}
|
||||||
|
|
||||||
|
trustedOrigins = append(trustedOrigins, origin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fileService := initFileService(*flags.Data)
|
fileService := initFileService(*flags.Data)
|
||||||
encryptionKey := loadEncryptionSecretKey(*flags.SecretKeyName)
|
encryptionKey := loadEncryptionSecretKey(*flags.SecretKeyName)
|
||||||
if encryptionKey == nil {
|
if encryptionKey == nil {
|
||||||
|
@ -576,6 +589,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||||
PendingActionsService: pendingActionsService,
|
PendingActionsService: pendingActionsService,
|
||||||
PlatformService: platformService,
|
PlatformService: platformService,
|
||||||
PullLimitCheckDisabled: *flags.PullLimitCheckDisabled,
|
PullLimitCheckDisabled: *flags.PullLimitCheckDisabled,
|
||||||
|
TrustedOrigins: trustedOrigins,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package csrf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
@ -9,7 +10,8 @@ import (
|
||||||
"github.com/portainer/portainer/api/http/security"
|
"github.com/portainer/portainer/api/http/security"
|
||||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||||
|
|
||||||
gorillacsrf "github.com/gorilla/csrf"
|
gcsrf "github.com/gorilla/csrf"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/urfave/negroni"
|
"github.com/urfave/negroni"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,7 +21,7 @@ func SkipCSRFToken(w http.ResponseWriter) {
|
||||||
w.Header().Set(csrfSkipHeader, "1")
|
w.Header().Set(csrfSkipHeader, "1")
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithProtect(handler http.Handler) (http.Handler, error) {
|
func WithProtect(handler http.Handler, trustedOrigins []string) (http.Handler, error) {
|
||||||
// IsDockerDesktopExtension is used to check if we should skip csrf checks in the request bouncer (ShouldSkipCSRFCheck)
|
// IsDockerDesktopExtension is used to check if we should skip csrf checks in the request bouncer (ShouldSkipCSRFCheck)
|
||||||
// DOCKER_EXTENSION is set to '1' in build/docker-extension/docker-compose.yml
|
// DOCKER_EXTENSION is set to '1' in build/docker-extension/docker-compose.yml
|
||||||
isDockerDesktopExtension := false
|
isDockerDesktopExtension := false
|
||||||
|
@ -34,10 +36,12 @@ func WithProtect(handler http.Handler) (http.Handler, error) {
|
||||||
return nil, fmt.Errorf("failed to generate CSRF token: %w", err)
|
return nil, fmt.Errorf("failed to generate CSRF token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
handler = gorillacsrf.Protect(
|
handler = gcsrf.Protect(
|
||||||
token,
|
token,
|
||||||
gorillacsrf.Path("/"),
|
gcsrf.Path("/"),
|
||||||
gorillacsrf.Secure(false),
|
gcsrf.Secure(false),
|
||||||
|
gcsrf.TrustedOrigins(trustedOrigins),
|
||||||
|
gcsrf.ErrorHandler(withErrorHandler(trustedOrigins)),
|
||||||
)(handler)
|
)(handler)
|
||||||
|
|
||||||
return withSkipCSRF(handler, isDockerDesktopExtension), nil
|
return withSkipCSRF(handler, isDockerDesktopExtension), nil
|
||||||
|
@ -55,7 +59,7 @@ func withSendCSRFToken(handler http.Handler) http.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
if statusCode := sw.Status(); statusCode >= 200 && statusCode < 300 {
|
if statusCode := sw.Status(); statusCode >= 200 && statusCode < 300 {
|
||||||
sw.Header().Set("X-CSRF-Token", gorillacsrf.Token(r))
|
sw.Header().Set("X-CSRF-Token", gcsrf.Token(r))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -73,9 +77,33 @@ func withSkipCSRF(handler http.Handler, isDockerDesktopExtension bool) http.Hand
|
||||||
}
|
}
|
||||||
|
|
||||||
if skip {
|
if skip {
|
||||||
r = gorillacsrf.UnsafeSkipCheck(r)
|
r = gcsrf.UnsafeSkipCheck(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func withErrorHandler(trustedOrigins []string) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
err := gcsrf.FailureReason(r)
|
||||||
|
|
||||||
|
if errors.Is(err, gcsrf.ErrBadOrigin) || errors.Is(err, gcsrf.ErrBadReferer) || errors.Is(err, gcsrf.ErrNoReferer) {
|
||||||
|
log.Error().Err(err).
|
||||||
|
Str("request_url", r.URL.String()).
|
||||||
|
Str("host", r.Host).
|
||||||
|
Str("x_forwarded_proto", r.Header.Get("X-Forwarded-Proto")).
|
||||||
|
Str("forwarded", r.Header.Get("Forwarded")).
|
||||||
|
Str("origin", r.Header.Get("Origin")).
|
||||||
|
Str("referer", r.Header.Get("Referer")).
|
||||||
|
Strs("trusted_origins", trustedOrigins).
|
||||||
|
Msg("Failed to validate Origin or Referer")
|
||||||
|
}
|
||||||
|
|
||||||
|
http.Error(
|
||||||
|
w,
|
||||||
|
http.StatusText(http.StatusForbidden)+" - "+err.Error(),
|
||||||
|
http.StatusForbidden,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package middlewares
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gorilla/csrf"
|
"github.com/gorilla/csrf"
|
||||||
)
|
)
|
||||||
|
@ -16,6 +17,45 @@ type plainTextHTTPRequestHandler struct {
|
||||||
next http.Handler
|
next http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseForwardedHeaderProto parses the Forwarded header and extracts the protocol.
|
||||||
|
// The Forwarded header format supports:
|
||||||
|
// - Single proxy: Forwarded: by=<identifier>;for=<identifier>;host=<host>;proto=<http|https>
|
||||||
|
// - Multiple proxies: Forwarded: for=192.0.2.43, for=198.51.100.17
|
||||||
|
// We take the first (leftmost) entry as it represents the original client
|
||||||
|
func parseForwardedHeaderProto(forwarded string) string {
|
||||||
|
if forwarded == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the first part (leftmost proxy, closest to original client)
|
||||||
|
firstPart, _, _ := strings.Cut(forwarded, ",")
|
||||||
|
firstPart = strings.TrimSpace(firstPart)
|
||||||
|
|
||||||
|
// Split by semicolon to get key-value pairs within this proxy entry
|
||||||
|
// Format: key=value;key=value;key=value
|
||||||
|
pairs := strings.Split(firstPart, ";")
|
||||||
|
for _, pair := range pairs {
|
||||||
|
// Split by equals sign to separate key and value
|
||||||
|
key, value, found := strings.Cut(pair, "=")
|
||||||
|
if !found {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.EqualFold(strings.TrimSpace(key), "proto") {
|
||||||
|
return strings.Trim(strings.TrimSpace(value), `"'`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// isHTTPSRequest checks if the original request was made over HTTPS
|
||||||
|
// by examining both X-Forwarded-Proto and Forwarded headers
|
||||||
|
func isHTTPSRequest(r *http.Request) bool {
|
||||||
|
return strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https") ||
|
||||||
|
strings.EqualFold(parseForwardedHeaderProto(r.Header.Get("Forwarded")), "https")
|
||||||
|
}
|
||||||
|
|
||||||
func (h *plainTextHTTPRequestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (h *plainTextHTTPRequestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
if slices.Contains(safeMethods, r.Method) {
|
if slices.Contains(safeMethods, r.Method) {
|
||||||
h.next.ServeHTTP(w, r)
|
h.next.ServeHTTP(w, r)
|
||||||
|
@ -24,7 +64,7 @@ func (h *plainTextHTTPRequestHandler) ServeHTTP(w http.ResponseWriter, r *http.R
|
||||||
|
|
||||||
req := r
|
req := r
|
||||||
// If original request was HTTPS (via proxy), keep CSRF checks.
|
// If original request was HTTPS (via proxy), keep CSRF checks.
|
||||||
if xfproto := r.Header.Get("X-Forwarded-Proto"); xfproto != "https" {
|
if !isHTTPSRequest(r) {
|
||||||
req = csrf.PlaintextHTTPRequest(r)
|
req = csrf.PlaintextHTTPRequest(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
173
api/http/middlewares/plaintext_http_request_test.go
Normal file
173
api/http/middlewares/plaintext_http_request_test.go
Normal file
|
@ -0,0 +1,173 @@
|
||||||
|
package middlewares
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var tests = []struct {
|
||||||
|
name string
|
||||||
|
forwarded string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty header",
|
||||||
|
forwarded: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single proxy with proto=https",
|
||||||
|
forwarded: "proto=https",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single proxy with proto=http",
|
||||||
|
forwarded: "proto=http",
|
||||||
|
expected: "http",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single proxy with multiple directives",
|
||||||
|
forwarded: "for=192.0.2.60;proto=https;by=203.0.113.43",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single proxy with proto in middle",
|
||||||
|
forwarded: "for=192.0.2.60;proto=https;host=example.com",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single proxy with proto at end",
|
||||||
|
forwarded: "for=192.0.2.60;host=example.com;proto=https",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple proxies - takes first",
|
||||||
|
forwarded: "proto=https, proto=http",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple proxies with complex format",
|
||||||
|
forwarded: "for=192.0.2.43;proto=https, for=198.51.100.17;proto=http",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple proxies with for directive only",
|
||||||
|
forwarded: "for=192.0.2.43, for=198.51.100.17",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple proxies with proto only in second",
|
||||||
|
forwarded: "for=192.0.2.43, proto=https",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple proxies with proto only in first",
|
||||||
|
forwarded: "proto=https, for=198.51.100.17",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "quoted protocol value",
|
||||||
|
forwarded: "proto=\"https\"",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single quoted protocol value",
|
||||||
|
forwarded: "proto='https'",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case protocol",
|
||||||
|
forwarded: "proto=HTTPS",
|
||||||
|
expected: "HTTPS",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no proto directive",
|
||||||
|
forwarded: "for=192.0.2.60;by=203.0.113.43",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty proto value",
|
||||||
|
forwarded: "proto=",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "whitespace around values",
|
||||||
|
forwarded: " proto = https ",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "whitespace around semicolons",
|
||||||
|
forwarded: "for=192.0.2.60 ; proto=https ; by=203.0.113.43",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "whitespace around commas",
|
||||||
|
forwarded: "proto=https , proto=http",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 address in for directive",
|
||||||
|
forwarded: "for=\"[2001:db8:cafe::17]:4711\";proto=https",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex multiple proxies with IPv6",
|
||||||
|
forwarded: "for=192.0.2.43;proto=https, for=\"[2001:db8:cafe::17]\";proto=http",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "obfuscated identifiers",
|
||||||
|
forwarded: "for=_mdn;proto=https",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown identifier",
|
||||||
|
forwarded: "for=unknown;proto=https",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "malformed key-value pair",
|
||||||
|
forwarded: "proto",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "malformed key-value pair with equals",
|
||||||
|
forwarded: "proto=",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple equals signs",
|
||||||
|
forwarded: "proto=https=extra",
|
||||||
|
expected: "https=extra",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case directive name",
|
||||||
|
forwarded: "PROTO=https",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case directive name with spaces",
|
||||||
|
forwarded: " Proto = https ",
|
||||||
|
expected: "https",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseForwardedHeaderProto(t *testing.T) {
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := parseForwardedHeaderProto(tt.forwarded)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("parseForwardedHeader(%q) = %q, want %q", tt.forwarded, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func FuzzParseForwardedHeaderProto(f *testing.F) {
|
||||||
|
for _, t := range tests {
|
||||||
|
f.Add(t.forwarded)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.Fuzz(func(t *testing.T, forwarded string) {
|
||||||
|
parseForwardedHeaderProto(forwarded)
|
||||||
|
})
|
||||||
|
}
|
|
@ -113,6 +113,7 @@ type Server struct {
|
||||||
PendingActionsService *pendingactions.PendingActionsService
|
PendingActionsService *pendingactions.PendingActionsService
|
||||||
PlatformService platform.Service
|
PlatformService platform.Service
|
||||||
PullLimitCheckDisabled bool
|
PullLimitCheckDisabled bool
|
||||||
|
TrustedOrigins []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts the HTTP server
|
// Start starts the HTTP server
|
||||||
|
@ -339,7 +340,7 @@ func (server *Server) Start() error {
|
||||||
|
|
||||||
handler = middlewares.WithSlowRequestsLogger(handler)
|
handler = middlewares.WithSlowRequestsLogger(handler)
|
||||||
|
|
||||||
handler, err := csrf.WithProtect(handler)
|
handler, err := csrf.WithProtect(handler, server.TrustedOrigins)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "failed to create CSRF middleware")
|
return errors.Wrap(err, "failed to create CSRF middleware")
|
||||||
}
|
}
|
||||||
|
|
|
@ -135,6 +135,7 @@ type (
|
||||||
LogMode *string
|
LogMode *string
|
||||||
KubectlShellImage *string
|
KubectlShellImage *string
|
||||||
PullLimitCheckDisabled *bool
|
PullLimitCheckDisabled *bool
|
||||||
|
TrustedOrigins *string
|
||||||
}
|
}
|
||||||
|
|
||||||
// CustomTemplateVariableDefinition
|
// CustomTemplateVariableDefinition
|
||||||
|
@ -1692,6 +1693,13 @@ const (
|
||||||
KubectlShellImageEnvVar = "KUBECTL_SHELL_IMAGE"
|
KubectlShellImageEnvVar = "KUBECTL_SHELL_IMAGE"
|
||||||
// PullLimitCheckDisabledEnvVar is the environment variable used to disable the pull limit check
|
// PullLimitCheckDisabledEnvVar is the environment variable used to disable the pull limit check
|
||||||
PullLimitCheckDisabledEnvVar = "PULL_LIMIT_CHECK_DISABLED"
|
PullLimitCheckDisabledEnvVar = "PULL_LIMIT_CHECK_DISABLED"
|
||||||
|
// LicenseServerBaseURL represents the base URL of the API used to validate
|
||||||
|
// an extension license.
|
||||||
|
LicenseServerBaseURL = "https://api.portainer.io"
|
||||||
|
// URL to validate licenses along with system metadata.
|
||||||
|
LicenseCheckInURL = LicenseServerBaseURL + "/licenses/checkin"
|
||||||
|
// TrustedOriginsEnvVar is the environment variable used to set the trusted origins for CSRF protection
|
||||||
|
TrustedOriginsEnvVar = "TRUSTED_ORIGINS"
|
||||||
)
|
)
|
||||||
|
|
||||||
// List of supported features
|
// List of supported features
|
||||||
|
|
111
pkg/validate/validate.go
Normal file
111
pkg/validate/validate.go
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
package validate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
hexadecimalRegex = regexp.MustCompile(`^[0-9a-fA-F]+$`)
|
||||||
|
dnsNameRegex = regexp.MustCompile(`^([a-zA-Z0-9_]{1}[a-zA-Z0-9_-]{0,62}){1}(\.[a-zA-Z0-9_]{1}[a-zA-Z0-9_-]{0,62})*[\._]?$`)
|
||||||
|
)
|
||||||
|
|
||||||
|
func IsURL(urlString string) bool {
|
||||||
|
if len(urlString) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
strTemp := urlString
|
||||||
|
if !strings.Contains(urlString, "://") {
|
||||||
|
// support no indicated urlscheme
|
||||||
|
// http:// is appended so url.Parse will succeed
|
||||||
|
strTemp = "http://" + urlString
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(strTemp)
|
||||||
|
return err == nil && u.Host != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsUUID(uuidString string) bool {
|
||||||
|
return uuid.Validate(uuidString) == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsHexadecimal(hexString string) bool {
|
||||||
|
return hexadecimalRegex.MatchString(hexString)
|
||||||
|
}
|
||||||
|
|
||||||
|
func HasWhitespaceOnly(s string) bool {
|
||||||
|
return len(s) > 0 && strings.TrimSpace(s) == ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func MinStringLength(s string, len int) bool {
|
||||||
|
return utf8.RuneCountInString(s) >= len
|
||||||
|
}
|
||||||
|
|
||||||
|
func Matches(s, pattern string) bool {
|
||||||
|
match, err := regexp.MatchString(pattern, s)
|
||||||
|
return err == nil && match
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsNonPositive(f float64) bool {
|
||||||
|
return f <= 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func InRange(val, left, right float64) bool {
|
||||||
|
if left > right {
|
||||||
|
left, right = right, left
|
||||||
|
}
|
||||||
|
|
||||||
|
return val >= left && val <= right
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsHost(s string) bool {
|
||||||
|
return IsIP(s) || IsDNSName(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsIP(s string) bool {
|
||||||
|
return net.ParseIP(s) != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsDNSName(s string) bool {
|
||||||
|
if s == "" || len(strings.ReplaceAll(s, ".", "")) > 255 {
|
||||||
|
// constraints already violated
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return !IsIP(s) && dnsNameRegex.MatchString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsTrustedOrigin(s string) bool {
|
||||||
|
// Reject if a scheme is present
|
||||||
|
if strings.Contains(s, "://") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepend http:// for parsing
|
||||||
|
strTemp := "http://" + s
|
||||||
|
parsedOrigin, err := url.Parse(strTemp)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate host, and ensure no user, path, query, fragment, port, etc.
|
||||||
|
if parsedOrigin.Host == "" ||
|
||||||
|
parsedOrigin.User != nil ||
|
||||||
|
parsedOrigin.Path != "" ||
|
||||||
|
parsedOrigin.RawQuery != "" ||
|
||||||
|
parsedOrigin.Fragment != "" ||
|
||||||
|
parsedOrigin.Opaque != "" ||
|
||||||
|
parsedOrigin.RawFragment != "" ||
|
||||||
|
parsedOrigin.RawPath != "" ||
|
||||||
|
parsedOrigin.Port() != "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
500
pkg/validate/validate_test.go
Normal file
500
pkg/validate/validate_test.go
Normal file
|
@ -0,0 +1,500 @@
|
||||||
|
package validate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_IsURL(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
url string
|
||||||
|
expectedResult bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple url",
|
||||||
|
url: "https://google.com",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
url: "",
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no schema",
|
||||||
|
url: "google.com",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "path",
|
||||||
|
url: "https://google.com/some/thing",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "query params",
|
||||||
|
url: "https://google.com/some/thing?a=5&b=6",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no top level domain",
|
||||||
|
url: "google",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Unicode URL",
|
||||||
|
url: "www.xn--exampe-7db.ai",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := IsURL(tc.url)
|
||||||
|
require.Equal(t, tc.expectedResult, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_IsUUID(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
uuid string
|
||||||
|
expectedResult bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
uuid: "",
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "version 3 UUID",
|
||||||
|
uuid: "060507eb-3b9a-362e-b850-d5f065eea403",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "version 4 UUID",
|
||||||
|
uuid: "63e695ee-48a9-498a-98b3-9472ff75e09f",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "version 5 UUID",
|
||||||
|
uuid: "5daabcd8-f17e-568c-aa6f-da9d92c7032c",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text",
|
||||||
|
uuid: "something like this",
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := IsUUID(tc.uuid)
|
||||||
|
require.Equal(t, tc.expectedResult, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_IsHexadecimal(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
hex string
|
||||||
|
expectedResult bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
hex: "",
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hex",
|
||||||
|
hex: "48656C6C6F20736F6D657468696E67",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text",
|
||||||
|
hex: "something like this",
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := IsHexadecimal(tc.hex)
|
||||||
|
require.Equal(t, tc.expectedResult, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_HasWhitespaceOnly(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
s string
|
||||||
|
expectedResult bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
s: "",
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "space",
|
||||||
|
s: " ",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tab",
|
||||||
|
s: "\t",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text",
|
||||||
|
s: "something like this",
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all whitespace",
|
||||||
|
s: "\t\n\v\f\r ",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := HasWhitespaceOnly(tc.s)
|
||||||
|
require.Equal(t, tc.expectedResult, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_MinStringLength(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
s string
|
||||||
|
len int
|
||||||
|
expectedResult bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty + zero len",
|
||||||
|
s: "",
|
||||||
|
len: 0,
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty + non zero len",
|
||||||
|
s: "",
|
||||||
|
len: 10,
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "long text + non zero len",
|
||||||
|
s: "something else",
|
||||||
|
len: 10,
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multibyte characters - enough",
|
||||||
|
s: "X生",
|
||||||
|
len: 2,
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multibyte characters - not enough",
|
||||||
|
s: "X生",
|
||||||
|
len: 3,
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := MinStringLength(tc.s, tc.len)
|
||||||
|
require.Equal(t, tc.expectedResult, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Matches(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
s string
|
||||||
|
pattern string
|
||||||
|
expectedResult bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
s: "",
|
||||||
|
pattern: "",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "space",
|
||||||
|
s: "something else",
|
||||||
|
pattern: " ",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := Matches(tc.s, tc.pattern)
|
||||||
|
require.Equal(t, tc.expectedResult, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_IsNonPositive(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
f float64
|
||||||
|
expectedResult bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "zero",
|
||||||
|
f: 0,
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "positive",
|
||||||
|
f: 1,
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative",
|
||||||
|
f: -1,
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := IsNonPositive(tc.f)
|
||||||
|
require.Equal(t, tc.expectedResult, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_InRange(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
f float64
|
||||||
|
left float64
|
||||||
|
right float64
|
||||||
|
expectedResult bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "zero",
|
||||||
|
f: 0,
|
||||||
|
left: 0,
|
||||||
|
right: 0,
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "equal left",
|
||||||
|
f: 1,
|
||||||
|
left: 1,
|
||||||
|
right: 2,
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "equal right",
|
||||||
|
f: 2,
|
||||||
|
left: 1,
|
||||||
|
right: 2,
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "above",
|
||||||
|
f: 3,
|
||||||
|
left: 1,
|
||||||
|
right: 2,
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "below",
|
||||||
|
f: 0,
|
||||||
|
left: 1,
|
||||||
|
right: 2,
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := InRange(tc.f, tc.left, tc.right)
|
||||||
|
require.Equal(t, tc.expectedResult, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_IsHost(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
s string
|
||||||
|
expectedResult bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
s: "",
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ip address",
|
||||||
|
s: "192.168.1.1",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hostname",
|
||||||
|
s: "google.com",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text",
|
||||||
|
s: "Something like this",
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := IsHost(tc.s)
|
||||||
|
require.Equal(t, tc.expectedResult, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_IsIP(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
s string
|
||||||
|
expectedResult bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
s: "",
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ip address",
|
||||||
|
s: "192.168.1.1",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hostname",
|
||||||
|
s: "google.com",
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text",
|
||||||
|
s: "Something like this",
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := IsIP(tc.s)
|
||||||
|
require.Equal(t, tc.expectedResult, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_IsDNSName(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
s string
|
||||||
|
expectedResult bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
s: "",
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ip address",
|
||||||
|
s: "192.168.1.1",
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hostname",
|
||||||
|
s: "google.com",
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text",
|
||||||
|
s: "Something like this",
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := IsDNSName(tc.s)
|
||||||
|
require.Equal(t, tc.expectedResult, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_IsTrustedOrigin(t *testing.T) {
|
||||||
|
f := func(s string, expected bool) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
result := IsTrustedOrigin(s)
|
||||||
|
if result != expected {
|
||||||
|
t.Fatalf("unexpected result for %q; got %t; want %t", s, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Valid trusted origins - host only
|
||||||
|
f("localhost", true)
|
||||||
|
f("example.com", true)
|
||||||
|
f("192.168.1.1", true)
|
||||||
|
f("api.example.com", true)
|
||||||
|
f("subdomain.example.org", true)
|
||||||
|
|
||||||
|
// Invalid trusted origins - host with port (no longer allowed)
|
||||||
|
f("localhost:8080", false)
|
||||||
|
f("example.com:3000", false)
|
||||||
|
f("192.168.1.1:443", false)
|
||||||
|
f("api.example.com:9000", false)
|
||||||
|
|
||||||
|
// Invalid trusted origins - empty or malformed
|
||||||
|
f("", false)
|
||||||
|
f("invalid url", false)
|
||||||
|
f("://example.com", false)
|
||||||
|
|
||||||
|
// Invalid trusted origins - with scheme
|
||||||
|
f("http://example.com", false)
|
||||||
|
f("https://localhost", false)
|
||||||
|
f("ftp://192.168.1.1", false)
|
||||||
|
|
||||||
|
// Invalid trusted origins - with user info
|
||||||
|
f("user@example.com", false)
|
||||||
|
f("user:pass@localhost", false)
|
||||||
|
|
||||||
|
// Invalid trusted origins - with path
|
||||||
|
f("example.com/path", false)
|
||||||
|
f("localhost/api", false)
|
||||||
|
f("192.168.1.1/static", false)
|
||||||
|
|
||||||
|
// Invalid trusted origins - with query parameters
|
||||||
|
f("example.com?param=value", false)
|
||||||
|
f("localhost:8080?query=test", false)
|
||||||
|
|
||||||
|
// Invalid trusted origins - with fragment
|
||||||
|
f("example.com#fragment", false)
|
||||||
|
f("localhost:3000#section", false)
|
||||||
|
|
||||||
|
// Invalid trusted origins - with multiple invalid components
|
||||||
|
f("https://user@example.com/path?query=value#fragment", false)
|
||||||
|
f("http://localhost:8080/api/v1?param=test", false)
|
||||||
|
|
||||||
|
// Edge cases - ports are no longer allowed
|
||||||
|
f("example.com:0", false) // port 0 is no longer valid
|
||||||
|
f("example.com:65535", false) // max port number is no longer valid
|
||||||
|
f("example.com:99999", false) // invalid port number
|
||||||
|
f("example.com:-1", false) // negative port
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue