mirror of
https://github.com/portainer/portainer.git
synced 2026-06-23 04:10:29 +00:00
feat(ssrf): implement an SSRF protection mechanism BE-13021 (#2818)
This commit is contained in:
@@ -5,6 +5,7 @@ run:
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- gocritic
|
||||
- bodyclose
|
||||
- copyloopvar
|
||||
- depguard
|
||||
@@ -76,6 +77,13 @@ linters:
|
||||
desc: use github.com/Masterminds/semver/v3
|
||||
- pkg: github.com/hashicorp/go-version
|
||||
desc: use github.com/Masterminds/semver/v3
|
||||
gocritic:
|
||||
disable-all: true
|
||||
enabled-checks:
|
||||
- ruleguard
|
||||
settings:
|
||||
ruleguard:
|
||||
rules: "./analysis/ssrf.go"
|
||||
forbidigo:
|
||||
forbid:
|
||||
- pattern: ^tls\.Config$
|
||||
@@ -93,6 +101,11 @@ linters:
|
||||
- comments
|
||||
- common-false-positives
|
||||
- legacy
|
||||
rules:
|
||||
- path: pkg/libhttp/ssrf
|
||||
linters:
|
||||
- gocritic
|
||||
text: ruleguard
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
//go:build ignore
|
||||
|
||||
package gorules
|
||||
|
||||
import "github.com/quasilyte/go-ruleguard/dsl"
|
||||
|
||||
// unwrappedHTTPTransport flags http.Transport composite literals that are not
|
||||
// the direct argument to ssrf.WrapTransport.
|
||||
func unwrappedHTTPTransport(m dsl.Matcher) {
|
||||
// Inline construction passed to a function call.
|
||||
m.Match(`$f(&http.Transport{$*_})`).
|
||||
Where(m["f"].Text != "ssrf.WrapTransport" && m["f"].Text != "WrapTransport" &&
|
||||
m["f"].Text != "ssrf.WrapTransportInternal" && m["f"].Text != "WrapTransportInternal").
|
||||
Report(`$f receives a bare *http.Transport; wrap with ssrf.WrapTransport() to enforce the SSRF protection policy`)
|
||||
|
||||
// Variable assigned a bare transport (cannot be tracked to a later WrapTransport call).
|
||||
m.Match(`$_ := &http.Transport{$*_}`).
|
||||
Report(`bare *http.Transport variable; use ssrf.WrapTransport(&http.Transport{...}) inline instead`)
|
||||
}
|
||||
|
||||
// internalTransportMisuse flags calls to WrapTransportInternal outside the four proxy
|
||||
// factory files where Chisel-tunnel and in-cluster K8s destinations are valid exemptions.
|
||||
func internalTransportMisuse(m dsl.Matcher) {
|
||||
m.Match(`ssrf.WrapTransportInternal($*_)`).
|
||||
Where(
|
||||
!(m.File().PkgPath.Matches(`proxy/factory`) &&
|
||||
m.File().Name.Matches(`^(docker|agent|local_transport|edge_transport)\.go$`))).
|
||||
Report(`WrapTransportInternal bypasses SSRF validation; only valid in the kubernetes local/edge transport constructors and the docker/agent proxy factories`)
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
//go:build tools
|
||||
|
||||
package gorules
|
||||
|
||||
import _ "github.com/quasilyte/go-ruleguard/dsl"
|
||||
@@ -56,6 +56,8 @@ func CLIFlags() *portainer.CLIFlags {
|
||||
TrustedOrigins: kingpin.Flag("trusted-origins", "List of trusted origins for CSRF protection. Separate multiple origins with a comma.").Envar(portainer.TrustedOriginsEnvVar).String(),
|
||||
CSP: kingpin.Flag("csp", "Content Security Policy (CSP) header").Envar(portainer.CSPEnvVar).Default("true").Bool(),
|
||||
CompactDB: kingpin.Flag("compact-db", "Enable database compaction on startup").Envar(portainer.CompactDBEnvVar).Default("false").Bool(),
|
||||
SSRFMode: kingpin.Flag("ssrf-mode", "SSRF protection mode: off (disabled), audit (log violations but allow), enforce (block violations)").Envar("PORTAINER_SSRF_MODE").Default("off").Enum("off", "audit", "enforce"),
|
||||
SSRFAllowedHosts: kingpin.Flag("ssrf-allowed-hosts", "Allowlist of hostnames (with optional wildcards), IPs, or CIDRs permitted for outbound requests. When empty and mode is enforce, all outbound connections are blocked").Envar("PORTAINER_SSRF_ALLOWED_HOSTS").Strings(),
|
||||
NoSetupToken: kingpin.Flag("no-setup-token", "Disable the setup token requirement for admin initialization and restore on an uninitialized instance").Envar(portainer.NoSetupTokenEnvVar).Bool(),
|
||||
SetupToken: kingpin.Flag("setup-token", "Set a custom setup token for admin initialization and restore on an uninitialized instance (overrides auto-generation)").Envar(portainer.SetupTokenEnvVar).String(),
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"cmp"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
nethttp "net/http"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
@@ -52,10 +53,12 @@ import (
|
||||
"github.com/portainer/portainer/pkg/featureflags"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
"github.com/portainer/portainer/pkg/libhelm"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
"github.com/portainer/portainer/pkg/libstack/compose"
|
||||
libswarm "github.com/portainer/portainer/pkg/libstack/swarm"
|
||||
"github.com/portainer/portainer/pkg/validate"
|
||||
|
||||
gogithttp "github.com/go-git/go-git/v5/plumbing/transport/http"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -384,6 +387,19 @@ func buildServer(flags *portainer.CLIFlags, shutdownCtx context.Context, shutdow
|
||||
// -ce can not ever be run in FIPS mode
|
||||
fips.InitFIPS(false)
|
||||
|
||||
ssrf.Configure(ssrf.Policy{
|
||||
Mode: ssrf.Mode(*flags.SSRFMode),
|
||||
AllowedHosts: *flags.SSRFAllowedHosts,
|
||||
})
|
||||
|
||||
if ssrf.IsEnabled() {
|
||||
if dt, ok := nethttp.DefaultTransport.(*nethttp.Transport); ok {
|
||||
nethttp.DefaultTransport = ssrf.WrapTransport(dt)
|
||||
}
|
||||
|
||||
gogithttp.DefaultClient = gogithttp.NewClient(&nethttp.Client{Transport: nethttp.DefaultTransport})
|
||||
}
|
||||
|
||||
fileService := initFileService(*flags.Data)
|
||||
encryptionKey := loadEncryptionSecretKey(dbSecretPath(*flags.SecretKeyName))
|
||||
if encryptionKey == nil {
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/crypto"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/docker/docker/api/types/image"
|
||||
@@ -184,17 +186,20 @@ func (t *NodeNameTransport) RoundTrip(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func httpClient(endpoint *portainer.Endpoint, timeout *time.Duration) (*http.Client, error) {
|
||||
transport := &NodeNameTransport{
|
||||
Transport: &http.Transport{},
|
||||
}
|
||||
|
||||
var transport *NodeNameTransport
|
||||
if endpoint.TLSConfig.TLS {
|
||||
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
transport.TLSClientConfig = tlsConfig
|
||||
transport = &NodeNameTransport{
|
||||
Transport: ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig}),
|
||||
}
|
||||
} else {
|
||||
transport = &NodeNameTransport{
|
||||
Transport: ssrf.WrapTransport(&http.Transport{}),
|
||||
}
|
||||
}
|
||||
|
||||
clientTimeout := defaultDockerRequestTimeout
|
||||
|
||||
+5
-6
@@ -14,12 +14,13 @@ import (
|
||||
"github.com/portainer/portainer/api/crypto"
|
||||
gittypes "github.com/portainer/portainer/api/git/types"
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
|
||||
"github.com/go-git/go-git/v5"
|
||||
"github.com/go-git/go-git/v5/plumbing/filemode"
|
||||
githttp "github.com/go-git/go-git/v5/plumbing/transport/http"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/segmentio/encoding/json"
|
||||
)
|
||||
|
||||
@@ -64,15 +65,13 @@ func NewAzureClient() *azureClient {
|
||||
}
|
||||
|
||||
func newHttpClientForAzure(insecureSkipVerify bool) *http.Client {
|
||||
httpsCli := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
return &http.Client{
|
||||
Transport: ssrf.WrapTransport(&http.Transport{
|
||||
TLSClientConfig: crypto.CreateTLSConfiguration(insecureSkipVerify),
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
},
|
||||
}),
|
||||
Timeout: 300 * time.Second,
|
||||
}
|
||||
|
||||
return httpsCli
|
||||
}
|
||||
|
||||
func (a *azureClient) Download(ctx context.Context, destination string, opt *git.CloneOptions) error {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/crypto"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/segmentio/encoding/json"
|
||||
@@ -114,18 +115,19 @@ func Get(url string, timeout int) ([]byte, error) {
|
||||
// using the specified host and optional TLS configuration.
|
||||
// It uses a new Http.Client for each operation.
|
||||
func ExecutePingOperation(host string, tlsConfiguration portainer.TLSConfiguration) (bool, error) {
|
||||
transport := &http.Transport{}
|
||||
|
||||
scheme := "http"
|
||||
|
||||
var transport *http.Transport
|
||||
if tlsConfiguration.TLS {
|
||||
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(tlsConfiguration)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
transport.TLSClientConfig = tlsConfig
|
||||
scheme = "https"
|
||||
transport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig})
|
||||
} else {
|
||||
transport = ssrf.WrapTransport(&http.Transport{})
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
|
||||
@@ -115,6 +115,11 @@ func (handler *Handler) createEdgeStackFromFileUpload(r *http.Request, tx datase
|
||||
if dryrun {
|
||||
return stack, nil
|
||||
}
|
||||
|
||||
if err := stackutils.ValidateEdgeStackComposeContent(r.Context(), payload.DeploymentType, payload.StackFileContent); err != nil {
|
||||
return nil, httperrors.NewInvalidPayloadError(err.Error())
|
||||
}
|
||||
|
||||
stack.CreatedByUserId = fmt.Sprintf("%d", tokenData.ID)
|
||||
stack.CreatedBy = stackutils.SanitizeLabel(tokenData.Username)
|
||||
|
||||
|
||||
@@ -93,6 +93,10 @@ func (handler *Handler) createEdgeStackFromFileContent(r *http.Request, tx datas
|
||||
return stack, nil
|
||||
}
|
||||
|
||||
if err := stackutils.ValidateEdgeStackComposeContent(r.Context(), payload.DeploymentType, []byte(payload.StackFileContent)); err != nil {
|
||||
return nil, httperrors.NewInvalidPayloadError(err.Error())
|
||||
}
|
||||
|
||||
return handler.edgeStacksService.PersistEdgeStack(tx, stack, func(stackFolder string, relatedEndpointIds []portainer.EndpointID) (composePath string, manifestPath string, projectPath string, err error) {
|
||||
return handler.storeFileContent(tx, stackFolder, payload.DeploymentType, relatedEndpointIds, []byte(payload.StackFileContent))
|
||||
})
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/internal/edge"
|
||||
"github.com/portainer/portainer/api/set"
|
||||
"github.com/portainer/portainer/api/stacks/stackutils"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
@@ -61,6 +62,10 @@ func (handler *Handler) edgeStackUpdate(w http.ResponseWriter, r *http.Request)
|
||||
return httperror.BadRequest("Invalid request payload", err)
|
||||
}
|
||||
|
||||
if err := stackutils.ValidateEdgeStackComposeContent(r.Context(), payload.DeploymentType, []byte(payload.StackFileContent)); err != nil {
|
||||
return httperror.BadRequest("Stack file contains a URL blocked by the SSRF policy", err)
|
||||
}
|
||||
|
||||
var stack *portainer.EdgeStack
|
||||
if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
stack, err = handler.updateEdgeStack(tx, portainer.EdgeStackID(stackID), payload)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/portainer/portainer/api/http/proxy/factory/agent"
|
||||
"github.com/portainer/portainer/api/internal/endpointutils"
|
||||
"github.com/portainer/portainer/api/url"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -40,21 +41,30 @@ func (factory *ProxyFactory) NewAgentProxy(endpoint *portainer.Endpoint) (*Proxy
|
||||
}
|
||||
|
||||
endpointURL.Scheme = "http"
|
||||
httpTransport := &http.Transport{}
|
||||
|
||||
var innerTransport *http.Transport
|
||||
if endpoint.TLSConfig.TLS || endpoint.TLSConfig.TLSSkipVerify {
|
||||
config, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig)
|
||||
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig)
|
||||
if err != nil {
|
||||
return nil, errors.WithMessage(err, "failed generating tls configuration")
|
||||
}
|
||||
|
||||
httpTransport.TLSClientConfig = config
|
||||
endpointURL.Scheme = "https"
|
||||
|
||||
if endpointutils.IsEdgeEndpoint(endpoint) {
|
||||
innerTransport = ssrf.WrapTransportInternal(&http.Transport{TLSClientConfig: tlsConfig})
|
||||
} else {
|
||||
innerTransport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig})
|
||||
}
|
||||
} else if endpointutils.IsEdgeEndpoint(endpoint) {
|
||||
innerTransport = ssrf.WrapTransportInternal(&http.Transport{})
|
||||
} else {
|
||||
innerTransport = ssrf.WrapTransport(&http.Transport{})
|
||||
}
|
||||
|
||||
proxy := NewSingleHostReverseProxyWithHostHeader(endpointURL)
|
||||
|
||||
proxy.Transport = agent.NewTransport(factory.signatureService, httpTransport)
|
||||
proxy.Transport = agent.NewTransport(factory.signatureService, innerTransport)
|
||||
|
||||
proxyServer := &ProxyServer{
|
||||
server: &http.Server{
|
||||
|
||||
@@ -8,8 +8,10 @@ import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/crypto"
|
||||
"github.com/portainer/portainer/api/http/proxy/factory/docker"
|
||||
"github.com/portainer/portainer/api/internal/endpointutils"
|
||||
"github.com/portainer/portainer/api/url"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -47,17 +49,6 @@ func (factory *ProxyFactory) newDockerHTTPProxy(endpoint *portainer.Endpoint) (h
|
||||
}
|
||||
|
||||
endpointURL.Scheme = "http"
|
||||
httpTransport := &http.Transport{}
|
||||
|
||||
if endpoint.TLSConfig.TLS || endpoint.TLSConfig.TLSSkipVerify {
|
||||
config, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpTransport.TLSClientConfig = config
|
||||
endpointURL.Scheme = "https"
|
||||
}
|
||||
|
||||
transportParameters := &docker.TransportParameters{
|
||||
Endpoint: endpoint,
|
||||
@@ -67,7 +58,27 @@ func (factory *ProxyFactory) newDockerHTTPProxy(endpoint *portainer.Endpoint) (h
|
||||
DockerClientFactory: factory.dockerClientFactory,
|
||||
}
|
||||
|
||||
dockerTransport, err := docker.NewTransport(transportParameters, httpTransport, factory.gitService, factory.snapshotService)
|
||||
var innerTransport *http.Transport
|
||||
if endpoint.TLSConfig.TLS || endpoint.TLSConfig.TLSSkipVerify {
|
||||
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
endpointURL.Scheme = "https"
|
||||
|
||||
if endpointutils.IsEdgeEndpoint(endpoint) {
|
||||
innerTransport = ssrf.WrapTransportInternal(&http.Transport{TLSClientConfig: tlsConfig})
|
||||
} else {
|
||||
innerTransport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig})
|
||||
}
|
||||
} else if endpointutils.IsEdgeEndpoint(endpoint) {
|
||||
innerTransport = ssrf.WrapTransportInternal(&http.Transport{})
|
||||
} else {
|
||||
innerTransport = ssrf.WrapTransport(&http.Transport{})
|
||||
}
|
||||
|
||||
dockerTransport, err := docker.NewTransport(transportParameters, innerTransport, factory.gitService, factory.snapshotService)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/segmentio/encoding/json"
|
||||
"oras.land/oras-go/v2/registry/remote/retry"
|
||||
@@ -92,7 +94,7 @@ func NewHTTPClient(token string) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: &tokenTransport{
|
||||
token: token,
|
||||
transport: retry.NewTransport(&http.Transport{}), // Use ORAS retry transport for consistent rate limiting and error handling
|
||||
transport: retry.NewTransport(ssrf.WrapTransport(&http.Transport{})), // Use ORAS retry transport for consistent rate limiting and error handling
|
||||
},
|
||||
Timeout: 1 * time.Minute,
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/segmentio/encoding/json"
|
||||
"oras.land/oras-go/v2/registry/remote/retry"
|
||||
@@ -92,7 +94,7 @@ type Transport struct {
|
||||
// interface for proxying requests to the Gitlab API.
|
||||
func NewTransport() *Transport {
|
||||
return &Transport{
|
||||
httpTransport: &http.Transport{},
|
||||
httpTransport: ssrf.WrapTransport(&http.Transport{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,7 +119,7 @@ func NewHTTPClient(token string) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: &tokenTransport{
|
||||
token: token,
|
||||
transport: retry.NewTransport(&http.Transport{}), // Use ORAS retry transport for consistent rate limiting and error handling
|
||||
transport: retry.NewTransport(ssrf.WrapTransport(&http.Transport{})), // Use ORAS retry transport for consistent rate limiting and error handling
|
||||
},
|
||||
Timeout: 1 * time.Minute,
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/portainer/portainer/api/crypto"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/kubernetes/cli"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
)
|
||||
|
||||
type agentTransport struct {
|
||||
@@ -24,9 +25,9 @@ func NewAgentTransport(signatureService portainer.DigitalSignatureService, token
|
||||
|
||||
transport := &agentTransport{
|
||||
baseTransport: newBaseTransport(
|
||||
&http.Transport{
|
||||
ssrf.WrapTransport(&http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
},
|
||||
}),
|
||||
tokenManager,
|
||||
endpoint,
|
||||
k8sClientFactory,
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/kubernetes/cli"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
)
|
||||
|
||||
type edgeTransport struct {
|
||||
@@ -21,7 +22,7 @@ func NewEdgeTransport(dataStore dataservices.DataStore, signatureService portain
|
||||
reverseTunnelService: reverseTunnelService,
|
||||
signatureService: signatureService,
|
||||
baseTransport: newBaseTransport(
|
||||
&http.Transport{},
|
||||
ssrf.WrapTransportInternal(&http.Transport{}),
|
||||
tokenManager,
|
||||
endpoint,
|
||||
k8sClientFactory,
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
package kubernetes
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewEdgeTransport(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
transport := NewEdgeTransport(nil, nil, nil, nil, nil, nil, nil)
|
||||
require.NotNil(t, transport)
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/portainer/portainer/api/crypto"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/kubernetes/cli"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
)
|
||||
|
||||
type localTransport struct {
|
||||
@@ -22,9 +23,9 @@ func NewLocalTransport(tokenManager *tokenManager, endpoint *portainer.Endpoint,
|
||||
|
||||
transport := &localTransport{
|
||||
baseTransport: newBaseTransport(
|
||||
&http.Transport{
|
||||
ssrf.WrapTransportInternal(&http.Transport{
|
||||
TLSClientConfig: config,
|
||||
},
|
||||
}),
|
||||
tokenManager,
|
||||
endpoint,
|
||||
k8sClientFactory,
|
||||
|
||||
@@ -0,0 +1,200 @@
|
||||
package factory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httputil"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/http/proxy/factory/docker"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() {
|
||||
fips.InitFIPS(false)
|
||||
}
|
||||
|
||||
type stubTunnelService struct{}
|
||||
|
||||
func (s *stubTunnelService) StartTunnelServer(addr, port string, snapshotService portainer.SnapshotService) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubTunnelService) StopTunnelServer() error { return nil }
|
||||
|
||||
func (s *stubTunnelService) GenerateEdgeKey(apiURL, tunnelAddr string, endpointIdentifier int) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *stubTunnelService) Open(endpoint *portainer.Endpoint) error { return nil }
|
||||
|
||||
func (s *stubTunnelService) Config(endpointID portainer.EndpointID) portainer.TunnelDetails {
|
||||
return portainer.TunnelDetails{}
|
||||
}
|
||||
|
||||
func (s *stubTunnelService) TunnelAddr(endpoint *portainer.Endpoint) (string, error) {
|
||||
return "127.0.0.1:9999", nil
|
||||
}
|
||||
|
||||
func (s *stubTunnelService) UpdateLastActivity(endpointID portainer.EndpointID) {}
|
||||
|
||||
func (s *stubTunnelService) KeepTunnelAlive(endpointID portainer.EndpointID, ctx context.Context, maxKeepAlive time.Duration) {
|
||||
}
|
||||
|
||||
func enableSSRF(t *testing.T) {
|
||||
t.Helper()
|
||||
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
|
||||
}
|
||||
|
||||
// TestNewDockerHTTPProxy_NonEdgeNoTLS verifies that a plain non-edge endpoint
|
||||
// uses WrapTransport, setting DialContext on the inner transport.
|
||||
func TestNewDockerHTTPProxy_NonEdgeNoTLS(t *testing.T) {
|
||||
enableSSRF(t)
|
||||
|
||||
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
|
||||
endpoint := &portainer.Endpoint{
|
||||
Type: portainer.DockerEnvironment,
|
||||
URL: "tcp://192.168.1.100:2376",
|
||||
}
|
||||
|
||||
handler, err := f.newDockerHTTPProxy(endpoint)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := handler.(*httputil.ReverseProxy)
|
||||
dt := proxy.Transport.(*docker.Transport)
|
||||
require.NotNil(t, dt.HTTPTransport.DialContext)
|
||||
}
|
||||
|
||||
// TestNewDockerHTTPProxy_NonEdgeTLS verifies that a TLS non-edge endpoint
|
||||
// uses WrapTransport, setting DialContext on the inner transport.
|
||||
func TestNewDockerHTTPProxy_NonEdgeTLS(t *testing.T) {
|
||||
enableSSRF(t)
|
||||
|
||||
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
|
||||
endpoint := &portainer.Endpoint{
|
||||
Type: portainer.DockerEnvironment,
|
||||
URL: "tcp://192.168.1.100:2376",
|
||||
TLSConfig: portainer.TLSConfiguration{
|
||||
TLS: true,
|
||||
TLSSkipVerify: true,
|
||||
},
|
||||
}
|
||||
|
||||
handler, err := f.newDockerHTTPProxy(endpoint)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := handler.(*httputil.ReverseProxy)
|
||||
dt := proxy.Transport.(*docker.Transport)
|
||||
require.NotNil(t, dt.HTTPTransport.DialContext)
|
||||
}
|
||||
|
||||
// TestNewDockerHTTPProxy_EdgeNoTLS verifies that an edge endpoint without TLS
|
||||
// uses WrapTransportInternal, leaving DialContext nil.
|
||||
func TestNewDockerHTTPProxy_EdgeNoTLS(t *testing.T) {
|
||||
enableSSRF(t)
|
||||
|
||||
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
|
||||
endpoint := &portainer.Endpoint{
|
||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||
URL: "tcp://192.168.1.100:2376",
|
||||
}
|
||||
|
||||
handler, err := f.newDockerHTTPProxy(endpoint)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := handler.(*httputil.ReverseProxy)
|
||||
dt := proxy.Transport.(*docker.Transport)
|
||||
require.Nil(t, dt.HTTPTransport.DialContext)
|
||||
}
|
||||
|
||||
// TestNewDockerHTTPProxy_EdgeTLS verifies that an edge endpoint with TLS
|
||||
// uses WrapTransportInternal, leaving DialContext nil.
|
||||
func TestNewDockerHTTPProxy_EdgeTLS(t *testing.T) {
|
||||
enableSSRF(t)
|
||||
|
||||
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
|
||||
endpoint := &portainer.Endpoint{
|
||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||
URL: "tcp://192.168.1.100:2376",
|
||||
TLSConfig: portainer.TLSConfiguration{
|
||||
TLS: true,
|
||||
TLSSkipVerify: true,
|
||||
},
|
||||
}
|
||||
|
||||
handler, err := f.newDockerHTTPProxy(endpoint)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := handler.(*httputil.ReverseProxy)
|
||||
dt := proxy.Transport.(*docker.Transport)
|
||||
require.Nil(t, dt.HTTPTransport.DialContext)
|
||||
}
|
||||
|
||||
func TestNewAgentProxy_NonEdgeNoTLS(t *testing.T) {
|
||||
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
|
||||
endpoint := &portainer.Endpoint{
|
||||
Type: portainer.AgentOnDockerEnvironment,
|
||||
URL: "tcp://192.168.1.100:9001",
|
||||
}
|
||||
|
||||
proxyServer, err := f.NewAgentProxy(endpoint)
|
||||
require.NoError(t, err)
|
||||
defer proxyServer.Close()
|
||||
|
||||
require.Positive(t, proxyServer.Port)
|
||||
}
|
||||
|
||||
func TestNewAgentProxy_NonEdgeTLS(t *testing.T) {
|
||||
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
|
||||
endpoint := &portainer.Endpoint{
|
||||
Type: portainer.AgentOnDockerEnvironment,
|
||||
URL: "tcp://192.168.1.100:9001",
|
||||
TLSConfig: portainer.TLSConfiguration{
|
||||
TLS: true,
|
||||
TLSSkipVerify: true,
|
||||
},
|
||||
}
|
||||
|
||||
proxyServer, err := f.NewAgentProxy(endpoint)
|
||||
require.NoError(t, err)
|
||||
defer proxyServer.Close()
|
||||
|
||||
require.Positive(t, proxyServer.Port)
|
||||
}
|
||||
|
||||
func TestNewAgentProxy_EdgeNoTLS(t *testing.T) {
|
||||
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
|
||||
endpoint := &portainer.Endpoint{
|
||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||
URL: "tcp://192.168.1.100:9001",
|
||||
}
|
||||
|
||||
proxyServer, err := f.NewAgentProxy(endpoint)
|
||||
require.NoError(t, err)
|
||||
defer proxyServer.Close()
|
||||
|
||||
require.Positive(t, proxyServer.Port)
|
||||
}
|
||||
|
||||
func TestNewAgentProxy_EdgeTLS(t *testing.T) {
|
||||
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
|
||||
endpoint := &portainer.Endpoint{
|
||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||
URL: "tcp://192.168.1.100:9001",
|
||||
TLSConfig: portainer.TLSConfiguration{
|
||||
TLS: true,
|
||||
TLSSkipVerify: true,
|
||||
},
|
||||
}
|
||||
|
||||
proxyServer, err := f.NewAgentProxy(endpoint)
|
||||
require.NoError(t, err)
|
||||
defer proxyServer.Close()
|
||||
|
||||
require.Positive(t, proxyServer.Port)
|
||||
}
|
||||
@@ -111,6 +111,8 @@ type (
|
||||
KubectlShellImageSet bool
|
||||
PullLimitCheckDisabled *bool
|
||||
TrustedOrigins *string
|
||||
SSRFMode *string
|
||||
SSRFAllowedHosts *[]string
|
||||
NoSetupToken *bool
|
||||
SetupToken *string
|
||||
}
|
||||
|
||||
@@ -71,6 +71,10 @@ func (config *ComposeStackDeploymentConfig) Deploy(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := stackutils.ValidateComposeURLs(ctx, config.stack, config.FileService); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if stackutils.IsRelativePathStack(config.stack) {
|
||||
return config.StackDeployer.DeployRemoteComposeStack(ctx, config.stack, config.endpoint, config.registries, config.prune, config.forcePullImage, config.ForceCreate)
|
||||
}
|
||||
|
||||
@@ -71,6 +71,10 @@ func (config *SwarmStackDeploymentConfig) Deploy(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := stackutils.ValidateComposeURLs(ctx, config.stack, config.FileService); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if stackutils.IsRelativePathStack(config.stack) {
|
||||
return config.StackDeployer.DeployRemoteSwarmStack(ctx, config.stack, config.endpoint, config.registries, config.prune, config.pullImage)
|
||||
}
|
||||
|
||||
@@ -2,10 +2,13 @@ package stackutils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
|
||||
composeloader "github.com/compose-spec/compose-go/v2/loader"
|
||||
composetypes "github.com/compose-spec/compose-go/v2/types"
|
||||
@@ -68,6 +71,97 @@ func IsValidStackFile(config StackFileValidationConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateComposeURLs parses each stack file and checks that every external URL
|
||||
// (build contexts and image registry hostnames) is permitted by the active SSRF
|
||||
// policy. It is a no-op when SSRF protection is disabled.
|
||||
func ValidateComposeURLs(ctx context.Context, stack *portainer.Stack, fileService portainer.FileService) error {
|
||||
if !ssrf.IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
env := BuildEnvMap(stack)
|
||||
workingDir := filesystem.JoinPaths(stack.ProjectPath, path.Dir(stack.EntryPoint))
|
||||
|
||||
for _, file := range GetStackFilePaths(stack, false) {
|
||||
stackContent, err := fileService.GetFileContent(stack.ProjectPath, file)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get stack file content")
|
||||
}
|
||||
|
||||
if err := checkComposeFileURLs(ctx, stackContent, env, workingDir); err != nil {
|
||||
return errors.Wrap(err, "stack file contains a URL blocked by the SSRF policy")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateEdgeStackComposeContent checks that every external URL in an edge
|
||||
// stack's Compose file is permitted by the active SSRF policy. It is a no-op
|
||||
// when SSRF protection is disabled or the deployment type is not compose.
|
||||
func ValidateEdgeStackComposeContent(ctx context.Context, deploymentType portainer.EdgeStackDeploymentType, content []byte) error {
|
||||
if !ssrf.IsEnabled() || deploymentType != portainer.EdgeStackDeploymentCompose {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := checkComposeFileURLs(ctx, content, nil, ""); err != nil {
|
||||
return errors.Wrap(err, "stack file contains a URL blocked by the SSRF policy")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkComposeFileURLs(ctx context.Context, content []byte, env map[string]string, workingDir string) error {
|
||||
composeConfigDetails := composetypes.ConfigDetails{
|
||||
ConfigFiles: []composetypes.ConfigFile{{Content: content}},
|
||||
Environment: env,
|
||||
WorkingDir: workingDir,
|
||||
}
|
||||
|
||||
composeConfig, err := composeloader.LoadWithContext(ctx, composeConfigDetails, composeloader.WithSkipValidation)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, service := range composeConfig.Services {
|
||||
if service.Build != nil {
|
||||
buildCtx := service.Build.Context
|
||||
if strings.HasPrefix(buildCtx, "http://") || strings.HasPrefix(buildCtx, "https://") {
|
||||
if err := ssrf.CheckURL(ctx, buildCtx); err != nil {
|
||||
return fmt.Errorf("service %q: build context URL blocked: %w", service.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if service.Image != "" {
|
||||
if registry := extractImageRegistry(service.Image); registry != "" {
|
||||
if err := ssrf.CheckURL(ctx, "https://"+registry); err != nil {
|
||||
return fmt.Errorf("service %q: image registry %q blocked: %w", service.Name, registry, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractImageRegistry returns the registry hostname from an OCI image reference,
|
||||
// or an empty string if the image resolves to Docker Hub (no explicit registry).
|
||||
func extractImageRegistry(imageRef string) string {
|
||||
ref, _, _ := strings.Cut(imageRef, "@")
|
||||
|
||||
first, _, hasSlash := strings.Cut(ref, "/")
|
||||
if !hasSlash {
|
||||
return ""
|
||||
}
|
||||
|
||||
if strings.ContainsAny(first, ".:") || first == "localhost" {
|
||||
return first
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func ValidateStackFiles(stack *portainer.Stack, securitySettings *portainer.EndpointSecuritySettings, fileService portainer.FileService) error {
|
||||
env := BuildEnvMap(stack)
|
||||
workingDir := filesystem.JoinPaths(stack.ProjectPath, path.Dir(stack.EntryPoint))
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -271,3 +272,170 @@ services:
|
||||
err := ValidateStackFiles(stack, securitySettings, fileService)
|
||||
require.ErrorContains(t, err, "bind-mount disabled for non administrator users")
|
||||
}
|
||||
|
||||
func TestExtractImageRegistry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
image string
|
||||
expected string
|
||||
}{
|
||||
{"nginx", ""},
|
||||
{"nginx:latest", ""},
|
||||
{"library/nginx", ""},
|
||||
{"ghcr.io/owner/image:tag", "ghcr.io"},
|
||||
{"myregistry.com/image:tag", "myregistry.com"},
|
||||
{"myregistry.com:5000/image:tag", "myregistry.com:5000"},
|
||||
{"localhost/image:tag", "localhost"},
|
||||
{"localhost:5000/image:tag", "localhost:5000"},
|
||||
{"myregistry.com/image@sha256:abc", "myregistry.com"},
|
||||
{"169.254.169.254/image:tag", "169.254.169.254"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
got := extractImageRegistry(tc.image)
|
||||
require.Equal(t, tc.expected, got, "image: %s", tc.image)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateComposeURLs_DisabledSSRF(t *testing.T) {
|
||||
ssrf.Configure(ssrf.Policy{})
|
||||
|
||||
stack := &portainer.Stack{
|
||||
ProjectPath: "/tmp/stack/1",
|
||||
EntryPoint: "docker-compose.yml",
|
||||
}
|
||||
|
||||
fileService := mockFileService{
|
||||
fileContent: []byte(`
|
||||
version: "3"
|
||||
services:
|
||||
web:
|
||||
build:
|
||||
context: http://169.254.169.254/repo.tar.gz
|
||||
`),
|
||||
projectVersionPath: "/tmp/stack/1",
|
||||
}
|
||||
|
||||
err := ValidateComposeURLs(t.Context(), stack, fileService)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidateComposeURLs_BuildContextBlocked(t *testing.T) {
|
||||
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
|
||||
|
||||
stack := &portainer.Stack{
|
||||
ProjectPath: "/tmp/stack/1",
|
||||
EntryPoint: "docker-compose.yml",
|
||||
}
|
||||
|
||||
fileService := mockFileService{
|
||||
fileContent: []byte(`
|
||||
version: "3"
|
||||
services:
|
||||
web:
|
||||
build:
|
||||
context: http://169.254.169.254/repo.tar.gz
|
||||
image: nginx
|
||||
`),
|
||||
projectVersionPath: "/tmp/stack/1",
|
||||
}
|
||||
|
||||
err := ValidateComposeURLs(t.Context(), stack, fileService)
|
||||
require.ErrorContains(t, err, "SSRF policy")
|
||||
}
|
||||
|
||||
func TestValidateComposeURLs_BuildContextPath(t *testing.T) {
|
||||
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
|
||||
|
||||
stack := &portainer.Stack{
|
||||
ProjectPath: "/tmp/stack/1",
|
||||
EntryPoint: "docker-compose.yml",
|
||||
}
|
||||
|
||||
fileService := mockFileService{
|
||||
fileContent: []byte(`
|
||||
version: "3"
|
||||
services:
|
||||
web:
|
||||
build:
|
||||
context: ./app
|
||||
image: nginx
|
||||
`),
|
||||
projectVersionPath: "/tmp/stack/1",
|
||||
}
|
||||
|
||||
err := ValidateComposeURLs(t.Context(), stack, fileService)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidateComposeURLs_ImageRegistryBlocked(t *testing.T) {
|
||||
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
|
||||
|
||||
stack := &portainer.Stack{
|
||||
ProjectPath: "/tmp/stack/1",
|
||||
EntryPoint: "docker-compose.yml",
|
||||
}
|
||||
|
||||
fileService := mockFileService{
|
||||
fileContent: []byte(`
|
||||
version: "3"
|
||||
services:
|
||||
web:
|
||||
image: 169.254.169.254/myimage:latest
|
||||
`),
|
||||
projectVersionPath: "/tmp/stack/1",
|
||||
}
|
||||
|
||||
err := ValidateComposeURLs(t.Context(), stack, fileService)
|
||||
require.ErrorContains(t, err, "SSRF policy")
|
||||
}
|
||||
|
||||
func TestValidateComposeURLs_ImageRegistryAllowed(t *testing.T) {
|
||||
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"myregistry.com"}})
|
||||
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
|
||||
|
||||
stack := &portainer.Stack{
|
||||
ProjectPath: "/tmp/stack/1",
|
||||
EntryPoint: "docker-compose.yml",
|
||||
}
|
||||
|
||||
fileService := mockFileService{
|
||||
fileContent: []byte(`
|
||||
version: "3"
|
||||
services:
|
||||
web:
|
||||
image: myregistry.com/myimage:latest
|
||||
`),
|
||||
projectVersionPath: "/tmp/stack/1",
|
||||
}
|
||||
|
||||
err := ValidateComposeURLs(t.Context(), stack, fileService)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidateComposeURLs_DockerHubImageAllowed(t *testing.T) {
|
||||
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
|
||||
|
||||
stack := &portainer.Stack{
|
||||
ProjectPath: "/tmp/stack/1",
|
||||
EntryPoint: "docker-compose.yml",
|
||||
}
|
||||
|
||||
fileService := mockFileService{
|
||||
fileContent: []byte(`
|
||||
version: "3"
|
||||
services:
|
||||
web:
|
||||
image: nginx:latest
|
||||
`),
|
||||
projectVersionPath: "/tmp/stack/1",
|
||||
}
|
||||
|
||||
err := ValidateComposeURLs(t.Context(), stack, fileService)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ require (
|
||||
github.com/prometheus/client_model v0.6.2
|
||||
github.com/prometheus/common v0.67.5
|
||||
github.com/prometheus/prometheus v0.311.3
|
||||
github.com/quasilyte/go-ruleguard/dsl v0.3.23
|
||||
github.com/robfig/cron/v3 v3.0.1
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/samber/slog-zerolog/v2 v2.9.1
|
||||
|
||||
@@ -841,6 +841,8 @@ github.com/prometheus/sigv4 v0.4.1 h1:EIc3j+8NBea9u1iV6O5ZAN8uvPq2xOIUPcqCTivHuX
|
||||
github.com/prometheus/sigv4 v0.4.1/go.mod h1:eu+ZbRvsc5TPiHwqh77OWuCnWK73IdkETYY46P4dXOU=
|
||||
github.com/puzpuzpuz/xsync/v4 v4.4.0 h1:vlSN6/CkEY0pY8KaB0yqo/pCLZvp9nhdbBdjipT4gWo=
|
||||
github.com/puzpuzpuz/xsync/v4 v4.4.0/go.mod h1:VJDmTCJMBt8igNxnkQd86r+8KUeN1quSfNKu5bLYFQo=
|
||||
github.com/quasilyte/go-ruleguard/dsl v0.3.23 h1:lxjt5B6ZCiBeeNO8/oQsegE6fLeCzuMRoVWSkXC4uvY=
|
||||
github.com/quasilyte/go-ruleguard/dsl v0.3.23/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU=
|
||||
github.com/redis/go-redis/extra/rediscmd/v9 v9.0.5 h1:EaDatTxkdHG+U3Bk4EUr+DZ7fOGwTfezUiUJMaIcaho=
|
||||
github.com/redis/go-redis/extra/rediscmd/v9 v9.0.5/go.mod h1:fyalQWdtzDBECAQFBJuQe5bzQ02jGd5Qcbgb97Flm7U=
|
||||
github.com/redis/go-redis/extra/redisotel/v9 v9.0.5 h1:EfpWLLCyXw8PSM2/XNJLjI3Pb27yVE+gIAfeqp8LUCc=
|
||||
|
||||
@@ -0,0 +1,252 @@
|
||||
package ssrf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Mode controls how the SSRF policy is applied.
|
||||
type Mode string
|
||||
|
||||
const (
|
||||
// ModeOff disables SSRF protection entirely. All connections pass through unchanged.
|
||||
ModeOff Mode = "off"
|
||||
// ModeAudit resolves and checks destinations but only logs violations; connections are allowed.
|
||||
ModeAudit Mode = "audit"
|
||||
// ModeEnforce blocks connections that violate the policy.
|
||||
ModeEnforce Mode = "enforce"
|
||||
)
|
||||
|
||||
// Policy defines the SSRF protection policy for outbound HTTP connections.
|
||||
type Policy struct {
|
||||
// Mode controls whether protection is off, in audit-only mode, or enforcing.
|
||||
Mode Mode
|
||||
|
||||
// AllowedHosts is the allowlist of permitted destinations.
|
||||
// Accepted formats:
|
||||
// - Exact hostname: "example.com"
|
||||
// - Wildcard hostname: "*.example.com" (matches any subdomain at any depth)
|
||||
// - Single IP: "1.2.3.4"
|
||||
// - CIDR range: "10.0.0.0/8"
|
||||
//
|
||||
// When Mode is ModeEnforce and AllowedHosts is empty, all outbound connections are blocked.
|
||||
AllowedHosts []string
|
||||
}
|
||||
|
||||
type safeDialer struct {
|
||||
base net.Dialer
|
||||
mode Mode
|
||||
allowedNets []*net.IPNet
|
||||
allowedHosts map[string]bool
|
||||
allowedWilds []string // derived from "*.foo.com" entries; stored as ".foo.com"
|
||||
}
|
||||
|
||||
var globalDialer atomic.Pointer[safeDialer]
|
||||
|
||||
// Configure initializes the global SSRF policy. Intended to be called once
|
||||
// at startup before any outbound HTTP connections are established.
|
||||
func Configure(policy Policy) {
|
||||
if policy.Mode == ModeOff || policy.Mode == "" {
|
||||
globalDialer.Store(nil)
|
||||
return
|
||||
}
|
||||
|
||||
globalDialer.Store(newSafeDialer(policy))
|
||||
}
|
||||
|
||||
// IsEnabled reports whether SSRF protection is currently active (audit or enforce).
|
||||
func IsEnabled() bool {
|
||||
return globalDialer.Load() != nil
|
||||
}
|
||||
|
||||
// CheckURL validates rawURL against the active SSRF policy without making a
|
||||
// connection. Returns nil when protection is disabled or the destination is
|
||||
// permitted. In audit mode, logs a warning on violations and returns nil.
|
||||
func CheckURL(ctx context.Context, rawURL string) error {
|
||||
d := globalDialer.Load()
|
||||
if d == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ssrf: invalid URL %q: %w", rawURL, err)
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return d.checkHost(ctx, host)
|
||||
}
|
||||
|
||||
// WrapTransport clones t and replaces its DialContext with the global SSRF-filtering
|
||||
// dialer. Returns t unchanged when SSRF protection is not configured.
|
||||
func WrapTransport(t *http.Transport) *http.Transport {
|
||||
d := globalDialer.Load()
|
||||
if d == nil {
|
||||
return t
|
||||
}
|
||||
|
||||
cloned := t.Clone()
|
||||
cloned.DialContext = d.DialContext
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
// WrapTransportInternal is a documented no-op for transports that connect to
|
||||
// internally computed destinations (local Docker socket proxy, Chisel tunnels,
|
||||
// in-cluster Kubernetes API). The destination is chosen by Portainer, not
|
||||
// supplied by any user, so SSRF validation is not applicable. Using this
|
||||
// function instead of WrapTransport makes the exemption explicit and
|
||||
// satisfies the ruleguard lint rule.
|
||||
func WrapTransportInternal(t *http.Transport) *http.Transport {
|
||||
return t
|
||||
}
|
||||
|
||||
func newSafeDialer(policy Policy) *safeDialer {
|
||||
allowedNets := make([]*net.IPNet, 0, len(policy.AllowedHosts))
|
||||
allowedHosts := make(map[string]bool, len(policy.AllowedHosts))
|
||||
var allowedWilds []string
|
||||
|
||||
for _, entry := range policy.AllowedHosts {
|
||||
if _, network, err := net.ParseCIDR(entry); err == nil {
|
||||
allowedNets = append(allowedNets, network)
|
||||
continue
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(entry); ip != nil {
|
||||
bits := 32
|
||||
if ip.To4() == nil {
|
||||
bits = 128
|
||||
}
|
||||
|
||||
mask := net.CIDRMask(bits, bits)
|
||||
allowedNets = append(allowedNets, &net.IPNet{IP: ip.Mask(mask), Mask: mask})
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(entry, "*.") {
|
||||
allowedWilds = append(allowedWilds, entry[1:]) // "*.foo.com" -> ".foo.com"
|
||||
continue
|
||||
}
|
||||
|
||||
allowedHosts[entry] = true
|
||||
}
|
||||
|
||||
return &safeDialer{
|
||||
mode: policy.Mode,
|
||||
allowedNets: allowedNets,
|
||||
allowedHosts: allowedHosts,
|
||||
allowedWilds: allowedWilds,
|
||||
}
|
||||
}
|
||||
|
||||
// DialContext resolves addr, validates all resolved IPs against the allowlist policy,
|
||||
// then dials using the first resolved IP to prevent DNS rebinding attacks.
|
||||
func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ssrf: invalid address %q: %w", addr, err)
|
||||
}
|
||||
|
||||
resolved, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ssrf: resolving %q: %w", host, err)
|
||||
}
|
||||
|
||||
if len(resolved) == 0 {
|
||||
return nil, fmt.Errorf("ssrf: no addresses resolved for %q", host)
|
||||
}
|
||||
|
||||
// Dial by resolved IP regardless of how the host was allowed to close the
|
||||
// window between DNS validation and the TCP handshake (DNS rebinding).
|
||||
dialTarget := net.JoinHostPort(resolved[0].IP.String(), port)
|
||||
|
||||
if d.allowedHosts[host] || d.matchesWildcard(host) {
|
||||
return d.base.DialContext(ctx, network, dialTarget)
|
||||
}
|
||||
|
||||
for _, a := range resolved {
|
||||
if !d.ipAllowed(a.IP) {
|
||||
if d.mode == ModeAudit {
|
||||
log.Warn().Str("host", host).Str("ip", a.IP.String()).Msg("ssrf: destination not in allowlist (audit mode, allowing)")
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("ssrf: destination %s is not in the allowlist", a.IP)
|
||||
}
|
||||
}
|
||||
|
||||
return d.base.DialContext(ctx, network, dialTarget)
|
||||
}
|
||||
|
||||
func (d *safeDialer) checkHost(ctx context.Context, host string) error {
|
||||
if d.allowedHosts[host] || d.matchesWildcard(host) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if !d.ipAllowed(ip) {
|
||||
if d.mode == ModeAudit {
|
||||
log.Warn().Str("host", host).Msg("ssrf: destination not in allowlist (audit mode, allowing)")
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("ssrf: destination %s is not in the allowlist", ip)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
resolved, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ssrf: resolving %q: %w", host, err)
|
||||
}
|
||||
|
||||
if len(resolved) == 0 {
|
||||
return fmt.Errorf("ssrf: no addresses resolved for %q", host)
|
||||
}
|
||||
|
||||
for _, a := range resolved {
|
||||
if !d.ipAllowed(a.IP) {
|
||||
if d.mode == ModeAudit {
|
||||
log.Warn().Str("host", host).Str("ip", a.IP.String()).Msg("ssrf: destination not in allowlist (audit mode, allowing)")
|
||||
continue
|
||||
}
|
||||
|
||||
return fmt.Errorf("ssrf: destination %s is not in the allowlist", a.IP)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *safeDialer) matchesWildcard(host string) bool {
|
||||
for _, suffix := range d.allowedWilds {
|
||||
if strings.HasSuffix(host, suffix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *safeDialer) ipAllowed(ip net.IP) bool {
|
||||
for _, network := range d.allowedNets {
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,357 @@
|
||||
package ssrf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIpAllowed_CIDR(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
d := newSafeDialer(Policy{
|
||||
Mode: ModeEnforce,
|
||||
AllowedHosts: []string{"8.8.0.0/16", "2001:4860::/32"},
|
||||
})
|
||||
|
||||
require.True(t, d.ipAllowed(net.ParseIP("8.8.8.8")))
|
||||
require.True(t, d.ipAllowed(net.ParseIP("8.8.4.4")))
|
||||
require.True(t, d.ipAllowed(net.ParseIP("2001:4860:4860::8888")))
|
||||
|
||||
require.False(t, d.ipAllowed(net.ParseIP("1.1.1.1")))
|
||||
require.False(t, d.ipAllowed(net.ParseIP("127.0.0.1")))
|
||||
require.False(t, d.ipAllowed(net.ParseIP("169.254.169.254")))
|
||||
}
|
||||
|
||||
func TestIpAllowed_SingleIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
d := newSafeDialer(Policy{
|
||||
Mode: ModeEnforce,
|
||||
AllowedHosts: []string{"1.2.3.4"},
|
||||
})
|
||||
|
||||
require.True(t, d.ipAllowed(net.ParseIP("1.2.3.4")))
|
||||
require.False(t, d.ipAllowed(net.ParseIP("1.2.3.5")))
|
||||
}
|
||||
|
||||
func TestMatchesWildcard(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
d := newSafeDialer(Policy{
|
||||
Mode: ModeEnforce,
|
||||
AllowedHosts: []string{"*.example.com", "exact.host.com"},
|
||||
})
|
||||
|
||||
require.True(t, d.matchesWildcard("foo.example.com"))
|
||||
require.True(t, d.matchesWildcard("bar.example.com"))
|
||||
require.True(t, d.matchesWildcard("deep.nested.example.com"))
|
||||
|
||||
require.False(t, d.matchesWildcard("example.com"))
|
||||
require.False(t, d.matchesWildcard("notexample.com"))
|
||||
require.False(t, d.matchesWildcard("exact.host.com"))
|
||||
}
|
||||
|
||||
func TestNewSafeDialer_MixedHosts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
d := newSafeDialer(Policy{
|
||||
Mode: ModeEnforce,
|
||||
AllowedHosts: []string{"example.com", "*.internal.net", "10.0.0.0/8", "1.2.3.4"},
|
||||
})
|
||||
|
||||
require.True(t, d.allowedHosts["example.com"])
|
||||
require.Contains(t, d.allowedWilds, ".internal.net")
|
||||
require.Len(t, d.allowedNets, 2) // 10.0.0.0/8 and 1.2.3.4/32
|
||||
}
|
||||
|
||||
func TestConfigure_Disabled(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
require.NotNil(t, globalDialer.Load())
|
||||
|
||||
Configure(Policy{})
|
||||
require.Nil(t, globalDialer.Load())
|
||||
}
|
||||
|
||||
func TestWrapTransport_NoPolicy(t *testing.T) {
|
||||
globalDialer.Store(nil)
|
||||
|
||||
base := &http.Transport{}
|
||||
result := WrapTransport(base)
|
||||
require.Equal(t, base, result)
|
||||
}
|
||||
|
||||
func TestWrapTransport_WithPolicy(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
base := &http.Transport{}
|
||||
result := WrapTransport(base)
|
||||
require.NotEqual(t, base, result)
|
||||
require.NotNil(t, result.DialContext)
|
||||
}
|
||||
|
||||
func TestCheckURL_Disabled(t *testing.T) {
|
||||
globalDialer.Store(nil)
|
||||
|
||||
err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCheckURL_BlocksIPNotInAllowlist(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "ssrf")
|
||||
}
|
||||
|
||||
func TestCheckURL_AllowedHostname(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "https://example.com/path")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCheckURL_AuditMode_ReturnsNil(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.0/24"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestDialContext_BlocksLoopback is an end-to-end test: it starts a real HTTP
|
||||
// server on 127.0.0.1, enables SSRF protection with an allowlist that does not
|
||||
// include loopback, and verifies that the wrapped transport blocks the request.
|
||||
func TestDialContext_BlocksLoopback(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.8"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
blocked := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
||||
resp, err := blocked.Get(srv.URL)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "ssrf")
|
||||
if resp != nil {
|
||||
require.NoError(t, resp.Body.Close())
|
||||
}
|
||||
|
||||
Configure(Policy{})
|
||||
|
||||
open := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
||||
resp, err = open.Get(srv.URL)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
}
|
||||
|
||||
// TestDialContext_AuditMode_AllowsLoopback verifies that audit mode logs the
|
||||
// violation but still allows the connection through.
|
||||
func TestDialContext_AuditMode_AllowsLoopback(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.8"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
||||
resp, err := client.Get(srv.URL)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
}
|
||||
|
||||
func TestIsEnabled(t *testing.T) {
|
||||
globalDialer.Store(nil)
|
||||
require.False(t, IsEnabled())
|
||||
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
require.True(t, IsEnabled())
|
||||
}
|
||||
|
||||
func TestWrapTransportInternal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := &http.Transport{}
|
||||
result := WrapTransportInternal(base)
|
||||
require.Equal(t, base, result)
|
||||
}
|
||||
|
||||
func TestNewSafeDialer_IPv6SingleIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
d := newSafeDialer(Policy{
|
||||
Mode: ModeEnforce,
|
||||
AllowedHosts: []string{"::1"},
|
||||
})
|
||||
|
||||
require.True(t, d.ipAllowed(net.ParseIP("::1")))
|
||||
require.False(t, d.ipAllowed(net.ParseIP("::2")))
|
||||
}
|
||||
|
||||
func TestCheckURL_InvalidURL(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "http://%gg")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "ssrf")
|
||||
}
|
||||
|
||||
func TestCheckURL_EmptyHost(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "http://")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestCheckURL_IPInAllowlist verifies that a literal IP address that falls
|
||||
// within an allowed CIDR range is permitted.
|
||||
func TestCheckURL_IPInAllowlist(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "http://8.8.8.8/path")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCheckURL_WildcardHostname(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"*.example.com"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "https://api.example.com/path")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestCheckURL_HostnameDNSResolvesToAllowedIP verifies that a hostname
|
||||
// resolving to an IP within the allowlist is permitted (DNS resolution path).
|
||||
func TestCheckURL_HostnameDNSResolvesToAllowedIP(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"127.0.0.0/8", "::1/128"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "http://localhost/path")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestCheckURL_HostnameDNSResolvesToBlockedIP verifies that a hostname
|
||||
// resolving to an IP outside the allowlist is blocked (DNS resolution path).
|
||||
func TestCheckURL_HostnameDNSResolvesToBlockedIP(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "http://localhost/path")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "ssrf")
|
||||
}
|
||||
|
||||
// TestCheckURL_HostnameDNSAuditMode verifies that audit mode logs violations
|
||||
// from hostname DNS resolution but still returns nil.
|
||||
func TestCheckURL_HostnameDNSAuditMode(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.0/24"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "http://localhost/path")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestCheckURL_HostnameDNSError verifies that a DNS resolution failure is
|
||||
// propagated as an SSRF-prefixed error.
|
||||
func TestCheckURL_HostnameDNSError(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel()
|
||||
|
||||
err := CheckURL(ctx, "http://portainer-nonexistent.test.invalid/path")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// TestDialContext_InvalidAddress verifies that an address without a port
|
||||
// returns an SSRF-prefixed error.
|
||||
func TestDialContext_InvalidAddress(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
d := globalDialer.Load()
|
||||
_, err := d.DialContext(t.Context(), "tcp", "no-port-here")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "ssrf")
|
||||
}
|
||||
|
||||
// TestDialContext_DNSError verifies that a DNS resolution failure in
|
||||
// DialContext is propagated as an SSRF-prefixed error.
|
||||
func TestDialContext_DNSError(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel()
|
||||
|
||||
d := globalDialer.Load()
|
||||
_, err := d.DialContext(ctx, "tcp", "portainer-nonexistent.test.invalid:80")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// TestDialContext_AllowedByCIDR is an end-to-end test verifying that
|
||||
// connections to IPs within an allowed CIDR range are permitted.
|
||||
func TestDialContext_AllowedByCIDR(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"127.0.0.0/8"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
||||
resp, err := client.Get(srv.URL)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
}
|
||||
|
||||
// TestDialContext_AllowedByExactHostname verifies that when a hostname is in
|
||||
// the allowed-hosts list, the connection is permitted even though the resolved
|
||||
// IP is not covered by any CIDR entry.
|
||||
//
|
||||
// The server is bound to whatever IP "localhost" resolves to first so that the
|
||||
// dialTarget computed by DialContext (resolved[0]) matches the listening address.
|
||||
func TestDialContext_AllowedByExactHostname(t *testing.T) {
|
||||
addrs, err := net.DefaultResolver.LookupIPAddr(t.Context(), "localhost")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, addrs, "localhost must resolve to at least one address")
|
||||
|
||||
l, err := net.Listen("tcp", net.JoinHostPort(addrs[0].IP.String(), "0"))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
srv.Listener = l
|
||||
srv.Start()
|
||||
defer srv.Close()
|
||||
|
||||
_, portStr, err := net.SplitHostPort(l.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"localhost"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
||||
resp, err := client.Get("http://localhost:" + portStr)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"oras.land/oras-go/v2/registry/remote/retry"
|
||||
@@ -15,8 +16,8 @@ import (
|
||||
func CreateClient(registry *portainer.Registry) (httpClient *http.Client, usePlainHttp bool, err error) {
|
||||
switch registry.Type {
|
||||
case portainer.AzureRegistry, portainer.EcrRegistry, portainer.GithubRegistry, portainer.GitlabRegistry, portainer.DockerHubRegistry:
|
||||
// Cloud registries use the default retry client with built-in TLS
|
||||
return retry.DefaultClient, false, nil
|
||||
base := http.DefaultTransport.(*http.Transport).Clone()
|
||||
return &http.Client{Transport: retry.NewTransport(ssrf.WrapTransport(base))}, false, nil
|
||||
default:
|
||||
// For all other registry types, use shared helper to build transport and scheme
|
||||
|
||||
|
||||
@@ -11,6 +11,11 @@ import (
|
||||
"oras.land/oras-go/v2/registry/remote/retry"
|
||||
)
|
||||
|
||||
func isRetryTransport(t *http.Client) bool {
|
||||
_, ok := t.Transport.(*retry.Transport)
|
||||
return ok
|
||||
}
|
||||
|
||||
func init() {
|
||||
fips.InitFIPS(false)
|
||||
}
|
||||
@@ -102,8 +107,7 @@ func TestCreateClient(t *testing.T) {
|
||||
// Verify client type based on registry configuration
|
||||
switch tt.registry.Type {
|
||||
case portainer.AzureRegistry, portainer.EcrRegistry, portainer.GithubRegistry, portainer.GitlabRegistry:
|
||||
// Cloud registries should use the default retry client
|
||||
assert.Equal(t, retry.DefaultClient, client)
|
||||
assert.True(t, isRetryTransport(client), "Cloud registries should use a retry transport")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -133,7 +137,7 @@ func TestCreateClient_CloudRegistries(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, client)
|
||||
assert.False(t, usePlainHTTP, "Cloud registries should use HTTPS")
|
||||
assert.Equal(t, retry.DefaultClient, client, "Cloud registries should use default retry client")
|
||||
assert.True(t, isRetryTransport(client), "Cloud registries should use a retry transport")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user