feat(ssrf): implement an SSRF protection mechanism BE-13021 (#2818)

This commit is contained in:
andres-portainer
2026-06-09 00:41:42 -03:00
committed by GitHub
parent d34ee82754
commit 1765e41fd4
31 changed files with 1259 additions and 43 deletions
+13
View File
@@ -5,6 +5,7 @@ run:
linters:
default: none
enable:
- gocritic
- bodyclose
- copyloopvar
- depguard
@@ -76,6 +77,13 @@ linters:
desc: use github.com/Masterminds/semver/v3
- pkg: github.com/hashicorp/go-version
desc: use github.com/Masterminds/semver/v3
gocritic:
disable-all: true
enabled-checks:
- ruleguard
settings:
ruleguard:
rules: "./analysis/ssrf.go"
forbidigo:
forbid:
- pattern: ^tls\.Config$
@@ -93,6 +101,11 @@ linters:
- comments
- common-false-positives
- legacy
rules:
- path: pkg/libhttp/ssrf
linters:
- gocritic
text: ruleguard
paths:
- third_party$
- builtin$
+29
View File
@@ -0,0 +1,29 @@
//go:build ignore
package gorules
import "github.com/quasilyte/go-ruleguard/dsl"
// unwrappedHTTPTransport flags http.Transport composite literals that are not
// the direct argument to ssrf.WrapTransport.
func unwrappedHTTPTransport(m dsl.Matcher) {
// Inline construction passed to a function call.
m.Match(`$f(&http.Transport{$*_})`).
Where(m["f"].Text != "ssrf.WrapTransport" && m["f"].Text != "WrapTransport" &&
m["f"].Text != "ssrf.WrapTransportInternal" && m["f"].Text != "WrapTransportInternal").
Report(`$f receives a bare *http.Transport; wrap with ssrf.WrapTransport() to enforce the SSRF protection policy`)
// Variable assigned a bare transport (cannot be tracked to a later WrapTransport call).
m.Match(`$_ := &http.Transport{$*_}`).
Report(`bare *http.Transport variable; use ssrf.WrapTransport(&http.Transport{...}) inline instead`)
}
// internalTransportMisuse flags calls to WrapTransportInternal outside the four proxy
// factory files where Chisel-tunnel and in-cluster K8s destinations are valid exemptions.
func internalTransportMisuse(m dsl.Matcher) {
m.Match(`ssrf.WrapTransportInternal($*_)`).
Where(
!(m.File().PkgPath.Matches(`proxy/factory`) &&
m.File().Name.Matches(`^(docker|agent|local_transport|edge_transport)\.go$`))).
Report(`WrapTransportInternal bypasses SSRF validation; only valid in the kubernetes local/edge transport constructors and the docker/agent proxy factories`)
}
+5
View File
@@ -0,0 +1,5 @@
//go:build tools
package gorules
import _ "github.com/quasilyte/go-ruleguard/dsl"
+2
View File
@@ -56,6 +56,8 @@ func CLIFlags() *portainer.CLIFlags {
TrustedOrigins: kingpin.Flag("trusted-origins", "List of trusted origins for CSRF protection. Separate multiple origins with a comma.").Envar(portainer.TrustedOriginsEnvVar).String(),
CSP: kingpin.Flag("csp", "Content Security Policy (CSP) header").Envar(portainer.CSPEnvVar).Default("true").Bool(),
CompactDB: kingpin.Flag("compact-db", "Enable database compaction on startup").Envar(portainer.CompactDBEnvVar).Default("false").Bool(),
SSRFMode: kingpin.Flag("ssrf-mode", "SSRF protection mode: off (disabled), audit (log violations but allow), enforce (block violations)").Envar("PORTAINER_SSRF_MODE").Default("off").Enum("off", "audit", "enforce"),
SSRFAllowedHosts: kingpin.Flag("ssrf-allowed-hosts", "Allowlist of hostnames (with optional wildcards), IPs, or CIDRs permitted for outbound requests. When empty and mode is enforce, all outbound connections are blocked").Envar("PORTAINER_SSRF_ALLOWED_HOSTS").Strings(),
NoSetupToken: kingpin.Flag("no-setup-token", "Disable the setup token requirement for admin initialization and restore on an uninitialized instance").Envar(portainer.NoSetupTokenEnvVar).Bool(),
SetupToken: kingpin.Flag("setup-token", "Set a custom setup token for admin initialization and restore on an uninitialized instance (overrides auto-generation)").Envar(portainer.SetupTokenEnvVar).String(),
}
+16
View File
@@ -4,6 +4,7 @@ import (
"cmp"
"context"
"crypto/sha256"
nethttp "net/http"
"os"
"path"
"strings"
@@ -52,10 +53,12 @@ import (
"github.com/portainer/portainer/pkg/featureflags"
"github.com/portainer/portainer/pkg/fips"
"github.com/portainer/portainer/pkg/libhelm"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/portainer/portainer/pkg/libstack/compose"
libswarm "github.com/portainer/portainer/pkg/libstack/swarm"
"github.com/portainer/portainer/pkg/validate"
gogithttp "github.com/go-git/go-git/v5/plumbing/transport/http"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
@@ -384,6 +387,19 @@ func buildServer(flags *portainer.CLIFlags, shutdownCtx context.Context, shutdow
// -ce can not ever be run in FIPS mode
fips.InitFIPS(false)
ssrf.Configure(ssrf.Policy{
Mode: ssrf.Mode(*flags.SSRFMode),
AllowedHosts: *flags.SSRFAllowedHosts,
})
if ssrf.IsEnabled() {
if dt, ok := nethttp.DefaultTransport.(*nethttp.Transport); ok {
nethttp.DefaultTransport = ssrf.WrapTransport(dt)
}
gogithttp.DefaultClient = gogithttp.NewClient(&nethttp.Client{Transport: nethttp.DefaultTransport})
}
fileService := initFileService(*flags.Data)
encryptionKey := loadEncryptionSecretKey(dbSecretPath(*flags.SecretKeyName))
if encryptionKey == nil {
+10 -5
View File
@@ -11,6 +11,8 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/crypto"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/rs/zerolog/log"
"github.com/docker/docker/api/types/image"
@@ -184,17 +186,20 @@ func (t *NodeNameTransport) RoundTrip(req *http.Request) (*http.Response, error)
}
func httpClient(endpoint *portainer.Endpoint, timeout *time.Duration) (*http.Client, error) {
transport := &NodeNameTransport{
Transport: &http.Transport{},
}
var transport *NodeNameTransport
if endpoint.TLSConfig.TLS {
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig)
if err != nil {
return nil, err
}
transport.TLSClientConfig = tlsConfig
transport = &NodeNameTransport{
Transport: ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig}),
}
} else {
transport = &NodeNameTransport{
Transport: ssrf.WrapTransport(&http.Transport{}),
}
}
clientTimeout := defaultDockerRequestTimeout
+5 -6
View File
@@ -14,12 +14,13 @@ import (
"github.com/portainer/portainer/api/crypto"
gittypes "github.com/portainer/portainer/api/git/types"
"github.com/portainer/portainer/api/logs"
"github.com/rs/zerolog/log"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing/filemode"
githttp "github.com/go-git/go-git/v5/plumbing/transport/http"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
"github.com/segmentio/encoding/json"
)
@@ -64,15 +65,13 @@ func NewAzureClient() *azureClient {
}
func newHttpClientForAzure(insecureSkipVerify bool) *http.Client {
httpsCli := &http.Client{
Transport: &http.Transport{
return &http.Client{
Transport: ssrf.WrapTransport(&http.Transport{
TLSClientConfig: crypto.CreateTLSConfiguration(insecureSkipVerify),
Proxy: http.ProxyFromEnvironment,
},
}),
Timeout: 300 * time.Second,
}
return httpsCli
}
func (a *azureClient) Download(ctx context.Context, destination string, opt *git.CloneOptions) error {
+5 -3
View File
@@ -11,6 +11,7 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/crypto"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/rs/zerolog/log"
"github.com/segmentio/encoding/json"
@@ -114,18 +115,19 @@ func Get(url string, timeout int) ([]byte, error) {
// using the specified host and optional TLS configuration.
// It uses a new Http.Client for each operation.
func ExecutePingOperation(host string, tlsConfiguration portainer.TLSConfiguration) (bool, error) {
transport := &http.Transport{}
scheme := "http"
var transport *http.Transport
if tlsConfiguration.TLS {
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(tlsConfiguration)
if err != nil {
return false, err
}
transport.TLSClientConfig = tlsConfig
scheme = "https"
transport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig})
} else {
transport = ssrf.WrapTransport(&http.Transport{})
}
client := &http.Client{
@@ -115,6 +115,11 @@ func (handler *Handler) createEdgeStackFromFileUpload(r *http.Request, tx datase
if dryrun {
return stack, nil
}
if err := stackutils.ValidateEdgeStackComposeContent(r.Context(), payload.DeploymentType, payload.StackFileContent); err != nil {
return nil, httperrors.NewInvalidPayloadError(err.Error())
}
stack.CreatedByUserId = fmt.Sprintf("%d", tokenData.ID)
stack.CreatedBy = stackutils.SanitizeLabel(tokenData.Username)
@@ -93,6 +93,10 @@ func (handler *Handler) createEdgeStackFromFileContent(r *http.Request, tx datas
return stack, nil
}
if err := stackutils.ValidateEdgeStackComposeContent(r.Context(), payload.DeploymentType, []byte(payload.StackFileContent)); err != nil {
return nil, httperrors.NewInvalidPayloadError(err.Error())
}
return handler.edgeStacksService.PersistEdgeStack(tx, stack, func(stackFolder string, relatedEndpointIds []portainer.EndpointID) (composePath string, manifestPath string, projectPath string, err error) {
return handler.storeFileContent(tx, stackFolder, payload.DeploymentType, relatedEndpointIds, []byte(payload.StackFileContent))
})
@@ -7,6 +7,7 @@ import (
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/edge"
"github.com/portainer/portainer/api/set"
"github.com/portainer/portainer/api/stacks/stackutils"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"
@@ -61,6 +62,10 @@ func (handler *Handler) edgeStackUpdate(w http.ResponseWriter, r *http.Request)
return httperror.BadRequest("Invalid request payload", err)
}
if err := stackutils.ValidateEdgeStackComposeContent(r.Context(), payload.DeploymentType, []byte(payload.StackFileContent)); err != nil {
return httperror.BadRequest("Stack file contains a URL blocked by the SSRF policy", err)
}
var stack *portainer.EdgeStack
if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
stack, err = handler.updateEdgeStack(tx, portainer.EdgeStackID(stackID), payload)
+14 -4
View File
@@ -10,6 +10,7 @@ import (
"github.com/portainer/portainer/api/http/proxy/factory/agent"
"github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/url"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
@@ -40,21 +41,30 @@ func (factory *ProxyFactory) NewAgentProxy(endpoint *portainer.Endpoint) (*Proxy
}
endpointURL.Scheme = "http"
httpTransport := &http.Transport{}
var innerTransport *http.Transport
if endpoint.TLSConfig.TLS || endpoint.TLSConfig.TLSSkipVerify {
config, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig)
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig)
if err != nil {
return nil, errors.WithMessage(err, "failed generating tls configuration")
}
httpTransport.TLSClientConfig = config
endpointURL.Scheme = "https"
if endpointutils.IsEdgeEndpoint(endpoint) {
innerTransport = ssrf.WrapTransportInternal(&http.Transport{TLSClientConfig: tlsConfig})
} else {
innerTransport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig})
}
} else if endpointutils.IsEdgeEndpoint(endpoint) {
innerTransport = ssrf.WrapTransportInternal(&http.Transport{})
} else {
innerTransport = ssrf.WrapTransport(&http.Transport{})
}
proxy := NewSingleHostReverseProxyWithHostHeader(endpointURL)
proxy.Transport = agent.NewTransport(factory.signatureService, httpTransport)
proxy.Transport = agent.NewTransport(factory.signatureService, innerTransport)
proxyServer := &ProxyServer{
server: &http.Server{
+23 -12
View File
@@ -8,8 +8,10 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/crypto"
"github.com/portainer/portainer/api/http/proxy/factory/docker"
"github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/url"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/rs/zerolog/log"
)
@@ -47,17 +49,6 @@ func (factory *ProxyFactory) newDockerHTTPProxy(endpoint *portainer.Endpoint) (h
}
endpointURL.Scheme = "http"
httpTransport := &http.Transport{}
if endpoint.TLSConfig.TLS || endpoint.TLSConfig.TLSSkipVerify {
config, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig)
if err != nil {
return nil, err
}
httpTransport.TLSClientConfig = config
endpointURL.Scheme = "https"
}
transportParameters := &docker.TransportParameters{
Endpoint: endpoint,
@@ -67,7 +58,27 @@ func (factory *ProxyFactory) newDockerHTTPProxy(endpoint *portainer.Endpoint) (h
DockerClientFactory: factory.dockerClientFactory,
}
dockerTransport, err := docker.NewTransport(transportParameters, httpTransport, factory.gitService, factory.snapshotService)
var innerTransport *http.Transport
if endpoint.TLSConfig.TLS || endpoint.TLSConfig.TLSSkipVerify {
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig)
if err != nil {
return nil, err
}
endpointURL.Scheme = "https"
if endpointutils.IsEdgeEndpoint(endpoint) {
innerTransport = ssrf.WrapTransportInternal(&http.Transport{TLSClientConfig: tlsConfig})
} else {
innerTransport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig})
}
} else if endpointutils.IsEdgeEndpoint(endpoint) {
innerTransport = ssrf.WrapTransportInternal(&http.Transport{})
} else {
innerTransport = ssrf.WrapTransport(&http.Transport{})
}
dockerTransport, err := docker.NewTransport(transportParameters, innerTransport, factory.gitService, factory.snapshotService)
if err != nil {
return nil, err
}
+3 -1
View File
@@ -8,6 +8,8 @@ import (
"strings"
"time"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/rs/zerolog/log"
"github.com/segmentio/encoding/json"
"oras.land/oras-go/v2/registry/remote/retry"
@@ -92,7 +94,7 @@ func NewHTTPClient(token string) *http.Client {
return &http.Client{
Transport: &tokenTransport{
token: token,
transport: retry.NewTransport(&http.Transport{}), // Use ORAS retry transport for consistent rate limiting and error handling
transport: retry.NewTransport(ssrf.WrapTransport(&http.Transport{})), // Use ORAS retry transport for consistent rate limiting and error handling
},
Timeout: 1 * time.Minute,
}
+4 -2
View File
@@ -8,6 +8,8 @@ import (
"net/http"
"time"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/rs/zerolog/log"
"github.com/segmentio/encoding/json"
"oras.land/oras-go/v2/registry/remote/retry"
@@ -92,7 +94,7 @@ type Transport struct {
// interface for proxying requests to the Gitlab API.
func NewTransport() *Transport {
return &Transport{
httpTransport: &http.Transport{},
httpTransport: ssrf.WrapTransport(&http.Transport{}),
}
}
@@ -117,7 +119,7 @@ func NewHTTPClient(token string) *http.Client {
return &http.Client{
Transport: &tokenTransport{
token: token,
transport: retry.NewTransport(&http.Transport{}), // Use ORAS retry transport for consistent rate limiting and error handling
transport: retry.NewTransport(ssrf.WrapTransport(&http.Transport{})), // Use ORAS retry transport for consistent rate limiting and error handling
},
Timeout: 1 * time.Minute,
}
@@ -8,6 +8,7 @@ import (
"github.com/portainer/portainer/api/crypto"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/kubernetes/cli"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
)
type agentTransport struct {
@@ -24,9 +25,9 @@ func NewAgentTransport(signatureService portainer.DigitalSignatureService, token
transport := &agentTransport{
baseTransport: newBaseTransport(
&http.Transport{
ssrf.WrapTransport(&http.Transport{
TLSClientConfig: tlsConfig,
},
}),
tokenManager,
endpoint,
k8sClientFactory,
@@ -7,6 +7,7 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/kubernetes/cli"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
)
type edgeTransport struct {
@@ -21,7 +22,7 @@ func NewEdgeTransport(dataStore dataservices.DataStore, signatureService portain
reverseTunnelService: reverseTunnelService,
signatureService: signatureService,
baseTransport: newBaseTransport(
&http.Transport{},
ssrf.WrapTransportInternal(&http.Transport{}),
tokenManager,
endpoint,
k8sClientFactory,
@@ -0,0 +1,14 @@
package kubernetes
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNewEdgeTransport(t *testing.T) {
t.Parallel()
transport := NewEdgeTransport(nil, nil, nil, nil, nil, nil, nil)
require.NotNil(t, transport)
}
@@ -7,6 +7,7 @@ import (
"github.com/portainer/portainer/api/crypto"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/kubernetes/cli"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
)
type localTransport struct {
@@ -22,9 +23,9 @@ func NewLocalTransport(tokenManager *tokenManager, endpoint *portainer.Endpoint,
transport := &localTransport{
baseTransport: newBaseTransport(
&http.Transport{
ssrf.WrapTransportInternal(&http.Transport{
TLSClientConfig: config,
},
}),
tokenManager,
endpoint,
k8sClientFactory,
+200
View File
@@ -0,0 +1,200 @@
package factory
import (
"context"
"net/http/httputil"
"testing"
"time"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/http/proxy/factory/docker"
"github.com/portainer/portainer/pkg/fips"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/stretchr/testify/require"
)
func init() {
fips.InitFIPS(false)
}
type stubTunnelService struct{}
func (s *stubTunnelService) StartTunnelServer(addr, port string, snapshotService portainer.SnapshotService) error {
return nil
}
func (s *stubTunnelService) StopTunnelServer() error { return nil }
func (s *stubTunnelService) GenerateEdgeKey(apiURL, tunnelAddr string, endpointIdentifier int) string {
return ""
}
func (s *stubTunnelService) Open(endpoint *portainer.Endpoint) error { return nil }
func (s *stubTunnelService) Config(endpointID portainer.EndpointID) portainer.TunnelDetails {
return portainer.TunnelDetails{}
}
func (s *stubTunnelService) TunnelAddr(endpoint *portainer.Endpoint) (string, error) {
return "127.0.0.1:9999", nil
}
func (s *stubTunnelService) UpdateLastActivity(endpointID portainer.EndpointID) {}
func (s *stubTunnelService) KeepTunnelAlive(endpointID portainer.EndpointID, ctx context.Context, maxKeepAlive time.Duration) {
}
func enableSSRF(t *testing.T) {
t.Helper()
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
}
// TestNewDockerHTTPProxy_NonEdgeNoTLS verifies that a plain non-edge endpoint
// uses WrapTransport, setting DialContext on the inner transport.
func TestNewDockerHTTPProxy_NonEdgeNoTLS(t *testing.T) {
enableSSRF(t)
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
endpoint := &portainer.Endpoint{
Type: portainer.DockerEnvironment,
URL: "tcp://192.168.1.100:2376",
}
handler, err := f.newDockerHTTPProxy(endpoint)
require.NoError(t, err)
proxy := handler.(*httputil.ReverseProxy)
dt := proxy.Transport.(*docker.Transport)
require.NotNil(t, dt.HTTPTransport.DialContext)
}
// TestNewDockerHTTPProxy_NonEdgeTLS verifies that a TLS non-edge endpoint
// uses WrapTransport, setting DialContext on the inner transport.
func TestNewDockerHTTPProxy_NonEdgeTLS(t *testing.T) {
enableSSRF(t)
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
endpoint := &portainer.Endpoint{
Type: portainer.DockerEnvironment,
URL: "tcp://192.168.1.100:2376",
TLSConfig: portainer.TLSConfiguration{
TLS: true,
TLSSkipVerify: true,
},
}
handler, err := f.newDockerHTTPProxy(endpoint)
require.NoError(t, err)
proxy := handler.(*httputil.ReverseProxy)
dt := proxy.Transport.(*docker.Transport)
require.NotNil(t, dt.HTTPTransport.DialContext)
}
// TestNewDockerHTTPProxy_EdgeNoTLS verifies that an edge endpoint without TLS
// uses WrapTransportInternal, leaving DialContext nil.
func TestNewDockerHTTPProxy_EdgeNoTLS(t *testing.T) {
enableSSRF(t)
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
endpoint := &portainer.Endpoint{
Type: portainer.EdgeAgentOnDockerEnvironment,
URL: "tcp://192.168.1.100:2376",
}
handler, err := f.newDockerHTTPProxy(endpoint)
require.NoError(t, err)
proxy := handler.(*httputil.ReverseProxy)
dt := proxy.Transport.(*docker.Transport)
require.Nil(t, dt.HTTPTransport.DialContext)
}
// TestNewDockerHTTPProxy_EdgeTLS verifies that an edge endpoint with TLS
// uses WrapTransportInternal, leaving DialContext nil.
func TestNewDockerHTTPProxy_EdgeTLS(t *testing.T) {
enableSSRF(t)
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
endpoint := &portainer.Endpoint{
Type: portainer.EdgeAgentOnDockerEnvironment,
URL: "tcp://192.168.1.100:2376",
TLSConfig: portainer.TLSConfiguration{
TLS: true,
TLSSkipVerify: true,
},
}
handler, err := f.newDockerHTTPProxy(endpoint)
require.NoError(t, err)
proxy := handler.(*httputil.ReverseProxy)
dt := proxy.Transport.(*docker.Transport)
require.Nil(t, dt.HTTPTransport.DialContext)
}
func TestNewAgentProxy_NonEdgeNoTLS(t *testing.T) {
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
endpoint := &portainer.Endpoint{
Type: portainer.AgentOnDockerEnvironment,
URL: "tcp://192.168.1.100:9001",
}
proxyServer, err := f.NewAgentProxy(endpoint)
require.NoError(t, err)
defer proxyServer.Close()
require.Positive(t, proxyServer.Port)
}
func TestNewAgentProxy_NonEdgeTLS(t *testing.T) {
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
endpoint := &portainer.Endpoint{
Type: portainer.AgentOnDockerEnvironment,
URL: "tcp://192.168.1.100:9001",
TLSConfig: portainer.TLSConfiguration{
TLS: true,
TLSSkipVerify: true,
},
}
proxyServer, err := f.NewAgentProxy(endpoint)
require.NoError(t, err)
defer proxyServer.Close()
require.Positive(t, proxyServer.Port)
}
func TestNewAgentProxy_EdgeNoTLS(t *testing.T) {
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
endpoint := &portainer.Endpoint{
Type: portainer.EdgeAgentOnDockerEnvironment,
URL: "tcp://192.168.1.100:9001",
}
proxyServer, err := f.NewAgentProxy(endpoint)
require.NoError(t, err)
defer proxyServer.Close()
require.Positive(t, proxyServer.Port)
}
func TestNewAgentProxy_EdgeTLS(t *testing.T) {
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
endpoint := &portainer.Endpoint{
Type: portainer.EdgeAgentOnDockerEnvironment,
URL: "tcp://192.168.1.100:9001",
TLSConfig: portainer.TLSConfiguration{
TLS: true,
TLSSkipVerify: true,
},
}
proxyServer, err := f.NewAgentProxy(endpoint)
require.NoError(t, err)
defer proxyServer.Close()
require.Positive(t, proxyServer.Port)
}
+2
View File
@@ -111,6 +111,8 @@ type (
KubectlShellImageSet bool
PullLimitCheckDisabled *bool
TrustedOrigins *string
SSRFMode *string
SSRFAllowedHosts *[]string
NoSetupToken *bool
SetupToken *string
}
@@ -71,6 +71,10 @@ func (config *ComposeStackDeploymentConfig) Deploy(ctx context.Context) error {
}
}
if err := stackutils.ValidateComposeURLs(ctx, config.stack, config.FileService); err != nil {
return err
}
if stackutils.IsRelativePathStack(config.stack) {
return config.StackDeployer.DeployRemoteComposeStack(ctx, config.stack, config.endpoint, config.registries, config.prune, config.forcePullImage, config.ForceCreate)
}
@@ -71,6 +71,10 @@ func (config *SwarmStackDeploymentConfig) Deploy(ctx context.Context) error {
}
}
if err := stackutils.ValidateComposeURLs(ctx, config.stack, config.FileService); err != nil {
return err
}
if stackutils.IsRelativePathStack(config.stack) {
return config.StackDeployer.DeployRemoteSwarmStack(ctx, config.stack, config.endpoint, config.registries, config.prune, config.pullImage)
}
+94
View File
@@ -2,10 +2,13 @@ package stackutils
import (
"context"
"fmt"
"path"
"strings"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/filesystem"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
composeloader "github.com/compose-spec/compose-go/v2/loader"
composetypes "github.com/compose-spec/compose-go/v2/types"
@@ -68,6 +71,97 @@ func IsValidStackFile(config StackFileValidationConfig) error {
return nil
}
// ValidateComposeURLs parses each stack file and checks that every external URL
// (build contexts and image registry hostnames) is permitted by the active SSRF
// policy. It is a no-op when SSRF protection is disabled.
func ValidateComposeURLs(ctx context.Context, stack *portainer.Stack, fileService portainer.FileService) error {
if !ssrf.IsEnabled() {
return nil
}
env := BuildEnvMap(stack)
workingDir := filesystem.JoinPaths(stack.ProjectPath, path.Dir(stack.EntryPoint))
for _, file := range GetStackFilePaths(stack, false) {
stackContent, err := fileService.GetFileContent(stack.ProjectPath, file)
if err != nil {
return errors.Wrap(err, "failed to get stack file content")
}
if err := checkComposeFileURLs(ctx, stackContent, env, workingDir); err != nil {
return errors.Wrap(err, "stack file contains a URL blocked by the SSRF policy")
}
}
return nil
}
// ValidateEdgeStackComposeContent checks that every external URL in an edge
// stack's Compose file is permitted by the active SSRF policy. It is a no-op
// when SSRF protection is disabled or the deployment type is not compose.
func ValidateEdgeStackComposeContent(ctx context.Context, deploymentType portainer.EdgeStackDeploymentType, content []byte) error {
if !ssrf.IsEnabled() || deploymentType != portainer.EdgeStackDeploymentCompose {
return nil
}
if err := checkComposeFileURLs(ctx, content, nil, ""); err != nil {
return errors.Wrap(err, "stack file contains a URL blocked by the SSRF policy")
}
return nil
}
func checkComposeFileURLs(ctx context.Context, content []byte, env map[string]string, workingDir string) error {
composeConfigDetails := composetypes.ConfigDetails{
ConfigFiles: []composetypes.ConfigFile{{Content: content}},
Environment: env,
WorkingDir: workingDir,
}
composeConfig, err := composeloader.LoadWithContext(ctx, composeConfigDetails, composeloader.WithSkipValidation)
if err != nil {
return err
}
for _, service := range composeConfig.Services {
if service.Build != nil {
buildCtx := service.Build.Context
if strings.HasPrefix(buildCtx, "http://") || strings.HasPrefix(buildCtx, "https://") {
if err := ssrf.CheckURL(ctx, buildCtx); err != nil {
return fmt.Errorf("service %q: build context URL blocked: %w", service.Name, err)
}
}
}
if service.Image != "" {
if registry := extractImageRegistry(service.Image); registry != "" {
if err := ssrf.CheckURL(ctx, "https://"+registry); err != nil {
return fmt.Errorf("service %q: image registry %q blocked: %w", service.Name, registry, err)
}
}
}
}
return nil
}
// extractImageRegistry returns the registry hostname from an OCI image reference,
// or an empty string if the image resolves to Docker Hub (no explicit registry).
func extractImageRegistry(imageRef string) string {
ref, _, _ := strings.Cut(imageRef, "@")
first, _, hasSlash := strings.Cut(ref, "/")
if !hasSlash {
return ""
}
if strings.ContainsAny(first, ".:") || first == "localhost" {
return first
}
return ""
}
func ValidateStackFiles(stack *portainer.Stack, securitySettings *portainer.EndpointSecuritySettings, fileService portainer.FileService) error {
env := BuildEnvMap(stack)
workingDir := filesystem.JoinPaths(stack.ProjectPath, path.Dir(stack.EntryPoint))
+168
View File
@@ -6,6 +6,7 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/filesystem"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/stretchr/testify/require"
)
@@ -271,3 +272,170 @@ services:
err := ValidateStackFiles(stack, securitySettings, fileService)
require.ErrorContains(t, err, "bind-mount disabled for non administrator users")
}
func TestExtractImageRegistry(t *testing.T) {
t.Parallel()
cases := []struct {
image string
expected string
}{
{"nginx", ""},
{"nginx:latest", ""},
{"library/nginx", ""},
{"ghcr.io/owner/image:tag", "ghcr.io"},
{"myregistry.com/image:tag", "myregistry.com"},
{"myregistry.com:5000/image:tag", "myregistry.com:5000"},
{"localhost/image:tag", "localhost"},
{"localhost:5000/image:tag", "localhost:5000"},
{"myregistry.com/image@sha256:abc", "myregistry.com"},
{"169.254.169.254/image:tag", "169.254.169.254"},
}
for _, tc := range cases {
got := extractImageRegistry(tc.image)
require.Equal(t, tc.expected, got, "image: %s", tc.image)
}
}
func TestValidateComposeURLs_DisabledSSRF(t *testing.T) {
ssrf.Configure(ssrf.Policy{})
stack := &portainer.Stack{
ProjectPath: "/tmp/stack/1",
EntryPoint: "docker-compose.yml",
}
fileService := mockFileService{
fileContent: []byte(`
version: "3"
services:
web:
build:
context: http://169.254.169.254/repo.tar.gz
`),
projectVersionPath: "/tmp/stack/1",
}
err := ValidateComposeURLs(t.Context(), stack, fileService)
require.NoError(t, err)
}
func TestValidateComposeURLs_BuildContextBlocked(t *testing.T) {
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
stack := &portainer.Stack{
ProjectPath: "/tmp/stack/1",
EntryPoint: "docker-compose.yml",
}
fileService := mockFileService{
fileContent: []byte(`
version: "3"
services:
web:
build:
context: http://169.254.169.254/repo.tar.gz
image: nginx
`),
projectVersionPath: "/tmp/stack/1",
}
err := ValidateComposeURLs(t.Context(), stack, fileService)
require.ErrorContains(t, err, "SSRF policy")
}
func TestValidateComposeURLs_BuildContextPath(t *testing.T) {
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
stack := &portainer.Stack{
ProjectPath: "/tmp/stack/1",
EntryPoint: "docker-compose.yml",
}
fileService := mockFileService{
fileContent: []byte(`
version: "3"
services:
web:
build:
context: ./app
image: nginx
`),
projectVersionPath: "/tmp/stack/1",
}
err := ValidateComposeURLs(t.Context(), stack, fileService)
require.NoError(t, err)
}
func TestValidateComposeURLs_ImageRegistryBlocked(t *testing.T) {
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
stack := &portainer.Stack{
ProjectPath: "/tmp/stack/1",
EntryPoint: "docker-compose.yml",
}
fileService := mockFileService{
fileContent: []byte(`
version: "3"
services:
web:
image: 169.254.169.254/myimage:latest
`),
projectVersionPath: "/tmp/stack/1",
}
err := ValidateComposeURLs(t.Context(), stack, fileService)
require.ErrorContains(t, err, "SSRF policy")
}
func TestValidateComposeURLs_ImageRegistryAllowed(t *testing.T) {
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"myregistry.com"}})
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
stack := &portainer.Stack{
ProjectPath: "/tmp/stack/1",
EntryPoint: "docker-compose.yml",
}
fileService := mockFileService{
fileContent: []byte(`
version: "3"
services:
web:
image: myregistry.com/myimage:latest
`),
projectVersionPath: "/tmp/stack/1",
}
err := ValidateComposeURLs(t.Context(), stack, fileService)
require.NoError(t, err)
}
func TestValidateComposeURLs_DockerHubImageAllowed(t *testing.T) {
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
stack := &portainer.Stack{
ProjectPath: "/tmp/stack/1",
EntryPoint: "docker-compose.yml",
}
fileService := mockFileService{
fileContent: []byte(`
version: "3"
services:
web:
image: nginx:latest
`),
projectVersionPath: "/tmp/stack/1",
}
err := ValidateComposeURLs(t.Context(), stack, fileService)
require.NoError(t, err)
}
+1
View File
@@ -48,6 +48,7 @@ require (
github.com/prometheus/client_model v0.6.2
github.com/prometheus/common v0.67.5
github.com/prometheus/prometheus v0.311.3
github.com/quasilyte/go-ruleguard/dsl v0.3.23
github.com/robfig/cron/v3 v3.0.1
github.com/rs/zerolog v1.34.0
github.com/samber/slog-zerolog/v2 v2.9.1
+2
View File
@@ -841,6 +841,8 @@ github.com/prometheus/sigv4 v0.4.1 h1:EIc3j+8NBea9u1iV6O5ZAN8uvPq2xOIUPcqCTivHuX
github.com/prometheus/sigv4 v0.4.1/go.mod h1:eu+ZbRvsc5TPiHwqh77OWuCnWK73IdkETYY46P4dXOU=
github.com/puzpuzpuz/xsync/v4 v4.4.0 h1:vlSN6/CkEY0pY8KaB0yqo/pCLZvp9nhdbBdjipT4gWo=
github.com/puzpuzpuz/xsync/v4 v4.4.0/go.mod h1:VJDmTCJMBt8igNxnkQd86r+8KUeN1quSfNKu5bLYFQo=
github.com/quasilyte/go-ruleguard/dsl v0.3.23 h1:lxjt5B6ZCiBeeNO8/oQsegE6fLeCzuMRoVWSkXC4uvY=
github.com/quasilyte/go-ruleguard/dsl v0.3.23/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU=
github.com/redis/go-redis/extra/rediscmd/v9 v9.0.5 h1:EaDatTxkdHG+U3Bk4EUr+DZ7fOGwTfezUiUJMaIcaho=
github.com/redis/go-redis/extra/rediscmd/v9 v9.0.5/go.mod h1:fyalQWdtzDBECAQFBJuQe5bzQ02jGd5Qcbgb97Flm7U=
github.com/redis/go-redis/extra/redisotel/v9 v9.0.5 h1:EfpWLLCyXw8PSM2/XNJLjI3Pb27yVE+gIAfeqp8LUCc=
+252
View File
@@ -0,0 +1,252 @@
package ssrf
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"sync/atomic"
"github.com/rs/zerolog/log"
)
// Mode controls how the SSRF policy is applied.
type Mode string
const (
// ModeOff disables SSRF protection entirely. All connections pass through unchanged.
ModeOff Mode = "off"
// ModeAudit resolves and checks destinations but only logs violations; connections are allowed.
ModeAudit Mode = "audit"
// ModeEnforce blocks connections that violate the policy.
ModeEnforce Mode = "enforce"
)
// Policy defines the SSRF protection policy for outbound HTTP connections.
type Policy struct {
// Mode controls whether protection is off, in audit-only mode, or enforcing.
Mode Mode
// AllowedHosts is the allowlist of permitted destinations.
// Accepted formats:
// - Exact hostname: "example.com"
// - Wildcard hostname: "*.example.com" (matches any subdomain at any depth)
// - Single IP: "1.2.3.4"
// - CIDR range: "10.0.0.0/8"
//
// When Mode is ModeEnforce and AllowedHosts is empty, all outbound connections are blocked.
AllowedHosts []string
}
type safeDialer struct {
base net.Dialer
mode Mode
allowedNets []*net.IPNet
allowedHosts map[string]bool
allowedWilds []string // derived from "*.foo.com" entries; stored as ".foo.com"
}
var globalDialer atomic.Pointer[safeDialer]
// Configure initializes the global SSRF policy. Intended to be called once
// at startup before any outbound HTTP connections are established.
func Configure(policy Policy) {
if policy.Mode == ModeOff || policy.Mode == "" {
globalDialer.Store(nil)
return
}
globalDialer.Store(newSafeDialer(policy))
}
// IsEnabled reports whether SSRF protection is currently active (audit or enforce).
func IsEnabled() bool {
return globalDialer.Load() != nil
}
// CheckURL validates rawURL against the active SSRF policy without making a
// connection. Returns nil when protection is disabled or the destination is
// permitted. In audit mode, logs a warning on violations and returns nil.
func CheckURL(ctx context.Context, rawURL string) error {
d := globalDialer.Load()
if d == nil {
return nil
}
u, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("ssrf: invalid URL %q: %w", rawURL, err)
}
host := u.Hostname()
if host == "" {
return nil
}
return d.checkHost(ctx, host)
}
// WrapTransport clones t and replaces its DialContext with the global SSRF-filtering
// dialer. Returns t unchanged when SSRF protection is not configured.
func WrapTransport(t *http.Transport) *http.Transport {
d := globalDialer.Load()
if d == nil {
return t
}
cloned := t.Clone()
cloned.DialContext = d.DialContext
return cloned
}
// WrapTransportInternal is a documented no-op for transports that connect to
// internally computed destinations (local Docker socket proxy, Chisel tunnels,
// in-cluster Kubernetes API). The destination is chosen by Portainer, not
// supplied by any user, so SSRF validation is not applicable. Using this
// function instead of WrapTransport makes the exemption explicit and
// satisfies the ruleguard lint rule.
func WrapTransportInternal(t *http.Transport) *http.Transport {
return t
}
func newSafeDialer(policy Policy) *safeDialer {
allowedNets := make([]*net.IPNet, 0, len(policy.AllowedHosts))
allowedHosts := make(map[string]bool, len(policy.AllowedHosts))
var allowedWilds []string
for _, entry := range policy.AllowedHosts {
if _, network, err := net.ParseCIDR(entry); err == nil {
allowedNets = append(allowedNets, network)
continue
}
if ip := net.ParseIP(entry); ip != nil {
bits := 32
if ip.To4() == nil {
bits = 128
}
mask := net.CIDRMask(bits, bits)
allowedNets = append(allowedNets, &net.IPNet{IP: ip.Mask(mask), Mask: mask})
continue
}
if strings.HasPrefix(entry, "*.") {
allowedWilds = append(allowedWilds, entry[1:]) // "*.foo.com" -> ".foo.com"
continue
}
allowedHosts[entry] = true
}
return &safeDialer{
mode: policy.Mode,
allowedNets: allowedNets,
allowedHosts: allowedHosts,
allowedWilds: allowedWilds,
}
}
// DialContext resolves addr, validates all resolved IPs against the allowlist policy,
// then dials using the first resolved IP to prevent DNS rebinding attacks.
func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("ssrf: invalid address %q: %w", addr, err)
}
resolved, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, fmt.Errorf("ssrf: resolving %q: %w", host, err)
}
if len(resolved) == 0 {
return nil, fmt.Errorf("ssrf: no addresses resolved for %q", host)
}
// Dial by resolved IP regardless of how the host was allowed to close the
// window between DNS validation and the TCP handshake (DNS rebinding).
dialTarget := net.JoinHostPort(resolved[0].IP.String(), port)
if d.allowedHosts[host] || d.matchesWildcard(host) {
return d.base.DialContext(ctx, network, dialTarget)
}
for _, a := range resolved {
if !d.ipAllowed(a.IP) {
if d.mode == ModeAudit {
log.Warn().Str("host", host).Str("ip", a.IP.String()).Msg("ssrf: destination not in allowlist (audit mode, allowing)")
continue
}
return nil, fmt.Errorf("ssrf: destination %s is not in the allowlist", a.IP)
}
}
return d.base.DialContext(ctx, network, dialTarget)
}
func (d *safeDialer) checkHost(ctx context.Context, host string) error {
if d.allowedHosts[host] || d.matchesWildcard(host) {
return nil
}
if ip := net.ParseIP(host); ip != nil {
if !d.ipAllowed(ip) {
if d.mode == ModeAudit {
log.Warn().Str("host", host).Msg("ssrf: destination not in allowlist (audit mode, allowing)")
return nil
}
return fmt.Errorf("ssrf: destination %s is not in the allowlist", ip)
}
return nil
}
resolved, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return fmt.Errorf("ssrf: resolving %q: %w", host, err)
}
if len(resolved) == 0 {
return fmt.Errorf("ssrf: no addresses resolved for %q", host)
}
for _, a := range resolved {
if !d.ipAllowed(a.IP) {
if d.mode == ModeAudit {
log.Warn().Str("host", host).Str("ip", a.IP.String()).Msg("ssrf: destination not in allowlist (audit mode, allowing)")
continue
}
return fmt.Errorf("ssrf: destination %s is not in the allowlist", a.IP)
}
}
return nil
}
func (d *safeDialer) matchesWildcard(host string) bool {
for _, suffix := range d.allowedWilds {
if strings.HasSuffix(host, suffix) {
return true
}
}
return false
}
func (d *safeDialer) ipAllowed(ip net.IP) bool {
for _, network := range d.allowedNets {
if network.Contains(ip) {
return true
}
}
return false
}
+357
View File
@@ -0,0 +1,357 @@
package ssrf
import (
"context"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestIpAllowed_CIDR(t *testing.T) {
t.Parallel()
d := newSafeDialer(Policy{
Mode: ModeEnforce,
AllowedHosts: []string{"8.8.0.0/16", "2001:4860::/32"},
})
require.True(t, d.ipAllowed(net.ParseIP("8.8.8.8")))
require.True(t, d.ipAllowed(net.ParseIP("8.8.4.4")))
require.True(t, d.ipAllowed(net.ParseIP("2001:4860:4860::8888")))
require.False(t, d.ipAllowed(net.ParseIP("1.1.1.1")))
require.False(t, d.ipAllowed(net.ParseIP("127.0.0.1")))
require.False(t, d.ipAllowed(net.ParseIP("169.254.169.254")))
}
func TestIpAllowed_SingleIP(t *testing.T) {
t.Parallel()
d := newSafeDialer(Policy{
Mode: ModeEnforce,
AllowedHosts: []string{"1.2.3.4"},
})
require.True(t, d.ipAllowed(net.ParseIP("1.2.3.4")))
require.False(t, d.ipAllowed(net.ParseIP("1.2.3.5")))
}
func TestMatchesWildcard(t *testing.T) {
t.Parallel()
d := newSafeDialer(Policy{
Mode: ModeEnforce,
AllowedHosts: []string{"*.example.com", "exact.host.com"},
})
require.True(t, d.matchesWildcard("foo.example.com"))
require.True(t, d.matchesWildcard("bar.example.com"))
require.True(t, d.matchesWildcard("deep.nested.example.com"))
require.False(t, d.matchesWildcard("example.com"))
require.False(t, d.matchesWildcard("notexample.com"))
require.False(t, d.matchesWildcard("exact.host.com"))
}
func TestNewSafeDialer_MixedHosts(t *testing.T) {
t.Parallel()
d := newSafeDialer(Policy{
Mode: ModeEnforce,
AllowedHosts: []string{"example.com", "*.internal.net", "10.0.0.0/8", "1.2.3.4"},
})
require.True(t, d.allowedHosts["example.com"])
require.Contains(t, d.allowedWilds, ".internal.net")
require.Len(t, d.allowedNets, 2) // 10.0.0.0/8 and 1.2.3.4/32
}
func TestConfigure_Disabled(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
require.NotNil(t, globalDialer.Load())
Configure(Policy{})
require.Nil(t, globalDialer.Load())
}
func TestWrapTransport_NoPolicy(t *testing.T) {
globalDialer.Store(nil)
base := &http.Transport{}
result := WrapTransport(base)
require.Equal(t, base, result)
}
func TestWrapTransport_WithPolicy(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { globalDialer.Store(nil) })
base := &http.Transport{}
result := WrapTransport(base)
require.NotEqual(t, base, result)
require.NotNil(t, result.DialContext)
}
func TestCheckURL_Disabled(t *testing.T) {
globalDialer.Store(nil)
err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/")
require.NoError(t, err)
}
func TestCheckURL_BlocksIPNotInAllowlist(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}})
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/")
require.Error(t, err)
require.Contains(t, err.Error(), "ssrf")
}
func TestCheckURL_AllowedHostname(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "https://example.com/path")
require.NoError(t, err)
}
func TestCheckURL_AuditMode_ReturnsNil(t *testing.T) {
Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.0/24"}})
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/")
require.NoError(t, err)
}
// TestDialContext_BlocksLoopback is an end-to-end test: it starts a real HTTP
// server on 127.0.0.1, enables SSRF protection with an allowlist that does not
// include loopback, and verifies that the wrapped transport blocks the request.
func TestDialContext_BlocksLoopback(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.8"}})
t.Cleanup(func() { globalDialer.Store(nil) })
blocked := &http.Client{Transport: WrapTransport(&http.Transport{})}
resp, err := blocked.Get(srv.URL)
require.Error(t, err)
require.Contains(t, err.Error(), "ssrf")
if resp != nil {
require.NoError(t, resp.Body.Close())
}
Configure(Policy{})
open := &http.Client{Transport: WrapTransport(&http.Transport{})}
resp, err = open.Get(srv.URL)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
}
// TestDialContext_AuditMode_AllowsLoopback verifies that audit mode logs the
// violation but still allows the connection through.
func TestDialContext_AuditMode_AllowsLoopback(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.8"}})
t.Cleanup(func() { globalDialer.Store(nil) })
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
resp, err := client.Get(srv.URL)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
}
func TestIsEnabled(t *testing.T) {
globalDialer.Store(nil)
require.False(t, IsEnabled())
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { globalDialer.Store(nil) })
require.True(t, IsEnabled())
}
func TestWrapTransportInternal(t *testing.T) {
t.Parallel()
base := &http.Transport{}
result := WrapTransportInternal(base)
require.Equal(t, base, result)
}
func TestNewSafeDialer_IPv6SingleIP(t *testing.T) {
t.Parallel()
d := newSafeDialer(Policy{
Mode: ModeEnforce,
AllowedHosts: []string{"::1"},
})
require.True(t, d.ipAllowed(net.ParseIP("::1")))
require.False(t, d.ipAllowed(net.ParseIP("::2")))
}
func TestCheckURL_InvalidURL(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "http://%gg")
require.Error(t, err)
require.Contains(t, err.Error(), "ssrf")
}
func TestCheckURL_EmptyHost(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "http://")
require.NoError(t, err)
}
// TestCheckURL_IPInAllowlist verifies that a literal IP address that falls
// within an allowed CIDR range is permitted.
func TestCheckURL_IPInAllowlist(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}})
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "http://8.8.8.8/path")
require.NoError(t, err)
}
func TestCheckURL_WildcardHostname(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"*.example.com"}})
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "https://api.example.com/path")
require.NoError(t, err)
}
// TestCheckURL_HostnameDNSResolvesToAllowedIP verifies that a hostname
// resolving to an IP within the allowlist is permitted (DNS resolution path).
func TestCheckURL_HostnameDNSResolvesToAllowedIP(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"127.0.0.0/8", "::1/128"}})
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "http://localhost/path")
require.NoError(t, err)
}
// TestCheckURL_HostnameDNSResolvesToBlockedIP verifies that a hostname
// resolving to an IP outside the allowlist is blocked (DNS resolution path).
func TestCheckURL_HostnameDNSResolvesToBlockedIP(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}})
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "http://localhost/path")
require.Error(t, err)
require.Contains(t, err.Error(), "ssrf")
}
// TestCheckURL_HostnameDNSAuditMode verifies that audit mode logs violations
// from hostname DNS resolution but still returns nil.
func TestCheckURL_HostnameDNSAuditMode(t *testing.T) {
Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.0/24"}})
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "http://localhost/path")
require.NoError(t, err)
}
// TestCheckURL_HostnameDNSError verifies that a DNS resolution failure is
// propagated as an SSRF-prefixed error.
func TestCheckURL_HostnameDNSError(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}})
t.Cleanup(func() { globalDialer.Store(nil) })
ctx, cancel := context.WithCancel(t.Context())
cancel()
err := CheckURL(ctx, "http://portainer-nonexistent.test.invalid/path")
require.Error(t, err)
}
// TestDialContext_InvalidAddress verifies that an address without a port
// returns an SSRF-prefixed error.
func TestDialContext_InvalidAddress(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { globalDialer.Store(nil) })
d := globalDialer.Load()
_, err := d.DialContext(t.Context(), "tcp", "no-port-here")
require.Error(t, err)
require.Contains(t, err.Error(), "ssrf")
}
// TestDialContext_DNSError verifies that a DNS resolution failure in
// DialContext is propagated as an SSRF-prefixed error.
func TestDialContext_DNSError(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{}})
t.Cleanup(func() { globalDialer.Store(nil) })
ctx, cancel := context.WithCancel(t.Context())
cancel()
d := globalDialer.Load()
_, err := d.DialContext(ctx, "tcp", "portainer-nonexistent.test.invalid:80")
require.Error(t, err)
}
// TestDialContext_AllowedByCIDR is an end-to-end test verifying that
// connections to IPs within an allowed CIDR range are permitted.
func TestDialContext_AllowedByCIDR(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"127.0.0.0/8"}})
t.Cleanup(func() { globalDialer.Store(nil) })
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
resp, err := client.Get(srv.URL)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
}
// TestDialContext_AllowedByExactHostname verifies that when a hostname is in
// the allowed-hosts list, the connection is permitted even though the resolved
// IP is not covered by any CIDR entry.
//
// The server is bound to whatever IP "localhost" resolves to first so that the
// dialTarget computed by DialContext (resolved[0]) matches the listening address.
func TestDialContext_AllowedByExactHostname(t *testing.T) {
addrs, err := net.DefaultResolver.LookupIPAddr(t.Context(), "localhost")
require.NoError(t, err)
require.NotEmpty(t, addrs, "localhost must resolve to at least one address")
l, err := net.Listen("tcp", net.JoinHostPort(addrs[0].IP.String(), "0"))
require.NoError(t, err)
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
srv.Listener = l
srv.Start()
defer srv.Close()
_, portStr, err := net.SplitHostPort(l.Addr().String())
require.NoError(t, err)
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"localhost"}})
t.Cleanup(func() { globalDialer.Store(nil) })
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
resp, err := client.Get("http://localhost:" + portStr)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
}
+3 -2
View File
@@ -4,6 +4,7 @@ import (
"net/http"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/rs/zerolog/log"
"oras.land/oras-go/v2/registry/remote/retry"
@@ -15,8 +16,8 @@ import (
func CreateClient(registry *portainer.Registry) (httpClient *http.Client, usePlainHttp bool, err error) {
switch registry.Type {
case portainer.AzureRegistry, portainer.EcrRegistry, portainer.GithubRegistry, portainer.GitlabRegistry, portainer.DockerHubRegistry:
// Cloud registries use the default retry client with built-in TLS
return retry.DefaultClient, false, nil
base := http.DefaultTransport.(*http.Transport).Clone()
return &http.Client{Transport: retry.NewTransport(ssrf.WrapTransport(base))}, false, nil
default:
// For all other registry types, use shared helper to build transport and scheme
+7 -3
View File
@@ -11,6 +11,11 @@ import (
"oras.land/oras-go/v2/registry/remote/retry"
)
func isRetryTransport(t *http.Client) bool {
_, ok := t.Transport.(*retry.Transport)
return ok
}
func init() {
fips.InitFIPS(false)
}
@@ -102,8 +107,7 @@ func TestCreateClient(t *testing.T) {
// Verify client type based on registry configuration
switch tt.registry.Type {
case portainer.AzureRegistry, portainer.EcrRegistry, portainer.GithubRegistry, portainer.GitlabRegistry:
// Cloud registries should use the default retry client
assert.Equal(t, retry.DefaultClient, client)
assert.True(t, isRetryTransport(client), "Cloud registries should use a retry transport")
}
})
}
@@ -133,7 +137,7 @@ func TestCreateClient_CloudRegistries(t *testing.T) {
require.NoError(t, err)
assert.NotNil(t, client)
assert.False(t, usePlainHTTP, "Cloud registries should use HTTPS")
assert.Equal(t, retry.DefaultClient, client, "Cloud registries should use default retry client")
assert.True(t, isRetryTransport(client), "Cloud registries should use a retry transport")
})
}
}