mirror of
https://github.com/portainer/portainer.git
synced 2026-06-23 04:10:29 +00:00
fix(csrf): use the proper format for trusted origins BE-12810 (#2398)
This commit is contained in:
@@ -344,7 +344,7 @@ func buildServer(flags *portainer.CLIFlags, shutdownCtx context.Context, shutdow
|
||||
// validate if the trusted origins are valid urls
|
||||
for origin := range strings.SplitSeq(*flags.TrustedOrigins, ",") {
|
||||
if !validate.IsTrustedOrigin(origin) {
|
||||
log.Fatal().Str("trusted_origin", origin).Msg("invalid url for trusted origin. Please check the trusted origins flag.")
|
||||
log.Fatal().Str("trusted_origin", origin).Msg("invalid trusted origin: must be scheme://host or scheme://host:port (e.g. https://example.com)")
|
||||
}
|
||||
|
||||
trustedOrigins = append(trustedOrigins, origin)
|
||||
|
||||
+14
-1
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
@@ -83,11 +84,23 @@ func withLegacyProtect(handler http.Handler, trustedOrigins []string, isDockerDe
|
||||
return nil, fmt.Errorf("failed to generate CSRF token: %w", err)
|
||||
}
|
||||
|
||||
// gorilla/csrf compares referer.Host against trusted origin entries, so it
|
||||
// needs bare host[:port] values rather than full scheme://host[:port] origins.
|
||||
legacyOrigins := make([]string, len(trustedOrigins))
|
||||
for i, origin := range trustedOrigins {
|
||||
parsed, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse trusted origin %q: %w", origin, err)
|
||||
}
|
||||
|
||||
legacyOrigins[i] = parsed.Host
|
||||
}
|
||||
|
||||
handler = gcsrf.Protect(
|
||||
token,
|
||||
gcsrf.Path("/"),
|
||||
gcsrf.Secure(false),
|
||||
gcsrf.TrustedOrigins(trustedOrigins),
|
||||
gcsrf.TrustedOrigins(legacyOrigins),
|
||||
gcsrf.ErrorHandler(withLegacyErrorHandler(trustedOrigins)),
|
||||
)(handler)
|
||||
|
||||
|
||||
@@ -15,11 +15,15 @@ var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
|
||||
func TestWithProtect_invalidTrustedOriginReturnsError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := WithProtect(okHandler, []string{"not-a-valid-origin"})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestWithProtect_safeMethodsAlwaysAllowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -34,6 +38,8 @@ func TestWithProtect_safeMethodsAlwaysAllowed(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWithProtect_allowsPostWithNoOriginHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -45,6 +51,8 @@ func TestWithProtect_allowsPostWithNoOriginHeaders(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWithProtect_allowsPostWithSameOriginSecFetchSite(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -57,6 +65,8 @@ func TestWithProtect_allowsPostWithSameOriginSecFetchSite(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWithProtect_allowsPostWithNoneSecFetchSite(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -69,6 +79,8 @@ func TestWithProtect_allowsPostWithNoneSecFetchSite(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWithProtect_blocksCrossSiteSecFetchSite(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -82,6 +94,8 @@ func TestWithProtect_blocksCrossSiteSecFetchSite(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWithProtect_blocksSameSiteSecFetchSite(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -95,6 +109,8 @@ func TestWithProtect_blocksSameSiteSecFetchSite(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWithProtect_allowsPostWithMatchingOriginHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -108,6 +124,8 @@ func TestWithProtect_allowsPostWithMatchingOriginHeader(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWithProtect_blocksMismatchedOriginHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -122,6 +140,8 @@ func TestWithProtect_blocksMismatchedOriginHeader(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWithProtect_allowsPostFromTrustedOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := WithProtect(okHandler, []string{"https://trusted.example.com"})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -135,6 +155,8 @@ func TestWithProtect_allowsPostFromTrustedOrigin(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWithProtect_skipsCsrfForApiKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -148,6 +170,8 @@ func TestWithProtect_skipsCsrfForApiKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWithProtect_skipsCsrfForBearerToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -161,6 +185,8 @@ func TestWithProtect_skipsCsrfForBearerToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWithProtect_forbidsBothApiKeyAndBearerToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -174,6 +200,8 @@ func TestWithProtect_forbidsBothApiKeyAndBearerToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWithProtect_enforcesCsrfForCookieAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -185,3 +213,88 @@ func TestWithProtect_enforcesCsrfForCookieAuth(t *testing.T) {
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusForbidden, rr.Code)
|
||||
}
|
||||
|
||||
func TestWithLegacyProtect_noError_noOrigins(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := withLegacyProtect(okHandler, nil, false)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestWithLegacyProtect_noError_schemeHostOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := withLegacyProtect(okHandler, []string{"https://example.com"}, false)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestWithLegacyProtect_noError_schemeHostPortOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := withLegacyProtect(okHandler, []string{"https://example.com:3000"}, false)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestWithLegacyProtect_noError_multipleOrigins(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := withLegacyProtect(okHandler, []string{"https://example.com", "http://internal.example.com:8080"}, false)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestWithLegacyProtect_safeMethodsAlwaysAllowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := withLegacyProtect(okHandler, nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, method := range []string{http.MethodGet, http.MethodHead, http.MethodOptions} {
|
||||
req := httptest.NewRequest(method, "/", nil)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code, "method %s should be allowed", method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithLegacyProtect_blocksPostWithoutToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := withLegacyProtect(okHandler, nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.AddCookie(&http.Cookie{Name: portainer.AuthCookieKey, Value: "some-token"})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusForbidden, rr.Code)
|
||||
}
|
||||
|
||||
func TestWithLegacyProtect_skipsCsrfForApiKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := withLegacyProtect(okHandler, nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Header.Set("X-API-KEY", "my-api-key")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
func TestWithLegacyProtect_skipsCsrfForBearerToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := withLegacyProtect(okHandler, nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer some-token")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
@@ -82,19 +82,19 @@ func IsDNSName(s string) bool {
|
||||
}
|
||||
|
||||
func IsTrustedOrigin(s string) bool {
|
||||
// Reject if a scheme is present
|
||||
if strings.Contains(s, "://") {
|
||||
if !strings.Contains(s, "://") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Prepend http:// for parsing
|
||||
strTemp := "http://" + s
|
||||
parsedOrigin, err := url.Parse(strTemp)
|
||||
parsedOrigin, err := url.Parse(s)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate host, and ensure no user, path, query, fragment, port, etc.
|
||||
if parsedOrigin.Scheme != "http" && parsedOrigin.Scheme != "https" {
|
||||
return false
|
||||
}
|
||||
|
||||
if parsedOrigin.Host == "" ||
|
||||
parsedOrigin.User != nil ||
|
||||
parsedOrigin.Path != "" ||
|
||||
@@ -102,8 +102,7 @@ func IsTrustedOrigin(s string) bool {
|
||||
parsedOrigin.Fragment != "" ||
|
||||
parsedOrigin.Opaque != "" ||
|
||||
parsedOrigin.RawFragment != "" ||
|
||||
parsedOrigin.RawPath != "" ||
|
||||
parsedOrigin.Port() != "" {
|
||||
parsedOrigin.RawPath != "" {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -460,53 +460,50 @@ func Test_IsTrustedOrigin(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Valid trusted origins - host only
|
||||
f("localhost", true)
|
||||
f("example.com", true)
|
||||
f("192.168.1.1", true)
|
||||
f("api.example.com", true)
|
||||
f("subdomain.example.org", true)
|
||||
// Valid trusted origins - scheme + host
|
||||
f("http://localhost", true)
|
||||
f("https://example.com", true)
|
||||
f("http://192.168.1.1", true)
|
||||
f("https://api.example.com", true)
|
||||
f("https://subdomain.example.org", true)
|
||||
|
||||
// Invalid trusted origins - host with port (no longer allowed)
|
||||
f("localhost:8080", false)
|
||||
f("example.com:3000", false)
|
||||
f("192.168.1.1:443", false)
|
||||
f("api.example.com:9000", false)
|
||||
// Valid trusted origins - scheme + host + port
|
||||
f("http://localhost:8080", true)
|
||||
f("https://example.com:3000", true)
|
||||
f("http://192.168.1.1:443", true)
|
||||
f("https://api.example.com:9000", true)
|
||||
|
||||
// Invalid trusted origins - bare hostname (no scheme)
|
||||
f("localhost", false)
|
||||
f("example.com", false)
|
||||
f("192.168.1.1", false)
|
||||
|
||||
// Invalid trusted origins - empty or malformed
|
||||
f("", false)
|
||||
f("invalid url", false)
|
||||
f("://example.com", false)
|
||||
|
||||
// Invalid trusted origins - with scheme
|
||||
f("http://example.com", false)
|
||||
f("https://localhost", false)
|
||||
// Invalid trusted origins - unsupported scheme
|
||||
f("ftp://192.168.1.1", false)
|
||||
|
||||
// Invalid trusted origins - with user info
|
||||
f("user@example.com", false)
|
||||
f("user:pass@localhost", false)
|
||||
f("http://user@example.com", false)
|
||||
f("http://user:pass@localhost", false)
|
||||
|
||||
// Invalid trusted origins - with path
|
||||
f("example.com/path", false)
|
||||
f("localhost/api", false)
|
||||
f("192.168.1.1/static", false)
|
||||
f("https://example.com/path", false)
|
||||
f("http://localhost/api", false)
|
||||
f("http://192.168.1.1/static", false)
|
||||
|
||||
// Invalid trusted origins - with query parameters
|
||||
f("example.com?param=value", false)
|
||||
f("localhost:8080?query=test", false)
|
||||
f("https://example.com?param=value", false)
|
||||
f("http://localhost:8080?query=test", false)
|
||||
|
||||
// Invalid trusted origins - with fragment
|
||||
f("example.com#fragment", false)
|
||||
f("localhost:3000#section", false)
|
||||
f("https://example.com#fragment", false)
|
||||
f("http://localhost:3000#section", false)
|
||||
|
||||
// Invalid trusted origins - with multiple invalid components
|
||||
f("https://user@example.com/path?query=value#fragment", false)
|
||||
f("http://localhost:8080/api/v1?param=test", false)
|
||||
|
||||
// Edge cases - ports are no longer allowed
|
||||
f("example.com:0", false) // port 0 is no longer valid
|
||||
f("example.com:65535", false) // max port number is no longer valid
|
||||
f("example.com:99999", false) // invalid port number
|
||||
f("example.com:-1", false) // negative port
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user