feat(ssrf): add ssrf allow list to settings [BE-13021] (#2858)

This commit is contained in:
Devon Steenberg
2026-06-12 15:16:06 +12:00
committed by GitHub
parent f87fec6d61
commit 8b21dfc318
17 changed files with 831 additions and 308 deletions
-2
View File
@@ -56,8 +56,6 @@ func CLIFlags() *portainer.CLIFlags {
TrustedOrigins: kingpin.Flag("trusted-origins", "List of trusted origins for CSRF protection. Separate multiple origins with a comma.").Envar(portainer.TrustedOriginsEnvVar).String(),
CSP: kingpin.Flag("csp", "Content Security Policy (CSP) header").Envar(portainer.CSPEnvVar).Default("true").Bool(),
CompactDB: kingpin.Flag("compact-db", "Enable database compaction on startup").Envar(portainer.CompactDBEnvVar).Default("false").Bool(),
SSRFMode: kingpin.Flag("ssrf-mode", "SSRF protection mode: off (disabled), audit (log violations but allow), enforce (block violations)").Envar("PORTAINER_SSRF_MODE").Default("off").Enum("off", "audit", "enforce"),
SSRFAllowedHosts: kingpin.Flag("ssrf-allowed-hosts", "Allowlist of hostnames (with optional wildcards), IPs, or CIDRs permitted for outbound requests. When empty and mode is enforce, all outbound connections are blocked").Envar("PORTAINER_SSRF_ALLOWED_HOSTS").Strings(),
NoSetupToken: kingpin.Flag("no-setup-token", "Disable the setup token requirement for admin initialization and restore on an uninitialized instance").Envar(portainer.NoSetupTokenEnvVar).Bool(),
SetupToken: kingpin.Flag("setup-token", "Set a custom setup token for admin initialization and restore on an uninitialized instance (overrides auto-generation)").Envar(portainer.SetupTokenEnvVar).String(),
}
+10 -13
View File
@@ -387,19 +387,6 @@ func buildServer(flags *portainer.CLIFlags, shutdownCtx context.Context, shutdow
// -ce can not ever be run in FIPS mode
fips.InitFIPS(false)
ssrf.Configure(ssrf.Policy{
Mode: ssrf.Mode(*flags.SSRFMode),
AllowedHosts: *flags.SSRFAllowedHosts,
})
if ssrf.IsEnabled() {
if dt, ok := nethttp.DefaultTransport.(*nethttp.Transport); ok {
nethttp.DefaultTransport = ssrf.WrapTransport(dt)
}
gogithttp.DefaultClient = gogithttp.NewClient(&nethttp.Client{Transport: nethttp.DefaultTransport})
}
fileService := initFileService(*flags.Data)
encryptionKey := loadEncryptionSecretKey(dbSecretPath(*flags.SecretKeyName))
if encryptionKey == nil {
@@ -417,6 +404,16 @@ func buildServer(flags *portainer.CLIFlags, shutdownCtx context.Context, shutdow
log.Fatal().Msg("The database schema version does not align with the server version. Please consider reverting to the previous server version or addressing the database migration issue.")
}
if err := ssrf.Configure(dataStore.AllowList()); err != nil {
log.Fatal().Err(err).Msg("failed initializing ssrf service")
}
if dt, ok := nethttp.DefaultTransport.(*nethttp.Transport); ok {
nethttp.DefaultTransport = ssrf.WrapTransport(dt)
}
gogithttp.DefaultClient = gogithttp.NewClient(&nethttp.Client{Transport: nethttp.DefaultTransport})
instanceID, err := dataStore.Version().InstanceID()
if err != nil {
log.Fatal().Err(err).Msg("failed getting instance id")
+131
View File
@@ -0,0 +1,131 @@
package allowlist
import (
"fmt"
lru "github.com/hashicorp/golang-lru"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
)
const (
BucketName = "allowlist"
)
type Service struct {
baseService dataservices.BaseDataService[portainer.AllowList, portainer.AllowListKey]
cache *lru.Cache
}
func (service *Service) BucketName() string {
return service.baseService.BucketName()
}
func NewService(connection portainer.Connection) (*Service, error) {
err := connection.SetServiceName(BucketName)
if err != nil {
return nil, err
}
service := &Service{
baseService: dataservices.BaseDataService[portainer.AllowList, portainer.AllowListKey]{
Bucket: BucketName,
Connection: connection,
},
}
err = service.populateCache()
return service, err
}
func (service *Service) populateCache() error {
allowListKeys := []portainer.AllowListKey{portainer.AllowListSSRF}
cache, err := lru.New(len(allowListKeys))
if err != nil {
return err
}
for _, k := range allowListKeys {
allowList, err := service.baseService.Read(k)
if dataservices.IsErrObjectNotFound(err) {
allowList = &portainer.AllowList{
ID: k,
Mode: portainer.SSRFModeOff,
Entries: []string{},
}
} else if err != nil {
return err
}
parsedAllowList := ssrf.ParseAllowedHosts(allowList.Entries)
parsedAllowList.Mode = allowList.Mode
cache.Add(k, &parsedAllowList)
}
service.cache = cache
return nil
}
func (service *Service) Tx(tx portainer.Transaction) *ServiceTx {
return &ServiceTx{
baseService: service.baseService.Tx(tx),
cache: service.cache,
}
}
func (service *Service) Read(id portainer.AllowListKey) (*portainer.AllowList, error) {
var result *portainer.AllowList
if err := service.baseService.Connection.ViewTx(func(tx portainer.Transaction) error {
var err error
result, err = service.Tx(tx).Read(id)
return err
}); err != nil {
return nil, err
}
return result, nil
}
func (service *Service) ReadAll() ([]portainer.AllowList, error) {
var result []portainer.AllowList
if err := service.baseService.Connection.ViewTx(func(tx portainer.Transaction) error {
var err error
result, err = service.Tx(tx).ReadAll()
return err
}); err != nil {
return nil, err
}
return result, nil
}
func (service *Service) ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error) {
allowListAny, ok := service.cache.Get(id)
if ok {
allowList, ok := allowListAny.(*portainer.ParsedAllowList)
if !ok {
return nil, fmt.Errorf("expected ParsedAllowList in cache but got %T", allowListAny)
}
return allowList, nil
}
var result *portainer.ParsedAllowList
err := service.baseService.Connection.ViewTx(func(tx portainer.Transaction) error {
var err error
result, err = service.Tx(tx).ReadParsed(id)
return err
})
return result, err
}
func (service *Service) Update(id portainer.AllowListKey, allowList *portainer.AllowList) error {
return service.baseService.Connection.UpdateTx(func(tx portainer.Transaction) error {
return service.Tx(tx).Update(id, allowList)
})
}
@@ -0,0 +1,89 @@
package allowlist_test
import (
"net"
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore"
"github.com/stretchr/testify/require"
)
func TestAllowListReadEmpty(t *testing.T) {
t.Parallel()
_, ds := datastore.MustNewTestStore(t, false, false)
got, err := ds.AllowList().Read(portainer.AllowListSSRF)
expected := &portainer.AllowList{
ID: portainer.AllowListSSRF,
Mode: portainer.SSRFModeOff,
Entries: []string{},
}
require.NoError(t, err)
require.Equal(t, expected, got)
}
func TestAllowListUpdate(t *testing.T) {
t.Parallel()
_, ds := datastore.MustNewTestStore(t, false, false)
expected := &portainer.AllowList{
ID: portainer.AllowListSSRF,
Mode: portainer.SSRFModeEnforce,
Entries: []string{"example.com", "10.0.0.0/8"},
}
require.NoError(t, ds.AllowList().Update(portainer.AllowListSSRF, expected))
got, err := ds.AllowList().Read(portainer.AllowListSSRF)
require.NoError(t, err)
require.Equal(t, expected, got)
}
func TestAllowListReadAllEmpty(t *testing.T) {
t.Parallel()
_, ds := datastore.MustNewTestStore(t, false, false)
got, err := ds.AllowList().ReadAll()
require.NoError(t, err)
require.Equal(t, []portainer.AllowList{}, got)
}
func TestAllowListReadAllAfterUpdate(t *testing.T) {
t.Parallel()
_, ds := datastore.MustNewTestStore(t, false, false)
expected := portainer.AllowList{
ID: portainer.AllowListSSRF,
Mode: portainer.SSRFModeEnforce,
Entries: []string{"example.com", "10.0.0.0/8"},
}
require.NoError(t, ds.AllowList().Update(portainer.AllowListSSRF, &expected))
got, err := ds.AllowList().ReadAll()
require.NoError(t, err)
require.Equal(t, []portainer.AllowList{expected}, got)
}
func TestAllowListReadParsedAfterUpdate(t *testing.T) {
t.Parallel()
_, ds := datastore.MustNewTestStore(t, false, false)
require.NoError(t, ds.AllowList().Update(portainer.AllowListSSRF, &portainer.AllowList{
ID: portainer.AllowListSSRF,
Mode: portainer.SSRFModeEnforce,
Entries: []string{"example.com"},
}))
expected := &portainer.ParsedAllowList{
Mode: portainer.SSRFModeEnforce,
Nets: []*net.IPNet{},
Hosts: map[string]bool{
"example.com": true,
},
}
got, err := ds.AllowList().ReadParsed(portainer.AllowListSSRF)
require.NoError(t, err)
require.Equal(t, expected, got)
}
+77
View File
@@ -0,0 +1,77 @@
package allowlist
import (
"fmt"
lru "github.com/hashicorp/golang-lru"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
)
type ServiceTx struct {
baseService dataservices.BaseDataServiceTx[portainer.AllowList, portainer.AllowListKey]
cache *lru.Cache
}
func (service *ServiceTx) BucketName() string {
return service.baseService.BucketName()
}
func (service *ServiceTx) ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error) {
allowListAny, ok := service.cache.Get(id)
if ok {
allowList, ok := allowListAny.(*portainer.ParsedAllowList)
if !ok {
return nil, fmt.Errorf("expected ParsedAllowList in cache but got %T", allowListAny)
}
return allowList, nil
}
allowList, err := service.Read(id)
if err != nil {
return nil, err
}
parsed := ssrf.ParseAllowedHosts(allowList.Entries)
parsed.Mode = allowList.Mode
service.cache.Add(id, &parsed)
return &parsed, nil
}
func (service *ServiceTx) Read(id portainer.AllowListKey) (*portainer.AllowList, error) {
allowList, err := service.baseService.Read(id)
if dataservices.IsErrObjectNotFound(err) {
allowList = &portainer.AllowList{
ID: id,
Mode: portainer.SSRFModeOff,
Entries: []string{},
}
} else if err != nil {
return nil, err
}
return allowList, nil
}
func (service *ServiceTx) ReadAll() ([]portainer.AllowList, error) {
allowLists, err := service.baseService.ReadAll()
if err != nil && !dataservices.IsErrObjectNotFound(err) {
return nil, err
}
return allowLists, nil
}
func (service *ServiceTx) Update(id portainer.AllowListKey, allowList *portainer.AllowList) error {
if err := service.baseService.Update(id, allowList); err != nil {
return err
}
parsed := ssrf.ParseAllowedHosts(allowList.Entries)
parsed.Mode = allowList.Mode
service.cache.Add(id, &parsed)
return nil
}
+92
View File
@@ -0,0 +1,92 @@
package allowlist_test
import (
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/datastore"
"github.com/stretchr/testify/require"
)
func TestAllowListReadTx(t *testing.T) {
t.Parallel()
_, ds := datastore.MustNewTestStore(t, false, false)
var got *portainer.AllowList
require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error {
var err error
got, err = tx.AllowList().Read(portainer.AllowListSSRF)
return err
}))
expected := &portainer.AllowList{
ID: portainer.AllowListSSRF,
Mode: portainer.SSRFModeOff,
Entries: []string{},
}
require.Equal(t, expected, got)
}
func TestAllowListReadAllEmptyTx(t *testing.T) {
t.Parallel()
_, ds := datastore.MustNewTestStore(t, false, false)
var got []portainer.AllowList
require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error {
var err error
got, err = tx.AllowList().ReadAll()
return err
}))
require.Equal(t, []portainer.AllowList{}, got)
}
func TestAllowListReadAllAfterUpdateTx(t *testing.T) {
t.Parallel()
_, ds := datastore.MustNewTestStore(t, false, false)
expected := portainer.AllowList{
ID: portainer.AllowListSSRF,
Mode: portainer.SSRFModeEnforce,
Entries: []string{"example.com"},
}
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
return tx.AllowList().Update(portainer.AllowListSSRF, &expected)
}))
var got []portainer.AllowList
require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error {
var err error
got, err = tx.AllowList().ReadAll()
return err
}))
require.Equal(t, []portainer.AllowList{expected}, got)
}
func TestAllowListUpdateTx(t *testing.T) {
t.Parallel()
_, ds := datastore.MustNewTestStore(t, false, false)
expected := &portainer.AllowList{
ID: portainer.AllowListSSRF,
Mode: portainer.SSRFModeEnforce,
Entries: []string{"example.com"},
}
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
return tx.AllowList().Update(portainer.AllowListSSRF, expected)
}))
var got *portainer.AllowList
require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error {
var err error
got, err = tx.AllowList().Read(portainer.AllowListSSRF)
return err
}))
require.Equal(t, expected, got)
}
+10
View File
@@ -8,6 +8,7 @@ import (
type (
DataStoreTx interface {
IsErrObjectNotFound(err error) bool
AllowList() AllowListService
CustomTemplate() CustomTemplateService
EdgeGroup() EdgeGroupService
EdgeJob() EdgeJobService
@@ -53,6 +54,15 @@ type (
DataStoreTx
}
// AllowListService represents a service for managing the URL allow list
AllowListService interface {
Read(id portainer.AllowListKey) (*portainer.AllowList, error)
ReadAll() ([]portainer.AllowList, error)
ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error)
Update(id portainer.AllowListKey, allowList *portainer.AllowList) error
BucketName() string
}
// CustomTemplateService represents a service to manage custom templates
CustomTemplateService interface {
BaseCRUD[portainer.CustomTemplate, portainer.CustomTemplateID]
+14 -1
View File
@@ -7,6 +7,7 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/database/models"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/dataservices/allowlist"
"github.com/portainer/portainer/api/dataservices/apikeyrepository"
"github.com/portainer/portainer/api/dataservices/customtemplate"
"github.com/portainer/portainer/api/dataservices/dockerhub"
@@ -51,6 +52,7 @@ type Store struct {
connection portainer.Connection
fileService portainer.FileService
AllowListService *allowlist.Service
CustomTemplateService *customtemplate.Service
DockerHubService *dockerhub.Service
EdgeGroupService *edgegroup.Service
@@ -84,6 +86,12 @@ type Store struct {
}
func (store *Store) initServices() error {
allowListService, err := allowlist.NewService(store.connection)
if err != nil {
return err
}
store.AllowListService = allowListService
authorizationsetService, err := role.NewService(store.connection)
if err != nil {
return err
@@ -275,6 +283,11 @@ func (store *Store) PendingActions() dataservices.PendingActionsService {
return store.PendingActionsService
}
// AllowList gives access to the AllowList data management layer
func (store *Store) AllowList() dataservices.AllowListService {
return store.AllowListService
}
// CustomTemplate gives access to the CustomTemplate data management layer
func (store *Store) CustomTemplate() dataservices.CustomTemplateService {
return store.CustomTemplateService
@@ -654,7 +667,7 @@ func (store *Store) Export(filename string) (err error) {
return err
}
return os.WriteFile(filename, b, 0600)
return os.WriteFile(filename, b, 0o600)
}
func (store *Store) Import(filename string) (err error) {
+4
View File
@@ -14,6 +14,10 @@ func (tx *StoreTx) IsErrObjectNotFound(err error) bool {
return tx.store.IsErrObjectNotFound(err)
}
func (tx *StoreTx) AllowList() dataservices.AllowListService {
return tx.store.AllowListService.Tx(tx.tx)
}
func (tx *StoreTx) CustomTemplate() dataservices.CustomTemplateService {
return tx.store.CustomTemplateService.Tx(tx.tx)
}
@@ -1,4 +1,5 @@
{
"allowlist": null,
"api_key": null,
"customtemplates": null,
"dockerhub": [
+2 -2
View File
@@ -90,7 +90,7 @@ func createTCPClient(endpoint *portainer.Endpoint, timeout *time.Duration) (*cli
client.WithHTTPClient(httpCli),
}
if nnTransport, ok := httpCli.Transport.(*NodeNameTransport); ok && nnTransport.TLSClientConfig != nil {
if endpoint.TLSConfig.TLS {
opts = append(opts, client.WithScheme("https"))
}
@@ -124,7 +124,7 @@ func createAgentClient(endpoint *portainer.Endpoint, endpointURL string, signatu
client.WithHTTPHeaders(headers),
}
if nnTransport, ok := httpCli.Transport.(*NodeNameTransport); ok && nnTransport.TLSClientConfig != nil {
if endpoint.TLSConfig.TLS {
opts = append(opts, client.WithScheme("https"))
}
+16 -2
View File
@@ -45,10 +45,24 @@ func (s *stubTunnelService) UpdateLastActivity(endpointID portainer.EndpointID)
func (s *stubTunnelService) KeepTunnelAlive(endpointID portainer.EndpointID, ctx context.Context, maxKeepAlive time.Duration) {
}
type staticAllowListService struct {
parsed portainer.ParsedAllowList
}
func (s *staticAllowListService) ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error) {
return &s.parsed, nil
}
func enableSSRF(t *testing.T) {
t.Helper()
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
parsed := ssrf.ParseAllowedHosts([]string{"example.com"})
parsed.Mode = portainer.SSRFModeEnforce
err := ssrf.Configure(&staticAllowListService{parsed: parsed})
require.NoError(t, err)
t.Cleanup(func() {
err := ssrf.Configure(&staticAllowListService{})
require.NoError(t, err)
})
}
// TestNewDockerHTTPProxy_NonEdgeNoTLS verifies that a plain non-edge endpoint
+2
View File
@@ -13,6 +13,7 @@ import (
var _ dataservices.DataStore = &testDatastore{}
type testDatastore struct {
allowList dataservices.AllowListService
customTemplate dataservices.CustomTemplateService
edgeGroup dataservices.EdgeGroupService
edgeJob dataservices.EdgeJobService
@@ -53,6 +54,7 @@ func (d *testDatastore) ViewTx(func(dataservices.DataStoreTx) error) error { r
func (d *testDatastore) CheckCurrentEdition() error { return nil }
func (d *testDatastore) MigrateData() error { return nil }
func (d *testDatastore) Rollback(force bool) error { return nil }
func (d *testDatastore) AllowList() dataservices.AllowListService { return d.allowList }
func (d *testDatastore) CustomTemplate() dataservices.CustomTemplateService { return d.customTemplate }
func (d *testDatastore) EdgeGroup() dataservices.EdgeGroupService { return d.edgeGroup }
func (d *testDatastore) EdgeJob() dataservices.EdgeJobService { return d.edgeJob }
+30 -2
View File
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"net"
"net/http"
"time"
@@ -111,8 +112,6 @@ type (
KubectlShellImageSet bool
PullLimitCheckDisabled *bool
TrustedOrigins *string
SSRFMode *string
SSRFAllowedHosts *[]string
NoSetupToken *bool
SetupToken *string
}
@@ -1215,6 +1214,21 @@ type (
// SoftwareEdition represents an edition of Portainer
SoftwareEdition int
// AllowList holds the list of permitted outbound proxy destinations.
AllowList struct {
ID AllowListKey `json:"Id"`
Mode SSRFMode `json:"Mode"`
Entries []string `json:"Entries"`
}
// ParsedAllowList holds the three parsed forms of allow list entries.
ParsedAllowList struct {
Mode SSRFMode
Nets []*net.IPNet
Hosts map[string]bool
Wilds []string // stored as ".foo.com" ("*." prefix stripped)
}
// SSLSettings represents a pair of SSL certificate and key
SSLSettings struct {
CertPath string `json:"certPath"`
@@ -2670,3 +2684,17 @@ func DefaultEndpointSecuritySettings() EndpointSecuritySettings {
AllowStackManagementForRegularUsers: true,
}
}
type AllowListKey int
const (
AllowListSSRF AllowListKey = iota
)
type SSRFMode int
const (
SSRFModeOff SSRFMode = iota
SSRFModeAudit
SSRFModeEnforce
)
+29 -14
View File
@@ -139,7 +139,6 @@ services:
`)
stack := &portainer.Stack{
ProjectPath: "/tmp/stack/1",
EntryPoint: "docker-compose.yml",
Env: []portainer.Pair{{Name: "API_PORT", Value: "3000"}},
@@ -186,7 +185,7 @@ func TestValidateStackFiles_DotEnvFile(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
err := os.WriteFile(filesystem.JoinPaths(tmpDir, ".env"), []byte("HOST_PORT=3000\n"), 0600)
err := os.WriteFile(filesystem.JoinPaths(tmpDir, ".env"), []byte("HOST_PORT=3000\n"), 0o600)
require.NoError(t, err)
fileContent := []byte(`
@@ -217,7 +216,7 @@ func TestValidateStackFiles_EnvFileAttribute(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
err := os.WriteFile(filesystem.JoinPaths(tmpDir, "web.env"), []byte("HOST_PORT=3000\n"), 0600)
err := os.WriteFile(filesystem.JoinPaths(tmpDir, "web.env"), []byte("HOST_PORT=3000\n"), 0o600)
require.NoError(t, err)
fileContent := []byte(`
@@ -298,8 +297,29 @@ func TestExtractImageRegistry(t *testing.T) {
}
}
type staticAllowListService struct {
parsed portainer.ParsedAllowList
}
func (s *staticAllowListService) ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error) {
return &s.parsed, nil
}
func configureSSRF(t *testing.T, mode portainer.SSRFMode, entries []string) {
t.Helper()
parsed := ssrf.ParseAllowedHosts(entries)
parsed.Mode = mode
err := ssrf.Configure(&staticAllowListService{parsed: parsed})
require.NoError(t, err)
t.Cleanup(func() {
err := ssrf.Configure(&staticAllowListService{})
require.NoError(t, err)
})
}
func TestValidateComposeURLs_DisabledSSRF(t *testing.T) {
ssrf.Configure(ssrf.Policy{})
configureSSRF(t, portainer.SSRFModeOff, nil)
stack := &portainer.Stack{
ProjectPath: "/tmp/stack/1",
@@ -322,8 +342,7 @@ services:
}
func TestValidateComposeURLs_BuildContextBlocked(t *testing.T) {
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
configureSSRF(t, portainer.SSRFModeEnforce, []string{"example.com"})
stack := &portainer.Stack{
ProjectPath: "/tmp/stack/1",
@@ -347,8 +366,7 @@ services:
}
func TestValidateComposeURLs_BuildContextPath(t *testing.T) {
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
configureSSRF(t, portainer.SSRFModeEnforce, []string{"example.com"})
stack := &portainer.Stack{
ProjectPath: "/tmp/stack/1",
@@ -372,8 +390,7 @@ services:
}
func TestValidateComposeURLs_ImageRegistryBlocked(t *testing.T) {
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
configureSSRF(t, portainer.SSRFModeEnforce, []string{"example.com"})
stack := &portainer.Stack{
ProjectPath: "/tmp/stack/1",
@@ -395,8 +412,7 @@ services:
}
func TestValidateComposeURLs_ImageRegistryAllowed(t *testing.T) {
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"myregistry.com"}})
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
configureSSRF(t, portainer.SSRFModeEnforce, []string{"myregistry.com"})
stack := &portainer.Stack{
ProjectPath: "/tmp/stack/1",
@@ -418,8 +434,7 @@ services:
}
func TestValidateComposeURLs_DockerHubImageAllowed(t *testing.T) {
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
configureSSRF(t, portainer.SSRFModeEnforce, []string{"example.com"})
stack := &portainer.Stack{
ProjectPath: "/tmp/stack/1",
+92 -88
View File
@@ -2,6 +2,7 @@ package ssrf
import (
"context"
"errors"
"fmt"
"net"
"net/http"
@@ -9,61 +10,84 @@ import (
"strings"
"sync/atomic"
portainer "github.com/portainer/portainer/api"
"github.com/rs/zerolog/log"
)
// Mode controls how the SSRF policy is applied.
type Mode string
// ParseAllowedHosts parses raw allow list entries into their three canonical
// forms. Accepted formats: exact hostname, wildcard hostname (*.example.com),
// single IP, or CIDR range.
func ParseAllowedHosts(entries []string) portainer.ParsedAllowList {
nets := make([]*net.IPNet, 0, len(entries))
hosts := make(map[string]bool, len(entries))
var wilds []string
const (
// ModeOff disables SSRF protection entirely. All connections pass through unchanged.
ModeOff Mode = "off"
// ModeAudit resolves and checks destinations but only logs violations; connections are allowed.
ModeAudit Mode = "audit"
// ModeEnforce blocks connections that violate the policy.
ModeEnforce Mode = "enforce"
)
for _, entry := range entries {
if _, network, err := net.ParseCIDR(entry); err == nil {
nets = append(nets, network)
continue
}
// Policy defines the SSRF protection policy for outbound HTTP connections.
type Policy struct {
// Mode controls whether protection is off, in audit-only mode, or enforcing.
Mode Mode
if ip := net.ParseIP(entry); ip != nil {
bits := 32
if ip.To4() == nil {
bits = 128
}
// AllowedHosts is the allowlist of permitted destinations.
// Accepted formats:
// - Exact hostname: "example.com"
// - Wildcard hostname: "*.example.com" (matches any subdomain at any depth)
// - Single IP: "1.2.3.4"
// - CIDR range: "10.0.0.0/8"
//
// When Mode is ModeEnforce and AllowedHosts is empty, all outbound connections are blocked.
AllowedHosts []string
mask := net.CIDRMask(bits, bits)
nets = append(nets, &net.IPNet{IP: ip.Mask(mask), Mask: mask})
continue
}
if strings.HasPrefix(entry, "*.") {
wilds = append(wilds, entry[1:]) // "*.foo.com" -> ".foo.com"
continue
}
hosts[entry] = true
}
return portainer.ParsedAllowList{Nets: nets, Hosts: hosts, Wilds: wilds}
}
// AllowListService is implemented by the allowlist data service.
// ReadParsed is called on every dial to pick up runtime changes.
type AllowListService interface {
ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error)
}
type safeDialer struct {
base net.Dialer
mode Mode
allowedNets []*net.IPNet
allowedHosts map[string]bool
allowedWilds []string // derived from "*.foo.com" entries; stored as ".foo.com"
base net.Dialer
service AllowListService
}
var globalDialer atomic.Pointer[safeDialer]
// Configure initializes the global SSRF policy. Intended to be called once
// at startup before any outbound HTTP connections are established.
func Configure(policy Policy) {
if policy.Mode == ModeOff || policy.Mode == "" {
globalDialer.Store(nil)
return
// Configure initializes the global SSRF policy with the allow list data service.
func Configure(svc AllowListService) error {
if svc == nil {
return errors.New("unable to configure ssrf: service must not be nil")
}
globalDialer.Store(newSafeDialer(policy))
globalDialer.Store(&safeDialer{service: svc})
return nil
}
// IsEnabled reports whether SSRF protection is currently active (audit or enforce).
func IsEnabled() bool {
return globalDialer.Load() != nil
d := globalDialer.Load()
if d == nil {
return false
}
allowList, err := d.service.ReadParsed(portainer.AllowListSSRF)
if err != nil {
log.Err(err).Msg("unable to check SSRF protection mode")
return false
}
return allowList.Mode != portainer.SSRFModeOff
}
// CheckURL validates rawURL against the active SSRF policy without making a
@@ -89,7 +113,8 @@ func CheckURL(ctx context.Context, rawURL string) error {
}
// WrapTransport clones t and replaces its DialContext with the global SSRF-filtering
// dialer. Returns t unchanged when SSRF protection is not configured.
// dialer. The dialer checks the mode on every connection, so the transport is always
// wrapped and mode changes take effect without restarting.
func WrapTransport(t *http.Transport) *http.Transport {
d := globalDialer.Load()
if d == nil {
@@ -112,48 +137,18 @@ func WrapTransportInternal(t *http.Transport) *http.Transport {
return t
}
func newSafeDialer(policy Policy) *safeDialer {
allowedNets := make([]*net.IPNet, 0, len(policy.AllowedHosts))
allowedHosts := make(map[string]bool, len(policy.AllowedHosts))
var allowedWilds []string
for _, entry := range policy.AllowedHosts {
if _, network, err := net.ParseCIDR(entry); err == nil {
allowedNets = append(allowedNets, network)
continue
}
if ip := net.ParseIP(entry); ip != nil {
bits := 32
if ip.To4() == nil {
bits = 128
}
mask := net.CIDRMask(bits, bits)
allowedNets = append(allowedNets, &net.IPNet{IP: ip.Mask(mask), Mask: mask})
continue
}
if strings.HasPrefix(entry, "*.") {
allowedWilds = append(allowedWilds, entry[1:]) // "*.foo.com" -> ".foo.com"
continue
}
allowedHosts[entry] = true
}
return &safeDialer{
mode: policy.Mode,
allowedNets: allowedNets,
allowedHosts: allowedHosts,
allowedWilds: allowedWilds,
}
}
// DialContext resolves addr, validates all resolved IPs against the allowlist policy,
// then dials using the first resolved IP to prevent DNS rebinding attacks.
func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
allowList, err := d.service.ReadParsed(portainer.AllowListSSRF)
if err != nil {
return nil, fmt.Errorf("ssrf: reading allow list: %w", err)
}
if allowList.Mode == portainer.SSRFModeOff {
return d.base.DialContext(ctx, network, addr)
}
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("ssrf: invalid address %q: %w", addr, err)
@@ -172,13 +167,13 @@ func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net
// window between DNS validation and the TCP handshake (DNS rebinding).
dialTarget := net.JoinHostPort(resolved[0].IP.String(), port)
if d.allowedHosts[host] || d.matchesWildcard(host) {
if allowList.Hosts[host] || matchesWildcard(host, allowList.Wilds) {
return d.base.DialContext(ctx, network, dialTarget)
}
for _, a := range resolved {
if !d.ipAllowed(a.IP) {
if d.mode == ModeAudit {
if !ipAllowed(a.IP, allowList.Nets) {
if allowList.Mode == portainer.SSRFModeAudit {
log.Warn().Str("host", host).Str("ip", a.IP.String()).Msg("ssrf: destination not in allowlist (audit mode, allowing)")
continue
}
@@ -191,13 +186,22 @@ func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net
}
func (d *safeDialer) checkHost(ctx context.Context, host string) error {
if d.allowedHosts[host] || d.matchesWildcard(host) {
allowList, err := d.service.ReadParsed(portainer.AllowListSSRF)
if err != nil {
return fmt.Errorf("ssrf: reading allow list: %w", err)
}
if allowList.Mode == portainer.SSRFModeOff {
return nil
}
if allowList.Hosts[host] || matchesWildcard(host, allowList.Wilds) {
return nil
}
if ip := net.ParseIP(host); ip != nil {
if !d.ipAllowed(ip) {
if d.mode == ModeAudit {
if !ipAllowed(ip, allowList.Nets) {
if allowList.Mode == portainer.SSRFModeAudit {
log.Warn().Str("host", host).Msg("ssrf: destination not in allowlist (audit mode, allowing)")
return nil
}
@@ -218,8 +222,8 @@ func (d *safeDialer) checkHost(ctx context.Context, host string) error {
}
for _, a := range resolved {
if !d.ipAllowed(a.IP) {
if d.mode == ModeAudit {
if !ipAllowed(a.IP, allowList.Nets) {
if allowList.Mode == portainer.SSRFModeAudit {
log.Warn().Str("host", host).Str("ip", a.IP.String()).Msg("ssrf: destination not in allowlist (audit mode, allowing)")
continue
}
@@ -231,8 +235,8 @@ func (d *safeDialer) checkHost(ctx context.Context, host string) error {
return nil
}
func (d *safeDialer) matchesWildcard(host string) bool {
for _, suffix := range d.allowedWilds {
func matchesWildcard(host string, wilds []string) bool {
for _, suffix := range wilds {
if strings.HasSuffix(host, suffix) {
return true
}
@@ -241,8 +245,8 @@ func (d *safeDialer) matchesWildcard(host string) bool {
return false
}
func (d *safeDialer) ipAllowed(ip net.IP) bool {
for _, network := range d.allowedNets {
func ipAllowed(ip net.IP, nets []*net.IPNet) bool {
for _, network := range nets {
if network.Contains(ip) {
return true
}
+232 -184
View File
@@ -7,74 +7,113 @@ import (
"net/http/httptest"
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/stretchr/testify/require"
)
func TestIpAllowed_CIDR(t *testing.T) {
t.Parallel()
d := newSafeDialer(Policy{
Mode: ModeEnforce,
AllowedHosts: []string{"8.8.0.0/16", "2001:4860::/32"},
})
require.True(t, d.ipAllowed(net.ParseIP("8.8.8.8")))
require.True(t, d.ipAllowed(net.ParseIP("8.8.4.4")))
require.True(t, d.ipAllowed(net.ParseIP("2001:4860:4860::8888")))
require.False(t, d.ipAllowed(net.ParseIP("1.1.1.1")))
require.False(t, d.ipAllowed(net.ParseIP("127.0.0.1")))
require.False(t, d.ipAllowed(net.ParseIP("169.254.169.254")))
// staticService is a simple in-memory AllowListService for testing.
type staticService struct {
parsed portainer.ParsedAllowList
}
func TestIpAllowed_SingleIP(t *testing.T) {
func (s *staticService) ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error) {
return &s.parsed, nil
}
func newStaticService(mode portainer.SSRFMode, entries []string) *staticService {
parsed := ParseAllowedHosts(entries)
parsed.Mode = mode
return &staticService{parsed: parsed}
}
func TestParseAllowedHosts_ipAllowed(t *testing.T) {
t.Parallel()
d := newSafeDialer(Policy{
Mode: ModeEnforce,
AllowedHosts: []string{"1.2.3.4"},
})
testCases := []struct {
name string
hostEntries []string
allowed []string
denied []string
}{
{
name: "CIDR",
hostEntries: []string{"8.8.0.0/16", "2001:4860::/32"},
allowed: []string{"8.8.8.8", "8.8.4.4", "2001:4860:4860::8888"},
denied: []string{"1.1.1.1", "127.0.0.1", "169.254.169.254"},
},
{
name: "Single IP",
hostEntries: []string{"1.2.3.4"},
allowed: []string{"1.2.3.4"},
denied: []string{"1.2.3.5"},
},
{
name: "Single IPv6",
hostEntries: []string{"::1"},
allowed: []string{"::1"},
denied: []string{"::2"},
},
}
require.True(t, d.ipAllowed(net.ParseIP("1.2.3.4")))
require.False(t, d.ipAllowed(net.ParseIP("1.2.3.5")))
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
parsed := ParseAllowedHosts(tc.hostEntries)
for _, a := range tc.allowed {
require.True(t, ipAllowed(net.ParseIP(a), parsed.Nets))
}
for _, d := range tc.denied {
require.False(t, ipAllowed(net.ParseIP(d), parsed.Nets))
}
})
}
}
func TestParseAllowedHosts_MixedEntries(t *testing.T) {
t.Parallel()
parsed := ParseAllowedHosts([]string{"example.com", "*.internal.net", "10.0.0.0/8", "1.2.3.4"})
require.True(t, parsed.Hosts["example.com"])
require.Contains(t, parsed.Wilds, ".internal.net")
require.Len(t, parsed.Nets, 2) // 10.0.0.0/8 and 1.2.3.4/32
}
func TestMatchesWildcard(t *testing.T) {
t.Parallel()
d := newSafeDialer(Policy{
Mode: ModeEnforce,
AllowedHosts: []string{"*.example.com", "exact.host.com"},
})
parsed := ParseAllowedHosts([]string{"*.example.com", "exact.host.com"})
require.True(t, d.matchesWildcard("foo.example.com"))
require.True(t, d.matchesWildcard("bar.example.com"))
require.True(t, d.matchesWildcard("deep.nested.example.com"))
tests := []struct {
host string
want bool
}{
{"foo.example.com", true},
{"bar.example.com", true},
{"deep.nested.example.com", true},
{"example.com", false},
{"notexample.com", false},
{"exact.host.com", false},
}
require.False(t, d.matchesWildcard("example.com"))
require.False(t, d.matchesWildcard("notexample.com"))
require.False(t, d.matchesWildcard("exact.host.com"))
for _, tc := range tests {
got := matchesWildcard(tc.host, parsed.Wilds)
require.Equal(t, tc.want, got, "host: %s", tc.host)
}
}
func TestNewSafeDialer_MixedHosts(t *testing.T) {
t.Parallel()
d := newSafeDialer(Policy{
Mode: ModeEnforce,
AllowedHosts: []string{"example.com", "*.internal.net", "10.0.0.0/8", "1.2.3.4"},
})
require.True(t, d.allowedHosts["example.com"])
require.Contains(t, d.allowedWilds, ".internal.net")
require.Len(t, d.allowedNets, 2) // 10.0.0.0/8 and 1.2.3.4/32
}
func TestConfigure_Disabled(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
func TestConfigure_SetsDialer(t *testing.T) {
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"}))
require.NoError(t, err)
require.NotNil(t, globalDialer.Load())
t.Cleanup(func() { globalDialer.Store(nil) })
}
Configure(Policy{})
require.Nil(t, globalDialer.Load())
func TestConfigure_NilServicesReturnsError(t *testing.T) {
err := Configure(nil)
require.Error(t, err)
}
func TestWrapTransport_NoPolicy(t *testing.T) {
@@ -86,7 +125,8 @@ func TestWrapTransport_NoPolicy(t *testing.T) {
}
func TestWrapTransport_WithPolicy(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"}))
require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) })
base := &http.Transport{}
@@ -95,36 +135,144 @@ func TestWrapTransport_WithPolicy(t *testing.T) {
require.NotNil(t, result.DialContext)
}
func TestCheckURL_Disabled(t *testing.T) {
globalDialer.Store(nil)
func TestCheckURL(t *testing.T) {
tests := []struct {
name string
mode portainer.SSRFMode
entries []string
url string
wantErr bool
}{
{
name: "disabled",
mode: portainer.SSRFModeOff,
url: "http://169.254.169.254/latest/meta-data/",
wantErr: false,
},
{
name: "blocks IP not in allowlist",
mode: portainer.SSRFModeEnforce,
entries: []string{"8.8.8.0/24"},
url: "http://169.254.169.254/latest/meta-data/",
wantErr: true,
},
{
name: "allowed exact hostname",
mode: portainer.SSRFModeEnforce,
entries: []string{"example.com"},
url: "https://example.com/path",
wantErr: false,
},
{
name: "audit mode allows blocked IP",
mode: portainer.SSRFModeAudit,
entries: []string{"8.8.8.0/24"},
url: "http://169.254.169.254/latest/meta-data/",
wantErr: false,
},
{
name: "IP in CIDR allowlist",
mode: portainer.SSRFModeEnforce,
entries: []string{"8.8.8.0/24"},
url: "http://8.8.8.8/path",
wantErr: false,
},
{
name: "wildcard hostname",
mode: portainer.SSRFModeEnforce,
entries: []string{"*.example.com"},
url: "https://api.example.com/path",
wantErr: false,
},
{
name: "hostname DNS resolves to allowed IP",
mode: portainer.SSRFModeEnforce,
entries: []string{"127.0.0.0/8", "::1/128"},
url: "http://localhost/path",
wantErr: false,
},
{
name: "hostname DNS resolves to blocked IP",
mode: portainer.SSRFModeEnforce,
entries: []string{"8.8.8.0/24"},
url: "http://localhost/path",
wantErr: true,
},
{
name: "audit mode allows hostname resolving to blocked IP",
mode: portainer.SSRFModeAudit,
entries: []string{"8.8.8.0/24"},
url: "http://localhost/path",
wantErr: false,
},
{
name: "invalid URL",
mode: portainer.SSRFModeEnforce,
entries: []string{"example.com"},
url: "http://%gg",
wantErr: true,
},
{
name: "empty host",
mode: portainer.SSRFModeEnforce,
entries: []string{"example.com"},
url: "http://",
wantErr: false,
},
}
err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/")
require.NoError(t, err)
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := Configure(newStaticService(tc.mode, tc.entries))
require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) })
err = CheckURL(t.Context(), tc.url)
if tc.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), "ssrf")
} else {
require.NoError(t, err)
}
})
}
}
func TestCheckURL_BlocksIPNotInAllowlist(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}})
// TestCheckURL_HostnameDNSError verifies that a DNS resolution failure is
// propagated as an SSRF-prefixed error. Kept separate because it needs a
// cancelled context rather than t.Context().
func TestCheckURL_HostnameDNSError(t *testing.T) {
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"8.8.8.0/24"}))
require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/")
ctx, cancel := context.WithCancel(t.Context())
cancel()
err = CheckURL(ctx, "http://portainer-nonexistent.test.invalid/path")
require.Error(t, err)
require.Contains(t, err.Error(), "ssrf")
}
func TestCheckURL_AllowedHostname(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { globalDialer.Store(nil) })
func TestIsEnabled(t *testing.T) {
globalDialer.Store(nil)
require.False(t, IsEnabled())
err := CheckURL(t.Context(), "https://example.com/path")
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"}))
require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) })
require.True(t, IsEnabled())
err = Configure(newStaticService(portainer.SSRFModeOff, nil))
require.NoError(t, err)
require.False(t, IsEnabled())
}
func TestCheckURL_AuditMode_ReturnsNil(t *testing.T) {
Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.0/24"}})
t.Cleanup(func() { globalDialer.Store(nil) })
func TestWrapTransportInternal(t *testing.T) {
t.Parallel()
err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/")
require.NoError(t, err)
base := &http.Transport{}
result := WrapTransportInternal(base)
require.Equal(t, base, result)
}
// TestDialContext_BlocksLoopback is an end-to-end test: it starts a real HTTP
@@ -136,7 +284,8 @@ func TestDialContext_BlocksLoopback(t *testing.T) {
}))
defer srv.Close()
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.8"}})
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"8.8.8.8"}))
require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) })
blocked := &http.Client{Transport: WrapTransport(&http.Transport{})}
@@ -147,7 +296,9 @@ func TestDialContext_BlocksLoopback(t *testing.T) {
require.NoError(t, resp.Body.Close())
}
Configure(Policy{})
// Switch to off mode — dialer stays configured but checks are skipped.
err = Configure(newStaticService(portainer.SSRFModeOff, nil))
require.NoError(t, err)
open := &http.Client{Transport: WrapTransport(&http.Transport{})}
resp, err = open.Get(srv.URL)
@@ -163,7 +314,8 @@ func TestDialContext_AuditMode_AllowsLoopback(t *testing.T) {
}))
defer srv.Close()
Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.8"}})
err := Configure(newStaticService(portainer.SSRFModeAudit, []string{"8.8.8.8"}))
require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) })
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
@@ -172,122 +324,15 @@ func TestDialContext_AuditMode_AllowsLoopback(t *testing.T) {
require.NoError(t, resp.Body.Close())
}
func TestIsEnabled(t *testing.T) {
globalDialer.Store(nil)
require.False(t, IsEnabled())
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { globalDialer.Store(nil) })
require.True(t, IsEnabled())
}
func TestWrapTransportInternal(t *testing.T) {
t.Parallel()
base := &http.Transport{}
result := WrapTransportInternal(base)
require.Equal(t, base, result)
}
func TestNewSafeDialer_IPv6SingleIP(t *testing.T) {
t.Parallel()
d := newSafeDialer(Policy{
Mode: ModeEnforce,
AllowedHosts: []string{"::1"},
})
require.True(t, d.ipAllowed(net.ParseIP("::1")))
require.False(t, d.ipAllowed(net.ParseIP("::2")))
}
func TestCheckURL_InvalidURL(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "http://%gg")
require.Error(t, err)
require.Contains(t, err.Error(), "ssrf")
}
func TestCheckURL_EmptyHost(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "http://")
require.NoError(t, err)
}
// TestCheckURL_IPInAllowlist verifies that a literal IP address that falls
// within an allowed CIDR range is permitted.
func TestCheckURL_IPInAllowlist(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}})
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "http://8.8.8.8/path")
require.NoError(t, err)
}
func TestCheckURL_WildcardHostname(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"*.example.com"}})
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "https://api.example.com/path")
require.NoError(t, err)
}
// TestCheckURL_HostnameDNSResolvesToAllowedIP verifies that a hostname
// resolving to an IP within the allowlist is permitted (DNS resolution path).
func TestCheckURL_HostnameDNSResolvesToAllowedIP(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"127.0.0.0/8", "::1/128"}})
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "http://localhost/path")
require.NoError(t, err)
}
// TestCheckURL_HostnameDNSResolvesToBlockedIP verifies that a hostname
// resolving to an IP outside the allowlist is blocked (DNS resolution path).
func TestCheckURL_HostnameDNSResolvesToBlockedIP(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}})
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "http://localhost/path")
require.Error(t, err)
require.Contains(t, err.Error(), "ssrf")
}
// TestCheckURL_HostnameDNSAuditMode verifies that audit mode logs violations
// from hostname DNS resolution but still returns nil.
func TestCheckURL_HostnameDNSAuditMode(t *testing.T) {
Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.0/24"}})
t.Cleanup(func() { globalDialer.Store(nil) })
err := CheckURL(t.Context(), "http://localhost/path")
require.NoError(t, err)
}
// TestCheckURL_HostnameDNSError verifies that a DNS resolution failure is
// propagated as an SSRF-prefixed error.
func TestCheckURL_HostnameDNSError(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}})
t.Cleanup(func() { globalDialer.Store(nil) })
ctx, cancel := context.WithCancel(t.Context())
cancel()
err := CheckURL(ctx, "http://portainer-nonexistent.test.invalid/path")
require.Error(t, err)
}
// TestDialContext_InvalidAddress verifies that an address without a port
// returns an SSRF-prefixed error.
func TestDialContext_InvalidAddress(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"}))
require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) })
d := globalDialer.Load()
_, err := d.DialContext(t.Context(), "tcp", "no-port-here")
_, err = d.DialContext(t.Context(), "tcp", "no-port-here")
require.Error(t, err)
require.Contains(t, err.Error(), "ssrf")
}
@@ -295,14 +340,15 @@ func TestDialContext_InvalidAddress(t *testing.T) {
// TestDialContext_DNSError verifies that a DNS resolution failure in
// DialContext is propagated as an SSRF-prefixed error.
func TestDialContext_DNSError(t *testing.T) {
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{}})
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{}))
require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) })
ctx, cancel := context.WithCancel(t.Context())
cancel()
d := globalDialer.Load()
_, err := d.DialContext(ctx, "tcp", "portainer-nonexistent.test.invalid:80")
_, err = d.DialContext(ctx, "tcp", "portainer-nonexistent.test.invalid:80")
require.Error(t, err)
}
@@ -314,7 +360,8 @@ func TestDialContext_AllowedByCIDR(t *testing.T) {
}))
defer srv.Close()
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"127.0.0.0/8"}})
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"127.0.0.0/8"}))
require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) })
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
@@ -347,7 +394,8 @@ func TestDialContext_AllowedByExactHostname(t *testing.T) {
_, portStr, err := net.SplitHostPort(l.Addr().String())
require.NoError(t, err)
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"localhost"}})
err = Configure(newStaticService(portainer.SSRFModeEnforce, []string{"localhost"}))
require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) })
client := &http.Client{Transport: WrapTransport(&http.Transport{})}