mirror of
https://github.com/portainer/portainer.git
synced 2026-06-23 04:10:29 +00:00
feat(ssrf): add ssrf allow list to settings [BE-13021] (#2858)
This commit is contained in:
@@ -56,8 +56,6 @@ func CLIFlags() *portainer.CLIFlags {
|
||||
TrustedOrigins: kingpin.Flag("trusted-origins", "List of trusted origins for CSRF protection. Separate multiple origins with a comma.").Envar(portainer.TrustedOriginsEnvVar).String(),
|
||||
CSP: kingpin.Flag("csp", "Content Security Policy (CSP) header").Envar(portainer.CSPEnvVar).Default("true").Bool(),
|
||||
CompactDB: kingpin.Flag("compact-db", "Enable database compaction on startup").Envar(portainer.CompactDBEnvVar).Default("false").Bool(),
|
||||
SSRFMode: kingpin.Flag("ssrf-mode", "SSRF protection mode: off (disabled), audit (log violations but allow), enforce (block violations)").Envar("PORTAINER_SSRF_MODE").Default("off").Enum("off", "audit", "enforce"),
|
||||
SSRFAllowedHosts: kingpin.Flag("ssrf-allowed-hosts", "Allowlist of hostnames (with optional wildcards), IPs, or CIDRs permitted for outbound requests. When empty and mode is enforce, all outbound connections are blocked").Envar("PORTAINER_SSRF_ALLOWED_HOSTS").Strings(),
|
||||
NoSetupToken: kingpin.Flag("no-setup-token", "Disable the setup token requirement for admin initialization and restore on an uninitialized instance").Envar(portainer.NoSetupTokenEnvVar).Bool(),
|
||||
SetupToken: kingpin.Flag("setup-token", "Set a custom setup token for admin initialization and restore on an uninitialized instance (overrides auto-generation)").Envar(portainer.SetupTokenEnvVar).String(),
|
||||
}
|
||||
|
||||
+10
-13
@@ -387,19 +387,6 @@ func buildServer(flags *portainer.CLIFlags, shutdownCtx context.Context, shutdow
|
||||
// -ce can not ever be run in FIPS mode
|
||||
fips.InitFIPS(false)
|
||||
|
||||
ssrf.Configure(ssrf.Policy{
|
||||
Mode: ssrf.Mode(*flags.SSRFMode),
|
||||
AllowedHosts: *flags.SSRFAllowedHosts,
|
||||
})
|
||||
|
||||
if ssrf.IsEnabled() {
|
||||
if dt, ok := nethttp.DefaultTransport.(*nethttp.Transport); ok {
|
||||
nethttp.DefaultTransport = ssrf.WrapTransport(dt)
|
||||
}
|
||||
|
||||
gogithttp.DefaultClient = gogithttp.NewClient(&nethttp.Client{Transport: nethttp.DefaultTransport})
|
||||
}
|
||||
|
||||
fileService := initFileService(*flags.Data)
|
||||
encryptionKey := loadEncryptionSecretKey(dbSecretPath(*flags.SecretKeyName))
|
||||
if encryptionKey == nil {
|
||||
@@ -417,6 +404,16 @@ func buildServer(flags *portainer.CLIFlags, shutdownCtx context.Context, shutdow
|
||||
log.Fatal().Msg("The database schema version does not align with the server version. Please consider reverting to the previous server version or addressing the database migration issue.")
|
||||
}
|
||||
|
||||
if err := ssrf.Configure(dataStore.AllowList()); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed initializing ssrf service")
|
||||
}
|
||||
|
||||
if dt, ok := nethttp.DefaultTransport.(*nethttp.Transport); ok {
|
||||
nethttp.DefaultTransport = ssrf.WrapTransport(dt)
|
||||
}
|
||||
|
||||
gogithttp.DefaultClient = gogithttp.NewClient(&nethttp.Client{Transport: nethttp.DefaultTransport})
|
||||
|
||||
instanceID, err := dataStore.Version().InstanceID()
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed getting instance id")
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
package allowlist
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
)
|
||||
|
||||
const (
|
||||
BucketName = "allowlist"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
baseService dataservices.BaseDataService[portainer.AllowList, portainer.AllowListKey]
|
||||
cache *lru.Cache
|
||||
}
|
||||
|
||||
func (service *Service) BucketName() string {
|
||||
return service.baseService.BucketName()
|
||||
}
|
||||
|
||||
func NewService(connection portainer.Connection) (*Service, error) {
|
||||
err := connection.SetServiceName(BucketName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
service := &Service{
|
||||
baseService: dataservices.BaseDataService[portainer.AllowList, portainer.AllowListKey]{
|
||||
Bucket: BucketName,
|
||||
Connection: connection,
|
||||
},
|
||||
}
|
||||
|
||||
err = service.populateCache()
|
||||
|
||||
return service, err
|
||||
}
|
||||
|
||||
func (service *Service) populateCache() error {
|
||||
allowListKeys := []portainer.AllowListKey{portainer.AllowListSSRF}
|
||||
cache, err := lru.New(len(allowListKeys))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, k := range allowListKeys {
|
||||
allowList, err := service.baseService.Read(k)
|
||||
if dataservices.IsErrObjectNotFound(err) {
|
||||
allowList = &portainer.AllowList{
|
||||
ID: k,
|
||||
Mode: portainer.SSRFModeOff,
|
||||
Entries: []string{},
|
||||
}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
parsedAllowList := ssrf.ParseAllowedHosts(allowList.Entries)
|
||||
parsedAllowList.Mode = allowList.Mode
|
||||
|
||||
cache.Add(k, &parsedAllowList)
|
||||
}
|
||||
|
||||
service.cache = cache
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service *Service) Tx(tx portainer.Transaction) *ServiceTx {
|
||||
return &ServiceTx{
|
||||
baseService: service.baseService.Tx(tx),
|
||||
cache: service.cache,
|
||||
}
|
||||
}
|
||||
|
||||
func (service *Service) Read(id portainer.AllowListKey) (*portainer.AllowList, error) {
|
||||
var result *portainer.AllowList
|
||||
if err := service.baseService.Connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
result, err = service.Tx(tx).Read(id)
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (service *Service) ReadAll() ([]portainer.AllowList, error) {
|
||||
var result []portainer.AllowList
|
||||
if err := service.baseService.Connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
result, err = service.Tx(tx).ReadAll()
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (service *Service) ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error) {
|
||||
allowListAny, ok := service.cache.Get(id)
|
||||
if ok {
|
||||
allowList, ok := allowListAny.(*portainer.ParsedAllowList)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected ParsedAllowList in cache but got %T", allowListAny)
|
||||
}
|
||||
|
||||
return allowList, nil
|
||||
}
|
||||
|
||||
var result *portainer.ParsedAllowList
|
||||
err := service.baseService.Connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
result, err = service.Tx(tx).ReadParsed(id)
|
||||
return err
|
||||
})
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (service *Service) Update(id portainer.AllowListKey, allowList *portainer.AllowList) error {
|
||||
return service.baseService.Connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).Update(id, allowList)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package allowlist_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAllowListReadEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ds := datastore.MustNewTestStore(t, false, false)
|
||||
got, err := ds.AllowList().Read(portainer.AllowListSSRF)
|
||||
expected := &portainer.AllowList{
|
||||
ID: portainer.AllowListSSRF,
|
||||
Mode: portainer.SSRFModeOff,
|
||||
Entries: []string{},
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected, got)
|
||||
}
|
||||
|
||||
func TestAllowListUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ds := datastore.MustNewTestStore(t, false, false)
|
||||
|
||||
expected := &portainer.AllowList{
|
||||
ID: portainer.AllowListSSRF,
|
||||
Mode: portainer.SSRFModeEnforce,
|
||||
Entries: []string{"example.com", "10.0.0.0/8"},
|
||||
}
|
||||
|
||||
require.NoError(t, ds.AllowList().Update(portainer.AllowListSSRF, expected))
|
||||
|
||||
got, err := ds.AllowList().Read(portainer.AllowListSSRF)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected, got)
|
||||
}
|
||||
|
||||
func TestAllowListReadAllEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ds := datastore.MustNewTestStore(t, false, false)
|
||||
|
||||
got, err := ds.AllowList().ReadAll()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []portainer.AllowList{}, got)
|
||||
}
|
||||
|
||||
func TestAllowListReadAllAfterUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ds := datastore.MustNewTestStore(t, false, false)
|
||||
|
||||
expected := portainer.AllowList{
|
||||
ID: portainer.AllowListSSRF,
|
||||
Mode: portainer.SSRFModeEnforce,
|
||||
Entries: []string{"example.com", "10.0.0.0/8"},
|
||||
}
|
||||
|
||||
require.NoError(t, ds.AllowList().Update(portainer.AllowListSSRF, &expected))
|
||||
|
||||
got, err := ds.AllowList().ReadAll()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []portainer.AllowList{expected}, got)
|
||||
}
|
||||
|
||||
func TestAllowListReadParsedAfterUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ds := datastore.MustNewTestStore(t, false, false)
|
||||
|
||||
require.NoError(t, ds.AllowList().Update(portainer.AllowListSSRF, &portainer.AllowList{
|
||||
ID: portainer.AllowListSSRF,
|
||||
Mode: portainer.SSRFModeEnforce,
|
||||
Entries: []string{"example.com"},
|
||||
}))
|
||||
|
||||
expected := &portainer.ParsedAllowList{
|
||||
Mode: portainer.SSRFModeEnforce,
|
||||
Nets: []*net.IPNet{},
|
||||
Hosts: map[string]bool{
|
||||
"example.com": true,
|
||||
},
|
||||
}
|
||||
|
||||
got, err := ds.AllowList().ReadParsed(portainer.AllowListSSRF)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected, got)
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package allowlist
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
)
|
||||
|
||||
type ServiceTx struct {
|
||||
baseService dataservices.BaseDataServiceTx[portainer.AllowList, portainer.AllowListKey]
|
||||
cache *lru.Cache
|
||||
}
|
||||
|
||||
func (service *ServiceTx) BucketName() string {
|
||||
return service.baseService.BucketName()
|
||||
}
|
||||
|
||||
func (service *ServiceTx) ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error) {
|
||||
allowListAny, ok := service.cache.Get(id)
|
||||
if ok {
|
||||
allowList, ok := allowListAny.(*portainer.ParsedAllowList)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected ParsedAllowList in cache but got %T", allowListAny)
|
||||
}
|
||||
|
||||
return allowList, nil
|
||||
}
|
||||
|
||||
allowList, err := service.Read(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsed := ssrf.ParseAllowedHosts(allowList.Entries)
|
||||
parsed.Mode = allowList.Mode
|
||||
service.cache.Add(id, &parsed)
|
||||
|
||||
return &parsed, nil
|
||||
}
|
||||
|
||||
func (service *ServiceTx) Read(id portainer.AllowListKey) (*portainer.AllowList, error) {
|
||||
allowList, err := service.baseService.Read(id)
|
||||
if dataservices.IsErrObjectNotFound(err) {
|
||||
allowList = &portainer.AllowList{
|
||||
ID: id,
|
||||
Mode: portainer.SSRFModeOff,
|
||||
Entries: []string{},
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return allowList, nil
|
||||
}
|
||||
|
||||
func (service *ServiceTx) ReadAll() ([]portainer.AllowList, error) {
|
||||
allowLists, err := service.baseService.ReadAll()
|
||||
if err != nil && !dataservices.IsErrObjectNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return allowLists, nil
|
||||
}
|
||||
|
||||
func (service *ServiceTx) Update(id portainer.AllowListKey, allowList *portainer.AllowList) error {
|
||||
if err := service.baseService.Update(id, allowList); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
parsed := ssrf.ParseAllowedHosts(allowList.Entries)
|
||||
parsed.Mode = allowList.Mode
|
||||
service.cache.Add(id, &parsed)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package allowlist_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAllowListReadTx(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ds := datastore.MustNewTestStore(t, false, false)
|
||||
|
||||
var got *portainer.AllowList
|
||||
require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error {
|
||||
var err error
|
||||
got, err = tx.AllowList().Read(portainer.AllowListSSRF)
|
||||
return err
|
||||
}))
|
||||
|
||||
expected := &portainer.AllowList{
|
||||
ID: portainer.AllowListSSRF,
|
||||
Mode: portainer.SSRFModeOff,
|
||||
Entries: []string{},
|
||||
}
|
||||
|
||||
require.Equal(t, expected, got)
|
||||
}
|
||||
|
||||
func TestAllowListReadAllEmptyTx(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ds := datastore.MustNewTestStore(t, false, false)
|
||||
|
||||
var got []portainer.AllowList
|
||||
require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error {
|
||||
var err error
|
||||
got, err = tx.AllowList().ReadAll()
|
||||
return err
|
||||
}))
|
||||
|
||||
require.Equal(t, []portainer.AllowList{}, got)
|
||||
}
|
||||
|
||||
func TestAllowListReadAllAfterUpdateTx(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ds := datastore.MustNewTestStore(t, false, false)
|
||||
|
||||
expected := portainer.AllowList{
|
||||
ID: portainer.AllowListSSRF,
|
||||
Mode: portainer.SSRFModeEnforce,
|
||||
Entries: []string{"example.com"},
|
||||
}
|
||||
|
||||
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
return tx.AllowList().Update(portainer.AllowListSSRF, &expected)
|
||||
}))
|
||||
|
||||
var got []portainer.AllowList
|
||||
require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error {
|
||||
var err error
|
||||
got, err = tx.AllowList().ReadAll()
|
||||
return err
|
||||
}))
|
||||
|
||||
require.Equal(t, []portainer.AllowList{expected}, got)
|
||||
}
|
||||
|
||||
func TestAllowListUpdateTx(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ds := datastore.MustNewTestStore(t, false, false)
|
||||
|
||||
expected := &portainer.AllowList{
|
||||
ID: portainer.AllowListSSRF,
|
||||
Mode: portainer.SSRFModeEnforce,
|
||||
Entries: []string{"example.com"},
|
||||
}
|
||||
|
||||
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
return tx.AllowList().Update(portainer.AllowListSSRF, expected)
|
||||
}))
|
||||
|
||||
var got *portainer.AllowList
|
||||
require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error {
|
||||
var err error
|
||||
got, err = tx.AllowList().Read(portainer.AllowListSSRF)
|
||||
return err
|
||||
}))
|
||||
|
||||
require.Equal(t, expected, got)
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
type (
|
||||
DataStoreTx interface {
|
||||
IsErrObjectNotFound(err error) bool
|
||||
AllowList() AllowListService
|
||||
CustomTemplate() CustomTemplateService
|
||||
EdgeGroup() EdgeGroupService
|
||||
EdgeJob() EdgeJobService
|
||||
@@ -53,6 +54,15 @@ type (
|
||||
DataStoreTx
|
||||
}
|
||||
|
||||
// AllowListService represents a service for managing the URL allow list
|
||||
AllowListService interface {
|
||||
Read(id portainer.AllowListKey) (*portainer.AllowList, error)
|
||||
ReadAll() ([]portainer.AllowList, error)
|
||||
ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error)
|
||||
Update(id portainer.AllowListKey, allowList *portainer.AllowList) error
|
||||
BucketName() string
|
||||
}
|
||||
|
||||
// CustomTemplateService represents a service to manage custom templates
|
||||
CustomTemplateService interface {
|
||||
BaseCRUD[portainer.CustomTemplate, portainer.CustomTemplateID]
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/database/models"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/dataservices/allowlist"
|
||||
"github.com/portainer/portainer/api/dataservices/apikeyrepository"
|
||||
"github.com/portainer/portainer/api/dataservices/customtemplate"
|
||||
"github.com/portainer/portainer/api/dataservices/dockerhub"
|
||||
@@ -51,6 +52,7 @@ type Store struct {
|
||||
connection portainer.Connection
|
||||
|
||||
fileService portainer.FileService
|
||||
AllowListService *allowlist.Service
|
||||
CustomTemplateService *customtemplate.Service
|
||||
DockerHubService *dockerhub.Service
|
||||
EdgeGroupService *edgegroup.Service
|
||||
@@ -84,6 +86,12 @@ type Store struct {
|
||||
}
|
||||
|
||||
func (store *Store) initServices() error {
|
||||
allowListService, err := allowlist.NewService(store.connection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
store.AllowListService = allowListService
|
||||
|
||||
authorizationsetService, err := role.NewService(store.connection)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -275,6 +283,11 @@ func (store *Store) PendingActions() dataservices.PendingActionsService {
|
||||
return store.PendingActionsService
|
||||
}
|
||||
|
||||
// AllowList gives access to the AllowList data management layer
|
||||
func (store *Store) AllowList() dataservices.AllowListService {
|
||||
return store.AllowListService
|
||||
}
|
||||
|
||||
// CustomTemplate gives access to the CustomTemplate data management layer
|
||||
func (store *Store) CustomTemplate() dataservices.CustomTemplateService {
|
||||
return store.CustomTemplateService
|
||||
@@ -654,7 +667,7 @@ func (store *Store) Export(filename string) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(filename, b, 0600)
|
||||
return os.WriteFile(filename, b, 0o600)
|
||||
}
|
||||
|
||||
func (store *Store) Import(filename string) (err error) {
|
||||
|
||||
@@ -14,6 +14,10 @@ func (tx *StoreTx) IsErrObjectNotFound(err error) bool {
|
||||
return tx.store.IsErrObjectNotFound(err)
|
||||
}
|
||||
|
||||
func (tx *StoreTx) AllowList() dataservices.AllowListService {
|
||||
return tx.store.AllowListService.Tx(tx.tx)
|
||||
}
|
||||
|
||||
func (tx *StoreTx) CustomTemplate() dataservices.CustomTemplateService {
|
||||
return tx.store.CustomTemplateService.Tx(tx.tx)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
{
|
||||
"allowlist": null,
|
||||
"api_key": null,
|
||||
"customtemplates": null,
|
||||
"dockerhub": [
|
||||
|
||||
@@ -90,7 +90,7 @@ func createTCPClient(endpoint *portainer.Endpoint, timeout *time.Duration) (*cli
|
||||
client.WithHTTPClient(httpCli),
|
||||
}
|
||||
|
||||
if nnTransport, ok := httpCli.Transport.(*NodeNameTransport); ok && nnTransport.TLSClientConfig != nil {
|
||||
if endpoint.TLSConfig.TLS {
|
||||
opts = append(opts, client.WithScheme("https"))
|
||||
}
|
||||
|
||||
@@ -124,7 +124,7 @@ func createAgentClient(endpoint *portainer.Endpoint, endpointURL string, signatu
|
||||
client.WithHTTPHeaders(headers),
|
||||
}
|
||||
|
||||
if nnTransport, ok := httpCli.Transport.(*NodeNameTransport); ok && nnTransport.TLSClientConfig != nil {
|
||||
if endpoint.TLSConfig.TLS {
|
||||
opts = append(opts, client.WithScheme("https"))
|
||||
}
|
||||
|
||||
|
||||
@@ -45,10 +45,24 @@ func (s *stubTunnelService) UpdateLastActivity(endpointID portainer.EndpointID)
|
||||
func (s *stubTunnelService) KeepTunnelAlive(endpointID portainer.EndpointID, ctx context.Context, maxKeepAlive time.Duration) {
|
||||
}
|
||||
|
||||
type staticAllowListService struct {
|
||||
parsed portainer.ParsedAllowList
|
||||
}
|
||||
|
||||
func (s *staticAllowListService) ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error) {
|
||||
return &s.parsed, nil
|
||||
}
|
||||
|
||||
func enableSSRF(t *testing.T) {
|
||||
t.Helper()
|
||||
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
|
||||
parsed := ssrf.ParseAllowedHosts([]string{"example.com"})
|
||||
parsed.Mode = portainer.SSRFModeEnforce
|
||||
err := ssrf.Configure(&staticAllowListService{parsed: parsed})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := ssrf.Configure(&staticAllowListService{})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestNewDockerHTTPProxy_NonEdgeNoTLS verifies that a plain non-edge endpoint
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
var _ dataservices.DataStore = &testDatastore{}
|
||||
|
||||
type testDatastore struct {
|
||||
allowList dataservices.AllowListService
|
||||
customTemplate dataservices.CustomTemplateService
|
||||
edgeGroup dataservices.EdgeGroupService
|
||||
edgeJob dataservices.EdgeJobService
|
||||
@@ -53,6 +54,7 @@ func (d *testDatastore) ViewTx(func(dataservices.DataStoreTx) error) error { r
|
||||
func (d *testDatastore) CheckCurrentEdition() error { return nil }
|
||||
func (d *testDatastore) MigrateData() error { return nil }
|
||||
func (d *testDatastore) Rollback(force bool) error { return nil }
|
||||
func (d *testDatastore) AllowList() dataservices.AllowListService { return d.allowList }
|
||||
func (d *testDatastore) CustomTemplate() dataservices.CustomTemplateService { return d.customTemplate }
|
||||
func (d *testDatastore) EdgeGroup() dataservices.EdgeGroupService { return d.edgeGroup }
|
||||
func (d *testDatastore) EdgeJob() dataservices.EdgeJobService { return d.edgeJob }
|
||||
|
||||
+30
-2
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -111,8 +112,6 @@ type (
|
||||
KubectlShellImageSet bool
|
||||
PullLimitCheckDisabled *bool
|
||||
TrustedOrigins *string
|
||||
SSRFMode *string
|
||||
SSRFAllowedHosts *[]string
|
||||
NoSetupToken *bool
|
||||
SetupToken *string
|
||||
}
|
||||
@@ -1215,6 +1214,21 @@ type (
|
||||
// SoftwareEdition represents an edition of Portainer
|
||||
SoftwareEdition int
|
||||
|
||||
// AllowList holds the list of permitted outbound proxy destinations.
|
||||
AllowList struct {
|
||||
ID AllowListKey `json:"Id"`
|
||||
Mode SSRFMode `json:"Mode"`
|
||||
Entries []string `json:"Entries"`
|
||||
}
|
||||
|
||||
// ParsedAllowList holds the three parsed forms of allow list entries.
|
||||
ParsedAllowList struct {
|
||||
Mode SSRFMode
|
||||
Nets []*net.IPNet
|
||||
Hosts map[string]bool
|
||||
Wilds []string // stored as ".foo.com" ("*." prefix stripped)
|
||||
}
|
||||
|
||||
// SSLSettings represents a pair of SSL certificate and key
|
||||
SSLSettings struct {
|
||||
CertPath string `json:"certPath"`
|
||||
@@ -2670,3 +2684,17 @@ func DefaultEndpointSecuritySettings() EndpointSecuritySettings {
|
||||
AllowStackManagementForRegularUsers: true,
|
||||
}
|
||||
}
|
||||
|
||||
type AllowListKey int
|
||||
|
||||
const (
|
||||
AllowListSSRF AllowListKey = iota
|
||||
)
|
||||
|
||||
type SSRFMode int
|
||||
|
||||
const (
|
||||
SSRFModeOff SSRFMode = iota
|
||||
SSRFModeAudit
|
||||
SSRFModeEnforce
|
||||
)
|
||||
|
||||
@@ -139,7 +139,6 @@ services:
|
||||
`)
|
||||
|
||||
stack := &portainer.Stack{
|
||||
|
||||
ProjectPath: "/tmp/stack/1",
|
||||
EntryPoint: "docker-compose.yml",
|
||||
Env: []portainer.Pair{{Name: "API_PORT", Value: "3000"}},
|
||||
@@ -186,7 +185,7 @@ func TestValidateStackFiles_DotEnvFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
err := os.WriteFile(filesystem.JoinPaths(tmpDir, ".env"), []byte("HOST_PORT=3000\n"), 0600)
|
||||
err := os.WriteFile(filesystem.JoinPaths(tmpDir, ".env"), []byte("HOST_PORT=3000\n"), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
fileContent := []byte(`
|
||||
@@ -217,7 +216,7 @@ func TestValidateStackFiles_EnvFileAttribute(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
err := os.WriteFile(filesystem.JoinPaths(tmpDir, "web.env"), []byte("HOST_PORT=3000\n"), 0600)
|
||||
err := os.WriteFile(filesystem.JoinPaths(tmpDir, "web.env"), []byte("HOST_PORT=3000\n"), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
fileContent := []byte(`
|
||||
@@ -298,8 +297,29 @@ func TestExtractImageRegistry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type staticAllowListService struct {
|
||||
parsed portainer.ParsedAllowList
|
||||
}
|
||||
|
||||
func (s *staticAllowListService) ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error) {
|
||||
return &s.parsed, nil
|
||||
}
|
||||
|
||||
func configureSSRF(t *testing.T, mode portainer.SSRFMode, entries []string) {
|
||||
t.Helper()
|
||||
|
||||
parsed := ssrf.ParseAllowedHosts(entries)
|
||||
parsed.Mode = mode
|
||||
err := ssrf.Configure(&staticAllowListService{parsed: parsed})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := ssrf.Configure(&staticAllowListService{})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateComposeURLs_DisabledSSRF(t *testing.T) {
|
||||
ssrf.Configure(ssrf.Policy{})
|
||||
configureSSRF(t, portainer.SSRFModeOff, nil)
|
||||
|
||||
stack := &portainer.Stack{
|
||||
ProjectPath: "/tmp/stack/1",
|
||||
@@ -322,8 +342,7 @@ services:
|
||||
}
|
||||
|
||||
func TestValidateComposeURLs_BuildContextBlocked(t *testing.T) {
|
||||
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
|
||||
configureSSRF(t, portainer.SSRFModeEnforce, []string{"example.com"})
|
||||
|
||||
stack := &portainer.Stack{
|
||||
ProjectPath: "/tmp/stack/1",
|
||||
@@ -347,8 +366,7 @@ services:
|
||||
}
|
||||
|
||||
func TestValidateComposeURLs_BuildContextPath(t *testing.T) {
|
||||
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
|
||||
configureSSRF(t, portainer.SSRFModeEnforce, []string{"example.com"})
|
||||
|
||||
stack := &portainer.Stack{
|
||||
ProjectPath: "/tmp/stack/1",
|
||||
@@ -372,8 +390,7 @@ services:
|
||||
}
|
||||
|
||||
func TestValidateComposeURLs_ImageRegistryBlocked(t *testing.T) {
|
||||
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
|
||||
configureSSRF(t, portainer.SSRFModeEnforce, []string{"example.com"})
|
||||
|
||||
stack := &portainer.Stack{
|
||||
ProjectPath: "/tmp/stack/1",
|
||||
@@ -395,8 +412,7 @@ services:
|
||||
}
|
||||
|
||||
func TestValidateComposeURLs_ImageRegistryAllowed(t *testing.T) {
|
||||
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"myregistry.com"}})
|
||||
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
|
||||
configureSSRF(t, portainer.SSRFModeEnforce, []string{"myregistry.com"})
|
||||
|
||||
stack := &portainer.Stack{
|
||||
ProjectPath: "/tmp/stack/1",
|
||||
@@ -418,8 +434,7 @@ services:
|
||||
}
|
||||
|
||||
func TestValidateComposeURLs_DockerHubImageAllowed(t *testing.T) {
|
||||
ssrf.Configure(ssrf.Policy{Mode: ssrf.ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { ssrf.Configure(ssrf.Policy{}) })
|
||||
configureSSRF(t, portainer.SSRFModeEnforce, []string{"example.com"})
|
||||
|
||||
stack := &portainer.Stack{
|
||||
ProjectPath: "/tmp/stack/1",
|
||||
|
||||
+92
-88
@@ -2,6 +2,7 @@ package ssrf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -9,61 +10,84 @@ import (
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Mode controls how the SSRF policy is applied.
|
||||
type Mode string
|
||||
// ParseAllowedHosts parses raw allow list entries into their three canonical
|
||||
// forms. Accepted formats: exact hostname, wildcard hostname (*.example.com),
|
||||
// single IP, or CIDR range.
|
||||
func ParseAllowedHosts(entries []string) portainer.ParsedAllowList {
|
||||
nets := make([]*net.IPNet, 0, len(entries))
|
||||
hosts := make(map[string]bool, len(entries))
|
||||
var wilds []string
|
||||
|
||||
const (
|
||||
// ModeOff disables SSRF protection entirely. All connections pass through unchanged.
|
||||
ModeOff Mode = "off"
|
||||
// ModeAudit resolves and checks destinations but only logs violations; connections are allowed.
|
||||
ModeAudit Mode = "audit"
|
||||
// ModeEnforce blocks connections that violate the policy.
|
||||
ModeEnforce Mode = "enforce"
|
||||
)
|
||||
for _, entry := range entries {
|
||||
if _, network, err := net.ParseCIDR(entry); err == nil {
|
||||
nets = append(nets, network)
|
||||
continue
|
||||
}
|
||||
|
||||
// Policy defines the SSRF protection policy for outbound HTTP connections.
|
||||
type Policy struct {
|
||||
// Mode controls whether protection is off, in audit-only mode, or enforcing.
|
||||
Mode Mode
|
||||
if ip := net.ParseIP(entry); ip != nil {
|
||||
bits := 32
|
||||
if ip.To4() == nil {
|
||||
bits = 128
|
||||
}
|
||||
|
||||
// AllowedHosts is the allowlist of permitted destinations.
|
||||
// Accepted formats:
|
||||
// - Exact hostname: "example.com"
|
||||
// - Wildcard hostname: "*.example.com" (matches any subdomain at any depth)
|
||||
// - Single IP: "1.2.3.4"
|
||||
// - CIDR range: "10.0.0.0/8"
|
||||
//
|
||||
// When Mode is ModeEnforce and AllowedHosts is empty, all outbound connections are blocked.
|
||||
AllowedHosts []string
|
||||
mask := net.CIDRMask(bits, bits)
|
||||
nets = append(nets, &net.IPNet{IP: ip.Mask(mask), Mask: mask})
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(entry, "*.") {
|
||||
wilds = append(wilds, entry[1:]) // "*.foo.com" -> ".foo.com"
|
||||
continue
|
||||
}
|
||||
|
||||
hosts[entry] = true
|
||||
}
|
||||
|
||||
return portainer.ParsedAllowList{Nets: nets, Hosts: hosts, Wilds: wilds}
|
||||
}
|
||||
|
||||
// AllowListService is implemented by the allowlist data service.
|
||||
// ReadParsed is called on every dial to pick up runtime changes.
|
||||
type AllowListService interface {
|
||||
ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error)
|
||||
}
|
||||
|
||||
type safeDialer struct {
|
||||
base net.Dialer
|
||||
mode Mode
|
||||
allowedNets []*net.IPNet
|
||||
allowedHosts map[string]bool
|
||||
allowedWilds []string // derived from "*.foo.com" entries; stored as ".foo.com"
|
||||
base net.Dialer
|
||||
service AllowListService
|
||||
}
|
||||
|
||||
var globalDialer atomic.Pointer[safeDialer]
|
||||
|
||||
// Configure initializes the global SSRF policy. Intended to be called once
|
||||
// at startup before any outbound HTTP connections are established.
|
||||
func Configure(policy Policy) {
|
||||
if policy.Mode == ModeOff || policy.Mode == "" {
|
||||
globalDialer.Store(nil)
|
||||
return
|
||||
// Configure initializes the global SSRF policy with the allow list data service.
|
||||
func Configure(svc AllowListService) error {
|
||||
if svc == nil {
|
||||
return errors.New("unable to configure ssrf: service must not be nil")
|
||||
}
|
||||
|
||||
globalDialer.Store(newSafeDialer(policy))
|
||||
globalDialer.Store(&safeDialer{service: svc})
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsEnabled reports whether SSRF protection is currently active (audit or enforce).
|
||||
func IsEnabled() bool {
|
||||
return globalDialer.Load() != nil
|
||||
d := globalDialer.Load()
|
||||
if d == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
allowList, err := d.service.ReadParsed(portainer.AllowListSSRF)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("unable to check SSRF protection mode")
|
||||
return false
|
||||
}
|
||||
|
||||
return allowList.Mode != portainer.SSRFModeOff
|
||||
}
|
||||
|
||||
// CheckURL validates rawURL against the active SSRF policy without making a
|
||||
@@ -89,7 +113,8 @@ func CheckURL(ctx context.Context, rawURL string) error {
|
||||
}
|
||||
|
||||
// WrapTransport clones t and replaces its DialContext with the global SSRF-filtering
|
||||
// dialer. Returns t unchanged when SSRF protection is not configured.
|
||||
// dialer. The dialer checks the mode on every connection, so the transport is always
|
||||
// wrapped and mode changes take effect without restarting.
|
||||
func WrapTransport(t *http.Transport) *http.Transport {
|
||||
d := globalDialer.Load()
|
||||
if d == nil {
|
||||
@@ -112,48 +137,18 @@ func WrapTransportInternal(t *http.Transport) *http.Transport {
|
||||
return t
|
||||
}
|
||||
|
||||
func newSafeDialer(policy Policy) *safeDialer {
|
||||
allowedNets := make([]*net.IPNet, 0, len(policy.AllowedHosts))
|
||||
allowedHosts := make(map[string]bool, len(policy.AllowedHosts))
|
||||
var allowedWilds []string
|
||||
|
||||
for _, entry := range policy.AllowedHosts {
|
||||
if _, network, err := net.ParseCIDR(entry); err == nil {
|
||||
allowedNets = append(allowedNets, network)
|
||||
continue
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(entry); ip != nil {
|
||||
bits := 32
|
||||
if ip.To4() == nil {
|
||||
bits = 128
|
||||
}
|
||||
|
||||
mask := net.CIDRMask(bits, bits)
|
||||
allowedNets = append(allowedNets, &net.IPNet{IP: ip.Mask(mask), Mask: mask})
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(entry, "*.") {
|
||||
allowedWilds = append(allowedWilds, entry[1:]) // "*.foo.com" -> ".foo.com"
|
||||
continue
|
||||
}
|
||||
|
||||
allowedHosts[entry] = true
|
||||
}
|
||||
|
||||
return &safeDialer{
|
||||
mode: policy.Mode,
|
||||
allowedNets: allowedNets,
|
||||
allowedHosts: allowedHosts,
|
||||
allowedWilds: allowedWilds,
|
||||
}
|
||||
}
|
||||
|
||||
// DialContext resolves addr, validates all resolved IPs against the allowlist policy,
|
||||
// then dials using the first resolved IP to prevent DNS rebinding attacks.
|
||||
func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
allowList, err := d.service.ReadParsed(portainer.AllowListSSRF)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ssrf: reading allow list: %w", err)
|
||||
}
|
||||
|
||||
if allowList.Mode == portainer.SSRFModeOff {
|
||||
return d.base.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ssrf: invalid address %q: %w", addr, err)
|
||||
@@ -172,13 +167,13 @@ func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net
|
||||
// window between DNS validation and the TCP handshake (DNS rebinding).
|
||||
dialTarget := net.JoinHostPort(resolved[0].IP.String(), port)
|
||||
|
||||
if d.allowedHosts[host] || d.matchesWildcard(host) {
|
||||
if allowList.Hosts[host] || matchesWildcard(host, allowList.Wilds) {
|
||||
return d.base.DialContext(ctx, network, dialTarget)
|
||||
}
|
||||
|
||||
for _, a := range resolved {
|
||||
if !d.ipAllowed(a.IP) {
|
||||
if d.mode == ModeAudit {
|
||||
if !ipAllowed(a.IP, allowList.Nets) {
|
||||
if allowList.Mode == portainer.SSRFModeAudit {
|
||||
log.Warn().Str("host", host).Str("ip", a.IP.String()).Msg("ssrf: destination not in allowlist (audit mode, allowing)")
|
||||
continue
|
||||
}
|
||||
@@ -191,13 +186,22 @@ func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net
|
||||
}
|
||||
|
||||
func (d *safeDialer) checkHost(ctx context.Context, host string) error {
|
||||
if d.allowedHosts[host] || d.matchesWildcard(host) {
|
||||
allowList, err := d.service.ReadParsed(portainer.AllowListSSRF)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ssrf: reading allow list: %w", err)
|
||||
}
|
||||
|
||||
if allowList.Mode == portainer.SSRFModeOff {
|
||||
return nil
|
||||
}
|
||||
|
||||
if allowList.Hosts[host] || matchesWildcard(host, allowList.Wilds) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if !d.ipAllowed(ip) {
|
||||
if d.mode == ModeAudit {
|
||||
if !ipAllowed(ip, allowList.Nets) {
|
||||
if allowList.Mode == portainer.SSRFModeAudit {
|
||||
log.Warn().Str("host", host).Msg("ssrf: destination not in allowlist (audit mode, allowing)")
|
||||
return nil
|
||||
}
|
||||
@@ -218,8 +222,8 @@ func (d *safeDialer) checkHost(ctx context.Context, host string) error {
|
||||
}
|
||||
|
||||
for _, a := range resolved {
|
||||
if !d.ipAllowed(a.IP) {
|
||||
if d.mode == ModeAudit {
|
||||
if !ipAllowed(a.IP, allowList.Nets) {
|
||||
if allowList.Mode == portainer.SSRFModeAudit {
|
||||
log.Warn().Str("host", host).Str("ip", a.IP.String()).Msg("ssrf: destination not in allowlist (audit mode, allowing)")
|
||||
continue
|
||||
}
|
||||
@@ -231,8 +235,8 @@ func (d *safeDialer) checkHost(ctx context.Context, host string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *safeDialer) matchesWildcard(host string) bool {
|
||||
for _, suffix := range d.allowedWilds {
|
||||
func matchesWildcard(host string, wilds []string) bool {
|
||||
for _, suffix := range wilds {
|
||||
if strings.HasSuffix(host, suffix) {
|
||||
return true
|
||||
}
|
||||
@@ -241,8 +245,8 @@ func (d *safeDialer) matchesWildcard(host string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *safeDialer) ipAllowed(ip net.IP) bool {
|
||||
for _, network := range d.allowedNets {
|
||||
func ipAllowed(ip net.IP, nets []*net.IPNet) bool {
|
||||
for _, network := range nets {
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
|
||||
+232
-184
@@ -7,74 +7,113 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIpAllowed_CIDR(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
d := newSafeDialer(Policy{
|
||||
Mode: ModeEnforce,
|
||||
AllowedHosts: []string{"8.8.0.0/16", "2001:4860::/32"},
|
||||
})
|
||||
|
||||
require.True(t, d.ipAllowed(net.ParseIP("8.8.8.8")))
|
||||
require.True(t, d.ipAllowed(net.ParseIP("8.8.4.4")))
|
||||
require.True(t, d.ipAllowed(net.ParseIP("2001:4860:4860::8888")))
|
||||
|
||||
require.False(t, d.ipAllowed(net.ParseIP("1.1.1.1")))
|
||||
require.False(t, d.ipAllowed(net.ParseIP("127.0.0.1")))
|
||||
require.False(t, d.ipAllowed(net.ParseIP("169.254.169.254")))
|
||||
// staticService is a simple in-memory AllowListService for testing.
|
||||
type staticService struct {
|
||||
parsed portainer.ParsedAllowList
|
||||
}
|
||||
|
||||
func TestIpAllowed_SingleIP(t *testing.T) {
|
||||
func (s *staticService) ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error) {
|
||||
return &s.parsed, nil
|
||||
}
|
||||
|
||||
func newStaticService(mode portainer.SSRFMode, entries []string) *staticService {
|
||||
parsed := ParseAllowedHosts(entries)
|
||||
parsed.Mode = mode
|
||||
return &staticService{parsed: parsed}
|
||||
}
|
||||
|
||||
func TestParseAllowedHosts_ipAllowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
d := newSafeDialer(Policy{
|
||||
Mode: ModeEnforce,
|
||||
AllowedHosts: []string{"1.2.3.4"},
|
||||
})
|
||||
testCases := []struct {
|
||||
name string
|
||||
hostEntries []string
|
||||
allowed []string
|
||||
denied []string
|
||||
}{
|
||||
{
|
||||
name: "CIDR",
|
||||
hostEntries: []string{"8.8.0.0/16", "2001:4860::/32"},
|
||||
allowed: []string{"8.8.8.8", "8.8.4.4", "2001:4860:4860::8888"},
|
||||
denied: []string{"1.1.1.1", "127.0.0.1", "169.254.169.254"},
|
||||
},
|
||||
{
|
||||
name: "Single IP",
|
||||
hostEntries: []string{"1.2.3.4"},
|
||||
allowed: []string{"1.2.3.4"},
|
||||
denied: []string{"1.2.3.5"},
|
||||
},
|
||||
{
|
||||
name: "Single IPv6",
|
||||
hostEntries: []string{"::1"},
|
||||
allowed: []string{"::1"},
|
||||
denied: []string{"::2"},
|
||||
},
|
||||
}
|
||||
|
||||
require.True(t, d.ipAllowed(net.ParseIP("1.2.3.4")))
|
||||
require.False(t, d.ipAllowed(net.ParseIP("1.2.3.5")))
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parsed := ParseAllowedHosts(tc.hostEntries)
|
||||
for _, a := range tc.allowed {
|
||||
require.True(t, ipAllowed(net.ParseIP(a), parsed.Nets))
|
||||
}
|
||||
|
||||
for _, d := range tc.denied {
|
||||
require.False(t, ipAllowed(net.ParseIP(d), parsed.Nets))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAllowedHosts_MixedEntries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parsed := ParseAllowedHosts([]string{"example.com", "*.internal.net", "10.0.0.0/8", "1.2.3.4"})
|
||||
|
||||
require.True(t, parsed.Hosts["example.com"])
|
||||
require.Contains(t, parsed.Wilds, ".internal.net")
|
||||
require.Len(t, parsed.Nets, 2) // 10.0.0.0/8 and 1.2.3.4/32
|
||||
}
|
||||
|
||||
func TestMatchesWildcard(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
d := newSafeDialer(Policy{
|
||||
Mode: ModeEnforce,
|
||||
AllowedHosts: []string{"*.example.com", "exact.host.com"},
|
||||
})
|
||||
parsed := ParseAllowedHosts([]string{"*.example.com", "exact.host.com"})
|
||||
|
||||
require.True(t, d.matchesWildcard("foo.example.com"))
|
||||
require.True(t, d.matchesWildcard("bar.example.com"))
|
||||
require.True(t, d.matchesWildcard("deep.nested.example.com"))
|
||||
tests := []struct {
|
||||
host string
|
||||
want bool
|
||||
}{
|
||||
{"foo.example.com", true},
|
||||
{"bar.example.com", true},
|
||||
{"deep.nested.example.com", true},
|
||||
{"example.com", false},
|
||||
{"notexample.com", false},
|
||||
{"exact.host.com", false},
|
||||
}
|
||||
|
||||
require.False(t, d.matchesWildcard("example.com"))
|
||||
require.False(t, d.matchesWildcard("notexample.com"))
|
||||
require.False(t, d.matchesWildcard("exact.host.com"))
|
||||
for _, tc := range tests {
|
||||
got := matchesWildcard(tc.host, parsed.Wilds)
|
||||
require.Equal(t, tc.want, got, "host: %s", tc.host)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSafeDialer_MixedHosts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
d := newSafeDialer(Policy{
|
||||
Mode: ModeEnforce,
|
||||
AllowedHosts: []string{"example.com", "*.internal.net", "10.0.0.0/8", "1.2.3.4"},
|
||||
})
|
||||
|
||||
require.True(t, d.allowedHosts["example.com"])
|
||||
require.Contains(t, d.allowedWilds, ".internal.net")
|
||||
require.Len(t, d.allowedNets, 2) // 10.0.0.0/8 and 1.2.3.4/32
|
||||
}
|
||||
|
||||
func TestConfigure_Disabled(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
func TestConfigure_SetsDialer(t *testing.T) {
|
||||
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"}))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, globalDialer.Load())
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
}
|
||||
|
||||
Configure(Policy{})
|
||||
require.Nil(t, globalDialer.Load())
|
||||
func TestConfigure_NilServicesReturnsError(t *testing.T) {
|
||||
err := Configure(nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestWrapTransport_NoPolicy(t *testing.T) {
|
||||
@@ -86,7 +125,8 @@ func TestWrapTransport_NoPolicy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWrapTransport_WithPolicy(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"}))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
base := &http.Transport{}
|
||||
@@ -95,36 +135,144 @@ func TestWrapTransport_WithPolicy(t *testing.T) {
|
||||
require.NotNil(t, result.DialContext)
|
||||
}
|
||||
|
||||
func TestCheckURL_Disabled(t *testing.T) {
|
||||
globalDialer.Store(nil)
|
||||
func TestCheckURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode portainer.SSRFMode
|
||||
entries []string
|
||||
url string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "disabled",
|
||||
mode: portainer.SSRFModeOff,
|
||||
url: "http://169.254.169.254/latest/meta-data/",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "blocks IP not in allowlist",
|
||||
mode: portainer.SSRFModeEnforce,
|
||||
entries: []string{"8.8.8.0/24"},
|
||||
url: "http://169.254.169.254/latest/meta-data/",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "allowed exact hostname",
|
||||
mode: portainer.SSRFModeEnforce,
|
||||
entries: []string{"example.com"},
|
||||
url: "https://example.com/path",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "audit mode allows blocked IP",
|
||||
mode: portainer.SSRFModeAudit,
|
||||
entries: []string{"8.8.8.0/24"},
|
||||
url: "http://169.254.169.254/latest/meta-data/",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "IP in CIDR allowlist",
|
||||
mode: portainer.SSRFModeEnforce,
|
||||
entries: []string{"8.8.8.0/24"},
|
||||
url: "http://8.8.8.8/path",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard hostname",
|
||||
mode: portainer.SSRFModeEnforce,
|
||||
entries: []string{"*.example.com"},
|
||||
url: "https://api.example.com/path",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "hostname DNS resolves to allowed IP",
|
||||
mode: portainer.SSRFModeEnforce,
|
||||
entries: []string{"127.0.0.0/8", "::1/128"},
|
||||
url: "http://localhost/path",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "hostname DNS resolves to blocked IP",
|
||||
mode: portainer.SSRFModeEnforce,
|
||||
entries: []string{"8.8.8.0/24"},
|
||||
url: "http://localhost/path",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "audit mode allows hostname resolving to blocked IP",
|
||||
mode: portainer.SSRFModeAudit,
|
||||
entries: []string{"8.8.8.0/24"},
|
||||
url: "http://localhost/path",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid URL",
|
||||
mode: portainer.SSRFModeEnforce,
|
||||
entries: []string{"example.com"},
|
||||
url: "http://%gg",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty host",
|
||||
mode: portainer.SSRFModeEnforce,
|
||||
entries: []string{"example.com"},
|
||||
url: "http://",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/")
|
||||
require.NoError(t, err)
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := Configure(newStaticService(tc.mode, tc.entries))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err = CheckURL(t.Context(), tc.url)
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "ssrf")
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckURL_BlocksIPNotInAllowlist(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}})
|
||||
// TestCheckURL_HostnameDNSError verifies that a DNS resolution failure is
|
||||
// propagated as an SSRF-prefixed error. Kept separate because it needs a
|
||||
// cancelled context rather than t.Context().
|
||||
func TestCheckURL_HostnameDNSError(t *testing.T) {
|
||||
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"8.8.8.0/24"}))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/")
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel()
|
||||
|
||||
err = CheckURL(ctx, "http://portainer-nonexistent.test.invalid/path")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "ssrf")
|
||||
}
|
||||
|
||||
func TestCheckURL_AllowedHostname(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
func TestIsEnabled(t *testing.T) {
|
||||
globalDialer.Store(nil)
|
||||
require.False(t, IsEnabled())
|
||||
|
||||
err := CheckURL(t.Context(), "https://example.com/path")
|
||||
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"}))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
require.True(t, IsEnabled())
|
||||
|
||||
err = Configure(newStaticService(portainer.SSRFModeOff, nil))
|
||||
require.NoError(t, err)
|
||||
require.False(t, IsEnabled())
|
||||
}
|
||||
|
||||
func TestCheckURL_AuditMode_ReturnsNil(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.0/24"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
func TestWrapTransportInternal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := CheckURL(t.Context(), "http://169.254.169.254/latest/meta-data/")
|
||||
require.NoError(t, err)
|
||||
base := &http.Transport{}
|
||||
result := WrapTransportInternal(base)
|
||||
require.Equal(t, base, result)
|
||||
}
|
||||
|
||||
// TestDialContext_BlocksLoopback is an end-to-end test: it starts a real HTTP
|
||||
@@ -136,7 +284,8 @@ func TestDialContext_BlocksLoopback(t *testing.T) {
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.8"}})
|
||||
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"8.8.8.8"}))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
blocked := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
||||
@@ -147,7 +296,9 @@ func TestDialContext_BlocksLoopback(t *testing.T) {
|
||||
require.NoError(t, resp.Body.Close())
|
||||
}
|
||||
|
||||
Configure(Policy{})
|
||||
// Switch to off mode — dialer stays configured but checks are skipped.
|
||||
err = Configure(newStaticService(portainer.SSRFModeOff, nil))
|
||||
require.NoError(t, err)
|
||||
|
||||
open := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
||||
resp, err = open.Get(srv.URL)
|
||||
@@ -163,7 +314,8 @@ func TestDialContext_AuditMode_AllowsLoopback(t *testing.T) {
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.8"}})
|
||||
err := Configure(newStaticService(portainer.SSRFModeAudit, []string{"8.8.8.8"}))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
||||
@@ -172,122 +324,15 @@ func TestDialContext_AuditMode_AllowsLoopback(t *testing.T) {
|
||||
require.NoError(t, resp.Body.Close())
|
||||
}
|
||||
|
||||
func TestIsEnabled(t *testing.T) {
|
||||
globalDialer.Store(nil)
|
||||
require.False(t, IsEnabled())
|
||||
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
require.True(t, IsEnabled())
|
||||
}
|
||||
|
||||
func TestWrapTransportInternal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := &http.Transport{}
|
||||
result := WrapTransportInternal(base)
|
||||
require.Equal(t, base, result)
|
||||
}
|
||||
|
||||
func TestNewSafeDialer_IPv6SingleIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
d := newSafeDialer(Policy{
|
||||
Mode: ModeEnforce,
|
||||
AllowedHosts: []string{"::1"},
|
||||
})
|
||||
|
||||
require.True(t, d.ipAllowed(net.ParseIP("::1")))
|
||||
require.False(t, d.ipAllowed(net.ParseIP("::2")))
|
||||
}
|
||||
|
||||
func TestCheckURL_InvalidURL(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "http://%gg")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "ssrf")
|
||||
}
|
||||
|
||||
func TestCheckURL_EmptyHost(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "http://")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestCheckURL_IPInAllowlist verifies that a literal IP address that falls
|
||||
// within an allowed CIDR range is permitted.
|
||||
func TestCheckURL_IPInAllowlist(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "http://8.8.8.8/path")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCheckURL_WildcardHostname(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"*.example.com"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "https://api.example.com/path")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestCheckURL_HostnameDNSResolvesToAllowedIP verifies that a hostname
|
||||
// resolving to an IP within the allowlist is permitted (DNS resolution path).
|
||||
func TestCheckURL_HostnameDNSResolvesToAllowedIP(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"127.0.0.0/8", "::1/128"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "http://localhost/path")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestCheckURL_HostnameDNSResolvesToBlockedIP verifies that a hostname
|
||||
// resolving to an IP outside the allowlist is blocked (DNS resolution path).
|
||||
func TestCheckURL_HostnameDNSResolvesToBlockedIP(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "http://localhost/path")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "ssrf")
|
||||
}
|
||||
|
||||
// TestCheckURL_HostnameDNSAuditMode verifies that audit mode logs violations
|
||||
// from hostname DNS resolution but still returns nil.
|
||||
func TestCheckURL_HostnameDNSAuditMode(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeAudit, AllowedHosts: []string{"8.8.8.0/24"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
err := CheckURL(t.Context(), "http://localhost/path")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestCheckURL_HostnameDNSError verifies that a DNS resolution failure is
|
||||
// propagated as an SSRF-prefixed error.
|
||||
func TestCheckURL_HostnameDNSError(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"8.8.8.0/24"}})
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel()
|
||||
|
||||
err := CheckURL(ctx, "http://portainer-nonexistent.test.invalid/path")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// TestDialContext_InvalidAddress verifies that an address without a port
|
||||
// returns an SSRF-prefixed error.
|
||||
func TestDialContext_InvalidAddress(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"example.com"}})
|
||||
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"}))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
d := globalDialer.Load()
|
||||
_, err := d.DialContext(t.Context(), "tcp", "no-port-here")
|
||||
_, err = d.DialContext(t.Context(), "tcp", "no-port-here")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "ssrf")
|
||||
}
|
||||
@@ -295,14 +340,15 @@ func TestDialContext_InvalidAddress(t *testing.T) {
|
||||
// TestDialContext_DNSError verifies that a DNS resolution failure in
|
||||
// DialContext is propagated as an SSRF-prefixed error.
|
||||
func TestDialContext_DNSError(t *testing.T) {
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{}})
|
||||
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{}))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel()
|
||||
|
||||
d := globalDialer.Load()
|
||||
_, err := d.DialContext(ctx, "tcp", "portainer-nonexistent.test.invalid:80")
|
||||
_, err = d.DialContext(ctx, "tcp", "portainer-nonexistent.test.invalid:80")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -314,7 +360,8 @@ func TestDialContext_AllowedByCIDR(t *testing.T) {
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"127.0.0.0/8"}})
|
||||
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"127.0.0.0/8"}))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
||||
@@ -347,7 +394,8 @@ func TestDialContext_AllowedByExactHostname(t *testing.T) {
|
||||
_, portStr, err := net.SplitHostPort(l.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
Configure(Policy{Mode: ModeEnforce, AllowedHosts: []string{"localhost"}})
|
||||
err = Configure(newStaticService(portainer.SSRFModeEnforce, []string{"localhost"}))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
||||
|
||||
Reference in New Issue
Block a user