fix(csrf): add CSRF protection from the stdlib BE-12810 (#2250)

This commit is contained in:
andres-portainer
2026-04-17 10:51:04 -03:00
committed by GitHub
parent 544e302fe1
commit 8d5675a7d7
3 changed files with 246 additions and 9 deletions
+58 -8
View File
@@ -8,6 +8,7 @@ import (
"os"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/pkg/featureflags"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
gcsrf "github.com/gorilla/csrf"
@@ -17,19 +18,65 @@ import (
const csrfSkipHeader = "X-CSRF-Token-Skip"
// SkipCSRFToken signals that the X-CSRF-Token header should not be sent in the response.
// Deprecated: only meaningful when the "legacy-csrf" feature flag is enabled.
func SkipCSRFToken(w http.ResponseWriter) {
w.Header().Set(csrfSkipHeader, "1")
}
func WithProtect(handler http.Handler, trustedOrigins []string) (http.Handler, error) {
// IsDockerDesktopExtension is used to check if we should skip csrf checks in the request bouncer (ShouldSkipCSRFCheck)
// DOCKER_EXTENSION is set to '1' in build/docker-extension/docker-compose.yml
// DOCKER_EXTENSION=1 is set in build/docker-extension/docker-compose.yml
isDockerDesktopExtension := false
if val, ok := os.LookupEnv("DOCKER_EXTENSION"); ok && val == "1" {
isDockerDesktopExtension = true
}
handler = withSendCSRFToken(handler)
if featureflags.IsEnabled("legacy-csrf") {
return withLegacyProtect(handler, trustedOrigins, isDockerDesktopExtension)
}
cop := http.NewCrossOriginProtection()
for _, origin := range trustedOrigins {
if err := cop.AddTrustedOrigin(origin); err != nil {
return nil, fmt.Errorf("failed to add trusted origin %q: %w", origin, err)
}
}
cop.SetDenyHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Error().Err(cop.Check(r)).
Str("request_url", r.URL.String()).
Str("host", r.Host).
Str("origin", r.Header.Get("Origin")).
Str("sec_fetch_site", r.Header.Get("Sec-Fetch-Site")).
Strs("trusted_origins", trustedOrigins).
Msg("CSRF check failed")
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
}))
protected := cop.Handler(handler)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
skip, err := security.ShouldSkipCSRFCheck(r, isDockerDesktopExtension)
if err != nil {
httperror.WriteError(w, http.StatusForbidden, err.Error(), err)
return
}
if skip {
handler.ServeHTTP(w, r)
return
}
protected.ServeHTTP(w, r)
}), nil
}
// Deprecated: use WithProtect without the "legacy-csrf" feature flag instead.
func withLegacyProtect(handler http.Handler, trustedOrigins []string, isDockerDesktopExtension bool) (http.Handler, error) {
handler = withLegacySendCSRFToken(handler)
token := make([]byte, 32)
if _, err := rand.Read(token); err != nil {
@@ -41,13 +88,14 @@ func WithProtect(handler http.Handler, trustedOrigins []string) (http.Handler, e
gcsrf.Path("/"),
gcsrf.Secure(false),
gcsrf.TrustedOrigins(trustedOrigins),
gcsrf.ErrorHandler(withErrorHandler(trustedOrigins)),
gcsrf.ErrorHandler(withLegacyErrorHandler(trustedOrigins)),
)(handler)
return withSkipCSRF(handler, isDockerDesktopExtension), nil
return withLegacySkipCSRF(handler, isDockerDesktopExtension), nil
}
func withSendCSRFToken(handler http.Handler) http.Handler {
// Deprecated: use WithProtect without the "legacy-csrf" feature flag instead.
func withLegacySendCSRFToken(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sw := negroni.NewResponseWriter(w)
@@ -67,7 +115,8 @@ func withSendCSRFToken(handler http.Handler) http.Handler {
})
}
func withSkipCSRF(handler http.Handler, isDockerDesktopExtension bool) http.Handler {
// Deprecated: use WithProtect without the "legacy-csrf" feature flag instead.
func withLegacySkipCSRF(handler http.Handler, isDockerDesktopExtension bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
skip, err := security.ShouldSkipCSRFCheck(r, isDockerDesktopExtension)
if err != nil {
@@ -84,7 +133,8 @@ func withSkipCSRF(handler http.Handler, isDockerDesktopExtension bool) http.Hand
})
}
func withErrorHandler(trustedOrigins []string) http.Handler {
// Deprecated: use WithProtect without the "legacy-csrf" feature flag instead.
func withLegacyErrorHandler(trustedOrigins []string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := gcsrf.FailureReason(r)
+187
View File
@@ -0,0 +1,187 @@
package csrf
import (
"net/http"
"net/http/httptest"
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/stretchr/testify/require"
)
var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
func TestWithProtect_invalidTrustedOriginReturnsError(t *testing.T) {
_, err := WithProtect(okHandler, []string{"not-a-valid-origin"})
require.Error(t, err)
}
func TestWithProtect_safeMethodsAlwaysAllowed(t *testing.T) {
handler, err := WithProtect(okHandler, nil)
require.NoError(t, err)
for _, method := range []string{http.MethodGet, http.MethodHead, http.MethodOptions} {
req := httptest.NewRequest(method, "/", nil)
req.Header.Set("Sec-Fetch-Site", "cross-site")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code, "method %s should be allowed", method)
}
}
func TestWithProtect_allowsPostWithNoOriginHeaders(t *testing.T) {
handler, err := WithProtect(okHandler, nil)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
}
func TestWithProtect_allowsPostWithSameOriginSecFetchSite(t *testing.T) {
handler, err := WithProtect(okHandler, nil)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set("Sec-Fetch-Site", "same-origin")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
}
func TestWithProtect_allowsPostWithNoneSecFetchSite(t *testing.T) {
handler, err := WithProtect(okHandler, nil)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set("Sec-Fetch-Site", "none")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
}
func TestWithProtect_blocksCrossSiteSecFetchSite(t *testing.T) {
handler, err := WithProtect(okHandler, nil)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set("Sec-Fetch-Site", "cross-site")
req.AddCookie(&http.Cookie{Name: portainer.AuthCookieKey, Value: "some-token"})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusForbidden, rr.Code)
}
func TestWithProtect_blocksSameSiteSecFetchSite(t *testing.T) {
handler, err := WithProtect(okHandler, nil)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set("Sec-Fetch-Site", "same-site")
req.AddCookie(&http.Cookie{Name: portainer.AuthCookieKey, Value: "some-token"})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusForbidden, rr.Code)
}
func TestWithProtect_allowsPostWithMatchingOriginHeader(t *testing.T) {
handler, err := WithProtect(okHandler, nil)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Host = "portainer.example.com"
req.Header.Set("Origin", "https://portainer.example.com")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
}
func TestWithProtect_blocksMismatchedOriginHeader(t *testing.T) {
handler, err := WithProtect(okHandler, nil)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Host = "portainer.example.com"
req.Header.Set("Origin", "https://evil.example.com")
req.AddCookie(&http.Cookie{Name: portainer.AuthCookieKey, Value: "some-token"})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusForbidden, rr.Code)
}
func TestWithProtect_allowsPostFromTrustedOrigin(t *testing.T) {
handler, err := WithProtect(okHandler, []string{"https://trusted.example.com"})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Host = "portainer.example.com"
req.Header.Set("Origin", "https://trusted.example.com")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
}
func TestWithProtect_skipsCsrfForApiKey(t *testing.T) {
handler, err := WithProtect(okHandler, nil)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set("Sec-Fetch-Site", "cross-site")
req.Header.Set("X-API-KEY", "my-api-key")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
}
func TestWithProtect_skipsCsrfForBearerToken(t *testing.T) {
handler, err := WithProtect(okHandler, nil)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set("Sec-Fetch-Site", "cross-site")
req.Header.Set("Authorization", "Bearer some-token")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
}
func TestWithProtect_forbidsBothApiKeyAndBearerToken(t *testing.T) {
handler, err := WithProtect(okHandler, nil)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set("X-API-KEY", "my-api-key")
req.Header.Set("Authorization", "Bearer some-token")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusForbidden, rr.Code)
}
func TestWithProtect_enforcesCsrfForCookieAuth(t *testing.T) {
handler, err := WithProtect(okHandler, nil)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set("Sec-Fetch-Site", "cross-site")
req.AddCookie(&http.Cookie{Name: portainer.AuthCookieKey, Value: "some-token"})
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusForbidden, rr.Code)
}
+1 -1
View File
@@ -2007,7 +2007,7 @@ const (
)
// List of supported features
var SupportedFeatureFlags = []featureflags.Feature{"hsts", "csp"}
var SupportedFeatureFlags = []featureflags.Feature{"hsts", "csp", "legacy-csrf"}
const (
_ AuthenticationMethod = iota