diff --git a/api/cli/cli.go b/api/cli/cli.go index 5e1148ee51..d7fa93cc63 100644 --- a/api/cli/cli.go +++ b/api/cli/cli.go @@ -56,8 +56,6 @@ func CLIFlags() *portainer.CLIFlags { TrustedOrigins: kingpin.Flag("trusted-origins", "List of trusted origins for CSRF protection. Separate multiple origins with a comma.").Envar(portainer.TrustedOriginsEnvVar).String(), CSP: kingpin.Flag("csp", "Content Security Policy (CSP) header").Envar(portainer.CSPEnvVar).Default("true").Bool(), CompactDB: kingpin.Flag("compact-db", "Enable database compaction on startup").Envar(portainer.CompactDBEnvVar).Default("false").Bool(), - SSRFMode: kingpin.Flag("ssrf-mode", "SSRF protection mode: off (disabled), audit (log violations but allow), enforce (block violations)").Envar("PORTAINER_SSRF_MODE").Default("off").Enum("off", "audit", "enforce"), - SSRFAllowedHosts: kingpin.Flag("ssrf-allowed-hosts", "Allowlist of hostnames (with optional wildcards), IPs, or CIDRs permitted for outbound requests. When empty and mode is enforce, all outbound connections are blocked").Envar("PORTAINER_SSRF_ALLOWED_HOSTS").Strings(), NoSetupToken: kingpin.Flag("no-setup-token", "Disable the setup token requirement for admin initialization and restore on an uninitialized instance").Envar(portainer.NoSetupTokenEnvVar).Bool(), SetupToken: kingpin.Flag("setup-token", "Set a custom setup token for admin initialization and restore on an uninitialized instance (overrides auto-generation)").Envar(portainer.SetupTokenEnvVar).String(), } diff --git a/api/cmd/portainer/main.go b/api/cmd/portainer/main.go index 1d77814e45..8b92aee116 100644 --- a/api/cmd/portainer/main.go +++ b/api/cmd/portainer/main.go @@ -387,19 +387,6 @@ func buildServer(flags *portainer.CLIFlags, shutdownCtx context.Context, shutdow // -ce can not ever be run in FIPS mode fips.InitFIPS(false) - ssrf.Configure(ssrf.Policy{ - Mode: ssrf.Mode(*flags.SSRFMode), - AllowedHosts: *flags.SSRFAllowedHosts, - }) - - if ssrf.IsEnabled() { - if dt, ok := nethttp.DefaultTransport.(*nethttp.Transport); ok { - nethttp.DefaultTransport = ssrf.WrapTransport(dt) - } - - gogithttp.DefaultClient = gogithttp.NewClient(&nethttp.Client{Transport: nethttp.DefaultTransport}) - } - fileService := initFileService(*flags.Data) encryptionKey := loadEncryptionSecretKey(dbSecretPath(*flags.SecretKeyName)) if encryptionKey == nil { @@ -417,6 +404,16 @@ func buildServer(flags *portainer.CLIFlags, shutdownCtx context.Context, shutdow log.Fatal().Msg("The database schema version does not align with the server version. Please consider reverting to the previous server version or addressing the database migration issue.") } + if err := ssrf.Configure(dataStore.AllowList()); err != nil { + log.Fatal().Err(err).Msg("failed initializing ssrf service") + } + + if dt, ok := nethttp.DefaultTransport.(*nethttp.Transport); ok { + nethttp.DefaultTransport = ssrf.WrapTransport(dt) + } + + gogithttp.DefaultClient = gogithttp.NewClient(&nethttp.Client{Transport: nethttp.DefaultTransport}) + instanceID, err := dataStore.Version().InstanceID() if err != nil { log.Fatal().Err(err).Msg("failed getting instance id") diff --git a/api/dataservices/allowlist/allowlist.go b/api/dataservices/allowlist/allowlist.go new file mode 100644 index 0000000000..adaa8856be --- /dev/null +++ b/api/dataservices/allowlist/allowlist.go @@ -0,0 +1,131 @@ +package allowlist + +import ( + "fmt" + + lru "github.com/hashicorp/golang-lru" + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" + "github.com/portainer/portainer/pkg/libhttp/ssrf" +) + +const ( + BucketName = "allowlist" +) + +type Service struct { + baseService dataservices.BaseDataService[portainer.AllowList, portainer.AllowListKey] + cache *lru.Cache +} + +func (service *Service) BucketName() string { + return service.baseService.BucketName() +} + +func NewService(connection portainer.Connection) (*Service, error) { + err := connection.SetServiceName(BucketName) + if err != nil { + return nil, err + } + + service := &Service{ + baseService: dataservices.BaseDataService[portainer.AllowList, portainer.AllowListKey]{ + Bucket: BucketName, + Connection: connection, + }, + } + + err = service.populateCache() + + return service, err +} + +func (service *Service) populateCache() error { + allowListKeys := []portainer.AllowListKey{portainer.AllowListSSRF} + cache, err := lru.New(len(allowListKeys)) + if err != nil { + return err + } + + for _, k := range allowListKeys { + allowList, err := service.baseService.Read(k) + if dataservices.IsErrObjectNotFound(err) { + allowList = &portainer.AllowList{ + ID: k, + Mode: portainer.SSRFModeOff, + Entries: []string{}, + } + } else if err != nil { + return err + } + + parsedAllowList := ssrf.ParseAllowedHosts(allowList.Entries) + parsedAllowList.Mode = allowList.Mode + + cache.Add(k, &parsedAllowList) + } + + service.cache = cache + + return nil +} + +func (service *Service) Tx(tx portainer.Transaction) *ServiceTx { + return &ServiceTx{ + baseService: service.baseService.Tx(tx), + cache: service.cache, + } +} + +func (service *Service) Read(id portainer.AllowListKey) (*portainer.AllowList, error) { + var result *portainer.AllowList + if err := service.baseService.Connection.ViewTx(func(tx portainer.Transaction) error { + var err error + result, err = service.Tx(tx).Read(id) + return err + }); err != nil { + return nil, err + } + + return result, nil +} + +func (service *Service) ReadAll() ([]portainer.AllowList, error) { + var result []portainer.AllowList + if err := service.baseService.Connection.ViewTx(func(tx portainer.Transaction) error { + var err error + result, err = service.Tx(tx).ReadAll() + return err + }); err != nil { + return nil, err + } + + return result, nil +} + +func (service *Service) ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error) { + allowListAny, ok := service.cache.Get(id) + if ok { + allowList, ok := allowListAny.(*portainer.ParsedAllowList) + if !ok { + return nil, fmt.Errorf("expected ParsedAllowList in cache but got %T", allowListAny) + } + + return allowList, nil + } + + var result *portainer.ParsedAllowList + err := service.baseService.Connection.ViewTx(func(tx portainer.Transaction) error { + var err error + result, err = service.Tx(tx).ReadParsed(id) + return err + }) + + return result, err +} + +func (service *Service) Update(id portainer.AllowListKey, allowList *portainer.AllowList) error { + return service.baseService.Connection.UpdateTx(func(tx portainer.Transaction) error { + return service.Tx(tx).Update(id, allowList) + }) +} diff --git a/api/dataservices/allowlist/allowlist_test.go b/api/dataservices/allowlist/allowlist_test.go new file mode 100644 index 0000000000..480f369784 --- /dev/null +++ b/api/dataservices/allowlist/allowlist_test.go @@ -0,0 +1,89 @@ +package allowlist_test + +import ( + "net" + "testing" + + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/datastore" + "github.com/stretchr/testify/require" +) + +func TestAllowListReadEmpty(t *testing.T) { + t.Parallel() + _, ds := datastore.MustNewTestStore(t, false, false) + got, err := ds.AllowList().Read(portainer.AllowListSSRF) + expected := &portainer.AllowList{ + ID: portainer.AllowListSSRF, + Mode: portainer.SSRFModeOff, + Entries: []string{}, + } + require.NoError(t, err) + require.Equal(t, expected, got) +} + +func TestAllowListUpdate(t *testing.T) { + t.Parallel() + _, ds := datastore.MustNewTestStore(t, false, false) + + expected := &portainer.AllowList{ + ID: portainer.AllowListSSRF, + Mode: portainer.SSRFModeEnforce, + Entries: []string{"example.com", "10.0.0.0/8"}, + } + + require.NoError(t, ds.AllowList().Update(portainer.AllowListSSRF, expected)) + + got, err := ds.AllowList().Read(portainer.AllowListSSRF) + require.NoError(t, err) + require.Equal(t, expected, got) +} + +func TestAllowListReadAllEmpty(t *testing.T) { + t.Parallel() + _, ds := datastore.MustNewTestStore(t, false, false) + + got, err := ds.AllowList().ReadAll() + require.NoError(t, err) + require.Equal(t, []portainer.AllowList{}, got) +} + +func TestAllowListReadAllAfterUpdate(t *testing.T) { + t.Parallel() + _, ds := datastore.MustNewTestStore(t, false, false) + + expected := portainer.AllowList{ + ID: portainer.AllowListSSRF, + Mode: portainer.SSRFModeEnforce, + Entries: []string{"example.com", "10.0.0.0/8"}, + } + + require.NoError(t, ds.AllowList().Update(portainer.AllowListSSRF, &expected)) + + got, err := ds.AllowList().ReadAll() + require.NoError(t, err) + require.Equal(t, []portainer.AllowList{expected}, got) +} + +func TestAllowListReadParsedAfterUpdate(t *testing.T) { + t.Parallel() + _, ds := datastore.MustNewTestStore(t, false, false) + + require.NoError(t, ds.AllowList().Update(portainer.AllowListSSRF, &portainer.AllowList{ + ID: portainer.AllowListSSRF, + Mode: portainer.SSRFModeEnforce, + Entries: []string{"example.com"}, + })) + + expected := &portainer.ParsedAllowList{ + Mode: portainer.SSRFModeEnforce, + Nets: []*net.IPNet{}, + Hosts: map[string]bool{ + "example.com": true, + }, + } + + got, err := ds.AllowList().ReadParsed(portainer.AllowListSSRF) + require.NoError(t, err) + require.Equal(t, expected, got) +} diff --git a/api/dataservices/allowlist/tx.go b/api/dataservices/allowlist/tx.go new file mode 100644 index 0000000000..dcf49a593b --- /dev/null +++ b/api/dataservices/allowlist/tx.go @@ -0,0 +1,77 @@ +package allowlist + +import ( + "fmt" + + lru "github.com/hashicorp/golang-lru" + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" + "github.com/portainer/portainer/pkg/libhttp/ssrf" +) + +type ServiceTx struct { + baseService dataservices.BaseDataServiceTx[portainer.AllowList, portainer.AllowListKey] + cache *lru.Cache +} + +func (service *ServiceTx) BucketName() string { + return service.baseService.BucketName() +} + +func (service *ServiceTx) ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error) { + allowListAny, ok := service.cache.Get(id) + if ok { + allowList, ok := allowListAny.(*portainer.ParsedAllowList) + if !ok { + return nil, fmt.Errorf("expected ParsedAllowList in cache but got %T", allowListAny) + } + + return allowList, nil + } + + allowList, err := service.Read(id) + if err != nil { + return nil, err + } + + parsed := ssrf.ParseAllowedHosts(allowList.Entries) + parsed.Mode = allowList.Mode + service.cache.Add(id, &parsed) + + return &parsed, nil +} + +func (service *ServiceTx) Read(id portainer.AllowListKey) (*portainer.AllowList, error) { + allowList, err := service.baseService.Read(id) + if dataservices.IsErrObjectNotFound(err) { + allowList = &portainer.AllowList{ + ID: id, + Mode: portainer.SSRFModeOff, + Entries: []string{}, + } + } else if err != nil { + return nil, err + } + + return allowList, nil +} + +func (service *ServiceTx) ReadAll() ([]portainer.AllowList, error) { + allowLists, err := service.baseService.ReadAll() + if err != nil && !dataservices.IsErrObjectNotFound(err) { + return nil, err + } + + return allowLists, nil +} + +func (service *ServiceTx) Update(id portainer.AllowListKey, allowList *portainer.AllowList) error { + if err := service.baseService.Update(id, allowList); err != nil { + return err + } + + parsed := ssrf.ParseAllowedHosts(allowList.Entries) + parsed.Mode = allowList.Mode + service.cache.Add(id, &parsed) + return nil +} diff --git a/api/dataservices/allowlist/tx_test.go b/api/dataservices/allowlist/tx_test.go new file mode 100644 index 0000000000..3b003459a8 --- /dev/null +++ b/api/dataservices/allowlist/tx_test.go @@ -0,0 +1,92 @@ +package allowlist_test + +import ( + "testing" + + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" + "github.com/portainer/portainer/api/datastore" + "github.com/stretchr/testify/require" +) + +func TestAllowListReadTx(t *testing.T) { + t.Parallel() + _, ds := datastore.MustNewTestStore(t, false, false) + + var got *portainer.AllowList + require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error { + var err error + got, err = tx.AllowList().Read(portainer.AllowListSSRF) + return err + })) + + expected := &portainer.AllowList{ + ID: portainer.AllowListSSRF, + Mode: portainer.SSRFModeOff, + Entries: []string{}, + } + + require.Equal(t, expected, got) +} + +func TestAllowListReadAllEmptyTx(t *testing.T) { + t.Parallel() + _, ds := datastore.MustNewTestStore(t, false, false) + + var got []portainer.AllowList + require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error { + var err error + got, err = tx.AllowList().ReadAll() + return err + })) + + require.Equal(t, []portainer.AllowList{}, got) +} + +func TestAllowListReadAllAfterUpdateTx(t *testing.T) { + t.Parallel() + _, ds := datastore.MustNewTestStore(t, false, false) + + expected := portainer.AllowList{ + ID: portainer.AllowListSSRF, + Mode: portainer.SSRFModeEnforce, + Entries: []string{"example.com"}, + } + + require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error { + return tx.AllowList().Update(portainer.AllowListSSRF, &expected) + })) + + var got []portainer.AllowList + require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error { + var err error + got, err = tx.AllowList().ReadAll() + return err + })) + + require.Equal(t, []portainer.AllowList{expected}, got) +} + +func TestAllowListUpdateTx(t *testing.T) { + t.Parallel() + _, ds := datastore.MustNewTestStore(t, false, false) + + expected := &portainer.AllowList{ + ID: portainer.AllowListSSRF, + Mode: portainer.SSRFModeEnforce, + Entries: []string{"example.com"}, + } + + require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error { + return tx.AllowList().Update(portainer.AllowListSSRF, expected) + })) + + var got *portainer.AllowList + require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error { + var err error + got, err = tx.AllowList().Read(portainer.AllowListSSRF) + return err + })) + + require.Equal(t, expected, got) +} diff --git a/api/dataservices/interface.go b/api/dataservices/interface.go index 4819f667f7..6b883496e3 100644 --- a/api/dataservices/interface.go +++ b/api/dataservices/interface.go @@ -8,6 +8,7 @@ import ( type ( DataStoreTx interface { IsErrObjectNotFound(err error) bool + AllowList() AllowListService CustomTemplate() CustomTemplateService EdgeGroup() EdgeGroupService EdgeJob() EdgeJobService @@ -53,6 +54,15 @@ type ( DataStoreTx } + // AllowListService represents a service for managing the URL allow list + AllowListService interface { + Read(id portainer.AllowListKey) (*portainer.AllowList, error) + ReadAll() ([]portainer.AllowList, error) + ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error) + Update(id portainer.AllowListKey, allowList *portainer.AllowList) error + BucketName() string + } + // CustomTemplateService represents a service to manage custom templates CustomTemplateService interface { BaseCRUD[portainer.CustomTemplate, portainer.CustomTemplateID] diff --git a/api/datastore/services.go b/api/datastore/services.go index 163cf92f47..84a5749231 100644 --- a/api/datastore/services.go +++ b/api/datastore/services.go @@ -7,6 +7,7 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/database/models" "github.com/portainer/portainer/api/dataservices" + "github.com/portainer/portainer/api/dataservices/allowlist" "github.com/portainer/portainer/api/dataservices/apikeyrepository" "github.com/portainer/portainer/api/dataservices/customtemplate" "github.com/portainer/portainer/api/dataservices/dockerhub" @@ -51,6 +52,7 @@ type Store struct { connection portainer.Connection fileService portainer.FileService + AllowListService *allowlist.Service CustomTemplateService *customtemplate.Service DockerHubService *dockerhub.Service EdgeGroupService *edgegroup.Service @@ -84,6 +86,12 @@ type Store struct { } func (store *Store) initServices() error { + allowListService, err := allowlist.NewService(store.connection) + if err != nil { + return err + } + store.AllowListService = allowListService + authorizationsetService, err := role.NewService(store.connection) if err != nil { return err @@ -275,6 +283,11 @@ func (store *Store) PendingActions() dataservices.PendingActionsService { return store.PendingActionsService } +// AllowList gives access to the AllowList data management layer +func (store *Store) AllowList() dataservices.AllowListService { + return store.AllowListService +} + // CustomTemplate gives access to the CustomTemplate data management layer func (store *Store) CustomTemplate() dataservices.CustomTemplateService { return store.CustomTemplateService @@ -654,7 +667,7 @@ func (store *Store) Export(filename string) (err error) { return err } - return os.WriteFile(filename, b, 0600) + return os.WriteFile(filename, b, 0o600) } func (store *Store) Import(filename string) (err error) { diff --git a/api/datastore/services_tx.go b/api/datastore/services_tx.go index 31fcd53706..bfb9dc7f4e 100644 --- a/api/datastore/services_tx.go +++ b/api/datastore/services_tx.go @@ -14,6 +14,10 @@ func (tx *StoreTx) IsErrObjectNotFound(err error) bool { return tx.store.IsErrObjectNotFound(err) } +func (tx *StoreTx) AllowList() dataservices.AllowListService { + return tx.store.AllowListService.Tx(tx.tx) +} + func (tx *StoreTx) CustomTemplate() dataservices.CustomTemplateService { return tx.store.CustomTemplateService.Tx(tx.tx) } diff --git a/api/datastore/test_data/output_24_to_latest.json b/api/datastore/test_data/output_24_to_latest.json index 309a41095f..16d6c3c06d 100644 --- a/api/datastore/test_data/output_24_to_latest.json +++ b/api/datastore/test_data/output_24_to_latest.json @@ -1,4 +1,5 @@ { + "allowlist": null, "api_key": null, "customtemplates": null, "dockerhub": [ diff --git a/api/docker/client/client.go b/api/docker/client/client.go index 6ef575b755..5609b290d9 100644 --- a/api/docker/client/client.go +++ b/api/docker/client/client.go @@ -90,7 +90,7 @@ func createTCPClient(endpoint *portainer.Endpoint, timeout *time.Duration) (*cli client.WithHTTPClient(httpCli), } - if nnTransport, ok := httpCli.Transport.(*NodeNameTransport); ok && nnTransport.TLSClientConfig != nil { + if endpoint.TLSConfig.TLS { opts = append(opts, client.WithScheme("https")) } @@ -124,7 +124,7 @@ func createAgentClient(endpoint *portainer.Endpoint, endpointURL string, signatu client.WithHTTPHeaders(headers), } - if nnTransport, ok := httpCli.Transport.(*NodeNameTransport); ok && nnTransport.TLSClientConfig != nil { + if endpoint.TLSConfig.TLS { opts = append(opts, client.WithScheme("https")) } diff --git a/api/http/proxy/factory/transport_test.go b/api/http/proxy/factory/transport_test.go index d6c623a666..44e5bc6c29 100644 --- a/api/http/proxy/factory/transport_test.go +++ b/api/http/proxy/factory/transport_test.go @@ -45,10 +45,24 @@ func (s *stubTunnelService) UpdateLastActivity(endpointID portainer.EndpointID) func (s *stubTunnelService) KeepTunnelAlive(endpointID portainer.EndpointID, ctx context.Context, maxKeepAlive time.Duration) { } +type staticAllowListService struct { + parsed portainer.ParsedAllowList +} + +func (s *staticAllowListService) ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error) { + return &s.parsed, nil +} + func enableSSRF(t *testing.T) { t.Helper() - ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}}) - t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) }) + parsed := ssrf.ParseAllowedHosts([]string{"example.com"}) + parsed.Mode = portainer.SSRFModeEnforce + err := ssrf.Configure(&staticAllowListService{parsed: parsed}) + require.NoError(t, err) + t.Cleanup(func() { + err := ssrf.Configure(&staticAllowListService{}) + require.NoError(t, err) + }) } // TestNewDockerHTTPProxy_NonEdgeNoTLS verifies that a plain non-edge endpoint diff --git a/api/internal/testhelpers/datastore.go b/api/internal/testhelpers/datastore.go index 3a28baf307..80b7e19cba 100644 --- a/api/internal/testhelpers/datastore.go +++ b/api/internal/testhelpers/datastore.go @@ -13,6 +13,7 @@ import ( var _ dataservices.DataStore = &testDatastore{} type testDatastore struct { + allowList dataservices.AllowListService customTemplate dataservices.CustomTemplateService edgeGroup dataservices.EdgeGroupService edgeJob dataservices.EdgeJobService @@ -53,6 +54,7 @@ func (d *testDatastore) ViewTx(func(dataservices.DataStoreTx) error) error { r func (d *testDatastore) CheckCurrentEdition() error { return nil } func (d *testDatastore) MigrateData() error { return nil } func (d *testDatastore) Rollback(force bool) error { return nil } +func (d *testDatastore) AllowList() dataservices.AllowListService { return d.allowList } func (d *testDatastore) CustomTemplate() dataservices.CustomTemplateService { return d.customTemplate } func (d *testDatastore) EdgeGroup() dataservices.EdgeGroupService { return d.edgeGroup } func (d *testDatastore) EdgeJob() dataservices.EdgeJobService { return d.edgeJob } diff --git a/api/portainer.go b/api/portainer.go index acd16edd7e..0a4aa8977b 100644 --- a/api/portainer.go +++ b/api/portainer.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net" "net/http" "time" @@ -111,8 +112,6 @@ type ( KubectlShellImageSet bool PullLimitCheckDisabled *bool TrustedOrigins *string - SSRFMode *string - SSRFAllowedHosts *[]string NoSetupToken *bool SetupToken *string } @@ -1215,6 +1214,21 @@ type ( // SoftwareEdition represents an edition of Portainer SoftwareEdition int + // AllowList holds the list of permitted outbound proxy destinations. + AllowList struct { + ID AllowListKey `json:"Id"` + Mode SSRFMode `json:"Mode"` + Entries []string `json:"Entries"` + } + + // ParsedAllowList holds the three parsed forms of allow list entries. + ParsedAllowList struct { + Mode SSRFMode + Nets []*net.IPNet + Hosts map[string]bool + Wilds []string // stored as ".foo.com" ("*." prefix stripped) + } + // SSLSettings represents a pair of SSL certificate and key SSLSettings struct { CertPath string `json:"certPath"` @@ -2670,3 +2684,17 @@ func DefaultEndpointSecuritySettings() EndpointSecuritySettings { AllowStackManagementForRegularUsers: true, } } + +type AllowListKey int + +const ( + AllowListSSRF AllowListKey = iota +) + +type SSRFMode int + +const ( + SSRFModeOff SSRFMode = iota + SSRFModeAudit + SSRFModeEnforce +) diff --git a/api/stacks/stackutils/validation_test.go b/api/stacks/stackutils/validation_test.go index 893f3925b4..9f020738c2 100644 --- a/api/stacks/stackutils/validation_test.go +++ b/api/stacks/stackutils/validation_test.go @@ -139,7 +139,6 @@ services: `) stack := &portainer.Stack{ - ProjectPath: "/tmp/stack/1", EntryPoint: "docker-compose.yml", Env: []portainer.Pair{{Name: "API_PORT", Value: "3000"}}, @@ -186,7 +185,7 @@ func TestValidateStackFiles_DotEnvFile(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - err := os.WriteFile(filesystem.JoinPaths(tmpDir, ".env"), []byte("HOST_PORT=3000\n"), 0600) + err := os.WriteFile(filesystem.JoinPaths(tmpDir, ".env"), []byte("HOST_PORT=3000\n"), 0o600) require.NoError(t, err) fileContent := []byte(` @@ -217,7 +216,7 @@ func TestValidateStackFiles_EnvFileAttribute(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - err := os.WriteFile(filesystem.JoinPaths(tmpDir, "web.env"), []byte("HOST_PORT=3000\n"), 0600) + err := os.WriteFile(filesystem.JoinPaths(tmpDir, "web.env"), []byte("HOST_PORT=3000\n"), 0o600) require.NoError(t, err) fileContent := []byte(` @@ -298,8 +297,29 @@ func TestExtractImageRegistry(t *testing.T) { } } +type staticAllowListService struct { + parsed portainer.ParsedAllowList +} + +func (s *staticAllowListService) ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error) { + return &s.parsed, nil +} + +func configureSSRF(t *testing.T, mode portainer.SSRFMode, entries []string) { + t.Helper() + + parsed := ssrf.ParseAllowedHosts(entries) + parsed.Mode = mode + err := ssrf.Configure(&staticAllowListService{parsed: parsed}) + require.NoError(t, err) + t.Cleanup(func() { + err := ssrf.Configure(&staticAllowListService{}) + require.NoError(t, err) + }) +} + func TestValidateComposeURLs_DisabledSSRF(t *testing.T) { - ssrf.Configure(ssrf.Policy{}) + configureSSRF(t, portainer.SSRFModeOff, nil) stack := &portainer.Stack{ ProjectPath: "/tmp/stack/1", @@ -322,8 +342,7 @@ services: } func TestValidateComposeURLs_BuildContextBlocked(t *testing.T) { - ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}}) - t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) }) + configureSSRF(t, portainer.SSRFModeEnforce, []string{"example.com"}) stack := &portainer.Stack{ ProjectPath: "/tmp/stack/1", @@ -347,8 +366,7 @@ services: } func TestValidateComposeURLs_BuildContextPath(t *testing.T) { - ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}}) - t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) }) + configureSSRF(t, portainer.SSRFModeEnforce, []string{"example.com"}) stack := &portainer.Stack{ ProjectPath: "/tmp/stack/1", @@ -372,8 +390,7 @@ services: } func TestValidateComposeURLs_ImageRegistryBlocked(t *testing.T) { - ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}}) - t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) }) + configureSSRF(t, portainer.SSRFModeEnforce, []string{"example.com"}) stack := &portainer.Stack{ ProjectPath: "/tmp/stack/1", @@ -395,8 +412,7 @@ services: } func TestValidateComposeURLs_ImageRegistryAllowed(t *testing.T) { - ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"myregistry.com"}}) - t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) }) + configureSSRF(t, portainer.SSRFModeEnforce, []string{"myregistry.com"}) stack := &portainer.Stack{ ProjectPath: "/tmp/stack/1", @@ -418,8 +434,7 @@ services: } func TestValidateComposeURLs_DockerHubImageAllowed(t *testing.T) { - ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}}) - t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) }) + configureSSRF(t, portainer.SSRFModeEnforce, []string{"example.com"}) stack := &portainer.Stack{ ProjectPath: "/tmp/stack/1", diff --git a/pkg/libhttp/ssrf/ssrf.go b/pkg/libhttp/ssrf/ssrf.go index 15a3ca1e18..1a35f95700 100644 --- a/pkg/libhttp/ssrf/ssrf.go +++ b/pkg/libhttp/ssrf/ssrf.go @@ -2,6 +2,7 @@ package ssrf import ( "context" + "errors" "fmt" "net" "net/http" @@ -9,61 +10,84 @@ import ( "strings" "sync/atomic" + portainer "github.com/portainer/portainer/api" "github.com/rs/zerolog/log" ) -// Mode controls how the SSRF policy is applied. -type Mode string +// ParseAllowedHosts parses raw allow list entries into their three canonical +// forms. Accepted formats: exact hostname, wildcard hostname (*.example.com), +// single IP, or CIDR range. +func ParseAllowedHosts(entries []string) portainer.ParsedAllowList { + nets := make([]*net.IPNet, 0, len(entries)) + hosts := make(map[string]bool, len(entries)) + var wilds []string -const ( - // ModeOff disables SSRF protection entirely. All connections pass through unchanged. - ModeOff Mode = "off" - // ModeAudit resolves and checks destinations but only logs violations; connections are allowed. - ModeAudit Mode = "audit" - // ModeEnforce blocks connections that violate the policy. - ModeEnforce Mode = "enforce" -) + for _, entry := range entries { + if _, network, err := net.ParseCIDR(entry); err == nil { + nets = append(nets, network) + continue + } -// Policy defines the SSRF protection policy for outbound HTTP connections. -type Policy struct { - // Mode controls whether protection is off, in audit-only mode, or enforcing. - Mode Mode + if ip := net.ParseIP(entry); ip != nil { + bits := 32 + if ip.To4() == nil { + bits = 128 + } - // AllowedHosts is the allowlist of permitted destinations. - // Accepted formats: - // - Exact hostname: "example.com" - // - Wildcard hostname: "*.example.com" (matches any subdomain at any depth) - // - Single IP: "1.2.3.4" - // - CIDR range: "10.0.0.0/8" - // - // When Mode is ModeEnforce and AllowedHosts is empty, all outbound connections are blocked. - AllowedHosts []string + mask := net.CIDRMask(bits, bits) + nets = append(nets, &net.IPNet{IP: ip.Mask(mask), Mask: mask}) + + continue + } + + if strings.HasPrefix(entry, "*.") { + wilds = append(wilds, entry[1:]) // "*.foo.com" -> ".foo.com" + continue + } + + hosts[entry] = true + } + + return portainer.ParsedAllowList{Nets: nets, Hosts: hosts, Wilds: wilds} +} + +// AllowListService is implemented by the allowlist data service. +// ReadParsed is called on every dial to pick up runtime changes. +type AllowListService interface { + ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error) } type safeDialer struct { - base net.Dialer - mode Mode - allowedNets []*net.IPNet - allowedHosts map[string]bool - allowedWilds []string // derived from "*.foo.com" entries; stored as ".foo.com" + base net.Dialer + service AllowListService } var globalDialer atomic.Pointer[safeDialer] -// Configure initializes the global SSRF policy. Intended to be called once -// at startup before any outbound HTTP connections are established. -func Configure(policy Policy) { - if policy.Mode == ModeOff || policy.Mode == "" { - globalDialer.Store(nil) - return +// Configure initializes the global SSRF policy with the allow list data service. +func Configure(svc AllowListService) error { + if svc == nil { + return errors.New("unable to configure ssrf: service must not be nil") } - globalDialer.Store(newSafeDialer(policy)) + globalDialer.Store(&safeDialer{service: svc}) + return nil } // IsEnabled reports whether SSRF protection is currently active (audit or enforce). func IsEnabled() bool { - return globalDialer.Load() != nil + d := globalDialer.Load() + if d == nil { + return false + } + + allowList, err := d.service.ReadParsed(portainer.AllowListSSRF) + if err != nil { + log.Err(err).Msg("unable to check SSRF protection mode") + return false + } + + return allowList.Mode != portainer.SSRFModeOff } // CheckURL validates rawURL against the active SSRF policy without making a @@ -89,7 +113,8 @@ func CheckURL(ctx context.Context, rawURL string) error { } // WrapTransport clones t and replaces its DialContext with the global SSRF-filtering -// dialer. Returns t unchanged when SSRF protection is not configured. +// dialer. The dialer checks the mode on every connection, so the transport is always +// wrapped and mode changes take effect without restarting. func WrapTransport(t *http.Transport) *http.Transport { d := globalDialer.Load() if d == nil { @@ -112,48 +137,18 @@ func WrapTransportInternal(t *http.Transport) *http.Transport { return t } -func newSafeDialer(policy Policy) *safeDialer { - allowedNets := make([]*net.IPNet, 0, len(policy.AllowedHosts)) - allowedHosts := make(map[string]bool, len(policy.AllowedHosts)) - var allowedWilds []string - - for _, entry := range policy.AllowedHosts { - if _, network, err := net.ParseCIDR(entry); err == nil { - allowedNets = append(allowedNets, network) - continue - } - - if ip := net.ParseIP(entry); ip != nil { - bits := 32 - if ip.To4() == nil { - bits = 128 - } - - mask := net.CIDRMask(bits, bits) - allowedNets = append(allowedNets, &net.IPNet{IP: ip.Mask(mask), Mask: mask}) - - continue - } - - if strings.HasPrefix(entry, "*.") { - allowedWilds = append(allowedWilds, entry[1:]) // "*.foo.com" -> ".foo.com" - continue - } - - allowedHosts[entry] = true - } - - return &safeDialer{ - mode: policy.Mode, - allowedNets: allowedNets, - allowedHosts: allowedHosts, - allowedWilds: allowedWilds, - } -} - // DialContext resolves addr, validates all resolved IPs against the allowlist policy, // then dials using the first resolved IP to prevent DNS rebinding attacks. func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + allowList, err := d.service.ReadParsed(portainer.AllowListSSRF) + if err != nil { + return nil, fmt.Errorf("ssrf: reading allow list: %w", err) + } + + if allowList.Mode == portainer.SSRFModeOff { + return d.base.DialContext(ctx, network, addr) + } + host, port, err := net.SplitHostPort(addr) if err != nil { return nil, fmt.Errorf("ssrf: invalid address %q: %w", addr, err) @@ -172,13 +167,13 @@ func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net // window between DNS validation and the TCP handshake (DNS rebinding). dialTarget := net.JoinHostPort(resolved[0].IP.String(), port) - if d.allowedHosts[host] || d.matchesWildcard(host) { + if allowList.Hosts[host] || matchesWildcard(host, allowList.Wilds) { return d.base.DialContext(ctx, network, dialTarget) } for _, a := range resolved { - if !d.ipAllowed(a.IP) { - if d.mode == ModeAudit { + if !ipAllowed(a.IP, allowList.Nets) { + if allowList.Mode == portainer.SSRFModeAudit { log.Warn().Str("host", host).Str("ip", a.IP.String()).Msg("ssrf: destination not in allowlist (audit mode, allowing)") continue } @@ -191,13 +186,22 @@ func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net } func (d *safeDialer) checkHost(ctx context.Context, host string) error { - if d.allowedHosts[host] || d.matchesWildcard(host) { + allowList, err := d.service.ReadParsed(portainer.AllowListSSRF) + if err != nil { + return fmt.Errorf("ssrf: reading allow list: %w", err) + } + + if allowList.Mode == portainer.SSRFModeOff { + return nil + } + + if allowList.Hosts[host] || matchesWildcard(host, allowList.Wilds) { return nil } if ip := net.ParseIP(host); ip != nil { - if !d.ipAllowed(ip) { - if d.mode == ModeAudit { + if !ipAllowed(ip, allowList.Nets) { + if allowList.Mode == portainer.SSRFModeAudit { log.Warn().Str("host", host).Msg("ssrf: destination not in allowlist (audit mode, allowing)") return nil } @@ -218,8 +222,8 @@ func (d *safeDialer) checkHost(ctx context.Context, host string) error { } for _, a := range resolved { - if !d.ipAllowed(a.IP) { - if d.mode == ModeAudit { + if !ipAllowed(a.IP, allowList.Nets) { + if allowList.Mode == portainer.SSRFModeAudit { log.Warn().Str("host", host).Str("ip", a.IP.String()).Msg("ssrf: destination not in allowlist (audit mode, allowing)") continue } @@ -231,8 +235,8 @@ func (d *safeDialer) checkHost(ctx context.Context, host string) error { return nil } -func (d *safeDialer) matchesWildcard(host string) bool { - for _, suffix := range d.allowedWilds { +func matchesWildcard(host string, wilds []string) bool { + for _, suffix := range wilds { if strings.HasSuffix(host, suffix) { return true } @@ -241,8 +245,8 @@ func (d *safeDialer) matchesWildcard(host string) bool { return false } -func (d *safeDialer) ipAllowed(ip net.IP) bool { - for _, network := range d.allowedNets { +func ipAllowed(ip net.IP, nets []*net.IPNet) bool { + for _, network := range nets { if network.Contains(ip) { return true } diff --git a/pkg/libhttp/ssrf/ssrf_test.go b/pkg/libhttp/ssrf/ssrf_test.go index 329d632257..2439957b29 100644 --- a/pkg/libhttp/ssrf/ssrf_test.go +++ b/pkg/libhttp/ssrf/ssrf_test.go @@ -7,74 +7,113 @@ import ( "net/http/httptest" "testing" + portainer "github.com/portainer/portainer/api" "github.com/stretchr/testify/require" ) -func TestIpAllowed_CIDR(t *testing.T) { - t.Parallel() - - d := newSafeDialer(Policy{ - Mode: ModeEnforce, - AllowedHosts: []string{"8.8.0.0/16", "2001:4860::/32"}, - }) - - require.True(t, d.ipAllowed(net.ParseIP("8.8.8.8"))) - require.True(t, d.ipAllowed(net.ParseIP("8.8.4.4"))) - require.True(t, d.ipAllowed(net.ParseIP("2001:4860:4860::8888"))) - - require.False(t, d.ipAllowed(net.ParseIP("1.1.1.1"))) - require.False(t, d.ipAllowed(net.ParseIP("127.0.0.1"))) - require.False(t, d.ipAllowed(net.ParseIP("169.254.169.254"))) +// staticService is a simple in-memory AllowListService for testing. +type staticService struct { + parsed portainer.ParsedAllowList } -func TestIpAllowed_SingleIP(t *testing.T) { +func (s *staticService) ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error) { + return &s.parsed, nil +} + +func newStaticService(mode portainer.SSRFMode, entries []string) *staticService { + parsed := ParseAllowedHosts(entries) + parsed.Mode = mode + return &staticService{parsed: parsed} +} + +func TestParseAllowedHosts_ipAllowed(t *testing.T) { t.Parallel() - d := newSafeDialer(Policy{ - Mode: ModeEnforce, - AllowedHosts: []string{"1.2.3.4"}, - }) + testCases := []struct { + name string + hostEntries []string + allowed []string + denied []string + }{ + { + name: "CIDR", + hostEntries: []string{"8.8.0.0/16", "2001:4860::/32"}, + allowed: []string{"8.8.8.8", "8.8.4.4", "2001:4860:4860::8888"}, + denied: []string{"1.1.1.1", "127.0.0.1", "169.254.169.254"}, + }, + { + name: "Single IP", + hostEntries: []string{"1.2.3.4"}, + allowed: []string{"1.2.3.4"}, + denied: []string{"1.2.3.5"}, + }, + { + name: "Single IPv6", + hostEntries: []string{"::1"}, + allowed: []string{"::1"}, + denied: []string{"::2"}, + }, + } - require.True(t, d.ipAllowed(net.ParseIP("1.2.3.4"))) - require.False(t, d.ipAllowed(net.ParseIP("1.2.3.5"))) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + parsed := ParseAllowedHosts(tc.hostEntries) + for _, a := range tc.allowed { + require.True(t, ipAllowed(net.ParseIP(a), parsed.Nets)) + } + + for _, d := range tc.denied { + require.False(t, ipAllowed(net.ParseIP(d), parsed.Nets)) + } + }) + } +} + +func TestParseAllowedHosts_MixedEntries(t *testing.T) { + t.Parallel() + + parsed := ParseAllowedHosts([]string{"example.com", "*.internal.net", "10.0.0.0/8", "1.2.3.4"}) + + require.True(t, parsed.Hosts["example.com"]) + require.Contains(t, parsed.Wilds, ".internal.net") + require.Len(t, parsed.Nets, 2) // 10.0.0.0/8 and 1.2.3.4/32 } func TestMatchesWildcard(t *testing.T) { t.Parallel() - d := newSafeDialer(Policy{ - Mode: ModeEnforce, - AllowedHosts: []string{"*.example.com", "exact.host.com"}, - }) + parsed := ParseAllowedHosts([]string{"*.example.com", "exact.host.com"}) - require.True(t, d.matchesWildcard("foo.example.com")) - require.True(t, d.matchesWildcard("bar.example.com")) - require.True(t, d.matchesWildcard("deep.nested.example.com")) + tests := []struct { + host string + want bool + }{ + {"foo.example.com", true}, + {"bar.example.com", true}, + {"deep.nested.example.com", true}, + {"example.com", false}, + {"notexample.com", false}, + {"exact.host.com", false}, + } - require.False(t, d.matchesWildcard("example.com")) - require.False(t, d.matchesWildcard("notexample.com")) - require.False(t, d.matchesWildcard("exact.host.com")) + for _, tc := range tests { + got := matchesWildcard(tc.host, parsed.Wilds) + require.Equal(t, tc.want, got, "host: %s", tc.host) + } } -func TestNewSafeDialer_MixedHosts(t *testing.T) { - t.Parallel() - - d := newSafeDialer(Policy{ - Mode: ModeEnforce, - AllowedHosts: []string{"example.com", "*.internal.net", "10.0.0.0/8", "1.2.3.4"}, - }) - - require.True(t, d.allowedHosts["example.com"]) - require.Contains(t, d.allowedWilds, ".internal.net") - require.Len(t, d.allowedNets, 2) // 10.0.0.0/8 and 1.2.3.4/32 -} - -func TestConfigure_Disabled(t *testing.T) { - Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}}) +func TestConfigure_SetsDialer(t *testing.T) { + err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"})) + require.NoError(t, err) require.NotNil(t, globalDialer.Load()) + t.Cleanup(func() { globalDialer.Store(nil) }) +} - Configure(Policy{}) - require.Nil(t, globalDialer.Load()) +func TestConfigure_NilServicesReturnsError(t *testing.T) { + err := Configure(nil) + require.Error(t, err) } func TestWrapTransport_NoPolicy(t *testing.T) { @@ -86,7 +125,8 @@ func TestWrapTransport_NoPolicy(t *testing.T) { } func TestWrapTransport_WithPolicy(t *testing.T) { - Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}}) + err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"})) + require.NoError(t, err) t.Cleanup(func() { globalDialer.Store(nil) }) base := &http.Transport{} @@ -95,36 +135,144 @@ func TestWrapTransport_WithPolicy(t *testing.T) { require.NotNil(t, result.DialContext) } -func TestCheckURL_Disabled(t *testing.T) { - globalDialer.Store(nil) +func TestCheckURL(t *testing.T) { + tests := []struct { + name string + mode portainer.SSRFMode + entries []string + url string + wantErr bool + }{ + { + name: "disabled", + mode: portainer.SSRFModeOff, + url: "http://169.254.169.254/latest/meta-data/", + wantErr: false, + }, + { + name: "blocks IP not in allowlist", + mode: portainer.SSRFModeEnforce, + entries: []string{"8.8.8.0/24"}, + url: "http://169.254.169.254/latest/meta-data/", + wantErr: true, + }, + { + name: "allowed exact hostname", + mode: portainer.SSRFModeEnforce, + entries: []string{"example.com"}, + url: "https://example.com/path", + wantErr: false, + }, + { + name: "audit mode allows blocked IP", + mode: portainer.SSRFModeAudit, + entries: []string{"8.8.8.0/24"}, + url: "http://169.254.169.254/latest/meta-data/", + wantErr: false, + }, + { + name: "IP in CIDR allowlist", + mode: portainer.SSRFModeEnforce, + entries: []string{"8.8.8.0/24"}, + url: "http://8.8.8.8/path", + wantErr: false, + }, + { + name: "wildcard hostname", + mode: portainer.SSRFModeEnforce, + entries: []string{"*.example.com"}, + url: "https://api.example.com/path", + wantErr: false, + }, + { + name: "hostname DNS resolves to allowed IP", + mode: portainer.SSRFModeEnforce, + entries: []string{"127.0.0.0/8", "::1/128"}, + url: "http://localhost/path", + wantErr: false, + }, + { + name: "hostname DNS resolves to blocked IP", + mode: portainer.SSRFModeEnforce, + entries: []string{"8.8.8.0/24"}, + url: "http://localhost/path", + wantErr: true, + }, + { + name: "audit mode allows hostname resolving to blocked IP", + mode: portainer.SSRFModeAudit, + entries: []string{"8.8.8.0/24"}, + url: "http://localhost/path", + wantErr: false, + }, + { + name: "invalid URL", + mode: portainer.SSRFModeEnforce, + entries: []string{"example.com"}, + url: "http://%gg", + wantErr: true, + }, + { + name: "empty host", + mode: portainer.SSRFModeEnforce, + entries: []string{"example.com"}, + url: "http://", + wantErr: false, + }, + } - err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/") - require.NoError(t, err) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := Configure(newStaticService(tc.mode, tc.entries)) + require.NoError(t, err) + t.Cleanup(func() { globalDialer.Store(nil) }) + + err = CheckURL(t.Context(), tc.url) + if tc.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), "ssrf") + } else { + require.NoError(t, err) + } + }) + } } -func TestCheckURL_BlocksIPNotInAllowlist(t *testing.T) { - Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}}) +// TestCheckURL_HostnameDNSError verifies that a DNS resolution failure is +// propagated as an SSRF-prefixed error. Kept separate because it needs a +// cancelled context rather than t.Context(). +func TestCheckURL_HostnameDNSError(t *testing.T) { + err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"8.8.8.0/24"})) + require.NoError(t, err) t.Cleanup(func() { globalDialer.Store(nil) }) - err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/") + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + err = CheckURL(ctx, "http://portainer-nonexistent.test.invalid/path") require.Error(t, err) - require.Contains(t, err.Error(), "ssrf") } -func TestCheckURL_AllowedHostname(t *testing.T) { - Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}}) - t.Cleanup(func() { globalDialer.Store(nil) }) +func TestIsEnabled(t *testing.T) { + globalDialer.Store(nil) + require.False(t, IsEnabled()) - err := CheckURL(t.Context(), "https://example.com/path") + err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"})) require.NoError(t, err) + t.Cleanup(func() { globalDialer.Store(nil) }) + require.True(t, IsEnabled()) + + err = Configure(newStaticService(portainer.SSRFModeOff, nil)) + require.NoError(t, err) + require.False(t, IsEnabled()) } -func TestCheckURL_AuditMode_ReturnsNil(t *testing.T) { - Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.0/24"}}) - t.Cleanup(func() { globalDialer.Store(nil) }) +func TestWrapTransportInternal(t *testing.T) { + t.Parallel() - err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/") - require.NoError(t, err) + base := &http.Transport{} + result := WrapTransportInternal(base) + require.Equal(t, base, result) } // TestDialContext_BlocksLoopback is an end-to-end test: it starts a real HTTP @@ -136,7 +284,8 @@ func TestDialContext_BlocksLoopback(t *testing.T) { })) defer srv.Close() - Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.8"}}) + err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"8.8.8.8"})) + require.NoError(t, err) t.Cleanup(func() { globalDialer.Store(nil) }) blocked := &http.Client{Transport: WrapTransport(&http.Transport{})} @@ -147,7 +296,9 @@ func TestDialContext_BlocksLoopback(t *testing.T) { require.NoError(t, resp.Body.Close()) } - Configure(Policy{}) + // Switch to off mode — dialer stays configured but checks are skipped. + err = Configure(newStaticService(portainer.SSRFModeOff, nil)) + require.NoError(t, err) open := &http.Client{Transport: WrapTransport(&http.Transport{})} resp, err = open.Get(srv.URL) @@ -163,7 +314,8 @@ func TestDialContext_AuditMode_AllowsLoopback(t *testing.T) { })) defer srv.Close() - Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.8"}}) + err := Configure(newStaticService(portainer.SSRFModeAudit, []string{"8.8.8.8"})) + require.NoError(t, err) t.Cleanup(func() { globalDialer.Store(nil) }) client := &http.Client{Transport: WrapTransport(&http.Transport{})} @@ -172,122 +324,15 @@ func TestDialContext_AuditMode_AllowsLoopback(t *testing.T) { require.NoError(t, resp.Body.Close()) } -func TestIsEnabled(t *testing.T) { - globalDialer.Store(nil) - require.False(t, IsEnabled()) - - Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}}) - t.Cleanup(func() { globalDialer.Store(nil) }) - require.True(t, IsEnabled()) -} - -func TestWrapTransportInternal(t *testing.T) { - t.Parallel() - - base := &http.Transport{} - result := WrapTransportInternal(base) - require.Equal(t, base, result) -} - -func TestNewSafeDialer_IPv6SingleIP(t *testing.T) { - t.Parallel() - - d := newSafeDialer(Policy{ - Mode: ModeEnforce, - AllowedHosts: []string{"::1"}, - }) - - require.True(t, d.ipAllowed(net.ParseIP("::1"))) - require.False(t, d.ipAllowed(net.ParseIP("::2"))) -} - -func TestCheckURL_InvalidURL(t *testing.T) { - Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}}) - t.Cleanup(func() { globalDialer.Store(nil) }) - - err := CheckURL(t.Context(), "http://%gg") - require.Error(t, err) - require.Contains(t, err.Error(), "ssrf") -} - -func TestCheckURL_EmptyHost(t *testing.T) { - Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}}) - t.Cleanup(func() { globalDialer.Store(nil) }) - - err := CheckURL(t.Context(), "http://") - require.NoError(t, err) -} - -// TestCheckURL_IPInAllowlist verifies that a literal IP address that falls -// within an allowed CIDR range is permitted. -func TestCheckURL_IPInAllowlist(t *testing.T) { - Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}}) - t.Cleanup(func() { globalDialer.Store(nil) }) - - err := CheckURL(t.Context(), "http://8.8.8.8/path") - require.NoError(t, err) -} - -func TestCheckURL_WildcardHostname(t *testing.T) { - Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"*.example.com"}}) - t.Cleanup(func() { globalDialer.Store(nil) }) - - err := CheckURL(t.Context(), "https://api.example.com/path") - require.NoError(t, err) -} - -// TestCheckURL_HostnameDNSResolvesToAllowedIP verifies that a hostname -// resolving to an IP within the allowlist is permitted (DNS resolution path). -func TestCheckURL_HostnameDNSResolvesToAllowedIP(t *testing.T) { - Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"127.0.0.0/8", "::1/128"}}) - t.Cleanup(func() { globalDialer.Store(nil) }) - - err := CheckURL(t.Context(), "http://localhost/path") - require.NoError(t, err) -} - -// TestCheckURL_HostnameDNSResolvesToBlockedIP verifies that a hostname -// resolving to an IP outside the allowlist is blocked (DNS resolution path). -func TestCheckURL_HostnameDNSResolvesToBlockedIP(t *testing.T) { - Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}}) - t.Cleanup(func() { globalDialer.Store(nil) }) - - err := CheckURL(t.Context(), "http://localhost/path") - require.Error(t, err) - require.Contains(t, err.Error(), "ssrf") -} - -// TestCheckURL_HostnameDNSAuditMode verifies that audit mode logs violations -// from hostname DNS resolution but still returns nil. -func TestCheckURL_HostnameDNSAuditMode(t *testing.T) { - Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.0/24"}}) - t.Cleanup(func() { globalDialer.Store(nil) }) - - err := CheckURL(t.Context(), "http://localhost/path") - require.NoError(t, err) -} - -// TestCheckURL_HostnameDNSError verifies that a DNS resolution failure is -// propagated as an SSRF-prefixed error. -func TestCheckURL_HostnameDNSError(t *testing.T) { - Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}}) - t.Cleanup(func() { globalDialer.Store(nil) }) - - ctx, cancel := context.WithCancel(t.Context()) - cancel() - - err := CheckURL(ctx, "http://portainer-nonexistent.test.invalid/path") - require.Error(t, err) -} - // TestDialContext_InvalidAddress verifies that an address without a port // returns an SSRF-prefixed error. func TestDialContext_InvalidAddress(t *testing.T) { - Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}}) + err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"})) + require.NoError(t, err) t.Cleanup(func() { globalDialer.Store(nil) }) d := globalDialer.Load() - _, err := d.DialContext(t.Context(), "tcp", "no-port-here") + _, err = d.DialContext(t.Context(), "tcp", "no-port-here") require.Error(t, err) require.Contains(t, err.Error(), "ssrf") } @@ -295,14 +340,15 @@ func TestDialContext_InvalidAddress(t *testing.T) { // TestDialContext_DNSError verifies that a DNS resolution failure in // DialContext is propagated as an SSRF-prefixed error. func TestDialContext_DNSError(t *testing.T) { - Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{}}) + err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{})) + require.NoError(t, err) t.Cleanup(func() { globalDialer.Store(nil) }) ctx, cancel := context.WithCancel(t.Context()) cancel() d := globalDialer.Load() - _, err := d.DialContext(ctx, "tcp", "portainer-nonexistent.test.invalid:80") + _, err = d.DialContext(ctx, "tcp", "portainer-nonexistent.test.invalid:80") require.Error(t, err) } @@ -314,7 +360,8 @@ func TestDialContext_AllowedByCIDR(t *testing.T) { })) defer srv.Close() - Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"127.0.0.0/8"}}) + err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"127.0.0.0/8"})) + require.NoError(t, err) t.Cleanup(func() { globalDialer.Store(nil) }) client := &http.Client{Transport: WrapTransport(&http.Transport{})} @@ -347,7 +394,8 @@ func TestDialContext_AllowedByExactHostname(t *testing.T) { _, portStr, err := net.SplitHostPort(l.Addr().String()) require.NoError(t, err) - Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"localhost"}}) + err = Configure(newStaticService(portainer.SSRFModeEnforce, []string{"localhost"})) + require.NoError(t, err) t.Cleanup(func() { globalDialer.Store(nil) }) client := &http.Client{Transport: WrapTransport(&http.Transport{})}