diff --git a/.golangci.yaml b/.golangci.yaml index 2fedf1ead4..4659ba4710 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -5,6 +5,7 @@ run: linters: default: none enable: + - gocritic - bodyclose - copyloopvar - depguard @@ -76,6 +77,13 @@ linters: desc: use github.com/Masterminds/semver/v3 - pkg: github.com/hashicorp/go-version desc: use github.com/Masterminds/semver/v3 + gocritic: + disable-all: true + enabled-checks: + - ruleguard + settings: + ruleguard: + rules: "./analysis/ssrf.go" forbidigo: forbid: - pattern: ^tls\.Config$ @@ -93,6 +101,11 @@ linters: - comments - common-false-positives - legacy + rules: + - path: pkg/libhttp/ssrf + linters: + - gocritic + text: ruleguard paths: - third_party$ - builtin$ diff --git a/analysis/ssrf.go b/analysis/ssrf.go new file mode 100644 index 0000000000..3af95653ba --- /dev/null +++ b/analysis/ssrf.go @@ -0,0 +1,29 @@ +//go:build ignore + +package gorules + +import "github.com/quasilyte/go-ruleguard/dsl" + +// unwrappedHTTPTransport flags http.Transport composite literals that are not +// the direct argument to ssrf.WrapTransport. +func unwrappedHTTPTransport(m dsl.Matcher) { + // Inline construction passed to a function call. + m.Match(`$f(&http.Transport{$*_})`). + Where(m["f"].Text != "ssrf.WrapTransport" && m["f"].Text != "WrapTransport" && + m["f"].Text != "ssrf.WrapTransportInternal" && m["f"].Text != "WrapTransportInternal"). + Report(`$f receives a bare *http.Transport; wrap with ssrf.WrapTransport() to enforce the SSRF protection policy`) + + // Variable assigned a bare transport (cannot be tracked to a later WrapTransport call). + m.Match(`$_ := &http.Transport{$*_}`). + Report(`bare *http.Transport variable; use ssrf.WrapTransport(&http.Transport{...}) inline instead`) +} + +// internalTransportMisuse flags calls to WrapTransportInternal outside the four proxy +// factory files where Chisel-tunnel and in-cluster K8s destinations are valid exemptions. +func internalTransportMisuse(m dsl.Matcher) { + m.Match(`ssrf.WrapTransportInternal($*_)`). + Where( + !(m.File().PkgPath.Matches(`proxy/factory`) && + m.File().Name.Matches(`^(docker|agent|local_transport|edge_transport)\.go$`))). + Report(`WrapTransportInternal bypasses SSRF validation; only valid in the kubernetes local/edge transport constructors and the docker/agent proxy factories`) +} diff --git a/analysis/tools.go b/analysis/tools.go new file mode 100644 index 0000000000..55f01d8f8b --- /dev/null +++ b/analysis/tools.go @@ -0,0 +1,5 @@ +//go:build tools + +package gorules + +import _ "github.com/quasilyte/go-ruleguard/dsl" diff --git a/api/cli/cli.go b/api/cli/cli.go index d7fa93cc63..5e1148ee51 100644 --- a/api/cli/cli.go +++ b/api/cli/cli.go @@ -56,6 +56,8 @@ func CLIFlags() *portainer.CLIFlags { TrustedOrigins: kingpin.Flag("trusted-origins", "List of trusted origins for CSRF protection. Separate multiple origins with a comma.").Envar(portainer.TrustedOriginsEnvVar).String(), CSP: kingpin.Flag("csp", "Content Security Policy (CSP) header").Envar(portainer.CSPEnvVar).Default("true").Bool(), CompactDB: kingpin.Flag("compact-db", "Enable database compaction on startup").Envar(portainer.CompactDBEnvVar).Default("false").Bool(), + SSRFMode: kingpin.Flag("ssrf-mode", "SSRF protection mode: off (disabled), audit (log violations but allow), enforce (block violations)").Envar("PORTAINER_SSRF_MODE").Default("off").Enum("off", "audit", "enforce"), + SSRFAllowedHosts: kingpin.Flag("ssrf-allowed-hosts", "Allowlist of hostnames (with optional wildcards), IPs, or CIDRs permitted for outbound requests. When empty and mode is enforce, all outbound connections are blocked").Envar("PORTAINER_SSRF_ALLOWED_HOSTS").Strings(), NoSetupToken: kingpin.Flag("no-setup-token", "Disable the setup token requirement for admin initialization and restore on an uninitialized instance").Envar(portainer.NoSetupTokenEnvVar).Bool(), SetupToken: kingpin.Flag("setup-token", "Set a custom setup token for admin initialization and restore on an uninitialized instance (overrides auto-generation)").Envar(portainer.SetupTokenEnvVar).String(), } diff --git a/api/cmd/portainer/main.go b/api/cmd/portainer/main.go index 91f94ba1f2..1d77814e45 100644 --- a/api/cmd/portainer/main.go +++ b/api/cmd/portainer/main.go @@ -4,6 +4,7 @@ import ( "cmp" "context" "crypto/sha256" + nethttp "net/http" "os" "path" "strings" @@ -52,10 +53,12 @@ import ( "github.com/portainer/portainer/pkg/featureflags" "github.com/portainer/portainer/pkg/fips" "github.com/portainer/portainer/pkg/libhelm" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/portainer/portainer/pkg/libstack/compose" libswarm "github.com/portainer/portainer/pkg/libstack/swarm" "github.com/portainer/portainer/pkg/validate" + gogithttp "github.com/go-git/go-git/v5/plumbing/transport/http" "github.com/google/uuid" "github.com/rs/zerolog/log" ) @@ -384,6 +387,19 @@ func buildServer(flags *portainer.CLIFlags, shutdownCtx context.Context, shutdow // -ce can not ever be run in FIPS mode fips.InitFIPS(false) + ssrf.Configure(ssrf.Policy{ + Mode: ssrf.Mode(*flags.SSRFMode), + AllowedHosts: *flags.SSRFAllowedHosts, + }) + + if ssrf.IsEnabled() { + if dt, ok := nethttp.DefaultTransport.(*nethttp.Transport); ok { + nethttp.DefaultTransport = ssrf.WrapTransport(dt) + } + + gogithttp.DefaultClient = gogithttp.NewClient(&nethttp.Client{Transport: nethttp.DefaultTransport}) + } + fileService := initFileService(*flags.Data) encryptionKey := loadEncryptionSecretKey(dbSecretPath(*flags.SecretKeyName)) if encryptionKey == nil { diff --git a/api/docker/client/client.go b/api/docker/client/client.go index 86cc1fa4ba..6ef575b755 100644 --- a/api/docker/client/client.go +++ b/api/docker/client/client.go @@ -11,6 +11,8 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/crypto" + "github.com/portainer/portainer/pkg/libhttp/ssrf" + "github.com/rs/zerolog/log" "github.com/docker/docker/api/types/image" @@ -184,17 +186,20 @@ func (t *NodeNameTransport) RoundTrip(req *http.Request) (*http.Response, error) } func httpClient(endpoint *portainer.Endpoint, timeout *time.Duration) (*http.Client, error) { - transport := &NodeNameTransport{ - Transport: &http.Transport{}, - } - + var transport *NodeNameTransport if endpoint.TLSConfig.TLS { tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig) if err != nil { return nil, err } - transport.TLSClientConfig = tlsConfig + transport = &NodeNameTransport{ + Transport: ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig}), + } + } else { + transport = &NodeNameTransport{ + Transport: ssrf.WrapTransport(&http.Transport{}), + } } clientTimeout := defaultDockerRequestTimeout diff --git a/api/git/azure.go b/api/git/azure.go index 79e5c340c9..2acf14b233 100644 --- a/api/git/azure.go +++ b/api/git/azure.go @@ -14,12 +14,13 @@ import ( "github.com/portainer/portainer/api/crypto" gittypes "github.com/portainer/portainer/api/git/types" "github.com/portainer/portainer/api/logs" - "github.com/rs/zerolog/log" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing/filemode" githttp "github.com/go-git/go-git/v5/plumbing/transport/http" "github.com/pkg/errors" + "github.com/rs/zerolog/log" "github.com/segmentio/encoding/json" ) @@ -64,15 +65,13 @@ func NewAzureClient() *azureClient { } func newHttpClientForAzure(insecureSkipVerify bool) *http.Client { - httpsCli := &http.Client{ - Transport: &http.Transport{ + return &http.Client{ + Transport: ssrf.WrapTransport(&http.Transport{ TLSClientConfig: crypto.CreateTLSConfiguration(insecureSkipVerify), Proxy: http.ProxyFromEnvironment, - }, + }), Timeout: 300 * time.Second, } - - return httpsCli } func (a *azureClient) Download(ctx context.Context, destination string, opt *git.CloneOptions) error { diff --git a/api/http/client/client.go b/api/http/client/client.go index 194131c576..68c4ecd6f7 100644 --- a/api/http/client/client.go +++ b/api/http/client/client.go @@ -11,6 +11,7 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/crypto" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/rs/zerolog/log" "github.com/segmentio/encoding/json" @@ -114,18 +115,19 @@ func Get(url string, timeout int) ([]byte, error) { // using the specified host and optional TLS configuration. // It uses a new Http.Client for each operation. func ExecutePingOperation(host string, tlsConfiguration portainer.TLSConfiguration) (bool, error) { - transport := &http.Transport{} - scheme := "http" + var transport *http.Transport if tlsConfiguration.TLS { tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(tlsConfiguration) if err != nil { return false, err } - transport.TLSClientConfig = tlsConfig scheme = "https" + transport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig}) + } else { + transport = ssrf.WrapTransport(&http.Transport{}) } client := &http.Client{ diff --git a/api/http/handler/edgestacks/edgestack_create_file.go b/api/http/handler/edgestacks/edgestack_create_file.go index 93c6809d68..0129ccdca3 100644 --- a/api/http/handler/edgestacks/edgestack_create_file.go +++ b/api/http/handler/edgestacks/edgestack_create_file.go @@ -115,6 +115,11 @@ func (handler *Handler) createEdgeStackFromFileUpload(r *http.Request, tx datase if dryrun { return stack, nil } + + if err := stackutils.ValidateEdgeStackComposeContent(r.Context(), payload.DeploymentType, payload.StackFileContent); err != nil { + return nil, httperrors.NewInvalidPayloadError(err.Error()) + } + stack.CreatedByUserId = fmt.Sprintf("%d", tokenData.ID) stack.CreatedBy = stackutils.SanitizeLabel(tokenData.Username) diff --git a/api/http/handler/edgestacks/edgestack_create_string.go b/api/http/handler/edgestacks/edgestack_create_string.go index 32d53713b8..3a573ead0f 100644 --- a/api/http/handler/edgestacks/edgestack_create_string.go +++ b/api/http/handler/edgestacks/edgestack_create_string.go @@ -93,6 +93,10 @@ func (handler *Handler) createEdgeStackFromFileContent(r *http.Request, tx datas return stack, nil } + if err := stackutils.ValidateEdgeStackComposeContent(r.Context(), payload.DeploymentType, []byte(payload.StackFileContent)); err != nil { + return nil, httperrors.NewInvalidPayloadError(err.Error()) + } + return handler.edgeStacksService.PersistEdgeStack(tx, stack, func(stackFolder string, relatedEndpointIds []portainer.EndpointID) (composePath string, manifestPath string, projectPath string, err error) { return handler.storeFileContent(tx, stackFolder, payload.DeploymentType, relatedEndpointIds, []byte(payload.StackFileContent)) }) diff --git a/api/http/handler/edgestacks/edgestack_update.go b/api/http/handler/edgestacks/edgestack_update.go index d859a0c058..efb4d6d202 100644 --- a/api/http/handler/edgestacks/edgestack_update.go +++ b/api/http/handler/edgestacks/edgestack_update.go @@ -7,6 +7,7 @@ import ( "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/internal/edge" "github.com/portainer/portainer/api/set" + "github.com/portainer/portainer/api/stacks/stackutils" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/response" @@ -61,6 +62,10 @@ func (handler *Handler) edgeStackUpdate(w http.ResponseWriter, r *http.Request) return httperror.BadRequest("Invalid request payload", err) } + if err := stackutils.ValidateEdgeStackComposeContent(r.Context(), payload.DeploymentType, []byte(payload.StackFileContent)); err != nil { + return httperror.BadRequest("Stack file contains a URL blocked by the SSRF policy", err) + } + var stack *portainer.EdgeStack if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { stack, err = handler.updateEdgeStack(tx, portainer.EdgeStackID(stackID), payload) diff --git a/api/http/proxy/factory/agent.go b/api/http/proxy/factory/agent.go index 821b7ada0e..ac6d6892ed 100644 --- a/api/http/proxy/factory/agent.go +++ b/api/http/proxy/factory/agent.go @@ -10,6 +10,7 @@ import ( "github.com/portainer/portainer/api/http/proxy/factory/agent" "github.com/portainer/portainer/api/internal/endpointutils" "github.com/portainer/portainer/api/url" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/pkg/errors" "github.com/rs/zerolog/log" @@ -40,21 +41,30 @@ func (factory *ProxyFactory) NewAgentProxy(endpoint *portainer.Endpoint) (*Proxy } endpointURL.Scheme = "http" - httpTransport := &http.Transport{} + var innerTransport *http.Transport if endpoint.TLSConfig.TLS || endpoint.TLSConfig.TLSSkipVerify { - config, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig) + tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig) if err != nil { return nil, errors.WithMessage(err, "failed generating tls configuration") } - httpTransport.TLSClientConfig = config endpointURL.Scheme = "https" + + if endpointutils.IsEdgeEndpoint(endpoint) { + innerTransport = ssrf.WrapTransportInternal(&http.Transport{TLSClientConfig: tlsConfig}) + } else { + innerTransport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig}) + } + } else if endpointutils.IsEdgeEndpoint(endpoint) { + innerTransport = ssrf.WrapTransportInternal(&http.Transport{}) + } else { + innerTransport = ssrf.WrapTransport(&http.Transport{}) } proxy := NewSingleHostReverseProxyWithHostHeader(endpointURL) - proxy.Transport = agent.NewTransport(factory.signatureService, httpTransport) + proxy.Transport = agent.NewTransport(factory.signatureService, innerTransport) proxyServer := &ProxyServer{ server: &http.Server{ diff --git a/api/http/proxy/factory/docker.go b/api/http/proxy/factory/docker.go index 0db22aa4c8..f49b093742 100644 --- a/api/http/proxy/factory/docker.go +++ b/api/http/proxy/factory/docker.go @@ -8,8 +8,10 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/crypto" "github.com/portainer/portainer/api/http/proxy/factory/docker" + "github.com/portainer/portainer/api/internal/endpointutils" "github.com/portainer/portainer/api/url" httperror "github.com/portainer/portainer/pkg/libhttp/error" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/rs/zerolog/log" ) @@ -47,17 +49,6 @@ func (factory *ProxyFactory) newDockerHTTPProxy(endpoint *portainer.Endpoint) (h } endpointURL.Scheme = "http" - httpTransport := &http.Transport{} - - if endpoint.TLSConfig.TLS || endpoint.TLSConfig.TLSSkipVerify { - config, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig) - if err != nil { - return nil, err - } - - httpTransport.TLSClientConfig = config - endpointURL.Scheme = "https" - } transportParameters := &docker.TransportParameters{ Endpoint: endpoint, @@ -67,7 +58,27 @@ func (factory *ProxyFactory) newDockerHTTPProxy(endpoint *portainer.Endpoint) (h DockerClientFactory: factory.dockerClientFactory, } - dockerTransport, err := docker.NewTransport(transportParameters, httpTransport, factory.gitService, factory.snapshotService) + var innerTransport *http.Transport + if endpoint.TLSConfig.TLS || endpoint.TLSConfig.TLSSkipVerify { + tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig) + if err != nil { + return nil, err + } + + endpointURL.Scheme = "https" + + if endpointutils.IsEdgeEndpoint(endpoint) { + innerTransport = ssrf.WrapTransportInternal(&http.Transport{TLSClientConfig: tlsConfig}) + } else { + innerTransport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig}) + } + } else if endpointutils.IsEdgeEndpoint(endpoint) { + innerTransport = ssrf.WrapTransportInternal(&http.Transport{}) + } else { + innerTransport = ssrf.WrapTransport(&http.Transport{}) + } + + dockerTransport, err := docker.NewTransport(transportParameters, innerTransport, factory.gitService, factory.snapshotService) if err != nil { return nil, err } diff --git a/api/http/proxy/factory/github/client.go b/api/http/proxy/factory/github/client.go index e075e12b99..d455d79f27 100644 --- a/api/http/proxy/factory/github/client.go +++ b/api/http/proxy/factory/github/client.go @@ -8,6 +8,8 @@ import ( "strings" "time" + "github.com/portainer/portainer/pkg/libhttp/ssrf" + "github.com/rs/zerolog/log" "github.com/segmentio/encoding/json" "oras.land/oras-go/v2/registry/remote/retry" @@ -92,7 +94,7 @@ func NewHTTPClient(token string) *http.Client { return &http.Client{ Transport: &tokenTransport{ token: token, - transport: retry.NewTransport(&http.Transport{}), // Use ORAS retry transport for consistent rate limiting and error handling + transport: retry.NewTransport(ssrf.WrapTransport(&http.Transport{})), // Use ORAS retry transport for consistent rate limiting and error handling }, Timeout: 1 * time.Minute, } diff --git a/api/http/proxy/factory/gitlab/client.go b/api/http/proxy/factory/gitlab/client.go index 050dab8817..d4deb18f38 100644 --- a/api/http/proxy/factory/gitlab/client.go +++ b/api/http/proxy/factory/gitlab/client.go @@ -8,6 +8,8 @@ import ( "net/http" "time" + "github.com/portainer/portainer/pkg/libhttp/ssrf" + "github.com/rs/zerolog/log" "github.com/segmentio/encoding/json" "oras.land/oras-go/v2/registry/remote/retry" @@ -92,7 +94,7 @@ type Transport struct { // interface for proxying requests to the Gitlab API. func NewTransport() *Transport { return &Transport{ - httpTransport: &http.Transport{}, + httpTransport: ssrf.WrapTransport(&http.Transport{}), } } @@ -117,7 +119,7 @@ func NewHTTPClient(token string) *http.Client { return &http.Client{ Transport: &tokenTransport{ token: token, - transport: retry.NewTransport(&http.Transport{}), // Use ORAS retry transport for consistent rate limiting and error handling + transport: retry.NewTransport(ssrf.WrapTransport(&http.Transport{})), // Use ORAS retry transport for consistent rate limiting and error handling }, Timeout: 1 * time.Minute, } diff --git a/api/http/proxy/factory/kubernetes/agent_transport.go b/api/http/proxy/factory/kubernetes/agent_transport.go index 4a62e23678..c8232b1b88 100644 --- a/api/http/proxy/factory/kubernetes/agent_transport.go +++ b/api/http/proxy/factory/kubernetes/agent_transport.go @@ -8,6 +8,7 @@ import ( "github.com/portainer/portainer/api/crypto" "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/kubernetes/cli" + "github.com/portainer/portainer/pkg/libhttp/ssrf" ) type agentTransport struct { @@ -24,9 +25,9 @@ func NewAgentTransport(signatureService portainer.DigitalSignatureService, token transport := &agentTransport{ baseTransport: newBaseTransport( - &http.Transport{ + ssrf.WrapTransport(&http.Transport{ TLSClientConfig: tlsConfig, - }, + }), tokenManager, endpoint, k8sClientFactory, diff --git a/api/http/proxy/factory/kubernetes/edge_transport.go b/api/http/proxy/factory/kubernetes/edge_transport.go index 73946114ec..cd817eb6b4 100644 --- a/api/http/proxy/factory/kubernetes/edge_transport.go +++ b/api/http/proxy/factory/kubernetes/edge_transport.go @@ -7,6 +7,7 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/kubernetes/cli" + "github.com/portainer/portainer/pkg/libhttp/ssrf" ) type edgeTransport struct { @@ -21,7 +22,7 @@ func NewEdgeTransport(dataStore dataservices.DataStore, signatureService portain reverseTunnelService: reverseTunnelService, signatureService: signatureService, baseTransport: newBaseTransport( - &http.Transport{}, + ssrf.WrapTransportInternal(&http.Transport{}), tokenManager, endpoint, k8sClientFactory, diff --git a/api/http/proxy/factory/kubernetes/edge_transport_test.go b/api/http/proxy/factory/kubernetes/edge_transport_test.go new file mode 100644 index 0000000000..ddf40d8391 --- /dev/null +++ b/api/http/proxy/factory/kubernetes/edge_transport_test.go @@ -0,0 +1,14 @@ +package kubernetes + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewEdgeTransport(t *testing.T) { + t.Parallel() + + transport := NewEdgeTransport(nil, nil, nil, nil, nil, nil, nil) + require.NotNil(t, transport) +} diff --git a/api/http/proxy/factory/kubernetes/local_transport.go b/api/http/proxy/factory/kubernetes/local_transport.go index bc832f35c3..6c22c5dc5c 100644 --- a/api/http/proxy/factory/kubernetes/local_transport.go +++ b/api/http/proxy/factory/kubernetes/local_transport.go @@ -7,6 +7,7 @@ import ( "github.com/portainer/portainer/api/crypto" "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/kubernetes/cli" + "github.com/portainer/portainer/pkg/libhttp/ssrf" ) type localTransport struct { @@ -22,9 +23,9 @@ func NewLocalTransport(tokenManager *tokenManager, endpoint *portainer.Endpoint, transport := &localTransport{ baseTransport: newBaseTransport( - &http.Transport{ + ssrf.WrapTransportInternal(&http.Transport{ TLSClientConfig: config, - }, + }), tokenManager, endpoint, k8sClientFactory, diff --git a/api/http/proxy/factory/transport_test.go b/api/http/proxy/factory/transport_test.go new file mode 100644 index 0000000000..d6c623a666 --- /dev/null +++ b/api/http/proxy/factory/transport_test.go @@ -0,0 +1,200 @@ +package factory + +import ( + "context" + "net/http/httputil" + "testing" + "time" + + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/http/proxy/factory/docker" + "github.com/portainer/portainer/pkg/fips" + "github.com/portainer/portainer/pkg/libhttp/ssrf" + + "github.com/stretchr/testify/require" +) + +func init() { + fips.InitFIPS(false) +} + +type stubTunnelService struct{} + +func (s *stubTunnelService) StartTunnelServer(addr, port string, snapshotService portainer.SnapshotService) error { + return nil +} + +func (s *stubTunnelService) StopTunnelServer() error { return nil } + +func (s *stubTunnelService) GenerateEdgeKey(apiURL, tunnelAddr string, endpointIdentifier int) string { + return "" +} + +func (s *stubTunnelService) Open(endpoint *portainer.Endpoint) error { return nil } + +func (s *stubTunnelService) Config(endpointID portainer.EndpointID) portainer.TunnelDetails { + return portainer.TunnelDetails{} +} + +func (s *stubTunnelService) TunnelAddr(endpoint *portainer.Endpoint) (string, error) { + return "127.0.0.1:9999", nil +} + +func (s *stubTunnelService) UpdateLastActivity(endpointID portainer.EndpointID) {} + +func (s *stubTunnelService) KeepTunnelAlive(endpointID portainer.EndpointID, ctx context.Context, maxKeepAlive time.Duration) { +} + +func enableSSRF(t *testing.T) { + t.Helper() + ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}}) + t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) }) +} + +// TestNewDockerHTTPProxy_NonEdgeNoTLS verifies that a plain non-edge endpoint +// uses WrapTransport, setting DialContext on the inner transport. +func TestNewDockerHTTPProxy_NonEdgeNoTLS(t *testing.T) { + enableSSRF(t) + + f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}} + endpoint := &portainer.Endpoint{ + Type: portainer.DockerEnvironment, + URL: "tcp://192.168.1.100:2376", + } + + handler, err := f.newDockerHTTPProxy(endpoint) + require.NoError(t, err) + + proxy := handler.(*httputil.ReverseProxy) + dt := proxy.Transport.(*docker.Transport) + require.NotNil(t, dt.HTTPTransport.DialContext) +} + +// TestNewDockerHTTPProxy_NonEdgeTLS verifies that a TLS non-edge endpoint +// uses WrapTransport, setting DialContext on the inner transport. +func TestNewDockerHTTPProxy_NonEdgeTLS(t *testing.T) { + enableSSRF(t) + + f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}} + endpoint := &portainer.Endpoint{ + Type: portainer.DockerEnvironment, + URL: "tcp://192.168.1.100:2376", + TLSConfig: portainer.TLSConfiguration{ + TLS: true, + TLSSkipVerify: true, + }, + } + + handler, err := f.newDockerHTTPProxy(endpoint) + require.NoError(t, err) + + proxy := handler.(*httputil.ReverseProxy) + dt := proxy.Transport.(*docker.Transport) + require.NotNil(t, dt.HTTPTransport.DialContext) +} + +// TestNewDockerHTTPProxy_EdgeNoTLS verifies that an edge endpoint without TLS +// uses WrapTransportInternal, leaving DialContext nil. +func TestNewDockerHTTPProxy_EdgeNoTLS(t *testing.T) { + enableSSRF(t) + + f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}} + endpoint := &portainer.Endpoint{ + Type: portainer.EdgeAgentOnDockerEnvironment, + URL: "tcp://192.168.1.100:2376", + } + + handler, err := f.newDockerHTTPProxy(endpoint) + require.NoError(t, err) + + proxy := handler.(*httputil.ReverseProxy) + dt := proxy.Transport.(*docker.Transport) + require.Nil(t, dt.HTTPTransport.DialContext) +} + +// TestNewDockerHTTPProxy_EdgeTLS verifies that an edge endpoint with TLS +// uses WrapTransportInternal, leaving DialContext nil. +func TestNewDockerHTTPProxy_EdgeTLS(t *testing.T) { + enableSSRF(t) + + f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}} + endpoint := &portainer.Endpoint{ + Type: portainer.EdgeAgentOnDockerEnvironment, + URL: "tcp://192.168.1.100:2376", + TLSConfig: portainer.TLSConfiguration{ + TLS: true, + TLSSkipVerify: true, + }, + } + + handler, err := f.newDockerHTTPProxy(endpoint) + require.NoError(t, err) + + proxy := handler.(*httputil.ReverseProxy) + dt := proxy.Transport.(*docker.Transport) + require.Nil(t, dt.HTTPTransport.DialContext) +} + +func TestNewAgentProxy_NonEdgeNoTLS(t *testing.T) { + f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}} + endpoint := &portainer.Endpoint{ + Type: portainer.AgentOnDockerEnvironment, + URL: "tcp://192.168.1.100:9001", + } + + proxyServer, err := f.NewAgentProxy(endpoint) + require.NoError(t, err) + defer proxyServer.Close() + + require.Positive(t, proxyServer.Port) +} + +func TestNewAgentProxy_NonEdgeTLS(t *testing.T) { + f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}} + endpoint := &portainer.Endpoint{ + Type: portainer.AgentOnDockerEnvironment, + URL: "tcp://192.168.1.100:9001", + TLSConfig: portainer.TLSConfiguration{ + TLS: true, + TLSSkipVerify: true, + }, + } + + proxyServer, err := f.NewAgentProxy(endpoint) + require.NoError(t, err) + defer proxyServer.Close() + + require.Positive(t, proxyServer.Port) +} + +func TestNewAgentProxy_EdgeNoTLS(t *testing.T) { + f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}} + endpoint := &portainer.Endpoint{ + Type: portainer.EdgeAgentOnDockerEnvironment, + URL: "tcp://192.168.1.100:9001", + } + + proxyServer, err := f.NewAgentProxy(endpoint) + require.NoError(t, err) + defer proxyServer.Close() + + require.Positive(t, proxyServer.Port) +} + +func TestNewAgentProxy_EdgeTLS(t *testing.T) { + f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}} + endpoint := &portainer.Endpoint{ + Type: portainer.EdgeAgentOnDockerEnvironment, + URL: "tcp://192.168.1.100:9001", + TLSConfig: portainer.TLSConfiguration{ + TLS: true, + TLSSkipVerify: true, + }, + } + + proxyServer, err := f.NewAgentProxy(endpoint) + require.NoError(t, err) + defer proxyServer.Close() + + require.Positive(t, proxyServer.Port) +} diff --git a/api/portainer.go b/api/portainer.go index 9d0536c4ac..e30f1a7b6a 100644 --- a/api/portainer.go +++ b/api/portainer.go @@ -111,6 +111,8 @@ type ( KubectlShellImageSet bool PullLimitCheckDisabled *bool TrustedOrigins *string + SSRFMode *string + SSRFAllowedHosts *[]string NoSetupToken *bool SetupToken *string } diff --git a/api/stacks/deployments/deployment_compose_config.go b/api/stacks/deployments/deployment_compose_config.go index 85f5e163c5..a34865c600 100644 --- a/api/stacks/deployments/deployment_compose_config.go +++ b/api/stacks/deployments/deployment_compose_config.go @@ -71,6 +71,10 @@ func (config *ComposeStackDeploymentConfig) Deploy(ctx context.Context) error { } } + if err := stackutils.ValidateComposeURLs(ctx, config.stack, config.FileService); err != nil { + return err + } + if stackutils.IsRelativePathStack(config.stack) { return config.StackDeployer.DeployRemoteComposeStack(ctx, config.stack, config.endpoint, config.registries, config.prune, config.forcePullImage, config.ForceCreate) } diff --git a/api/stacks/deployments/deployment_swarm_config.go b/api/stacks/deployments/deployment_swarm_config.go index d99dc83d5d..51d9a3f09b 100644 --- a/api/stacks/deployments/deployment_swarm_config.go +++ b/api/stacks/deployments/deployment_swarm_config.go @@ -71,6 +71,10 @@ func (config *SwarmStackDeploymentConfig) Deploy(ctx context.Context) error { } } + if err := stackutils.ValidateComposeURLs(ctx, config.stack, config.FileService); err != nil { + return err + } + if stackutils.IsRelativePathStack(config.stack) { return config.StackDeployer.DeployRemoteSwarmStack(ctx, config.stack, config.endpoint, config.registries, config.prune, config.pullImage) } diff --git a/api/stacks/stackutils/validation.go b/api/stacks/stackutils/validation.go index 75bb5015d7..9f74ac149b 100644 --- a/api/stacks/stackutils/validation.go +++ b/api/stacks/stackutils/validation.go @@ -2,10 +2,13 @@ package stackutils import ( "context" + "fmt" "path" + "strings" portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/filesystem" + "github.com/portainer/portainer/pkg/libhttp/ssrf" composeloader "github.com/compose-spec/compose-go/v2/loader" composetypes "github.com/compose-spec/compose-go/v2/types" @@ -68,6 +71,97 @@ func IsValidStackFile(config StackFileValidationConfig) error { return nil } +// ValidateComposeURLs parses each stack file and checks that every external URL +// (build contexts and image registry hostnames) is permitted by the active SSRF +// policy. It is a no-op when SSRF protection is disabled. +func ValidateComposeURLs(ctx context.Context, stack *portainer.Stack, fileService portainer.FileService) error { + if !ssrf.IsEnabled() { + return nil + } + + env := BuildEnvMap(stack) + workingDir := filesystem.JoinPaths(stack.ProjectPath, path.Dir(stack.EntryPoint)) + + for _, file := range GetStackFilePaths(stack, false) { + stackContent, err := fileService.GetFileContent(stack.ProjectPath, file) + if err != nil { + return errors.Wrap(err, "failed to get stack file content") + } + + if err := checkComposeFileURLs(ctx, stackContent, env, workingDir); err != nil { + return errors.Wrap(err, "stack file contains a URL blocked by the SSRF policy") + } + } + + return nil +} + +// ValidateEdgeStackComposeContent checks that every external URL in an edge +// stack's Compose file is permitted by the active SSRF policy. It is a no-op +// when SSRF protection is disabled or the deployment type is not compose. +func ValidateEdgeStackComposeContent(ctx context.Context, deploymentType portainer.EdgeStackDeploymentType, content []byte) error { + if !ssrf.IsEnabled() || deploymentType != portainer.EdgeStackDeploymentCompose { + return nil + } + + if err := checkComposeFileURLs(ctx, content, nil, ""); err != nil { + return errors.Wrap(err, "stack file contains a URL blocked by the SSRF policy") + } + + return nil +} + +func checkComposeFileURLs(ctx context.Context, content []byte, env map[string]string, workingDir string) error { + composeConfigDetails := composetypes.ConfigDetails{ + ConfigFiles: []composetypes.ConfigFile{{Content: content}}, + Environment: env, + WorkingDir: workingDir, + } + + composeConfig, err := composeloader.LoadWithContext(ctx, composeConfigDetails, composeloader.WithSkipValidation) + if err != nil { + return err + } + + for _, service := range composeConfig.Services { + if service.Build != nil { + buildCtx := service.Build.Context + if strings.HasPrefix(buildCtx, "http://") || strings.HasPrefix(buildCtx, "https://") { + if err := ssrf.CheckURL(ctx, buildCtx); err != nil { + return fmt.Errorf("service %q: build context URL blocked: %w", service.Name, err) + } + } + } + + if service.Image != "" { + if registry := extractImageRegistry(service.Image); registry != "" { + if err := ssrf.CheckURL(ctx, "https://"+registry); err != nil { + return fmt.Errorf("service %q: image registry %q blocked: %w", service.Name, registry, err) + } + } + } + } + + return nil +} + +// extractImageRegistry returns the registry hostname from an OCI image reference, +// or an empty string if the image resolves to Docker Hub (no explicit registry). +func extractImageRegistry(imageRef string) string { + ref, _, _ := strings.Cut(imageRef, "@") + + first, _, hasSlash := strings.Cut(ref, "/") + if !hasSlash { + return "" + } + + if strings.ContainsAny(first, ".:") || first == "localhost" { + return first + } + + return "" +} + func ValidateStackFiles(stack *portainer.Stack, securitySettings *portainer.EndpointSecuritySettings, fileService portainer.FileService) error { env := BuildEnvMap(stack) workingDir := filesystem.JoinPaths(stack.ProjectPath, path.Dir(stack.EntryPoint)) diff --git a/api/stacks/stackutils/validation_test.go b/api/stacks/stackutils/validation_test.go index 2aca9357b6..893f3925b4 100644 --- a/api/stacks/stackutils/validation_test.go +++ b/api/stacks/stackutils/validation_test.go @@ -6,6 +6,7 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/filesystem" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/stretchr/testify/require" ) @@ -271,3 +272,170 @@ services: err := ValidateStackFiles(stack, securitySettings, fileService) require.ErrorContains(t, err, "bind-mount disabled for non administrator users") } + +func TestExtractImageRegistry(t *testing.T) { + t.Parallel() + + cases := []struct { + image string + expected string + }{ + {"nginx", ""}, + {"nginx:latest", ""}, + {"library/nginx", ""}, + {"ghcr.io/owner/image:tag", "ghcr.io"}, + {"myregistry.com/image:tag", "myregistry.com"}, + {"myregistry.com:5000/image:tag", "myregistry.com:5000"}, + {"localhost/image:tag", "localhost"}, + {"localhost:5000/image:tag", "localhost:5000"}, + {"myregistry.com/image@sha256:abc", "myregistry.com"}, + {"169.254.169.254/image:tag", "169.254.169.254"}, + } + + for _, tc := range cases { + got := extractImageRegistry(tc.image) + require.Equal(t, tc.expected, got, "image: %s", tc.image) + } +} + +func TestValidateComposeURLs_DisabledSSRF(t *testing.T) { + ssrf.Configure(ssrf.Policy{}) + + stack := &portainer.Stack{ + ProjectPath: "/tmp/stack/1", + EntryPoint: "docker-compose.yml", + } + + fileService := mockFileService{ + fileContent: []byte(` +version: "3" +services: + web: + build: + context: http://169.254.169.254/repo.tar.gz +`), + projectVersionPath: "/tmp/stack/1", + } + + err := ValidateComposeURLs(t.Context(), stack, fileService) + require.NoError(t, err) +} + +func TestValidateComposeURLs_BuildContextBlocked(t *testing.T) { + ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}}) + t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) }) + + stack := &portainer.Stack{ + ProjectPath: "/tmp/stack/1", + EntryPoint: "docker-compose.yml", + } + + fileService := mockFileService{ + fileContent: []byte(` +version: "3" +services: + web: + build: + context: http://169.254.169.254/repo.tar.gz + image: nginx +`), + projectVersionPath: "/tmp/stack/1", + } + + err := ValidateComposeURLs(t.Context(), stack, fileService) + require.ErrorContains(t, err, "SSRF policy") +} + +func TestValidateComposeURLs_BuildContextPath(t *testing.T) { + ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}}) + t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) }) + + stack := &portainer.Stack{ + ProjectPath: "/tmp/stack/1", + EntryPoint: "docker-compose.yml", + } + + fileService := mockFileService{ + fileContent: []byte(` +version: "3" +services: + web: + build: + context: ./app + image: nginx +`), + projectVersionPath: "/tmp/stack/1", + } + + err := ValidateComposeURLs(t.Context(), stack, fileService) + require.NoError(t, err) +} + +func TestValidateComposeURLs_ImageRegistryBlocked(t *testing.T) { + ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}}) + t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) }) + + stack := &portainer.Stack{ + ProjectPath: "/tmp/stack/1", + EntryPoint: "docker-compose.yml", + } + + fileService := mockFileService{ + fileContent: []byte(` +version: "3" +services: + web: + image: 169.254.169.254/myimage:latest +`), + projectVersionPath: "/tmp/stack/1", + } + + err := ValidateComposeURLs(t.Context(), stack, fileService) + require.ErrorContains(t, err, "SSRF policy") +} + +func TestValidateComposeURLs_ImageRegistryAllowed(t *testing.T) { + ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"myregistry.com"}}) + t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) }) + + stack := &portainer.Stack{ + ProjectPath: "/tmp/stack/1", + EntryPoint: "docker-compose.yml", + } + + fileService := mockFileService{ + fileContent: []byte(` +version: "3" +services: + web: + image: myregistry.com/myimage:latest +`), + projectVersionPath: "/tmp/stack/1", + } + + err := ValidateComposeURLs(t.Context(), stack, fileService) + require.NoError(t, err) +} + +func TestValidateComposeURLs_DockerHubImageAllowed(t *testing.T) { + ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}}) + t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) }) + + stack := &portainer.Stack{ + ProjectPath: "/tmp/stack/1", + EntryPoint: "docker-compose.yml", + } + + fileService := mockFileService{ + fileContent: []byte(` +version: "3" +services: + web: + image: nginx:latest +`), + projectVersionPath: "/tmp/stack/1", + } + + err := ValidateComposeURLs(t.Context(), stack, fileService) + require.NoError(t, err) +} diff --git a/go.mod b/go.mod index fef9e26c7b..93ad5a24ef 100644 --- a/go.mod +++ b/go.mod @@ -48,6 +48,7 @@ require ( github.com/prometheus/client_model v0.6.2 github.com/prometheus/common v0.67.5 github.com/prometheus/prometheus v0.311.3 + github.com/quasilyte/go-ruleguard/dsl v0.3.23 github.com/robfig/cron/v3 v3.0.1 github.com/rs/zerolog v1.34.0 github.com/samber/slog-zerolog/v2 v2.9.1 diff --git a/go.sum b/go.sum index cb9702aa1e..17ab4f1599 100644 --- a/go.sum +++ b/go.sum @@ -841,6 +841,8 @@ github.com/prometheus/sigv4 v0.4.1 h1:EIc3j+8NBea9u1iV6O5ZAN8uvPq2xOIUPcqCTivHuX github.com/prometheus/sigv4 v0.4.1/go.mod h1:eu+ZbRvsc5TPiHwqh77OWuCnWK73IdkETYY46P4dXOU= github.com/puzpuzpuz/xsync/v4 v4.4.0 h1:vlSN6/CkEY0pY8KaB0yqo/pCLZvp9nhdbBdjipT4gWo= github.com/puzpuzpuz/xsync/v4 v4.4.0/go.mod h1:VJDmTCJMBt8igNxnkQd86r+8KUeN1quSfNKu5bLYFQo= +github.com/quasilyte/go-ruleguard/dsl v0.3.23 h1:lxjt5B6ZCiBeeNO8/oQsegE6fLeCzuMRoVWSkXC4uvY= +github.com/quasilyte/go-ruleguard/dsl v0.3.23/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU= github.com/redis/go-redis/extra/rediscmd/v9 v9.0.5 h1:EaDatTxkdHG+U3Bk4EUr+DZ7fOGwTfezUiUJMaIcaho= github.com/redis/go-redis/extra/rediscmd/v9 v9.0.5/go.mod h1:fyalQWdtzDBECAQFBJuQe5bzQ02jGd5Qcbgb97Flm7U= github.com/redis/go-redis/extra/redisotel/v9 v9.0.5 h1:EfpWLLCyXw8PSM2/XNJLjI3Pb27yVE+gIAfeqp8LUCc= diff --git a/pkg/libhttp/ssrf/ssrf.go b/pkg/libhttp/ssrf/ssrf.go new file mode 100644 index 0000000000..15a3ca1e18 --- /dev/null +++ b/pkg/libhttp/ssrf/ssrf.go @@ -0,0 +1,252 @@ +package ssrf + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "strings" + "sync/atomic" + + "github.com/rs/zerolog/log" +) + +// Mode controls how the SSRF policy is applied. +type Mode string + +const ( + // ModeOff disables SSRF protection entirely. All connections pass through unchanged. + ModeOff Mode = "off" + // ModeAudit resolves and checks destinations but only logs violations; connections are allowed. + ModeAudit Mode = "audit" + // ModeEnforce blocks connections that violate the policy. + ModeEnforce Mode = "enforce" +) + +// Policy defines the SSRF protection policy for outbound HTTP connections. +type Policy struct { + // Mode controls whether protection is off, in audit-only mode, or enforcing. + Mode Mode + + // AllowedHosts is the allowlist of permitted destinations. + // Accepted formats: + // - Exact hostname: "example.com" + // - Wildcard hostname: "*.example.com" (matches any subdomain at any depth) + // - Single IP: "1.2.3.4" + // - CIDR range: "10.0.0.0/8" + // + // When Mode is ModeEnforce and AllowedHosts is empty, all outbound connections are blocked. + AllowedHosts []string +} + +type safeDialer struct { + base net.Dialer + mode Mode + allowedNets []*net.IPNet + allowedHosts map[string]bool + allowedWilds []string // derived from "*.foo.com" entries; stored as ".foo.com" +} + +var globalDialer atomic.Pointer[safeDialer] + +// Configure initializes the global SSRF policy. Intended to be called once +// at startup before any outbound HTTP connections are established. +func Configure(policy Policy) { + if policy.Mode == ModeOff || policy.Mode == "" { + globalDialer.Store(nil) + return + } + + globalDialer.Store(newSafeDialer(policy)) +} + +// IsEnabled reports whether SSRF protection is currently active (audit or enforce). +func IsEnabled() bool { + return globalDialer.Load() != nil +} + +// CheckURL validates rawURL against the active SSRF policy without making a +// connection. Returns nil when protection is disabled or the destination is +// permitted. In audit mode, logs a warning on violations and returns nil. +func CheckURL(ctx context.Context, rawURL string) error { + d := globalDialer.Load() + if d == nil { + return nil + } + + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("ssrf: invalid URL %q: %w", rawURL, err) + } + + host := u.Hostname() + if host == "" { + return nil + } + + return d.checkHost(ctx, host) +} + +// WrapTransport clones t and replaces its DialContext with the global SSRF-filtering +// dialer. Returns t unchanged when SSRF protection is not configured. +func WrapTransport(t *http.Transport) *http.Transport { + d := globalDialer.Load() + if d == nil { + return t + } + + cloned := t.Clone() + cloned.DialContext = d.DialContext + + return cloned +} + +// WrapTransportInternal is a documented no-op for transports that connect to +// internally computed destinations (local Docker socket proxy, Chisel tunnels, +// in-cluster Kubernetes API). The destination is chosen by Portainer, not +// supplied by any user, so SSRF validation is not applicable. Using this +// function instead of WrapTransport makes the exemption explicit and +// satisfies the ruleguard lint rule. +func WrapTransportInternal(t *http.Transport) *http.Transport { + return t +} + +func newSafeDialer(policy Policy) *safeDialer { + allowedNets := make([]*net.IPNet, 0, len(policy.AllowedHosts)) + allowedHosts := make(map[string]bool, len(policy.AllowedHosts)) + var allowedWilds []string + + for _, entry := range policy.AllowedHosts { + if _, network, err := net.ParseCIDR(entry); err == nil { + allowedNets = append(allowedNets, network) + continue + } + + if ip := net.ParseIP(entry); ip != nil { + bits := 32 + if ip.To4() == nil { + bits = 128 + } + + mask := net.CIDRMask(bits, bits) + allowedNets = append(allowedNets, &net.IPNet{IP: ip.Mask(mask), Mask: mask}) + + continue + } + + if strings.HasPrefix(entry, "*.") { + allowedWilds = append(allowedWilds, entry[1:]) // "*.foo.com" -> ".foo.com" + continue + } + + allowedHosts[entry] = true + } + + return &safeDialer{ + mode: policy.Mode, + allowedNets: allowedNets, + allowedHosts: allowedHosts, + allowedWilds: allowedWilds, + } +} + +// DialContext resolves addr, validates all resolved IPs against the allowlist policy, +// then dials using the first resolved IP to prevent DNS rebinding attacks. +func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("ssrf: invalid address %q: %w", addr, err) + } + + resolved, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, fmt.Errorf("ssrf: resolving %q: %w", host, err) + } + + if len(resolved) == 0 { + return nil, fmt.Errorf("ssrf: no addresses resolved for %q", host) + } + + // Dial by resolved IP regardless of how the host was allowed to close the + // window between DNS validation and the TCP handshake (DNS rebinding). + dialTarget := net.JoinHostPort(resolved[0].IP.String(), port) + + if d.allowedHosts[host] || d.matchesWildcard(host) { + return d.base.DialContext(ctx, network, dialTarget) + } + + for _, a := range resolved { + if !d.ipAllowed(a.IP) { + if d.mode == ModeAudit { + log.Warn().Str("host", host).Str("ip", a.IP.String()).Msg("ssrf: destination not in allowlist (audit mode, allowing)") + continue + } + + return nil, fmt.Errorf("ssrf: destination %s is not in the allowlist", a.IP) + } + } + + return d.base.DialContext(ctx, network, dialTarget) +} + +func (d *safeDialer) checkHost(ctx context.Context, host string) error { + if d.allowedHosts[host] || d.matchesWildcard(host) { + return nil + } + + if ip := net.ParseIP(host); ip != nil { + if !d.ipAllowed(ip) { + if d.mode == ModeAudit { + log.Warn().Str("host", host).Msg("ssrf: destination not in allowlist (audit mode, allowing)") + return nil + } + + return fmt.Errorf("ssrf: destination %s is not in the allowlist", ip) + } + + return nil + } + + resolved, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("ssrf: resolving %q: %w", host, err) + } + + if len(resolved) == 0 { + return fmt.Errorf("ssrf: no addresses resolved for %q", host) + } + + for _, a := range resolved { + if !d.ipAllowed(a.IP) { + if d.mode == ModeAudit { + log.Warn().Str("host", host).Str("ip", a.IP.String()).Msg("ssrf: destination not in allowlist (audit mode, allowing)") + continue + } + + return fmt.Errorf("ssrf: destination %s is not in the allowlist", a.IP) + } + } + + return nil +} + +func (d *safeDialer) matchesWildcard(host string) bool { + for _, suffix := range d.allowedWilds { + if strings.HasSuffix(host, suffix) { + return true + } + } + + return false +} + +func (d *safeDialer) ipAllowed(ip net.IP) bool { + for _, network := range d.allowedNets { + if network.Contains(ip) { + return true + } + } + + return false +} diff --git a/pkg/libhttp/ssrf/ssrf_test.go b/pkg/libhttp/ssrf/ssrf_test.go new file mode 100644 index 0000000000..329d632257 --- /dev/null +++ b/pkg/libhttp/ssrf/ssrf_test.go @@ -0,0 +1,357 @@ +package ssrf + +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIpAllowed_CIDR(t *testing.T) { + t.Parallel() + + d := newSafeDialer(Policy{ + Mode: ModeEnforce, + AllowedHosts: []string{"8.8.0.0/16", "2001:4860::/32"}, + }) + + require.True(t, d.ipAllowed(net.ParseIP("8.8.8.8"))) + require.True(t, d.ipAllowed(net.ParseIP("8.8.4.4"))) + require.True(t, d.ipAllowed(net.ParseIP("2001:4860:4860::8888"))) + + require.False(t, d.ipAllowed(net.ParseIP("1.1.1.1"))) + require.False(t, d.ipAllowed(net.ParseIP("127.0.0.1"))) + require.False(t, d.ipAllowed(net.ParseIP("169.254.169.254"))) +} + +func TestIpAllowed_SingleIP(t *testing.T) { + t.Parallel() + + d := newSafeDialer(Policy{ + Mode: ModeEnforce, + AllowedHosts: []string{"1.2.3.4"}, + }) + + require.True(t, d.ipAllowed(net.ParseIP("1.2.3.4"))) + require.False(t, d.ipAllowed(net.ParseIP("1.2.3.5"))) +} + +func TestMatchesWildcard(t *testing.T) { + t.Parallel() + + d := newSafeDialer(Policy{ + Mode: ModeEnforce, + AllowedHosts: []string{"*.example.com", "exact.host.com"}, + }) + + require.True(t, d.matchesWildcard("foo.example.com")) + require.True(t, d.matchesWildcard("bar.example.com")) + require.True(t, d.matchesWildcard("deep.nested.example.com")) + + require.False(t, d.matchesWildcard("example.com")) + require.False(t, d.matchesWildcard("notexample.com")) + require.False(t, d.matchesWildcard("exact.host.com")) +} + +func TestNewSafeDialer_MixedHosts(t *testing.T) { + t.Parallel() + + d := newSafeDialer(Policy{ + Mode: ModeEnforce, + AllowedHosts: []string{"example.com", "*.internal.net", "10.0.0.0/8", "1.2.3.4"}, + }) + + require.True(t, d.allowedHosts["example.com"]) + require.Contains(t, d.allowedWilds, ".internal.net") + require.Len(t, d.allowedNets, 2) // 10.0.0.0/8 and 1.2.3.4/32 +} + +func TestConfigure_Disabled(t *testing.T) { + Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}}) + require.NotNil(t, globalDialer.Load()) + + Configure(Policy{}) + require.Nil(t, globalDialer.Load()) +} + +func TestWrapTransport_NoPolicy(t *testing.T) { + globalDialer.Store(nil) + + base := &http.Transport{} + result := WrapTransport(base) + require.Equal(t, base, result) +} + +func TestWrapTransport_WithPolicy(t *testing.T) { + Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + base := &http.Transport{} + result := WrapTransport(base) + require.NotEqual(t, base, result) + require.NotNil(t, result.DialContext) +} + +func TestCheckURL_Disabled(t *testing.T) { + globalDialer.Store(nil) + + err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/") + require.NoError(t, err) +} + +func TestCheckURL_BlocksIPNotInAllowlist(t *testing.T) { + Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/") + require.Error(t, err) + require.Contains(t, err.Error(), "ssrf") +} + +func TestCheckURL_AllowedHostname(t *testing.T) { + Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + err := CheckURL(t.Context(), "https://example.com/path") + require.NoError(t, err) +} + +func TestCheckURL_AuditMode_ReturnsNil(t *testing.T) { + Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.0/24"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/") + require.NoError(t, err) +} + +// TestDialContext_BlocksLoopback is an end-to-end test: it starts a real HTTP +// server on 127.0.0.1, enables SSRF protection with an allowlist that does not +// include loopback, and verifies that the wrapped transport blocks the request. +func TestDialContext_BlocksLoopback(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.8"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + blocked := &http.Client{Transport: WrapTransport(&http.Transport{})} + resp, err := blocked.Get(srv.URL) + require.Error(t, err) + require.Contains(t, err.Error(), "ssrf") + if resp != nil { + require.NoError(t, resp.Body.Close()) + } + + Configure(Policy{}) + + open := &http.Client{Transport: WrapTransport(&http.Transport{})} + resp, err = open.Get(srv.URL) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) +} + +// TestDialContext_AuditMode_AllowsLoopback verifies that audit mode logs the +// violation but still allows the connection through. +func TestDialContext_AuditMode_AllowsLoopback(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.8"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + client := &http.Client{Transport: WrapTransport(&http.Transport{})} + resp, err := client.Get(srv.URL) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) +} + +func TestIsEnabled(t *testing.T) { + globalDialer.Store(nil) + require.False(t, IsEnabled()) + + Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + require.True(t, IsEnabled()) +} + +func TestWrapTransportInternal(t *testing.T) { + t.Parallel() + + base := &http.Transport{} + result := WrapTransportInternal(base) + require.Equal(t, base, result) +} + +func TestNewSafeDialer_IPv6SingleIP(t *testing.T) { + t.Parallel() + + d := newSafeDialer(Policy{ + Mode: ModeEnforce, + AllowedHosts: []string{"::1"}, + }) + + require.True(t, d.ipAllowed(net.ParseIP("::1"))) + require.False(t, d.ipAllowed(net.ParseIP("::2"))) +} + +func TestCheckURL_InvalidURL(t *testing.T) { + Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + err := CheckURL(t.Context(), "http://%gg") + require.Error(t, err) + require.Contains(t, err.Error(), "ssrf") +} + +func TestCheckURL_EmptyHost(t *testing.T) { + Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + err := CheckURL(t.Context(), "http://") + require.NoError(t, err) +} + +// TestCheckURL_IPInAllowlist verifies that a literal IP address that falls +// within an allowed CIDR range is permitted. +func TestCheckURL_IPInAllowlist(t *testing.T) { + Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + err := CheckURL(t.Context(), "http://8.8.8.8/path") + require.NoError(t, err) +} + +func TestCheckURL_WildcardHostname(t *testing.T) { + Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"*.example.com"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + err := CheckURL(t.Context(), "https://api.example.com/path") + require.NoError(t, err) +} + +// TestCheckURL_HostnameDNSResolvesToAllowedIP verifies that a hostname +// resolving to an IP within the allowlist is permitted (DNS resolution path). +func TestCheckURL_HostnameDNSResolvesToAllowedIP(t *testing.T) { + Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"127.0.0.0/8", "::1/128"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + err := CheckURL(t.Context(), "http://localhost/path") + require.NoError(t, err) +} + +// TestCheckURL_HostnameDNSResolvesToBlockedIP verifies that a hostname +// resolving to an IP outside the allowlist is blocked (DNS resolution path). +func TestCheckURL_HostnameDNSResolvesToBlockedIP(t *testing.T) { + Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + err := CheckURL(t.Context(), "http://localhost/path") + require.Error(t, err) + require.Contains(t, err.Error(), "ssrf") +} + +// TestCheckURL_HostnameDNSAuditMode verifies that audit mode logs violations +// from hostname DNS resolution but still returns nil. +func TestCheckURL_HostnameDNSAuditMode(t *testing.T) { + Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.0/24"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + err := CheckURL(t.Context(), "http://localhost/path") + require.NoError(t, err) +} + +// TestCheckURL_HostnameDNSError verifies that a DNS resolution failure is +// propagated as an SSRF-prefixed error. +func TestCheckURL_HostnameDNSError(t *testing.T) { + Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + err := CheckURL(ctx, "http://portainer-nonexistent.test.invalid/path") + require.Error(t, err) +} + +// TestDialContext_InvalidAddress verifies that an address without a port +// returns an SSRF-prefixed error. +func TestDialContext_InvalidAddress(t *testing.T) { + Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + d := globalDialer.Load() + _, err := d.DialContext(t.Context(), "tcp", "no-port-here") + require.Error(t, err) + require.Contains(t, err.Error(), "ssrf") +} + +// TestDialContext_DNSError verifies that a DNS resolution failure in +// DialContext is propagated as an SSRF-prefixed error. +func TestDialContext_DNSError(t *testing.T) { + Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + d := globalDialer.Load() + _, err := d.DialContext(ctx, "tcp", "portainer-nonexistent.test.invalid:80") + require.Error(t, err) +} + +// TestDialContext_AllowedByCIDR is an end-to-end test verifying that +// connections to IPs within an allowed CIDR range are permitted. +func TestDialContext_AllowedByCIDR(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"127.0.0.0/8"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + client := &http.Client{Transport: WrapTransport(&http.Transport{})} + resp, err := client.Get(srv.URL) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) +} + +// TestDialContext_AllowedByExactHostname verifies that when a hostname is in +// the allowed-hosts list, the connection is permitted even though the resolved +// IP is not covered by any CIDR entry. +// +// The server is bound to whatever IP "localhost" resolves to first so that the +// dialTarget computed by DialContext (resolved[0]) matches the listening address. +func TestDialContext_AllowedByExactHostname(t *testing.T) { + addrs, err := net.DefaultResolver.LookupIPAddr(t.Context(), "localhost") + require.NoError(t, err) + require.NotEmpty(t, addrs, "localhost must resolve to at least one address") + + l, err := net.Listen("tcp", net.JoinHostPort(addrs[0].IP.String(), "0")) + require.NoError(t, err) + + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + srv.Listener = l + srv.Start() + defer srv.Close() + + _, portStr, err := net.SplitHostPort(l.Addr().String()) + require.NoError(t, err) + + Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"localhost"}}) + t.Cleanup(func() { globalDialer.Store(nil) }) + + client := &http.Client{Transport: WrapTransport(&http.Transport{})} + resp, err := client.Get("http://localhost:" + portStr) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) +} diff --git a/pkg/registryhttp/client.go b/pkg/registryhttp/client.go index 519d2fd303..cedab57b11 100644 --- a/pkg/registryhttp/client.go +++ b/pkg/registryhttp/client.go @@ -4,6 +4,7 @@ import ( "net/http" portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/rs/zerolog/log" "oras.land/oras-go/v2/registry/remote/retry" @@ -15,8 +16,8 @@ import ( func CreateClient(registry *portainer.Registry) (httpClient *http.Client, usePlainHttp bool, err error) { switch registry.Type { case portainer.AzureRegistry, portainer.EcrRegistry, portainer.GithubRegistry, portainer.GitlabRegistry, portainer.DockerHubRegistry: - // Cloud registries use the default retry client with built-in TLS - return retry.DefaultClient, false, nil + base := http.DefaultTransport.(*http.Transport).Clone() + return &http.Client{Transport: retry.NewTransport(ssrf.WrapTransport(base))}, false, nil default: // For all other registry types, use shared helper to build transport and scheme diff --git a/pkg/registryhttp/client_test.go b/pkg/registryhttp/client_test.go index b6d494c9f2..abe7806580 100644 --- a/pkg/registryhttp/client_test.go +++ b/pkg/registryhttp/client_test.go @@ -11,6 +11,11 @@ import ( "oras.land/oras-go/v2/registry/remote/retry" ) +func isRetryTransport(t *http.Client) bool { + _, ok := t.Transport.(*retry.Transport) + return ok +} + func init() { fips.InitFIPS(false) } @@ -102,8 +107,7 @@ func TestCreateClient(t *testing.T) { // Verify client type based on registry configuration switch tt.registry.Type { case portainer.AzureRegistry, portainer.EcrRegistry, portainer.GithubRegistry, portainer.GitlabRegistry: - // Cloud registries should use the default retry client - assert.Equal(t, retry.DefaultClient, client) + assert.True(t, isRetryTransport(client), "Cloud registries should use a retry transport") } }) } @@ -133,7 +137,7 @@ func TestCreateClient_CloudRegistries(t *testing.T) { require.NoError(t, err) assert.NotNil(t, client) assert.False(t, usePlainHTTP, "Cloud registries should use HTTPS") - assert.Equal(t, retry.DefaultClient, client, "Cloud registries should use default retry client") + assert.True(t, isRetryTransport(client), "Cloud registries should use a retry transport") }) } }