mirror of
https://github.com/portainer/portainer.git
synced 2026-06-23 04:50:12 +00:00
feat(ssrf): add missing transport wrappings and more checks BE-13021 (#2968)
This commit is contained in:
@@ -108,6 +108,9 @@ linters:
|
||||
linters:
|
||||
- gocritic
|
||||
text: ruleguard
|
||||
- path: pkg/libhttp/ssrf/builder\.go
|
||||
linters:
|
||||
- forbidigo
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
|
||||
+58
-12
@@ -4,26 +4,72 @@ package gorules
|
||||
|
||||
import "github.com/quasilyte/go-ruleguard/dsl"
|
||||
|
||||
// unwrappedHTTPTransport flags http.Transport composite literals that are not
|
||||
// the direct argument to ssrf.WrapTransport.
|
||||
// unwrappedHTTPTransport flags any bare http.Transport composite literal.
|
||||
// All transports must be created via ssrf.NewTransport or ssrf.NewInternalTransport,
|
||||
// which clone http.DefaultTransport and handle SSRF protection internally.
|
||||
func unwrappedHTTPTransport(m dsl.Matcher) {
|
||||
// Inline construction passed to a function call.
|
||||
m.Match(`$f(&http.Transport{$*_})`).
|
||||
Where(m["f"].Text != "ssrf.WrapTransport" && m["f"].Text != "WrapTransport" &&
|
||||
m["f"].Text != "ssrf.WrapTransportInternal" && m["f"].Text != "WrapTransportInternal").
|
||||
Report(`$f receives a bare *http.Transport; wrap with ssrf.WrapTransport() to enforce the SSRF protection policy`)
|
||||
Report(`$f receives a bare *http.Transport; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`)
|
||||
|
||||
// Variable assigned a bare transport (cannot be tracked to a later WrapTransport call).
|
||||
m.Match(`$_ := &http.Transport{$*_}`).
|
||||
Report(`bare *http.Transport variable; use ssrf.WrapTransport(&http.Transport{...}) inline instead`)
|
||||
Report(`bare *http.Transport variable; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`)
|
||||
|
||||
m.Match(`$_.Transport = &http.Transport{$*_}`).
|
||||
Report(`bare *http.Transport field assignment; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`)
|
||||
}
|
||||
|
||||
// internalTransportMisuse flags calls to WrapTransportInternal outside the four proxy
|
||||
// helmGetterTransport flags getter.WithTransport calls that receive a bare *http.Transport.
|
||||
// Helm v4 installs its own transport and bypasses http.DefaultTransport, so the transport
|
||||
// passed here must be created via ssrf.NewTransport.
|
||||
func helmGetterTransport(m dsl.Matcher) {
|
||||
m.Match(`getter.WithTransport(&http.Transport{$*_})`).
|
||||
Report(`getter.WithTransport called with a bare *http.Transport; use ssrf.NewTransport(tlsConfig) as Helm v4 bypasses http.DefaultTransport`)
|
||||
}
|
||||
|
||||
// cloneDefaultTransport flags direct clones of *http.Transport outside main.go.
|
||||
// The one legitimate clone is in main.go where http.DefaultTransport is globally
|
||||
// wrapped with SSRF protection at server startup.
|
||||
func cloneDefaultTransport(m dsl.Matcher) {
|
||||
m.Match(`$_.(*http.Transport).Clone()`).
|
||||
Where(!m.File().Name.Matches(`^main\.go$`)).
|
||||
Report(`cloning *http.Transport directly is forbidden; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`)
|
||||
}
|
||||
|
||||
// internalTransportMisuse flags calls to NewInternalTransport outside the proxy
|
||||
// factory files where Chisel-tunnel and in-cluster K8s destinations are valid exemptions.
|
||||
func internalTransportMisuse(m dsl.Matcher) {
|
||||
m.Match(`ssrf.WrapTransportInternal($*_)`).
|
||||
m.Match(`ssrf.NewInternalTransport($*_)`).
|
||||
Where(
|
||||
!(m.File().PkgPath.Matches(`proxy/factory`) &&
|
||||
m.File().Name.Matches(`^(docker|agent|local_transport|edge_transport)\.go$`))).
|
||||
Report(`WrapTransportInternal bypasses SSRF validation; only valid in the kubernetes local/edge transport constructors and the docker/agent proxy factories`)
|
||||
m.File().Name.Matches(`^(docker|agent|local_transport|edge_transport|docker_unix|docker_windows)\.go$`))).
|
||||
Report(`NewInternalTransport bypasses SSRF validation; only valid in the proxy factory files for local sockets and internally-routed endpoints`)
|
||||
}
|
||||
|
||||
// dialerOverride flags direct assignments to any of the dialer fields on a transport.
|
||||
// The only valid assignments are in docker_unix.go and docker_windows.go where a
|
||||
// custom dialer is required for unix sockets and named pipes.
|
||||
func dialerOverride(m dsl.Matcher) {
|
||||
m.Match(`$_.DialContext = $*_`).
|
||||
Where(
|
||||
!(m.File().PkgPath.Matches(`proxy/factory`) &&
|
||||
m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))).
|
||||
Report(`direct DialContext assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`)
|
||||
|
||||
m.Match(`$_.Dial = $*_`).
|
||||
Where(
|
||||
!(m.File().PkgPath.Matches(`proxy/factory`) &&
|
||||
m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))).
|
||||
Report(`direct Dial assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`)
|
||||
|
||||
m.Match(`$_.DialTLSContext = $*_`).
|
||||
Where(
|
||||
!(m.File().PkgPath.Matches(`proxy/factory`) &&
|
||||
m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))).
|
||||
Report(`direct DialTLSContext assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`)
|
||||
|
||||
m.Match(`$_.DialTLS = $*_`).
|
||||
Where(
|
||||
!(m.File().PkgPath.Matches(`proxy/factory`) &&
|
||||
m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))).
|
||||
Report(`direct DialTLS assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/url"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -19,10 +21,14 @@ import (
|
||||
//
|
||||
// it sends a ping to the agent and parses the version and platform from the headers
|
||||
func GetAgentVersionAndPlatform(endpointUrl string, tlsConfig *tls.Config) (portainer.AgentPlatform, string, error) { //nolint:forbidigo
|
||||
if err := ssrf.CheckURL(context.Background(), endpointUrl); err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
httpCli := &http.Client{Timeout: 3 * time.Second}
|
||||
|
||||
if tlsConfig != nil {
|
||||
httpCli.Transport = &http.Transport{TLSClientConfig: tlsConfig}
|
||||
httpCli.Transport = ssrf.NewTransport(tlsConfig)
|
||||
}
|
||||
|
||||
parsedURL, err := url.ParseURL(endpointUrl + "/ping")
|
||||
|
||||
@@ -58,7 +58,10 @@ import (
|
||||
libswarm "github.com/portainer/portainer/pkg/libstack/swarm"
|
||||
"github.com/portainer/portainer/pkg/validate"
|
||||
|
||||
gogitclient "github.com/go-git/go-git/v5/plumbing/transport/client"
|
||||
gogitraw "github.com/go-git/go-git/v5/plumbing/transport/git"
|
||||
gogithttp "github.com/go-git/go-git/v5/plumbing/transport/http"
|
||||
gogitssh "github.com/go-git/go-git/v5/plumbing/transport/ssh"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -408,11 +411,14 @@ func buildServer(flags *portainer.CLIFlags, shutdownCtx context.Context, shutdow
|
||||
log.Fatal().Err(err).Msg("failed initializing ssrf service")
|
||||
}
|
||||
|
||||
if dt, ok := nethttp.DefaultTransport.(*nethttp.Transport); ok {
|
||||
nethttp.DefaultTransport = ssrf.WrapTransport(dt)
|
||||
if !ssrf.WrapDefaultTransport() {
|
||||
log.Fatal().Msg("failed to wrap default HTTP transport with SSRF protection")
|
||||
}
|
||||
|
||||
gogithttp.DefaultClient = gogithttp.NewClient(&nethttp.Client{Transport: nethttp.DefaultTransport})
|
||||
gogitclient.InstallProtocol("git", git.NewSSRFGitTransport(gogitraw.DefaultClient))
|
||||
gogitclient.InstallProtocol("ssh", git.NewSSRFGitTransport(gogitssh.DefaultClient))
|
||||
gogitclient.InstallProtocol("file", nil)
|
||||
|
||||
instanceID, err := dataStore.Version().InstanceID()
|
||||
if err != nil {
|
||||
|
||||
@@ -194,11 +194,11 @@ func httpClient(endpoint *portainer.Endpoint, timeout *time.Duration) (*http.Cli
|
||||
}
|
||||
|
||||
transport = &NodeNameTransport{
|
||||
Transport: ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig}),
|
||||
Transport: ssrf.NewTransport(tlsConfig),
|
||||
}
|
||||
} else {
|
||||
transport = &NodeNameTransport{
|
||||
Transport: ssrf.WrapTransport(&http.Transport{}),
|
||||
Transport: ssrf.NewTransport(nil),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+1
-4
@@ -66,10 +66,7 @@ func NewAzureClient() *azureClient {
|
||||
|
||||
func newHttpClientForAzure(insecureSkipVerify bool) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: ssrf.WrapTransport(&http.Transport{
|
||||
TLSClientConfig: crypto.CreateTLSConfiguration(insecureSkipVerify),
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
}),
|
||||
Transport: ssrf.NewTransport(crypto.CreateTLSConfiguration(insecureSkipVerify)),
|
||||
Timeout: 300 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
package git
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
|
||||
gittransport "github.com/go-git/go-git/v5/plumbing/transport"
|
||||
)
|
||||
|
||||
const gitDefaultPort = 9418
|
||||
|
||||
// ssrfGitTransport wraps a git:// transport and validates the resolved IP
|
||||
// against the SSRF policy before establishing connections.
|
||||
type ssrfGitTransport struct {
|
||||
inner gittransport.Transport
|
||||
}
|
||||
|
||||
// NewSSRFGitTransport wraps inner and blocks connections to private IP ranges
|
||||
// according to the active SSRF policy.
|
||||
func NewSSRFGitTransport(inner gittransport.Transport) gittransport.Transport {
|
||||
return &ssrfGitTransport{inner: inner}
|
||||
}
|
||||
|
||||
func (t *ssrfGitTransport) NewUploadPackSession(ep *gittransport.Endpoint, auth gittransport.AuthMethod) (gittransport.UploadPackSession, error) {
|
||||
if err := checkEndpointSSRF(ep); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return t.inner.NewUploadPackSession(ep, auth)
|
||||
}
|
||||
|
||||
func (t *ssrfGitTransport) NewReceivePackSession(ep *gittransport.Endpoint, auth gittransport.AuthMethod) (gittransport.ReceivePackSession, error) {
|
||||
if err := checkEndpointSSRF(ep); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return t.inner.NewReceivePackSession(ep, auth)
|
||||
}
|
||||
|
||||
func checkEndpointSSRF(ep *gittransport.Endpoint) error {
|
||||
port := ep.Port
|
||||
if port <= 0 {
|
||||
port = gitDefaultPort
|
||||
}
|
||||
|
||||
rawURL := fmt.Sprintf("git://%s/", net.JoinHostPort(ep.Host, strconv.Itoa(port)))
|
||||
|
||||
return ssrf.CheckURL(context.Background(), rawURL)
|
||||
}
|
||||
@@ -125,9 +125,9 @@ func ExecutePingOperation(host string, tlsConfiguration portainer.TLSConfigurati
|
||||
}
|
||||
|
||||
scheme = "https"
|
||||
transport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig})
|
||||
transport = ssrf.NewTransport(tlsConfig)
|
||||
} else {
|
||||
transport = ssrf.WrapTransport(&http.Transport{})
|
||||
transport = ssrf.NewTransport(nil)
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
"github.com/portainer/portainer/pkg/validate"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -315,6 +316,10 @@ func (handler *Handler) createCustomTemplateFromGitRepository(r *http.Request) (
|
||||
return nil, httpErr
|
||||
}
|
||||
|
||||
if err := ssrf.CheckURL(r.Context(), gitConfig.URL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
commitHash, err := stackutils.DownloadGitRepository(context.TODO(), gitConfig, handler.GitService, getProjectPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/portainer/portainer/api/stacks/stackutils"
|
||||
"github.com/portainer/portainer/pkg/edge"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
"github.com/portainer/portainer/pkg/validate"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@@ -69,7 +70,6 @@ func (payload *edgeStackFromGitRepositoryPayload) Validate(r *http.Request) erro
|
||||
if len(payload.RepositoryURL) == 0 || !validate.IsURL(payload.RepositoryURL) {
|
||||
return httperrors.NewInvalidPayloadError("Invalid repository URL. Must correspond to a valid URL format")
|
||||
}
|
||||
|
||||
if payload.RepositoryAuthentication && len(payload.RepositoryPassword) == 0 {
|
||||
return httperrors.NewInvalidPayloadError("Invalid repository credentials. Password must be specified when authentication is enabled")
|
||||
}
|
||||
@@ -138,6 +138,10 @@ func (handler *Handler) createEdgeStackFromGitRepository(r *http.Request, tx dat
|
||||
return nil, httpErr
|
||||
}
|
||||
|
||||
if err := ssrf.CheckURL(r.Context(), repoConfig.URL); err != nil {
|
||||
return nil, errors.Wrap(err, "repository URL blocked by SSRF policy")
|
||||
}
|
||||
|
||||
stack.CreatedByUserId = fmt.Sprintf("%d", tokenData.ID)
|
||||
stack.CreatedBy = stackutils.SanitizeLabel(tokenData.Username)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
"github.com/portainer/portainer/pkg/validate"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -100,6 +101,10 @@ func (handler *Handler) gitOperationRepoFilePreview(w http.ResponseWriter, r *ht
|
||||
tlsSkipVerify = src.Git.TLSSkipVerify
|
||||
}
|
||||
|
||||
if err := ssrf.CheckURL(r.Context(), repoURL); err != nil {
|
||||
return httperror.BadRequest("Repository URL blocked by SSRF policy", err)
|
||||
}
|
||||
|
||||
projectPath, err := handler.fileService.GetTemporaryPath()
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to create temporary folder", err)
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
"github.com/portainer/portainer/pkg/validate"
|
||||
)
|
||||
|
||||
// GitAuthenticationPayload holds authentication parameters for a git source
|
||||
@@ -30,8 +31,8 @@ type GitSourceCreatePayload struct {
|
||||
|
||||
// Validate implements the portainer.Validatable interface
|
||||
func (payload *GitSourceCreatePayload) Validate(_ *http.Request) error {
|
||||
if strings.TrimSpace(payload.URL) == "" {
|
||||
return errors.New("url is required")
|
||||
if !validate.IsURL(payload.URL) {
|
||||
return errors.New("invalid repository URL. Must correspond to a valid URL format")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
"github.com/portainer/portainer/pkg/validate"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -35,6 +36,10 @@ type GitAuthenticationUpdatePayload struct {
|
||||
|
||||
// Validate implements the portainer.Validatable interface
|
||||
func (payload *GitSourceUpdatePayload) Validate(_ *http.Request) error {
|
||||
if payload.URL != nil && !validate.IsURL(*payload.URL) {
|
||||
return errors.New("invalid repository URL. Must correspond to a valid URL format")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/portainer/portainer/pkg/libhelm/options"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@@ -45,6 +46,10 @@ func (handler *Handler) helmRepoSearch(w http.ResponseWriter, r *http.Request) *
|
||||
return httperror.BadRequest("Bad request", errors.Wrap(err, fmt.Sprintf("provided URL %q is not valid", repo)))
|
||||
}
|
||||
|
||||
if err := ssrf.CheckURL(r.Context(), repo); err != nil {
|
||||
return httperror.BadRequest("Repository URL blocked by SSRF policy", err)
|
||||
}
|
||||
|
||||
searchOpts := options.SearchRepoOptions{
|
||||
Repo: repo,
|
||||
Chart: chart,
|
||||
@@ -53,7 +58,8 @@ func (handler *Handler) helmRepoSearch(w http.ResponseWriter, r *http.Request) *
|
||||
|
||||
result, err := handler.helmPackageManager.SearchRepo(searchOpts)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Search failed", err)
|
||||
log.Warn().Err(err).Str("repo", repo).Msg("helm repo search failed")
|
||||
return httperror.InternalServerError("Search failed", errors.New("failed to search Helm repository"))
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/portainer/portainer/pkg/libhelm/options"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -41,6 +42,10 @@ func (handler *Handler) helmShow(w http.ResponseWriter, r *http.Request) *httper
|
||||
return httperror.BadRequest("Bad request", errors.Wrap(err, fmt.Sprintf("provided URL %q is not valid", repo)))
|
||||
}
|
||||
|
||||
if err := ssrf.CheckURL(r.Context(), repo); err != nil {
|
||||
return httperror.BadRequest("Repository URL blocked by SSRF policy", err)
|
||||
}
|
||||
|
||||
chart := r.URL.Query().Get("chart")
|
||||
if chart == "" {
|
||||
return httperror.BadRequest("Bad request", errors.New("missing `chart` query parameter"))
|
||||
@@ -65,7 +70,8 @@ func (handler *Handler) helmShow(w http.ResponseWriter, r *http.Request) *httper
|
||||
}
|
||||
result, err := handler.helmPackageManager.Show(showOptions)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to show chart", err)
|
||||
log.Warn().Err(err).Str("repo", repo).Str("chart", chart).Msg("helm show failed")
|
||||
return httperror.InternalServerError("Unable to show chart", errors.New("failed to retrieve Helm chart information"))
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
"github.com/portainer/portainer/pkg/validate"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@@ -74,6 +75,12 @@ func (payload *settingsUpdatePayload) Validate(r *http.Request) error {
|
||||
return errors.New("Invalid Helm repository URL. Must correspond to a valid URL format")
|
||||
}
|
||||
|
||||
if payload.HelmRepositoryURL != nil && *payload.HelmRepositoryURL != "" {
|
||||
if err := ssrf.CheckURL(r.Context(), *payload.HelmRepositoryURL); err != nil {
|
||||
return errors.New("Invalid Helm repository URL. Must correspond to a valid URL format")
|
||||
}
|
||||
}
|
||||
|
||||
if payload.UserSessionTimeout != nil {
|
||||
if _, err := time.ParseDuration(*payload.UserSessionTimeout); err != nil {
|
||||
return errors.New("Invalid user session timeout")
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
"github.com/portainer/portainer/pkg/validate"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@@ -125,6 +126,10 @@ func (payload *kubernetesManifestURLDeploymentPayload) Validate(r *http.Request)
|
||||
return errors.New("Invalid manifest URL")
|
||||
}
|
||||
|
||||
if err := ssrf.CheckURL(r.Context(), payload.ManifestURL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
@@ -24,7 +25,11 @@ type addHelmRepoUrlPayload struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
func (p *addHelmRepoUrlPayload) Validate(_ *http.Request) error {
|
||||
func (p *addHelmRepoUrlPayload) Validate(r *http.Request) error {
|
||||
if err := ssrf.CheckURL(r.Context(), p.URL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return libhelm.ValidateHelmRepositoryURL(p.URL, nil)
|
||||
}
|
||||
|
||||
|
||||
@@ -52,14 +52,14 @@ func (factory *ProxyFactory) NewAgentProxy(endpoint *portainer.Endpoint) (*Proxy
|
||||
endpointURL.Scheme = "https"
|
||||
|
||||
if endpointutils.IsEdgeEndpoint(endpoint) {
|
||||
innerTransport = ssrf.WrapTransportInternal(&http.Transport{TLSClientConfig: tlsConfig})
|
||||
innerTransport = ssrf.NewInternalTransport(tlsConfig)
|
||||
} else {
|
||||
innerTransport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig})
|
||||
innerTransport = ssrf.NewTransport(tlsConfig)
|
||||
}
|
||||
} else if endpointutils.IsEdgeEndpoint(endpoint) {
|
||||
innerTransport = ssrf.WrapTransportInternal(&http.Transport{})
|
||||
innerTransport = ssrf.NewInternalTransport(nil)
|
||||
} else {
|
||||
innerTransport = ssrf.WrapTransport(&http.Transport{})
|
||||
innerTransport = ssrf.NewTransport(nil)
|
||||
}
|
||||
|
||||
proxy := NewSingleHostReverseProxyWithHostHeader(endpointURL)
|
||||
|
||||
@@ -68,14 +68,14 @@ func (factory *ProxyFactory) newDockerHTTPProxy(endpoint *portainer.Endpoint) (h
|
||||
endpointURL.Scheme = "https"
|
||||
|
||||
if endpointutils.IsEdgeEndpoint(endpoint) {
|
||||
innerTransport = ssrf.WrapTransportInternal(&http.Transport{TLSClientConfig: tlsConfig})
|
||||
innerTransport = ssrf.NewInternalTransport(tlsConfig)
|
||||
} else {
|
||||
innerTransport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig})
|
||||
innerTransport = ssrf.NewTransport(tlsConfig)
|
||||
}
|
||||
} else if endpointutils.IsEdgeEndpoint(endpoint) {
|
||||
innerTransport = ssrf.WrapTransportInternal(&http.Transport{})
|
||||
innerTransport = ssrf.NewInternalTransport(nil)
|
||||
} else {
|
||||
innerTransport = ssrf.WrapTransport(&http.Transport{})
|
||||
innerTransport = ssrf.NewTransport(nil)
|
||||
}
|
||||
|
||||
dockerTransport, err := docker.NewTransport(transportParameters, innerTransport, factory.gitService, factory.snapshotService)
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/portainer/portainer/api/internal/authorization"
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
"github.com/portainer/portainer/api/slicesx"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
|
||||
"github.com/docker/docker/api/types/network"
|
||||
"github.com/docker/docker/api/types/swarm"
|
||||
@@ -506,6 +507,11 @@ func (transport *Transport) updateDefaultGitBranch(request *http.Request) error
|
||||
}
|
||||
|
||||
repositoryURL := remote[:len(remote)-4]
|
||||
|
||||
if err := ssrf.CheckURL(request.Context(), repositoryURL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
latestCommitID, err := transport.gitService.LatestCommitID(
|
||||
request.Context(),
|
||||
repositoryURL,
|
||||
|
||||
@@ -3,11 +3,13 @@
|
||||
package factory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/http/proxy/factory/docker"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
)
|
||||
|
||||
func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portainer.Endpoint) (http.Handler, error) {
|
||||
@@ -31,9 +33,11 @@ func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portaine
|
||||
}
|
||||
|
||||
func newSocketTransport(socketPath string) *http.Transport {
|
||||
return &http.Transport{
|
||||
Dial: func(proto, addr string) (conn net.Conn, err error) {
|
||||
return net.Dial("unix", socketPath)
|
||||
},
|
||||
d := &net.Dialer{}
|
||||
t := ssrf.NewInternalTransport(nil)
|
||||
t.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return d.DialContext(ctx, "unix", socketPath)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
package factory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/Microsoft/go-winio"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/http/proxy/factory/docker"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
)
|
||||
|
||||
func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portainer.Endpoint) (http.Handler, error) {
|
||||
@@ -32,9 +34,10 @@ func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portaine
|
||||
}
|
||||
|
||||
func newNamedPipeTransport(namedPipePath string) *http.Transport {
|
||||
return &http.Transport{
|
||||
Dial: func(proto, addr string) (conn net.Conn, err error) {
|
||||
t := ssrf.NewInternalTransport(nil)
|
||||
t.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return winio.DialPipe(namedPipePath, nil)
|
||||
},
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
@@ -94,7 +94,7 @@ func NewHTTPClient(token string) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: &tokenTransport{
|
||||
token: token,
|
||||
transport: retry.NewTransport(ssrf.WrapTransport(&http.Transport{})), // Use ORAS retry transport for consistent rate limiting and error handling
|
||||
transport: retry.NewTransport(ssrf.NewTransport(nil)), // Use ORAS retry transport for consistent rate limiting and error handling
|
||||
},
|
||||
Timeout: 1 * time.Minute,
|
||||
}
|
||||
|
||||
@@ -94,7 +94,7 @@ type Transport struct {
|
||||
// interface for proxying requests to the Gitlab API.
|
||||
func NewTransport() *Transport {
|
||||
return &Transport{
|
||||
httpTransport: ssrf.WrapTransport(&http.Transport{}),
|
||||
httpTransport: ssrf.NewTransport(nil),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,7 +119,7 @@ func NewHTTPClient(token string) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: &tokenTransport{
|
||||
token: token,
|
||||
transport: retry.NewTransport(ssrf.WrapTransport(&http.Transport{})), // Use ORAS retry transport for consistent rate limiting and error handling
|
||||
transport: retry.NewTransport(ssrf.NewTransport(nil)), // Use ORAS retry transport for consistent rate limiting and error handling
|
||||
},
|
||||
Timeout: 1 * time.Minute,
|
||||
}
|
||||
|
||||
@@ -25,9 +25,7 @@ func NewAgentTransport(signatureService portainer.DigitalSignatureService, token
|
||||
|
||||
transport := &agentTransport{
|
||||
baseTransport: newBaseTransport(
|
||||
ssrf.WrapTransport(&http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
}),
|
||||
ssrf.NewTransport(tlsConfig),
|
||||
tokenManager,
|
||||
endpoint,
|
||||
k8sClientFactory,
|
||||
|
||||
@@ -22,7 +22,7 @@ func NewEdgeTransport(dataStore dataservices.DataStore, signatureService portain
|
||||
reverseTunnelService: reverseTunnelService,
|
||||
signatureService: signatureService,
|
||||
baseTransport: newBaseTransport(
|
||||
ssrf.WrapTransportInternal(&http.Transport{}),
|
||||
ssrf.NewInternalTransport(nil),
|
||||
tokenManager,
|
||||
endpoint,
|
||||
k8sClientFactory,
|
||||
|
||||
@@ -23,9 +23,7 @@ func NewLocalTransport(tokenManager *tokenManager, endpoint *portainer.Endpoint,
|
||||
|
||||
transport := &localTransport{
|
||||
baseTransport: newBaseTransport(
|
||||
ssrf.WrapTransportInternal(&http.Transport{
|
||||
TLSClientConfig: config,
|
||||
}),
|
||||
ssrf.NewInternalTransport(config),
|
||||
tokenManager,
|
||||
endpoint,
|
||||
k8sClientFactory,
|
||||
|
||||
@@ -2,6 +2,8 @@ package factory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -65,8 +67,6 @@ func enableSSRF(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestNewDockerHTTPProxy_NonEdgeNoTLS verifies that a plain non-edge endpoint
|
||||
// uses WrapTransport, setting DialContext on the inner transport.
|
||||
func TestNewDockerHTTPProxy_NonEdgeNoTLS(t *testing.T) {
|
||||
enableSSRF(t)
|
||||
|
||||
@@ -84,8 +84,6 @@ func TestNewDockerHTTPProxy_NonEdgeNoTLS(t *testing.T) {
|
||||
require.NotNil(t, dt.HTTPTransport.DialContext)
|
||||
}
|
||||
|
||||
// TestNewDockerHTTPProxy_NonEdgeTLS verifies that a TLS non-edge endpoint
|
||||
// uses WrapTransport, setting DialContext on the inner transport.
|
||||
func TestNewDockerHTTPProxy_NonEdgeTLS(t *testing.T) {
|
||||
enableSSRF(t)
|
||||
|
||||
@@ -107,11 +105,14 @@ func TestNewDockerHTTPProxy_NonEdgeTLS(t *testing.T) {
|
||||
require.NotNil(t, dt.HTTPTransport.DialContext)
|
||||
}
|
||||
|
||||
// TestNewDockerHTTPProxy_EdgeNoTLS verifies that an edge endpoint without TLS
|
||||
// uses WrapTransportInternal, leaving DialContext nil.
|
||||
func TestNewDockerHTTPProxy_EdgeNoTLS(t *testing.T) {
|
||||
enableSSRF(t)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
|
||||
endpoint := &portainer.Endpoint{
|
||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||
@@ -123,14 +124,20 @@ func TestNewDockerHTTPProxy_EdgeNoTLS(t *testing.T) {
|
||||
|
||||
proxy := handler.(*httputil.ReverseProxy)
|
||||
dt := proxy.Transport.(*docker.Transport)
|
||||
require.Nil(t, dt.HTTPTransport.DialContext)
|
||||
|
||||
resp, err := (&http.Client{Transport: dt.HTTPTransport}).Get(srv.URL)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
}
|
||||
|
||||
// TestNewDockerHTTPProxy_EdgeTLS verifies that an edge endpoint with TLS
|
||||
// uses WrapTransportInternal, leaving DialContext nil.
|
||||
func TestNewDockerHTTPProxy_EdgeTLS(t *testing.T) {
|
||||
enableSSRF(t)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
|
||||
endpoint := &portainer.Endpoint{
|
||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||
@@ -146,7 +153,10 @@ func TestNewDockerHTTPProxy_EdgeTLS(t *testing.T) {
|
||||
|
||||
proxy := handler.(*httputil.ReverseProxy)
|
||||
dt := proxy.Transport.(*docker.Transport)
|
||||
require.Nil(t, dt.HTTPTransport.DialContext)
|
||||
|
||||
resp, err := (&http.Client{Transport: dt.HTTPTransport}).Get(srv.URL)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
}
|
||||
|
||||
func TestNewAgentProxy_NonEdgeNoTLS(t *testing.T) {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/portainer/portainer/api/scheduler"
|
||||
"github.com/portainer/portainer/api/stacks/deployments"
|
||||
"github.com/portainer/portainer/api/stacks/stackutils"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
)
|
||||
|
||||
type GitMethodStackBuilder struct {
|
||||
@@ -78,6 +79,10 @@ func (b *GitMethodStackBuilder) prepare(ctx context.Context, payload *StackPaylo
|
||||
return b.fileService.GetStackProjectPath(stackFolder)
|
||||
}
|
||||
|
||||
if err := ssrf.CheckURL(ctx, repoConfig.URL); err != nil {
|
||||
return fmt.Errorf("repository URL blocked by SSRF policy: %w", err)
|
||||
}
|
||||
|
||||
commitHash, err := stackutils.DownloadGitRepository(ctx, repoConfig, b.gitService, getProjectPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download git repository: %w", err)
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
"github.com/portainer/portainer/pkg/libhelm/options"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
"github.com/portainer/portainer/pkg/liboras"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/segmentio/encoding/json"
|
||||
@@ -216,13 +217,15 @@ func downloadRepoIndexFromHttpRepo(repoURLString string, repoSettings *cli.EnvSe
|
||||
Str("repo_name", repoName).
|
||||
Msg("Creating chart repository object")
|
||||
|
||||
ssrfTransport := ssrf.NewTransport(nil)
|
||||
|
||||
// Create chart repository object
|
||||
rep, err := repo.NewChartRepository(
|
||||
&repo.Entry{
|
||||
Name: repoName,
|
||||
URL: repoURLString,
|
||||
},
|
||||
getter.All(repoSettings),
|
||||
getter.All(repoSettings, getter.WithTransport(ssrfTransport)),
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/portainer/portainer/pkg/libhelm/sdk"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
"helm.sh/helm/v4/pkg/cli"
|
||||
"helm.sh/helm/v4/pkg/getter"
|
||||
repo "helm.sh/helm/v4/pkg/repo/v1"
|
||||
@@ -40,12 +41,14 @@ func ValidateHelmRepositoryURL(repoUrl string, _ *http.Client) error {
|
||||
return fmt.Errorf("failed to derive repo name: %w", err)
|
||||
}
|
||||
|
||||
ssrfTransport := ssrf.NewTransport(nil)
|
||||
|
||||
r, err := repo.NewChartRepository(
|
||||
&repo.Entry{
|
||||
Name: repoName,
|
||||
URL: repoUrl,
|
||||
},
|
||||
getter.All(settings),
|
||||
getter.All(settings, getter.WithTransport(ssrfTransport)),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s is not a valid chart repository or cannot be reached: %w", repoUrl, err)
|
||||
@@ -53,7 +56,7 @@ func ValidateHelmRepositoryURL(repoUrl string, _ *http.Client) error {
|
||||
|
||||
indexPath, err := r.DownloadIndexFile()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s is not a valid chart repository or cannot be reached: %w", repoUrl, err)
|
||||
return fmt.Errorf("%s is not a valid chart repository or cannot be reached", repoUrl)
|
||||
}
|
||||
|
||||
// Best-effort: load and seed in-memory cache for future SearchRepo calls
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
package ssrf
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// NewTransport creates an SSRF-protected transport for user-influenced destinations.
|
||||
// It clones http.DefaultTransport as its base (inheriting pool and timeout defaults)
|
||||
// and applies the global SSRF dial-context filter so mode changes take effect without
|
||||
// restarting. tlsConfig may be nil to preserve standard TLS behavior (system CAs).
|
||||
func NewTransport(tlsConfig *tls.Config) *http.Transport {
|
||||
base := http.DefaultTransport.(*http.Transport).Clone()
|
||||
base.TLSClientConfig = tlsConfig
|
||||
applySSRF(base)
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
// NewInternalTransport creates a plain transport for destinations chosen by Portainer,
|
||||
// not by the user (Docker socket proxy, Chisel tunnels, in-cluster Kubernetes API).
|
||||
// It clones http.DefaultTransport as its base. tlsConfig may be nil.
|
||||
// Using this function instead of NewTransport makes the exemption explicit and
|
||||
// satisfies the ruleguard lint rule.
|
||||
func NewInternalTransport(tlsConfig *tls.Config) *http.Transport {
|
||||
base := http.DefaultTransport.(*http.Transport).Clone()
|
||||
base.TLSClientConfig = tlsConfig
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
// WrapDefaultTransport replaces http.DefaultTransport with an SSRF-protected version.
|
||||
// Must be called after Configure. Returns false if DefaultTransport is not an *http.Transport.
|
||||
func WrapDefaultTransport() bool {
|
||||
dt, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
cloned := dt.Clone()
|
||||
applySSRF(cloned)
|
||||
http.DefaultTransport = cloned
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// applySSRF sets the SSRF-filtering DialContext on t when the global dialer is active.
|
||||
func applySSRF(t *http.Transport) {
|
||||
d := globalDialer.Load()
|
||||
if d != nil {
|
||||
t.DialContext = d.DialContext
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@@ -112,31 +111,6 @@ func CheckURL(ctx context.Context, rawURL string) error {
|
||||
return d.checkHost(ctx, host)
|
||||
}
|
||||
|
||||
// WrapTransport clones t and replaces its DialContext with the global SSRF-filtering
|
||||
// dialer. The dialer checks the mode on every connection, so the transport is always
|
||||
// wrapped and mode changes take effect without restarting.
|
||||
func WrapTransport(t *http.Transport) *http.Transport {
|
||||
d := globalDialer.Load()
|
||||
if d == nil {
|
||||
return t
|
||||
}
|
||||
|
||||
cloned := t.Clone()
|
||||
cloned.DialContext = d.DialContext
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
// WrapTransportInternal is a documented no-op for transports that connect to
|
||||
// internally computed destinations (local Docker socket proxy, Chisel tunnels,
|
||||
// in-cluster Kubernetes API). The destination is chosen by Portainer, not
|
||||
// supplied by any user, so SSRF validation is not applicable. Using this
|
||||
// function instead of WrapTransport makes the exemption explicit and
|
||||
// satisfies the ruleguard lint rule.
|
||||
func WrapTransportInternal(t *http.Transport) *http.Transport {
|
||||
return t
|
||||
}
|
||||
|
||||
// DialContext resolves addr, validates all resolved IPs against the allowlist policy,
|
||||
// then dials using the first resolved IP to prevent DNS rebinding attacks.
|
||||
func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
|
||||
@@ -116,22 +116,27 @@ func TestConfigure_NilServicesReturnsError(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestWrapTransport_NoPolicy(t *testing.T) {
|
||||
func TestNewTransport_NoPolicy(t *testing.T) {
|
||||
globalDialer.Store(nil)
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
base := &http.Transport{}
|
||||
result := WrapTransport(base)
|
||||
require.Equal(t, base, result)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := &http.Client{Transport: NewTransport(nil)}
|
||||
resp, err := client.Get(srv.URL)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
}
|
||||
|
||||
func TestWrapTransport_WithPolicy(t *testing.T) {
|
||||
func TestNewTransport_WithPolicy(t *testing.T) {
|
||||
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"}))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
base := &http.Transport{}
|
||||
result := WrapTransport(base)
|
||||
require.NotEqual(t, base, result)
|
||||
result := NewTransport(nil)
|
||||
require.NotNil(t, result.DialContext)
|
||||
}
|
||||
|
||||
@@ -267,12 +272,12 @@ func TestIsEnabled(t *testing.T) {
|
||||
require.False(t, IsEnabled())
|
||||
}
|
||||
|
||||
func TestWrapTransportInternal(t *testing.T) {
|
||||
func TestNewInternalTransport(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := &http.Transport{}
|
||||
result := WrapTransportInternal(base)
|
||||
require.Equal(t, base, result)
|
||||
result := NewInternalTransport(nil)
|
||||
require.NotNil(t, result)
|
||||
require.Nil(t, result.TLSClientConfig)
|
||||
}
|
||||
|
||||
// TestDialContext_BlocksLoopback is an end-to-end test: it starts a real HTTP
|
||||
@@ -288,7 +293,7 @@ func TestDialContext_BlocksLoopback(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
blocked := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
||||
blocked := &http.Client{Transport: NewTransport(nil)}
|
||||
resp, err := blocked.Get(srv.URL)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "ssrf")
|
||||
@@ -300,7 +305,7 @@ func TestDialContext_BlocksLoopback(t *testing.T) {
|
||||
err = Configure(newStaticService(portainer.SSRFModeOff, nil))
|
||||
require.NoError(t, err)
|
||||
|
||||
open := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
||||
open := &http.Client{Transport: NewTransport(nil)}
|
||||
resp, err = open.Get(srv.URL)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
@@ -318,7 +323,7 @@ func TestDialContext_AuditMode_AllowsLoopback(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
||||
client := &http.Client{Transport: NewTransport(nil)}
|
||||
resp, err := client.Get(srv.URL)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
@@ -364,7 +369,7 @@ func TestDialContext_AllowedByCIDR(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
||||
client := &http.Client{Transport: NewTransport(nil)}
|
||||
resp, err := client.Get(srv.URL)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
@@ -398,7 +403,7 @@ func TestDialContext_AllowedByExactHostname(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||
|
||||
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
||||
client := &http.Client{Transport: NewTransport(nil)}
|
||||
resp, err := client.Get("http://localhost:" + portStr)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/portainer/portainer/api/crypto"
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/segmentio/encoding/json"
|
||||
@@ -73,9 +74,7 @@ func ProbeTelnetConnection(url string) string {
|
||||
// ignores errors for the http request since we want to know if the host is reachable
|
||||
func DetectProxy(url string) string {
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: crypto.CreateTLSConfiguration(true),
|
||||
},
|
||||
Transport: ssrf.NewTransport(crypto.CreateTLSConfiguration(true)),
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
|
||||
@@ -16,8 +16,7 @@ import (
|
||||
func CreateClient(registry *portainer.Registry) (httpClient *http.Client, usePlainHttp bool, err error) {
|
||||
switch registry.Type {
|
||||
case portainer.AzureRegistry, portainer.EcrRegistry, portainer.GithubRegistry, portainer.GitlabRegistry, portainer.DockerHubRegistry:
|
||||
base := http.DefaultTransport.(*http.Transport).Clone()
|
||||
return &http.Client{Transport: retry.NewTransport(ssrf.WrapTransport(base))}, false, nil
|
||||
return &http.Client{Transport: retry.NewTransport(ssrf.NewTransport(nil))}, false, nil
|
||||
default:
|
||||
// For all other registry types, use shared helper to build transport and scheme
|
||||
|
||||
|
||||
@@ -6,28 +6,27 @@ import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/crypto"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||
)
|
||||
|
||||
// BuildTransportAndSchemeFromTLSConfig returns a base HTTP transport configured
|
||||
// with ProxyFromEnvironment and, when needed, a TLSClientConfig derived from the
|
||||
// provided TLS settings. It also returns the scheme ("http" or "https") that
|
||||
// should be used to contact the registry based on the TLS settings.
|
||||
// BuildTransportAndSchemeFromTLSConfig returns an SSRF-protected HTTP transport and the
|
||||
// scheme ("http" or "https") to use when contacting the registry. The transport is based on
|
||||
// the TLS settings from tlsCfg; pass a zero-value TLSConfiguration for plaintext.
|
||||
func BuildTransportAndSchemeFromTLSConfig(tlsCfg portainer.TLSConfiguration) (*http.Transport, string, error) {
|
||||
baseTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
baseTransport.Proxy = http.ProxyFromEnvironment
|
||||
|
||||
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(tlsCfg)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
baseTransport.TLSClientConfig = tlsConfig
|
||||
|
||||
if tlsConfig == nil && fips.FIPSMode() {
|
||||
return nil, "", fips.ErrTLSRequired
|
||||
} else if tlsConfig == nil {
|
||||
return baseTransport, "http", nil
|
||||
}
|
||||
|
||||
return baseTransport, "https", nil
|
||||
transport := ssrf.NewTransport(tlsConfig)
|
||||
|
||||
if tlsConfig == nil {
|
||||
return transport, "http", nil
|
||||
}
|
||||
|
||||
return transport, "https", nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user