From 825a7669a665b77fd97db4da2aa31eff30589368 Mon Sep 17 00:00:00 2001 From: andres-portainer <91705312+andres-portainer@users.noreply.github.com> Date: Tue, 21 Apr 2026 11:52:58 -0300 Subject: [PATCH] fix(csrf): use the proper format for trusted origins BE-12810 (#2398) --- api/cmd/portainer/main.go | 2 +- api/http/csrf/csrf.go | 15 ++++- api/http/csrf/csrf_test.go | 113 ++++++++++++++++++++++++++++++++++ pkg/validate/validate.go | 15 +++-- pkg/validate/validate_test.go | 55 ++++++++--------- 5 files changed, 161 insertions(+), 39 deletions(-) diff --git a/api/cmd/portainer/main.go b/api/cmd/portainer/main.go index 094ef4fda6..76a9db4ba9 100644 --- a/api/cmd/portainer/main.go +++ b/api/cmd/portainer/main.go @@ -344,7 +344,7 @@ func buildServer(flags *portainer.CLIFlags, shutdownCtx context.Context, shutdow // validate if the trusted origins are valid urls for origin := range strings.SplitSeq(*flags.TrustedOrigins, ",") { if !validate.IsTrustedOrigin(origin) { - log.Fatal().Str("trusted_origin", origin).Msg("invalid url for trusted origin. Please check the trusted origins flag.") + log.Fatal().Str("trusted_origin", origin).Msg("invalid trusted origin: must be scheme://host or scheme://host:port (e.g. https://example.com)") } trustedOrigins = append(trustedOrigins, origin) diff --git a/api/http/csrf/csrf.go b/api/http/csrf/csrf.go index 93ac2fe58b..0616eb69c1 100644 --- a/api/http/csrf/csrf.go +++ b/api/http/csrf/csrf.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "net/url" "os" "github.com/portainer/portainer/api/http/security" @@ -83,11 +84,23 @@ func withLegacyProtect(handler http.Handler, trustedOrigins []string, isDockerDe return nil, fmt.Errorf("failed to generate CSRF token: %w", err) } + // gorilla/csrf compares referer.Host against trusted origin entries, so it + // needs bare host[:port] values rather than full scheme://host[:port] origins. + legacyOrigins := make([]string, len(trustedOrigins)) + for i, origin := range trustedOrigins { + parsed, err := url.Parse(origin) + if err != nil { + return nil, fmt.Errorf("failed to parse trusted origin %q: %w", origin, err) + } + + legacyOrigins[i] = parsed.Host + } + handler = gcsrf.Protect( token, gcsrf.Path("/"), gcsrf.Secure(false), - gcsrf.TrustedOrigins(trustedOrigins), + gcsrf.TrustedOrigins(legacyOrigins), gcsrf.ErrorHandler(withLegacyErrorHandler(trustedOrigins)), )(handler) diff --git a/api/http/csrf/csrf_test.go b/api/http/csrf/csrf_test.go index 57890166a9..034a882830 100644 --- a/api/http/csrf/csrf_test.go +++ b/api/http/csrf/csrf_test.go @@ -15,11 +15,15 @@ var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { }) func TestWithProtect_invalidTrustedOriginReturnsError(t *testing.T) { + t.Parallel() + _, err := WithProtect(okHandler, []string{"not-a-valid-origin"}) require.Error(t, err) } func TestWithProtect_safeMethodsAlwaysAllowed(t *testing.T) { + t.Parallel() + handler, err := WithProtect(okHandler, nil) require.NoError(t, err) @@ -34,6 +38,8 @@ func TestWithProtect_safeMethodsAlwaysAllowed(t *testing.T) { } func TestWithProtect_allowsPostWithNoOriginHeaders(t *testing.T) { + t.Parallel() + handler, err := WithProtect(okHandler, nil) require.NoError(t, err) @@ -45,6 +51,8 @@ func TestWithProtect_allowsPostWithNoOriginHeaders(t *testing.T) { } func TestWithProtect_allowsPostWithSameOriginSecFetchSite(t *testing.T) { + t.Parallel() + handler, err := WithProtect(okHandler, nil) require.NoError(t, err) @@ -57,6 +65,8 @@ func TestWithProtect_allowsPostWithSameOriginSecFetchSite(t *testing.T) { } func TestWithProtect_allowsPostWithNoneSecFetchSite(t *testing.T) { + t.Parallel() + handler, err := WithProtect(okHandler, nil) require.NoError(t, err) @@ -69,6 +79,8 @@ func TestWithProtect_allowsPostWithNoneSecFetchSite(t *testing.T) { } func TestWithProtect_blocksCrossSiteSecFetchSite(t *testing.T) { + t.Parallel() + handler, err := WithProtect(okHandler, nil) require.NoError(t, err) @@ -82,6 +94,8 @@ func TestWithProtect_blocksCrossSiteSecFetchSite(t *testing.T) { } func TestWithProtect_blocksSameSiteSecFetchSite(t *testing.T) { + t.Parallel() + handler, err := WithProtect(okHandler, nil) require.NoError(t, err) @@ -95,6 +109,8 @@ func TestWithProtect_blocksSameSiteSecFetchSite(t *testing.T) { } func TestWithProtect_allowsPostWithMatchingOriginHeader(t *testing.T) { + t.Parallel() + handler, err := WithProtect(okHandler, nil) require.NoError(t, err) @@ -108,6 +124,8 @@ func TestWithProtect_allowsPostWithMatchingOriginHeader(t *testing.T) { } func TestWithProtect_blocksMismatchedOriginHeader(t *testing.T) { + t.Parallel() + handler, err := WithProtect(okHandler, nil) require.NoError(t, err) @@ -122,6 +140,8 @@ func TestWithProtect_blocksMismatchedOriginHeader(t *testing.T) { } func TestWithProtect_allowsPostFromTrustedOrigin(t *testing.T) { + t.Parallel() + handler, err := WithProtect(okHandler, []string{"https://trusted.example.com"}) require.NoError(t, err) @@ -135,6 +155,8 @@ func TestWithProtect_allowsPostFromTrustedOrigin(t *testing.T) { } func TestWithProtect_skipsCsrfForApiKey(t *testing.T) { + t.Parallel() + handler, err := WithProtect(okHandler, nil) require.NoError(t, err) @@ -148,6 +170,8 @@ func TestWithProtect_skipsCsrfForApiKey(t *testing.T) { } func TestWithProtect_skipsCsrfForBearerToken(t *testing.T) { + t.Parallel() + handler, err := WithProtect(okHandler, nil) require.NoError(t, err) @@ -161,6 +185,8 @@ func TestWithProtect_skipsCsrfForBearerToken(t *testing.T) { } func TestWithProtect_forbidsBothApiKeyAndBearerToken(t *testing.T) { + t.Parallel() + handler, err := WithProtect(okHandler, nil) require.NoError(t, err) @@ -174,6 +200,8 @@ func TestWithProtect_forbidsBothApiKeyAndBearerToken(t *testing.T) { } func TestWithProtect_enforcesCsrfForCookieAuth(t *testing.T) { + t.Parallel() + handler, err := WithProtect(okHandler, nil) require.NoError(t, err) @@ -185,3 +213,88 @@ func TestWithProtect_enforcesCsrfForCookieAuth(t *testing.T) { handler.ServeHTTP(rr, req) require.Equal(t, http.StatusForbidden, rr.Code) } + +func TestWithLegacyProtect_noError_noOrigins(t *testing.T) { + t.Parallel() + + _, err := withLegacyProtect(okHandler, nil, false) + require.NoError(t, err) +} + +func TestWithLegacyProtect_noError_schemeHostOrigin(t *testing.T) { + t.Parallel() + + _, err := withLegacyProtect(okHandler, []string{"https://example.com"}, false) + require.NoError(t, err) +} + +func TestWithLegacyProtect_noError_schemeHostPortOrigin(t *testing.T) { + t.Parallel() + + _, err := withLegacyProtect(okHandler, []string{"https://example.com:3000"}, false) + require.NoError(t, err) +} + +func TestWithLegacyProtect_noError_multipleOrigins(t *testing.T) { + t.Parallel() + + _, err := withLegacyProtect(okHandler, []string{"https://example.com", "http://internal.example.com:8080"}, false) + require.NoError(t, err) +} + +func TestWithLegacyProtect_safeMethodsAlwaysAllowed(t *testing.T) { + t.Parallel() + + handler, err := withLegacyProtect(okHandler, nil, false) + require.NoError(t, err) + + for _, method := range []string{http.MethodGet, http.MethodHead, http.MethodOptions} { + req := httptest.NewRequest(method, "/", nil) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusOK, rr.Code, "method %s should be allowed", method) + } +} + +func TestWithLegacyProtect_blocksPostWithoutToken(t *testing.T) { + t.Parallel() + + handler, err := withLegacyProtect(okHandler, nil, false) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.AddCookie(&http.Cookie{Name: portainer.AuthCookieKey, Value: "some-token"}) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusForbidden, rr.Code) +} + +func TestWithLegacyProtect_skipsCsrfForApiKey(t *testing.T) { + t.Parallel() + + handler, err := withLegacyProtect(okHandler, nil, false) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set("X-API-KEY", "my-api-key") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusOK, rr.Code) +} + +func TestWithLegacyProtect_skipsCsrfForBearerToken(t *testing.T) { + t.Parallel() + + handler, err := withLegacyProtect(okHandler, nil, false) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set("Authorization", "Bearer some-token") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusOK, rr.Code) +} diff --git a/pkg/validate/validate.go b/pkg/validate/validate.go index 8ad69df72a..d9f481122c 100644 --- a/pkg/validate/validate.go +++ b/pkg/validate/validate.go @@ -82,19 +82,19 @@ func IsDNSName(s string) bool { } func IsTrustedOrigin(s string) bool { - // Reject if a scheme is present - if strings.Contains(s, "://") { + if !strings.Contains(s, "://") { return false } - // Prepend http:// for parsing - strTemp := "http://" + s - parsedOrigin, err := url.Parse(strTemp) + parsedOrigin, err := url.Parse(s) if err != nil { return false } - // Validate host, and ensure no user, path, query, fragment, port, etc. + if parsedOrigin.Scheme != "http" && parsedOrigin.Scheme != "https" { + return false + } + if parsedOrigin.Host == "" || parsedOrigin.User != nil || parsedOrigin.Path != "" || @@ -102,8 +102,7 @@ func IsTrustedOrigin(s string) bool { parsedOrigin.Fragment != "" || parsedOrigin.Opaque != "" || parsedOrigin.RawFragment != "" || - parsedOrigin.RawPath != "" || - parsedOrigin.Port() != "" { + parsedOrigin.RawPath != "" { return false } diff --git a/pkg/validate/validate_test.go b/pkg/validate/validate_test.go index c086d62a78..b506da0d5a 100644 --- a/pkg/validate/validate_test.go +++ b/pkg/validate/validate_test.go @@ -460,53 +460,50 @@ func Test_IsTrustedOrigin(t *testing.T) { } } - // Valid trusted origins - host only - f("localhost", true) - f("example.com", true) - f("192.168.1.1", true) - f("api.example.com", true) - f("subdomain.example.org", true) + // Valid trusted origins - scheme + host + f("http://localhost", true) + f("https://example.com", true) + f("http://192.168.1.1", true) + f("https://api.example.com", true) + f("https://subdomain.example.org", true) - // Invalid trusted origins - host with port (no longer allowed) - f("localhost:8080", false) - f("example.com:3000", false) - f("192.168.1.1:443", false) - f("api.example.com:9000", false) + // Valid trusted origins - scheme + host + port + f("http://localhost:8080", true) + f("https://example.com:3000", true) + f("http://192.168.1.1:443", true) + f("https://api.example.com:9000", true) + + // Invalid trusted origins - bare hostname (no scheme) + f("localhost", false) + f("example.com", false) + f("192.168.1.1", false) // Invalid trusted origins - empty or malformed f("", false) f("invalid url", false) f("://example.com", false) - // Invalid trusted origins - with scheme - f("http://example.com", false) - f("https://localhost", false) + // Invalid trusted origins - unsupported scheme f("ftp://192.168.1.1", false) // Invalid trusted origins - with user info - f("user@example.com", false) - f("user:pass@localhost", false) + f("http://user@example.com", false) + f("http://user:pass@localhost", false) // Invalid trusted origins - with path - f("example.com/path", false) - f("localhost/api", false) - f("192.168.1.1/static", false) + f("https://example.com/path", false) + f("http://localhost/api", false) + f("http://192.168.1.1/static", false) // Invalid trusted origins - with query parameters - f("example.com?param=value", false) - f("localhost:8080?query=test", false) + f("https://example.com?param=value", false) + f("http://localhost:8080?query=test", false) // Invalid trusted origins - with fragment - f("example.com#fragment", false) - f("localhost:3000#section", false) + f("https://example.com#fragment", false) + f("http://localhost:3000#section", false) // Invalid trusted origins - with multiple invalid components f("https://user@example.com/path?query=value#fragment", false) f("http://localhost:8080/api/v1?param=test", false) - - // Edge cases - ports are no longer allowed - f("example.com:0", false) // port 0 is no longer valid - f("example.com:65535", false) // max port number is no longer valid - f("example.com:99999", false) // invalid port number - f("example.com:-1", false) // negative port }