feat(ssrf): add missing transport wrappings and more checks BE-13021 (#2968)

This commit is contained in:
andres-portainer
2026-06-19 20:26:03 -03:00
committed by GitHub
parent cc45af2873
commit 26334e9088
38 changed files with 349 additions and 130 deletions
+3
View File
@@ -108,6 +108,9 @@ linters:
linters:
- gocritic
text: ruleguard
- path: pkg/libhttp/ssrf/builder\.go
linters:
- forbidigo
paths:
- third_party$
- builtin$
+58 -12
View File
@@ -4,26 +4,72 @@ package gorules
import "github.com/quasilyte/go-ruleguard/dsl"
// unwrappedHTTPTransport flags http.Transport composite literals that are not
// the direct argument to ssrf.WrapTransport.
// unwrappedHTTPTransport flags any bare http.Transport composite literal.
// All transports must be created via ssrf.NewTransport or ssrf.NewInternalTransport,
// which clone http.DefaultTransport and handle SSRF protection internally.
func unwrappedHTTPTransport(m dsl.Matcher) {
// Inline construction passed to a function call.
m.Match(`$f(&http.Transport{$*_})`).
Where(m["f"].Text != "ssrf.WrapTransport" && m["f"].Text != "WrapTransport" &&
m["f"].Text != "ssrf.WrapTransportInternal" && m["f"].Text != "WrapTransportInternal").
Report(`$f receives a bare *http.Transport; wrap with ssrf.WrapTransport() to enforce the SSRF protection policy`)
Report(`$f receives a bare *http.Transport; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`)
// Variable assigned a bare transport (cannot be tracked to a later WrapTransport call).
m.Match(`$_ := &http.Transport{$*_}`).
Report(`bare *http.Transport variable; use ssrf.WrapTransport(&http.Transport{...}) inline instead`)
Report(`bare *http.Transport variable; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`)
m.Match(`$_.Transport = &http.Transport{$*_}`).
Report(`bare *http.Transport field assignment; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`)
}
// internalTransportMisuse flags calls to WrapTransportInternal outside the four proxy
// helmGetterTransport flags getter.WithTransport calls that receive a bare *http.Transport.
// Helm v4 installs its own transport and bypasses http.DefaultTransport, so the transport
// passed here must be created via ssrf.NewTransport.
func helmGetterTransport(m dsl.Matcher) {
m.Match(`getter.WithTransport(&http.Transport{$*_})`).
Report(`getter.WithTransport called with a bare *http.Transport; use ssrf.NewTransport(tlsConfig) as Helm v4 bypasses http.DefaultTransport`)
}
// cloneDefaultTransport flags direct clones of *http.Transport outside main.go.
// The one legitimate clone is in main.go where http.DefaultTransport is globally
// wrapped with SSRF protection at server startup.
func cloneDefaultTransport(m dsl.Matcher) {
m.Match(`$_.(*http.Transport).Clone()`).
Where(!m.File().Name.Matches(`^main\.go$`)).
Report(`cloning *http.Transport directly is forbidden; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`)
}
// internalTransportMisuse flags calls to NewInternalTransport outside the proxy
// factory files where Chisel-tunnel and in-cluster K8s destinations are valid exemptions.
func internalTransportMisuse(m dsl.Matcher) {
m.Match(`ssrf.WrapTransportInternal($*_)`).
m.Match(`ssrf.NewInternalTransport($*_)`).
Where(
!(m.File().PkgPath.Matches(`proxy/factory`) &&
m.File().Name.Matches(`^(docker|agent|local_transport|edge_transport)\.go$`))).
Report(`WrapTransportInternal bypasses SSRF validation; only valid in the kubernetes local/edge transport constructors and the docker/agent proxy factories`)
m.File().Name.Matches(`^(docker|agent|local_transport|edge_transport|docker_unix|docker_windows)\.go$`))).
Report(`NewInternalTransport bypasses SSRF validation; only valid in the proxy factory files for local sockets and internally-routed endpoints`)
}
// dialerOverride flags direct assignments to any of the dialer fields on a transport.
// The only valid assignments are in docker_unix.go and docker_windows.go where a
// custom dialer is required for unix sockets and named pipes.
func dialerOverride(m dsl.Matcher) {
m.Match(`$_.DialContext = $*_`).
Where(
!(m.File().PkgPath.Matches(`proxy/factory`) &&
m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))).
Report(`direct DialContext assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`)
m.Match(`$_.Dial = $*_`).
Where(
!(m.File().PkgPath.Matches(`proxy/factory`) &&
m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))).
Report(`direct Dial assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`)
m.Match(`$_.DialTLSContext = $*_`).
Where(
!(m.File().PkgPath.Matches(`proxy/factory`) &&
m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))).
Report(`direct DialTLSContext assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`)
m.Match(`$_.DialTLS = $*_`).
Where(
!(m.File().PkgPath.Matches(`proxy/factory`) &&
m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))).
Report(`direct DialTLS assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`)
}
+7 -1
View File
@@ -1,6 +1,7 @@
package agent
import (
"context"
"crypto/tls"
"errors"
"fmt"
@@ -11,6 +12,7 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/url"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/rs/zerolog/log"
)
@@ -19,10 +21,14 @@ import (
//
// it sends a ping to the agent and parses the version and platform from the headers
func GetAgentVersionAndPlatform(endpointUrl string, tlsConfig *tls.Config) (portainer.AgentPlatform, string, error) { //nolint:forbidigo
if err := ssrf.CheckURL(context.Background(), endpointUrl); err != nil {
return 0, "", err
}
httpCli := &http.Client{Timeout: 3 * time.Second}
if tlsConfig != nil {
httpCli.Transport = &http.Transport{TLSClientConfig: tlsConfig}
httpCli.Transport = ssrf.NewTransport(tlsConfig)
}
parsedURL, err := url.ParseURL(endpointUrl + "/ping")
+8 -2
View File
@@ -58,7 +58,10 @@ import (
libswarm "github.com/portainer/portainer/pkg/libstack/swarm"
"github.com/portainer/portainer/pkg/validate"
gogitclient "github.com/go-git/go-git/v5/plumbing/transport/client"
gogitraw "github.com/go-git/go-git/v5/plumbing/transport/git"
gogithttp "github.com/go-git/go-git/v5/plumbing/transport/http"
gogitssh "github.com/go-git/go-git/v5/plumbing/transport/ssh"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
@@ -408,11 +411,14 @@ func buildServer(flags *portainer.CLIFlags, shutdownCtx context.Context, shutdow
log.Fatal().Err(err).Msg("failed initializing ssrf service")
}
if dt, ok := nethttp.DefaultTransport.(*nethttp.Transport); ok {
nethttp.DefaultTransport = ssrf.WrapTransport(dt)
if !ssrf.WrapDefaultTransport() {
log.Fatal().Msg("failed to wrap default HTTP transport with SSRF protection")
}
gogithttp.DefaultClient = gogithttp.NewClient(&nethttp.Client{Transport: nethttp.DefaultTransport})
gogitclient.InstallProtocol("git", git.NewSSRFGitTransport(gogitraw.DefaultClient))
gogitclient.InstallProtocol("ssh", git.NewSSRFGitTransport(gogitssh.DefaultClient))
gogitclient.InstallProtocol("file", nil)
instanceID, err := dataStore.Version().InstanceID()
if err != nil {
+2 -2
View File
@@ -194,11 +194,11 @@ func httpClient(endpoint *portainer.Endpoint, timeout *time.Duration) (*http.Cli
}
transport = &NodeNameTransport{
Transport: ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig}),
Transport: ssrf.NewTransport(tlsConfig),
}
} else {
transport = &NodeNameTransport{
Transport: ssrf.WrapTransport(&http.Transport{}),
Transport: ssrf.NewTransport(nil),
}
}
+2 -5
View File
@@ -66,11 +66,8 @@ func NewAzureClient() *azureClient {
func newHttpClientForAzure(insecureSkipVerify bool) *http.Client {
return &http.Client{
Transport: ssrf.WrapTransport(&http.Transport{
TLSClientConfig: crypto.CreateTLSConfiguration(insecureSkipVerify),
Proxy: http.ProxyFromEnvironment,
}),
Timeout: 300 * time.Second,
Transport: ssrf.NewTransport(crypto.CreateTLSConfiguration(insecureSkipVerify)),
Timeout: 300 * time.Second,
}
}
+53
View File
@@ -0,0 +1,53 @@
package git
import (
"context"
"fmt"
"net"
"strconv"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
gittransport "github.com/go-git/go-git/v5/plumbing/transport"
)
const gitDefaultPort = 9418
// ssrfGitTransport wraps a git:// transport and validates the resolved IP
// against the SSRF policy before establishing connections.
type ssrfGitTransport struct {
inner gittransport.Transport
}
// NewSSRFGitTransport wraps inner and blocks connections to private IP ranges
// according to the active SSRF policy.
func NewSSRFGitTransport(inner gittransport.Transport) gittransport.Transport {
return &ssrfGitTransport{inner: inner}
}
func (t *ssrfGitTransport) NewUploadPackSession(ep *gittransport.Endpoint, auth gittransport.AuthMethod) (gittransport.UploadPackSession, error) {
if err := checkEndpointSSRF(ep); err != nil {
return nil, err
}
return t.inner.NewUploadPackSession(ep, auth)
}
func (t *ssrfGitTransport) NewReceivePackSession(ep *gittransport.Endpoint, auth gittransport.AuthMethod) (gittransport.ReceivePackSession, error) {
if err := checkEndpointSSRF(ep); err != nil {
return nil, err
}
return t.inner.NewReceivePackSession(ep, auth)
}
func checkEndpointSSRF(ep *gittransport.Endpoint) error {
port := ep.Port
if port <= 0 {
port = gitDefaultPort
}
rawURL := fmt.Sprintf("git://%s/", net.JoinHostPort(ep.Host, strconv.Itoa(port)))
return ssrf.CheckURL(context.Background(), rawURL)
}
+2 -2
View File
@@ -125,9 +125,9 @@ func ExecutePingOperation(host string, tlsConfiguration portainer.TLSConfigurati
}
scheme = "https"
transport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig})
transport = ssrf.NewTransport(tlsConfig)
} else {
transport = ssrf.WrapTransport(&http.Transport{})
transport = ssrf.NewTransport(nil)
}
client := &http.Client{
@@ -19,6 +19,7 @@ import (
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/portainer/portainer/pkg/validate"
"github.com/rs/zerolog/log"
@@ -315,6 +316,10 @@ func (handler *Handler) createCustomTemplateFromGitRepository(r *http.Request) (
return nil, httpErr
}
if err := ssrf.CheckURL(r.Context(), gitConfig.URL); err != nil {
return nil, err
}
commitHash, err := stackutils.DownloadGitRepository(context.TODO(), gitConfig, handler.GitService, getProjectPath)
if err != nil {
return nil, err
@@ -14,6 +14,7 @@ import (
"github.com/portainer/portainer/api/stacks/stackutils"
"github.com/portainer/portainer/pkg/edge"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/portainer/portainer/pkg/validate"
"github.com/pkg/errors"
@@ -69,7 +70,6 @@ func (payload *edgeStackFromGitRepositoryPayload) Validate(r *http.Request) erro
if len(payload.RepositoryURL) == 0 || !validate.IsURL(payload.RepositoryURL) {
return httperrors.NewInvalidPayloadError("Invalid repository URL. Must correspond to a valid URL format")
}
if payload.RepositoryAuthentication && len(payload.RepositoryPassword) == 0 {
return httperrors.NewInvalidPayloadError("Invalid repository credentials. Password must be specified when authentication is enabled")
}
@@ -138,6 +138,10 @@ func (handler *Handler) createEdgeStackFromGitRepository(r *http.Request, tx dat
return nil, httpErr
}
if err := ssrf.CheckURL(r.Context(), repoConfig.URL); err != nil {
return nil, errors.Wrap(err, "repository URL blocked by SSRF policy")
}
stack.CreatedByUserId = fmt.Sprintf("%d", tokenData.ID)
stack.CreatedBy = stackutils.SanitizeLabel(tokenData.Username)
@@ -12,6 +12,7 @@ import (
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/portainer/portainer/pkg/validate"
"github.com/rs/zerolog/log"
)
@@ -100,6 +101,10 @@ func (handler *Handler) gitOperationRepoFilePreview(w http.ResponseWriter, r *ht
tlsSkipVerify = src.Git.TLSSkipVerify
}
if err := ssrf.CheckURL(r.Context(), repoURL); err != nil {
return httperror.BadRequest("Repository URL blocked by SSRF policy", err)
}
projectPath, err := handler.fileService.GetTemporaryPath()
if err != nil {
return httperror.InternalServerError("Unable to create temporary folder", err)
@@ -12,6 +12,7 @@ import (
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"
"github.com/portainer/portainer/pkg/validate"
)
// GitAuthenticationPayload holds authentication parameters for a git source
@@ -30,8 +31,8 @@ type GitSourceCreatePayload struct {
// Validate implements the portainer.Validatable interface
func (payload *GitSourceCreatePayload) Validate(_ *http.Request) error {
if strings.TrimSpace(payload.URL) == "" {
return errors.New("url is required")
if !validate.IsURL(payload.URL) {
return errors.New("invalid repository URL. Must correspond to a valid URL format")
}
return nil
@@ -12,6 +12,7 @@ import (
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"
"github.com/portainer/portainer/pkg/validate"
)
var (
@@ -35,6 +36,10 @@ type GitAuthenticationUpdatePayload struct {
// Validate implements the portainer.Validatable interface
func (payload *GitSourceUpdatePayload) Validate(_ *http.Request) error {
if payload.URL != nil && !validate.IsURL(*payload.URL) {
return errors.New("invalid repository URL. Must correspond to a valid URL format")
}
return nil
}
+7 -1
View File
@@ -8,6 +8,7 @@ import (
"github.com/portainer/portainer/pkg/libhelm/options"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/rs/zerolog/log"
"github.com/pkg/errors"
@@ -45,6 +46,10 @@ func (handler *Handler) helmRepoSearch(w http.ResponseWriter, r *http.Request) *
return httperror.BadRequest("Bad request", errors.Wrap(err, fmt.Sprintf("provided URL %q is not valid", repo)))
}
if err := ssrf.CheckURL(r.Context(), repo); err != nil {
return httperror.BadRequest("Repository URL blocked by SSRF policy", err)
}
searchOpts := options.SearchRepoOptions{
Repo: repo,
Chart: chart,
@@ -53,7 +58,8 @@ func (handler *Handler) helmRepoSearch(w http.ResponseWriter, r *http.Request) *
result, err := handler.helmPackageManager.SearchRepo(searchOpts)
if err != nil {
return httperror.InternalServerError("Search failed", err)
log.Warn().Err(err).Str("repo", repo).Msg("helm repo search failed")
return httperror.InternalServerError("Search failed", errors.New("failed to search Helm repository"))
}
w.Header().Set("Content-Type", "text/plain")
+7 -1
View File
@@ -8,6 +8,7 @@ import (
"github.com/portainer/portainer/pkg/libhelm/options"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
@@ -41,6 +42,10 @@ func (handler *Handler) helmShow(w http.ResponseWriter, r *http.Request) *httper
return httperror.BadRequest("Bad request", errors.Wrap(err, fmt.Sprintf("provided URL %q is not valid", repo)))
}
if err := ssrf.CheckURL(r.Context(), repo); err != nil {
return httperror.BadRequest("Repository URL blocked by SSRF policy", err)
}
chart := r.URL.Query().Get("chart")
if chart == "" {
return httperror.BadRequest("Bad request", errors.New("missing `chart` query parameter"))
@@ -65,7 +70,8 @@ func (handler *Handler) helmShow(w http.ResponseWriter, r *http.Request) *httper
}
result, err := handler.helmPackageManager.Show(showOptions)
if err != nil {
return httperror.InternalServerError("Unable to show chart", err)
log.Warn().Err(err).Str("repo", repo).Str("chart", chart).Msg("helm show failed")
return httperror.InternalServerError("Unable to show chart", errors.New("failed to retrieve Helm chart information"))
}
w.Header().Set("Content-Type", "text/plain")
@@ -14,6 +14,7 @@ import (
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/portainer/portainer/pkg/validate"
"github.com/pkg/errors"
@@ -74,6 +75,12 @@ func (payload *settingsUpdatePayload) Validate(r *http.Request) error {
return errors.New("Invalid Helm repository URL. Must correspond to a valid URL format")
}
if payload.HelmRepositoryURL != nil && *payload.HelmRepositoryURL != "" {
if err := ssrf.CheckURL(r.Context(), *payload.HelmRepositoryURL); err != nil {
return errors.New("Invalid Helm repository URL. Must correspond to a valid URL format")
}
}
if payload.UserSessionTimeout != nil {
if _, err := time.ParseDuration(*payload.UserSessionTimeout); err != nil {
return errors.New("Invalid user session timeout")
@@ -14,6 +14,7 @@ import (
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/portainer/portainer/pkg/validate"
"github.com/pkg/errors"
@@ -125,6 +126,10 @@ func (payload *kubernetesManifestURLDeploymentPayload) Validate(r *http.Request)
return errors.New("Invalid manifest URL")
}
if err := ssrf.CheckURL(r.Context(), payload.ManifestURL); err != nil {
return err
}
return nil
}
+6 -1
View File
@@ -11,6 +11,7 @@ import (
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/pkg/errors"
)
@@ -24,7 +25,11 @@ type addHelmRepoUrlPayload struct {
URL string `json:"url"`
}
func (p *addHelmRepoUrlPayload) Validate(_ *http.Request) error {
func (p *addHelmRepoUrlPayload) Validate(r *http.Request) error {
if err := ssrf.CheckURL(r.Context(), p.URL); err != nil {
return err
}
return libhelm.ValidateHelmRepositoryURL(p.URL, nil)
}
+4 -4
View File
@@ -52,14 +52,14 @@ func (factory *ProxyFactory) NewAgentProxy(endpoint *portainer.Endpoint) (*Proxy
endpointURL.Scheme = "https"
if endpointutils.IsEdgeEndpoint(endpoint) {
innerTransport = ssrf.WrapTransportInternal(&http.Transport{TLSClientConfig: tlsConfig})
innerTransport = ssrf.NewInternalTransport(tlsConfig)
} else {
innerTransport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig})
innerTransport = ssrf.NewTransport(tlsConfig)
}
} else if endpointutils.IsEdgeEndpoint(endpoint) {
innerTransport = ssrf.WrapTransportInternal(&http.Transport{})
innerTransport = ssrf.NewInternalTransport(nil)
} else {
innerTransport = ssrf.WrapTransport(&http.Transport{})
innerTransport = ssrf.NewTransport(nil)
}
proxy := NewSingleHostReverseProxyWithHostHeader(endpointURL)
+4 -4
View File
@@ -68,14 +68,14 @@ func (factory *ProxyFactory) newDockerHTTPProxy(endpoint *portainer.Endpoint) (h
endpointURL.Scheme = "https"
if endpointutils.IsEdgeEndpoint(endpoint) {
innerTransport = ssrf.WrapTransportInternal(&http.Transport{TLSClientConfig: tlsConfig})
innerTransport = ssrf.NewInternalTransport(tlsConfig)
} else {
innerTransport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig})
innerTransport = ssrf.NewTransport(tlsConfig)
}
} else if endpointutils.IsEdgeEndpoint(endpoint) {
innerTransport = ssrf.WrapTransportInternal(&http.Transport{})
innerTransport = ssrf.NewInternalTransport(nil)
} else {
innerTransport = ssrf.WrapTransport(&http.Transport{})
innerTransport = ssrf.NewTransport(nil)
}
dockerTransport, err := docker.NewTransport(transportParameters, innerTransport, factory.gitService, factory.snapshotService)
@@ -22,6 +22,7 @@ import (
"github.com/portainer/portainer/api/internal/authorization"
"github.com/portainer/portainer/api/logs"
"github.com/portainer/portainer/api/slicesx"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/api/types/swarm"
@@ -506,6 +507,11 @@ func (transport *Transport) updateDefaultGitBranch(request *http.Request) error
}
repositoryURL := remote[:len(remote)-4]
if err := ssrf.CheckURL(request.Context(), repositoryURL); err != nil {
return err
}
latestCommitID, err := transport.gitService.LatestCommitID(
request.Context(),
repositoryURL,
+8 -4
View File
@@ -3,11 +3,13 @@
package factory
import (
"context"
"net"
"net/http"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/http/proxy/factory/docker"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
)
func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portainer.Endpoint) (http.Handler, error) {
@@ -31,9 +33,11 @@ func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portaine
}
func newSocketTransport(socketPath string) *http.Transport {
return &http.Transport{
Dial: func(proto, addr string) (conn net.Conn, err error) {
return net.Dial("unix", socketPath)
},
d := &net.Dialer{}
t := ssrf.NewInternalTransport(nil)
t.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return d.DialContext(ctx, "unix", socketPath)
}
return t
}
+7 -4
View File
@@ -3,12 +3,14 @@
package factory
import (
"context"
"net"
"net/http"
"github.com/Microsoft/go-winio"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/http/proxy/factory/docker"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
)
func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portainer.Endpoint) (http.Handler, error) {
@@ -32,9 +34,10 @@ func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portaine
}
func newNamedPipeTransport(namedPipePath string) *http.Transport {
return &http.Transport{
Dial: func(proto, addr string) (conn net.Conn, err error) {
return winio.DialPipe(namedPipePath, nil)
},
t := ssrf.NewInternalTransport(nil)
t.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) {
return winio.DialPipe(namedPipePath, nil)
}
return t
}
+1 -1
View File
@@ -94,7 +94,7 @@ func NewHTTPClient(token string) *http.Client {
return &http.Client{
Transport: &tokenTransport{
token: token,
transport: retry.NewTransport(ssrf.WrapTransport(&http.Transport{})), // Use ORAS retry transport for consistent rate limiting and error handling
transport: retry.NewTransport(ssrf.NewTransport(nil)), // Use ORAS retry transport for consistent rate limiting and error handling
},
Timeout: 1 * time.Minute,
}
+2 -2
View File
@@ -94,7 +94,7 @@ type Transport struct {
// interface for proxying requests to the Gitlab API.
func NewTransport() *Transport {
return &Transport{
httpTransport: ssrf.WrapTransport(&http.Transport{}),
httpTransport: ssrf.NewTransport(nil),
}
}
@@ -119,7 +119,7 @@ func NewHTTPClient(token string) *http.Client {
return &http.Client{
Transport: &tokenTransport{
token: token,
transport: retry.NewTransport(ssrf.WrapTransport(&http.Transport{})), // Use ORAS retry transport for consistent rate limiting and error handling
transport: retry.NewTransport(ssrf.NewTransport(nil)), // Use ORAS retry transport for consistent rate limiting and error handling
},
Timeout: 1 * time.Minute,
}
@@ -25,9 +25,7 @@ func NewAgentTransport(signatureService portainer.DigitalSignatureService, token
transport := &agentTransport{
baseTransport: newBaseTransport(
ssrf.WrapTransport(&http.Transport{
TLSClientConfig: tlsConfig,
}),
ssrf.NewTransport(tlsConfig),
tokenManager,
endpoint,
k8sClientFactory,
@@ -22,7 +22,7 @@ func NewEdgeTransport(dataStore dataservices.DataStore, signatureService portain
reverseTunnelService: reverseTunnelService,
signatureService: signatureService,
baseTransport: newBaseTransport(
ssrf.WrapTransportInternal(&http.Transport{}),
ssrf.NewInternalTransport(nil),
tokenManager,
endpoint,
k8sClientFactory,
@@ -23,9 +23,7 @@ func NewLocalTransport(tokenManager *tokenManager, endpoint *portainer.Endpoint,
transport := &localTransport{
baseTransport: newBaseTransport(
ssrf.WrapTransportInternal(&http.Transport{
TLSClientConfig: config,
}),
ssrf.NewInternalTransport(config),
tokenManager,
endpoint,
k8sClientFactory,
+20 -10
View File
@@ -2,6 +2,8 @@ package factory
import (
"context"
"net/http"
"net/http/httptest"
"net/http/httputil"
"testing"
"time"
@@ -65,8 +67,6 @@ func enableSSRF(t *testing.T) {
})
}
// TestNewDockerHTTPProxy_NonEdgeNoTLS verifies that a plain non-edge endpoint
// uses WrapTransport, setting DialContext on the inner transport.
func TestNewDockerHTTPProxy_NonEdgeNoTLS(t *testing.T) {
enableSSRF(t)
@@ -84,8 +84,6 @@ func TestNewDockerHTTPProxy_NonEdgeNoTLS(t *testing.T) {
require.NotNil(t, dt.HTTPTransport.DialContext)
}
// TestNewDockerHTTPProxy_NonEdgeTLS verifies that a TLS non-edge endpoint
// uses WrapTransport, setting DialContext on the inner transport.
func TestNewDockerHTTPProxy_NonEdgeTLS(t *testing.T) {
enableSSRF(t)
@@ -107,11 +105,14 @@ func TestNewDockerHTTPProxy_NonEdgeTLS(t *testing.T) {
require.NotNil(t, dt.HTTPTransport.DialContext)
}
// TestNewDockerHTTPProxy_EdgeNoTLS verifies that an edge endpoint without TLS
// uses WrapTransportInternal, leaving DialContext nil.
func TestNewDockerHTTPProxy_EdgeNoTLS(t *testing.T) {
enableSSRF(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
endpoint := &portainer.Endpoint{
Type: portainer.EdgeAgentOnDockerEnvironment,
@@ -123,14 +124,20 @@ func TestNewDockerHTTPProxy_EdgeNoTLS(t *testing.T) {
proxy := handler.(*httputil.ReverseProxy)
dt := proxy.Transport.(*docker.Transport)
require.Nil(t, dt.HTTPTransport.DialContext)
resp, err := (&http.Client{Transport: dt.HTTPTransport}).Get(srv.URL)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
}
// TestNewDockerHTTPProxy_EdgeTLS verifies that an edge endpoint with TLS
// uses WrapTransportInternal, leaving DialContext nil.
func TestNewDockerHTTPProxy_EdgeTLS(t *testing.T) {
enableSSRF(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
endpoint := &portainer.Endpoint{
Type: portainer.EdgeAgentOnDockerEnvironment,
@@ -146,7 +153,10 @@ func TestNewDockerHTTPProxy_EdgeTLS(t *testing.T) {
proxy := handler.(*httputil.ReverseProxy)
dt := proxy.Transport.(*docker.Transport)
require.Nil(t, dt.HTTPTransport.DialContext)
resp, err := (&http.Client{Transport: dt.HTTPTransport}).Get(srv.URL)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
}
func TestNewAgentProxy_NonEdgeNoTLS(t *testing.T) {
@@ -13,6 +13,7 @@ import (
"github.com/portainer/portainer/api/scheduler"
"github.com/portainer/portainer/api/stacks/deployments"
"github.com/portainer/portainer/api/stacks/stackutils"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
)
type GitMethodStackBuilder struct {
@@ -78,6 +79,10 @@ func (b *GitMethodStackBuilder) prepare(ctx context.Context, payload *StackPaylo
return b.fileService.GetStackProjectPath(stackFolder)
}
if err := ssrf.CheckURL(ctx, repoConfig.URL); err != nil {
return fmt.Errorf("repository URL blocked by SSRF policy: %w", err)
}
commitHash, err := stackutils.DownloadGitRepository(ctx, repoConfig, b.gitService, getProjectPath)
if err != nil {
return fmt.Errorf("failed to download git repository: %w", err)
+4 -1
View File
@@ -14,6 +14,7 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/logs"
"github.com/portainer/portainer/pkg/libhelm/options"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/portainer/portainer/pkg/liboras"
"github.com/rs/zerolog/log"
"github.com/segmentio/encoding/json"
@@ -216,13 +217,15 @@ func downloadRepoIndexFromHttpRepo(repoURLString string, repoSettings *cli.EnvSe
Str("repo_name", repoName).
Msg("Creating chart repository object")
ssrfTransport := ssrf.NewTransport(nil)
// Create chart repository object
rep, err := repo.NewChartRepository(
&repo.Entry{
Name: repoName,
URL: repoURLString,
},
getter.All(repoSettings),
getter.All(repoSettings, getter.WithTransport(ssrfTransport)),
)
if err != nil {
log.Error().
+5 -2
View File
@@ -8,6 +8,7 @@ import (
"strings"
"github.com/portainer/portainer/pkg/libhelm/sdk"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"helm.sh/helm/v4/pkg/cli"
"helm.sh/helm/v4/pkg/getter"
repo "helm.sh/helm/v4/pkg/repo/v1"
@@ -40,12 +41,14 @@ func ValidateHelmRepositoryURL(repoUrl string, _ *http.Client) error {
return fmt.Errorf("failed to derive repo name: %w", err)
}
ssrfTransport := ssrf.NewTransport(nil)
r, err := repo.NewChartRepository(
&repo.Entry{
Name: repoName,
URL: repoUrl,
},
getter.All(settings),
getter.All(settings, getter.WithTransport(ssrfTransport)),
)
if err != nil {
return fmt.Errorf("%s is not a valid chart repository or cannot be reached: %w", repoUrl, err)
@@ -53,7 +56,7 @@ func ValidateHelmRepositoryURL(repoUrl string, _ *http.Client) error {
indexPath, err := r.DownloadIndexFile()
if err != nil {
return fmt.Errorf("%s is not a valid chart repository or cannot be reached: %w", repoUrl, err)
return fmt.Errorf("%s is not a valid chart repository or cannot be reached", repoUrl)
}
// Best-effort: load and seed in-memory cache for future SearchRepo calls
+53
View File
@@ -0,0 +1,53 @@
package ssrf
import (
"crypto/tls"
"net/http"
)
// NewTransport creates an SSRF-protected transport for user-influenced destinations.
// It clones http.DefaultTransport as its base (inheriting pool and timeout defaults)
// and applies the global SSRF dial-context filter so mode changes take effect without
// restarting. tlsConfig may be nil to preserve standard TLS behavior (system CAs).
func NewTransport(tlsConfig *tls.Config) *http.Transport {
base := http.DefaultTransport.(*http.Transport).Clone()
base.TLSClientConfig = tlsConfig
applySSRF(base)
return base
}
// NewInternalTransport creates a plain transport for destinations chosen by Portainer,
// not by the user (Docker socket proxy, Chisel tunnels, in-cluster Kubernetes API).
// It clones http.DefaultTransport as its base. tlsConfig may be nil.
// Using this function instead of NewTransport makes the exemption explicit and
// satisfies the ruleguard lint rule.
func NewInternalTransport(tlsConfig *tls.Config) *http.Transport {
base := http.DefaultTransport.(*http.Transport).Clone()
base.TLSClientConfig = tlsConfig
return base
}
// WrapDefaultTransport replaces http.DefaultTransport with an SSRF-protected version.
// Must be called after Configure. Returns false if DefaultTransport is not an *http.Transport.
func WrapDefaultTransport() bool {
dt, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return false
}
cloned := dt.Clone()
applySSRF(cloned)
http.DefaultTransport = cloned
return true
}
// applySSRF sets the SSRF-filtering DialContext on t when the global dialer is active.
func applySSRF(t *http.Transport) {
d := globalDialer.Load()
if d != nil {
t.DialContext = d.DialContext
}
}
-26
View File
@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"sync/atomic"
@@ -112,31 +111,6 @@ func CheckURL(ctx context.Context, rawURL string) error {
return d.checkHost(ctx, host)
}
// WrapTransport clones t and replaces its DialContext with the global SSRF-filtering
// dialer. The dialer checks the mode on every connection, so the transport is always
// wrapped and mode changes take effect without restarting.
func WrapTransport(t *http.Transport) *http.Transport {
d := globalDialer.Load()
if d == nil {
return t
}
cloned := t.Clone()
cloned.DialContext = d.DialContext
return cloned
}
// WrapTransportInternal is a documented no-op for transports that connect to
// internally computed destinations (local Docker socket proxy, Chisel tunnels,
// in-cluster Kubernetes API). The destination is chosen by Portainer, not
// supplied by any user, so SSRF validation is not applicable. Using this
// function instead of WrapTransport makes the exemption explicit and
// satisfies the ruleguard lint rule.
func WrapTransportInternal(t *http.Transport) *http.Transport {
return t
}
// DialContext resolves addr, validates all resolved IPs against the allowlist policy,
// then dials using the first resolved IP to prevent DNS rebinding attacks.
func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
+22 -17
View File
@@ -116,22 +116,27 @@ func TestConfigure_NilServicesReturnsError(t *testing.T) {
require.Error(t, err)
}
func TestWrapTransport_NoPolicy(t *testing.T) {
func TestNewTransport_NoPolicy(t *testing.T) {
globalDialer.Store(nil)
t.Cleanup(func() { globalDialer.Store(nil) })
base := &http.Transport{}
result := WrapTransport(base)
require.Equal(t, base, result)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
client := &http.Client{Transport: NewTransport(nil)}
resp, err := client.Get(srv.URL)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
}
func TestWrapTransport_WithPolicy(t *testing.T) {
func TestNewTransport_WithPolicy(t *testing.T) {
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"}))
require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) })
base := &http.Transport{}
result := WrapTransport(base)
require.NotEqual(t, base, result)
result := NewTransport(nil)
require.NotNil(t, result.DialContext)
}
@@ -267,12 +272,12 @@ func TestIsEnabled(t *testing.T) {
require.False(t, IsEnabled())
}
func TestWrapTransportInternal(t *testing.T) {
func TestNewInternalTransport(t *testing.T) {
t.Parallel()
base := &http.Transport{}
result := WrapTransportInternal(base)
require.Equal(t, base, result)
result := NewInternalTransport(nil)
require.NotNil(t, result)
require.Nil(t, result.TLSClientConfig)
}
// TestDialContext_BlocksLoopback is an end-to-end test: it starts a real HTTP
@@ -288,7 +293,7 @@ func TestDialContext_BlocksLoopback(t *testing.T) {
require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) })
blocked := &http.Client{Transport: WrapTransport(&http.Transport{})}
blocked := &http.Client{Transport: NewTransport(nil)}
resp, err := blocked.Get(srv.URL)
require.Error(t, err)
require.Contains(t, err.Error(), "ssrf")
@@ -300,7 +305,7 @@ func TestDialContext_BlocksLoopback(t *testing.T) {
err = Configure(newStaticService(portainer.SSRFModeOff, nil))
require.NoError(t, err)
open := &http.Client{Transport: WrapTransport(&http.Transport{})}
open := &http.Client{Transport: NewTransport(nil)}
resp, err = open.Get(srv.URL)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
@@ -318,7 +323,7 @@ func TestDialContext_AuditMode_AllowsLoopback(t *testing.T) {
require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) })
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
client := &http.Client{Transport: NewTransport(nil)}
resp, err := client.Get(srv.URL)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
@@ -364,7 +369,7 @@ func TestDialContext_AllowedByCIDR(t *testing.T) {
require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) })
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
client := &http.Client{Transport: NewTransport(nil)}
resp, err := client.Get(srv.URL)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
@@ -398,7 +403,7 @@ func TestDialContext_AllowedByExactHostname(t *testing.T) {
require.NoError(t, err)
t.Cleanup(func() { globalDialer.Store(nil) })
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
client := &http.Client{Transport: NewTransport(nil)}
resp, err := client.Get("http://localhost:" + portStr)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
+3 -4
View File
@@ -10,6 +10,7 @@ import (
"github.com/portainer/portainer/api/crypto"
"github.com/portainer/portainer/api/logs"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
"github.com/rs/zerolog/log"
"github.com/segmentio/encoding/json"
@@ -73,10 +74,8 @@ func ProbeTelnetConnection(url string) string {
// ignores errors for the http request since we want to know if the host is reachable
func DetectProxy(url string) string {
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: crypto.CreateTLSConfiguration(true),
},
Timeout: 10 * time.Second,
Transport: ssrf.NewTransport(crypto.CreateTLSConfiguration(true)),
Timeout: 10 * time.Second,
}
result := map[string]string{
+1 -2
View File
@@ -16,8 +16,7 @@ import (
func CreateClient(registry *portainer.Registry) (httpClient *http.Client, usePlainHttp bool, err error) {
switch registry.Type {
case portainer.AzureRegistry, portainer.EcrRegistry, portainer.GithubRegistry, portainer.GitlabRegistry, portainer.DockerHubRegistry:
base := http.DefaultTransport.(*http.Transport).Clone()
return &http.Client{Transport: retry.NewTransport(ssrf.WrapTransport(base))}, false, nil
return &http.Client{Transport: retry.NewTransport(ssrf.NewTransport(nil))}, false, nil
default:
// For all other registry types, use shared helper to build transport and scheme
+11 -12
View File
@@ -6,28 +6,27 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/crypto"
"github.com/portainer/portainer/pkg/fips"
"github.com/portainer/portainer/pkg/libhttp/ssrf"
)
// BuildTransportAndSchemeFromTLSConfig returns a base HTTP transport configured
// with ProxyFromEnvironment and, when needed, a TLSClientConfig derived from the
// provided TLS settings. It also returns the scheme ("http" or "https") that
// should be used to contact the registry based on the TLS settings.
// BuildTransportAndSchemeFromTLSConfig returns an SSRF-protected HTTP transport and the
// scheme ("http" or "https") to use when contacting the registry. The transport is based on
// the TLS settings from tlsCfg; pass a zero-value TLSConfiguration for plaintext.
func BuildTransportAndSchemeFromTLSConfig(tlsCfg portainer.TLSConfiguration) (*http.Transport, string, error) {
baseTransport := http.DefaultTransport.(*http.Transport).Clone()
baseTransport.Proxy = http.ProxyFromEnvironment
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(tlsCfg)
if err != nil {
return nil, "", err
}
baseTransport.TLSClientConfig = tlsConfig
if tlsConfig == nil && fips.FIPSMode() {
return nil, "", fips.ErrTLSRequired
} else if tlsConfig == nil {
return baseTransport, "http", nil
}
return baseTransport, "https", nil
transport := ssrf.NewTransport(tlsConfig)
if tlsConfig == nil {
return transport, "http", nil
}
return transport, "https", nil
}