mirror of
https://github.com/portainer/portainer.git
synced 2026-06-23 04:30:16 +00:00
231 lines
5.8 KiB
Go
231 lines
5.8 KiB
Go
package ssrf
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"strings"
|
|
"sync/atomic"
|
|
|
|
portainer "github.com/portainer/portainer/api"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
// ParseAllowedHosts parses raw allow list entries into their three canonical
|
|
// forms. Accepted formats: exact hostname, wildcard hostname (*.example.com),
|
|
// single IP, or CIDR range.
|
|
func ParseAllowedHosts(entries []string) portainer.ParsedAllowList {
|
|
nets := make([]*net.IPNet, 0, len(entries))
|
|
hosts := make(map[string]bool, len(entries))
|
|
var wilds []string
|
|
|
|
for _, entry := range entries {
|
|
if _, network, err := net.ParseCIDR(entry); err == nil {
|
|
nets = append(nets, network)
|
|
continue
|
|
}
|
|
|
|
if ip := net.ParseIP(entry); ip != nil {
|
|
bits := 32
|
|
if ip.To4() == nil {
|
|
bits = 128
|
|
}
|
|
|
|
mask := net.CIDRMask(bits, bits)
|
|
nets = append(nets, &net.IPNet{IP: ip.Mask(mask), Mask: mask})
|
|
|
|
continue
|
|
}
|
|
|
|
if strings.HasPrefix(entry, "*.") {
|
|
wilds = append(wilds, entry[1:]) // "*.foo.com" -> ".foo.com"
|
|
continue
|
|
}
|
|
|
|
hosts[entry] = true
|
|
}
|
|
|
|
return portainer.ParsedAllowList{Nets: nets, Hosts: hosts, Wilds: wilds}
|
|
}
|
|
|
|
// AllowListService is implemented by the allowlist data service.
|
|
// ReadParsed is called on every dial to pick up runtime changes.
|
|
type AllowListService interface {
|
|
ReadParsed(id portainer.AllowListKey) (*portainer.ParsedAllowList, error)
|
|
}
|
|
|
|
type safeDialer struct {
|
|
base net.Dialer
|
|
service AllowListService
|
|
}
|
|
|
|
var globalDialer atomic.Pointer[safeDialer]
|
|
|
|
// Configure initializes the global SSRF policy with the allow list data service.
|
|
func Configure(svc AllowListService) error {
|
|
if svc == nil {
|
|
return errors.New("unable to configure ssrf: service must not be nil")
|
|
}
|
|
|
|
globalDialer.Store(&safeDialer{service: svc})
|
|
return nil
|
|
}
|
|
|
|
// IsEnabled reports whether SSRF protection is currently active (audit or enforce).
|
|
func IsEnabled() bool {
|
|
d := globalDialer.Load()
|
|
if d == nil {
|
|
return false
|
|
}
|
|
|
|
allowList, err := d.service.ReadParsed(portainer.AllowListSSRF)
|
|
if err != nil {
|
|
log.Err(err).Msg("unable to check SSRF protection mode")
|
|
return false
|
|
}
|
|
|
|
return allowList.Mode != portainer.SSRFModeOff
|
|
}
|
|
|
|
// CheckURL validates rawURL against the active SSRF policy without making a
|
|
// connection. Returns nil when protection is disabled or the destination is
|
|
// permitted. In audit mode, logs a warning on violations and returns nil.
|
|
func CheckURL(ctx context.Context, rawURL string) error {
|
|
d := globalDialer.Load()
|
|
if d == nil {
|
|
return nil
|
|
}
|
|
|
|
u, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
return fmt.Errorf("ssrf: invalid URL %q: %w", rawURL, err)
|
|
}
|
|
|
|
host := u.Hostname()
|
|
if host == "" {
|
|
return nil
|
|
}
|
|
|
|
return d.checkHost(ctx, host)
|
|
}
|
|
|
|
// DialContext resolves addr, validates all resolved IPs against the allowlist policy,
|
|
// then dials using the first resolved IP to prevent DNS rebinding attacks.
|
|
func (d *safeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
allowList, err := d.service.ReadParsed(portainer.AllowListSSRF)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ssrf: reading allow list: %w", err)
|
|
}
|
|
|
|
if allowList.Mode == portainer.SSRFModeOff {
|
|
return d.base.DialContext(ctx, network, addr)
|
|
}
|
|
|
|
host, port, err := net.SplitHostPort(addr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ssrf: invalid address %q: %w", addr, err)
|
|
}
|
|
|
|
resolved, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ssrf: resolving %q: %w", host, err)
|
|
}
|
|
|
|
if len(resolved) == 0 {
|
|
return nil, fmt.Errorf("ssrf: no addresses resolved for %q", host)
|
|
}
|
|
|
|
// Dial by resolved IP regardless of how the host was allowed to close the
|
|
// window between DNS validation and the TCP handshake (DNS rebinding).
|
|
dialTarget := net.JoinHostPort(resolved[0].IP.String(), port)
|
|
|
|
if allowList.Hosts[host] || matchesWildcard(host, allowList.Wilds) {
|
|
return d.base.DialContext(ctx, network, dialTarget)
|
|
}
|
|
|
|
for _, a := range resolved {
|
|
if !ipAllowed(a.IP, allowList.Nets) {
|
|
if allowList.Mode == portainer.SSRFModeAudit {
|
|
log.Warn().Str("host", host).Str("ip", a.IP.String()).Msg("ssrf: destination not in allowlist (audit mode, allowing)")
|
|
continue
|
|
}
|
|
|
|
return nil, fmt.Errorf("ssrf: destination %s is not in the allowlist", a.IP)
|
|
}
|
|
}
|
|
|
|
return d.base.DialContext(ctx, network, dialTarget)
|
|
}
|
|
|
|
func (d *safeDialer) checkHost(ctx context.Context, host string) error {
|
|
allowList, err := d.service.ReadParsed(portainer.AllowListSSRF)
|
|
if err != nil {
|
|
return fmt.Errorf("ssrf: reading allow list: %w", err)
|
|
}
|
|
|
|
if allowList.Mode == portainer.SSRFModeOff {
|
|
return nil
|
|
}
|
|
|
|
if allowList.Hosts[host] || matchesWildcard(host, allowList.Wilds) {
|
|
return nil
|
|
}
|
|
|
|
if ip := net.ParseIP(host); ip != nil {
|
|
if !ipAllowed(ip, allowList.Nets) {
|
|
if allowList.Mode == portainer.SSRFModeAudit {
|
|
log.Warn().Str("host", host).Msg("ssrf: destination not in allowlist (audit mode, allowing)")
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("ssrf: destination %s is not in the allowlist", ip)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
resolved, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
|
if err != nil {
|
|
return fmt.Errorf("ssrf: resolving %q: %w", host, err)
|
|
}
|
|
|
|
if len(resolved) == 0 {
|
|
return fmt.Errorf("ssrf: no addresses resolved for %q", host)
|
|
}
|
|
|
|
for _, a := range resolved {
|
|
if !ipAllowed(a.IP, allowList.Nets) {
|
|
if allowList.Mode == portainer.SSRFModeAudit {
|
|
log.Warn().Str("host", host).Str("ip", a.IP.String()).Msg("ssrf: destination not in allowlist (audit mode, allowing)")
|
|
continue
|
|
}
|
|
|
|
return fmt.Errorf("ssrf: destination %s is not in the allowlist", a.IP)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func matchesWildcard(host string, wilds []string) bool {
|
|
for _, suffix := range wilds {
|
|
if strings.HasSuffix(host, suffix) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func ipAllowed(ip net.IP, nets []*net.IPNet) bool {
|
|
for _, network := range nets {
|
|
if network.Contains(ip) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|