mirror of
https://github.com/portainer/portainer.git
synced 2026-06-23 04:40:13 +00:00
feat(ssrf): add missing transport wrappings and more checks BE-13021 (#2968)
This commit is contained in:
@@ -108,6 +108,9 @@ linters:
|
|||||||
linters:
|
linters:
|
||||||
- gocritic
|
- gocritic
|
||||||
text: ruleguard
|
text: ruleguard
|
||||||
|
- path: pkg/libhttp/ssrf/builder\.go
|
||||||
|
linters:
|
||||||
|
- forbidigo
|
||||||
paths:
|
paths:
|
||||||
- third_party$
|
- third_party$
|
||||||
- builtin$
|
- builtin$
|
||||||
|
|||||||
+58
-12
@@ -4,26 +4,72 @@ package gorules
|
|||||||
|
|
||||||
import "github.com/quasilyte/go-ruleguard/dsl"
|
import "github.com/quasilyte/go-ruleguard/dsl"
|
||||||
|
|
||||||
// unwrappedHTTPTransport flags http.Transport composite literals that are not
|
// unwrappedHTTPTransport flags any bare http.Transport composite literal.
|
||||||
// the direct argument to ssrf.WrapTransport.
|
// All transports must be created via ssrf.NewTransport or ssrf.NewInternalTransport,
|
||||||
|
// which clone http.DefaultTransport and handle SSRF protection internally.
|
||||||
func unwrappedHTTPTransport(m dsl.Matcher) {
|
func unwrappedHTTPTransport(m dsl.Matcher) {
|
||||||
// Inline construction passed to a function call.
|
|
||||||
m.Match(`$f(&http.Transport{$*_})`).
|
m.Match(`$f(&http.Transport{$*_})`).
|
||||||
Where(m["f"].Text != "ssrf.WrapTransport" && m["f"].Text != "WrapTransport" &&
|
Report(`$f receives a bare *http.Transport; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`)
|
||||||
m["f"].Text != "ssrf.WrapTransportInternal" && m["f"].Text != "WrapTransportInternal").
|
|
||||||
Report(`$f receives a bare *http.Transport; wrap with ssrf.WrapTransport() to enforce the SSRF protection policy`)
|
|
||||||
|
|
||||||
// Variable assigned a bare transport (cannot be tracked to a later WrapTransport call).
|
|
||||||
m.Match(`$_ := &http.Transport{$*_}`).
|
m.Match(`$_ := &http.Transport{$*_}`).
|
||||||
Report(`bare *http.Transport variable; use ssrf.WrapTransport(&http.Transport{...}) inline instead`)
|
Report(`bare *http.Transport variable; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`)
|
||||||
|
|
||||||
|
m.Match(`$_.Transport = &http.Transport{$*_}`).
|
||||||
|
Report(`bare *http.Transport field assignment; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`)
|
||||||
}
|
}
|
||||||
|
|
||||||
// internalTransportMisuse flags calls to WrapTransportInternal outside the four proxy
|
// helmGetterTransport flags getter.WithTransport calls that receive a bare *http.Transport.
|
||||||
|
// Helm v4 installs its own transport and bypasses http.DefaultTransport, so the transport
|
||||||
|
// passed here must be created via ssrf.NewTransport.
|
||||||
|
func helmGetterTransport(m dsl.Matcher) {
|
||||||
|
m.Match(`getter.WithTransport(&http.Transport{$*_})`).
|
||||||
|
Report(`getter.WithTransport called with a bare *http.Transport; use ssrf.NewTransport(tlsConfig) as Helm v4 bypasses http.DefaultTransport`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneDefaultTransport flags direct clones of *http.Transport outside main.go.
|
||||||
|
// The one legitimate clone is in main.go where http.DefaultTransport is globally
|
||||||
|
// wrapped with SSRF protection at server startup.
|
||||||
|
func cloneDefaultTransport(m dsl.Matcher) {
|
||||||
|
m.Match(`$_.(*http.Transport).Clone()`).
|
||||||
|
Where(!m.File().Name.Matches(`^main\.go$`)).
|
||||||
|
Report(`cloning *http.Transport directly is forbidden; use ssrf.NewTransport(tlsConfig) or ssrf.NewInternalTransport(tlsConfig) instead`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// internalTransportMisuse flags calls to NewInternalTransport outside the proxy
|
||||||
// factory files where Chisel-tunnel and in-cluster K8s destinations are valid exemptions.
|
// factory files where Chisel-tunnel and in-cluster K8s destinations are valid exemptions.
|
||||||
func internalTransportMisuse(m dsl.Matcher) {
|
func internalTransportMisuse(m dsl.Matcher) {
|
||||||
m.Match(`ssrf.WrapTransportInternal($*_)`).
|
m.Match(`ssrf.NewInternalTransport($*_)`).
|
||||||
Where(
|
Where(
|
||||||
!(m.File().PkgPath.Matches(`proxy/factory`) &&
|
!(m.File().PkgPath.Matches(`proxy/factory`) &&
|
||||||
m.File().Name.Matches(`^(docker|agent|local_transport|edge_transport)\.go$`))).
|
m.File().Name.Matches(`^(docker|agent|local_transport|edge_transport|docker_unix|docker_windows)\.go$`))).
|
||||||
Report(`WrapTransportInternal bypasses SSRF validation; only valid in the kubernetes local/edge transport constructors and the docker/agent proxy factories`)
|
Report(`NewInternalTransport bypasses SSRF validation; only valid in the proxy factory files for local sockets and internally-routed endpoints`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dialerOverride flags direct assignments to any of the dialer fields on a transport.
|
||||||
|
// The only valid assignments are in docker_unix.go and docker_windows.go where a
|
||||||
|
// custom dialer is required for unix sockets and named pipes.
|
||||||
|
func dialerOverride(m dsl.Matcher) {
|
||||||
|
m.Match(`$_.DialContext = $*_`).
|
||||||
|
Where(
|
||||||
|
!(m.File().PkgPath.Matches(`proxy/factory`) &&
|
||||||
|
m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))).
|
||||||
|
Report(`direct DialContext assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`)
|
||||||
|
|
||||||
|
m.Match(`$_.Dial = $*_`).
|
||||||
|
Where(
|
||||||
|
!(m.File().PkgPath.Matches(`proxy/factory`) &&
|
||||||
|
m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))).
|
||||||
|
Report(`direct Dial assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`)
|
||||||
|
|
||||||
|
m.Match(`$_.DialTLSContext = $*_`).
|
||||||
|
Where(
|
||||||
|
!(m.File().PkgPath.Matches(`proxy/factory`) &&
|
||||||
|
m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))).
|
||||||
|
Report(`direct DialTLSContext assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`)
|
||||||
|
|
||||||
|
m.Match(`$_.DialTLS = $*_`).
|
||||||
|
Where(
|
||||||
|
!(m.File().PkgPath.Matches(`proxy/factory`) &&
|
||||||
|
m.File().Name.Matches(`^(docker_unix|docker_windows)\.go$`))).
|
||||||
|
Report(`direct DialTLS assignment replaces the transport dialer; use ssrf.NewTransport or ssrf.NewInternalTransport instead`)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package agent
|
package agent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -11,6 +12,7 @@ import (
|
|||||||
|
|
||||||
portainer "github.com/portainer/portainer/api"
|
portainer "github.com/portainer/portainer/api"
|
||||||
"github.com/portainer/portainer/api/url"
|
"github.com/portainer/portainer/api/url"
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
@@ -19,10 +21,14 @@ import (
|
|||||||
//
|
//
|
||||||
// it sends a ping to the agent and parses the version and platform from the headers
|
// it sends a ping to the agent and parses the version and platform from the headers
|
||||||
func GetAgentVersionAndPlatform(endpointUrl string, tlsConfig *tls.Config) (portainer.AgentPlatform, string, error) { //nolint:forbidigo
|
func GetAgentVersionAndPlatform(endpointUrl string, tlsConfig *tls.Config) (portainer.AgentPlatform, string, error) { //nolint:forbidigo
|
||||||
|
if err := ssrf.CheckURL(context.Background(), endpointUrl); err != nil {
|
||||||
|
return 0, "", err
|
||||||
|
}
|
||||||
|
|
||||||
httpCli := &http.Client{Timeout: 3 * time.Second}
|
httpCli := &http.Client{Timeout: 3 * time.Second}
|
||||||
|
|
||||||
if tlsConfig != nil {
|
if tlsConfig != nil {
|
||||||
httpCli.Transport = &http.Transport{TLSClientConfig: tlsConfig}
|
httpCli.Transport = ssrf.NewTransport(tlsConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
parsedURL, err := url.ParseURL(endpointUrl + "/ping")
|
parsedURL, err := url.ParseURL(endpointUrl + "/ping")
|
||||||
|
|||||||
@@ -58,7 +58,10 @@ import (
|
|||||||
libswarm "github.com/portainer/portainer/pkg/libstack/swarm"
|
libswarm "github.com/portainer/portainer/pkg/libstack/swarm"
|
||||||
"github.com/portainer/portainer/pkg/validate"
|
"github.com/portainer/portainer/pkg/validate"
|
||||||
|
|
||||||
|
gogitclient "github.com/go-git/go-git/v5/plumbing/transport/client"
|
||||||
|
gogitraw "github.com/go-git/go-git/v5/plumbing/transport/git"
|
||||||
gogithttp "github.com/go-git/go-git/v5/plumbing/transport/http"
|
gogithttp "github.com/go-git/go-git/v5/plumbing/transport/http"
|
||||||
|
gogitssh "github.com/go-git/go-git/v5/plumbing/transport/ssh"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
@@ -408,11 +411,14 @@ func buildServer(flags *portainer.CLIFlags, shutdownCtx context.Context, shutdow
|
|||||||
log.Fatal().Err(err).Msg("failed initializing ssrf service")
|
log.Fatal().Err(err).Msg("failed initializing ssrf service")
|
||||||
}
|
}
|
||||||
|
|
||||||
if dt, ok := nethttp.DefaultTransport.(*nethttp.Transport); ok {
|
if !ssrf.WrapDefaultTransport() {
|
||||||
nethttp.DefaultTransport = ssrf.WrapTransport(dt)
|
log.Fatal().Msg("failed to wrap default HTTP transport with SSRF protection")
|
||||||
}
|
}
|
||||||
|
|
||||||
gogithttp.DefaultClient = gogithttp.NewClient(&nethttp.Client{Transport: nethttp.DefaultTransport})
|
gogithttp.DefaultClient = gogithttp.NewClient(&nethttp.Client{Transport: nethttp.DefaultTransport})
|
||||||
|
gogitclient.InstallProtocol("git", git.NewSSRFGitTransport(gogitraw.DefaultClient))
|
||||||
|
gogitclient.InstallProtocol("ssh", git.NewSSRFGitTransport(gogitssh.DefaultClient))
|
||||||
|
gogitclient.InstallProtocol("file", nil)
|
||||||
|
|
||||||
instanceID, err := dataStore.Version().InstanceID()
|
instanceID, err := dataStore.Version().InstanceID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -194,11 +194,11 @@ func httpClient(endpoint *portainer.Endpoint, timeout *time.Duration) (*http.Cli
|
|||||||
}
|
}
|
||||||
|
|
||||||
transport = &NodeNameTransport{
|
transport = &NodeNameTransport{
|
||||||
Transport: ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig}),
|
Transport: ssrf.NewTransport(tlsConfig),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
transport = &NodeNameTransport{
|
transport = &NodeNameTransport{
|
||||||
Transport: ssrf.WrapTransport(&http.Transport{}),
|
Transport: ssrf.NewTransport(nil),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+2
-5
@@ -66,11 +66,8 @@ func NewAzureClient() *azureClient {
|
|||||||
|
|
||||||
func newHttpClientForAzure(insecureSkipVerify bool) *http.Client {
|
func newHttpClientForAzure(insecureSkipVerify bool) *http.Client {
|
||||||
return &http.Client{
|
return &http.Client{
|
||||||
Transport: ssrf.WrapTransport(&http.Transport{
|
Transport: ssrf.NewTransport(crypto.CreateTLSConfiguration(insecureSkipVerify)),
|
||||||
TLSClientConfig: crypto.CreateTLSConfiguration(insecureSkipVerify),
|
Timeout: 300 * time.Second,
|
||||||
Proxy: http.ProxyFromEnvironment,
|
|
||||||
}),
|
|
||||||
Timeout: 300 * time.Second,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,53 @@
|
|||||||
|
package git
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
|
|
||||||
|
gittransport "github.com/go-git/go-git/v5/plumbing/transport"
|
||||||
|
)
|
||||||
|
|
||||||
|
const gitDefaultPort = 9418
|
||||||
|
|
||||||
|
// ssrfGitTransport wraps a git:// transport and validates the resolved IP
|
||||||
|
// against the SSRF policy before establishing connections.
|
||||||
|
type ssrfGitTransport struct {
|
||||||
|
inner gittransport.Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSSRFGitTransport wraps inner and blocks connections to private IP ranges
|
||||||
|
// according to the active SSRF policy.
|
||||||
|
func NewSSRFGitTransport(inner gittransport.Transport) gittransport.Transport {
|
||||||
|
return &ssrfGitTransport{inner: inner}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ssrfGitTransport) NewUploadPackSession(ep *gittransport.Endpoint, auth gittransport.AuthMethod) (gittransport.UploadPackSession, error) {
|
||||||
|
if err := checkEndpointSSRF(ep); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return t.inner.NewUploadPackSession(ep, auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ssrfGitTransport) NewReceivePackSession(ep *gittransport.Endpoint, auth gittransport.AuthMethod) (gittransport.ReceivePackSession, error) {
|
||||||
|
if err := checkEndpointSSRF(ep); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return t.inner.NewReceivePackSession(ep, auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkEndpointSSRF(ep *gittransport.Endpoint) error {
|
||||||
|
port := ep.Port
|
||||||
|
if port <= 0 {
|
||||||
|
port = gitDefaultPort
|
||||||
|
}
|
||||||
|
|
||||||
|
rawURL := fmt.Sprintf("git://%s/", net.JoinHostPort(ep.Host, strconv.Itoa(port)))
|
||||||
|
|
||||||
|
return ssrf.CheckURL(context.Background(), rawURL)
|
||||||
|
}
|
||||||
@@ -125,9 +125,9 @@ func ExecutePingOperation(host string, tlsConfiguration portainer.TLSConfigurati
|
|||||||
}
|
}
|
||||||
|
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
transport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig})
|
transport = ssrf.NewTransport(tlsConfig)
|
||||||
} else {
|
} else {
|
||||||
transport = ssrf.WrapTransport(&http.Transport{})
|
transport = ssrf.NewTransport(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &http.Client{
|
client := &http.Client{
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
"github.com/portainer/portainer/pkg/validate"
|
"github.com/portainer/portainer/pkg/validate"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
@@ -315,6 +316,10 @@ func (handler *Handler) createCustomTemplateFromGitRepository(r *http.Request) (
|
|||||||
return nil, httpErr
|
return nil, httpErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := ssrf.CheckURL(r.Context(), gitConfig.URL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
commitHash, err := stackutils.DownloadGitRepository(context.TODO(), gitConfig, handler.GitService, getProjectPath)
|
commitHash, err := stackutils.DownloadGitRepository(context.TODO(), gitConfig, handler.GitService, getProjectPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/portainer/portainer/api/stacks/stackutils"
|
"github.com/portainer/portainer/api/stacks/stackutils"
|
||||||
"github.com/portainer/portainer/pkg/edge"
|
"github.com/portainer/portainer/pkg/edge"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
"github.com/portainer/portainer/pkg/validate"
|
"github.com/portainer/portainer/pkg/validate"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@@ -69,7 +70,6 @@ func (payload *edgeStackFromGitRepositoryPayload) Validate(r *http.Request) erro
|
|||||||
if len(payload.RepositoryURL) == 0 || !validate.IsURL(payload.RepositoryURL) {
|
if len(payload.RepositoryURL) == 0 || !validate.IsURL(payload.RepositoryURL) {
|
||||||
return httperrors.NewInvalidPayloadError("Invalid repository URL. Must correspond to a valid URL format")
|
return httperrors.NewInvalidPayloadError("Invalid repository URL. Must correspond to a valid URL format")
|
||||||
}
|
}
|
||||||
|
|
||||||
if payload.RepositoryAuthentication && len(payload.RepositoryPassword) == 0 {
|
if payload.RepositoryAuthentication && len(payload.RepositoryPassword) == 0 {
|
||||||
return httperrors.NewInvalidPayloadError("Invalid repository credentials. Password must be specified when authentication is enabled")
|
return httperrors.NewInvalidPayloadError("Invalid repository credentials. Password must be specified when authentication is enabled")
|
||||||
}
|
}
|
||||||
@@ -138,6 +138,10 @@ func (handler *Handler) createEdgeStackFromGitRepository(r *http.Request, tx dat
|
|||||||
return nil, httpErr
|
return nil, httpErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := ssrf.CheckURL(r.Context(), repoConfig.URL); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "repository URL blocked by SSRF policy")
|
||||||
|
}
|
||||||
|
|
||||||
stack.CreatedByUserId = fmt.Sprintf("%d", tokenData.ID)
|
stack.CreatedByUserId = fmt.Sprintf("%d", tokenData.ID)
|
||||||
stack.CreatedBy = stackutils.SanitizeLabel(tokenData.Username)
|
stack.CreatedBy = stackutils.SanitizeLabel(tokenData.Username)
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
"github.com/portainer/portainer/pkg/validate"
|
"github.com/portainer/portainer/pkg/validate"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
@@ -100,6 +101,10 @@ func (handler *Handler) gitOperationRepoFilePreview(w http.ResponseWriter, r *ht
|
|||||||
tlsSkipVerify = src.Git.TLSSkipVerify
|
tlsSkipVerify = src.Git.TLSSkipVerify
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := ssrf.CheckURL(r.Context(), repoURL); err != nil {
|
||||||
|
return httperror.BadRequest("Repository URL blocked by SSRF policy", err)
|
||||||
|
}
|
||||||
|
|
||||||
projectPath, err := handler.fileService.GetTemporaryPath()
|
projectPath, err := handler.fileService.GetTemporaryPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httperror.InternalServerError("Unable to create temporary folder", err)
|
return httperror.InternalServerError("Unable to create temporary folder", err)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||||
|
"github.com/portainer/portainer/pkg/validate"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GitAuthenticationPayload holds authentication parameters for a git source
|
// GitAuthenticationPayload holds authentication parameters for a git source
|
||||||
@@ -30,8 +31,8 @@ type GitSourceCreatePayload struct {
|
|||||||
|
|
||||||
// Validate implements the portainer.Validatable interface
|
// Validate implements the portainer.Validatable interface
|
||||||
func (payload *GitSourceCreatePayload) Validate(_ *http.Request) error {
|
func (payload *GitSourceCreatePayload) Validate(_ *http.Request) error {
|
||||||
if strings.TrimSpace(payload.URL) == "" {
|
if !validate.IsURL(payload.URL) {
|
||||||
return errors.New("url is required")
|
return errors.New("invalid repository URL. Must correspond to a valid URL format")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||||
|
"github.com/portainer/portainer/pkg/validate"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -35,6 +36,10 @@ type GitAuthenticationUpdatePayload struct {
|
|||||||
|
|
||||||
// Validate implements the portainer.Validatable interface
|
// Validate implements the portainer.Validatable interface
|
||||||
func (payload *GitSourceUpdatePayload) Validate(_ *http.Request) error {
|
func (payload *GitSourceUpdatePayload) Validate(_ *http.Request) error {
|
||||||
|
if payload.URL != nil && !validate.IsURL(*payload.URL) {
|
||||||
|
return errors.New("invalid repository URL. Must correspond to a valid URL format")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/portainer/portainer/pkg/libhelm/options"
|
"github.com/portainer/portainer/pkg/libhelm/options"
|
||||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@@ -45,6 +46,10 @@ func (handler *Handler) helmRepoSearch(w http.ResponseWriter, r *http.Request) *
|
|||||||
return httperror.BadRequest("Bad request", errors.Wrap(err, fmt.Sprintf("provided URL %q is not valid", repo)))
|
return httperror.BadRequest("Bad request", errors.Wrap(err, fmt.Sprintf("provided URL %q is not valid", repo)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := ssrf.CheckURL(r.Context(), repo); err != nil {
|
||||||
|
return httperror.BadRequest("Repository URL blocked by SSRF policy", err)
|
||||||
|
}
|
||||||
|
|
||||||
searchOpts := options.SearchRepoOptions{
|
searchOpts := options.SearchRepoOptions{
|
||||||
Repo: repo,
|
Repo: repo,
|
||||||
Chart: chart,
|
Chart: chart,
|
||||||
@@ -53,7 +58,8 @@ func (handler *Handler) helmRepoSearch(w http.ResponseWriter, r *http.Request) *
|
|||||||
|
|
||||||
result, err := handler.helmPackageManager.SearchRepo(searchOpts)
|
result, err := handler.helmPackageManager.SearchRepo(searchOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httperror.InternalServerError("Search failed", err)
|
log.Warn().Err(err).Str("repo", repo).Msg("helm repo search failed")
|
||||||
|
return httperror.InternalServerError("Search failed", errors.New("failed to search Helm repository"))
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/portainer/portainer/pkg/libhelm/options"
|
"github.com/portainer/portainer/pkg/libhelm/options"
|
||||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
@@ -41,6 +42,10 @@ func (handler *Handler) helmShow(w http.ResponseWriter, r *http.Request) *httper
|
|||||||
return httperror.BadRequest("Bad request", errors.Wrap(err, fmt.Sprintf("provided URL %q is not valid", repo)))
|
return httperror.BadRequest("Bad request", errors.Wrap(err, fmt.Sprintf("provided URL %q is not valid", repo)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := ssrf.CheckURL(r.Context(), repo); err != nil {
|
||||||
|
return httperror.BadRequest("Repository URL blocked by SSRF policy", err)
|
||||||
|
}
|
||||||
|
|
||||||
chart := r.URL.Query().Get("chart")
|
chart := r.URL.Query().Get("chart")
|
||||||
if chart == "" {
|
if chart == "" {
|
||||||
return httperror.BadRequest("Bad request", errors.New("missing `chart` query parameter"))
|
return httperror.BadRequest("Bad request", errors.New("missing `chart` query parameter"))
|
||||||
@@ -65,7 +70,8 @@ func (handler *Handler) helmShow(w http.ResponseWriter, r *http.Request) *httper
|
|||||||
}
|
}
|
||||||
result, err := handler.helmPackageManager.Show(showOptions)
|
result, err := handler.helmPackageManager.Show(showOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httperror.InternalServerError("Unable to show chart", err)
|
log.Warn().Err(err).Str("repo", repo).Str("chart", chart).Msg("helm show failed")
|
||||||
|
return httperror.InternalServerError("Unable to show chart", errors.New("failed to retrieve Helm chart information"))
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
"github.com/portainer/portainer/pkg/validate"
|
"github.com/portainer/portainer/pkg/validate"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@@ -74,6 +75,12 @@ func (payload *settingsUpdatePayload) Validate(r *http.Request) error {
|
|||||||
return errors.New("Invalid Helm repository URL. Must correspond to a valid URL format")
|
return errors.New("Invalid Helm repository URL. Must correspond to a valid URL format")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if payload.HelmRepositoryURL != nil && *payload.HelmRepositoryURL != "" {
|
||||||
|
if err := ssrf.CheckURL(r.Context(), *payload.HelmRepositoryURL); err != nil {
|
||||||
|
return errors.New("Invalid Helm repository URL. Must correspond to a valid URL format")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if payload.UserSessionTimeout != nil {
|
if payload.UserSessionTimeout != nil {
|
||||||
if _, err := time.ParseDuration(*payload.UserSessionTimeout); err != nil {
|
if _, err := time.ParseDuration(*payload.UserSessionTimeout); err != nil {
|
||||||
return errors.New("Invalid user session timeout")
|
return errors.New("Invalid user session timeout")
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
"github.com/portainer/portainer/pkg/validate"
|
"github.com/portainer/portainer/pkg/validate"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@@ -125,6 +126,10 @@ func (payload *kubernetesManifestURLDeploymentPayload) Validate(r *http.Request)
|
|||||||
return errors.New("Invalid manifest URL")
|
return errors.New("Invalid manifest URL")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := ssrf.CheckURL(r.Context(), payload.ManifestURL); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
@@ -24,7 +25,11 @@ type addHelmRepoUrlPayload struct {
|
|||||||
URL string `json:"url"`
|
URL string `json:"url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *addHelmRepoUrlPayload) Validate(_ *http.Request) error {
|
func (p *addHelmRepoUrlPayload) Validate(r *http.Request) error {
|
||||||
|
if err := ssrf.CheckURL(r.Context(), p.URL); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return libhelm.ValidateHelmRepositoryURL(p.URL, nil)
|
return libhelm.ValidateHelmRepositoryURL(p.URL, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -52,14 +52,14 @@ func (factory *ProxyFactory) NewAgentProxy(endpoint *portainer.Endpoint) (*Proxy
|
|||||||
endpointURL.Scheme = "https"
|
endpointURL.Scheme = "https"
|
||||||
|
|
||||||
if endpointutils.IsEdgeEndpoint(endpoint) {
|
if endpointutils.IsEdgeEndpoint(endpoint) {
|
||||||
innerTransport = ssrf.WrapTransportInternal(&http.Transport{TLSClientConfig: tlsConfig})
|
innerTransport = ssrf.NewInternalTransport(tlsConfig)
|
||||||
} else {
|
} else {
|
||||||
innerTransport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig})
|
innerTransport = ssrf.NewTransport(tlsConfig)
|
||||||
}
|
}
|
||||||
} else if endpointutils.IsEdgeEndpoint(endpoint) {
|
} else if endpointutils.IsEdgeEndpoint(endpoint) {
|
||||||
innerTransport = ssrf.WrapTransportInternal(&http.Transport{})
|
innerTransport = ssrf.NewInternalTransport(nil)
|
||||||
} else {
|
} else {
|
||||||
innerTransport = ssrf.WrapTransport(&http.Transport{})
|
innerTransport = ssrf.NewTransport(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy := NewSingleHostReverseProxyWithHostHeader(endpointURL)
|
proxy := NewSingleHostReverseProxyWithHostHeader(endpointURL)
|
||||||
|
|||||||
@@ -68,14 +68,14 @@ func (factory *ProxyFactory) newDockerHTTPProxy(endpoint *portainer.Endpoint) (h
|
|||||||
endpointURL.Scheme = "https"
|
endpointURL.Scheme = "https"
|
||||||
|
|
||||||
if endpointutils.IsEdgeEndpoint(endpoint) {
|
if endpointutils.IsEdgeEndpoint(endpoint) {
|
||||||
innerTransport = ssrf.WrapTransportInternal(&http.Transport{TLSClientConfig: tlsConfig})
|
innerTransport = ssrf.NewInternalTransport(tlsConfig)
|
||||||
} else {
|
} else {
|
||||||
innerTransport = ssrf.WrapTransport(&http.Transport{TLSClientConfig: tlsConfig})
|
innerTransport = ssrf.NewTransport(tlsConfig)
|
||||||
}
|
}
|
||||||
} else if endpointutils.IsEdgeEndpoint(endpoint) {
|
} else if endpointutils.IsEdgeEndpoint(endpoint) {
|
||||||
innerTransport = ssrf.WrapTransportInternal(&http.Transport{})
|
innerTransport = ssrf.NewInternalTransport(nil)
|
||||||
} else {
|
} else {
|
||||||
innerTransport = ssrf.WrapTransport(&http.Transport{})
|
innerTransport = ssrf.NewTransport(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
dockerTransport, err := docker.NewTransport(transportParameters, innerTransport, factory.gitService, factory.snapshotService)
|
dockerTransport, err := docker.NewTransport(transportParameters, innerTransport, factory.gitService, factory.snapshotService)
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/portainer/portainer/api/internal/authorization"
|
"github.com/portainer/portainer/api/internal/authorization"
|
||||||
"github.com/portainer/portainer/api/logs"
|
"github.com/portainer/portainer/api/logs"
|
||||||
"github.com/portainer/portainer/api/slicesx"
|
"github.com/portainer/portainer/api/slicesx"
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
|
|
||||||
"github.com/docker/docker/api/types/network"
|
"github.com/docker/docker/api/types/network"
|
||||||
"github.com/docker/docker/api/types/swarm"
|
"github.com/docker/docker/api/types/swarm"
|
||||||
@@ -506,6 +507,11 @@ func (transport *Transport) updateDefaultGitBranch(request *http.Request) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
repositoryURL := remote[:len(remote)-4]
|
repositoryURL := remote[:len(remote)-4]
|
||||||
|
|
||||||
|
if err := ssrf.CheckURL(request.Context(), repositoryURL); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
latestCommitID, err := transport.gitService.LatestCommitID(
|
latestCommitID, err := transport.gitService.LatestCommitID(
|
||||||
request.Context(),
|
request.Context(),
|
||||||
repositoryURL,
|
repositoryURL,
|
||||||
|
|||||||
@@ -3,11 +3,13 @@
|
|||||||
package factory
|
package factory
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
portainer "github.com/portainer/portainer/api"
|
portainer "github.com/portainer/portainer/api"
|
||||||
"github.com/portainer/portainer/api/http/proxy/factory/docker"
|
"github.com/portainer/portainer/api/http/proxy/factory/docker"
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portainer.Endpoint) (http.Handler, error) {
|
func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portainer.Endpoint) (http.Handler, error) {
|
||||||
@@ -31,9 +33,11 @@ func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portaine
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newSocketTransport(socketPath string) *http.Transport {
|
func newSocketTransport(socketPath string) *http.Transport {
|
||||||
return &http.Transport{
|
d := &net.Dialer{}
|
||||||
Dial: func(proto, addr string) (conn net.Conn, err error) {
|
t := ssrf.NewInternalTransport(nil)
|
||||||
return net.Dial("unix", socketPath)
|
t.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||||
},
|
return d.DialContext(ctx, "unix", socketPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return t
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,12 +3,14 @@
|
|||||||
package factory
|
package factory
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/Microsoft/go-winio"
|
"github.com/Microsoft/go-winio"
|
||||||
portainer "github.com/portainer/portainer/api"
|
portainer "github.com/portainer/portainer/api"
|
||||||
"github.com/portainer/portainer/api/http/proxy/factory/docker"
|
"github.com/portainer/portainer/api/http/proxy/factory/docker"
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portainer.Endpoint) (http.Handler, error) {
|
func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portainer.Endpoint) (http.Handler, error) {
|
||||||
@@ -32,9 +34,10 @@ func (factory ProxyFactory) newOSBasedLocalProxy(path string, endpoint *portaine
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newNamedPipeTransport(namedPipePath string) *http.Transport {
|
func newNamedPipeTransport(namedPipePath string) *http.Transport {
|
||||||
return &http.Transport{
|
t := ssrf.NewInternalTransport(nil)
|
||||||
Dial: func(proto, addr string) (conn net.Conn, err error) {
|
t.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||||
return winio.DialPipe(namedPipePath, nil)
|
return winio.DialPipe(namedPipePath, nil)
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return t
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ func NewHTTPClient(token string) *http.Client {
|
|||||||
return &http.Client{
|
return &http.Client{
|
||||||
Transport: &tokenTransport{
|
Transport: &tokenTransport{
|
||||||
token: token,
|
token: token,
|
||||||
transport: retry.NewTransport(ssrf.WrapTransport(&http.Transport{})), // Use ORAS retry transport for consistent rate limiting and error handling
|
transport: retry.NewTransport(ssrf.NewTransport(nil)), // Use ORAS retry transport for consistent rate limiting and error handling
|
||||||
},
|
},
|
||||||
Timeout: 1 * time.Minute,
|
Timeout: 1 * time.Minute,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ type Transport struct {
|
|||||||
// interface for proxying requests to the Gitlab API.
|
// interface for proxying requests to the Gitlab API.
|
||||||
func NewTransport() *Transport {
|
func NewTransport() *Transport {
|
||||||
return &Transport{
|
return &Transport{
|
||||||
httpTransport: ssrf.WrapTransport(&http.Transport{}),
|
httpTransport: ssrf.NewTransport(nil),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,7 +119,7 @@ func NewHTTPClient(token string) *http.Client {
|
|||||||
return &http.Client{
|
return &http.Client{
|
||||||
Transport: &tokenTransport{
|
Transport: &tokenTransport{
|
||||||
token: token,
|
token: token,
|
||||||
transport: retry.NewTransport(ssrf.WrapTransport(&http.Transport{})), // Use ORAS retry transport for consistent rate limiting and error handling
|
transport: retry.NewTransport(ssrf.NewTransport(nil)), // Use ORAS retry transport for consistent rate limiting and error handling
|
||||||
},
|
},
|
||||||
Timeout: 1 * time.Minute,
|
Timeout: 1 * time.Minute,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,9 +25,7 @@ func NewAgentTransport(signatureService portainer.DigitalSignatureService, token
|
|||||||
|
|
||||||
transport := &agentTransport{
|
transport := &agentTransport{
|
||||||
baseTransport: newBaseTransport(
|
baseTransport: newBaseTransport(
|
||||||
ssrf.WrapTransport(&http.Transport{
|
ssrf.NewTransport(tlsConfig),
|
||||||
TLSClientConfig: tlsConfig,
|
|
||||||
}),
|
|
||||||
tokenManager,
|
tokenManager,
|
||||||
endpoint,
|
endpoint,
|
||||||
k8sClientFactory,
|
k8sClientFactory,
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ func NewEdgeTransport(dataStore dataservices.DataStore, signatureService portain
|
|||||||
reverseTunnelService: reverseTunnelService,
|
reverseTunnelService: reverseTunnelService,
|
||||||
signatureService: signatureService,
|
signatureService: signatureService,
|
||||||
baseTransport: newBaseTransport(
|
baseTransport: newBaseTransport(
|
||||||
ssrf.WrapTransportInternal(&http.Transport{}),
|
ssrf.NewInternalTransport(nil),
|
||||||
tokenManager,
|
tokenManager,
|
||||||
endpoint,
|
endpoint,
|
||||||
k8sClientFactory,
|
k8sClientFactory,
|
||||||
|
|||||||
@@ -23,9 +23,7 @@ func NewLocalTransport(tokenManager *tokenManager, endpoint *portainer.Endpoint,
|
|||||||
|
|
||||||
transport := &localTransport{
|
transport := &localTransport{
|
||||||
baseTransport: newBaseTransport(
|
baseTransport: newBaseTransport(
|
||||||
ssrf.WrapTransportInternal(&http.Transport{
|
ssrf.NewInternalTransport(config),
|
||||||
TLSClientConfig: config,
|
|
||||||
}),
|
|
||||||
tokenManager,
|
tokenManager,
|
||||||
endpoint,
|
endpoint,
|
||||||
k8sClientFactory,
|
k8sClientFactory,
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package factory
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -65,8 +67,6 @@ func enableSSRF(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestNewDockerHTTPProxy_NonEdgeNoTLS verifies that a plain non-edge endpoint
|
|
||||||
// uses WrapTransport, setting DialContext on the inner transport.
|
|
||||||
func TestNewDockerHTTPProxy_NonEdgeNoTLS(t *testing.T) {
|
func TestNewDockerHTTPProxy_NonEdgeNoTLS(t *testing.T) {
|
||||||
enableSSRF(t)
|
enableSSRF(t)
|
||||||
|
|
||||||
@@ -84,8 +84,6 @@ func TestNewDockerHTTPProxy_NonEdgeNoTLS(t *testing.T) {
|
|||||||
require.NotNil(t, dt.HTTPTransport.DialContext)
|
require.NotNil(t, dt.HTTPTransport.DialContext)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestNewDockerHTTPProxy_NonEdgeTLS verifies that a TLS non-edge endpoint
|
|
||||||
// uses WrapTransport, setting DialContext on the inner transport.
|
|
||||||
func TestNewDockerHTTPProxy_NonEdgeTLS(t *testing.T) {
|
func TestNewDockerHTTPProxy_NonEdgeTLS(t *testing.T) {
|
||||||
enableSSRF(t)
|
enableSSRF(t)
|
||||||
|
|
||||||
@@ -107,11 +105,14 @@ func TestNewDockerHTTPProxy_NonEdgeTLS(t *testing.T) {
|
|||||||
require.NotNil(t, dt.HTTPTransport.DialContext)
|
require.NotNil(t, dt.HTTPTransport.DialContext)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestNewDockerHTTPProxy_EdgeNoTLS verifies that an edge endpoint without TLS
|
|
||||||
// uses WrapTransportInternal, leaving DialContext nil.
|
|
||||||
func TestNewDockerHTTPProxy_EdgeNoTLS(t *testing.T) {
|
func TestNewDockerHTTPProxy_EdgeNoTLS(t *testing.T) {
|
||||||
enableSSRF(t)
|
enableSSRF(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
|
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
|
||||||
endpoint := &portainer.Endpoint{
|
endpoint := &portainer.Endpoint{
|
||||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||||
@@ -123,14 +124,20 @@ func TestNewDockerHTTPProxy_EdgeNoTLS(t *testing.T) {
|
|||||||
|
|
||||||
proxy := handler.(*httputil.ReverseProxy)
|
proxy := handler.(*httputil.ReverseProxy)
|
||||||
dt := proxy.Transport.(*docker.Transport)
|
dt := proxy.Transport.(*docker.Transport)
|
||||||
require.Nil(t, dt.HTTPTransport.DialContext)
|
|
||||||
|
resp, err := (&http.Client{Transport: dt.HTTPTransport}).Get(srv.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, resp.Body.Close())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestNewDockerHTTPProxy_EdgeTLS verifies that an edge endpoint with TLS
|
|
||||||
// uses WrapTransportInternal, leaving DialContext nil.
|
|
||||||
func TestNewDockerHTTPProxy_EdgeTLS(t *testing.T) {
|
func TestNewDockerHTTPProxy_EdgeTLS(t *testing.T) {
|
||||||
enableSSRF(t)
|
enableSSRF(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
|
f := &ProxyFactory{reverseTunnelService: &stubTunnelService{}}
|
||||||
endpoint := &portainer.Endpoint{
|
endpoint := &portainer.Endpoint{
|
||||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||||
@@ -146,7 +153,10 @@ func TestNewDockerHTTPProxy_EdgeTLS(t *testing.T) {
|
|||||||
|
|
||||||
proxy := handler.(*httputil.ReverseProxy)
|
proxy := handler.(*httputil.ReverseProxy)
|
||||||
dt := proxy.Transport.(*docker.Transport)
|
dt := proxy.Transport.(*docker.Transport)
|
||||||
require.Nil(t, dt.HTTPTransport.DialContext)
|
|
||||||
|
resp, err := (&http.Client{Transport: dt.HTTPTransport}).Get(srv.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, resp.Body.Close())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewAgentProxy_NonEdgeNoTLS(t *testing.T) {
|
func TestNewAgentProxy_NonEdgeNoTLS(t *testing.T) {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/portainer/portainer/api/scheduler"
|
"github.com/portainer/portainer/api/scheduler"
|
||||||
"github.com/portainer/portainer/api/stacks/deployments"
|
"github.com/portainer/portainer/api/stacks/deployments"
|
||||||
"github.com/portainer/portainer/api/stacks/stackutils"
|
"github.com/portainer/portainer/api/stacks/stackutils"
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GitMethodStackBuilder struct {
|
type GitMethodStackBuilder struct {
|
||||||
@@ -78,6 +79,10 @@ func (b *GitMethodStackBuilder) prepare(ctx context.Context, payload *StackPaylo
|
|||||||
return b.fileService.GetStackProjectPath(stackFolder)
|
return b.fileService.GetStackProjectPath(stackFolder)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := ssrf.CheckURL(ctx, repoConfig.URL); err != nil {
|
||||||
|
return fmt.Errorf("repository URL blocked by SSRF policy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
commitHash, err := stackutils.DownloadGitRepository(ctx, repoConfig, b.gitService, getProjectPath)
|
commitHash, err := stackutils.DownloadGitRepository(ctx, repoConfig, b.gitService, getProjectPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to download git repository: %w", err)
|
return fmt.Errorf("failed to download git repository: %w", err)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
portainer "github.com/portainer/portainer/api"
|
portainer "github.com/portainer/portainer/api"
|
||||||
"github.com/portainer/portainer/api/logs"
|
"github.com/portainer/portainer/api/logs"
|
||||||
"github.com/portainer/portainer/pkg/libhelm/options"
|
"github.com/portainer/portainer/pkg/libhelm/options"
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
"github.com/portainer/portainer/pkg/liboras"
|
"github.com/portainer/portainer/pkg/liboras"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/segmentio/encoding/json"
|
"github.com/segmentio/encoding/json"
|
||||||
@@ -216,13 +217,15 @@ func downloadRepoIndexFromHttpRepo(repoURLString string, repoSettings *cli.EnvSe
|
|||||||
Str("repo_name", repoName).
|
Str("repo_name", repoName).
|
||||||
Msg("Creating chart repository object")
|
Msg("Creating chart repository object")
|
||||||
|
|
||||||
|
ssrfTransport := ssrf.NewTransport(nil)
|
||||||
|
|
||||||
// Create chart repository object
|
// Create chart repository object
|
||||||
rep, err := repo.NewChartRepository(
|
rep, err := repo.NewChartRepository(
|
||||||
&repo.Entry{
|
&repo.Entry{
|
||||||
Name: repoName,
|
Name: repoName,
|
||||||
URL: repoURLString,
|
URL: repoURLString,
|
||||||
},
|
},
|
||||||
getter.All(repoSettings),
|
getter.All(repoSettings, getter.WithTransport(ssrfTransport)),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/portainer/portainer/pkg/libhelm/sdk"
|
"github.com/portainer/portainer/pkg/libhelm/sdk"
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
"helm.sh/helm/v4/pkg/cli"
|
"helm.sh/helm/v4/pkg/cli"
|
||||||
"helm.sh/helm/v4/pkg/getter"
|
"helm.sh/helm/v4/pkg/getter"
|
||||||
repo "helm.sh/helm/v4/pkg/repo/v1"
|
repo "helm.sh/helm/v4/pkg/repo/v1"
|
||||||
@@ -40,12 +41,14 @@ func ValidateHelmRepositoryURL(repoUrl string, _ *http.Client) error {
|
|||||||
return fmt.Errorf("failed to derive repo name: %w", err)
|
return fmt.Errorf("failed to derive repo name: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ssrfTransport := ssrf.NewTransport(nil)
|
||||||
|
|
||||||
r, err := repo.NewChartRepository(
|
r, err := repo.NewChartRepository(
|
||||||
&repo.Entry{
|
&repo.Entry{
|
||||||
Name: repoName,
|
Name: repoName,
|
||||||
URL: repoUrl,
|
URL: repoUrl,
|
||||||
},
|
},
|
||||||
getter.All(settings),
|
getter.All(settings, getter.WithTransport(ssrfTransport)),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s is not a valid chart repository or cannot be reached: %w", repoUrl, err)
|
return fmt.Errorf("%s is not a valid chart repository or cannot be reached: %w", repoUrl, err)
|
||||||
@@ -53,7 +56,7 @@ func ValidateHelmRepositoryURL(repoUrl string, _ *http.Client) error {
|
|||||||
|
|
||||||
indexPath, err := r.DownloadIndexFile()
|
indexPath, err := r.DownloadIndexFile()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s is not a valid chart repository or cannot be reached: %w", repoUrl, err)
|
return fmt.Errorf("%s is not a valid chart repository or cannot be reached", repoUrl)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Best-effort: load and seed in-memory cache for future SearchRepo calls
|
// Best-effort: load and seed in-memory cache for future SearchRepo calls
|
||||||
|
|||||||
@@ -0,0 +1,53 @@
|
|||||||
|
package ssrf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewTransport creates an SSRF-protected transport for user-influenced destinations.
|
||||||
|
// It clones http.DefaultTransport as its base (inheriting pool and timeout defaults)
|
||||||
|
// and applies the global SSRF dial-context filter so mode changes take effect without
|
||||||
|
// restarting. tlsConfig may be nil to preserve standard TLS behavior (system CAs).
|
||||||
|
func NewTransport(tlsConfig *tls.Config) *http.Transport {
|
||||||
|
base := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
base.TLSClientConfig = tlsConfig
|
||||||
|
applySSRF(base)
|
||||||
|
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInternalTransport creates a plain transport for destinations chosen by Portainer,
|
||||||
|
// not by the user (Docker socket proxy, Chisel tunnels, in-cluster Kubernetes API).
|
||||||
|
// It clones http.DefaultTransport as its base. tlsConfig may be nil.
|
||||||
|
// Using this function instead of NewTransport makes the exemption explicit and
|
||||||
|
// satisfies the ruleguard lint rule.
|
||||||
|
func NewInternalTransport(tlsConfig *tls.Config) *http.Transport {
|
||||||
|
base := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
base.TLSClientConfig = tlsConfig
|
||||||
|
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrapDefaultTransport replaces http.DefaultTransport with an SSRF-protected version.
|
||||||
|
// Must be called after Configure. Returns false if DefaultTransport is not an *http.Transport.
|
||||||
|
func WrapDefaultTransport() bool {
|
||||||
|
dt, ok := http.DefaultTransport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
cloned := dt.Clone()
|
||||||
|
applySSRF(cloned)
|
||||||
|
http.DefaultTransport = cloned
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// applySSRF sets the SSRF-filtering DialContext on t when the global dialer is active.
|
||||||
|
func applySSRF(t *http.Transport) {
|
||||||
|
d := globalDialer.Load()
|
||||||
|
if d != nil {
|
||||||
|
t.DialContext = d.DialContext
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -112,31 +111,6 @@ func CheckURL(ctx context.Context, rawURL string) error {
|
|||||||
return d.checkHost(ctx, host)
|
return d.checkHost(ctx, host)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WrapTransport clones t and replaces its DialContext with the global SSRF-filtering
|
|
||||||
// dialer. The dialer checks the mode on every connection, so the transport is always
|
|
||||||
// wrapped and mode changes take effect without restarting.
|
|
||||||
func WrapTransport(t *http.Transport) *http.Transport {
|
|
||||||
d := globalDialer.Load()
|
|
||||||
if d == nil {
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
cloned := t.Clone()
|
|
||||||
cloned.DialContext = d.DialContext
|
|
||||||
|
|
||||||
return cloned
|
|
||||||
}
|
|
||||||
|
|
||||||
// WrapTransportInternal is a documented no-op for transports that connect to
|
|
||||||
// internally computed destinations (local Docker socket proxy, Chisel tunnels,
|
|
||||||
// in-cluster Kubernetes API). The destination is chosen by Portainer, not
|
|
||||||
// supplied by any user, so SSRF validation is not applicable. Using this
|
|
||||||
// function instead of WrapTransport makes the exemption explicit and
|
|
||||||
// satisfies the ruleguard lint rule.
|
|
||||||
func WrapTransportInternal(t *http.Transport) *http.Transport {
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialContext resolves addr, validates all resolved IPs against the allowlist policy,
|
// DialContext resolves addr, validates all resolved IPs against the allowlist policy,
|
||||||
// then dials using the first resolved IP to prevent DNS rebinding attacks.
|
// then dials using the first resolved IP to prevent DNS rebinding attacks.
|
||||||
func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
|||||||
@@ -116,22 +116,27 @@ func TestConfigure_NilServicesReturnsError(t *testing.T) {
|
|||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWrapTransport_NoPolicy(t *testing.T) {
|
func TestNewTransport_NoPolicy(t *testing.T) {
|
||||||
globalDialer.Store(nil)
|
globalDialer.Store(nil)
|
||||||
|
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||||
|
|
||||||
base := &http.Transport{}
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
result := WrapTransport(base)
|
w.WriteHeader(http.StatusOK)
|
||||||
require.Equal(t, base, result)
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &http.Client{Transport: NewTransport(nil)}
|
||||||
|
resp, err := client.Get(srv.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, resp.Body.Close())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWrapTransport_WithPolicy(t *testing.T) {
|
func TestNewTransport_WithPolicy(t *testing.T) {
|
||||||
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"}))
|
err := Configure(newStaticService(portainer.SSRFModeEnforce, []string{"example.com"}))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||||
|
|
||||||
base := &http.Transport{}
|
result := NewTransport(nil)
|
||||||
result := WrapTransport(base)
|
|
||||||
require.NotEqual(t, base, result)
|
|
||||||
require.NotNil(t, result.DialContext)
|
require.NotNil(t, result.DialContext)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -267,12 +272,12 @@ func TestIsEnabled(t *testing.T) {
|
|||||||
require.False(t, IsEnabled())
|
require.False(t, IsEnabled())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWrapTransportInternal(t *testing.T) {
|
func TestNewInternalTransport(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
base := &http.Transport{}
|
result := NewInternalTransport(nil)
|
||||||
result := WrapTransportInternal(base)
|
require.NotNil(t, result)
|
||||||
require.Equal(t, base, result)
|
require.Nil(t, result.TLSClientConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestDialContext_BlocksLoopback is an end-to-end test: it starts a real HTTP
|
// TestDialContext_BlocksLoopback is an end-to-end test: it starts a real HTTP
|
||||||
@@ -288,7 +293,7 @@ func TestDialContext_BlocksLoopback(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||||
|
|
||||||
blocked := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
blocked := &http.Client{Transport: NewTransport(nil)}
|
||||||
resp, err := blocked.Get(srv.URL)
|
resp, err := blocked.Get(srv.URL)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "ssrf")
|
require.Contains(t, err.Error(), "ssrf")
|
||||||
@@ -300,7 +305,7 @@ func TestDialContext_BlocksLoopback(t *testing.T) {
|
|||||||
err = Configure(newStaticService(portainer.SSRFModeOff, nil))
|
err = Configure(newStaticService(portainer.SSRFModeOff, nil))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
open := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
open := &http.Client{Transport: NewTransport(nil)}
|
||||||
resp, err = open.Get(srv.URL)
|
resp, err = open.Get(srv.URL)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, resp.Body.Close())
|
require.NoError(t, resp.Body.Close())
|
||||||
@@ -318,7 +323,7 @@ func TestDialContext_AuditMode_AllowsLoopback(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||||
|
|
||||||
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
client := &http.Client{Transport: NewTransport(nil)}
|
||||||
resp, err := client.Get(srv.URL)
|
resp, err := client.Get(srv.URL)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, resp.Body.Close())
|
require.NoError(t, resp.Body.Close())
|
||||||
@@ -364,7 +369,7 @@ func TestDialContext_AllowedByCIDR(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||||
|
|
||||||
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
client := &http.Client{Transport: NewTransport(nil)}
|
||||||
resp, err := client.Get(srv.URL)
|
resp, err := client.Get(srv.URL)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, resp.Body.Close())
|
require.NoError(t, resp.Body.Close())
|
||||||
@@ -398,7 +403,7 @@ func TestDialContext_AllowedByExactHostname(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() { globalDialer.Store(nil) })
|
t.Cleanup(func() { globalDialer.Store(nil) })
|
||||||
|
|
||||||
client := &http.Client{Transport: WrapTransport(&http.Transport{})}
|
client := &http.Client{Transport: NewTransport(nil)}
|
||||||
resp, err := client.Get("http://localhost:" + portStr)
|
resp, err := client.Get("http://localhost:" + portStr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, resp.Body.Close())
|
require.NoError(t, resp.Body.Close())
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/portainer/portainer/api/crypto"
|
"github.com/portainer/portainer/api/crypto"
|
||||||
"github.com/portainer/portainer/api/logs"
|
"github.com/portainer/portainer/api/logs"
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
"github.com/segmentio/encoding/json"
|
"github.com/segmentio/encoding/json"
|
||||||
@@ -73,10 +74,8 @@ func ProbeTelnetConnection(url string) string {
|
|||||||
// ignores errors for the http request since we want to know if the host is reachable
|
// ignores errors for the http request since we want to know if the host is reachable
|
||||||
func DetectProxy(url string) string {
|
func DetectProxy(url string) string {
|
||||||
client := &http.Client{
|
client := &http.Client{
|
||||||
Transport: &http.Transport{
|
Transport: ssrf.NewTransport(crypto.CreateTLSConfiguration(true)),
|
||||||
TLSClientConfig: crypto.CreateTLSConfiguration(true),
|
Timeout: 10 * time.Second,
|
||||||
},
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result := map[string]string{
|
result := map[string]string{
|
||||||
|
|||||||
@@ -16,8 +16,7 @@ import (
|
|||||||
func CreateClient(registry *portainer.Registry) (httpClient *http.Client, usePlainHttp bool, err error) {
|
func CreateClient(registry *portainer.Registry) (httpClient *http.Client, usePlainHttp bool, err error) {
|
||||||
switch registry.Type {
|
switch registry.Type {
|
||||||
case portainer.AzureRegistry, portainer.EcrRegistry, portainer.GithubRegistry, portainer.GitlabRegistry, portainer.DockerHubRegistry:
|
case portainer.AzureRegistry, portainer.EcrRegistry, portainer.GithubRegistry, portainer.GitlabRegistry, portainer.DockerHubRegistry:
|
||||||
base := http.DefaultTransport.(*http.Transport).Clone()
|
return &http.Client{Transport: retry.NewTransport(ssrf.NewTransport(nil))}, false, nil
|
||||||
return &http.Client{Transport: retry.NewTransport(ssrf.WrapTransport(base))}, false, nil
|
|
||||||
default:
|
default:
|
||||||
// For all other registry types, use shared helper to build transport and scheme
|
// For all other registry types, use shared helper to build transport and scheme
|
||||||
|
|
||||||
|
|||||||
@@ -6,28 +6,27 @@ import (
|
|||||||
portainer "github.com/portainer/portainer/api"
|
portainer "github.com/portainer/portainer/api"
|
||||||
"github.com/portainer/portainer/api/crypto"
|
"github.com/portainer/portainer/api/crypto"
|
||||||
"github.com/portainer/portainer/pkg/fips"
|
"github.com/portainer/portainer/pkg/fips"
|
||||||
|
"github.com/portainer/portainer/pkg/libhttp/ssrf"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BuildTransportAndSchemeFromTLSConfig returns a base HTTP transport configured
|
// BuildTransportAndSchemeFromTLSConfig returns an SSRF-protected HTTP transport and the
|
||||||
// with ProxyFromEnvironment and, when needed, a TLSClientConfig derived from the
|
// scheme ("http" or "https") to use when contacting the registry. The transport is based on
|
||||||
// provided TLS settings. It also returns the scheme ("http" or "https") that
|
// the TLS settings from tlsCfg; pass a zero-value TLSConfiguration for plaintext.
|
||||||
// should be used to contact the registry based on the TLS settings.
|
|
||||||
func BuildTransportAndSchemeFromTLSConfig(tlsCfg portainer.TLSConfiguration) (*http.Transport, string, error) {
|
func BuildTransportAndSchemeFromTLSConfig(tlsCfg portainer.TLSConfiguration) (*http.Transport, string, error) {
|
||||||
baseTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
baseTransport.Proxy = http.ProxyFromEnvironment
|
|
||||||
|
|
||||||
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(tlsCfg)
|
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(tlsCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
baseTransport.TLSClientConfig = tlsConfig
|
|
||||||
|
|
||||||
if tlsConfig == nil && fips.FIPSMode() {
|
if tlsConfig == nil && fips.FIPSMode() {
|
||||||
return nil, "", fips.ErrTLSRequired
|
return nil, "", fips.ErrTLSRequired
|
||||||
} else if tlsConfig == nil {
|
|
||||||
return baseTransport, "http", nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return baseTransport, "https", nil
|
transport := ssrf.NewTransport(tlsConfig)
|
||||||
|
|
||||||
|
if tlsConfig == nil {
|
||||||
|
return transport, "http", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return transport, "https", nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user