feat(ssrf): add missing transport wrappings and more checks BE-13021 (#2968)

This commit is contained in:
andres-portainer
2026-06-19 20:26:03 -03:00
committed by GitHub
parent cc45af2873
commit 26334e9088
38 changed files with 349 additions and 130 deletions
+3
View File
@@ -108,6 +108,9 @@ linters:
linters: linters:
- gocritic - gocritic
text: ruleguard text: ruleguard
- path: pkg/libhttp/ssrf/builder\.go
linters:
- forbidigo
paths: paths:
- third_party$ - third_party$
- builtin$ - builtin$
+58 -12
View File
@@ -4,26 +4,72 @@ package gorules
import "github.com/quasilyte/go-ruleguard/dsl" import "github.com/quasilyte/go-ruleguard/dsl"
// unwrappedHTTPTransport flags http.Transport composite literals that are not // unwrappedHTTPTransport flags any bare http.Transport composite literal.
// the direct argument to ssrf.WrapTransport. // All transports must be created via ssrf.NewTransport or ssrf.NewInternalTransport,
// which clone http.DefaultTransport and handle SSRF protection internally.
func unwrappedHTTPTransport(m dsl.Matcher) { func unwrappedHTTPTransport(m dsl.Matcher) {
// Inline construction passed to a function call.
m.Match(`$f(&http.Transport{$*_})`). m.Match(`$f(&http.Transport{$*_})`).
Where(m["f"].Text != "ssrf.WrapTransport" && m["f"].Text != "WrapTransport" && Report(`$f receives a bare *http.Transport; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`)
m["f"].Text != "ssrf.WrapTransportInternal" && m["f"].Text != "WrapTransportInternal").
Report(`$f receives a bare *http.Transport; wrap with ssrf.WrapTransport() to enforce the SSRF protection policy`)
// Variable assigned a bare transport (cannot be tracked to a later WrapTransport call).
m.Match(`$_ := &http.Transport{$*_}`). m.Match(`$_ := &http.Transport{$*_}`).
Report(`bare *http.Transport variable; use ssrf.WrapTransport(&http.Transport{...}) inline instead`) Report(`bare *http.Transport variable; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`)
m.Match(`$_.Transport = &http.Transport{$*_}`).
Report(`bare *http.Transport field assignment; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`)
} }
// internalTransportMisuse flags calls to WrapTransportInternal outside the four proxy // helmGetterTransport flags getter.WithTransport calls that receive a bare *http.Transport.
// Helm v4 installs its own transport and bypasses http.DefaultTransport, so the transport
// passed here must be created via ssrf.NewTransport.
func helmGetterTransport(m dsl.Matcher) {
m.Match(`getter.WithTransport(&http.Transport{$*_})`).
Report(`getter.WithTransport called with a bare *http.Transport; use ssrf.NewTransport(tlsConfig) as Helm v4 bypasses http.DefaultTransport`)
}
// cloneDefaultTransport flags direct clones of *http.Transport outside main.go.
// The one legitimate clone is in main.go where http.DefaultTransport is globally
// wrapped with SSRF protection at server startup.
func cloneDefaultTransport(m dsl.Matcher) {
m.Match(`$_.(*http.Transport).Clone()`).
Where(!m.File().Name.Matches(`^main\.go$`)).
Report(`cloning *http.Transport directly is forbidden; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`)
}
// internalTransportMisuse flags calls to NewInternalTransport outside the proxy
// factory files where Chisel-tunnel and in-cluster K8s destinations are valid exemptions. // factory files where Chisel-tunnel and in-cluster K8s destinations are valid exemptions.
func internalTransportMisuse(m dsl.Matcher) { func internalTransportMisuse(m dsl.Matcher) {
m.Match(`ssrf.WrapTransportInternal($*_)`). m.Match(`ssrf.NewInternalTransport($*_)`).
Where( Where(
!(m.File().PkgPath.Matches(`proxy/factory`) && !(m.File().PkgPath.Matches(`proxy/factory`) &&
m.File().Name.Matches(`^(docker|agent|local_transport|edge_transport)\.go$`))). m.File().Name.Matches(`^(docker|agent|local_transport|edge_transport|docker_unix|docker_windows)\.go$`))).
Report(`WrapTransportInternal bypasses SSRF validation; only valid in the kubernetes local/edge transport constructors and the docker/agent proxy factories`) Report(`NewInternalTransport bypasses SSRF validation; only valid in the proxy factory files for local sockets and internally-routed endpoints`)
}
// dialerOverride flags direct assignments to any of the dialer fields on a transport.
// The only valid assignments are in docker_unix.go and docker_windows.go where a
// custom dialer is required for unix sockets and named pipes.
func dialerOverride(m dsl.Matcher) {
m.Match(`$_.DialContext = $*_`).
Where(
!(m.File().PkgPath.Matches(`proxy/factory`) &&
m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))).
Report(`direct DialContext assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`)
m.Match(`$_.Dial = $*_`).
Where(
!(m.File().PkgPath.Matches(`proxy/factory`) &&
m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))).
Report(`direct Dial assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`)
m.Match(`$_.DialTLSContext = $*_`).
Where(
!(m.File().PkgPath.Matches(`proxy/factory`) &&
m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))).
Report(`direct DialTLSContext assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`)
m.Match(`$_.DialTLS = $*_`).
Where(
!(m.File().PkgPath.Matches(`proxy/factory`) &&
m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))).
Report(`direct DialTLS assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`)
} }
+7 -1
View File
@@ -1,6 +1,7 @@
package agent package agent
import ( import (
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@@ -11,6 +12,7 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/url" "github.com/portainer/portainer/api/url"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@@ -19,10 +21,14 @@ import (
// //
// it sends a ping to the agent and parses the version and platform from the headers // it sends a ping to the agent and parses the version and platform from the headers
func GetAgentVersionAndPlatform(endpointUrl string, tlsConfig *tls.Config) (portainer.AgentPlatform, string, error) { //nolint:forbidigo func GetAgentVersionAndPlatform(endpointUrl string, tlsConfig *tls.Config) (portainer.AgentPlatform, string, error) { //nolint:forbidigo
if err := ssrf.CheckURL(context.Background(), endpointUrl); err != nil {
return 0, "", err
}
httpCli := &http.Client{Timeout: 3 * time.Second} httpCli := &http.Client{Timeout: 3 * time.Second}
if tlsConfig != nil { if tlsConfig != nil {
httpCli.Transport = &http.Transport{TLSClientConfig: tlsConfig} httpCli.Transport = ssrf.NewTransport(tlsConfig)
} }
parsedURL, err := url.ParseURL(endpointUrl + "/ping") parsedURL, err := url.ParseURL(endpointUrl + "/ping")
+8 -2
View File
@@ -58,7 +58,10 @@ import (
libswarm "github.com/portainer/portainer/pkg/libstack/swarm" libswarm "github.com/portainer/portainer/pkg/libstack/swarm"
"github.com/portainer/portainer/pkg/validate" "github.com/portainer/portainer/pkg/validate"
gogitclient "github.com/go-git/go-git/v5/plumbing/transport/client"
gogitraw "github.com/go-git/go-git/v5/plumbing/transport/git"
gogithttp "github.com/go-git/go-git/v5/plumbing/transport/http" gogithttp "github.com/go-git/go-git/v5/plumbing/transport/http"
gogitssh "github.com/go-git/go-git/v5/plumbing/transport/ssh"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@@ -408,11 +411,14 @@ func buildServer(flags *portainer.CLIFlags, shutdownCtx context.Context, shutdow
log.Fatal().Err(err).Msg("failed initializing ssrf service") log.Fatal().Err(err).Msg("failed initializing ssrf service")
} }
if dt, ok := nethttp.DefaultTransport.(*nethttp.Transport); ok { if !ssrf.WrapDefaultTransport() {
nethttp.DefaultTransport = ssrf.WrapTransport(dt) log.Fatal().Msg("failed to wrap default HTTP transport with SSRF protection")
} }
gogithttp.DefaultClient = gogithttp.NewClient(&nethttp.Client{Transport: nethttp.DefaultTransport}) gogithttp.DefaultClient = gogithttp.NewClient(&nethttp.Client{Transport: nethttp.DefaultTransport})
gogitclient.InstallProtocol("git", git.NewSSRFGitTransport(gogitraw.DefaultClient))
gogitclient.InstallProtocol("ssh", git.NewSSRFGitTransport(gogitssh.DefaultClient))
gogitclient.InstallProtocol("file", nil)
instanceID, err := dataStore.Version().InstanceID() instanceID, err := dataStore.Version().InstanceID()
if err != nil { if err != nil {
+2 -2
View File
@@ -194,11 +194,11 @@ func httpClient(endpoint *portainer.Endpoint, timeout *time.Duration) (*http.Cli
} }
transport = &NodeNameTransport{ transport = &NodeNameTransport{
Transport: ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig}), Transport: ssrf.NewTransport(tlsConfig),
} }
} else { } else {
transport = &NodeNameTransport{ transport = &NodeNameTransport{
Transport: ssrf.WrapTransport(&http.Transport{}), Transport: ssrf.NewTransport(nil),
} }
} }
+2 -5
View File
@@ -66,11 +66,8 @@ func NewAzureClient() *azureClient {
func newHttpClientForAzure(insecureSkipVerify bool) *http.Client { func newHttpClientForAzure(insecureSkipVerify bool) *http.Client {
return &http.Client{ return &http.Client{
Transport: ssrf.WrapTransport(&http.Transport{ Transport: ssrf.NewTransport(crypto.CreateTLSConfiguration(insecureSkipVerify)),
TLSClientConfig: crypto.CreateTLSConfiguration(insecureSkipVerify), Timeout: 300 * time.Second,
Proxy: http.ProxyFromEnvironment,
}),
Timeout: 300 * time.Second,
} }
} }
+53
View File
@@ -0,0 +1,53 @@
package git
import (
"context"
"fmt"
"net"
"strconv"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
gittransport "github.com/go-git/go-git/v5/plumbing/transport"
)
const gitDefaultPort = 9418
// ssrfGitTransport wraps a git:// transport and validates the resolved IP
// against the SSRF policy before establishing connections.
type ssrfGitTransport struct {
inner gittransport.Transport
}
// NewSSRFGitTransport wraps inner and blocks connections to private IP ranges
// according to the active SSRF policy.
func NewSSRFGitTransport(inner gittransport.Transport) gittransport.Transport {
return &ssrfGitTransport{inner: inner}
}
func (t *ssrfGitTransport) NewUploadPackSession(ep *gittransport.Endpoint, auth gittransport.AuthMethod) (gittransport.UploadPackSession, error) {
if err := checkEndpointSSRF(ep); err != nil {
return nil, err
}
return t.inner.NewUploadPackSession(ep, auth)
}
func (t *ssrfGitTransport) NewReceivePackSession(ep *gittransport.Endpoint, auth gittransport.AuthMethod) (gittransport.ReceivePackSession, error) {
if err := checkEndpointSSRF(ep); err != nil {
return nil, err
}
return t.inner.NewReceivePackSession(ep, auth)
}
func checkEndpointSSRF(ep *gittransport.Endpoint) error {
port := ep.Port
if port <= 0 {
port = gitDefaultPort
}
rawURL := fmt.Sprintf("git://%s/", net.JoinHostPort(ep.Host, strconv.Itoa(port)))
return ssrf.CheckURL(context.Background(), rawURL)
}
+2 -2
View File
@@ -125,9 +125,9 @@ func ExecutePingOperation(host string, tlsConfiguration portainer.TLSConfigurati
} }
scheme = "https" scheme = "https"
transport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig}) transport = ssrf.NewTransport(tlsConfig)
} else { } else {
transport = ssrf.WrapTransport(&http.Transport{}) transport = ssrf.NewTransport(nil)
} }
client := &http.Client{ client := &http.Client{
@@ -19,6 +19,7 @@ import (
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response" "github.com/portainer/portainer/pkg/libhttp/response"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/portainer/portainer/pkg/validate" "github.com/portainer/portainer/pkg/validate"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@@ -315,6 +316,10 @@ func (handler *Handler) createCustomTemplateFromGitRepository(r *http.Request) (
return nil, httpErr return nil, httpErr
} }
if err := ssrf.CheckURL(r.Context(), gitConfig.URL); err != nil {
return nil, err
}
commitHash, err := stackutils.DownloadGitRepository(context.TODO(), gitConfig, handler.GitService, getProjectPath) commitHash, err := stackutils.DownloadGitRepository(context.TODO(), gitConfig, handler.GitService, getProjectPath)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -14,6 +14,7 @@ import (
"github.com/portainer/portainer/api/stacks/stackutils" "github.com/portainer/portainer/api/stacks/stackutils"
"github.com/portainer/portainer/pkg/edge" "github.com/portainer/portainer/pkg/edge"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/portainer/portainer/pkg/validate" "github.com/portainer/portainer/pkg/validate"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -69,7 +70,6 @@ func (payload *edgeStackFromGitRepositoryPayload) Validate(r *http.Request) erro
if len(payload.RepositoryURL) == 0 || !validate.IsURL(payload.RepositoryURL) { if len(payload.RepositoryURL) == 0 || !validate.IsURL(payload.RepositoryURL) {
return httperrors.NewInvalidPayloadError("Invalid repository URL. Must correspond to a valid URL format") return httperrors.NewInvalidPayloadError("Invalid repository URL. Must correspond to a valid URL format")
} }
if payload.RepositoryAuthentication && len(payload.RepositoryPassword) == 0 { if payload.RepositoryAuthentication && len(payload.RepositoryPassword) == 0 {
return httperrors.NewInvalidPayloadError("Invalid repository credentials. Password must be specified when authentication is enabled") return httperrors.NewInvalidPayloadError("Invalid repository credentials. Password must be specified when authentication is enabled")
} }
@@ -138,6 +138,10 @@ func (handler *Handler) createEdgeStackFromGitRepository(r *http.Request, tx dat
return nil, httpErr return nil, httpErr
} }
if err := ssrf.CheckURL(r.Context(), repoConfig.URL); err != nil {
return nil, errors.Wrap(err, "repository URL blocked by SSRF policy")
}
stack.CreatedByUserId = fmt.Sprintf("%d", tokenData.ID) stack.CreatedByUserId = fmt.Sprintf("%d", tokenData.ID)
stack.CreatedBy = stackutils.SanitizeLabel(tokenData.Username) stack.CreatedBy = stackutils.SanitizeLabel(tokenData.Username)
@@ -12,6 +12,7 @@ import (
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response" "github.com/portainer/portainer/pkg/libhttp/response"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/portainer/portainer/pkg/validate" "github.com/portainer/portainer/pkg/validate"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@@ -100,6 +101,10 @@ func (handler *Handler) gitOperationRepoFilePreview(w http.ResponseWriter, r *ht
tlsSkipVerify = src.Git.TLSSkipVerify tlsSkipVerify = src.Git.TLSSkipVerify
} }
if err := ssrf.CheckURL(r.Context(), repoURL); err != nil {
return httperror.BadRequest("Repository URL blocked by SSRF policy", err)
}
projectPath, err := handler.fileService.GetTemporaryPath() projectPath, err := handler.fileService.GetTemporaryPath()
if err != nil { if err != nil {
return httperror.InternalServerError("Unable to create temporary folder", err) return httperror.InternalServerError("Unable to create temporary folder", err)
@@ -12,6 +12,7 @@ import (
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response" "github.com/portainer/portainer/pkg/libhttp/response"
"github.com/portainer/portainer/pkg/validate"
) )
// GitAuthenticationPayload holds authentication parameters for a git source // GitAuthenticationPayload holds authentication parameters for a git source
@@ -30,8 +31,8 @@ type GitSourceCreatePayload struct {
// Validate implements the portainer.Validatable interface // Validate implements the portainer.Validatable interface
func (payload *GitSourceCreatePayload) Validate(_ *http.Request) error { func (payload *GitSourceCreatePayload) Validate(_ *http.Request) error {
if strings.TrimSpace(payload.URL) == "" { if !validate.IsURL(payload.URL) {
return errors.New("url is required") return errors.New("invalid repository URL. Must correspond to a valid URL format")
} }
return nil return nil
@@ -12,6 +12,7 @@ import (
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response" "github.com/portainer/portainer/pkg/libhttp/response"
"github.com/portainer/portainer/pkg/validate"
) )
var ( var (
@@ -35,6 +36,10 @@ type GitAuthenticationUpdatePayload struct {
// Validate implements the portainer.Validatable interface // Validate implements the portainer.Validatable interface
func (payload *GitSourceUpdatePayload) Validate(_ *http.Request) error { func (payload *GitSourceUpdatePayload) Validate(_ *http.Request) error {
if payload.URL != nil && !validate.IsURL(*payload.URL) {
return errors.New("invalid repository URL. Must correspond to a valid URL format")
}
return nil return nil
} }
+7 -1
View File
@@ -8,6 +8,7 @@ import (
"github.com/portainer/portainer/pkg/libhelm/options" "github.com/portainer/portainer/pkg/libhelm/options"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -45,6 +46,10 @@ func (handler *Handler) helmRepoSearch(w http.ResponseWriter, r *http.Request) *
return httperror.BadRequest("Bad request", errors.Wrap(err, fmt.Sprintf("provided URL %q is not valid", repo))) return httperror.BadRequest("Bad request", errors.Wrap(err, fmt.Sprintf("provided URL %q is not valid", repo)))
} }
if err := ssrf.CheckURL(r.Context(), repo); err != nil {
return httperror.BadRequest("Repository URL blocked by SSRF policy", err)
}
searchOpts := options.SearchRepoOptions{ searchOpts := options.SearchRepoOptions{
Repo: repo, Repo: repo,
Chart: chart, Chart: chart,
@@ -53,7 +58,8 @@ func (handler *Handler) helmRepoSearch(w http.ResponseWriter, r *http.Request) *
result, err := handler.helmPackageManager.SearchRepo(searchOpts) result, err := handler.helmPackageManager.SearchRepo(searchOpts)
if err != nil { if err != nil {
return httperror.InternalServerError("Search failed", err) log.Warn().Err(err).Str("repo", repo).Msg("helm repo search failed")
return httperror.InternalServerError("Search failed", errors.New("failed to search Helm repository"))
} }
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
+7 -1
View File
@@ -8,6 +8,7 @@ import (
"github.com/portainer/portainer/pkg/libhelm/options" "github.com/portainer/portainer/pkg/libhelm/options"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@@ -41,6 +42,10 @@ func (handler *Handler) helmShow(w http.ResponseWriter, r *http.Request) *httper
return httperror.BadRequest("Bad request", errors.Wrap(err, fmt.Sprintf("provided URL %q is not valid", repo))) return httperror.BadRequest("Bad request", errors.Wrap(err, fmt.Sprintf("provided URL %q is not valid", repo)))
} }
if err := ssrf.CheckURL(r.Context(), repo); err != nil {
return httperror.BadRequest("Repository URL blocked by SSRF policy", err)
}
chart := r.URL.Query().Get("chart") chart := r.URL.Query().Get("chart")
if chart == "" { if chart == "" {
return httperror.BadRequest("Bad request", errors.New("missing `chart` query parameter")) return httperror.BadRequest("Bad request", errors.New("missing `chart` query parameter"))
@@ -65,7 +70,8 @@ func (handler *Handler) helmShow(w http.ResponseWriter, r *http.Request) *httper
} }
result, err := handler.helmPackageManager.Show(showOptions) result, err := handler.helmPackageManager.Show(showOptions)
if err != nil { if err != nil {
return httperror.InternalServerError("Unable to show chart", err) log.Warn().Err(err).Str("repo", repo).Str("chart", chart).Msg("helm show failed")
return httperror.InternalServerError("Unable to show chart", errors.New("failed to retrieve Helm chart information"))
} }
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
@@ -14,6 +14,7 @@ import (
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response" "github.com/portainer/portainer/pkg/libhttp/response"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/portainer/portainer/pkg/validate" "github.com/portainer/portainer/pkg/validate"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -74,6 +75,12 @@ func (payload *settingsUpdatePayload) Validate(r *http.Request) error {
return errors.New("Invalid Helm repository URL. Must correspond to a valid URL format") return errors.New("Invalid Helm repository URL. Must correspond to a valid URL format")
} }
if payload.HelmRepositoryURL != nil && *payload.HelmRepositoryURL != "" {
if err := ssrf.CheckURL(r.Context(), *payload.HelmRepositoryURL); err != nil {
return errors.New("Invalid Helm repository URL. Must correspond to a valid URL format")
}
}
if payload.UserSessionTimeout != nil { if payload.UserSessionTimeout != nil {
if _, err := time.ParseDuration(*payload.UserSessionTimeout); err != nil { if _, err := time.ParseDuration(*payload.UserSessionTimeout); err != nil {
return errors.New("Invalid user session timeout") return errors.New("Invalid user session timeout")
@@ -14,6 +14,7 @@ import (
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response" "github.com/portainer/portainer/pkg/libhttp/response"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/portainer/portainer/pkg/validate" "github.com/portainer/portainer/pkg/validate"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -125,6 +126,10 @@ func (payload *kubernetesManifestURLDeploymentPayload) Validate(r *http.Request)
return errors.New("Invalid manifest URL") return errors.New("Invalid manifest URL")
} }
if err := ssrf.CheckURL(r.Context(), payload.ManifestURL); err != nil {
return err
}
return nil return nil
} }
+6 -1
View File
@@ -11,6 +11,7 @@ import (
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response" "github.com/portainer/portainer/pkg/libhttp/response"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@@ -24,7 +25,11 @@ type addHelmRepoUrlPayload struct {
URL string `json:"url"` URL string `json:"url"`
} }
func (p *addHelmRepoUrlPayload) Validate(_ *http.Request) error { func (p *addHelmRepoUrlPayload) Validate(r *http.Request) error {
if err := ssrf.CheckURL(r.Context(), p.URL); err != nil {
return err
}
return libhelm.ValidateHelmRepositoryURL(p.URL, nil) return libhelm.ValidateHelmRepositoryURL(p.URL, nil)
} }
+4 -4
View File
@@ -52,14 +52,14 @@ func (factory *ProxyFactory) NewAgentProxy(endpoint *portainer.Endpoint) (*Proxy
endpointURL.Scheme = "https" endpointURL.Scheme = "https"
if endpointutils.IsEdgeEndpoint(endpoint) { if endpointutils.IsEdgeEndpoint(endpoint) {
innerTransport = ssrf.WrapTransportInternal(&http.Transport{TLSClientConfig: tlsConfig}) innerTransport = ssrf.NewInternalTransport(tlsConfig)
} else { } else {
innerTransport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig}) innerTransport = ssrf.NewTransport(tlsConfig)
} }
} else if endpointutils.IsEdgeEndpoint(endpoint) { } else if endpointutils.IsEdgeEndpoint(endpoint) {
innerTransport = ssrf.WrapTransportInternal(&http.Transport{}) innerTransport = ssrf.NewInternalTransport(nil)
} else { } else {
innerTransport = ssrf.WrapTransport(&http.Transport{}) innerTransport = ssrf.NewTransport(nil)
} }
proxy := NewSingleHostReverseProxyWithHostHeader(endpointURL) proxy := NewSingleHostReverseProxyWithHostHeader(endpointURL)
+4 -4
View File
@@ -68,14 +68,14 @@ func (factory *ProxyFactory) newDockerHTTPProxy(endpoint *portainer.Endpoint) (h
endpointURL.Scheme = "https" endpointURL.Scheme = "https"
if endpointutils.IsEdgeEndpoint(endpoint) { if endpointutils.IsEdgeEndpoint(endpoint) {
innerTransport = ssrf.WrapTransportInternal(&http.Transport{TLSClientConfig: tlsConfig}) innerTransport = ssrf.NewInternalTransport(tlsConfig)
} else { } else {
innerTransport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig}) innerTransport = ssrf.NewTransport(tlsConfig)
} }
} else if endpointutils.IsEdgeEndpoint(endpoint) { } else if endpointutils.IsEdgeEndpoint(endpoint) {
innerTransport = ssrf.WrapTransportInternal(&http.Transport{}) innerTransport = ssrf.NewInternalTransport(nil)
} else { } else {
innerTransport = ssrf.WrapTransport(&http.Transport{}) innerTransport = ssrf.NewTransport(nil)
} }
dockerTransport, err := docker.NewTransport(transportParameters, innerTransport, factory.gitService, factory.snapshotService) dockerTransport, err := docker.NewTransport(transportParameters, innerTransport, factory.gitService, factory.snapshotService)
@@ -22,6 +22,7 @@ import (
"github.com/portainer/portainer/api/internal/authorization" "github.com/portainer/portainer/api/internal/authorization"
"github.com/portainer/portainer/api/logs" "github.com/portainer/portainer/api/logs"
"github.com/portainer/portainer/api/slicesx" "github.com/portainer/portainer/api/slicesx"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/docker/docker/api/types/network" "github.com/docker/docker/api/types/network"
"github.com/docker/docker/api/types/swarm" "github.com/docker/docker/api/types/swarm"
@@ -506,6 +507,11 @@ func (transport *Transport) updateDefaultGitBranch(request *http.Request) error
} }
repositoryURL := remote[:len(remote)-4] repositoryURL := remote[:len(remote)-4]
if err := ssrf.CheckURL(request.Context(), repositoryURL); err != nil {
return err
}
latestCommitID, err := transport.gitService.LatestCommitID( latestCommitID, err := transport.gitService.LatestCommitID(
request.Context(), request.Context(),
repositoryURL, repositoryURL,
+8 -4
View File
@@ -3,11 +3,13 @@
package factory package factory
import ( import (
"context"
"net" "net"
"net/http" "net/http"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/http/proxy/factory/docker" "github.com/portainer/portainer/api/http/proxy/factory/docker"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
) )
func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portainer.Endpoint) (http.Handler, error) { func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portainer.Endpoint) (http.Handler, error) {
@@ -31,9 +33,11 @@ func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portaine
} }
func newSocketTransport(socketPath string) *http.Transport { func newSocketTransport(socketPath string) *http.Transport {
return &http.Transport{ d := &net.Dialer{}
Dial: func(proto, addr string) (conn net.Conn, err error) { t := ssrf.NewInternalTransport(nil)
return net.Dial("unix", socketPath) t.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
}, return d.DialContext(ctx, "unix", socketPath)
} }
return t
} }
+7 -4
View File
@@ -3,12 +3,14 @@
package factory package factory
import ( import (
"context"
"net" "net"
"net/http" "net/http"
"github.com/Microsoft/go-winio" "github.com/Microsoft/go-winio"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/http/proxy/factory/docker" "github.com/portainer/portainer/api/http/proxy/factory/docker"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
) )
func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portainer.Endpoint) (http.Handler, error) { func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portainer.Endpoint) (http.Handler, error) {
@@ -32,9 +34,10 @@ func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portaine
} }
func newNamedPipeTransport(namedPipePath string) *http.Transport { func newNamedPipeTransport(namedPipePath string) *http.Transport {
return &http.Transport{ t := ssrf.NewInternalTransport(nil)
Dial: func(proto, addr string) (conn net.Conn, err error) { t.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) {
return winio.DialPipe(namedPipePath, nil) return winio.DialPipe(namedPipePath, nil)
},
} }
return t
} }
+1 -1
View File
@@ -94,7 +94,7 @@ func NewHTTPClient(token string) *http.Client {
return &http.Client{ return &http.Client{
Transport: &tokenTransport{ Transport: &tokenTransport{
token: token, token: token,
transport: retry.NewTransport(ssrf.WrapTransport(&http.Transport{})), // Use ORAS retry transport for consistent rate limiting and error handling transport: retry.NewTransport(ssrf.NewTransport(nil)), // Use ORAS retry transport for consistent rate limiting and error handling
}, },
Timeout: 1 * time.Minute, Timeout: 1 * time.Minute,
} }
+2 -2
View File
@@ -94,7 +94,7 @@ type Transport struct {
// interface for proxying requests to the Gitlab API. // interface for proxying requests to the Gitlab API.
func NewTransport() *Transport { func NewTransport() *Transport {
return &Transport{ return &Transport{
httpTransport: ssrf.WrapTransport(&http.Transport{}), httpTransport: ssrf.NewTransport(nil),
} }
} }
@@ -119,7 +119,7 @@ func NewHTTPClient(token string) *http.Client {
return &http.Client{ return &http.Client{
Transport: &tokenTransport{ Transport: &tokenTransport{
token: token, token: token,
transport: retry.NewTransport(ssrf.WrapTransport(&http.Transport{})), // Use ORAS retry transport for consistent rate limiting and error handling transport: retry.NewTransport(ssrf.NewTransport(nil)), // Use ORAS retry transport for consistent rate limiting and error handling
}, },
Timeout: 1 * time.Minute, Timeout: 1 * time.Minute,
} }
@@ -25,9 +25,7 @@ func NewAgentTransport(signatureService portainer.DigitalSignatureService, token
transport := &agentTransport{ transport := &agentTransport{
baseTransport: newBaseTransport( baseTransport: newBaseTransport(
ssrf.WrapTransport(&http.Transport{ ssrf.NewTransport(tlsConfig),
TLSClientConfig: tlsConfig,
}),
tokenManager, tokenManager,
endpoint, endpoint,
k8sClientFactory, k8sClientFactory,
@@ -22,7 +22,7 @@ func NewEdgeTransport(dataStore dataservices.DataStore, signatureService portain
reverseTunnelService: reverseTunnelService, reverseTunnelService: reverseTunnelService,
signatureService: signatureService, signatureService: signatureService,
baseTransport: newBaseTransport( baseTransport: newBaseTransport(
ssrf.WrapTransportInternal(&http.Transport{}), ssrf.NewInternalTransport(nil),
tokenManager, tokenManager,
endpoint, endpoint,
k8sClientFactory, k8sClientFactory,
@@ -23,9 +23,7 @@ func NewLocalTransport(tokenManager *tokenManager, endpoint *portainer.Endpoint,
transport := &localTransport{ transport := &localTransport{
baseTransport: newBaseTransport( baseTransport: newBaseTransport(
ssrf.WrapTransportInternal(&http.Transport{ ssrf.NewInternalTransport(config),
TLSClientConfig: config,
}),
tokenManager, tokenManager,
endpoint, endpoint,
k8sClientFactory, k8sClientFactory,
+20 -10
View File
@@ -2,6 +2,8 @@ package factory
import ( import (
"context" "context"
"net/http"
"net/http/httptest"
"net/http/httputil" "net/http/httputil"
"testing" "testing"
"time" "time"
@@ -65,8 +67,6 @@ func enableSSRF(t *testing.T) {
}) })
} }
// TestNewDockerHTTPProxy_NonEdgeNoTLS verifies that a plain non-edge endpoint
// uses WrapTransport, setting DialContext on the inner transport.
func TestNewDockerHTTPProxy_NonEdgeNoTLS(t *testing.T) { func TestNewDockerHTTPProxy_NonEdgeNoTLS(t *testing.T) {
enableSSRF(t) enableSSRF(t)
@@ -84,8 +84,6 @@ func TestNewDockerHTTPProxy_NonEdgeNoTLS(t *testing.T) {
require.NotNil(t, dt.HTTPTransport.DialContext) require.NotNil(t, dt.HTTPTransport.DialContext)
} }
// TestNewDockerHTTPProxy_NonEdgeTLS verifies that a TLS non-edge endpoint
// uses WrapTransport, setting DialContext on the inner transport.
func TestNewDockerHTTPProxy_NonEdgeTLS(t *testing.T) { func TestNewDockerHTTPProxy_NonEdgeTLS(t *testing.T) {
enableSSRF(t) enableSSRF(t)
@@ -107,11 +105,14 @@ func TestNewDockerHTTPProxy_NonEdgeTLS(t *testing.T) {
require.NotNil(t, dt.HTTPTransport.DialContext) require.NotNil(t, dt.HTTPTransport.DialContext)
} }
// TestNewDockerHTTPProxy_EdgeNoTLS verifies that an edge endpoint without TLS
// uses WrapTransportInternal, leaving DialContext nil.
func TestNewDockerHTTPProxy_EdgeNoTLS(t *testing.T) { func TestNewDockerHTTPProxy_EdgeNoTLS(t *testing.T) {
enableSSRF(t) enableSSRF(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}} f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
endpoint := &portainer.Endpoint{ endpoint := &portainer.Endpoint{
Type: portainer.EdgeAgentOnDockerEnvironment, Type: portainer.EdgeAgentOnDockerEnvironment,
@@ -123,14 +124,20 @@ func TestNewDockerHTTPProxy_EdgeNoTLS(t *testing.T) {
proxy := handler.(*httputil.ReverseProxy) proxy := handler.(*httputil.ReverseProxy)
dt := proxy.Transport.(*docker.Transport) dt := proxy.Transport.(*docker.Transport)
require.Nil(t, dt.HTTPTransport.DialContext)
resp, err := (&http.Client{Transport: dt.HTTPTransport}).Get(srv.URL)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
} }
// TestNewDockerHTTPProxy_EdgeTLS verifies that an edge endpoint with TLS
// uses WrapTransportInternal, leaving DialContext nil.
func TestNewDockerHTTPProxy_EdgeTLS(t *testing.T) { func TestNewDockerHTTPProxy_EdgeTLS(t *testing.T) {
enableSSRF(t) enableSSRF(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}} f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
endpoint := &portainer.Endpoint{ endpoint := &portainer.Endpoint{
Type: portainer.EdgeAgentOnDockerEnvironment, Type: portainer.EdgeAgentOnDockerEnvironment,
@@ -146,7 +153,10 @@ func TestNewDockerHTTPProxy_EdgeTLS(t *testing.T) {
proxy := handler.(*httputil.ReverseProxy) proxy := handler.(*httputil.ReverseProxy)
dt := proxy.Transport.(*docker.Transport) dt := proxy.Transport.(*docker.Transport)
require.Nil(t, dt.HTTPTransport.DialContext)
resp, err := (&http.Client{Transport: dt.HTTPTransport}).Get(srv.URL)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
} }
func TestNewAgentProxy_NonEdgeNoTLS(t *testing.T) { func TestNewAgentProxy_NonEdgeNoTLS(t *testing.T) {
@@ -13,6 +13,7 @@ import (
"github.com/portainer/portainer/api/scheduler" "github.com/portainer/portainer/api/scheduler"
"github.com/portainer/portainer/api/stacks/deployments" "github.com/portainer/portainer/api/stacks/deployments"
"github.com/portainer/portainer/api/stacks/stackutils" "github.com/portainer/portainer/api/stacks/stackutils"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
) )
type GitMethodStackBuilder struct { type GitMethodStackBuilder struct {
@@ -78,6 +79,10 @@ func (b *GitMethodStackBuilder) prepare(ctx context.Context, payload *StackPaylo
return b.fileService.GetStackProjectPath(stackFolder) return b.fileService.GetStackProjectPath(stackFolder)
} }
if err := ssrf.CheckURL(ctx, repoConfig.URL); err != nil {
return fmt.Errorf("repository URL blocked by SSRF policy: %w", err)
}
commitHash, err := stackutils.DownloadGitRepository(ctx, repoConfig, b.gitService, getProjectPath) commitHash, err := stackutils.DownloadGitRepository(ctx, repoConfig, b.gitService, getProjectPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to download git repository: %w", err) return fmt.Errorf("failed to download git repository: %w", err)
+4 -1
View File
@@ -14,6 +14,7 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/logs" "github.com/portainer/portainer/api/logs"
"github.com/portainer/portainer/pkg/libhelm/options" "github.com/portainer/portainer/pkg/libhelm/options"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/portainer/portainer/pkg/liboras" "github.com/portainer/portainer/pkg/liboras"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/segmentio/encoding/json" "github.com/segmentio/encoding/json"
@@ -216,13 +217,15 @@ func downloadRepoIndexFromHttpRepo(repoURLString string, repoSettings *cli.EnvSe
Str("repo_name", repoName). Str("repo_name", repoName).
Msg("Creating chart repository object") Msg("Creating chart repository object")
ssrfTransport := ssrf.NewTransport(nil)
// Create chart repository object // Create chart repository object
rep, err := repo.NewChartRepository( rep, err := repo.NewChartRepository(
&repo.Entry{ &repo.Entry{
Name: repoName, Name: repoName,
URL: repoURLString, URL: repoURLString,
}, },
getter.All(repoSettings), getter.All(repoSettings, getter.WithTransport(ssrfTransport)),
) )
if err != nil { if err != nil {
log.Error(). log.Error().
+5 -2
View File
@@ -8,6 +8,7 @@ import (
"strings" "strings"
"github.com/portainer/portainer/pkg/libhelm/sdk" "github.com/portainer/portainer/pkg/libhelm/sdk"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"helm.sh/helm/v4/pkg/cli" "helm.sh/helm/v4/pkg/cli"
"helm.sh/helm/v4/pkg/getter" "helm.sh/helm/v4/pkg/getter"
repo "helm.sh/helm/v4/pkg/repo/v1" repo "helm.sh/helm/v4/pkg/repo/v1"
@@ -40,12 +41,14 @@ func ValidateHelmRepositoryURL(repoUrl string, _ *http.Client) error {
return fmt.Errorf("failed to derive repo name: %w", err) return fmt.Errorf("failed to derive repo name: %w", err)
} }
ssrfTransport := ssrf.NewTransport(nil)
r, err := repo.NewChartRepository( r, err := repo.NewChartRepository(
&repo.Entry{ &repo.Entry{
Name: repoName, Name: repoName,
URL: repoUrl, URL: repoUrl,
}, },
getter.All(settings), getter.All(settings, getter.WithTransport(ssrfTransport)),
) )
if err != nil { if err != nil {
return fmt.Errorf("%s is not a valid chart repository or cannot be reached: %w", repoUrl, err) return fmt.Errorf("%s is not a valid chart repository or cannot be reached: %w", repoUrl, err)
@@ -53,7 +56,7 @@ func ValidateHelmRepositoryURL(repoUrl string, _ *http.Client) error {
indexPath, err := r.DownloadIndexFile() indexPath, err := r.DownloadIndexFile()
if err != nil { if err != nil {
return fmt.Errorf("%s is not a valid chart repository or cannot be reached: %w", repoUrl, err) return fmt.Errorf("%s is not a valid chart repository or cannot be reached", repoUrl)
} }
// Best-effort: load and seed in-memory cache for future SearchRepo calls // Best-effort: load and seed in-memory cache for future SearchRepo calls
+53
View File
@@ -0,0 +1,53 @@
package ssrf
import (
"crypto/tls"
"net/http"
)
// NewTransport creates an SSRF-protected transport for user-influenced destinations.
// It clones http.DefaultTransport as its base (inheriting pool and timeout defaults)
// and applies the global SSRF dial-context filter so mode changes take effect without
// restarting. tlsConfig may be nil to preserve standard TLS behavior (system CAs).
func NewTransport(tlsConfig *tls.Config) *http.Transport {
base := http.DefaultTransport.(*http.Transport).Clone()
base.TLSClientConfig = tlsConfig
applySSRF(base)
return base
}
// NewInternalTransport creates a plain transport for destinations chosen by Portainer,
// not by the user (Docker socket proxy, Chisel tunnels, in-cluster Kubernetes API).
// It clones http.DefaultTransport as its base. tlsConfig may be nil.
// Using this function instead of NewTransport makes the exemption explicit and
// satisfies the ruleguard lint rule.
func NewInternalTransport(tlsConfig *tls.Config) *http.Transport {
base := http.DefaultTransport.(*http.Transport).Clone()
base.TLSClientConfig = tlsConfig
return base
}
// WrapDefaultTransport replaces http.DefaultTransport with an SSRF-protected version.
// Must be called after Configure. Returns false if DefaultTransport is not an *http.Transport.
func WrapDefaultTransport() bool {
dt, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return false
}
cloned := dt.Clone()
applySSRF(cloned)
http.DefaultTransport = cloned
return true
}
// applySSRF sets the SSRF-filtering DialContext on t when the global dialer is active.
func applySSRF(t *http.Transport) {
d := globalDialer.Load()
if d != nil {
t.DialContext = d.DialContext
}
}
-26
View File
@@ -5,7 +5,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/http"
"net/url" "net/url"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -112,31 +111,6 @@ func CheckURL(ctx context.Context, rawURL string) error {
return d.checkHost(ctx, host) return d.checkHost(ctx, host)
} }
// WrapTransport clones t and replaces its DialContext with the global SSRF-filtering
// dialer. The dialer checks the mode on every connection, so the transport is always
// wrapped and mode changes take effect without restarting.
func WrapTransport(t *http.Transport) *http.Transport {
d := globalDialer.Load()
if d == nil {
return t
}
cloned := t.Clone()
cloned.DialContext = d.DialContext
return cloned
}
// WrapTransportInternal is a documented no-op for transports that connect to
// internally computed destinations (local Docker socket proxy, Chisel tunnels,
// in-cluster Kubernetes API). The destination is chosen by Portainer, not
// supplied by any user, so SSRF validation is not applicable. Using this
// function instead of WrapTransport makes the exemption explicit and
// satisfies the ruleguard lint rule.
func WrapTransportInternal(t *http.Transport) *http.Transport {
return t
}
// DialContext resolves addr, validates all resolved IPs against the allowlist policy, // DialContext resolves addr, validates all resolved IPs against the allowlist policy,
// then dials using the first resolved IP to prevent DNS rebinding attacks. // then dials using the first resolved IP to prevent DNS rebinding attacks.
func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
+22 -17
View File
@@ -116,22 +116,27 @@ func TestConfigure_NilServicesReturnsError(t *testing.T) {
require.Error(t, err) require.Error(t, err)
} }
func TestWrapTransport_NoPolicy(t *testing.T) { func TestNewTransport_NoPolicy(t *testing.T) {
globalDialer.Store(nil) globalDialer.Store(nil)
t.Cleanup(func() { globalDialer.Store(nil) })
base := &http.Transport{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
result := WrapTransport(base) w.WriteHeader(http.StatusOK)
require.Equal(t, base, result) }))
defer srv.Close()
client := &http.Client{Transport: NewTransport(nil)}
resp, err := client.Get(srv.URL)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
} }
func TestWrapTransport_WithPolicy(t *testing.T) { func TestNewTransport_WithPolicy(t *testing.T) {
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"})) err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"}))
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) }) t.Cleanup(func() { globalDialer.Store(nil) })
base := &http.Transport{} result := NewTransport(nil)
result := WrapTransport(base)
require.NotEqual(t, base, result)
require.NotNil(t, result.DialContext) require.NotNil(t, result.DialContext)
} }
@@ -267,12 +272,12 @@ func TestIsEnabled(t *testing.T) {
require.False(t, IsEnabled()) require.False(t, IsEnabled())
} }
func TestWrapTransportInternal(t *testing.T) { func TestNewInternalTransport(t *testing.T) {
t.Parallel() t.Parallel()
base := &http.Transport{} result := NewInternalTransport(nil)
result := WrapTransportInternal(base) require.NotNil(t, result)
require.Equal(t, base, result) require.Nil(t, result.TLSClientConfig)
} }
// TestDialContext_BlocksLoopback is an end-to-end test: it starts a real HTTP // TestDialContext_BlocksLoopback is an end-to-end test: it starts a real HTTP
@@ -288,7 +293,7 @@ func TestDialContext_BlocksLoopback(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) }) t.Cleanup(func() { globalDialer.Store(nil) })
blocked := &http.Client{Transport: WrapTransport(&http.Transport{})} blocked := &http.Client{Transport: NewTransport(nil)}
resp, err := blocked.Get(srv.URL) resp, err := blocked.Get(srv.URL)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "ssrf") require.Contains(t, err.Error(), "ssrf")
@@ -300,7 +305,7 @@ func TestDialContext_BlocksLoopback(t *testing.T) {
err = Configure(newStaticService(portainer.SSRFModeOff, nil)) err = Configure(newStaticService(portainer.SSRFModeOff, nil))
require.NoError(t, err) require.NoError(t, err)
open := &http.Client{Transport: WrapTransport(&http.Transport{})} open := &http.Client{Transport: NewTransport(nil)}
resp, err = open.Get(srv.URL) resp, err = open.Get(srv.URL)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, resp.Body.Close()) require.NoError(t, resp.Body.Close())
@@ -318,7 +323,7 @@ func TestDialContext_AuditMode_AllowsLoopback(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) }) t.Cleanup(func() { globalDialer.Store(nil) })
client := &http.Client{Transport: WrapTransport(&http.Transport{})} client := &http.Client{Transport: NewTransport(nil)}
resp, err := client.Get(srv.URL) resp, err := client.Get(srv.URL)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, resp.Body.Close()) require.NoError(t, resp.Body.Close())
@@ -364,7 +369,7 @@ func TestDialContext_AllowedByCIDR(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) }) t.Cleanup(func() { globalDialer.Store(nil) })
client := &http.Client{Transport: WrapTransport(&http.Transport{})} client := &http.Client{Transport: NewTransport(nil)}
resp, err := client.Get(srv.URL) resp, err := client.Get(srv.URL)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, resp.Body.Close()) require.NoError(t, resp.Body.Close())
@@ -398,7 +403,7 @@ func TestDialContext_AllowedByExactHostname(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) }) t.Cleanup(func() { globalDialer.Store(nil) })
client := &http.Client{Transport: WrapTransport(&http.Transport{})} client := &http.Client{Transport: NewTransport(nil)}
resp, err := client.Get("http://localhost:" + portStr) resp, err := client.Get("http://localhost:" + portStr)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, resp.Body.Close()) require.NoError(t, resp.Body.Close())
+3 -4
View File
@@ -10,6 +10,7 @@ import (
"github.com/portainer/portainer/api/crypto" "github.com/portainer/portainer/api/crypto"
"github.com/portainer/portainer/api/logs" "github.com/portainer/portainer/api/logs"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/segmentio/encoding/json" "github.com/segmentio/encoding/json"
@@ -73,10 +74,8 @@ func ProbeTelnetConnection(url string) string {
// ignores errors for the http request since we want to know if the host is reachable // ignores errors for the http request since we want to know if the host is reachable
func DetectProxy(url string) string { func DetectProxy(url string) string {
client := &http.Client{ client := &http.Client{
Transport: &http.Transport{ Transport: ssrf.NewTransport(crypto.CreateTLSConfiguration(true)),
TLSClientConfig: crypto.CreateTLSConfiguration(true), Timeout: 10 * time.Second,
},
Timeout: 10 * time.Second,
} }
result := map[string]string{ result := map[string]string{
+1 -2
View File
@@ -16,8 +16,7 @@ import (
func CreateClient(registry *portainer.Registry) (httpClient *http.Client, usePlainHttp bool, err error) { func CreateClient(registry *portainer.Registry) (httpClient *http.Client, usePlainHttp bool, err error) {
switch registry.Type { switch registry.Type {
case portainer.AzureRegistry, portainer.EcrRegistry, portainer.GithubRegistry, portainer.GitlabRegistry, portainer.DockerHubRegistry: case portainer.AzureRegistry, portainer.EcrRegistry, portainer.GithubRegistry, portainer.GitlabRegistry, portainer.DockerHubRegistry:
base := http.DefaultTransport.(*http.Transport).Clone() return &http.Client{Transport: retry.NewTransport(ssrf.NewTransport(nil))}, false, nil
return &http.Client{Transport: retry.NewTransport(ssrf.WrapTransport(base))}, false, nil
default: default:
// For all other registry types, use shared helper to build transport and scheme // For all other registry types, use shared helper to build transport and scheme
+11 -12
View File
@@ -6,28 +6,27 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/crypto" "github.com/portainer/portainer/api/crypto"
"github.com/portainer/portainer/pkg/fips" "github.com/portainer/portainer/pkg/fips"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
) )
// BuildTransportAndSchemeFromTLSConfig returns a base HTTP transport configured // BuildTransportAndSchemeFromTLSConfig returns an SSRF-protected HTTP transport and the
// with ProxyFromEnvironment and, when needed, a TLSClientConfig derived from the // scheme ("http" or "https") to use when contacting the registry. The transport is based on
// provided TLS settings. It also returns the scheme ("http" or "https") that // the TLS settings from tlsCfg; pass a zero-value TLSConfiguration for plaintext.
// should be used to contact the registry based on the TLS settings.
func BuildTransportAndSchemeFromTLSConfig(tlsCfg portainer.TLSConfiguration) (*http.Transport, string, error) { func BuildTransportAndSchemeFromTLSConfig(tlsCfg portainer.TLSConfiguration) (*http.Transport, string, error) {
baseTransport := http.DefaultTransport.(*http.Transport).Clone()
baseTransport.Proxy = http.ProxyFromEnvironment
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(tlsCfg) tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(tlsCfg)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
baseTransport.TLSClientConfig = tlsConfig
if tlsConfig == nil && fips.FIPSMode() { if tlsConfig == nil && fips.FIPSMode() {
return nil, "", fips.ErrTLSRequired return nil, "", fips.ErrTLSRequired
} else if tlsConfig == nil {
return baseTransport, "http", nil
} }
return baseTransport, "https", nil transport := ssrf.NewTransport(tlsConfig)
if tlsConfig == nil {
return transport, "http", nil
}
return transport, "https", nil
} }