TUN-10558: Bump go to v1.24.4, x/crypto to v0.52.0 and google.golang.org/grpc to v1.81.1

Closes TUN-10558
This commit is contained in:
João "Pisco" Fernandes
2026-06-08 19:15:35 +01:00
parent 52519f67e8
commit ccffef1179
52 changed files with 1792 additions and 5207 deletions
+1 -1
View File
@@ -6,7 +6,7 @@ RUN apt-get update && \
apt-get install --no-install-recommends --allow-downgrades -y \ apt-get install --no-install-recommends --allow-downgrades -y \
build-essential \ build-essential \
git \ git \
go-boring=1.26.3-1 \ go-boring=1.26.4-1 \
libffi-dev \ libffi-dev \
procps \ procps \
python3-dev \ python3-dev \
+1 -1
View File
@@ -5,7 +5,7 @@
runner: linux-x86-8cpu-16gb runner: linux-x86-8cpu-16gb
stage: build stage: build
golangVersion: "boring-1.26" golangVersion: "boring-1.26"
imageVersion: "3605-596a300@sha256:19fa512630b4c5681082c68fd98902e2f92092fc216412df44f7dda31cfa57c3" imageVersion: "3625-1801d52@sha256:9261597bc2d229c997522848260de758567643d58ae1097196ae368db89a1d0f"
CGO_ENABLED: 1 CGO_ENABLED: 1
.default-packaging-job: &packaging-job-defaults .default-packaging-job: &packaging-job-defaults
+1 -1
View File
@@ -8,7 +8,7 @@ include:
rules: rules:
- !reference [.default-rules, run-always] - !reference [.default-rules, run-always]
tags: tags:
- windows-x86 - canary-windows-x86
cache: {} cache: {}
########################################## ##########################################
+1 -1
View File
@@ -1,5 +1,5 @@
variables: variables:
GO_VERSION: "1.26.3" GO_VERSION: "1.26.4"
MAC_GO_VERSION: "go@$GO_VERSION" MAC_GO_VERSION: "go@$GO_VERSION"
WIN_GO_VERSION: "go$GO_VERSION" WIN_GO_VERSION: "go$GO_VERSION"
GIT_DEPTH: "0" GIT_DEPTH: "0"
+1 -1
View File
@@ -1,7 +1,7 @@
# use a builder image for building cloudflare # use a builder image for building cloudflare
ARG TARGET_GOOS ARG TARGET_GOOS
ARG TARGET_GOARCH ARG TARGET_GOARCH
FROM golang:1.26.3 AS builder FROM golang:1.26.4 AS builder
ENV GO111MODULE=on \ ENV GO111MODULE=on \
CGO_ENABLED=0 \ CGO_ENABLED=0 \
TARGET_GOOS=${TARGET_GOOS} \ TARGET_GOOS=${TARGET_GOOS} \
+1 -1
View File
@@ -1,5 +1,5 @@
# use a builder image for building cloudflare # use a builder image for building cloudflare
FROM golang:1.26.3 AS builder FROM golang:1.26.4 AS builder
ENV GO111MODULE=on \ ENV GO111MODULE=on \
CGO_ENABLED=0 \ CGO_ENABLED=0 \
# the CONTAINER_BUILD envvar is used set github.com/cloudflare/cloudflared/metrics.Runtime=virtual # the CONTAINER_BUILD envvar is used set github.com/cloudflare/cloudflared/metrics.Runtime=virtual
+1 -1
View File
@@ -1,5 +1,5 @@
# use a builder image for building cloudflare # use a builder image for building cloudflare
FROM golang:1.26.3 AS builder FROM golang:1.26.4 AS builder
ENV GO111MODULE=on \ ENV GO111MODULE=on \
CGO_ENABLED=0 \ CGO_ENABLED=0 \
# the CONTAINER_BUILD envvar is used set github.com/cloudflare/cloudflared/metrics.Runtime=virtual # the CONTAINER_BUILD envvar is used set github.com/cloudflare/cloudflared/metrics.Runtime=virtual
+5 -5
View File
@@ -36,7 +36,7 @@ require (
go.opentelemetry.io/proto/otlp v1.10.0 go.opentelemetry.io/proto/otlp v1.10.0
go.uber.org/automaxprocs v1.6.0 go.uber.org/automaxprocs v1.6.0
go.uber.org/mock v0.5.1 go.uber.org/mock v0.5.1
golang.org/x/crypto v0.51.0 golang.org/x/crypto v0.52.0
golang.org/x/net v0.55.0 golang.org/x/net v0.55.0
golang.org/x/sync v0.20.0 golang.org/x/sync v0.20.0
golang.org/x/sys v0.45.0 golang.org/x/sys v0.45.0
@@ -92,12 +92,12 @@ require (
go.opentelemetry.io/otel/metric v1.43.0 // indirect go.opentelemetry.io/otel/metric v1.43.0 // indirect
golang.org/x/arch v0.4.0 // indirect golang.org/x/arch v0.4.0 // indirect
golang.org/x/mod v0.35.0 // indirect golang.org/x/mod v0.35.0 // indirect
golang.org/x/oauth2 v0.35.0 // indirect golang.org/x/oauth2 v0.36.0 // indirect
golang.org/x/text v0.37.0 // indirect golang.org/x/text v0.37.0 // indirect
golang.org/x/tools v0.44.0 // indirect golang.org/x/tools v0.44.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260226221140-a57be14db171 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect
google.golang.org/grpc v1.79.2 // indirect google.golang.org/grpc v1.81.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
) )
+12 -12
View File
@@ -245,8 +245,8 @@ golang.org/x/arch v0.4.0 h1:A8WCeEWhLwPBKNbFi5Wv5UTCBx5zzubnXDlMOFAzFMc=
golang.org/x/arch v0.4.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.4.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM= golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM=
@@ -255,8 +255,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8= golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww= golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
@@ -286,14 +286,14 @@ golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 h1:JLQynH/LBHfCTSbDWl+py8C+Rg/k1OVH3xfcaiANuF0= google.golang.org/genproto/googleapis/api v0.0.0-20260226221140-a57be14db171 h1:tu/dtnW1o3wfaxCOjSLn5IRX4YDcJrtlpzYkhHhGaC4=
google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:kSJwQxqmFXeo79zOmbrALdflXQeAYcUbgS7PbpMknCY= google.golang.org/genproto/googleapis/api v0.0.0-20260226221140-a57be14db171/go.mod h1:M5krXqk4GhBKvB596udGL3UyjL4I1+cTbK0orROM9ng=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 h1:mWPCjDEyshlQYzBpMNHaEof6UX1PmHcaUODUywQ0uac= google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/grpc v1.79.2 h1:fRMD94s2tITpyJGtBBn7MkMseNpOZU8ZxgC3MMBaXRU= google.golang.org/grpc v1.81.1 h1:VnnIIZ88UzOOKLukQi+ImGz8O1Wdp8nAGGnvOfEIWQQ=
google.golang.org/grpc v1.79.2/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= google.golang.org/grpc v1.81.1/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-11
View File
@@ -1,11 +0,0 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.25
package blake2b
import "hash"
var _ hash.XOF = (*xof)(nil)
+3 -3
View File
@@ -20,7 +20,7 @@ func chacha20Poly1305Open(dst []byte, key []uint32, src, ad []byte) bool
func chacha20Poly1305Seal(dst []byte, key []uint32, src, ad []byte) func chacha20Poly1305Seal(dst []byte, key []uint32, src, ad []byte)
var ( var (
useAVX2 = cpu.X86.HasAVX2 && cpu.X86.HasBMI2 useAVX2 = cpu.X86.HasSSSE3 && cpu.X86.HasAVX2 && cpu.X86.HasBMI2
) )
// setupState writes a ChaCha20 input matrix to state. See // setupState writes a ChaCha20 input matrix to state. See
@@ -47,7 +47,7 @@ func setupState(state *[16]uint32, key *[32]byte, nonce []byte) {
} }
func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []byte { func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []byte {
if !cpu.X86.HasSSSE3 { if !useAVX2 {
return c.sealGeneric(dst, nonce, plaintext, additionalData) return c.sealGeneric(dst, nonce, plaintext, additionalData)
} }
@@ -66,7 +66,7 @@ func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []
} }
func (c *chacha20poly1305) open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { func (c *chacha20poly1305) open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
if !cpu.X86.HasSSSE3 { if !useAVX2 {
return c.openGeneric(dst, nonce, ciphertext, additionalData) return c.openGeneric(dst, nonce, ciphertext, additionalData)
} }
File diff suppressed because it is too large Load Diff
+17 -1
View File
@@ -348,6 +348,9 @@ func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey)
if cert.CertType != HostCert { if cert.CertType != HostCert {
return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType) return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType)
} }
if c.IsHostAuthority == nil {
return errors.New("ssh: cannot verify certificate, IsHostAuthority not set")
}
if !c.IsHostAuthority(cert.SignatureKey, addr) { if !c.IsHostAuthority(cert.SignatureKey, addr) {
return fmt.Errorf("ssh: no authorities for hostname: %v", addr) return fmt.Errorf("ssh: no authorities for hostname: %v", addr)
} }
@@ -375,6 +378,9 @@ func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permis
if cert.CertType != UserCert { if cert.CertType != UserCert {
return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType)
} }
if c.IsUserAuthority == nil {
return nil, errors.New("ssh: cannot verify certificate, IsUserAuthority not set")
}
if !c.IsUserAuthority(cert.SignatureKey) { if !c.IsUserAuthority(cert.SignatureKey) {
return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority") return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority")
} }
@@ -438,7 +444,17 @@ func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) { if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) {
return fmt.Errorf("ssh: cert has expired") return fmt.Errorf("ssh: cert has expired")
} }
if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil { // Match OpenSSH: the SK user-presence flag is never enforced on a
// certificate's CA signature. OpenSSH calls sshkey_verify with
// detailsp==NULL in sshkey.c:cert_parse, so the UP/UV flags are
// not even extracted. The UP bit on a CA signature reflects the
// CA operator's presence at signing time, which has no bearing on
// whether the user being authenticated is present now; enforcing
// it here would only break interop with certificates issued by
// non-interactive SK CAs. skKeyWithoutUP is a no-op for non-SK
// keys (the common case).
caKey := skKeyWithoutUP(cert.SignatureKey)
if err := caKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil {
return fmt.Errorf("ssh: certificate signature does not verify") return fmt.Errorf("ssh: certificate signature does not verify")
} }
+59 -6
View File
@@ -11,6 +11,7 @@ import (
"io" "io"
"log" "log"
"sync" "sync"
"sync/atomic"
) )
const ( const (
@@ -131,11 +132,17 @@ func (r RejectionReason) String() string {
return fmt.Sprintf("unknown reason %d", int(r)) return fmt.Sprintf("unknown reason %d", int(r))
} }
func min(a uint32, b int) uint32 { // minPayloadSize returns min(limit, length) clamped to a uint32. It is used
if a < uint32(b) { // to compute the size of the next channel data packet from the remaining
return a // payload. The comparison is done in int64 because length is an int — on
// 64-bit systems len(data) can exceed 2^32, and a direct uint32(length)
// cast would silently truncate to 0 at every multiple of 2^32, causing
// WriteExtended's loop to spin without making progress.
func minPayloadSize(limit uint32, length int) uint32 {
if int64(length) > int64(limit) {
return limit
} }
return uint32(b) return uint32(length)
} }
type channelDirection uint8 type channelDirection uint8
@@ -177,6 +184,12 @@ type channel struct {
// with WantReply=true outstanding. This lock is held by a // with WantReply=true outstanding. This lock is held by a
// goroutine that has such an outgoing request pending. // goroutine that has such an outgoing request pending.
sentRequestMu sync.Mutex sentRequestMu sync.Mutex
// sentRequestPending is set to true while a SendRequest call with
// WantReply=true is in flight. handlePacket uses it as a gate: responses
// arriving while no request is pending are dropped to prevent a
// misbehaving peer from stalling the mux read loop by filling ch.msg
// with unsolicited channelRequestSuccess/Failure messages.
sentRequestPending atomic.Bool
incomingRequests chan *Request incomingRequests chan *Request
@@ -251,7 +264,7 @@ func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err e
ch.writeMu.Unlock() ch.writeMu.Unlock()
for len(data) > 0 { for len(data) > 0 {
space := min(ch.maxRemotePayload, len(data)) space := minPayloadSize(ch.maxRemotePayload, len(data))
if space, err = ch.remoteWin.reserve(space); err != nil { if space, err = ch.remoteWin.reserve(space); err != nil {
return n, err return n, err
} }
@@ -460,6 +473,18 @@ func (ch *channel) handlePacket(packet []byte) error {
} }
ch.incomingRequests <- &req ch.incomingRequests <- &req
case *channelRequestSuccessMsg, *channelRequestFailureMsg:
// Drop responses that arrive when no SendRequest is waiting, to
// prevent a malicious peer from filling ch.msg and stalling the
// mux read loop. The non-blocking send additionally protects the
// loop if a well-behaved caller is slow to read.
if !ch.sentRequestPending.Load() {
return nil
}
select {
case ch.msg <- msg:
default:
}
default: default:
ch.msg <- msg ch.msg <- msg
} }
@@ -530,7 +555,17 @@ func (ch *channel) Reject(reason RejectionReason, message string) error {
Language: "en", Language: "en",
} }
ch.decided = true ch.decided = true
return ch.sendMessage(reject) err := ch.sendMessage(reject)
// Remove the channel from the mux to prevent memory leaks.
// Do not call ch.close() here: no goroutine holds a reference to a
// rejected channel's internal channels (msg, incomingRequests), so
// removing it from chanList is sufficient for GC. Calling close()
// would race with the mux loop goroutine (handlePacket or dropAll),
// causing a panic from closing an already-closed channel.
ch.mux.chanList.remove(ch.localId)
return err
} }
func (ch *channel) Read(data []byte) (int, error) { func (ch *channel) Read(data []byte) (int, error) {
@@ -586,6 +621,24 @@ func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (boo
if wantReply { if wantReply {
ch.sentRequestMu.Lock() ch.sentRequestMu.Lock()
defer ch.sentRequestMu.Unlock() defer ch.sentRequestMu.Unlock()
// Open the gate so that responses arriving while this request is in
// flight are allowed to reach ch.msg. Responses arriving while no
// request is pending are dropped by handlePacket.
ch.sentRequestPending.Store(true)
defer ch.sentRequestPending.Store(false)
// Drain any spurious responses that may have been buffered. This
// prevents a previously buffered unexpected response from being
// consumed instead of the actual response for this request.
drain:
for {
select {
case <-ch.msg:
default:
break drain
}
}
} }
msg := channelRequestMsg{ msg := channelRequestMsg{
+1 -1
View File
@@ -407,7 +407,7 @@ func (c *gcmCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error)
return nil, fmt.Errorf("ssh: illegal padding %d", padding) return nil, fmt.Errorf("ssh: illegal padding %d", padding)
} }
if int(padding+1) >= len(plain) { if int(padding)+1 >= len(plain) {
return nil, fmt.Errorf("ssh: padding %d too large", padding) return nil, fmt.Errorf("ssh: padding %d too large", padding)
} }
plain = plain[1 : length-uint32(padding)] plain = plain[1 : length-uint32(padding)]
+58
View File
@@ -469,6 +469,12 @@ func parseRSA(in []byte) (out PublicKey, rest []byte, err error) {
return nil, nil, err return nil, nil, err
} }
// 8192 bits is also the maximum RSA key size accepted by crypto/tls for
// signature verification:
// https://github.com/golang/go/blob/69801b25/src/crypto/tls/handshake_client.go#L1096
if w.N.BitLen() > 8192 {
return nil, nil, errors.New("ssh: rsa modulus too large")
}
if w.E.BitLen() > 24 { if w.E.BitLen() > 24 {
return nil, nil, errors.New("ssh: exponent too large") return nil, nil, errors.New("ssh: exponent too large")
} }
@@ -574,6 +580,24 @@ func checkDSAParams(param *dsa.Parameters) error {
return fmt.Errorf("ssh: unsupported DSA key size %d", l) return fmt.Errorf("ssh: unsupported DSA key size %d", l)
} }
// FIPS 186-2 specifies that Q must be exactly 160 bits. We must enforce
// this to prevent DoS attacks where an attacker sends a huge Q which makes
// verification slow.
if l := param.Q.BitLen(); l != 160 {
return fmt.Errorf("ssh: unsupported DSA sub-prime size %d", l)
}
// The generator G is an element of the group, so it must be strictly less
// than the modulus P.
if param.G.Cmp(param.P) >= 0 {
return errors.New("ssh: DSA generator larger than modulus")
}
// G must be positive.
if param.G.Sign() <= 0 {
return errors.New("ssh: DSA generator must be positive")
}
return nil return nil
} }
@@ -596,6 +620,14 @@ func parseDSA(in []byte) (out PublicKey, rest []byte, err error) {
return nil, nil, err return nil, nil, err
} }
// The public value Y must be a non-zero element of the group, i.e.
// strictly between 0 and P. crypto/dsa.Verify does not range-check Y,
// so we reject out-of-range values here to prevent a maliciously
// oversized Y from slowing verification.
if w.Y.Sign() <= 0 || w.Y.Cmp(w.P) >= 0 {
return nil, nil, errors.New("ssh: DSA public value Y out of range")
}
key := &dsaPublicKey{ key := &dsaPublicKey{
Parameters: param, Parameters: param,
Y: w.Y, Y: w.Y,
@@ -869,11 +901,25 @@ type skFields struct {
Counter uint32 Counter uint32
} }
// flagUserPresence is the "user present" bit (UP) in the SK signature
// flags, matching the FIDO CTAP2 authenticatorData UP flag. See
// openssh/PROTOCOL.u2f.
const flagUserPresence = 0x01
// errSKMissingUserPresence is returned by SK key Verify methods when
// the signature does not assert user presence and the key was not
// marked as no-touch-required.
var errSKMissingUserPresence = errors.New("ssh: signature missing required user presence flag")
type skECDSAPublicKey struct { type skECDSAPublicKey struct {
// application is a URL-like string, typically "ssh:" for SSH. // application is a URL-like string, typically "ssh:" for SSH.
// see openssh/PROTOCOL.u2f for details. // see openssh/PROTOCOL.u2f for details.
application string application string
ecdsa.PublicKey ecdsa.PublicKey
// noTouchRequired, when true, disables the default user-presence
// check in Verify. It is set by skKeyWithoutUP on a clone of the
// key, never on an instance shared across authentication attempts.
noTouchRequired bool
} }
func (k *skECDSAPublicKey) Type() string { func (k *skECDSAPublicKey) Type() string {
@@ -959,6 +1005,10 @@ func (k *skECDSAPublicKey) Verify(data []byte, sig *Signature) error {
return err return err
} }
if skf.Flags&flagUserPresence == 0 && !k.noTouchRequired {
return errSKMissingUserPresence
}
blob := struct { blob := struct {
ApplicationDigest []byte `ssh:"rest"` ApplicationDigest []byte `ssh:"rest"`
Flags byte Flags byte
@@ -992,6 +1042,10 @@ type skEd25519PublicKey struct {
// see openssh/PROTOCOL.u2f for details. // see openssh/PROTOCOL.u2f for details.
application string application string
ed25519.PublicKey ed25519.PublicKey
// noTouchRequired, when true, disables the default user-presence
// check in Verify. It is set by skKeyWithoutUP on a clone of the
// key, never on an instance shared across authentication attempts.
noTouchRequired bool
} }
func (k *skEd25519PublicKey) Type() string { func (k *skEd25519PublicKey) Type() string {
@@ -1066,6 +1120,10 @@ func (k *skEd25519PublicKey) Verify(data []byte, sig *Signature) error {
return err return err
} }
if skf.Flags&flagUserPresence == 0 && !k.noTouchRequired {
return errSKMissingUserPresence
}
blob := struct { blob := struct {
ApplicationDigest []byte `ssh:"rest"` ApplicationDigest []byte `ssh:"rest"`
Flags byte Flags byte
+32 -4
View File
@@ -91,9 +91,10 @@ type mux struct {
incomingChannels chan NewChannel incomingChannels chan NewChannel
globalSentMu sync.Mutex globalSentMu sync.Mutex
globalResponses chan interface{} globalSentPending atomic.Bool
incomingRequests chan *Request globalResponses chan interface{}
incomingRequests chan *Request
errCond *sync.Cond errCond *sync.Cond
err error err error
@@ -141,6 +142,24 @@ func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []
if wantReply { if wantReply {
m.globalSentMu.Lock() m.globalSentMu.Lock()
defer m.globalSentMu.Unlock() defer m.globalSentMu.Unlock()
// Open the gate so that responses arriving while this request is in
// flight are allowed to reach globalResponses. Any response arriving
// while no request is pending is dropped by handleGlobalPacket.
m.globalSentPending.Store(true)
defer m.globalSentPending.Store(false)
// Drain any spurious responses that may have been buffered. This prevents
// a previously buffered unexpected response from being consumed instead
// of the actual response for this request.
drain:
for {
select {
case <-m.globalResponses:
default:
break drain
}
}
} }
if err := m.sendMessage(globalRequestMsg{ if err := m.sendMessage(globalRequestMsg{
@@ -267,7 +286,16 @@ func (m *mux) handleGlobalPacket(packet []byte) error {
mux: m, mux: m,
} }
case *globalRequestSuccessMsg, *globalRequestFailureMsg: case *globalRequestSuccessMsg, *globalRequestFailureMsg:
m.globalResponses <- msg // Drop responses that arrive when no SendRequest is waiting, to
// prevent a malicious peer from staging responses for a future
// caller.
if !m.globalSentPending.Load() {
return nil
}
select {
case m.globalResponses <- msg:
default:
}
default: default:
panic(fmt.Sprintf("not a global message %#v", msg)) panic(fmt.Sprintf("not a global message %#v", msg))
} }
+114 -13
View File
@@ -34,15 +34,20 @@ type Permissions struct {
// or not supported. // or not supported.
CriticalOptions map[string]string CriticalOptions map[string]string
// Extensions are extra functionality that the server may // Extensions are extra functionality that the server may offer on
// offer on authenticated connections. Lack of support for an // authenticated connections. Lack of support for an extension does not
// extension does not preclude authenticating a user. Common // preclude authenticating a user. Common extensions are
// extensions are "permit-agent-forwarding", // "permit-agent-forwarding", "permit-X11-forwarding". In general the Go
// "permit-X11-forwarding". The Go SSH library currently does // SSH library does not act on extensions and it is up to server
// not act on any extension, and it is up to server // implementations to honor them; extensions can also be used to pass data
// implementations to honor them. Extensions can be used to // from the authentication callbacks to the server application layer.
// pass data from the authentication callbacks to the server //
// application layer. // The one extension acted upon by this library is "no-touch-required",
// which applies only to security-key public keys
// (sk-ecdsa-sha2-nistp256@openssh.com and sk-ssh-ed25519@openssh.com).
// When present, it waives the default requirement that SK signatures
// assert user presence (i.e. a physical touch of the authenticator)
// during signature verification.
Extensions map[string]string Extensions map[string]string
// ExtraData allows to store user defined data. // ExtraData allows to store user defined data.
@@ -84,6 +89,79 @@ type ServerPreAuthConn interface {
SendAuthBanner(string) error SendAuthBanner(string) error
} }
// noTouchRequiredExtension is the extension name used by OpenSSH in
// authorized_keys options and certificate extensions to mark keys
// whose signatures do not need to assert user presence (touch). See
// ssh-keygen(1) and sshd(8).
const noTouchRequiredExtension = "no-touch-required"
// noTouchAllowed reports whether the user presence requirement on
// SK signatures should be waived for this authentication attempt. The
// requirement is waived when the "no-touch-required" extension is
// present either in the Permissions returned by the auth callback
// (authorized_keys-level opt-out) or in the certificate's own
// Extensions (CA-level opt-out), matching OpenSSH behavior. OpenSSH
// reads the per-key opt-out only from cert Extensions and
// authorized_keys options (never from CriticalOptions); we follow the
// same rule.
func noTouchAllowed(pubKey PublicKey, perms *Permissions) bool {
if perms != nil {
if _, ok := perms.Extensions[noTouchRequiredExtension]; ok {
return true
}
}
if cert, ok := pubKey.(*Certificate); ok {
if _, ok := cert.Extensions[noTouchRequiredExtension]; ok {
return true
}
}
return false
}
// skKeyWithoutUP returns a PublicKey equivalent to pubKey but whose
// Verify accepts SK signatures with the user-presence flag clear. If
// pubKey is not (and does not wrap) an SK key, pubKey is returned
// unchanged. The returned value never mutates pubKey: for SK keys a
// shallow copy is made so that the noTouchRequired flag is set only on
// the clone.
//
// The implementation is iterative rather than recursive. When pubKey
// is a *Certificate we unwrap exactly one level to look at the inner
// key. The SSH cert format forbids Certificate.Key from being another
// Certificate (parseCert rejects it), but nothing stops callers from
// constructing such a value directly in Go; a recursive descent could
// otherwise be driven to unbounded depth by a hand-crafted or cyclic
// Certificate. A malformed input of that shape simply returns
// unchanged here.
func skKeyWithoutUP(pubKey PublicKey) PublicKey {
cert, isCert := pubKey.(*Certificate)
target := pubKey
if isCert {
target = cert.Key
}
var cloned PublicKey
switch k := target.(type) {
case *skECDSAPublicKey:
c := *k
c.noTouchRequired = true
cloned = &c
case *skEd25519PublicKey:
c := *k
c.noTouchRequired = true
cloned = &c
default:
// Not an SK key (or a pathological *Certificate wrapping
// another *Certificate): pubKey is already usable for Verify.
return pubKey
}
if !isCert {
return cloned
}
c := *cert
c.Key = cloned
return &c
}
// ServerConfig holds server specific configuration data. // ServerConfig holds server specific configuration data.
type ServerConfig struct { type ServerConfig struct {
// Config contains configuration shared between client and server. // Config contains configuration shared between client and server.
@@ -242,8 +320,10 @@ func (c *pubKeyCache) add(candidate cachedPubKey) {
type ServerConn struct { type ServerConn struct {
Conn Conn
// If the succeeding authentication callback returned a // If the succeeding authentication callback returned a non-nil Permissions
// non-nil Permissions pointer, it is stored here. // pointer, it is stored here. These are the permissions from the final,
// successful authentication method. Permissions returned by callbacks that
// return PartialSuccessError are not preserved and must be nil.
Permissions *Permissions Permissions *Permissions
} }
@@ -737,8 +817,15 @@ userAuthLoop:
} }
signedData := buildDataSignedForAuth(sessionID, userAuthReq, algo, pubKeyData) signedData := buildDataSignedForAuth(sessionID, userAuthReq, algo, pubKeyData)
// pubKey is reused below for VerifiedPublicKeyCallback and
if err := pubKey.Verify(signedData, sig); err != nil { // must remain the key as presented by the client; derive a
// separate value for Verify that carries any applicable
// no-touch-required opt-out.
pubKeyForVerify := pubKey
if noTouchAllowed(pubKey, candidate.perms) {
pubKeyForVerify = skKeyWithoutUP(pubKey)
}
if err := pubKeyForVerify.Verify(signedData, sig); err != nil {
return nil, err return nil, err
} }
@@ -750,6 +837,13 @@ userAuthLoop:
// considered verified and the callback must not run. // considered verified and the callback must not run.
perms, authErr = config.VerifiedPublicKeyCallback(s, pubKey, perms, algo) perms, authErr = config.VerifiedPublicKeyCallback(s, pubKey, perms, algo)
} }
if authErr == nil && perms != nil && perms.CriticalOptions != nil {
if saco := perms.CriticalOptions[sourceAddressCriticalOption]; saco != "" {
if err := checkSourceAddress(s.RemoteAddr(), saco); err != nil {
authErr = err
}
}
}
} }
case "gssapi-with-mic": case "gssapi-with-mic":
if authConfig.GSSAPIWithMICConfig == nil { if authConfig.GSSAPIWithMICConfig == nil {
@@ -824,6 +918,13 @@ userAuthLoop:
var failureMsg userAuthFailureMsg var failureMsg userAuthFailureMsg
if partialSuccess, ok := authErr.(*PartialSuccessError); ok { if partialSuccess, ok := authErr.(*PartialSuccessError); ok {
// Permissions are not preserved between authentication steps. To
// avoid confusion about the final state of the connection, we
// disallow returning non-nil Permissions combined with
// PartialSuccessError.
if perms != nil {
return nil, errors.New("ssh: permissions must be nil when returning PartialSuccessError")
}
// After a partial success error we don't allow changing the user // After a partial success error we don't allow changing the user
// name and execute the NoClientAuthCallback. // name and execute the NoClientAuthCallback.
partialSuccessReturned = true partialSuccessReturned = true
+55 -22
View File
@@ -27,6 +27,8 @@ package attributes
import ( import (
"fmt" "fmt"
"iter"
"maps"
"strings" "strings"
) )
@@ -37,37 +39,46 @@ import (
// any) bool', it will be called by (*Attributes).Equal to determine whether // any) bool', it will be called by (*Attributes).Equal to determine whether
// two values with the same key should be considered equal. // two values with the same key should be considered equal.
type Attributes struct { type Attributes struct {
m map[any]any parent *Attributes
key, value any
} }
// New returns a new Attributes containing the key/value pair. // New returns a new Attributes containing the key/value pair.
func New(key, value any) *Attributes { func New(key, value any) *Attributes {
return &Attributes{m: map[any]any{key: value}} return &Attributes{
key: key,
value: value,
}
} }
// WithValue returns a new Attributes containing the previous keys and values // WithValue returns a new Attributes containing the previous keys and values
// and the new key/value pair. If the same key appears multiple times, the // and the new key/value pair. If the same key appears multiple times, the
// last value overwrites all previous values for that key. To remove an // last value overwrites all previous values for that key. value should not be
// existing key, use a nil value. value should not be modified later. // modified later.
//
// Note that Attributes do not support deletion. Avoid using untyped nil values.
// Since the Value method returns an untyped nil when a key is absent, it is
// impossible to distinguish between a missing key and a key explicitly set to
// an untyped nil. If you need to represent a value being unset, consider
// storing a specific sentinel type or a wrapper struct with a boolean field
// indicating presence.
func (a *Attributes) WithValue(key, value any) *Attributes { func (a *Attributes) WithValue(key, value any) *Attributes {
if a == nil { return &Attributes{
return New(key, value) parent: a,
key: key,
value: value,
} }
n := &Attributes{m: make(map[any]any, len(a.m)+1)}
for k, v := range a.m {
n.m[k] = v
}
n.m[key] = value
return n
} }
// Value returns the value associated with these attributes for key, or nil if // Value returns the value associated with these attributes for key, or nil if
// no value is associated with key. The returned value should not be modified. // no value is associated with key. The returned value should not be modified.
func (a *Attributes) Value(key any) any { func (a *Attributes) Value(key any) any {
if a == nil { for cur := a; cur != nil; cur = cur.parent {
return nil if cur.key == key {
return cur.value
}
} }
return a.m[key] return nil
} }
// Equal returns whether a and o are equivalent. If 'Equal(o any) bool' is // Equal returns whether a and o are equivalent. If 'Equal(o any) bool' is
@@ -83,11 +94,15 @@ func (a *Attributes) Equal(o *Attributes) bool {
if a == nil || o == nil { if a == nil || o == nil {
return false return false
} }
if len(a.m) != len(o.m) { if a == o {
return false return true
} }
for k, v := range a.m { m := maps.Collect(o.all())
ov, ok := o.m[k] lenA := 0
for k, v := range a.all() {
lenA++
ov, ok := m[k]
if !ok { if !ok {
// o missing element of a // o missing element of a
return false return false
@@ -101,7 +116,7 @@ func (a *Attributes) Equal(o *Attributes) bool {
return false return false
} }
} }
return true return lenA == len(m)
} }
// String prints the attribute map. If any key or values throughout the map // String prints the attribute map. If any key or values throughout the map
@@ -110,11 +125,11 @@ func (a *Attributes) String() string {
var sb strings.Builder var sb strings.Builder
sb.WriteString("{") sb.WriteString("{")
first := true first := true
for k, v := range a.m { for k, v := range a.all() {
if !first { if !first {
sb.WriteString(", ") sb.WriteString(", ")
} }
sb.WriteString(fmt.Sprintf("%q: %q ", str(k), str(v))) fmt.Fprintf(&sb, "%q: %q ", str(k), str(v))
first = false first = false
} }
sb.WriteString("}") sb.WriteString("}")
@@ -139,3 +154,21 @@ func str(x any) (s string) {
func (a *Attributes) MarshalJSON() ([]byte, error) { func (a *Attributes) MarshalJSON() ([]byte, error) {
return []byte(a.String()), nil return []byte(a.String()), nil
} }
// all returns an iterator that yields all key-value pairs in the Attributes
// chain. If a key appears multiple times, only the most recently added value
// is yielded.
func (a *Attributes) all() iter.Seq2[any, any] {
return func(yield func(any, any) bool) {
seen := map[any]bool{}
for cur := a; cur != nil; cur = cur.parent {
if seen[cur.key] {
continue
}
if !yield(cur.key, cur.value) {
return
}
seen[cur.key] = true
}
}
}
+17 -15
View File
@@ -33,6 +33,7 @@ import (
estats "google.golang.org/grpc/experimental/stats" estats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/serviceconfig"
@@ -46,8 +47,8 @@ var (
) )
// Register registers the balancer builder to the balancer map. b.Name // Register registers the balancer builder to the balancer map. b.Name
// (lowercased) will be used as the name registered with this builder. If the // will be used as the name registered with this builder. If the Builder
// Builder implements ConfigParser, ParseConfig will be called when new service // implements ConfigParser, ParseConfig will be called when new service
// configs are received by the resolver, and the result will be provided to the // configs are received by the resolver, and the result will be provided to the
// Balancer in UpdateClientConnState. // Balancer in UpdateClientConnState.
// //
@@ -55,12 +56,12 @@ var (
// an init() function), and is not thread-safe. If multiple Balancers are // an init() function), and is not thread-safe. If multiple Balancers are
// registered with the same name, the one registered last will take effect. // registered with the same name, the one registered last will take effect.
func Register(b Builder) { func Register(b Builder) {
name := strings.ToLower(b.Name()) name := b.Name()
if name != b.Name() { if !envconfig.CaseSensitiveBalancerRegistries {
// TODO: Skip the use of strings.ToLower() to index the map after v1.59 name = strings.ToLower(name)
// is released to switch to case sensitive balancer registry. Also, if name != b.Name() {
// remove this warning and update the docstrings for Register and Get. logger.Warningf("Balancer registered with name %q. grpc-go will be switching to case sensitive balancer registries soon. After 2 releases, we will enable the env var by default.", b.Name())
logger.Warningf("Balancer registered with name %q. grpc-go will be switching to case sensitive balancer registries soon", b.Name()) }
} }
m[name] = b m[name] = b
} }
@@ -78,16 +79,17 @@ func init() {
} }
// Get returns the resolver builder registered with the given name. // Get returns the resolver builder registered with the given name.
// Note that the compare is done in a case-insensitive fashion. // Note that the compare is done in a case-sensitive fashion.
// If no builder is register with the name, nil will be returned. // If no builder is register with the name, nil will be returned.
func Get(name string) Builder { func Get(name string) Builder {
if strings.ToLower(name) != name { if !envconfig.CaseSensitiveBalancerRegistries {
// TODO: Skip the use of strings.ToLower() to index the map after v1.59 lowerName := strings.ToLower(name)
// is released to switch to case sensitive balancer registry. Also, if lowerName != name {
// remove this warning and update the docstrings for Register and Get. logger.Warningf("Balancer retrieved for name %q. grpc-go will be switching to case sensitive balancer registries soon. After 2 releases, we will enable the env var by default.", name)
logger.Warningf("Balancer retrieved for name %q. grpc-go will be switching to case sensitive balancer registries soon", name) }
name = lowerName
} }
if b, ok := m[strings.ToLower(name)]; ok { if b, ok := m[name]; ok {
return b return b
} }
return nil return nil
+2 -4
View File
@@ -121,8 +121,7 @@ func (b *baseBalancer) UpdateClientConnState(s balancer.ClientConnState) error {
sc.Connect() sc.Connect()
} }
} }
for _, a := range b.subConns.Keys() { for a, sc := range b.subConns.All() {
sc, _ := b.subConns.Get(a)
// a was removed by resolver. // a was removed by resolver.
if _, ok := addrsSet.Get(a); !ok { if _, ok := addrsSet.Get(a); !ok {
sc.Shutdown() sc.Shutdown()
@@ -171,8 +170,7 @@ func (b *baseBalancer) regeneratePicker() {
readySCs := make(map[balancer.SubConn]SubConnInfo) readySCs := make(map[balancer.SubConn]SubConnInfo)
// Filter out all ready SCs from full subConn map. // Filter out all ready SCs from full subConn map.
for _, addr := range b.subConns.Keys() { for addr, sc := range b.subConns.All() {
sc, _ := b.subConns.Get(addr)
if st, ok := b.scStates[sc]; ok && st == connectivity.Ready { if st, ok := b.scStates[sc]; ok && st == connectivity.Ready {
readySCs[sc] = SubConnInfo{Address: addr} readySCs[sc] = SubConnInfo{Address: addr}
} }
@@ -187,8 +187,7 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState
} }
} }
// Delete old children that are no longer present. // Delete old children that are no longer present.
for _, e := range children.Keys() { for e, child := range children.All() {
child, _ := children.Get(e)
if _, ok := newChildren.Get(e); !ok { if _, ok := newChildren.Get(e); !ok {
child.closeLocked() child.closeLocked()
} }
@@ -212,7 +211,7 @@ func (es *endpointSharding) ResolverError(err error) {
es.updateState() es.updateState()
}() }()
children := es.children.Load() children := es.children.Load()
for _, child := range children.Values() { for _, child := range children.All() {
child.resolverErrorLocked(err) child.resolverErrorLocked(err)
} }
} }
@@ -225,7 +224,7 @@ func (es *endpointSharding) Close() {
es.childMu.Lock() es.childMu.Lock()
defer es.childMu.Unlock() defer es.childMu.Unlock()
children := es.children.Load() children := es.children.Load()
for _, child := range children.Values() { for _, child := range children.All() {
child.closeLocked() child.closeLocked()
} }
} }
@@ -233,7 +232,7 @@ func (es *endpointSharding) Close() {
func (es *endpointSharding) ExitIdle() { func (es *endpointSharding) ExitIdle() {
es.childMu.Lock() es.childMu.Lock()
defer es.childMu.Unlock() defer es.childMu.Unlock()
for _, bw := range es.children.Load().Values() { for _, bw := range es.children.Load().All() {
if !bw.isClosed { if !bw.isClosed {
bw.child.ExitIdle() bw.child.ExitIdle()
} }
@@ -255,7 +254,7 @@ func (es *endpointSharding) updateState() {
children := es.children.Load() children := es.children.Load()
childStates := make([]ChildState, 0, children.Len()) childStates := make([]ChildState, 0, children.Len())
for _, child := range children.Values() { for _, child := range children.All() {
childState := child.childState childState := child.childState
childStates = append(childStates, childState) childStates = append(childStates, childState)
childPicker := childState.State.Picker childPicker := childState.State.Picker
+6 -6
View File
@@ -399,14 +399,14 @@ func (b *pickfirstBalancer) startFirstPassLocked() {
b.firstPass = true b.firstPass = true
b.numTF = 0 b.numTF = 0
// Reset the connection attempt record for existing SubConns. // Reset the connection attempt record for existing SubConns.
for _, sd := range b.subConns.Values() { for _, sd := range b.subConns.All() {
sd.connectionFailedInFirstPass = false sd.connectionFailedInFirstPass = false
} }
b.requestConnectionLocked() b.requestConnectionLocked()
} }
func (b *pickfirstBalancer) closeSubConnsLocked() { func (b *pickfirstBalancer) closeSubConnsLocked() {
for _, sd := range b.subConns.Values() { for _, sd := range b.subConns.All() {
sd.subConn.Shutdown() sd.subConn.Shutdown()
} }
b.subConns = resolver.NewAddressMapV2[*scData]() b.subConns = resolver.NewAddressMapV2[*scData]()
@@ -506,7 +506,7 @@ func (b *pickfirstBalancer) reconcileSubConnsLocked(newAddrs []resolver.Address)
newAddrsMap.Set(addr, true) newAddrsMap.Set(addr, true)
} }
for _, oldAddr := range b.subConns.Keys() { for oldAddr := range b.subConns.All() {
if _, ok := newAddrsMap.Get(oldAddr); ok { if _, ok := newAddrsMap.Get(oldAddr); ok {
continue continue
} }
@@ -520,7 +520,7 @@ func (b *pickfirstBalancer) reconcileSubConnsLocked(newAddrs []resolver.Address)
// becomes ready, which means that all other subConn must be shutdown. // becomes ready, which means that all other subConn must be shutdown.
func (b *pickfirstBalancer) shutdownRemainingLocked(selected *scData) { func (b *pickfirstBalancer) shutdownRemainingLocked(selected *scData) {
b.cancelConnectionTimer() b.cancelConnectionTimer()
for _, sd := range b.subConns.Values() { for _, sd := range b.subConns.All() {
if sd.subConn != selected.subConn { if sd.subConn != selected.subConn {
sd.subConn.Shutdown() sd.subConn.Shutdown()
} }
@@ -771,7 +771,7 @@ func (b *pickfirstBalancer) endFirstPassIfPossibleLocked(lastErr error) {
} }
// Connect() has been called on all the SubConns. The first pass can be // Connect() has been called on all the SubConns. The first pass can be
// ended if all the SubConns have reported a failure. // ended if all the SubConns have reported a failure.
for _, sd := range b.subConns.Values() { for _, sd := range b.subConns.All() {
if !sd.connectionFailedInFirstPass { if !sd.connectionFailedInFirstPass {
return return
} }
@@ -782,7 +782,7 @@ func (b *pickfirstBalancer) endFirstPassIfPossibleLocked(lastErr error) {
Picker: &picker{err: lastErr}, Picker: &picker{err: lastErr},
}) })
// Start re-connecting all the SubConns that are already in IDLE. // Start re-connecting all the SubConns that are already in IDLE.
for _, sd := range b.subConns.Values() { for _, sd := range b.subConns.All() {
if sd.rawConnectivityState == connectivity.Idle { if sd.rawConnectivityState == connectivity.Idle {
sd.subConn.Connect() sd.subConn.Connect()
} }
+1 -1
View File
@@ -18,7 +18,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.36.10 // protoc-gen-go v1.36.11
// protoc v5.27.1 // protoc v5.27.1
// source: grpc/binlog/v1/binarylog.proto // source: grpc/binlog/v1/binarylog.proto
+43 -5
View File
@@ -24,10 +24,12 @@ import (
"fmt" "fmt"
"math" "math"
"net/url" "net/url"
"os"
"slices" "slices"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"syscall"
"time" "time"
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
@@ -1268,8 +1270,9 @@ type addrConn struct {
channelz *channelz.SubChannel channelz *channelz.SubChannel
localityLabel string localityLabel string
backendServiceLabel string backendServiceLabel string
disconnectErrorLabel string
} }
// Note: this requires a lock on ac.mu. // Note: this requires a lock on ac.mu.
@@ -1286,9 +1289,14 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error)
// TODO: https://github.com/grpc/grpc-go/issues/7862 - Remove the second // TODO: https://github.com/grpc/grpc-go/issues/7862 - Remove the second
// part of the if condition below once the issue is fixed. // part of the if condition below once the issue is fixed.
if ac.state == connectivity.Ready || (ac.state == connectivity.Connecting && s == connectivity.Idle) { if ac.state == connectivity.Ready || (ac.state == connectivity.Connecting && s == connectivity.Idle) {
disconnectionsMetric.Record(ac.cc.metricsRecorderList, 1, ac.cc.target, ac.backendServiceLabel, ac.localityLabel, "unknown") disconnectError := ac.disconnectErrorLabel
if disconnectError == "" {
disconnectError = "unknown"
}
disconnectionsMetric.Record(ac.cc.metricsRecorderList, 1, ac.cc.target, ac.backendServiceLabel, ac.localityLabel, disconnectError)
openConnectionsMetric.Record(ac.cc.metricsRecorderList, -1, ac.cc.target, ac.backendServiceLabel, ac.securityLevelLocked(), ac.localityLabel) openConnectionsMetric.Record(ac.cc.metricsRecorderList, -1, ac.cc.target, ac.backendServiceLabel, ac.securityLevelLocked(), ac.localityLabel)
} }
ac.disconnectErrorLabel = "" // Reset for next time
ac.state = s ac.state = s
ac.channelz.ChannelMetrics.State.Store(&s) ac.channelz.ChannelMetrics.State.Store(&s)
if lastErr == nil { if lastErr == nil {
@@ -1483,11 +1491,11 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
addr.ServerName = ac.cc.getServerName(addr) addr.ServerName = ac.cc.getServerName(addr)
hctx, hcancel := context.WithCancel(ctx) hctx, hcancel := context.WithCancel(ctx)
onClose := func(r transport.GoAwayReason) { onClose := func(info transport.GoAwayInfo) {
ac.mu.Lock() ac.mu.Lock()
defer ac.mu.Unlock() defer ac.mu.Unlock()
// adjust params based on GoAwayReason // adjust params based on GoAwayReason
ac.adjustParams(r) ac.adjustParams(info.Reason)
if ctx.Err() != nil { if ctx.Err() != nil {
// Already shut down or connection attempt canceled. tearDown() or // Already shut down or connection attempt canceled. tearDown() or
// updateAddrs() already cleared the transport and canceled hctx // updateAddrs() already cleared the transport and canceled hctx
@@ -1504,6 +1512,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
return return
} }
ac.transport = nil ac.transport = nil
ac.disconnectErrorLabel = disconnectErrorString(info)
// Refresh the name resolver on any connection loss. // Refresh the name resolver on any connection loss.
ac.cc.resolveNow(resolver.ResolveNowOptions{}) ac.cc.resolveNow(resolver.ResolveNowOptions{})
// Always go idle and wait for the LB policy to initiate a new // Always go idle and wait for the LB policy to initiate a new
@@ -1560,6 +1569,32 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
return nil return nil
} }
// disconnectErrorString returns the grpc.disconnect_error metric label corresponding
// to the provided transport.GoAwayInfo, as specified by gRFC A94:
// https://github.com/grpc/proposal/blob/master/A94-grpc-subchannel-disconnections-metrics.md
func disconnectErrorString(info transport.GoAwayInfo) string {
err := info.Err
var sysErr syscall.Errno
switch {
case info.Reason != transport.GoAwayInvalid:
return fmt.Sprintf("GOAWAY %s", info.GoAwayCode.String())
case err == nil:
return "unknown"
case errors.Is(err, context.Canceled):
return "subchannel shutdown"
case errors.Is(err, syscall.ECONNRESET):
return "connection reset"
case errors.Is(err, syscall.ETIMEDOUT), errors.Is(err, context.DeadlineExceeded), errors.Is(err, os.ErrDeadlineExceeded):
return "connection timed out"
case errors.Is(err, syscall.ECONNABORTED):
return "connection aborted"
case errors.As(err, &sysErr):
return "socket error"
default:
return "unknown"
}
}
// startHealthCheck starts the health checking stream (RPC) to watch the health // startHealthCheck starts the health checking stream (RPC) to watch the health
// stats of this connection if health checking is requested and configured. // stats of this connection if health checking is requested and configured.
// //
@@ -1663,6 +1698,9 @@ func (ac *addrConn) tearDown(err error) {
} }
curTr := ac.transport curTr := ac.transport
ac.transport = nil ac.transport = nil
if ac.disconnectErrorLabel == "" {
ac.disconnectErrorLabel = "subchannel shutdown"
}
// We have to set the state to Shutdown before anything else to prevent races // We have to set the state to Shutdown before anything else to prevent races
// between setting the state and logic that waits on context cancellation / etc. // between setting the state and logic that waits on context cancellation / etc.
ac.updateConnectivityState(connectivity.Shutdown, nil) ac.updateConnectivityState(connectivity.Shutdown, nil)
+8 -10
View File
@@ -22,7 +22,6 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors"
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
@@ -52,22 +51,21 @@ func (t TLSInfo) AuthType() string {
} }
// ValidateAuthority validates the provided authority being used to override the // ValidateAuthority validates the provided authority being used to override the
// :authority header by verifying it against the peer certificates. It returns a // :authority header by verifying it against the peer certificate. It returns a
// non-nil error if the validation fails. // non-nil error if the validation fails.
func (t TLSInfo) ValidateAuthority(authority string) error { func (t TLSInfo) ValidateAuthority(authority string) error {
var errs []error
host, _, err := net.SplitHostPort(authority) host, _, err := net.SplitHostPort(authority)
if err != nil { if err != nil {
host = authority host = authority
} }
for _, cert := range t.State.PeerCertificates {
var err error // Verify authority against the leaf certificate.
if err = cert.VerifyHostname(host); err == nil { if len(t.State.PeerCertificates) == 0 {
return nil // This is not expected to happen as the TLS handshake has already
} // completed and should have populated PeerCertificates.
errs = append(errs, err) return fmt.Errorf("credentials: no peer certificates found to verify authority %q", host)
} }
return fmt.Errorf("credentials: invalid authority %q: %v", authority, errors.Join(errs...)) return t.State.PeerCertificates[0].VerifyHostname(host)
} }
// cipherSuiteLookup returns the string version of a TLS cipher suite ID. // cipherSuiteLookup returns the string version of a TLS cipher suite ID.
+5 -4
View File
@@ -705,10 +705,11 @@ func WithDisableHealthCheck() DialOption {
func defaultDialOptions() dialOptions { func defaultDialOptions() dialOptions {
return dialOptions{ return dialOptions{
copts: transport.ConnectOptions{ copts: transport.ConnectOptions{
ReadBufferSize: defaultReadBufSize, ReadBufferSize: defaultReadBufSize,
WriteBufferSize: defaultWriteBufSize, WriteBufferSize: defaultWriteBufSize,
UserAgent: grpcUA, SharedWriteBuffer: true,
BufferPool: mem.DefaultBufferPool(), UserAgent: grpcUA,
BufferPool: mem.DefaultBufferPool(),
}, },
bs: internalbackoff.DefaultExponential, bs: internalbackoff.DefaultExponential,
idleTimeout: 30 * time.Minute, idleTimeout: 30 * time.Minute,
+17
View File
@@ -20,10 +20,27 @@
package stats package stats
import ( import (
"context"
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
) )
type customLabelKey struct{}
// NewContextWithCustomLabel returns a new context with the provided custom label
// attached. The label will be propagated to all metric instruments specified in gRFC A108.
func NewContextWithCustomLabel(ctx context.Context, label string) context.Context {
return context.WithValue(ctx, customLabelKey{}, label)
}
// CustomLabelFromContext returns the custom label from the context if it exists.
// If the custom label is not present, it returns an empty string.
func CustomLabelFromContext(ctx context.Context) string {
label, _ := ctx.Value(customLabelKey{}).(string)
return label
}
// MetricsRecorder records on metrics derived from metric registry. // MetricsRecorder records on metrics derived from metric registry.
// Implementors must embed UnimplementedMetricsRecorder. // Implementors must embed UnimplementedMetricsRecorder.
type MetricsRecorder interface { type MetricsRecorder interface {
+1 -1
View File
@@ -17,7 +17,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.36.10 // protoc-gen-go v1.36.11
// protoc v5.27.1 // protoc v5.27.1
// source: grpc/health/v1/health.proto // source: grpc/health/v1/health.proto
+1 -1
View File
@@ -17,7 +17,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.6.0 // - protoc-gen-go-grpc v1.6.1
// - protoc v5.27.1 // - protoc v5.27.1
// source: grpc/health/v1/health.proto // source: grpc/health/v1/health.proto
+56 -8
View File
@@ -54,17 +54,16 @@ var (
// XDSEndpointHashKeyBackwardCompat controls the parsing of the endpoint hash // XDSEndpointHashKeyBackwardCompat controls the parsing of the endpoint hash
// key from EDS LbEndpoint metadata. Endpoint hash keys can be disabled by // key from EDS LbEndpoint metadata. Endpoint hash keys can be disabled by
// setting "GRPC_XDS_ENDPOINT_HASH_KEY_BACKWARD_COMPAT" to "true". When the // setting "GRPC_XDS_ENDPOINT_HASH_KEY_BACKWARD_COMPAT" to "true". A future
// implementation of A76 is stable, we will flip the default value to false // release will remove this environment variable, enabling the new behavior
// in a subsequent release. A final release will remove this environment // unconditionally.
// variable, enabling the new behavior unconditionally. XDSEndpointHashKeyBackwardCompat = boolFromEnv("GRPC_XDS_ENDPOINT_HASH_KEY_BACKWARD_COMPAT", false)
XDSEndpointHashKeyBackwardCompat = boolFromEnv("GRPC_XDS_ENDPOINT_HASH_KEY_BACKWARD_COMPAT", true)
// RingHashSetRequestHashKey is set if the ring hash balancer can get the // RingHashSetRequestHashKey is set if the ring hash balancer can get the
// request hash header by setting the "requestHashHeader" field, according // request hash header by setting the "requestHashHeader" field, according
// to gRFC A76. It can be enabled by setting the environment variable // to gRFC A76. It can be disabled by setting the environment variable
// "GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY" to "true". // "GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY" to "false".
RingHashSetRequestHashKey = boolFromEnv("GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY", false) RingHashSetRequestHashKey = boolFromEnv("GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY", true)
// ALTSHandshakerKeepaliveParams is set if we should add the // ALTSHandshakerKeepaliveParams is set if we should add the
// KeepaliveParams when dial the ALTS handshaker service. // KeepaliveParams when dial the ALTS handshaker service.
@@ -78,6 +77,14 @@ var (
// - The DNS resolver is being used. // - The DNS resolver is being used.
EnableDefaultPortForProxyTarget = boolFromEnv("GRPC_EXPERIMENTAL_ENABLE_DEFAULT_PORT_FOR_PROXY_TARGET", true) EnableDefaultPortForProxyTarget = boolFromEnv("GRPC_EXPERIMENTAL_ENABLE_DEFAULT_PORT_FOR_PROXY_TARGET", true)
// CaseSensitiveBalancerRegistries is set if the balancer registry should be
// case-sensitive. This is disabled by default, but can be enabled by setting
// the env variable "GRPC_GO_EXPERIMENTAL_CASE_SENSITIVE_BALANCER_REGISTRIES"
// to "true".
//
// TODO: After 2 releases, we will enable the env var by default.
CaseSensitiveBalancerRegistries = boolFromEnv("GRPC_GO_EXPERIMENTAL_CASE_SENSITIVE_BALANCER_REGISTRIES", false)
// XDSAuthorityRewrite indicates whether xDS authority rewriting is enabled. // XDSAuthorityRewrite indicates whether xDS authority rewriting is enabled.
// This feature is defined in gRFC A81 and is enabled by setting the // This feature is defined in gRFC A81 and is enabled by setting the
// environment variable GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE to "true". // environment variable GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE to "true".
@@ -88,6 +95,47 @@ var (
// feature can be disabled by setting the environment variable // feature can be disabled by setting the environment variable
// GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING to "false". // GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING to "false".
PickFirstWeightedShuffling = boolFromEnv("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true) PickFirstWeightedShuffling = boolFromEnv("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true)
// XDSRecoverPanicInResourceParsing indicates whether the xdsclient should
// recover from panics while parsing xDS resources.
//
// This feature can be disabled (e.g. for fuzz testing) by setting the
// environment variable "GRPC_GO_EXPERIMENTAL_XDS_RESOURCE_PANIC_RECOVERY"
// to "false".
XDSRecoverPanicInResourceParsing = boolFromEnv("GRPC_GO_EXPERIMENTAL_XDS_RESOURCE_PANIC_RECOVERY", true)
// DisableStrictPathChecking indicates whether strict path checking is
// disabled. This feature can be disabled by setting the environment
// variable GRPC_GO_EXPERIMENTAL_DISABLE_STRICT_PATH_CHECKING to "true".
//
// When strict path checking is enabled, gRPC will reject requests with
// paths that do not conform to the gRPC over HTTP/2 specification found at
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md.
//
// When disabled, gRPC will allow paths that do not contain a leading slash.
// Enabling strict path checking is recommended for security reasons, as it
// prevents potential path traversal vulnerabilities.
//
// A future release will remove this environment variable, enabling strict
// path checking behavior unconditionally.
DisableStrictPathChecking = boolFromEnv("GRPC_GO_EXPERIMENTAL_DISABLE_STRICT_PATH_CHECKING", false)
// EnablePriorityLBChildPolicyCache controls whether the priority balancer
// should cache child balancers that are removed from the LB policy config,
// for a period of 15 minutes. This is disabled by default, but can be
// enabled by setting the env variable
// GRPC_EXPERIMENTAL_ENABLE_PRIORITY_LB_CHILD_POLICY_CACHE to true.
EnablePriorityLBChildPolicyCache = boolFromEnv("GRPC_EXPERIMENTAL_ENABLE_PRIORITY_LB_CHILD_POLICY_CACHE", false)
// EnableHTTPFramerReadBufferPooling enables the use of the
// readyreader.Reader interface to perform non-memory-pinning reads,
// provided the underlying net.Conn supports it. This reduces memory usage
// when subchannels are idle.
//
// This environment variable serves as an escape hatch to disable the
// feature if unforeseen issues arise, and it will be removed in a future
// release.
EnableHTTPFramerReadBufferPooling = boolFromEnv("GRPC_GO_EXPERIMENTAL_HTTP_FRAMER_READ_BUFFER_POOLING", true)
) )
func boolFromEnv(envVar string, def bool) bool { func boolFromEnv(envVar string, def bool) bool {
+10
View File
@@ -79,4 +79,14 @@ var (
// xDS bootstrap configuration via the `call_creds` field. For more details, // xDS bootstrap configuration via the `call_creds` field. For more details,
// see: https://github.com/grpc/proposal/blob/master/A97-xds-jwt-call-creds.md // see: https://github.com/grpc/proposal/blob/master/A97-xds-jwt-call-creds.md
XDSBootstrapCallCredsEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_BOOTSTRAP_CALL_CREDS", false) XDSBootstrapCallCredsEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_BOOTSTRAP_CALL_CREDS", false)
// XDSSNIEnabled controls if gRPC should send SNI information in xDS
// configured TLS handshakes. For more details, see:
// https://github.com/grpc/proposal/blob/master/A101-SNI-setting-and-SNI-SAN-validation.md
XDSSNIEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_SNI", false)
// XDSORCAToLRSPropEnabled controls whether ORCA metrics are explicitly
// filtered and prefix-propagated to the LRS server. For more details, see:
// https://github.com/grpc/proposal/blob/master/A85-lrs-custom-metrics-changes.md
XDSORCAToLRSPropEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_ORCA_LRS_PROPAGATION", false)
) )
+349
View File
@@ -0,0 +1,349 @@
/*
*
* Copyright 2026 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package mem provides utilities that facilitate memory reuse in byte slices
// that are used as buffers.
package mem
import (
"fmt"
"math/bits"
"slices"
"sort"
"sync"
)
const (
goPageSize = 4 * 1024 // 4KiB. N.B. this must be a power of 2.
)
var uintSize = bits.UintSize // use a variable for mocking during tests.
// bufferPool is a copy of the public bufferPool interface used to avoid
// circular dependencies.
type bufferPool interface {
// Get returns a buffer with specified length from the pool.
Get(length int) *[]byte
// Put returns a buffer to the pool.
//
// The provided pointer must hold a prefix of the buffer obtained via
// BufferPool.Get to ensure the buffer's entire capacity can be re-used.
Put(*[]byte)
}
// BinaryTieredBufferPool is a buffer pool that uses multiple sub-pools with
// power-of-two sizes.
type BinaryTieredBufferPool struct {
// exponentToNextLargestPoolMap maps a power-of-two exponent (e.g., 12 for
// 4KB) to the index of the next largest sizedBufferPool. This is used by
// Get() to find the smallest pool that can satisfy a request for a given
// size.
exponentToNextLargestPoolMap []int
// exponentToPreviousLargestPoolMap maps a power-of-two exponent to the
// index of the previous largest sizedBufferPool. This is used by Put()
// to return a buffer to the most appropriate pool based on its capacity.
exponentToPreviousLargestPoolMap []int
sizedPools []bufferPool
fallbackPool bufferPool
maxPoolCap int // Optimization: Cache max capacity
}
// NewBinaryTieredBufferPool returns a BufferPool backed by multiple sub-pools.
// This structure enables O(1) lookup time for Get and Put operations.
//
// The arguments provided are the exponents for the buffer capacities (powers
// of 2), not the raw byte sizes. For example, to create a pool of 16KB buffers
// (2^14 bytes), pass 14 as the argument.
func NewBinaryTieredBufferPool(powerOfTwoExponents ...uint8) (*BinaryTieredBufferPool, error) {
return newBinaryTiered(func(size int) bufferPool {
return newSizedBufferPool(size, true)
}, &SimpleBufferPool{shouldZero: true}, powerOfTwoExponents...)
}
// NewDirtyBinaryTieredBufferPool returns a BufferPool backed by multiple
// sub-pools. It is similar to NewBinaryTieredBufferPool but it does not
// initialize the buffers before returning them.
func NewDirtyBinaryTieredBufferPool(powerOfTwoExponents ...uint8) (*BinaryTieredBufferPool, error) {
return newBinaryTiered(func(size int) bufferPool {
return newSizedBufferPool(size, false)
}, NewDirtySimplePool(), powerOfTwoExponents...)
}
func newBinaryTiered(sizedPoolFactory func(int) bufferPool, fallbackPool bufferPool, powerOfTwoExponents ...uint8) (*BinaryTieredBufferPool, error) {
slices.Sort(powerOfTwoExponents)
powerOfTwoExponents = slices.Compact(powerOfTwoExponents)
// Determine the maximum exponent we need to support. This depends on the
// word size (32-bit vs 64-bit).
maxExponent := uintSize - 2
indexOfNextLargestBit := slices.Repeat([]int{-1}, maxExponent+1)
indexOfPreviousLargestBit := slices.Repeat([]int{-1}, maxExponent+1)
maxTier := 0
pools := make([]bufferPool, 0, len(powerOfTwoExponents))
for i, exp := range powerOfTwoExponents {
// Allocating slices of size > 2^maxExponent isn't possible on
// maxExponent-bit machines.
if int(exp) > maxExponent {
return nil, fmt.Errorf("mem: allocating slice of size 2^%d is not possible", exp)
}
tierSize := 1 << exp
pools = append(pools, sizedPoolFactory(tierSize))
maxTier = max(maxTier, tierSize)
// Map the exact power of 2 to this pool index.
indexOfNextLargestBit[exp] = i
indexOfPreviousLargestBit[exp] = i
}
// Fill gaps for Get() (Next Largest)
// We iterate backwards. If current is empty, take the value from the right (larger).
for i := maxExponent - 1; i >= 0; i-- {
if indexOfNextLargestBit[i] == -1 {
indexOfNextLargestBit[i] = indexOfNextLargestBit[i+1]
}
}
// Fill gaps for Put() (Previous Largest)
// We iterate forwards. If current is empty, take the value from the left (smaller).
for i := 1; i <= maxExponent; i++ {
if indexOfPreviousLargestBit[i] == -1 {
indexOfPreviousLargestBit[i] = indexOfPreviousLargestBit[i-1]
}
}
return &BinaryTieredBufferPool{
exponentToNextLargestPoolMap: indexOfNextLargestBit,
exponentToPreviousLargestPoolMap: indexOfPreviousLargestBit,
sizedPools: pools,
maxPoolCap: maxTier,
fallbackPool: fallbackPool,
}, nil
}
// Get returns a buffer with specified length from the pool.
func (b *BinaryTieredBufferPool) Get(size int) *[]byte {
return b.poolForGet(size).Get(size)
}
func (b *BinaryTieredBufferPool) poolForGet(size int) bufferPool {
if size == 0 || size > b.maxPoolCap {
return b.fallbackPool
}
// Calculate the exponent of the smallest power of 2 >= size.
// We subtract 1 from size to handle exact powers of 2 correctly.
//
// Examples:
// size=16 (0b10000) -> size-1=15 (0b01111) -> bits.Len=4 -> Pool for 2^4
// size=17 (0b10001) -> size-1=16 (0b10000) -> bits.Len=5 -> Pool for 2^5
querySize := uint(size - 1)
poolIdx := b.exponentToNextLargestPoolMap[bits.Len(querySize)]
return b.sizedPools[poolIdx]
}
// Put returns a buffer to the pool.
func (b *BinaryTieredBufferPool) Put(buf *[]byte) {
// We pass the capacity of the buffer, and not the size of the buffer here.
// If we did the latter, all buffers would eventually move to the smallest
// pool.
b.poolForPut(cap(*buf)).Put(buf)
}
func (b *BinaryTieredBufferPool) poolForPut(bCap int) bufferPool {
if bCap == 0 {
return NopBufferPool{}
}
if bCap > b.maxPoolCap {
return b.fallbackPool
}
// Find the pool with the largest capacity <= bCap.
//
// We calculate the exponent of the largest power of 2 <= bCap.
// bits.Len(x) returns the minimum number of bits required to represent x;
// i.e. the number of bits up to and including the most significant bit.
// Subtracting 1 gives the 0-based index of the most significant bit,
// which is the exponent of the largest power of 2 <= bCap.
//
// Examples:
// cap=16 (0b10000) -> Len=5 -> 5-1=4 -> 2^4
// cap=15 (0b01111) -> Len=4 -> 4-1=3 -> 2^3
largestPowerOfTwo := bits.Len(uint(bCap)) - 1
poolIdx := b.exponentToPreviousLargestPoolMap[largestPowerOfTwo]
// The buffer is smaller than the smallest power of 2, discard it.
if poolIdx == -1 {
// Buffer is smaller than our smallest pool bucket.
return NopBufferPool{}
}
return b.sizedPools[poolIdx]
}
// NopBufferPool is a buffer pool that returns new buffers without pooling.
type NopBufferPool struct{}
// Get returns a buffer with specified length from the pool.
func (NopBufferPool) Get(length int) *[]byte {
b := make([]byte, length)
return &b
}
// Put returns a buffer to the pool.
func (NopBufferPool) Put(*[]byte) {
}
// sizedBufferPool is a BufferPool implementation that is optimized for specific
// buffer sizes. For example, HTTP/2 frames within gRPC have a default max size
// of 16kb and a sizedBufferPool can be configured to only return buffers with a
// capacity of 16kb. Note that however it does not support returning larger
// buffers and in fact panics if such a buffer is requested. Because of this,
// this BufferPool implementation is not meant to be used on its own and rather
// is intended to be embedded in a TieredBufferPool such that Get is only
// invoked when the required size is smaller than or equal to defaultSize.
type sizedBufferPool struct {
pool sync.Pool
defaultSize int
shouldZero bool
}
func (p *sizedBufferPool) Get(size int) *[]byte {
buf, ok := p.pool.Get().(*[]byte)
if !ok {
buf := make([]byte, size, p.defaultSize)
return &buf
}
b := *buf
if p.shouldZero {
clear(b[:cap(b)])
}
*buf = b[:size]
return buf
}
func (p *sizedBufferPool) Put(buf *[]byte) {
if cap(*buf) < p.defaultSize {
// Ignore buffers that are too small to fit in the pool. Otherwise, when
// Get is called it will panic as it tries to index outside the bounds
// of the buffer.
return
}
p.pool.Put(buf)
}
func newSizedBufferPool(size int, zero bool) *sizedBufferPool {
return &sizedBufferPool{
defaultSize: size,
shouldZero: zero,
}
}
// TieredBufferPool implements the BufferPool interface with multiple tiers of
// buffer pools for different sizes of buffers.
type TieredBufferPool struct {
sizedPools []*sizedBufferPool
fallbackPool SimpleBufferPool
}
// NewTieredBufferPool returns a BufferPool implementation that uses multiple
// underlying pools of the given pool sizes.
func NewTieredBufferPool(poolSizes ...int) *TieredBufferPool {
sort.Ints(poolSizes)
pools := make([]*sizedBufferPool, len(poolSizes))
for i, s := range poolSizes {
pools[i] = newSizedBufferPool(s, true)
}
return &TieredBufferPool{
sizedPools: pools,
fallbackPool: SimpleBufferPool{shouldZero: true},
}
}
// Get returns a buffer with specified length from the pool.
func (p *TieredBufferPool) Get(size int) *[]byte {
return p.getPool(size).Get(size)
}
// Put returns a buffer to the pool.
func (p *TieredBufferPool) Put(buf *[]byte) {
p.getPool(cap(*buf)).Put(buf)
}
func (p *TieredBufferPool) getPool(size int) bufferPool {
poolIdx := sort.Search(len(p.sizedPools), func(i int) bool {
return p.sizedPools[i].defaultSize >= size
})
if poolIdx == len(p.sizedPools) {
return &p.fallbackPool
}
return p.sizedPools[poolIdx]
}
// SimpleBufferPool is an implementation of the mem.BufferPool interface that
// attempts to pool buffers with a sync.Pool. When Get is invoked, it tries to
// acquire a buffer from the pool but if that buffer is too small, it returns it
// to the pool and creates a new one.
type SimpleBufferPool struct {
pool sync.Pool
shouldZero bool
}
// NewDirtySimplePool constructs a [SimpleBufferPool]. It does not initialize
// the buffers before returning them. Callers must ensure they don't read the
// buffers before writing data to them.
func NewDirtySimplePool() *SimpleBufferPool {
return &SimpleBufferPool{
shouldZero: false,
}
}
// Get returns a buffer with specified length from the pool.
func (p *SimpleBufferPool) Get(size int) *[]byte {
bs, ok := p.pool.Get().(*[]byte)
if ok && cap(*bs) >= size {
if p.shouldZero {
clear((*bs)[:cap(*bs)])
}
*bs = (*bs)[:size]
return bs
}
// A buffer was pulled from the pool, but it is too small. Put it back in
// the pool and create one large enough.
if ok {
p.pool.Put(bs)
}
// If we're going to allocate, round up to the nearest page. This way if
// requests frequently arrive with small variation we don't allocate
// repeatedly if we get unlucky and they increase over time. By default we
// only allocate here if size > 1MiB. Because goPageSize is a power of 2, we
// can round up efficiently.
allocSize := (size + goPageSize - 1) & ^(goPageSize - 1)
b := make([]byte, size, allocSize)
return &b
}
// Put returns a buffer to the pool.
func (p *SimpleBufferPool) Put(buf *[]byte) {
p.pool.Put(buf)
}
+6
View File
@@ -115,6 +115,9 @@ type ClientInterceptor interface {
// ClientStream after done is called, since the interceptor is invoked by // ClientStream after done is called, since the interceptor is invoked by
// application-layer operations. done must never be nil when called. // application-layer operations. done must never be nil when called.
NewStream(ctx context.Context, ri RPCInfo, done func(), newStream func(ctx context.Context, done func()) (ClientStream, error)) (ClientStream, error) NewStream(ctx context.Context, ri RPCInfo, done func(), newStream func(ctx context.Context, done func()) (ClientStream, error)) (ClientStream, error)
// Close closes the interceptor. Once called, no new calls to NewStream are
// accepted. Ongoing calls to NewStream are allowed to complete.
Close()
} }
// ServerInterceptor is an interceptor for incoming RPC's on gRPC server side. // ServerInterceptor is an interceptor for incoming RPC's on gRPC server side.
@@ -123,6 +126,9 @@ type ServerInterceptor interface {
// information about connection RPC was received on, and HTTP Headers. This // information about connection RPC was received on, and HTTP Headers. This
// information will be piped into context. // information will be piped into context.
AllowRPC(ctx context.Context) error // TODO: Make this a real interceptor for filters such as rate limiting. AllowRPC(ctx context.Context) error // TODO: Make this a real interceptor for filters such as rate limiting.
// Close closes the interceptor. Once called, no new calls to NewStream are
// accepted. Ongoing calls to NewStream are allowed to complete.
Close()
} }
type csKeyType string type csKeyType string
+1
View File
@@ -46,6 +46,7 @@ const (
defaultWriteQuota = 64 * 1024 defaultWriteQuota = 64 * 1024
defaultClientMaxHeaderListSize = uint32(16 << 20) defaultClientMaxHeaderListSize = uint32(16 << 20)
defaultServerMaxHeaderListSize = uint32(16 << 20) defaultServerMaxHeaderListSize = uint32(16 << 20)
upcomingDefaultHeaderListSize = uint32(8 << 10)
) )
// MaxStreamID is the upper bound for the stream ID before the current // MaxStreamID is the upper bound for the stream ID before the current
+17 -7
View File
@@ -134,6 +134,8 @@ type http2Client struct {
// goAwayDebugMessage contains a detailed human readable string about a // goAwayDebugMessage contains a detailed human readable string about a
// GoAway frame, useful for error messages. // GoAway frame, useful for error messages.
goAwayDebugMessage string goAwayDebugMessage string
// goAwayCode records the http2.ErrCode received with the GoAway frame.
goAwayCode http2.ErrCode
// A condition variable used to signal when the keepalive goroutine should // A condition variable used to signal when the keepalive goroutine should
// go dormant. The condition for dormancy is based on the number of active // go dormant. The condition for dormancy is based on the number of active
// streams and the `PermitWithoutStream` keepalive client parameter. And // streams and the `PermitWithoutStream` keepalive client parameter. And
@@ -147,7 +149,7 @@ type http2Client struct {
channelz *channelz.Socket channelz *channelz.Socket
onClose func(GoAwayReason) onClose OnCloseFunc
bufferPool mem.BufferPool bufferPool mem.BufferPool
@@ -204,7 +206,7 @@ func isTemporary(err error) bool {
// NewHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // NewHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction // and starts to receive messages on it. Non-nil error returns if construction
// fails. // fails.
func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onClose func(GoAwayReason)) (_ ClientTransport, err error) { func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onClose OnCloseFunc) (_ ClientTransport, err error) {
scheme := "http" scheme := "http"
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer func() { defer func() {
@@ -871,11 +873,15 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr, handler s
} }
var sz int64 var sz int64
for _, f := range hdr.hf { for _, f := range hdr.hf {
if sz += int64(f.Size()); sz > int64(*t.maxSendHeaderListSize) { sz += int64(f.Size())
if sz > int64(*t.maxSendHeaderListSize) {
hdrListSizeErr = status.Errorf(codes.Internal, "header list size to send violates the maximum size (%d bytes) set by server", *t.maxSendHeaderListSize) hdrListSizeErr = status.Errorf(codes.Internal, "header list size to send violates the maximum size (%d bytes) set by server", *t.maxSendHeaderListSize)
return false return false
} }
} }
if sz > int64(upcomingDefaultHeaderListSize) {
t.logger.Warningf("Header list size to send (%d bytes) is larger than the upcoming default limit (%d bytes). In a future release, this will be restricted to %d bytes.", sz, upcomingDefaultHeaderListSize, upcomingDefaultHeaderListSize)
}
return true return true
} }
for { for {
@@ -1011,7 +1017,7 @@ func (t *http2Client) Close(err error) {
// Call t.onClose ASAP to prevent the client from attempting to create new // Call t.onClose ASAP to prevent the client from attempting to create new
// streams. // streams.
if t.state != draining { if t.state != draining {
t.onClose(GoAwayInvalid) t.onClose(GoAwayInfo{Reason: GoAwayInvalid, GoAwayCode: http2.ErrCodeNo, Err: err})
} }
t.state = closing t.state = closing
streams := t.activeStreams streams := t.activeStreams
@@ -1082,7 +1088,7 @@ func (t *http2Client) GracefulClose() {
if t.logger.V(logLevel) { if t.logger.V(logLevel) {
t.logger.Infof("GracefulClose called") t.logger.Infof("GracefulClose called")
} }
t.onClose(GoAwayInvalid) t.onClose(GoAwayInfo{Reason: GoAwayInvalid, GoAwayCode: http2.ErrCodeNo})
t.state = draining t.state = draining
active := len(t.activeStreams) active := len(t.activeStreams)
t.mu.Unlock() t.mu.Unlock()
@@ -1232,7 +1238,10 @@ func (t *http2Client) handleData(f *parsedDataFrame) {
// The server has closed the stream without sending trailers. Record that // The server has closed the stream without sending trailers. Record that
// the read direction is closed, and set the status appropriately. // the read direction is closed, and set the status appropriately.
if f.StreamEnded() { if f.StreamEnded() {
t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true) // If client received END_STREAM from server while stream was still
// active, send RST_STREAM.
rstStream := s.getState() == streamActive
t.closeStream(s, io.EOF, rstStream, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true)
} }
} }
@@ -1368,7 +1377,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) error {
// draining, to allow the client to stop attempting to create streams // draining, to allow the client to stop attempting to create streams
// before disallowing new streams on this connection. // before disallowing new streams on this connection.
if t.state != draining { if t.state != draining {
t.onClose(t.goAwayReason) t.onClose(GoAwayInfo{Reason: t.goAwayReason, GoAwayCode: t.goAwayCode})
t.state = draining t.state = draining
} }
} }
@@ -1418,6 +1427,7 @@ func (t *http2Client) setGoAwayReason(f *http2.GoAwayFrame) {
} else { } else {
t.goAwayDebugMessage = fmt.Sprintf("code: %s, debug data: %q", f.ErrCode, string(f.DebugData())) t.goAwayDebugMessage = fmt.Sprintf("code: %s, debug data: %q", f.ErrCode, string(f.DebugData()))
} }
t.goAwayCode = f.ErrCode
} }
func (t *http2Client) GetGoAwayReason() (GoAwayReason, string) { func (t *http2Client) GetGoAwayReason() (GoAwayReason, string) {
+5 -1
View File
@@ -940,13 +940,17 @@ func (t *http2Server) checkForHeaderListSize(hf []hpack.HeaderField) bool {
} }
var sz int64 var sz int64
for _, f := range hf { for _, f := range hf {
if sz += int64(f.Size()); sz > int64(*t.maxSendHeaderListSize) { sz += int64(f.Size())
if sz > int64(*t.maxSendHeaderListSize) {
if t.logger.V(logLevel) { if t.logger.V(logLevel) {
t.logger.Infof("Header list size to send violates the maximum size (%d bytes) set by client", *t.maxSendHeaderListSize) t.logger.Infof("Header list size to send violates the maximum size (%d bytes) set by client", *t.maxSendHeaderListSize)
} }
return false return false
} }
} }
if sz > int64(upcomingDefaultHeaderListSize) {
t.logger.Warningf("Header list size to send (%d bytes) is larger than the upcoming default limit (%d bytes). In a future release, this will be restricted to %d bytes.", sz, upcomingDefaultHeaderListSize, upcomingDefaultHeaderListSize)
}
return true return true
} }
+31 -23
View File
@@ -36,6 +36,9 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/internal/envconfig"
imem "google.golang.org/grpc/internal/mem"
"google.golang.org/grpc/internal/transport/readyreader"
"google.golang.org/grpc/mem" "google.golang.org/grpc/mem"
) )
@@ -296,7 +299,7 @@ func decodeGrpcMessageUnchecked(msg string) string {
} }
type bufWriter struct { type bufWriter struct {
pool *sync.Pool pool *imem.SimpleBufferPool
buf []byte buf []byte
offset int offset int
batchSize int batchSize int
@@ -304,7 +307,7 @@ type bufWriter struct {
err error err error
} }
func newBufWriter(conn io.Writer, batchSize int, pool *sync.Pool) *bufWriter { func newBufWriter(conn io.Writer, batchSize int, pool *imem.SimpleBufferPool) *bufWriter {
w := &bufWriter{ w := &bufWriter{
batchSize: batchSize, batchSize: batchSize,
conn: conn, conn: conn,
@@ -326,7 +329,7 @@ func (w *bufWriter) Write(b []byte) (int, error) {
return n, toIOError(err) return n, toIOError(err)
} }
if w.buf == nil { if w.buf == nil {
b := w.pool.Get().(*[]byte) b := w.pool.Get(w.batchSize)
w.buf = *b w.buf = *b
} }
written := 0 written := 0
@@ -407,22 +410,32 @@ type framer struct {
errDetail error errDetail error
} }
var writeBufferPoolMap = make(map[int]*sync.Pool) var ioBufferPoolMap = make(map[int]*imem.SimpleBufferPool)
var writeBufferMutex sync.Mutex var ioBufferMutex sync.Mutex
func bufferedReader(r io.Reader, bufSize int) io.Reader {
if bufSize <= 0 {
return r
}
if envconfig.EnableHTTPFramerReadBufferPooling {
if rr := readyreader.NewNonBlocking(r); rr != nil {
readPool := ioBufferPool(bufSize)
return readyreader.NewBuffered(rr, bufSize, readPool)
}
}
return bufio.NewReaderSize(r, bufSize)
}
func newFramer(conn io.ReadWriter, writeBufferSize, readBufferSize int, sharedWriteBuffer bool, maxHeaderListSize uint32, memPool mem.BufferPool) *framer { func newFramer(conn io.ReadWriter, writeBufferSize, readBufferSize int, sharedWriteBuffer bool, maxHeaderListSize uint32, memPool mem.BufferPool) *framer {
if writeBufferSize < 0 { if writeBufferSize < 0 {
writeBufferSize = 0 writeBufferSize = 0
} }
var r io.Reader = conn r := bufferedReader(conn, readBufferSize)
if readBufferSize > 0 { var writePool *imem.SimpleBufferPool
r = bufio.NewReaderSize(r, readBufferSize)
}
var pool *sync.Pool
if sharedWriteBuffer { if sharedWriteBuffer {
pool = getWriteBufferPool(writeBufferSize) writePool = ioBufferPool(writeBufferSize)
} }
w := newBufWriter(conn, writeBufferSize, pool) w := newBufWriter(conn, writeBufferSize, writePool)
f := &framer{ f := &framer{
writer: w, writer: w,
fr: http2.NewFramer(w, r), fr: http2.NewFramer(w, r),
@@ -578,20 +591,15 @@ func (df *parsedDataFrame) Header() http2.FrameHeader {
return df.FrameHeader return df.FrameHeader
} }
func getWriteBufferPool(size int) *sync.Pool { func ioBufferPool(size int) *imem.SimpleBufferPool {
writeBufferMutex.Lock() ioBufferMutex.Lock()
defer writeBufferMutex.Unlock() defer ioBufferMutex.Unlock()
pool, ok := writeBufferPoolMap[size] pool, ok := ioBufferPoolMap[size]
if ok { if ok {
return pool return pool
} }
pool = &sync.Pool{ pool = imem.NewDirtySimplePool()
New: func() any { ioBufferPoolMap[size] = pool
b := make([]byte, size)
return &b
},
}
writeBufferPoolMap[size] = pool
return pool return pool
} }
@@ -0,0 +1,39 @@
/*
*
* Copyright 2026 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package readyreader
import "syscall"
func isRawConnSupported() bool {
return true
}
// sysRead uses the standard syscall package rather than the modern unix package
// to avoid triggering the race detector. Because both packages perform sync
// operations on a local variable to satisfy the race detector, mixing them
// for read and write syscalls causes data races. We use syscall here to remain
// consistent with net.Conn implementations in standard library.
func sysRead(fd uintptr, p []byte) (int, error) {
return syscall.Read(int(fd), p)
}
// wouldBlock checks standard Unix non-blocking errors.
func wouldBlock(err error) bool {
return err == syscall.EAGAIN || err == syscall.EWOULDBLOCK
}
@@ -0,0 +1,35 @@
//go:build !linux
/*
*
* Copyright 2026 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package readyreader
func isRawConnSupported() bool {
return false
}
// sysRead is not implemented. Support can be added in the future if necessary.
func sysRead(uintptr, []byte) (int, error) {
panic("RawConn functionality is not implemented for non-unix platforms.")
}
// wouldBlock is not implemented. Support can be added in the future if necessary.
func wouldBlock(error) bool {
panic("RawConn functionality is not implemented for non-unix platforms.")
}
@@ -0,0 +1,253 @@
/*
*
* Copyright 2026 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package readyreader provides utilities to perform non-memory-pinning reads.
package readyreader
import (
"io"
"net"
"syscall"
"google.golang.org/grpc/mem"
)
// Reader is an optional interface that can be implemented by [net.Conn]
// implementations to enable gRPC to perform non-memory-pinning reads.
type Reader interface {
// ReadOnReady waits for data to arrive, fetches a buffer, and performs a
// read. When the underlying IO is readable, it allocates a buffer of size
// bufSize from the pool and reads up to bufSize bytes into the buffer.
//
// It returns a pointer to the buffer so it can be returned to the pool
// later, the number of bytes read, and an error.
//
// Callers should always process the n > 0 bytes returned before considering
// the error. Doing so correctly handles I/O errors that happen after
// reading some bytes, as well as both of the allowed EOF behaviors.
ReadOnReady(bufSize int, pool mem.BufferPool) (b *[]byte, n int, err error)
}
// nonBlockingReader is optimized for non-memory-pinning reads using the RawConn
// interface.
type nonBlockingReader struct {
raw syscall.RawConn
// The following fields are stored as field to avoid heap allocations.
state readState
doRead func(fd uintptr) bool
}
type readState struct {
// Request params.
bufSize int
pool mem.BufferPool
// Response params.
readError error
bytesRead int
buf *[]byte
}
// NewNonBlocking returns a ReadyReader if the passed reader supports
// non-memory-pinning reads, else nil.
func NewNonBlocking(r io.Reader) Reader {
if rr, ok := r.(Reader); ok {
return rr
}
if !isRawConnSupported() {
return nil
}
// We restrict the types before asserting syscall.Conn. The credentials
// package may return a wrapper that implements syscall.Conn by embedding
// both the raw connection and the encrypted connection. If the code
// attempts to read directly from the raw syscall.RawConn, it would read
// encrypted data.
switch r.(type) {
case *net.TCPConn, *net.UDPConn, *net.UnixConn, *net.IPConn:
default:
return nil
}
sysConn, ok := r.(syscall.Conn)
if !ok {
return nil
}
raw, err := sysConn.SyscallConn()
if err != nil {
return nil
}
rr := &nonBlockingReader{raw: raw}
rr.doRead = func(fd uintptr) bool {
s := &rr.state
s.buf = s.pool.Get(s.bufSize)
s.bytesRead, s.readError = sysRead(fd, *s.buf)
if s.readError != nil {
s.pool.Put(s.buf)
s.buf = nil
}
return !wouldBlock(s.readError)
}
return rr
}
func (c *nonBlockingReader) ReadOnReady(bufSize int, pool mem.BufferPool) (*[]byte, int, error) {
c.state = readState{
pool: pool,
bufSize: bufSize,
}
err := c.raw.Read(c.doRead)
buf := c.state.buf
n := c.state.bytesRead
readErr := c.state.readError
c.state = readState{}
if err != nil {
if buf != nil {
pool.Put(buf)
}
return nil, 0, err
}
if readErr != nil {
// buffer is already released in the callback.
return nil, 0, readErr
}
if n == 0 {
// syscall.Read doesn't consider a graceful socket closure to be an
// error condition, but Go's io.Reader expects an EOF error.
pool.Put(buf)
return nil, 0, io.EOF
}
return buf, n, nil
}
type blockingReader struct {
reader io.Reader
}
func (c *blockingReader) ReadOnReady(bufSize int, pool mem.BufferPool) (*[]byte, int, error) {
buf := pool.Get(bufSize)
n, err := c.reader.Read(*buf)
if err != nil {
pool.Put(buf)
return nil, 0, err
}
return buf, n, nil
}
// New detects if [syscall.RawConn] is available for non-memory-pinning reads.
// If [syscall.RawConn] is unavailable, it falls back to using the simpler
// [io.Reader] interface for reads.
func New(r io.Reader) Reader {
if r := NewNonBlocking(r); r != nil {
return r
}
return &blockingReader{reader: r}
}
// bufReadyReader implements buffering for a ReadyReader object.
// A new bufReadyReader is created by calling [NewBuffered].
type bufReadyReader struct {
buf *[]byte
pool mem.BufferPool
bufSize int
rd Reader // reader provided by the caller
r, w int // buf read and write positions
err error
constPool constBufferPool // stored as a field to avoid heap allocations.
}
// NewBuffered returns a new [io.Reader] with a buffer of the specified size
// which is allocated from the provided pool.
func NewBuffered(rd Reader, size int, pool mem.BufferPool) io.Reader {
return &bufReadyReader{
rd: rd,
pool: pool,
bufSize: size,
}
}
func (b *bufReadyReader) readErr() error {
err := b.err
b.err = nil
return err
}
func (b *bufReadyReader) buffered() int { return b.w - b.r }
// Read reads data into p. It returns the number of bytes read into p. The
// bytes are taken from at most one Read on the underlying [ReadyReader],
// hence n may be less than len(p). If the underlying [ReadyReader] can return
// a non-zero count with io.EOF, then this Read method can do so as well; see
// the [io.Reader] docs.
func (b *bufReadyReader) Read(p []byte) (n int, err error) {
n = len(p)
if n == 0 {
if b.buffered() > 0 {
return 0, nil
}
return 0, b.readErr()
}
if b.r == b.w {
if b.err != nil {
return 0, b.readErr()
}
if len(p) >= b.bufSize {
// Large read, empty buffer.
// Read directly into p to avoid copy.
b.constPool.buffer = p
_, n, b.err = b.rd.ReadOnReady(len(p), &b.constPool)
return n, b.readErr()
}
// One read.
b.r = 0
b.w = 0
b.buf, n, b.err = b.rd.ReadOnReady(b.bufSize, b.pool)
if n == 0 {
if b.buf != nil {
b.pool.Put(b.buf)
b.buf = nil
}
return 0, b.readErr()
}
b.w += n
}
// copy as much as we can
// b.buf must be non-nil since b.r != b.w.
buf := *b.buf
n = copy(p, buf[b.r:b.w])
b.r += n
if b.r == b.w {
// Consumed entire buffer, release it.
b.pool.Put(b.buf)
b.buf = nil
}
return n, nil
}
type constBufferPool struct {
buffer []byte
}
func (p *constBufferPool) Get(int) *[]byte {
return &p.buffer
}
func (p *constBufferPool) Put(*[]byte) {}
+17
View File
@@ -31,6 +31,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"golang.org/x/net/http2"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
@@ -742,6 +743,22 @@ const (
GoAwayTooManyPings GoAwayReason = 2 GoAwayTooManyPings GoAwayReason = 2
) )
// GoAwayInfo contains metadata about why a connection was closed.
type GoAwayInfo struct {
// Reason is the parsed reason for an HTTP/2 GOAWAY frame.
Reason GoAwayReason
// GoAwayCode is the raw HTTP/2 error code received in a GOAWAY frame.
GoAwayCode http2.ErrCode
// Err is the underlying error that caused the connection to close. It is
// populated if the connection was closed due to a socket error or context
// cancellation without receiving a GOAWAY frame. If the connection was
// closed due to a GOAWAY frame, this field will be nil.
Err error
}
// OnCloseFunc is a callback invoked when a ClientTransport closes.
type OnCloseFunc func(GoAwayInfo)
// ContextErr converts the error from context package into a status error. // ContextErr converts the error from context package into a status error.
func ContextErr(err error) error { func ContextErr(err error) error {
switch err { switch err {
+29 -138
View File
@@ -19,10 +19,10 @@
package mem package mem
import ( import (
"sort" "fmt"
"sync"
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/mem"
) )
// BufferPool is a pool of buffers that can be shared and reused, resulting in // BufferPool is a pool of buffers that can be shared and reused, resulting in
@@ -38,20 +38,23 @@ type BufferPool interface {
Put(*[]byte) Put(*[]byte)
} }
const goPageSize = 4 << 10 // 4KiB. N.B. this must be a power of 2. var (
defaultBufferPoolSizeExponents = []uint8{
var defaultBufferPoolSizes = []int{ 8,
256, 12, // Go page size, 4KB
goPageSize, 14, // 16KB (max HTTP/2 frame size used by gRPC)
16 << 10, // 16KB (max HTTP/2 frame size used by gRPC) 15, // 32KB (default buffer size for io.Copy)
32 << 10, // 32KB (default buffer size for io.Copy) 20, // 1MB
1 << 20, // 1MB }
} defaultBufferPool BufferPool
)
var defaultBufferPool BufferPool
func init() { func init() {
defaultBufferPool = NewTieredBufferPool(defaultBufferPoolSizes...) var err error
defaultBufferPool, err = NewBinaryTieredBufferPool(defaultBufferPoolSizeExponents...)
if err != nil {
panic(fmt.Sprintf("Failed to create default buffer pool: %v", err))
}
internal.SetDefaultBufferPool = func(pool BufferPool) { internal.SetDefaultBufferPool = func(pool BufferPool) {
defaultBufferPool = pool defaultBufferPool = pool
@@ -72,134 +75,22 @@ func DefaultBufferPool() BufferPool {
// NewTieredBufferPool returns a BufferPool implementation that uses multiple // NewTieredBufferPool returns a BufferPool implementation that uses multiple
// underlying pools of the given pool sizes. // underlying pools of the given pool sizes.
func NewTieredBufferPool(poolSizes ...int) BufferPool { func NewTieredBufferPool(poolSizes ...int) BufferPool {
sort.Ints(poolSizes) return mem.NewTieredBufferPool(poolSizes...)
pools := make([]*sizedBufferPool, len(poolSizes))
for i, s := range poolSizes {
pools[i] = newSizedBufferPool(s)
}
return &tieredBufferPool{
sizedPools: pools,
}
} }
// tieredBufferPool implements the BufferPool interface with multiple tiers of // NewBinaryTieredBufferPool returns a BufferPool backed by multiple sub-pools.
// buffer pools for different sizes of buffers. // This structure enables O(1) lookup time for Get and Put operations.
type tieredBufferPool struct { //
sizedPools []*sizedBufferPool // The arguments provided are the exponents for the buffer capacities (powers
fallbackPool simpleBufferPool // of 2), not the raw byte sizes. For example, to create a pool of 16KB buffers
// (2^14 bytes), pass 14 as the argument.
func NewBinaryTieredBufferPool(powerOfTwoExponents ...uint8) (BufferPool, error) {
return mem.NewBinaryTieredBufferPool(powerOfTwoExponents...)
} }
func (p *tieredBufferPool) Get(size int) *[]byte { // NopBufferPool is a buffer pool that returns new buffers without pooling.
return p.getPool(size).Get(size) type NopBufferPool struct {
} mem.NopBufferPool
func (p *tieredBufferPool) Put(buf *[]byte) {
p.getPool(cap(*buf)).Put(buf)
}
func (p *tieredBufferPool) getPool(size int) BufferPool {
poolIdx := sort.Search(len(p.sizedPools), func(i int) bool {
return p.sizedPools[i].defaultSize >= size
})
if poolIdx == len(p.sizedPools) {
return &p.fallbackPool
}
return p.sizedPools[poolIdx]
}
// sizedBufferPool is a BufferPool implementation that is optimized for specific
// buffer sizes. For example, HTTP/2 frames within gRPC have a default max size
// of 16kb and a sizedBufferPool can be configured to only return buffers with a
// capacity of 16kb. Note that however it does not support returning larger
// buffers and in fact panics if such a buffer is requested. Because of this,
// this BufferPool implementation is not meant to be used on its own and rather
// is intended to be embedded in a tieredBufferPool such that Get is only
// invoked when the required size is smaller than or equal to defaultSize.
type sizedBufferPool struct {
pool sync.Pool
defaultSize int
}
func (p *sizedBufferPool) Get(size int) *[]byte {
buf, ok := p.pool.Get().(*[]byte)
if !ok {
buf := make([]byte, size, p.defaultSize)
return &buf
}
b := *buf
clear(b[:cap(b)])
*buf = b[:size]
return buf
}
func (p *sizedBufferPool) Put(buf *[]byte) {
if cap(*buf) < p.defaultSize {
// Ignore buffers that are too small to fit in the pool. Otherwise, when
// Get is called it will panic as it tries to index outside the bounds
// of the buffer.
return
}
p.pool.Put(buf)
}
func newSizedBufferPool(size int) *sizedBufferPool {
return &sizedBufferPool{
defaultSize: size,
}
}
var _ BufferPool = (*simpleBufferPool)(nil)
// simpleBufferPool is an implementation of the BufferPool interface that
// attempts to pool buffers with a sync.Pool. When Get is invoked, it tries to
// acquire a buffer from the pool but if that buffer is too small, it returns it
// to the pool and creates a new one.
type simpleBufferPool struct {
pool sync.Pool
}
func (p *simpleBufferPool) Get(size int) *[]byte {
bs, ok := p.pool.Get().(*[]byte)
if ok && cap(*bs) >= size {
clear((*bs)[:cap(*bs)])
*bs = (*bs)[:size]
return bs
}
// A buffer was pulled from the pool, but it is too small. Put it back in
// the pool and create one large enough.
if ok {
p.pool.Put(bs)
}
// If we're going to allocate, round up to the nearest page. This way if
// requests frequently arrive with small variation we don't allocate
// repeatedly if we get unlucky and they increase over time. By default we
// only allocate here if size > 1MiB. Because goPageSize is a power of 2, we
// can round up efficiently.
allocSize := (size + goPageSize - 1) & ^(goPageSize - 1)
b := make([]byte, size, allocSize)
return &b
}
func (p *simpleBufferPool) Put(buf *[]byte) {
p.pool.Put(buf)
} }
var _ BufferPool = NopBufferPool{} var _ BufferPool = NopBufferPool{}
// NopBufferPool is a buffer pool that returns new buffers without pooling.
type NopBufferPool struct{}
// Get returns a buffer with specified length from the pool.
func (NopBufferPool) Get(length int) *[]byte {
b := make([]byte, length)
return &b
}
// Put returns a buffer to the pool.
func (NopBufferPool) Put(*[]byte) {
}
+1 -1
View File
@@ -165,7 +165,7 @@ func (r *Reader) Close() error {
} }
func (r *Reader) freeFirstBufferIfEmpty() bool { func (r *Reader) freeFirstBufferIfEmpty() bool {
if len(r.data) == 0 || r.bufferIdx != len(r.data[0].ReadOnlyData()) { if len(r.data) == 0 || r.bufferIdx != r.data[0].Len() {
return false return false
} }
+40
View File
@@ -53,6 +53,10 @@ type Buffer interface {
Free() Free()
// Len returns the Buffer's size. // Len returns the Buffer's size.
Len() int Len() int
// Slice returns a new Buffer that is a view into this buffer's data
// from [start:end). The buffer is not modified. Panics if the buffer
// has been freed or if start/end are out of bounds.
Slice(start, end int) Buffer
split(n int) (left, right Buffer) split(n int) (left, right Buffer)
read(buf []byte) (int, Buffer) read(buf []byte) (int, Buffer)
@@ -180,6 +184,32 @@ func (b *buffer) Len() int {
return len(b.ReadOnlyData()) return len(b.ReadOnlyData())
} }
func (b *buffer) Slice(start, end int) Buffer {
if b.rootBuf == nil {
panic("Cannot slice freed buffer")
}
data := b.data[start:end] // access the data to check slice bounds
if len(data) == 0 {
return emptyBuffer{}
}
if len(data) == len(b.data) {
b.Ref()
return b
}
// We are creating a new reference (view) to a portion of the root buffer's
// data. Therefore, we must increment the reference count of the root buffer
// to ensure the underlying data is not freed while this view is still in
// use.
b.rootBuf.Ref()
s := newBuffer()
s.data = data
s.rootBuf = b.rootBuf
s.refs.Store(1)
return s
}
func (b *buffer) split(n int) (Buffer, Buffer) { func (b *buffer) split(n int) (Buffer, Buffer) {
if b.rootBuf == nil || b.rootBuf.refs.Add(1) <= 1 { if b.rootBuf == nil || b.rootBuf.refs.Add(1) <= 1 {
panic("Cannot split freed buffer") panic("Cannot split freed buffer")
@@ -240,6 +270,13 @@ func (e emptyBuffer) Len() int {
return 0 return 0
} }
func (e emptyBuffer) Slice(start, end int) Buffer {
if start != 0 || end != 0 {
panic(fmt.Sprintf("slice bounds out of range [%d:%d] with length 0", start, end))
}
return e
}
func (e emptyBuffer) split(int) (left, right Buffer) { func (e emptyBuffer) split(int) (left, right Buffer) {
return e, e return e, e
} }
@@ -264,6 +301,9 @@ func (s SliceBuffer) Free() {}
// Len is a noop implementation of Len. // Len is a noop implementation of Len.
func (s SliceBuffer) Len() int { return len(s) } func (s SliceBuffer) Len() int { return len(s) }
// Slice returns a new SliceBuffer that is a view into the receiver from [start:end).
func (s SliceBuffer) Slice(start, end int) Buffer { return s[start:end] }
func (s SliceBuffer) split(n int) (left, right Buffer) { func (s SliceBuffer) split(n int) (left, right Buffer) {
return s[:n], s[n:] return s[:n], s[n:]
} }
+3 -1
View File
@@ -192,7 +192,9 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
// DoneInfo with default value works. // DoneInfo with default value works.
pickResult.Done(balancer.DoneInfo{}) pickResult.Done(balancer.DoneInfo{})
} }
logger.Infof("blockingPicker: the picked transport is not ready, loop back to repick") if logger.V(2) {
logger.Infof("blockingPicker: the picked transport is not ready, loop back to repick")
}
// If ok == false, ac.state is not READY. // If ok == false, ac.state is not READY.
// A valid picker always returns READY subConn. This means the state of ac // A valid picker always returns READY subConn. This means the state of ac
// just changed, and picker will be updated shortly. // just changed, and picker will be updated shortly.
+34
View File
@@ -20,6 +20,7 @@ package resolver
import ( import (
"encoding/base64" "encoding/base64"
"iter"
"sort" "sort"
"strings" "strings"
) )
@@ -135,6 +136,7 @@ func (a *AddressMapV2[T]) Len() int {
} }
// Keys returns a slice of all current map keys. // Keys returns a slice of all current map keys.
// Deprecated: Use AddressMapV2.All() instead.
func (a *AddressMapV2[T]) Keys() []Address { func (a *AddressMapV2[T]) Keys() []Address {
ret := make([]Address, 0, a.Len()) ret := make([]Address, 0, a.Len())
for _, entryList := range a.m { for _, entryList := range a.m {
@@ -146,6 +148,7 @@ func (a *AddressMapV2[T]) Keys() []Address {
} }
// Values returns a slice of all current map values. // Values returns a slice of all current map values.
// Deprecated: Use AddressMapV2.All() instead.
func (a *AddressMapV2[T]) Values() []T { func (a *AddressMapV2[T]) Values() []T {
ret := make([]T, 0, a.Len()) ret := make([]T, 0, a.Len())
for _, entryList := range a.m { for _, entryList := range a.m {
@@ -156,6 +159,19 @@ func (a *AddressMapV2[T]) Values() []T {
return ret return ret
} }
// All returns an iterator over all elements.
func (a *AddressMapV2[T]) All() iter.Seq2[Address, T] {
return func(yield func(Address, T) bool) {
for _, entryList := range a.m {
for _, entry := range entryList {
if !yield(entry.addr, entry.value) {
return
}
}
}
}
}
type endpointMapKey string type endpointMapKey string
// EndpointMap is a map of endpoints to arbitrary values keyed on only the // EndpointMap is a map of endpoints to arbitrary values keyed on only the
@@ -223,6 +239,7 @@ func (em *EndpointMap[T]) Len() int {
// the unordered set of addresses. Thus, endpoint information returned is not // the unordered set of addresses. Thus, endpoint information returned is not
// the full endpoint data (drops duplicated addresses and attributes) but can be // the full endpoint data (drops duplicated addresses and attributes) but can be
// used for EndpointMap accesses. // used for EndpointMap accesses.
// Deprecated: Use EndpointMap.All() instead.
func (em *EndpointMap[T]) Keys() []Endpoint { func (em *EndpointMap[T]) Keys() []Endpoint {
ret := make([]Endpoint, 0, len(em.endpoints)) ret := make([]Endpoint, 0, len(em.endpoints))
for _, en := range em.endpoints { for _, en := range em.endpoints {
@@ -232,6 +249,7 @@ func (em *EndpointMap[T]) Keys() []Endpoint {
} }
// Values returns a slice of all current map values. // Values returns a slice of all current map values.
// Deprecated: Use EndpointMap.All() instead.
func (em *EndpointMap[T]) Values() []T { func (em *EndpointMap[T]) Values() []T {
ret := make([]T, 0, len(em.endpoints)) ret := make([]T, 0, len(em.endpoints))
for _, val := range em.endpoints { for _, val := range em.endpoints {
@@ -240,6 +258,22 @@ func (em *EndpointMap[T]) Values() []T {
return ret return ret
} }
// All returns an iterator over all elements.
// The map keys are endpoints specifying the addresses present in the endpoint
// map, in which uniqueness is determined by the unordered set of addresses.
// Thus, endpoint information returned is not the full endpoint data (drops
// duplicated addresses and attributes) but can be used for EndpointMap
// accesses.
func (em *EndpointMap[T]) All() iter.Seq2[Endpoint, T] {
return func(yield func(Endpoint, T) bool) {
for _, en := range em.endpoints {
if !yield(en.decodedKey, en.value) {
return
}
}
}
}
// Delete removes the specified endpoint from the map. // Delete removes the specified endpoint from the map.
func (em *EndpointMap[T]) Delete(e Endpoint) { func (em *EndpointMap[T]) Delete(e Endpoint) {
en := encodeEndpoint(e) en := encodeEndpoint(e)
+16 -6
View File
@@ -961,24 +961,32 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM
return out, nil return out, nil
} }
// decompress processes the given data by decompressing it using either a custom decompressor or a standard compressor. // decompress processes the given data by decompressing it using either
// If a custom decompressor is provided, it takes precedence. The function validates that the decompressed data // a custom decompressor or a standard compressor. If a custom decompressor
// does not exceed the specified maximum size and returns an error if this limit is exceeded. // is provided, it takes precedence. The function validates that
// On success, it returns the decompressed data. Otherwise, it returns an error if decompression fails or the data exceeds the size limit. // the decompressed data does not exceed the specified maximum size and returns
// an error if this limit is exceeded. On success, it returns the decompressed
// data. Otherwise, it returns an error if decompression fails or the data
// exceeds the size limit.
func decompress(compressor encoding.Compressor, d mem.BufferSlice, dc Decompressor, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, error) { func decompress(compressor encoding.Compressor, d mem.BufferSlice, dc Decompressor, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, error) {
if dc != nil { if dc != nil {
uncompressed, err := dc.Do(d.Reader()) r := d.Reader()
uncompressed, err := dc.Do(r)
if err != nil { if err != nil {
r.Close() // ensure buffers are reused
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err) return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
} }
if len(uncompressed) > maxReceiveMessageSize { if len(uncompressed) > maxReceiveMessageSize {
r.Close() // ensure buffers are reused
return nil, status.Errorf(codes.ResourceExhausted, "grpc: message after decompression larger than max (%d vs. %d)", len(uncompressed), maxReceiveMessageSize) return nil, status.Errorf(codes.ResourceExhausted, "grpc: message after decompression larger than max (%d vs. %d)", len(uncompressed), maxReceiveMessageSize)
} }
return mem.BufferSlice{mem.SliceBuffer(uncompressed)}, nil return mem.BufferSlice{mem.SliceBuffer(uncompressed)}, nil
} }
if compressor != nil { if compressor != nil {
dcReader, err := compressor.Decompress(d.Reader()) r := d.Reader()
dcReader, err := compressor.Decompress(r)
if err != nil { if err != nil {
r.Close() // ensure buffers are reused
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the message: %v", err) return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the message: %v", err)
} }
@@ -990,11 +998,13 @@ func decompress(compressor encoding.Compressor, d mem.BufferSlice, dc Decompress
} }
out, err := mem.ReadAll(dcReader, pool) out, err := mem.ReadAll(dcReader, pool)
if err != nil { if err != nil {
r.Close() // ensure buffers are reused
out.Free() out.Free()
return nil, status.Errorf(codes.Internal, "grpc: failed to read decompressed data: %v", err) return nil, status.Errorf(codes.Internal, "grpc: failed to read decompressed data: %v", err)
} }
if out.Len() > maxReceiveMessageSize { if out.Len() > maxReceiveMessageSize {
r.Close() // ensure buffers are reused
out.Free() out.Free()
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", maxReceiveMessageSize) return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", maxReceiveMessageSize)
} }
+42 -16
View File
@@ -42,6 +42,7 @@ import (
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/binarylog" "google.golang.org/grpc/internal/binarylog"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/grpcutil"
istats "google.golang.org/grpc/internal/stats" istats "google.golang.org/grpc/internal/stats"
@@ -149,6 +150,8 @@ type Server struct {
serverWorkerChannel chan func() serverWorkerChannel chan func()
serverWorkerChannelClose func() serverWorkerChannelClose func()
strictPathCheckingLogEmitted atomic.Bool
} }
type serverOptions struct { type serverOptions struct {
@@ -189,6 +192,7 @@ var defaultServerOptions = serverOptions{
maxSendMessageSize: defaultServerMaxSendMessageSize, maxSendMessageSize: defaultServerMaxSendMessageSize,
connectionTimeout: 120 * time.Second, connectionTimeout: 120 * time.Second,
writeBufferSize: defaultWriteBufSize, writeBufferSize: defaultWriteBufSize,
sharedWriteBuffer: true,
readBufferSize: defaultReadBufSize, readBufferSize: defaultReadBufSize,
bufferPool: mem.DefaultBufferPool(), bufferPool: mem.DefaultBufferPool(),
} }
@@ -1762,6 +1766,24 @@ func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.Serv
return ss.s.WriteStatus(statusOK) return ss.s.WriteStatus(statusOK)
} }
func (s *Server) handleMalformedMethodName(stream *transport.ServerStream, ti *traceInfo) {
if ti != nil {
ti.tr.LazyLog(&fmtStringer{"Malformed method name %q", []any{stream.Method()}}, true)
ti.tr.SetError()
}
errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
if err := stream.WriteStatus(status.New(codes.Unimplemented, errDesc)); err != nil {
if ti != nil {
ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
ti.tr.SetError()
}
channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err)
}
if ti != nil {
ti.tr.Finish()
}
}
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.ServerStream) { func (s *Server) handleStream(t transport.ServerTransport, stream *transport.ServerStream) {
ctx := stream.Context() ctx := stream.Context()
ctx = contextWithServer(ctx, s) ctx = contextWithServer(ctx, s)
@@ -1782,26 +1804,30 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Ser
} }
sm := stream.Method() sm := stream.Method()
if sm != "" && sm[0] == '/' { if sm == "" {
s.handleMalformedMethodName(stream, ti)
return
}
if sm[0] != '/' {
// TODO(easwars): Add a link to the CVE in the below log messages once
// published.
if envconfig.DisableStrictPathChecking {
if old := s.strictPathCheckingLogEmitted.Swap(true); !old {
channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream received malformed method name %q. Allowing it because the environment variable GRPC_GO_EXPERIMENTAL_DISABLE_STRICT_PATH_CHECKING is set to true, but this option will be removed in a future release.", sm)
}
} else {
if old := s.strictPathCheckingLogEmitted.Swap(true); !old {
channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream rejected malformed method name %q. To temporarily allow such requests, set the environment variable GRPC_GO_EXPERIMENTAL_DISABLE_STRICT_PATH_CHECKING to true. Note that this is not recommended as it may allow requests to bypass security policies.", sm)
}
s.handleMalformedMethodName(stream, ti)
return
}
} else {
sm = sm[1:] sm = sm[1:]
} }
pos := strings.LastIndex(sm, "/") pos := strings.LastIndex(sm, "/")
if pos == -1 { if pos == -1 {
if ti != nil { s.handleMalformedMethodName(stream, ti)
ti.tr.LazyLog(&fmtStringer{"Malformed method name %q", []any{sm}}, true)
ti.tr.SetError()
}
errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
if err := stream.WriteStatus(status.New(codes.Unimplemented, errDesc)); err != nil {
if ti != nil {
ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
ti.tr.SetError()
}
channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err)
}
if ti != nil {
ti.tr.Finish()
}
return return
} }
service := sm[:pos] service := sm[:pos]
+2 -1
View File
@@ -21,6 +21,7 @@ package grpc
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"io" "io"
"math" "math"
rand "math/rand/v2" rand "math/rand/v2"
@@ -749,7 +750,7 @@ func (a *csAttempt) shouldRetry(err error) (bool, error) {
return false, err return false, err
} }
if cs.numRetries+1 >= rp.MaxAttempts { if cs.numRetries+1 >= rp.MaxAttempts {
return false, err return false, fmt.Errorf("max retries exhausted: failed after %d attempts: %w", cs.numRetries+1, err)
} }
var dur time.Duration var dur time.Duration
+1 -1
View File
@@ -19,4 +19,4 @@
package grpc package grpc
// Version is the current grpc version. // Version is the current grpc version.
const Version = "1.79.2" const Version = "1.81.1"
+11 -9
View File
@@ -322,7 +322,7 @@ go.uber.org/mock/mockgen
go.uber.org/mock/mockgen/model go.uber.org/mock/mockgen/model
# golang.org/x/arch v0.4.0 # golang.org/x/arch v0.4.0
## explicit; go 1.17 ## explicit; go 1.17
# golang.org/x/crypto v0.51.0 # golang.org/x/crypto v0.52.0
## explicit; go 1.25.0 ## explicit; go 1.25.0
golang.org/x/crypto/blake2b golang.org/x/crypto/blake2b
golang.org/x/crypto/blowfish golang.org/x/crypto/blowfish
@@ -364,8 +364,8 @@ golang.org/x/net/nettest
golang.org/x/net/proxy golang.org/x/net/proxy
golang.org/x/net/trace golang.org/x/net/trace
golang.org/x/net/websocket golang.org/x/net/websocket
# golang.org/x/oauth2 v0.35.0 # golang.org/x/oauth2 v0.36.0
## explicit; go 1.24.0 ## explicit; go 1.25.0
golang.org/x/oauth2 golang.org/x/oauth2
golang.org/x/oauth2/internal golang.org/x/oauth2/internal
# golang.org/x/sync v0.20.0 # golang.org/x/sync v0.20.0
@@ -423,14 +423,14 @@ golang.org/x/tools/internal/stdlib
golang.org/x/tools/internal/typeparams golang.org/x/tools/internal/typeparams
golang.org/x/tools/internal/typesinternal golang.org/x/tools/internal/typesinternal
golang.org/x/tools/internal/versions golang.org/x/tools/internal/versions
# google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 # google.golang.org/genproto/googleapis/api v0.0.0-20260226221140-a57be14db171
## explicit; go 1.24.0 ## explicit; go 1.25.0
google.golang.org/genproto/googleapis/api/httpbody google.golang.org/genproto/googleapis/api/httpbody
# google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 # google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171
## explicit; go 1.24.0 ## explicit; go 1.25.0
google.golang.org/genproto/googleapis/rpc/status google.golang.org/genproto/googleapis/rpc/status
# google.golang.org/grpc v1.79.2 # google.golang.org/grpc v1.81.1
## explicit; go 1.24.0 ## explicit; go 1.25.0
google.golang.org/grpc google.golang.org/grpc
google.golang.org/grpc/attributes google.golang.org/grpc/attributes
google.golang.org/grpc/backoff google.golang.org/grpc/backoff
@@ -468,6 +468,7 @@ google.golang.org/grpc/internal/grpclog
google.golang.org/grpc/internal/grpcsync google.golang.org/grpc/internal/grpcsync
google.golang.org/grpc/internal/grpcutil google.golang.org/grpc/internal/grpcutil
google.golang.org/grpc/internal/idle google.golang.org/grpc/internal/idle
google.golang.org/grpc/internal/mem
google.golang.org/grpc/internal/metadata google.golang.org/grpc/internal/metadata
google.golang.org/grpc/internal/pretty google.golang.org/grpc/internal/pretty
google.golang.org/grpc/internal/proxyattributes google.golang.org/grpc/internal/proxyattributes
@@ -483,6 +484,7 @@ google.golang.org/grpc/internal/status
google.golang.org/grpc/internal/syscall google.golang.org/grpc/internal/syscall
google.golang.org/grpc/internal/transport google.golang.org/grpc/internal/transport
google.golang.org/grpc/internal/transport/networktype google.golang.org/grpc/internal/transport/networktype
google.golang.org/grpc/internal/transport/readyreader
google.golang.org/grpc/keepalive google.golang.org/grpc/keepalive
google.golang.org/grpc/mem google.golang.org/grpc/mem
google.golang.org/grpc/metadata google.golang.org/grpc/metadata