diff --git a/.golangci.yaml b/.golangci.yaml index 7523eb1300..0bdc00e4ed 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -108,6 +108,9 @@ linters: linters: - gocritic text: ruleguard + - path: pkg/libhttp/ssrf/builder\.go + linters: + - forbidigo paths: - third_party$ - builtin$ diff --git a/analysis/ssrf.go b/analysis/ssrf.go index 3af95653ba..b183a7f185 100644 --- a/analysis/ssrf.go +++ b/analysis/ssrf.go @@ -4,26 +4,72 @@ package gorules import "github.com/quasilyte/go-ruleguard/dsl" -// unwrappedHTTPTransport flags http.Transport composite literals that are not -// the direct argument to ssrf.WrapTransport. +// unwrappedHTTPTransport flags any bare http.Transport composite literal. +// All transports must be created via ssrf.NewTransport or ssrf.NewInternalTransport, +// which clone http.DefaultTransport and handle SSRF protection internally. func unwrappedHTTPTransport(m dsl.Matcher) { - // Inline construction passed to a function call. m.Match(`$f(&http.Transport{$*_})`). - Where(m["f"].Text != "ssrf.WrapTransport" && m["f"].Text != "WrapTransport" && - m["f"].Text != "ssrf.WrapTransportInternal" && m["f"].Text != "WrapTransportInternal"). - Report(`$f receives a bare *http.Transport; wrap with ssrf.WrapTransport() to enforce the SSRF protection policy`) + Report(`$f receives a bare *http.Transport; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`) - // Variable assigned a bare transport (cannot be tracked to a later WrapTransport call). m.Match(`$_ := &http.Transport{$*_}`). - Report(`bare *http.Transport variable; use ssrf.WrapTransport(&http.Transport{...}) inline instead`) + Report(`bare *http.Transport variable; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`) + + m.Match(`$_.Transport = &http.Transport{$*_}`). + Report(`bare *http.Transport field assignment; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`) } -// internalTransportMisuse flags calls to WrapTransportInternal outside the four proxy +// helmGetterTransport flags getter.WithTransport calls that receive a bare *http.Transport. +// Helm v4 installs its own transport and bypasses http.DefaultTransport, so the transport +// passed here must be created via ssrf.NewTransport. +func helmGetterTransport(m dsl.Matcher) { + m.Match(`getter.WithTransport(&http.Transport{$*_})`). + Report(`getter.WithTransport called with a bare *http.Transport; use ssrf.NewTransport(tlsConfig) as Helm v4 bypasses http.DefaultTransport`) +} + +// cloneDefaultTransport flags direct clones of *http.Transport outside main.go. +// The one legitimate clone is in main.go where http.DefaultTransport is globally +// wrapped with SSRF protection at server startup. +func cloneDefaultTransport(m dsl.Matcher) { + m.Match(`$_.(*http.Transport).Clone()`). + Where(!m.File().Name.Matches(`^main\.go$`)). + Report(`cloning *http.Transport directly is forbidden; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`) +} + +// internalTransportMisuse flags calls to NewInternalTransport outside the proxy // factory files where Chisel-tunnel and in-cluster K8s destinations are valid exemptions. func internalTransportMisuse(m dsl.Matcher) { - m.Match(`ssrf.WrapTransportInternal($*_)`). + m.Match(`ssrf.NewInternalTransport($*_)`). Where( !(m.File().PkgPath.Matches(`proxy/factory`) && - m.File().Name.Matches(`^(docker|agent|local_transport|edge_transport)\.go$`))). - Report(`WrapTransportInternal bypasses SSRF validation; only valid in the kubernetes local/edge transport constructors and the docker/agent proxy factories`) + m.File().Name.Matches(`^(docker|agent|local_transport|edge_transport|docker_unix|docker_windows)\.go$`))). + Report(`NewInternalTransport bypasses SSRF validation; only valid in the proxy factory files for local sockets and internally-routed endpoints`) +} + +// dialerOverride flags direct assignments to any of the dialer fields on a transport. +// The only valid assignments are in docker_unix.go and docker_windows.go where a +// custom dialer is required for unix sockets and named pipes. +func dialerOverride(m dsl.Matcher) { + m.Match(`$_.DialContext = $*_`). + Where( + !(m.File().PkgPath.Matches(`proxy/factory`) && + m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))). + Report(`direct DialContext assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`) + + m.Match(`$_.Dial = $*_`). + Where( + !(m.File().PkgPath.Matches(`proxy/factory`) && + m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))). + Report(`direct Dial assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`) + + m.Match(`$_.DialTLSContext = $*_`). + Where( + !(m.File().PkgPath.Matches(`proxy/factory`) && + m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))). + Report(`direct DialTLSContext assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`) + + m.Match(`$_.DialTLS = $*_`). + Where( + !(m.File().PkgPath.Matches(`proxy/factory`) && + m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))). + Report(`direct DialTLS assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`) } diff --git a/api/agent/version.go b/api/agent/version.go index ea8a22b9ca..ff61d03eeb 100644 --- a/api/agent/version.go +++ b/api/agent/version.go @@ -1,6 +1,7 @@ package agent import ( + "context" "crypto/tls" "errors" "fmt" @@ -11,6 +12,7 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/url" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/rs/zerolog/log" ) @@ -19,10 +21,14 @@ import ( // // it sends a ping to the agent and parses the version and platform from the headers func GetAgentVersionAndPlatform(endpointUrl string, tlsConfig *tls.Config) (portainer.AgentPlatform, string, error) { //nolint:forbidigo + if err := ssrf.CheckURL(context.Background(), endpointUrl); err != nil { + return 0, "", err + } + httpCli := &http.Client{Timeout: 3 * time.Second} if tlsConfig != nil { - httpCli.Transport = &http.Transport{TLSClientConfig: tlsConfig} + httpCli.Transport = ssrf.NewTransport(tlsConfig) } parsedURL, err := url.ParseURL(endpointUrl + "/ping") diff --git a/api/cmd/portainer/main.go b/api/cmd/portainer/main.go index 8b92aee116..d5104db6b2 100644 --- a/api/cmd/portainer/main.go +++ b/api/cmd/portainer/main.go @@ -58,7 +58,10 @@ import ( libswarm "github.com/portainer/portainer/pkg/libstack/swarm" "github.com/portainer/portainer/pkg/validate" + gogitclient "github.com/go-git/go-git/v5/plumbing/transport/client" + gogitraw "github.com/go-git/go-git/v5/plumbing/transport/git" gogithttp "github.com/go-git/go-git/v5/plumbing/transport/http" + gogitssh "github.com/go-git/go-git/v5/plumbing/transport/ssh" "github.com/google/uuid" "github.com/rs/zerolog/log" ) @@ -408,11 +411,14 @@ func buildServer(flags *portainer.CLIFlags, shutdownCtx context.Context, shutdow log.Fatal().Err(err).Msg("failed initializing ssrf service") } - if dt, ok := nethttp.DefaultTransport.(*nethttp.Transport); ok { - nethttp.DefaultTransport = ssrf.WrapTransport(dt) + if !ssrf.WrapDefaultTransport() { + log.Fatal().Msg("failed to wrap default HTTP transport with SSRF protection") } gogithttp.DefaultClient = gogithttp.NewClient(&nethttp.Client{Transport: nethttp.DefaultTransport}) + gogitclient.InstallProtocol("git", git.NewSSRFGitTransport(gogitraw.DefaultClient)) + gogitclient.InstallProtocol("ssh", git.NewSSRFGitTransport(gogitssh.DefaultClient)) + gogitclient.InstallProtocol("file", nil) instanceID, err := dataStore.Version().InstanceID() if err != nil { diff --git a/api/docker/client/client.go b/api/docker/client/client.go index 5609b290d9..abd7b246e4 100644 --- a/api/docker/client/client.go +++ b/api/docker/client/client.go @@ -194,11 +194,11 @@ func httpClient(endpoint *portainer.Endpoint, timeout *time.Duration) (*http.Cli } transport = &NodeNameTransport{ - Transport: ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig}), + Transport: ssrf.NewTransport(tlsConfig), } } else { transport = &NodeNameTransport{ - Transport: ssrf.WrapTransport(&http.Transport{}), + Transport: ssrf.NewTransport(nil), } } diff --git a/api/git/azure.go b/api/git/azure.go index 2acf14b233..fa46f4bdd7 100644 --- a/api/git/azure.go +++ b/api/git/azure.go @@ -66,11 +66,8 @@ func NewAzureClient() *azureClient { func newHttpClientForAzure(insecureSkipVerify bool) *http.Client { return &http.Client{ - Transport: ssrf.WrapTransport(&http.Transport{ - TLSClientConfig: crypto.CreateTLSConfiguration(insecureSkipVerify), - Proxy: http.ProxyFromEnvironment, - }), - Timeout: 300 * time.Second, + Transport: ssrf.NewTransport(crypto.CreateTLSConfiguration(insecureSkipVerify)), + Timeout: 300 * time.Second, } } diff --git a/api/git/ssrf_transport.go b/api/git/ssrf_transport.go new file mode 100644 index 0000000000..6ef461488d --- /dev/null +++ b/api/git/ssrf_transport.go @@ -0,0 +1,53 @@ +package git + +import ( + "context" + "fmt" + "net" + "strconv" + + "github.com/portainer/portainer/pkg/libhttp/ssrf" + + gittransport "github.com/go-git/go-git/v5/plumbing/transport" +) + +const gitDefaultPort = 9418 + +// ssrfGitTransport wraps a git:// transport and validates the resolved IP +// against the SSRF policy before establishing connections. +type ssrfGitTransport struct { + inner gittransport.Transport +} + +// NewSSRFGitTransport wraps inner and blocks connections to private IP ranges +// according to the active SSRF policy. +func NewSSRFGitTransport(inner gittransport.Transport) gittransport.Transport { + return &ssrfGitTransport{inner: inner} +} + +func (t *ssrfGitTransport) NewUploadPackSession(ep *gittransport.Endpoint, auth gittransport.AuthMethod) (gittransport.UploadPackSession, error) { + if err := checkEndpointSSRF(ep); err != nil { + return nil, err + } + + return t.inner.NewUploadPackSession(ep, auth) +} + +func (t *ssrfGitTransport) NewReceivePackSession(ep *gittransport.Endpoint, auth gittransport.AuthMethod) (gittransport.ReceivePackSession, error) { + if err := checkEndpointSSRF(ep); err != nil { + return nil, err + } + + return t.inner.NewReceivePackSession(ep, auth) +} + +func checkEndpointSSRF(ep *gittransport.Endpoint) error { + port := ep.Port + if port <= 0 { + port = gitDefaultPort + } + + rawURL := fmt.Sprintf("git://%s/", net.JoinHostPort(ep.Host, strconv.Itoa(port))) + + return ssrf.CheckURL(context.Background(), rawURL) +} diff --git a/api/http/client/client.go b/api/http/client/client.go index 68c4ecd6f7..8b6577796a 100644 --- a/api/http/client/client.go +++ b/api/http/client/client.go @@ -125,9 +125,9 @@ func ExecutePingOperation(host string, tlsConfiguration portainer.TLSConfigurati } scheme = "https" - transport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig}) + transport = ssrf.NewTransport(tlsConfig) } else { - transport = ssrf.WrapTransport(&http.Transport{}) + transport = ssrf.NewTransport(nil) } client := &http.Client{ diff --git a/api/http/handler/customtemplates/customtemplate_create.go b/api/http/handler/customtemplates/customtemplate_create.go index 31c99cd952..2d95ecfaae 100644 --- a/api/http/handler/customtemplates/customtemplate_create.go +++ b/api/http/handler/customtemplates/customtemplate_create.go @@ -19,6 +19,7 @@ import ( httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/response" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/portainer/portainer/pkg/validate" "github.com/rs/zerolog/log" @@ -315,6 +316,10 @@ func (handler *Handler) createCustomTemplateFromGitRepository(r *http.Request) ( return nil, httpErr } + if err := ssrf.CheckURL(r.Context(), gitConfig.URL); err != nil { + return nil, err + } + commitHash, err := stackutils.DownloadGitRepository(context.TODO(), gitConfig, handler.GitService, getProjectPath) if err != nil { return nil, err diff --git a/api/http/handler/edgestacks/edgestack_create_git.go b/api/http/handler/edgestacks/edgestack_create_git.go index 09f902f0c6..83901a2985 100644 --- a/api/http/handler/edgestacks/edgestack_create_git.go +++ b/api/http/handler/edgestacks/edgestack_create_git.go @@ -14,6 +14,7 @@ import ( "github.com/portainer/portainer/api/stacks/stackutils" "github.com/portainer/portainer/pkg/edge" "github.com/portainer/portainer/pkg/libhttp/request" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/portainer/portainer/pkg/validate" "github.com/pkg/errors" @@ -69,7 +70,6 @@ func (payload *edgeStackFromGitRepositoryPayload) Validate(r *http.Request) erro if len(payload.RepositoryURL) == 0 || !validate.IsURL(payload.RepositoryURL) { return httperrors.NewInvalidPayloadError("Invalid repository URL. Must correspond to a valid URL format") } - if payload.RepositoryAuthentication && len(payload.RepositoryPassword) == 0 { return httperrors.NewInvalidPayloadError("Invalid repository credentials. Password must be specified when authentication is enabled") } @@ -138,6 +138,10 @@ func (handler *Handler) createEdgeStackFromGitRepository(r *http.Request, tx dat return nil, httpErr } + if err := ssrf.CheckURL(r.Context(), repoConfig.URL); err != nil { + return nil, errors.Wrap(err, "repository URL blocked by SSRF policy") + } + stack.CreatedByUserId = fmt.Sprintf("%d", tokenData.ID) stack.CreatedBy = stackutils.SanitizeLabel(tokenData.Username) diff --git a/api/http/handler/gitops/git_repo_file_preview.go b/api/http/handler/gitops/git_repo_file_preview.go index e4aa405598..93545a80ee 100644 --- a/api/http/handler/gitops/git_repo_file_preview.go +++ b/api/http/handler/gitops/git_repo_file_preview.go @@ -12,6 +12,7 @@ import ( httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/response" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/portainer/portainer/pkg/validate" "github.com/rs/zerolog/log" ) @@ -100,6 +101,10 @@ func (handler *Handler) gitOperationRepoFilePreview(w http.ResponseWriter, r *ht tlsSkipVerify = src.Git.TLSSkipVerify } + if err := ssrf.CheckURL(r.Context(), repoURL); err != nil { + return httperror.BadRequest("Repository URL blocked by SSRF policy", err) + } + projectPath, err := handler.fileService.GetTemporaryPath() if err != nil { return httperror.InternalServerError("Unable to create temporary folder", err) diff --git a/api/http/handler/gitops/sources/create_git.go b/api/http/handler/gitops/sources/create_git.go index fcde7aaa21..6329e377e8 100644 --- a/api/http/handler/gitops/sources/create_git.go +++ b/api/http/handler/gitops/sources/create_git.go @@ -12,6 +12,7 @@ import ( httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/response" + "github.com/portainer/portainer/pkg/validate" ) // GitAuthenticationPayload holds authentication parameters for a git source @@ -30,8 +31,8 @@ type GitSourceCreatePayload struct { // Validate implements the portainer.Validatable interface func (payload *GitSourceCreatePayload) Validate(_ *http.Request) error { - if strings.TrimSpace(payload.URL) == "" { - return errors.New("url is required") + if !validate.IsURL(payload.URL) { + return errors.New("invalid repository URL. Must correspond to a valid URL format") } return nil diff --git a/api/http/handler/gitops/sources/update_git.go b/api/http/handler/gitops/sources/update_git.go index 867baa3b56..4731c9d715 100644 --- a/api/http/handler/gitops/sources/update_git.go +++ b/api/http/handler/gitops/sources/update_git.go @@ -12,6 +12,7 @@ import ( httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/response" + "github.com/portainer/portainer/pkg/validate" ) var ( @@ -35,6 +36,10 @@ type GitAuthenticationUpdatePayload struct { // Validate implements the portainer.Validatable interface func (payload *GitSourceUpdatePayload) Validate(_ *http.Request) error { + if payload.URL != nil && !validate.IsURL(*payload.URL) { + return errors.New("invalid repository URL. Must correspond to a valid URL format") + } + return nil } diff --git a/api/http/handler/helm/helm_repo_search.go b/api/http/handler/helm/helm_repo_search.go index 42600ca4bf..7f940b01ed 100644 --- a/api/http/handler/helm/helm_repo_search.go +++ b/api/http/handler/helm/helm_repo_search.go @@ -8,6 +8,7 @@ import ( "github.com/portainer/portainer/pkg/libhelm/options" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/rs/zerolog/log" "github.com/pkg/errors" @@ -45,6 +46,10 @@ func (handler *Handler) helmRepoSearch(w http.ResponseWriter, r *http.Request) * return httperror.BadRequest("Bad request", errors.Wrap(err, fmt.Sprintf("provided URL %q is not valid", repo))) } + if err := ssrf.CheckURL(r.Context(), repo); err != nil { + return httperror.BadRequest("Repository URL blocked by SSRF policy", err) + } + searchOpts := options.SearchRepoOptions{ Repo: repo, Chart: chart, @@ -53,7 +58,8 @@ func (handler *Handler) helmRepoSearch(w http.ResponseWriter, r *http.Request) * result, err := handler.helmPackageManager.SearchRepo(searchOpts) if err != nil { - return httperror.InternalServerError("Search failed", err) + log.Warn().Err(err).Str("repo", repo).Msg("helm repo search failed") + return httperror.InternalServerError("Search failed", errors.New("failed to search Helm repository")) } w.Header().Set("Content-Type", "text/plain") diff --git a/api/http/handler/helm/helm_show.go b/api/http/handler/helm/helm_show.go index 3e706558da..1c5860ba4a 100644 --- a/api/http/handler/helm/helm_show.go +++ b/api/http/handler/helm/helm_show.go @@ -8,6 +8,7 @@ import ( "github.com/portainer/portainer/pkg/libhelm/options" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/pkg/errors" "github.com/rs/zerolog/log" @@ -41,6 +42,10 @@ func (handler *Handler) helmShow(w http.ResponseWriter, r *http.Request) *httper return httperror.BadRequest("Bad request", errors.Wrap(err, fmt.Sprintf("provided URL %q is not valid", repo))) } + if err := ssrf.CheckURL(r.Context(), repo); err != nil { + return httperror.BadRequest("Repository URL blocked by SSRF policy", err) + } + chart := r.URL.Query().Get("chart") if chart == "" { return httperror.BadRequest("Bad request", errors.New("missing `chart` query parameter")) @@ -65,7 +70,8 @@ func (handler *Handler) helmShow(w http.ResponseWriter, r *http.Request) *httper } result, err := handler.helmPackageManager.Show(showOptions) if err != nil { - return httperror.InternalServerError("Unable to show chart", err) + log.Warn().Err(err).Str("repo", repo).Str("chart", chart).Msg("helm show failed") + return httperror.InternalServerError("Unable to show chart", errors.New("failed to retrieve Helm chart information")) } w.Header().Set("Content-Type", "text/plain") diff --git a/api/http/handler/settings/settings_update.go b/api/http/handler/settings/settings_update.go index f0f0064dbd..4a24bf93f2 100644 --- a/api/http/handler/settings/settings_update.go +++ b/api/http/handler/settings/settings_update.go @@ -14,6 +14,7 @@ import ( httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/response" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/portainer/portainer/pkg/validate" "github.com/pkg/errors" @@ -74,6 +75,12 @@ func (payload *settingsUpdatePayload) Validate(r *http.Request) error { return errors.New("Invalid Helm repository URL. Must correspond to a valid URL format") } + if payload.HelmRepositoryURL != nil && *payload.HelmRepositoryURL != "" { + if err := ssrf.CheckURL(r.Context(), *payload.HelmRepositoryURL); err != nil { + return errors.New("Invalid Helm repository URL. Must correspond to a valid URL format") + } + } + if payload.UserSessionTimeout != nil { if _, err := time.ParseDuration(*payload.UserSessionTimeout); err != nil { return errors.New("Invalid user session timeout") diff --git a/api/http/handler/stacks/create_kubernetes_stack.go b/api/http/handler/stacks/create_kubernetes_stack.go index b179641526..d3daada50a 100644 --- a/api/http/handler/stacks/create_kubernetes_stack.go +++ b/api/http/handler/stacks/create_kubernetes_stack.go @@ -14,6 +14,7 @@ import ( httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/response" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/portainer/portainer/pkg/validate" "github.com/pkg/errors" @@ -125,6 +126,10 @@ func (payload *kubernetesManifestURLDeploymentPayload) Validate(r *http.Request) return errors.New("Invalid manifest URL") } + if err := ssrf.CheckURL(r.Context(), payload.ManifestURL); err != nil { + return err + } + return nil } diff --git a/api/http/handler/users/user_helm_repos.go b/api/http/handler/users/user_helm_repos.go index e8490b2e01..0bb37d0c24 100644 --- a/api/http/handler/users/user_helm_repos.go +++ b/api/http/handler/users/user_helm_repos.go @@ -11,6 +11,7 @@ import ( httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/response" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/pkg/errors" ) @@ -24,7 +25,11 @@ type addHelmRepoUrlPayload struct { URL string `json:"url"` } -func (p *addHelmRepoUrlPayload) Validate(_ *http.Request) error { +func (p *addHelmRepoUrlPayload) Validate(r *http.Request) error { + if err := ssrf.CheckURL(r.Context(), p.URL); err != nil { + return err + } + return libhelm.ValidateHelmRepositoryURL(p.URL, nil) } diff --git a/api/http/proxy/factory/agent.go b/api/http/proxy/factory/agent.go index ac6d6892ed..fc252536df 100644 --- a/api/http/proxy/factory/agent.go +++ b/api/http/proxy/factory/agent.go @@ -52,14 +52,14 @@ func (factory *ProxyFactory) NewAgentProxy(endpoint *portainer.Endpoint) (*Proxy endpointURL.Scheme = "https" if endpointutils.IsEdgeEndpoint(endpoint) { - innerTransport = ssrf.WrapTransportInternal(&http.Transport{TLSClientConfig: tlsConfig}) + innerTransport = ssrf.NewInternalTransport(tlsConfig) } else { - innerTransport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig}) + innerTransport = ssrf.NewTransport(tlsConfig) } } else if endpointutils.IsEdgeEndpoint(endpoint) { - innerTransport = ssrf.WrapTransportInternal(&http.Transport{}) + innerTransport = ssrf.NewInternalTransport(nil) } else { - innerTransport = ssrf.WrapTransport(&http.Transport{}) + innerTransport = ssrf.NewTransport(nil) } proxy := NewSingleHostReverseProxyWithHostHeader(endpointURL) diff --git a/api/http/proxy/factory/docker.go b/api/http/proxy/factory/docker.go index f49b093742..b5d92df0df 100644 --- a/api/http/proxy/factory/docker.go +++ b/api/http/proxy/factory/docker.go @@ -68,14 +68,14 @@ func (factory *ProxyFactory) newDockerHTTPProxy(endpoint *portainer.Endpoint) (h endpointURL.Scheme = "https" if endpointutils.IsEdgeEndpoint(endpoint) { - innerTransport = ssrf.WrapTransportInternal(&http.Transport{TLSClientConfig: tlsConfig}) + innerTransport = ssrf.NewInternalTransport(tlsConfig) } else { - innerTransport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig}) + innerTransport = ssrf.NewTransport(tlsConfig) } } else if endpointutils.IsEdgeEndpoint(endpoint) { - innerTransport = ssrf.WrapTransportInternal(&http.Transport{}) + innerTransport = ssrf.NewInternalTransport(nil) } else { - innerTransport = ssrf.WrapTransport(&http.Transport{}) + innerTransport = ssrf.NewTransport(nil) } dockerTransport, err := docker.NewTransport(transportParameters, innerTransport, factory.gitService, factory.snapshotService) diff --git a/api/http/proxy/factory/docker/transport.go b/api/http/proxy/factory/docker/transport.go index e1e8004130..bef0ac7b0a 100644 --- a/api/http/proxy/factory/docker/transport.go +++ b/api/http/proxy/factory/docker/transport.go @@ -22,6 +22,7 @@ import ( "github.com/portainer/portainer/api/internal/authorization" "github.com/portainer/portainer/api/logs" "github.com/portainer/portainer/api/slicesx" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/docker/docker/api/types/network" "github.com/docker/docker/api/types/swarm" @@ -506,6 +507,11 @@ func (transport *Transport) updateDefaultGitBranch(request *http.Request) error } repositoryURL := remote[:len(remote)-4] + + if err := ssrf.CheckURL(request.Context(), repositoryURL); err != nil { + return err + } + latestCommitID, err := transport.gitService.LatestCommitID( request.Context(), repositoryURL, diff --git a/api/http/proxy/factory/docker_unix.go b/api/http/proxy/factory/docker_unix.go index b10ef56b29..5cc9621d00 100644 --- a/api/http/proxy/factory/docker_unix.go +++ b/api/http/proxy/factory/docker_unix.go @@ -3,11 +3,13 @@ package factory import ( + "context" "net" "net/http" portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/http/proxy/factory/docker" + "github.com/portainer/portainer/pkg/libhttp/ssrf" ) func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portainer.Endpoint) (http.Handler, error) { @@ -31,9 +33,11 @@ func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portaine } func newSocketTransport(socketPath string) *http.Transport { - return &http.Transport{ - Dial: func(proto, addr string) (conn net.Conn, err error) { - return net.Dial("unix", socketPath) - }, + d := &net.Dialer{} + t := ssrf.NewInternalTransport(nil) + t.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + return d.DialContext(ctx, "unix", socketPath) } + + return t } diff --git a/api/http/proxy/factory/docker_windows.go b/api/http/proxy/factory/docker_windows.go index d6f6f6a278..08feaf2174 100644 --- a/api/http/proxy/factory/docker_windows.go +++ b/api/http/proxy/factory/docker_windows.go @@ -3,12 +3,14 @@ package factory import ( + "context" "net" "net/http" "github.com/Microsoft/go-winio" portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/http/proxy/factory/docker" + "github.com/portainer/portainer/pkg/libhttp/ssrf" ) func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portainer.Endpoint) (http.Handler, error) { @@ -32,9 +34,10 @@ func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portaine } func newNamedPipeTransport(namedPipePath string) *http.Transport { - return &http.Transport{ - Dial: func(proto, addr string) (conn net.Conn, err error) { - return winio.DialPipe(namedPipePath, nil) - }, + t := ssrf.NewInternalTransport(nil) + t.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { + return winio.DialPipe(namedPipePath, nil) } + + return t } diff --git a/api/http/proxy/factory/github/client.go b/api/http/proxy/factory/github/client.go index d455d79f27..84291fb176 100644 --- a/api/http/proxy/factory/github/client.go +++ b/api/http/proxy/factory/github/client.go @@ -94,7 +94,7 @@ func NewHTTPClient(token string) *http.Client { return &http.Client{ Transport: &tokenTransport{ token: token, - transport: retry.NewTransport(ssrf.WrapTransport(&http.Transport{})), // Use ORAS retry transport for consistent rate limiting and error handling + transport: retry.NewTransport(ssrf.NewTransport(nil)), // Use ORAS retry transport for consistent rate limiting and error handling }, Timeout: 1 * time.Minute, } diff --git a/api/http/proxy/factory/gitlab/client.go b/api/http/proxy/factory/gitlab/client.go index d4deb18f38..e3b5b2efa6 100644 --- a/api/http/proxy/factory/gitlab/client.go +++ b/api/http/proxy/factory/gitlab/client.go @@ -94,7 +94,7 @@ type Transport struct { // interface for proxying requests to the Gitlab API. func NewTransport() *Transport { return &Transport{ - httpTransport: ssrf.WrapTransport(&http.Transport{}), + httpTransport: ssrf.NewTransport(nil), } } @@ -119,7 +119,7 @@ func NewHTTPClient(token string) *http.Client { return &http.Client{ Transport: &tokenTransport{ token: token, - transport: retry.NewTransport(ssrf.WrapTransport(&http.Transport{})), // Use ORAS retry transport for consistent rate limiting and error handling + transport: retry.NewTransport(ssrf.NewTransport(nil)), // Use ORAS retry transport for consistent rate limiting and error handling }, Timeout: 1 * time.Minute, } diff --git a/api/http/proxy/factory/kubernetes/agent_transport.go b/api/http/proxy/factory/kubernetes/agent_transport.go index c8232b1b88..1cdfd126c5 100644 --- a/api/http/proxy/factory/kubernetes/agent_transport.go +++ b/api/http/proxy/factory/kubernetes/agent_transport.go @@ -25,9 +25,7 @@ func NewAgentTransport(signatureService portainer.DigitalSignatureService, token transport := &agentTransport{ baseTransport: newBaseTransport( - ssrf.WrapTransport(&http.Transport{ - TLSClientConfig: tlsConfig, - }), + ssrf.NewTransport(tlsConfig), tokenManager, endpoint, k8sClientFactory, diff --git a/api/http/proxy/factory/kubernetes/edge_transport.go b/api/http/proxy/factory/kubernetes/edge_transport.go index cd817eb6b4..3d07df613d 100644 --- a/api/http/proxy/factory/kubernetes/edge_transport.go +++ b/api/http/proxy/factory/kubernetes/edge_transport.go @@ -22,7 +22,7 @@ func NewEdgeTransport(dataStore dataservices.DataStore, signatureService portain reverseTunnelService: reverseTunnelService, signatureService: signatureService, baseTransport: newBaseTransport( - ssrf.WrapTransportInternal(&http.Transport{}), + ssrf.NewInternalTransport(nil), tokenManager, endpoint, k8sClientFactory, diff --git a/api/http/proxy/factory/kubernetes/local_transport.go b/api/http/proxy/factory/kubernetes/local_transport.go index 6c22c5dc5c..bd177df491 100644 --- a/api/http/proxy/factory/kubernetes/local_transport.go +++ b/api/http/proxy/factory/kubernetes/local_transport.go @@ -23,9 +23,7 @@ func NewLocalTransport(tokenManager *tokenManager, endpoint *portainer.Endpoint, transport := &localTransport{ baseTransport: newBaseTransport( - ssrf.WrapTransportInternal(&http.Transport{ - TLSClientConfig: config, - }), + ssrf.NewInternalTransport(config), tokenManager, endpoint, k8sClientFactory, diff --git a/api/http/proxy/factory/transport_test.go b/api/http/proxy/factory/transport_test.go index 44e5bc6c29..ff5b38b804 100644 --- a/api/http/proxy/factory/transport_test.go +++ b/api/http/proxy/factory/transport_test.go @@ -2,6 +2,8 @@ package factory import ( "context" + "net/http" + "net/http/httptest" "net/http/httputil" "testing" "time" @@ -65,8 +67,6 @@ func enableSSRF(t *testing.T) { }) } -// TestNewDockerHTTPProxy_NonEdgeNoTLS verifies that a plain non-edge endpoint -// uses WrapTransport, setting DialContext on the inner transport. func TestNewDockerHTTPProxy_NonEdgeNoTLS(t *testing.T) { enableSSRF(t) @@ -84,8 +84,6 @@ func TestNewDockerHTTPProxy_NonEdgeNoTLS(t *testing.T) { require.NotNil(t, dt.HTTPTransport.DialContext) } -// TestNewDockerHTTPProxy_NonEdgeTLS verifies that a TLS non-edge endpoint -// uses WrapTransport, setting DialContext on the inner transport. func TestNewDockerHTTPProxy_NonEdgeTLS(t *testing.T) { enableSSRF(t) @@ -107,11 +105,14 @@ func TestNewDockerHTTPProxy_NonEdgeTLS(t *testing.T) { require.NotNil(t, dt.HTTPTransport.DialContext) } -// TestNewDockerHTTPProxy_EdgeNoTLS verifies that an edge endpoint without TLS -// uses WrapTransportInternal, leaving DialContext nil. func TestNewDockerHTTPProxy_EdgeNoTLS(t *testing.T) { enableSSRF(t) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}} endpoint := &portainer.Endpoint{ Type: portainer.EdgeAgentOnDockerEnvironment, @@ -123,14 +124,20 @@ func TestNewDockerHTTPProxy_EdgeNoTLS(t *testing.T) { proxy := handler.(*httputil.ReverseProxy) dt := proxy.Transport.(*docker.Transport) - require.Nil(t, dt.HTTPTransport.DialContext) + + resp, err := (&http.Client{Transport: dt.HTTPTransport}).Get(srv.URL) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) } -// TestNewDockerHTTPProxy_EdgeTLS verifies that an edge endpoint with TLS -// uses WrapTransportInternal, leaving DialContext nil. func TestNewDockerHTTPProxy_EdgeTLS(t *testing.T) { enableSSRF(t) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}} endpoint := &portainer.Endpoint{ Type: portainer.EdgeAgentOnDockerEnvironment, @@ -146,7 +153,10 @@ func TestNewDockerHTTPProxy_EdgeTLS(t *testing.T) { proxy := handler.(*httputil.ReverseProxy) dt := proxy.Transport.(*docker.Transport) - require.Nil(t, dt.HTTPTransport.DialContext) + + resp, err := (&http.Client{Transport: dt.HTTPTransport}).Get(srv.URL) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) } func TestNewAgentProxy_NonEdgeNoTLS(t *testing.T) { diff --git a/api/stacks/stackbuilders/stack_git_builder.go b/api/stacks/stackbuilders/stack_git_builder.go index 66b9ddd77e..f5c9721cd1 100644 --- a/api/stacks/stackbuilders/stack_git_builder.go +++ b/api/stacks/stackbuilders/stack_git_builder.go @@ -13,6 +13,7 @@ import ( "github.com/portainer/portainer/api/scheduler" "github.com/portainer/portainer/api/stacks/deployments" "github.com/portainer/portainer/api/stacks/stackutils" + "github.com/portainer/portainer/pkg/libhttp/ssrf" ) type GitMethodStackBuilder struct { @@ -78,6 +79,10 @@ func (b *GitMethodStackBuilder) prepare(ctx context.Context, payload *StackPaylo return b.fileService.GetStackProjectPath(stackFolder) } + if err := ssrf.CheckURL(ctx, repoConfig.URL); err != nil { + return fmt.Errorf("repository URL blocked by SSRF policy: %w", err) + } + commitHash, err := stackutils.DownloadGitRepository(ctx, repoConfig, b.gitService, getProjectPath) if err != nil { return fmt.Errorf("failed to download git repository: %w", err) diff --git a/pkg/libhelm/sdk/search_repo.go b/pkg/libhelm/sdk/search_repo.go index dee1443d27..506b75401d 100644 --- a/pkg/libhelm/sdk/search_repo.go +++ b/pkg/libhelm/sdk/search_repo.go @@ -14,6 +14,7 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/logs" "github.com/portainer/portainer/pkg/libhelm/options" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/portainer/portainer/pkg/liboras" "github.com/rs/zerolog/log" "github.com/segmentio/encoding/json" @@ -216,13 +217,15 @@ func downloadRepoIndexFromHttpRepo(repoURLString string, repoSettings *cli.EnvSe Str("repo_name", repoName). Msg("Creating chart repository object") + ssrfTransport := ssrf.NewTransport(nil) + // Create chart repository object rep, err := repo.NewChartRepository( &repo.Entry{ Name: repoName, URL: repoURLString, }, - getter.All(repoSettings), + getter.All(repoSettings, getter.WithTransport(ssrfTransport)), ) if err != nil { log.Error(). diff --git a/pkg/libhelm/validate_repo.go b/pkg/libhelm/validate_repo.go index a9d3a810a6..9491233033 100644 --- a/pkg/libhelm/validate_repo.go +++ b/pkg/libhelm/validate_repo.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/portainer/portainer/pkg/libhelm/sdk" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "helm.sh/helm/v4/pkg/cli" "helm.sh/helm/v4/pkg/getter" repo "helm.sh/helm/v4/pkg/repo/v1" @@ -40,12 +41,14 @@ func ValidateHelmRepositoryURL(repoUrl string, _ *http.Client) error { return fmt.Errorf("failed to derive repo name: %w", err) } + ssrfTransport := ssrf.NewTransport(nil) + r, err := repo.NewChartRepository( &repo.Entry{ Name: repoName, URL: repoUrl, }, - getter.All(settings), + getter.All(settings, getter.WithTransport(ssrfTransport)), ) if err != nil { return fmt.Errorf("%s is not a valid chart repository or cannot be reached: %w", repoUrl, err) @@ -53,7 +56,7 @@ func ValidateHelmRepositoryURL(repoUrl string, _ *http.Client) error { indexPath, err := r.DownloadIndexFile() if err != nil { - return fmt.Errorf("%s is not a valid chart repository or cannot be reached: %w", repoUrl, err) + return fmt.Errorf("%s is not a valid chart repository or cannot be reached", repoUrl) } // Best-effort: load and seed in-memory cache for future SearchRepo calls diff --git a/pkg/libhttp/ssrf/builder.go b/pkg/libhttp/ssrf/builder.go new file mode 100644 index 0000000000..efcd1fb544 --- /dev/null +++ b/pkg/libhttp/ssrf/builder.go @@ -0,0 +1,53 @@ +package ssrf + +import ( + "crypto/tls" + "net/http" +) + +// NewTransport creates an SSRF-protected transport for user-influenced destinations. +// It clones http.DefaultTransport as its base (inheriting pool and timeout defaults) +// and applies the global SSRF dial-context filter so mode changes take effect without +// restarting. tlsConfig may be nil to preserve standard TLS behavior (system CAs). +func NewTransport(tlsConfig *tls.Config) *http.Transport { + base := http.DefaultTransport.(*http.Transport).Clone() + base.TLSClientConfig = tlsConfig + applySSRF(base) + + return base +} + +// NewInternalTransport creates a plain transport for destinations chosen by Portainer, +// not by the user (Docker socket proxy, Chisel tunnels, in-cluster Kubernetes API). +// It clones http.DefaultTransport as its base. tlsConfig may be nil. +// Using this function instead of NewTransport makes the exemption explicit and +// satisfies the ruleguard lint rule. +func NewInternalTransport(tlsConfig *tls.Config) *http.Transport { + base := http.DefaultTransport.(*http.Transport).Clone() + base.TLSClientConfig = tlsConfig + + return base +} + +// WrapDefaultTransport replaces http.DefaultTransport with an SSRF-protected version. +// Must be called after Configure. Returns false if DefaultTransport is not an *http.Transport. +func WrapDefaultTransport() bool { + dt, ok := http.DefaultTransport.(*http.Transport) + if !ok { + return false + } + + cloned := dt.Clone() + applySSRF(cloned) + http.DefaultTransport = cloned + + return true +} + +// applySSRF sets the SSRF-filtering DialContext on t when the global dialer is active. +func applySSRF(t *http.Transport) { + d := globalDialer.Load() + if d != nil { + t.DialContext = d.DialContext + } +} diff --git a/pkg/libhttp/ssrf/ssrf.go b/pkg/libhttp/ssrf/ssrf.go index 1a35f95700..dcacf54fdd 100644 --- a/pkg/libhttp/ssrf/ssrf.go +++ b/pkg/libhttp/ssrf/ssrf.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net" - "net/http" "net/url" "strings" "sync/atomic" @@ -112,31 +111,6 @@ func CheckURL(ctx context.Context, rawURL string) error { return d.checkHost(ctx, host) } -// WrapTransport clones t and replaces its DialContext with the global SSRF-filtering -// dialer. The dialer checks the mode on every connection, so the transport is always -// wrapped and mode changes take effect without restarting. -func WrapTransport(t *http.Transport) *http.Transport { - d := globalDialer.Load() - if d == nil { - return t - } - - cloned := t.Clone() - cloned.DialContext = d.DialContext - - return cloned -} - -// WrapTransportInternal is a documented no-op for transports that connect to -// internally computed destinations (local Docker socket proxy, Chisel tunnels, -// in-cluster Kubernetes API). The destination is chosen by Portainer, not -// supplied by any user, so SSRF validation is not applicable. Using this -// function instead of WrapTransport makes the exemption explicit and -// satisfies the ruleguard lint rule. -func WrapTransportInternal(t *http.Transport) *http.Transport { - return t -} - // DialContext resolves addr, validates all resolved IPs against the allowlist policy, // then dials using the first resolved IP to prevent DNS rebinding attacks. func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { diff --git a/pkg/libhttp/ssrf/ssrf_test.go b/pkg/libhttp/ssrf/ssrf_test.go index 2439957b29..b3298e7cc6 100644 --- a/pkg/libhttp/ssrf/ssrf_test.go +++ b/pkg/libhttp/ssrf/ssrf_test.go @@ -116,22 +116,27 @@ func TestConfigure_NilServicesReturnsError(t *testing.T) { require.Error(t, err) } -func TestWrapTransport_NoPolicy(t *testing.T) { +func TestNewTransport_NoPolicy(t *testing.T) { globalDialer.Store(nil) + t.Cleanup(func() { globalDialer.Store(nil) }) - base := &http.Transport{} - result := WrapTransport(base) - require.Equal(t, base, result) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := &http.Client{Transport: NewTransport(nil)} + resp, err := client.Get(srv.URL) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) } -func TestWrapTransport_WithPolicy(t *testing.T) { +func TestNewTransport_WithPolicy(t *testing.T) { err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"})) require.NoError(t, err) t.Cleanup(func() { globalDialer.Store(nil) }) - base := &http.Transport{} - result := WrapTransport(base) - require.NotEqual(t, base, result) + result := NewTransport(nil) require.NotNil(t, result.DialContext) } @@ -267,12 +272,12 @@ func TestIsEnabled(t *testing.T) { require.False(t, IsEnabled()) } -func TestWrapTransportInternal(t *testing.T) { +func TestNewInternalTransport(t *testing.T) { t.Parallel() - base := &http.Transport{} - result := WrapTransportInternal(base) - require.Equal(t, base, result) + result := NewInternalTransport(nil) + require.NotNil(t, result) + require.Nil(t, result.TLSClientConfig) } // TestDialContext_BlocksLoopback is an end-to-end test: it starts a real HTTP @@ -288,7 +293,7 @@ func TestDialContext_BlocksLoopback(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { globalDialer.Store(nil) }) - blocked := &http.Client{Transport: WrapTransport(&http.Transport{})} + blocked := &http.Client{Transport: NewTransport(nil)} resp, err := blocked.Get(srv.URL) require.Error(t, err) require.Contains(t, err.Error(), "ssrf") @@ -300,7 +305,7 @@ func TestDialContext_BlocksLoopback(t *testing.T) { err = Configure(newStaticService(portainer.SSRFModeOff, nil)) require.NoError(t, err) - open := &http.Client{Transport: WrapTransport(&http.Transport{})} + open := &http.Client{Transport: NewTransport(nil)} resp, err = open.Get(srv.URL) require.NoError(t, err) require.NoError(t, resp.Body.Close()) @@ -318,7 +323,7 @@ func TestDialContext_AuditMode_AllowsLoopback(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { globalDialer.Store(nil) }) - client := &http.Client{Transport: WrapTransport(&http.Transport{})} + client := &http.Client{Transport: NewTransport(nil)} resp, err := client.Get(srv.URL) require.NoError(t, err) require.NoError(t, resp.Body.Close()) @@ -364,7 +369,7 @@ func TestDialContext_AllowedByCIDR(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { globalDialer.Store(nil) }) - client := &http.Client{Transport: WrapTransport(&http.Transport{})} + client := &http.Client{Transport: NewTransport(nil)} resp, err := client.Get(srv.URL) require.NoError(t, err) require.NoError(t, resp.Body.Close()) @@ -398,7 +403,7 @@ func TestDialContext_AllowedByExactHostname(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { globalDialer.Store(nil) }) - client := &http.Client{Transport: WrapTransport(&http.Transport{})} + client := &http.Client{Transport: NewTransport(nil)} resp, err := client.Get("http://localhost:" + portStr) require.NoError(t, err) require.NoError(t, resp.Body.Close()) diff --git a/pkg/networking/diagnostics.go b/pkg/networking/diagnostics.go index 84491c8758..da1be7fdcc 100644 --- a/pkg/networking/diagnostics.go +++ b/pkg/networking/diagnostics.go @@ -10,6 +10,7 @@ import ( "github.com/portainer/portainer/api/crypto" "github.com/portainer/portainer/api/logs" + "github.com/portainer/portainer/pkg/libhttp/ssrf" "github.com/rs/zerolog/log" "github.com/segmentio/encoding/json" @@ -73,10 +74,8 @@ func ProbeTelnetConnection(url string) string { // ignores errors for the http request since we want to know if the host is reachable func DetectProxy(url string) string { client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: crypto.CreateTLSConfiguration(true), - }, - Timeout: 10 * time.Second, + Transport: ssrf.NewTransport(crypto.CreateTLSConfiguration(true)), + Timeout: 10 * time.Second, } result := map[string]string{ diff --git a/pkg/registryhttp/client.go b/pkg/registryhttp/client.go index cedab57b11..0c13710c46 100644 --- a/pkg/registryhttp/client.go +++ b/pkg/registryhttp/client.go @@ -16,8 +16,7 @@ import ( func CreateClient(registry *portainer.Registry) (httpClient *http.Client, usePlainHttp bool, err error) { switch registry.Type { case portainer.AzureRegistry, portainer.EcrRegistry, portainer.GithubRegistry, portainer.GitlabRegistry, portainer.DockerHubRegistry: - base := http.DefaultTransport.(*http.Transport).Clone() - return &http.Client{Transport: retry.NewTransport(ssrf.WrapTransport(base))}, false, nil + return &http.Client{Transport: retry.NewTransport(ssrf.NewTransport(nil))}, false, nil default: // For all other registry types, use shared helper to build transport and scheme diff --git a/pkg/registryhttp/transport.go b/pkg/registryhttp/transport.go index 9a3255ac66..8ce68e4e10 100644 --- a/pkg/registryhttp/transport.go +++ b/pkg/registryhttp/transport.go @@ -6,28 +6,27 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/crypto" "github.com/portainer/portainer/pkg/fips" + "github.com/portainer/portainer/pkg/libhttp/ssrf" ) -// BuildTransportAndSchemeFromTLSConfig returns a base HTTP transport configured -// with ProxyFromEnvironment and, when needed, a TLSClientConfig derived from the -// provided TLS settings. It also returns the scheme ("http" or "https") that -// should be used to contact the registry based on the TLS settings. +// BuildTransportAndSchemeFromTLSConfig returns an SSRF-protected HTTP transport and the +// scheme ("http" or "https") to use when contacting the registry. The transport is based on +// the TLS settings from tlsCfg; pass a zero-value TLSConfiguration for plaintext. func BuildTransportAndSchemeFromTLSConfig(tlsCfg portainer.TLSConfiguration) (*http.Transport, string, error) { - baseTransport := http.DefaultTransport.(*http.Transport).Clone() - baseTransport.Proxy = http.ProxyFromEnvironment - tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(tlsCfg) if err != nil { return nil, "", err } - baseTransport.TLSClientConfig = tlsConfig - if tlsConfig == nil && fips.FIPSMode() { return nil, "", fips.ErrTLSRequired - } else if tlsConfig == nil { - return baseTransport, "http", nil } - return baseTransport, "https", nil + transport := ssrf.NewTransport(tlsConfig) + + if tlsConfig == nil { + return transport, "http", nil + } + + return transport, "https", nil }