mirror of
https://github.com/portainer/portainer.git
synced 2026-06-23 07:40:11 +00:00
fix(csrf): add CSRF protection from the stdlib BE-12810 (#2250)
This commit is contained in:
+58
-8
@@ -8,6 +8,7 @@ import (
|
||||
"os"
|
||||
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/pkg/featureflags"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
|
||||
gcsrf "github.com/gorilla/csrf"
|
||||
@@ -17,19 +18,65 @@ import (
|
||||
|
||||
const csrfSkipHeader = "X-CSRF-Token-Skip"
|
||||
|
||||
// SkipCSRFToken signals that the X-CSRF-Token header should not be sent in the response.
|
||||
// Deprecated: only meaningful when the "legacy-csrf" feature flag is enabled.
|
||||
func SkipCSRFToken(w http.ResponseWriter) {
|
||||
w.Header().Set(csrfSkipHeader, "1")
|
||||
}
|
||||
|
||||
func WithProtect(handler http.Handler, trustedOrigins []string) (http.Handler, error) {
|
||||
// IsDockerDesktopExtension is used to check if we should skip csrf checks in the request bouncer (ShouldSkipCSRFCheck)
|
||||
// DOCKER_EXTENSION is set to '1' in build/docker-extension/docker-compose.yml
|
||||
// DOCKER_EXTENSION=1 is set in build/docker-extension/docker-compose.yml
|
||||
isDockerDesktopExtension := false
|
||||
if val, ok := os.LookupEnv("DOCKER_EXTENSION"); ok && val == "1" {
|
||||
isDockerDesktopExtension = true
|
||||
}
|
||||
|
||||
handler = withSendCSRFToken(handler)
|
||||
if featureflags.IsEnabled("legacy-csrf") {
|
||||
return withLegacyProtect(handler, trustedOrigins, isDockerDesktopExtension)
|
||||
}
|
||||
|
||||
cop := http.NewCrossOriginProtection()
|
||||
for _, origin := range trustedOrigins {
|
||||
if err := cop.AddTrustedOrigin(origin); err != nil {
|
||||
return nil, fmt.Errorf("failed to add trusted origin %q: %w", origin, err)
|
||||
}
|
||||
}
|
||||
|
||||
cop.SetDenyHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Error().Err(cop.Check(r)).
|
||||
Str("request_url", r.URL.String()).
|
||||
Str("host", r.Host).
|
||||
Str("origin", r.Header.Get("Origin")).
|
||||
Str("sec_fetch_site", r.Header.Get("Sec-Fetch-Site")).
|
||||
Strs("trusted_origins", trustedOrigins).
|
||||
Msg("CSRF check failed")
|
||||
|
||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
}))
|
||||
|
||||
protected := cop.Handler(handler)
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
skip, err := security.ShouldSkipCSRFCheck(r, isDockerDesktopExtension)
|
||||
if err != nil {
|
||||
httperror.WriteError(w, http.StatusForbidden, err.Error(), err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if skip {
|
||||
handler.ServeHTTP(w, r)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
protected.ServeHTTP(w, r)
|
||||
}), nil
|
||||
}
|
||||
|
||||
// Deprecated: use WithProtect without the "legacy-csrf" feature flag instead.
|
||||
func withLegacyProtect(handler http.Handler, trustedOrigins []string, isDockerDesktopExtension bool) (http.Handler, error) {
|
||||
handler = withLegacySendCSRFToken(handler)
|
||||
|
||||
token := make([]byte, 32)
|
||||
if _, err := rand.Read(token); err != nil {
|
||||
@@ -41,13 +88,14 @@ func WithProtect(handler http.Handler, trustedOrigins []string) (http.Handler, e
|
||||
gcsrf.Path("/"),
|
||||
gcsrf.Secure(false),
|
||||
gcsrf.TrustedOrigins(trustedOrigins),
|
||||
gcsrf.ErrorHandler(withErrorHandler(trustedOrigins)),
|
||||
gcsrf.ErrorHandler(withLegacyErrorHandler(trustedOrigins)),
|
||||
)(handler)
|
||||
|
||||
return withSkipCSRF(handler, isDockerDesktopExtension), nil
|
||||
return withLegacySkipCSRF(handler, isDockerDesktopExtension), nil
|
||||
}
|
||||
|
||||
func withSendCSRFToken(handler http.Handler) http.Handler {
|
||||
// Deprecated: use WithProtect without the "legacy-csrf" feature flag instead.
|
||||
func withLegacySendCSRFToken(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sw := negroni.NewResponseWriter(w)
|
||||
|
||||
@@ -67,7 +115,8 @@ func withSendCSRFToken(handler http.Handler) http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
func withSkipCSRF(handler http.Handler, isDockerDesktopExtension bool) http.Handler {
|
||||
// Deprecated: use WithProtect without the "legacy-csrf" feature flag instead.
|
||||
func withLegacySkipCSRF(handler http.Handler, isDockerDesktopExtension bool) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
skip, err := security.ShouldSkipCSRFCheck(r, isDockerDesktopExtension)
|
||||
if err != nil {
|
||||
@@ -84,7 +133,8 @@ func withSkipCSRF(handler http.Handler, isDockerDesktopExtension bool) http.Hand
|
||||
})
|
||||
}
|
||||
|
||||
func withErrorHandler(trustedOrigins []string) http.Handler {
|
||||
// Deprecated: use WithProtect without the "legacy-csrf" feature flag instead.
|
||||
func withLegacyErrorHandler(trustedOrigins []string) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := gcsrf.FailureReason(r)
|
||||
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
package csrf
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
func TestWithProtect_invalidTrustedOriginReturnsError(t *testing.T) {
|
||||
_, err := WithProtect(okHandler, []string{"not-a-valid-origin"})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestWithProtect_safeMethodsAlwaysAllowed(t *testing.T) {
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, method := range []string{http.MethodGet, http.MethodHead, http.MethodOptions} {
|
||||
req := httptest.NewRequest(method, "/", nil)
|
||||
req.Header.Set("Sec-Fetch-Site", "cross-site")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code, "method %s should be allowed", method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithProtect_allowsPostWithNoOriginHeaders(t *testing.T) {
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
func TestWithProtect_allowsPostWithSameOriginSecFetchSite(t *testing.T) {
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Header.Set("Sec-Fetch-Site", "same-origin")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
func TestWithProtect_allowsPostWithNoneSecFetchSite(t *testing.T) {
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Header.Set("Sec-Fetch-Site", "none")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
func TestWithProtect_blocksCrossSiteSecFetchSite(t *testing.T) {
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Header.Set("Sec-Fetch-Site", "cross-site")
|
||||
req.AddCookie(&http.Cookie{Name: portainer.AuthCookieKey, Value: "some-token"})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusForbidden, rr.Code)
|
||||
}
|
||||
|
||||
func TestWithProtect_blocksSameSiteSecFetchSite(t *testing.T) {
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Header.Set("Sec-Fetch-Site", "same-site")
|
||||
req.AddCookie(&http.Cookie{Name: portainer.AuthCookieKey, Value: "some-token"})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusForbidden, rr.Code)
|
||||
}
|
||||
|
||||
func TestWithProtect_allowsPostWithMatchingOriginHeader(t *testing.T) {
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Host = "portainer.example.com"
|
||||
req.Header.Set("Origin", "https://portainer.example.com")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
func TestWithProtect_blocksMismatchedOriginHeader(t *testing.T) {
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Host = "portainer.example.com"
|
||||
req.Header.Set("Origin", "https://evil.example.com")
|
||||
req.AddCookie(&http.Cookie{Name: portainer.AuthCookieKey, Value: "some-token"})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusForbidden, rr.Code)
|
||||
}
|
||||
|
||||
func TestWithProtect_allowsPostFromTrustedOrigin(t *testing.T) {
|
||||
handler, err := WithProtect(okHandler, []string{"https://trusted.example.com"})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Host = "portainer.example.com"
|
||||
req.Header.Set("Origin", "https://trusted.example.com")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
func TestWithProtect_skipsCsrfForApiKey(t *testing.T) {
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Header.Set("Sec-Fetch-Site", "cross-site")
|
||||
req.Header.Set("X-API-KEY", "my-api-key")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
func TestWithProtect_skipsCsrfForBearerToken(t *testing.T) {
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Header.Set("Sec-Fetch-Site", "cross-site")
|
||||
req.Header.Set("Authorization", "Bearer some-token")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
func TestWithProtect_forbidsBothApiKeyAndBearerToken(t *testing.T) {
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Header.Set("X-API-KEY", "my-api-key")
|
||||
req.Header.Set("Authorization", "Bearer some-token")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusForbidden, rr.Code)
|
||||
}
|
||||
|
||||
func TestWithProtect_enforcesCsrfForCookieAuth(t *testing.T) {
|
||||
handler, err := WithProtect(okHandler, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Header.Set("Sec-Fetch-Site", "cross-site")
|
||||
req.AddCookie(&http.Cookie{Name: portainer.AuthCookieKey, Value: "some-token"})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusForbidden, rr.Code)
|
||||
}
|
||||
+1
-1
@@ -2007,7 +2007,7 @@ const (
|
||||
)
|
||||
|
||||
// List of supported features
|
||||
var SupportedFeatureFlags = []featureflags.Feature{"hsts", "csp"}
|
||||
var SupportedFeatureFlags = []featureflags.Feature{"hsts", "csp", "legacy-csrf"}
|
||||
|
||||
const (
|
||||
_ AuthenticationMethod = iota
|
||||
|
||||
Reference in New Issue
Block a user