Bump golang.org/x/net from v0.40.0 to v0.53.0
Check / check (1.22.x, macos-latest) (push) Has been cancelled
Check / check (1.22.x, ubuntu-latest) (push) Has been cancelled
Check / check (1.22.x, windows-latest) (push) Has been cancelled
Semgrep config / semgrep/ci (push) Has been cancelled

This commit is contained in:
MiguelMarcelino
2026-05-08 11:26:25 +01:00
committed by Miguel da Costa Martins Marcelino
parent 4d8df2b2c0
commit ae3799a098
171 changed files with 31621 additions and 74888 deletions
+9 -9
View File
@@ -3,6 +3,7 @@ module github.com/cloudflare/cloudflared
go 1.26
require (
github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0
github.com/coreos/go-oidc/v3 v3.17.0
github.com/coreos/go-systemd/v22 v22.5.0
github.com/facebookgo/grace v0.0.0-20180706040059-75cf19382434
@@ -35,11 +36,11 @@ require (
go.opentelemetry.io/proto/otlp v1.2.0
go.uber.org/automaxprocs v1.6.0
go.uber.org/mock v0.5.1
golang.org/x/crypto v0.38.0
golang.org/x/net v0.40.0
golang.org/x/sync v0.14.0
golang.org/x/sys v0.41.0
golang.org/x/term v0.32.0
golang.org/x/crypto v0.50.0
golang.org/x/net v0.53.0
golang.org/x/sync v0.20.0
golang.org/x/sys v0.43.0
golang.org/x/term v0.42.0
google.golang.org/protobuf v1.36.6
gopkg.in/natefinch/lumberjack.v2 v2.0.0
gopkg.in/yaml.v3 v3.0.1
@@ -52,7 +53,6 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/bytedance/sonic v1.12.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/ebitengine/purego v0.10.0 // indirect
@@ -91,10 +91,10 @@ require (
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/otel/metric v1.40.0 // indirect
golang.org/x/arch v0.4.0 // indirect
golang.org/x/mod v0.24.0 // indirect
golang.org/x/mod v0.34.0 // indirect
golang.org/x/oauth2 v0.30.0 // indirect
golang.org/x/text v0.25.0 // indirect
golang.org/x/tools v0.32.0 // indirect
golang.org/x/text v0.36.0 // indirect
golang.org/x/tools v0.43.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250505200425-f936aa4a68b2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250512202823-5a2f75b736a9 // indirect
google.golang.org/grpc v1.72.2 // indirect
+16 -16
View File
@@ -245,21 +245,21 @@ golang.org/x/arch v0.4.0 h1:A8WCeEWhLwPBKNbFi5Wv5UTCBx5zzubnXDlMOFAzFMc=
golang.org/x/arch v0.4.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY=
golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds=
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -269,20 +269,20 @@ golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.32.0 h1:Q7N1vhpkQv7ybVzLFtTjvQya2ewbwNDZzUgfXGqtMWU=
golang.org/x/tools v0.32.0/go.mod h1:ZxrU41P/wAbZD8EDa6dDCa6XfpkhJ7HFMjHJXfBDu8s=
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+8
View File
@@ -12,6 +12,8 @@ import (
// XOF defines the interface to hash functions that
// support arbitrary-length output.
//
// New callers should prefer the standard library [hash.XOF].
type XOF interface {
// Write absorbs more data into the hash's state. It panics if called
// after Read.
@@ -47,6 +49,8 @@ const maxOutputLength = (1 << 32) * 64
//
// A non-nil key turns the hash into a MAC. The key must between
// zero and 32 bytes long.
//
// The result can be safely interface-upgraded to [hash.XOF].
func NewXOF(size uint32, key []byte) (XOF, error) {
if len(key) > Size {
return nil, errKeySize
@@ -93,6 +97,10 @@ func (x *xof) Clone() XOF {
return &clone
}
func (x *xof) BlockSize() int {
return x.d.BlockSize()
}
func (x *xof) Reset() {
x.cfg[0] = byte(Size)
binary.LittleEndian.PutUint32(x.cfg[4:], uint32(Size)) // leaf length
+11
View File
@@ -0,0 +1,11 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.25
package blake2b
import "hash"
var _ hash.XOF = (*xof)(nil)
+1 -1
View File
@@ -29,7 +29,7 @@ loop:
MOVD $NUM_ROUNDS, R21
VLD1 (R11), [V30.S4, V31.S4]
// load contants
// load constants
// VLD4R (R10), [V0.S4, V1.S4, V2.S4, V3.S4]
WORD $0x4D60E940
+3
View File
@@ -38,6 +38,9 @@ type chacha20poly1305 struct {
// New returns a ChaCha20-Poly1305 AEAD that uses the given 256-bit key.
func New(key []byte) (cipher.AEAD, error) {
if fips140Enforced() {
return nil, errors.New("chacha20poly1305: use of ChaCha20Poly1305 is not allowed in FIPS 140-only mode")
}
if len(key) != KeySize {
return nil, errors.New("chacha20poly1305: bad key length")
}
+8 -2
View File
@@ -56,7 +56,10 @@ func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []
ret, out := sliceForAppend(dst, len(plaintext)+16)
if alias.InexactOverlap(out, plaintext) {
panic("chacha20poly1305: invalid buffer overlap")
panic("chacha20poly1305: invalid buffer overlap of output and input")
}
if alias.AnyOverlap(out, additionalData) {
panic("chacha20poly1305: invalid buffer overlap of output and additional data")
}
chacha20Poly1305Seal(out[:], state[:], plaintext, additionalData)
return ret
@@ -73,7 +76,10 @@ func (c *chacha20poly1305) open(dst, nonce, ciphertext, additionalData []byte) (
ciphertext = ciphertext[:len(ciphertext)-16]
ret, out := sliceForAppend(dst, len(ciphertext))
if alias.InexactOverlap(out, ciphertext) {
panic("chacha20poly1305: invalid buffer overlap")
panic("chacha20poly1305: invalid buffer overlap of output and input")
}
if alias.AnyOverlap(out, additionalData) {
panic("chacha20poly1305: invalid buffer overlap of output and additional data")
}
if !chacha20Poly1305Open(out, state[:], ciphertext, additionalData) {
for i := range out {
+8 -2
View File
@@ -31,7 +31,10 @@ func (c *chacha20poly1305) sealGeneric(dst, nonce, plaintext, additionalData []b
ret, out := sliceForAppend(dst, len(plaintext)+poly1305.TagSize)
ciphertext, tag := out[:len(plaintext)], out[len(plaintext):]
if alias.InexactOverlap(out, plaintext) {
panic("chacha20poly1305: invalid buffer overlap")
panic("chacha20poly1305: invalid buffer overlap of output and input")
}
if alias.AnyOverlap(out, additionalData) {
panic("chacha20poly1305: invalid buffer overlap of output and additional data")
}
var polyKey [32]byte
@@ -67,7 +70,10 @@ func (c *chacha20poly1305) openGeneric(dst, nonce, ciphertext, additionalData []
ret, out := sliceForAppend(dst, len(ciphertext))
if alias.InexactOverlap(out, ciphertext) {
panic("chacha20poly1305: invalid buffer overlap")
panic("chacha20poly1305: invalid buffer overlap of output and input")
}
if alias.AnyOverlap(out, additionalData) {
panic("chacha20poly1305: invalid buffer overlap of output and additional data")
}
if !p.Verify(tag) {
for i := range out {
+9
View File
@@ -0,0 +1,9 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.26
package chacha20poly1305
func fips140Enforced() bool { return false }
+11
View File
@@ -0,0 +1,11 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.26
package chacha20poly1305
import "crypto/fips140"
func fips140Enforced() bool { return fips140.Enforced() }
+3
View File
@@ -22,6 +22,9 @@ type xchacha20poly1305 struct {
// preferred when nonce uniqueness cannot be trivially ensured, or whenever
// nonces are randomly generated.
func NewX(key []byte) (cipher.AEAD, error) {
if fips140Enforced() {
return nil, errors.New("chacha20poly1305: use of ChaCha20Poly1305 is not allowed in FIPS 140-only mode")
}
if len(key) != KeySize {
return nil, errors.New("chacha20poly1305: bad key length")
}
+8 -5
View File
@@ -3,11 +3,14 @@
// license that can be found in the LICENSE file.
// Package curve25519 provides an implementation of the X25519 function, which
// performs scalar multiplication on the elliptic curve known as Curve25519.
// See RFC 7748.
// performs scalar multiplication on the elliptic curve known as Curve25519
// according to [RFC 7748].
//
// This package is a wrapper for the X25519 implementation
// in the crypto/ecdh package.
// The curve25519 package is a wrapper for the X25519 implementation in the
// crypto/ecdh package. It is [frozen] and is not accepting new features.
//
// [RFC 7748]: https://datatracker.ietf.org/doc/html/rfc7748
// [frozen]: https://go.dev/wiki/Frozen
package curve25519
import "crypto/ecdh"
@@ -36,7 +39,7 @@ func ScalarBaseMult(dst, scalar *[32]byte) {
curve := ecdh.X25519()
priv, err := curve.NewPrivateKey(scalar[:])
if err != nil {
panic("curve25519: internal error: scalarBaseMult was not 32 bytes")
panic("curve25519: " + err.Error())
}
copy(dst[:], priv.PublicKey().Bytes())
}
+4
View File
@@ -3,6 +3,10 @@
// license that can be found in the LICENSE file.
// Package salsa provides low-level access to functions in the Salsa family.
//
// Deprecated: this package exposes unsafe low-level operations. New applications
// should consider using the AEAD construction in golang.org/x/crypto/chacha20poly1305
// instead. Existing users should migrate to golang.org/x/crypto/salsa20.
package salsa
import "math/bits"
+40 -27
View File
@@ -20,14 +20,19 @@ import (
// returned by MultiAlgorithmSigner and don't appear in the Signature.Format
// field.
const (
CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com"
CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com"
CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com"
CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com"
CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com"
CertAlgoSKECDSA256v01 = "sk-ecdsa-sha2-nistp256-cert-v01@openssh.com"
CertAlgoED25519v01 = "ssh-ed25519-cert-v01@openssh.com"
CertAlgoSKED25519v01 = "sk-ssh-ed25519-cert-v01@openssh.com"
CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com"
// Deprecated: DSA is only supported at insecure key sizes, and was removed
// from major implementations.
CertAlgoDSAv01 = InsecureCertAlgoDSAv01
// Deprecated: DSA is only supported at insecure key sizes, and was removed
// from major implementations.
InsecureCertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com"
CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com"
CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com"
CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com"
CertAlgoSKECDSA256v01 = "sk-ecdsa-sha2-nistp256-cert-v01@openssh.com"
CertAlgoED25519v01 = "ssh-ed25519-cert-v01@openssh.com"
CertAlgoSKED25519v01 = "sk-ssh-ed25519-cert-v01@openssh.com"
// CertAlgoRSASHA256v01 and CertAlgoRSASHA512v01 can't appear as a
// Certificate.Type (or PublicKey.Type), but only in
@@ -228,7 +233,11 @@ func parseCert(in []byte, privAlgo string) (*Certificate, error) {
if err != nil {
return nil, err
}
// The Type() function is intended to return only certificate key types, but
// we use certKeyAlgoNames anyway for safety, to match [Certificate.Type].
if _, ok := certKeyAlgoNames[k.Type()]; ok {
return nil, fmt.Errorf("ssh: the signature key type %q is invalid for certificates", k.Type())
}
c.SignatureKey = k
c.Signature, rest, ok = parseSignatureBody(g.Signature)
if !ok || len(rest) > 0 {
@@ -296,16 +305,13 @@ type CertChecker struct {
SupportedCriticalOptions []string
// IsUserAuthority should return true if the key is recognized as an
// authority for the given user certificate. This allows for
// certificates to be signed by other certificates. This must be set
// if this CertChecker will be checking user certificates.
// authority for user certificate. This must be set if this CertChecker
// will be checking user certificates.
IsUserAuthority func(auth PublicKey) bool
// IsHostAuthority should report whether the key is recognized as
// an authority for this host. This allows for certificates to be
// signed by other keys, and for those other keys to only be valid
// signers for particular hostnames. This must be set if this
// CertChecker will be checking host certificates.
// an authority for this host. This must be set if this CertChecker
// will be checking host certificates.
IsHostAuthority func(auth PublicKey, address string) bool
// Clock is used for verifying time stamps. If nil, time.Now
@@ -442,12 +448,19 @@ func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
// SignCert signs the certificate with an authority, setting the Nonce,
// SignatureKey, and Signature fields. If the authority implements the
// MultiAlgorithmSigner interface the first algorithm in the list is used. This
// is useful if you want to sign with a specific algorithm.
// is useful if you want to sign with a specific algorithm. As specified in
// [SSH-CERTS], Section 2.1.1, authority can't be a [Certificate].
func (c *Certificate) SignCert(rand io.Reader, authority Signer) error {
c.Nonce = make([]byte, 32)
if _, err := io.ReadFull(rand, c.Nonce); err != nil {
return err
}
// The Type() function is intended to return only certificate key types, but
// we use certKeyAlgoNames anyway for safety, to match [Certificate.Type].
if _, ok := certKeyAlgoNames[authority.PublicKey().Type()]; ok {
return fmt.Errorf("ssh: certificates cannot be used as authority (public key type %q)",
authority.PublicKey().Type())
}
c.SignatureKey = authority.PublicKey()
if v, ok := authority.(MultiAlgorithmSigner); ok {
@@ -485,16 +498,16 @@ func (c *Certificate) SignCert(rand io.Reader, authority Signer) error {
//
// This map must be kept in sync with the one in agent/client.go.
var certKeyAlgoNames = map[string]string{
CertAlgoRSAv01: KeyAlgoRSA,
CertAlgoRSASHA256v01: KeyAlgoRSASHA256,
CertAlgoRSASHA512v01: KeyAlgoRSASHA512,
CertAlgoDSAv01: KeyAlgoDSA,
CertAlgoECDSA256v01: KeyAlgoECDSA256,
CertAlgoECDSA384v01: KeyAlgoECDSA384,
CertAlgoECDSA521v01: KeyAlgoECDSA521,
CertAlgoSKECDSA256v01: KeyAlgoSKECDSA256,
CertAlgoED25519v01: KeyAlgoED25519,
CertAlgoSKED25519v01: KeyAlgoSKED25519,
CertAlgoRSAv01: KeyAlgoRSA,
CertAlgoRSASHA256v01: KeyAlgoRSASHA256,
CertAlgoRSASHA512v01: KeyAlgoRSASHA512,
InsecureCertAlgoDSAv01: InsecureKeyAlgoDSA,
CertAlgoECDSA256v01: KeyAlgoECDSA256,
CertAlgoECDSA384v01: KeyAlgoECDSA384,
CertAlgoECDSA521v01: KeyAlgoECDSA521,
CertAlgoSKECDSA256v01: KeyAlgoSKECDSA256,
CertAlgoED25519v01: KeyAlgoED25519,
CertAlgoSKED25519v01: KeyAlgoSKED25519,
}
// underlyingAlgo returns the signature algorithm associated with algo (which is
+38 -38
View File
@@ -8,6 +8,7 @@ import (
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/fips140"
"crypto/rc4"
"crypto/subtle"
"encoding/binary"
@@ -15,6 +16,7 @@ import (
"fmt"
"hash"
"io"
"slices"
"golang.org/x/crypto/chacha20"
"golang.org/x/crypto/internal/poly1305"
@@ -58,11 +60,11 @@ func newRC4(key, iv []byte) (cipher.Stream, error) {
type cipherMode struct {
keySize int
ivSize int
create func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error)
create func(key, iv []byte, macKey []byte, algs DirectionAlgorithms) (packetCipher, error)
}
func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
return func(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) {
return func(key, iv, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) {
stream, err := createFunc(key, iv)
if err != nil {
return nil, err
@@ -93,41 +95,41 @@ func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream,
}
// cipherModes documents properties of supported ciphers. Ciphers not included
// are not supported and will not be negotiated, even if explicitly requested in
// ClientConfig.Crypto.Ciphers.
var cipherModes = map[string]*cipherMode{
// Ciphers from RFC 4344, which introduced many CTR-based ciphers. Algorithms
// are defined in the order specified in the RFC.
"aes128-ctr": {16, aes.BlockSize, streamCipherMode(0, newAESCTR)},
"aes192-ctr": {24, aes.BlockSize, streamCipherMode(0, newAESCTR)},
"aes256-ctr": {32, aes.BlockSize, streamCipherMode(0, newAESCTR)},
// are not supported and will not be negotiated, even if explicitly configured.
// When FIPS mode is enabled, only FIPS-approved algorithms are included.
var cipherModes = map[string]*cipherMode{}
// Ciphers from RFC 4345, which introduces security-improved arcfour ciphers.
// They are defined in the order specified in the RFC.
"arcfour128": {16, 0, streamCipherMode(1536, newRC4)},
"arcfour256": {32, 0, streamCipherMode(1536, newRC4)},
func init() {
cipherModes[CipherAES128CTR] = &cipherMode{16, aes.BlockSize, streamCipherMode(0, newAESCTR)}
cipherModes[CipherAES192CTR] = &cipherMode{24, aes.BlockSize, streamCipherMode(0, newAESCTR)}
cipherModes[CipherAES256CTR] = &cipherMode{32, aes.BlockSize, streamCipherMode(0, newAESCTR)}
// Use of GCM with arbitrary IVs is not allowed in FIPS 140-only mode,
// we'll wire it up to NewGCMForSSH in Go 1.26.
//
// For now it means we'll work with fips140=on but not fips140=only.
cipherModes[CipherAES128GCM] = &cipherMode{16, 12, newGCMCipher}
cipherModes[CipherAES256GCM] = &cipherMode{32, 12, newGCMCipher}
// Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol.
// Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and
// RC4) has problems with weak keys, and should be used with caution."
// RFC 4345 introduces improved versions of Arcfour.
"arcfour": {16, 0, streamCipherMode(0, newRC4)},
// AEAD ciphers
gcm128CipherID: {16, 12, newGCMCipher},
gcm256CipherID: {32, 12, newGCMCipher},
chacha20Poly1305ID: {64, 0, newChaCha20Cipher},
if fips140.Enabled() {
defaultCiphers = slices.DeleteFunc(defaultCiphers, func(algo string) bool {
_, ok := cipherModes[algo]
return !ok
})
return
}
cipherModes[CipherChaCha20Poly1305] = &cipherMode{64, 0, newChaCha20Cipher}
// Insecure ciphers not included in the default configuration.
cipherModes[InsecureCipherRC4128] = &cipherMode{16, 0, streamCipherMode(1536, newRC4)}
cipherModes[InsecureCipherRC4256] = &cipherMode{32, 0, streamCipherMode(1536, newRC4)}
cipherModes[InsecureCipherRC4] = &cipherMode{16, 0, streamCipherMode(0, newRC4)}
// CBC mode is insecure and so is not included in the default config.
// (See https://www.ieee-security.org/TC/SP2013/papers/4977a526.pdf). If absolutely
// needed, it's possible to specify a custom Config to enable it.
// You should expect that an active attacker can recover plaintext if
// you do.
aes128cbcID: {16, aes.BlockSize, newAESCBCCipher},
// 3des-cbc is insecure and is not included in the default
// config.
tripledescbcID: {24, des.BlockSize, newTripleDESCBCCipher},
cipherModes[InsecureCipherAES128CBC] = &cipherMode{16, aes.BlockSize, newAESCBCCipher}
cipherModes[InsecureCipherTripleDESCBC] = &cipherMode{24, des.BlockSize, newTripleDESCBCCipher}
}
// prefixLen is the length of the packet prefix that contains the packet length
@@ -307,7 +309,7 @@ type gcmCipher struct {
buf []byte
}
func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs DirectionAlgorithms) (packetCipher, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
@@ -429,7 +431,7 @@ type cbcCipher struct {
oracleCamouflage uint32
}
func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) {
cbc := &cbcCipher{
mac: macModes[algs.MAC].new(macKey),
decrypter: cipher.NewCBCDecrypter(c, iv),
@@ -443,7 +445,7 @@ func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorith
return cbc, nil
}
func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
func newAESCBCCipher(key, iv, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
@@ -457,7 +459,7 @@ func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCi
return cbc, nil
}
func newTripleDESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
func newTripleDESCBCCipher(key, iv, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) {
c, err := des.NewTripleDESCipher(key)
if err != nil {
return nil, err
@@ -584,7 +586,7 @@ func (c *cbcCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader
// Length of encrypted portion of the packet (header, payload, padding).
// Enforce minimum padding and packet size.
encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize)
encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPacketSize)
// Enforce block size.
encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize
@@ -635,8 +637,6 @@ func (c *cbcCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader
return nil
}
const chacha20Poly1305ID = "chacha20-poly1305@openssh.com"
// chacha20Poly1305Cipher implements the chacha20-poly1305@openssh.com
// AEAD, which is described here:
//
@@ -650,7 +650,7 @@ type chacha20Poly1305Cipher struct {
buf []byte
}
func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs DirectionAlgorithms) (packetCipher, error) {
if len(key) != 64 {
panic(len(key))
}
+1
View File
@@ -110,6 +110,7 @@ func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) e
}
c.sessionID = c.transport.getSessionID()
c.algorithms = c.transport.getAlgorithms()
return c.clientAuthenticate(config)
}
+15 -19
View File
@@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"io"
"slices"
"strings"
)
@@ -83,7 +84,7 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
// success
return nil
} else if ok == authFailure {
if m := auth.method(); !contains(tried, m) {
if m := auth.method(); !slices.Contains(tried, m) {
tried = append(tried, m)
}
}
@@ -97,7 +98,7 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
findNext:
for _, a := range config.Auth {
candidateMethod := a.method()
if contains(tried, candidateMethod) {
if slices.Contains(tried, candidateMethod) {
continue
}
for _, meth := range methods {
@@ -117,15 +118,6 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", tried)
}
func contains(list []string, e string) bool {
for _, s := range list {
if s == e {
return true
}
}
return false
}
// An AuthMethod represents an instance of an RFC 4252 authentication method.
type AuthMethod interface {
// auth authenticates user over transport t.
@@ -255,7 +247,7 @@ func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (MultiA
// Fallback to use if there is no "server-sig-algs" extension or a
// common algorithm cannot be found. We use the public key format if the
// MultiAlgorithmSigner supports it, otherwise we return an error.
if !contains(as.Algorithms(), underlyingAlgo(keyFormat)) {
if !slices.Contains(as.Algorithms(), underlyingAlgo(keyFormat)) {
return "", fmt.Errorf("ssh: no common public key signature algorithm, server only supports %q for key type %q, signer only supports %v",
underlyingAlgo(keyFormat), keyFormat, as.Algorithms())
}
@@ -282,14 +274,18 @@ func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (MultiA
}
// Filter algorithms based on those supported by MultiAlgorithmSigner.
// Iterate over the signer's algorithms first to preserve its preference order.
supportedKeyAlgos := algorithmsForKeyFormat(keyFormat)
var keyAlgos []string
for _, algo := range algorithmsForKeyFormat(keyFormat) {
if contains(as.Algorithms(), underlyingAlgo(algo)) {
keyAlgos = append(keyAlgos, algo)
for _, signerAlgo := range as.Algorithms() {
if idx := slices.IndexFunc(supportedKeyAlgos, func(algo string) bool {
return underlyingAlgo(algo) == signerAlgo
}); idx >= 0 {
keyAlgos = append(keyAlgos, supportedKeyAlgos[idx])
}
}
algo, err := findCommon("public key signature algorithm", keyAlgos, serverAlgos)
algo, err := findCommon("public key signature algorithm", keyAlgos, serverAlgos, true)
if err != nil {
// If there is no overlap, return the fallback algorithm to support
// servers that fail to list all supported algorithms.
@@ -334,7 +330,7 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
// the key try to use the obtained algorithm as if "server-sig-algs" had
// not been implemented if supported from the algorithm signer.
if !ok && idx < origSignersLen && isRSACert(algo) && algo != CertAlgoRSAv01 {
if contains(as.Algorithms(), KeyAlgoRSA) {
if slices.Contains(as.Algorithms(), KeyAlgoRSA) {
// We retry using the compat algorithm after all signers have
// been tried normally.
signers = append(signers, &multiAlgorithmSigner{
@@ -385,7 +381,7 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
// contain the "publickey" method, do not attempt to authenticate with any
// other keys. According to RFC 4252 Section 7, the latter can occur when
// additional authentication methods are required.
if success == authSuccess || !contains(methods, cb.method()) {
if success == authSuccess || !slices.Contains(methods, cb.method()) {
return success, methods, err
}
}
@@ -434,7 +430,7 @@ func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
// servers send the key type instead. OpenSSH allows any algorithm
// that matches the public key, so we do the same.
// https://github.com/openssh/openssh-portable/blob/86bdd385/sshconnect2.c#L709
if !contains(algorithmsForKeyFormat(key.Type()), msg.Algo) {
if !slices.Contains(algorithmsForKeyFormat(key.Type()), msg.Algo) {
return false, nil
}
if !bytes.Equal(msg.PubKey, pubKey) {
+363 -112
View File
@@ -6,10 +6,12 @@ package ssh
import (
"crypto"
"crypto/fips140"
"crypto/rand"
"fmt"
"io"
"math"
"slices"
"sync"
_ "crypto/sha1"
@@ -24,88 +26,298 @@ const (
serviceSSH = "ssh-connection"
)
// supportedCiphers lists ciphers we support but might not recommend.
var supportedCiphers = []string{
"aes128-ctr", "aes192-ctr", "aes256-ctr",
"aes128-gcm@openssh.com", gcm256CipherID,
chacha20Poly1305ID,
"arcfour256", "arcfour128", "arcfour",
aes128cbcID,
tripledescbcID,
// The ciphers currently or previously implemented by this library, to use in
// [Config.Ciphers]. For a list, see the [Algorithms.Ciphers] returned by
// [SupportedAlgorithms] or [InsecureAlgorithms].
const (
CipherAES128GCM = "aes128-gcm@openssh.com"
CipherAES256GCM = "aes256-gcm@openssh.com"
CipherChaCha20Poly1305 = "chacha20-poly1305@openssh.com"
CipherAES128CTR = "aes128-ctr"
CipherAES192CTR = "aes192-ctr"
CipherAES256CTR = "aes256-ctr"
InsecureCipherAES128CBC = "aes128-cbc"
InsecureCipherTripleDESCBC = "3des-cbc"
InsecureCipherRC4 = "arcfour"
InsecureCipherRC4128 = "arcfour128"
InsecureCipherRC4256 = "arcfour256"
)
// The key exchanges currently or previously implemented by this library, to use
// in [Config.KeyExchanges]. For a list, see the
// [Algorithms.KeyExchanges] returned by [SupportedAlgorithms] or
// [InsecureAlgorithms].
const (
InsecureKeyExchangeDH1SHA1 = "diffie-hellman-group1-sha1"
InsecureKeyExchangeDH14SHA1 = "diffie-hellman-group14-sha1"
KeyExchangeDH14SHA256 = "diffie-hellman-group14-sha256"
KeyExchangeDH16SHA512 = "diffie-hellman-group16-sha512"
KeyExchangeECDHP256 = "ecdh-sha2-nistp256"
KeyExchangeECDHP384 = "ecdh-sha2-nistp384"
KeyExchangeECDHP521 = "ecdh-sha2-nistp521"
KeyExchangeCurve25519 = "curve25519-sha256"
InsecureKeyExchangeDHGEXSHA1 = "diffie-hellman-group-exchange-sha1"
KeyExchangeDHGEXSHA256 = "diffie-hellman-group-exchange-sha256"
// KeyExchangeMLKEM768X25519 is supported from Go 1.24.
KeyExchangeMLKEM768X25519 = "mlkem768x25519-sha256"
// An alias for KeyExchangeCurve25519SHA256. This kex ID will be added if
// KeyExchangeCurve25519SHA256 is requested for backward compatibility with
// OpenSSH versions up to 7.2.
keyExchangeCurve25519LibSSH = "curve25519-sha256@libssh.org"
)
// The message authentication code (MAC) currently or previously implemented by
// this library, to use in [Config.MACs]. For a list, see the
// [Algorithms.MACs] returned by [SupportedAlgorithms] or
// [InsecureAlgorithms].
const (
HMACSHA256ETM = "hmac-sha2-256-etm@openssh.com"
HMACSHA512ETM = "hmac-sha2-512-etm@openssh.com"
HMACSHA256 = "hmac-sha2-256"
HMACSHA512 = "hmac-sha2-512"
HMACSHA1 = "hmac-sha1"
InsecureHMACSHA196 = "hmac-sha1-96"
)
var (
// supportedKexAlgos specifies key-exchange algorithms implemented by this
// package in preference order, excluding those with security issues.
supportedKexAlgos = []string{
KeyExchangeMLKEM768X25519,
KeyExchangeCurve25519,
KeyExchangeECDHP256,
KeyExchangeECDHP384,
KeyExchangeECDHP521,
KeyExchangeDH14SHA256,
KeyExchangeDH16SHA512,
KeyExchangeDHGEXSHA256,
}
// defaultKexAlgos specifies the default preference for key-exchange
// algorithms in preference order.
defaultKexAlgos = []string{
KeyExchangeMLKEM768X25519,
KeyExchangeCurve25519,
KeyExchangeECDHP256,
KeyExchangeECDHP384,
KeyExchangeECDHP521,
KeyExchangeDH14SHA256,
InsecureKeyExchangeDH14SHA1,
}
// insecureKexAlgos specifies key-exchange algorithms implemented by this
// package and which have security issues.
insecureKexAlgos = []string{
InsecureKeyExchangeDH14SHA1,
InsecureKeyExchangeDH1SHA1,
InsecureKeyExchangeDHGEXSHA1,
}
// supportedCiphers specifies cipher algorithms implemented by this package
// in preference order, excluding those with security issues.
supportedCiphers = []string{
CipherAES128GCM,
CipherAES256GCM,
CipherChaCha20Poly1305,
CipherAES128CTR,
CipherAES192CTR,
CipherAES256CTR,
}
// defaultCiphers specifies the default preference for ciphers algorithms
// in preference order.
defaultCiphers = supportedCiphers
// insecureCiphers specifies cipher algorithms implemented by this
// package and which have security issues.
insecureCiphers = []string{
InsecureCipherAES128CBC,
InsecureCipherTripleDESCBC,
InsecureCipherRC4256,
InsecureCipherRC4128,
InsecureCipherRC4,
}
// supportedMACs specifies MAC algorithms implemented by this package in
// preference order, excluding those with security issues.
supportedMACs = []string{
HMACSHA256ETM,
HMACSHA512ETM,
HMACSHA256,
HMACSHA512,
HMACSHA1,
}
// defaultMACs specifies the default preference for MAC algorithms in
// preference order.
defaultMACs = []string{
HMACSHA256ETM,
HMACSHA512ETM,
HMACSHA256,
HMACSHA512,
HMACSHA1,
InsecureHMACSHA196,
}
// insecureMACs specifies MAC algorithms implemented by this
// package and which have security issues.
insecureMACs = []string{
InsecureHMACSHA196,
}
// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e.
// methods of authenticating servers) implemented by this package in
// preference order, excluding those with security issues.
supportedHostKeyAlgos = []string{
CertAlgoRSASHA256v01,
CertAlgoRSASHA512v01,
CertAlgoECDSA256v01,
CertAlgoECDSA384v01,
CertAlgoECDSA521v01,
CertAlgoED25519v01,
KeyAlgoRSASHA256,
KeyAlgoRSASHA512,
KeyAlgoECDSA256,
KeyAlgoECDSA384,
KeyAlgoECDSA521,
KeyAlgoED25519,
}
// defaultHostKeyAlgos specifies the default preference for host-key
// algorithms in preference order.
defaultHostKeyAlgos = []string{
CertAlgoRSASHA256v01,
CertAlgoRSASHA512v01,
CertAlgoRSAv01,
InsecureCertAlgoDSAv01,
CertAlgoECDSA256v01,
CertAlgoECDSA384v01,
CertAlgoECDSA521v01,
CertAlgoED25519v01,
KeyAlgoECDSA256,
KeyAlgoECDSA384,
KeyAlgoECDSA521,
KeyAlgoRSASHA256,
KeyAlgoRSASHA512,
KeyAlgoRSA,
InsecureKeyAlgoDSA,
KeyAlgoED25519,
}
// insecureHostKeyAlgos specifies host-key algorithms implemented by this
// package and which have security issues.
insecureHostKeyAlgos = []string{
KeyAlgoRSA,
InsecureKeyAlgoDSA,
CertAlgoRSAv01,
InsecureCertAlgoDSAv01,
}
// supportedPubKeyAuthAlgos specifies the supported client public key
// authentication algorithms. Note that this doesn't include certificate
// types since those use the underlying algorithm. Order is irrelevant.
supportedPubKeyAuthAlgos = []string{
KeyAlgoED25519,
KeyAlgoSKED25519,
KeyAlgoSKECDSA256,
KeyAlgoECDSA256,
KeyAlgoECDSA384,
KeyAlgoECDSA521,
KeyAlgoRSASHA256,
KeyAlgoRSASHA512,
}
// defaultPubKeyAuthAlgos specifies the preferred client public key
// authentication algorithms. This list is sent to the client if it supports
// the server-sig-algs extension. Order is irrelevant.
defaultPubKeyAuthAlgos = []string{
KeyAlgoED25519,
KeyAlgoSKED25519,
KeyAlgoSKECDSA256,
KeyAlgoECDSA256,
KeyAlgoECDSA384,
KeyAlgoECDSA521,
KeyAlgoRSASHA256,
KeyAlgoRSASHA512,
KeyAlgoRSA,
InsecureKeyAlgoDSA,
}
// insecurePubKeyAuthAlgos specifies client public key authentication
// algorithms implemented by this package and which have security issues.
insecurePubKeyAuthAlgos = []string{
KeyAlgoRSA,
InsecureKeyAlgoDSA,
}
)
// NegotiatedAlgorithms defines algorithms negotiated between client and server.
type NegotiatedAlgorithms struct {
KeyExchange string
HostKey string
Read DirectionAlgorithms
Write DirectionAlgorithms
}
// preferredCiphers specifies the default preference for ciphers.
var preferredCiphers = []string{
"aes128-gcm@openssh.com", gcm256CipherID,
chacha20Poly1305ID,
"aes128-ctr", "aes192-ctr", "aes256-ctr",
// Algorithms defines a set of algorithms that can be configured in the client
// or server config for negotiation during a handshake.
type Algorithms struct {
KeyExchanges []string
Ciphers []string
MACs []string
HostKeys []string
PublicKeyAuths []string
}
// supportedKexAlgos specifies the supported key-exchange algorithms in
// preference order.
var supportedKexAlgos = []string{
kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH,
// P384 and P521 are not constant-time yet, but since we don't
// reuse ephemeral keys, using them for ECDH should be OK.
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
kexAlgoDH14SHA256, kexAlgoDH16SHA512, kexAlgoDH14SHA1,
kexAlgoDH1SHA1,
func init() {
if fips140.Enabled() {
defaultHostKeyAlgos = slices.DeleteFunc(defaultHostKeyAlgos, func(algo string) bool {
_, err := hashFunc(underlyingAlgo(algo))
return err != nil
})
defaultPubKeyAuthAlgos = slices.DeleteFunc(defaultPubKeyAuthAlgos, func(algo string) bool {
_, err := hashFunc(underlyingAlgo(algo))
return err != nil
})
}
}
// serverForbiddenKexAlgos contains key exchange algorithms, that are forbidden
// for the server half.
var serverForbiddenKexAlgos = map[string]struct{}{
kexAlgoDHGEXSHA1: {}, // server half implementation is only minimal to satisfy the automated tests
kexAlgoDHGEXSHA256: {}, // server half implementation is only minimal to satisfy the automated tests
func hashFunc(format string) (crypto.Hash, error) {
switch format {
case KeyAlgoRSASHA256, KeyAlgoECDSA256, KeyAlgoSKED25519, KeyAlgoSKECDSA256:
return crypto.SHA256, nil
case KeyAlgoECDSA384:
return crypto.SHA384, nil
case KeyAlgoRSASHA512, KeyAlgoECDSA521:
return crypto.SHA512, nil
case KeyAlgoED25519:
// KeyAlgoED25519 doesn't pre-hash.
return 0, nil
case KeyAlgoRSA, InsecureKeyAlgoDSA:
if fips140.Enabled() {
return 0, fmt.Errorf("ssh: hash algorithm for format %q not allowed in FIPS 140 mode", format)
}
return crypto.SHA1, nil
default:
return 0, fmt.Errorf("ssh: hash algorithm for format %q not mapped", format)
}
}
// preferredKexAlgos specifies the default preference for key-exchange
// algorithms in preference order. The diffie-hellman-group16-sha512 algorithm
// is disabled by default because it is a bit slower than the others.
var preferredKexAlgos = []string{
kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH,
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
kexAlgoDH14SHA256, kexAlgoDH14SHA1,
// SupportedAlgorithms returns algorithms currently implemented by this package,
// excluding those with security issues, which are returned by
// InsecureAlgorithms. The algorithms listed here are in preference order.
func SupportedAlgorithms() Algorithms {
return Algorithms{
Ciphers: slices.Clone(supportedCiphers),
MACs: slices.Clone(supportedMACs),
KeyExchanges: slices.Clone(supportedKexAlgos),
HostKeys: slices.Clone(supportedHostKeyAlgos),
PublicKeyAuths: slices.Clone(supportedPubKeyAuthAlgos),
}
}
// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods
// of authenticating servers) in preference order.
var supportedHostKeyAlgos = []string{
CertAlgoRSASHA256v01, CertAlgoRSASHA512v01,
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01,
CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01,
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
KeyAlgoRSASHA256, KeyAlgoRSASHA512,
KeyAlgoRSA, KeyAlgoDSA,
KeyAlgoED25519,
}
// supportedMACs specifies a default set of MAC algorithms in preference order.
// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
// because they have reached the end of their useful life.
var supportedMACs = []string{
"hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com", "hmac-sha2-256", "hmac-sha2-512", "hmac-sha1", "hmac-sha1-96",
// InsecureAlgorithms returns algorithms currently implemented by this package
// and which have security issues.
func InsecureAlgorithms() Algorithms {
return Algorithms{
KeyExchanges: slices.Clone(insecureKexAlgos),
Ciphers: slices.Clone(insecureCiphers),
MACs: slices.Clone(insecureMACs),
HostKeys: slices.Clone(insecureHostKeyAlgos),
PublicKeyAuths: slices.Clone(insecurePubKeyAuthAlgos),
}
}
var supportedCompressions = []string{compressionNone}
// hashFuncs keeps the mapping of supported signature algorithms to their
// respective hashes needed for signing and verification.
var hashFuncs = map[string]crypto.Hash{
KeyAlgoRSA: crypto.SHA1,
KeyAlgoRSASHA256: crypto.SHA256,
KeyAlgoRSASHA512: crypto.SHA512,
KeyAlgoDSA: crypto.SHA1,
KeyAlgoECDSA256: crypto.SHA256,
KeyAlgoECDSA384: crypto.SHA384,
KeyAlgoECDSA521: crypto.SHA512,
// KeyAlgoED25519 doesn't pre-hash.
KeyAlgoSKECDSA256: crypto.SHA256,
KeyAlgoSKED25519: crypto.SHA256,
}
// algorithmsForKeyFormat returns the supported signature algorithms for a given
// public key format (PublicKey.Type), in order of preference. See RFC 8332,
// Section 2. See also the note in sendKexInit on backwards compatibility.
@@ -120,11 +332,40 @@ func algorithmsForKeyFormat(keyFormat string) []string {
}
}
// keyFormatForAlgorithm returns the key format corresponding to the given
// signature algorithm. It returns an empty string if the signature algorithm is
// invalid or unsupported.
func keyFormatForAlgorithm(sigAlgo string) string {
switch sigAlgo {
case KeyAlgoRSA, KeyAlgoRSASHA256, KeyAlgoRSASHA512:
return KeyAlgoRSA
case CertAlgoRSAv01, CertAlgoRSASHA256v01, CertAlgoRSASHA512v01:
return CertAlgoRSAv01
case KeyAlgoED25519,
KeyAlgoSKED25519,
KeyAlgoSKECDSA256,
KeyAlgoECDSA256,
KeyAlgoECDSA384,
KeyAlgoECDSA521,
InsecureKeyAlgoDSA,
InsecureCertAlgoDSAv01,
CertAlgoECDSA256v01,
CertAlgoECDSA384v01,
CertAlgoECDSA521v01,
CertAlgoSKECDSA256v01,
CertAlgoED25519v01,
CertAlgoSKED25519v01:
return sigAlgo
default:
return ""
}
}
// isRSA returns whether algo is a supported RSA algorithm, including certificate
// algorithms.
func isRSA(algo string) bool {
algos := algorithmsForKeyFormat(KeyAlgoRSA)
return contains(algos, underlyingAlgo(algo))
return slices.Contains(algos, underlyingAlgo(algo))
}
func isRSACert(algo string) bool {
@@ -135,18 +376,6 @@ func isRSACert(algo string) bool {
return isRSA(algo)
}
// supportedPubKeyAuthAlgos specifies the supported client public key
// authentication algorithms. Note that this doesn't include certificate types
// since those use the underlying algorithm. This list is sent to the client if
// it supports the server-sig-algs extension. Order is irrelevant.
var supportedPubKeyAuthAlgos = []string{
KeyAlgoED25519,
KeyAlgoSKED25519, KeyAlgoSKECDSA256,
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA,
KeyAlgoDSA,
}
// unexpectedMessageError results when the SSH message that we received didn't
// match what we wanted.
func unexpectedMessageError(expected, got uint8) error {
@@ -158,7 +387,7 @@ func parseError(tag uint8) error {
return fmt.Errorf("ssh: parse error in message type %d", tag)
}
func findCommon(what string, client []string, server []string) (common string, err error) {
func findCommon(what string, client []string, server []string, isClient bool) (string, error) {
for _, c := range client {
for _, s := range server {
if c == s {
@@ -166,23 +395,49 @@ func findCommon(what string, client []string, server []string) (common string, e
}
}
}
return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server)
err := &AlgorithmNegotiationError{
What: what,
}
if isClient {
err.SupportedAlgorithms = client
err.RequestedAlgorithms = server
} else {
err.SupportedAlgorithms = server
err.RequestedAlgorithms = client
}
return "", err
}
// directionAlgorithms records algorithm choices in one direction (either read or write)
type directionAlgorithms struct {
// AlgorithmNegotiationError defines the error returned if the client and the
// server cannot agree on an algorithm for key exchange, host key, cipher, MAC.
type AlgorithmNegotiationError struct {
What string
// RequestedAlgorithms lists the algorithms supported by the peer.
RequestedAlgorithms []string
// SupportedAlgorithms lists the algorithms supported on our side.
SupportedAlgorithms []string
}
func (a *AlgorithmNegotiationError) Error() string {
return fmt.Sprintf("ssh: no common algorithm for %s; we offered: %v, peer offered: %v",
a.What, a.SupportedAlgorithms, a.RequestedAlgorithms)
}
// DirectionAlgorithms defines the algorithms negotiated in one direction
// (either read or write).
type DirectionAlgorithms struct {
Cipher string
MAC string
Compression string
compression string
}
// rekeyBytes returns a rekeying intervals in bytes.
func (a *directionAlgorithms) rekeyBytes() int64 {
func (a *DirectionAlgorithms) rekeyBytes() int64 {
// According to RFC 4344 block ciphers should rekey after
// 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is
// 128.
switch a.Cipher {
case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcm128CipherID, gcm256CipherID, aes128cbcID:
case CipherAES128CTR, CipherAES192CTR, CipherAES256CTR, CipherAES128GCM, CipherAES256GCM, InsecureCipherAES128CBC:
return 16 * (1 << 32)
}
@@ -192,66 +447,59 @@ func (a *directionAlgorithms) rekeyBytes() int64 {
}
var aeadCiphers = map[string]bool{
gcm128CipherID: true,
gcm256CipherID: true,
chacha20Poly1305ID: true,
CipherAES128GCM: true,
CipherAES256GCM: true,
CipherChaCha20Poly1305: true,
}
type algorithms struct {
kex string
hostKey string
w directionAlgorithms
r directionAlgorithms
}
func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *NegotiatedAlgorithms, err error) {
result := &NegotiatedAlgorithms{}
func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) {
result := &algorithms{}
result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos)
result.KeyExchange, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos, isClient)
if err != nil {
return
}
result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
result.HostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos, isClient)
if err != nil {
return
}
stoc, ctos := &result.w, &result.r
stoc, ctos := &result.Write, &result.Read
if isClient {
ctos, stoc = stoc, ctos
}
ctos.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
ctos.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer, isClient)
if err != nil {
return
}
stoc.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
stoc.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient, isClient)
if err != nil {
return
}
if !aeadCiphers[ctos.Cipher] {
ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer, isClient)
if err != nil {
return
}
}
if !aeadCiphers[stoc.Cipher] {
stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient, isClient)
if err != nil {
return
}
}
ctos.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
ctos.compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer, isClient)
if err != nil {
return
}
stoc.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
stoc.compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient, isClient)
if err != nil {
return
}
@@ -297,7 +545,7 @@ func (c *Config) SetDefaults() {
c.Rand = rand.Reader
}
if c.Ciphers == nil {
c.Ciphers = preferredCiphers
c.Ciphers = defaultCiphers
}
var ciphers []string
for _, c := range c.Ciphers {
@@ -309,19 +557,22 @@ func (c *Config) SetDefaults() {
c.Ciphers = ciphers
if c.KeyExchanges == nil {
c.KeyExchanges = preferredKexAlgos
c.KeyExchanges = defaultKexAlgos
}
var kexs []string
for _, k := range c.KeyExchanges {
if kexAlgoMap[k] != nil {
// Ignore the KEX if we have no kexAlgoMap definition.
kexs = append(kexs, k)
if k == KeyExchangeCurve25519 && !slices.Contains(c.KeyExchanges, keyExchangeCurve25519LibSSH) {
kexs = append(kexs, keyExchangeCurve25519LibSSH)
}
}
}
c.KeyExchanges = kexs
if c.MACs == nil {
c.MACs = supportedMACs
c.MACs = defaultMACs
}
var macs []string
for _, m := range c.MACs {
+12
View File
@@ -74,6 +74,13 @@ type Conn interface {
// Disconnect
}
// AlgorithmsConnMetadata is a ConnMetadata that can return the algorithms
// negotiated between client and server.
type AlgorithmsConnMetadata interface {
ConnMetadata
Algorithms() NegotiatedAlgorithms
}
// DiscardRequests consumes and rejects all requests from the
// passed-in channel.
func DiscardRequests(in <-chan *Request) {
@@ -106,6 +113,7 @@ type sshConn struct {
sessionID []byte
clientVersion []byte
serverVersion []byte
algorithms NegotiatedAlgorithms
}
func dup(src []byte) []byte {
@@ -141,3 +149,7 @@ func (c *sshConn) ClientVersion() []byte {
func (c *sshConn) ServerVersion() []byte {
return dup(c.serverVersion)
}
func (c *sshConn) Algorithms() NegotiatedAlgorithms {
return c.algorithms
}
+11
View File
@@ -16,8 +16,19 @@ References:
[PROTOCOL]: https://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL?rev=HEAD
[PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
[SSH-CERTS]: https://datatracker.ietf.org/doc/html/draft-miller-ssh-cert-01
[FIPS 140-3 mode]: https://go.dev/doc/security/fips140
This package does not fall under the stability promise of the Go language itself,
so its API may be changed when pressing needs arise.
# FIPS 140-3 mode
When the program is in [FIPS 140-3 mode], this package behaves as if only SP
800-140C and SP 800-140D approved cipher suites, signature algorithms,
certificate public key types and sizes, and key exchange and derivation
algorithms were implemented. Others are silently ignored and not negotiated, or
rejected. This set may depend on the algorithms supported by the FIPS 140-3 Go
Cryptographic Module selected with GOFIPS140, and may change across Go versions.
*/
package ssh
+19 -14
View File
@@ -10,6 +10,7 @@ import (
"io"
"log"
"net"
"slices"
"strings"
"sync"
)
@@ -38,7 +39,7 @@ type keyingTransport interface {
// prepareKeyChange sets up a key change. The key change for a
// direction will be effected if a msgNewKeys message is sent
// or received.
prepareKeyChange(*algorithms, *kexResult) error
prepareKeyChange(*NegotiatedAlgorithms, *kexResult) error
// setStrictMode sets the strict KEX mode, notably triggering
// sequence number resets on sending or receiving msgNewKeys.
@@ -115,7 +116,7 @@ type handshakeTransport struct {
bannerCallback BannerCallback
// Algorithms agreed in the last key exchange.
algorithms *algorithms
algorithms *NegotiatedAlgorithms
// Counters exclusively owned by readLoop.
readPacketsLeft uint32
@@ -164,7 +165,7 @@ func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byt
if config.HostKeyAlgorithms != nil {
t.hostKeyAlgorithms = config.HostKeyAlgorithms
} else {
t.hostKeyAlgorithms = supportedHostKeyAlgos
t.hostKeyAlgorithms = defaultHostKeyAlgos
}
go t.readLoop()
go t.kexLoop()
@@ -184,6 +185,10 @@ func (t *handshakeTransport) getSessionID() []byte {
return t.sessionID
}
func (t *handshakeTransport) getAlgorithms() NegotiatedAlgorithms {
return *t.algorithms
}
// waitSession waits for the session to be established. This should be
// the first thing to call after instantiating handshakeTransport.
func (t *handshakeTransport) waitSession() error {
@@ -290,7 +295,7 @@ func (t *handshakeTransport) resetWriteThresholds() {
if t.config.RekeyThreshold > 0 {
t.writeBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil {
t.writeBytesLeft = t.algorithms.w.rekeyBytes()
t.writeBytesLeft = t.algorithms.Write.rekeyBytes()
} else {
t.writeBytesLeft = 1 << 30
}
@@ -407,7 +412,7 @@ func (t *handshakeTransport) resetReadThresholds() {
if t.config.RekeyThreshold > 0 {
t.readBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil {
t.readBytesLeft = t.algorithms.r.rekeyBytes()
t.readBytesLeft = t.algorithms.Read.rekeyBytes()
} else {
t.readBytesLeft = 1 << 30
}
@@ -523,7 +528,7 @@ func (t *handshakeTransport) sendKexInit() error {
switch s := k.(type) {
case MultiAlgorithmSigner:
for _, algo := range algorithmsForKeyFormat(keyFormat) {
if contains(s.Algorithms(), underlyingAlgo(algo)) {
if slices.Contains(s.Algorithms(), underlyingAlgo(algo)) {
msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algo)
}
}
@@ -675,7 +680,7 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
return err
}
if t.sessionID == nil && ((isClient && contains(serverInit.KexAlgos, kexStrictServer)) || (!isClient && contains(clientInit.KexAlgos, kexStrictClient))) {
if t.sessionID == nil && ((isClient && slices.Contains(serverInit.KexAlgos, kexStrictServer)) || (!isClient && slices.Contains(clientInit.KexAlgos, kexStrictClient))) {
t.strictMode = true
if err := t.conn.setStrictMode(); err != nil {
return err
@@ -700,9 +705,9 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
}
}
kex, ok := kexAlgoMap[t.algorithms.kex]
kex, ok := kexAlgoMap[t.algorithms.KeyExchange]
if !ok {
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.KeyExchange)
}
var result *kexResult
@@ -732,7 +737,7 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
// On the server side, after the first SSH_MSG_NEWKEYS, send a SSH_MSG_EXT_INFO
// message with the server-sig-algs extension if the client supports it. See
// RFC 8308, Sections 2.4 and 3.1, and [PROTOCOL], Section 1.9.
if !isClient && firstKeyExchange && contains(clientInit.KexAlgos, "ext-info-c") {
if !isClient && firstKeyExchange && slices.Contains(clientInit.KexAlgos, "ext-info-c") {
supportedPubKeyAuthAlgosList := strings.Join(t.publicKeyAuthAlgorithms, ",")
extInfo := &extInfoMsg{
NumExtensions: 2,
@@ -786,7 +791,7 @@ func (a algorithmSignerWrapper) SignWithAlgorithm(rand io.Reader, data []byte, a
func pickHostKey(hostKeys []Signer, algo string) AlgorithmSigner {
for _, k := range hostKeys {
if s, ok := k.(MultiAlgorithmSigner); ok {
if !contains(s.Algorithms(), underlyingAlgo(algo)) {
if !slices.Contains(s.Algorithms(), underlyingAlgo(algo)) {
continue
}
}
@@ -809,12 +814,12 @@ func pickHostKey(hostKeys []Signer, algo string) AlgorithmSigner {
}
func (t *handshakeTransport) server(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) {
hostKey := pickHostKey(t.hostKeys, t.algorithms.hostKey)
hostKey := pickHostKey(t.hostKeys, t.algorithms.HostKey)
if hostKey == nil {
return nil, errors.New("ssh: internal error: negotiated unsupported signature type")
}
r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.hostKey)
r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.HostKey)
return r, err
}
@@ -829,7 +834,7 @@ func (t *handshakeTransport) client(kex kexAlgorithm, magics *handshakeMagics) (
return nil, err
}
if err := verifyHostKeySignature(hostKey, t.algorithms.hostKey, result); err != nil {
if err := verifyHostKeySignature(hostKey, t.algorithms.HostKey, result); err != nil {
return nil, err
}
+88 -67
View File
@@ -8,33 +8,31 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/fips140"
"crypto/rand"
"crypto/subtle"
"encoding/binary"
"errors"
"fmt"
"io"
"math/big"
"slices"
"golang.org/x/crypto/curve25519"
)
const (
kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1"
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
kexAlgoDH14SHA256 = "diffie-hellman-group14-sha256"
kexAlgoDH16SHA512 = "diffie-hellman-group16-sha512"
kexAlgoECDH256 = "ecdh-sha2-nistp256"
kexAlgoECDH384 = "ecdh-sha2-nistp384"
kexAlgoECDH521 = "ecdh-sha2-nistp521"
kexAlgoCurve25519SHA256LibSSH = "curve25519-sha256@libssh.org"
kexAlgoCurve25519SHA256 = "curve25519-sha256"
// For the following kex only the client half contains a production
// ready implementation. The server half only consists of a minimal
// implementation to satisfy the automated tests.
kexAlgoDHGEXSHA1 = "diffie-hellman-group-exchange-sha1"
kexAlgoDHGEXSHA256 = "diffie-hellman-group-exchange-sha256"
// This is the group called diffie-hellman-group1-sha1 in RFC 4253 and
// Oakley Group 2 in RFC 2409.
oakleyGroup2 = "FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF"
// This is the group called diffie-hellman-group14-sha1 in RFC 4253 and
// Oakley Group 14 in RFC 3526.
oakleyGroup14 = "FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF"
// This is the group called diffie-hellman-group15-sha512 in RFC 8268 and
// Oakley Group 15 in RFC 3526.
oakleyGroup15 = "FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF"
// This is the group called diffie-hellman-group16-sha512 in RFC 8268 and
// Oakley Group 16 in RFC 3526.
oakleyGroup16 = "FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF"
)
// kexResult captures the outcome of a key exchange.
@@ -399,56 +397,64 @@ func ecHash(curve elliptic.Curve) crypto.Hash {
return crypto.SHA512
}
// kexAlgoMap defines the supported KEXs. KEXs not included are not supported
// and will not be negotiated, even if explicitly configured. When FIPS mode is
// enabled, only FIPS-approved algorithms are included.
var kexAlgoMap = map[string]kexAlgorithm{}
func init() {
// This is the group called diffie-hellman-group1-sha1 in
// RFC 4253 and Oakley Group 2 in RFC 2409.
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16)
kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{
// mlkem768x25519-sha256 we'll work with fips140=on but not fips140=only
// until Go 1.26.
kexAlgoMap[KeyExchangeMLKEM768X25519] = &mlkem768WithCurve25519sha256{}
kexAlgoMap[KeyExchangeECDHP521] = &ecdh{elliptic.P521()}
kexAlgoMap[KeyExchangeECDHP384] = &ecdh{elliptic.P384()}
kexAlgoMap[KeyExchangeECDHP256] = &ecdh{elliptic.P256()}
if fips140.Enabled() {
defaultKexAlgos = slices.DeleteFunc(defaultKexAlgos, func(algo string) bool {
_, ok := kexAlgoMap[algo]
return !ok
})
return
}
p, _ := new(big.Int).SetString(oakleyGroup2, 16)
kexAlgoMap[InsecureKeyExchangeDH1SHA1] = &dhGroup{
g: new(big.Int).SetInt64(2),
p: p,
pMinus1: new(big.Int).Sub(p, bigOne),
hashFunc: crypto.SHA1,
}
// This are the groups called diffie-hellman-group14-sha1 and
// diffie-hellman-group14-sha256 in RFC 4253 and RFC 8268,
// and Oakley Group 14 in RFC 3526.
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
p, _ = new(big.Int).SetString(oakleyGroup14, 16)
group14 := &dhGroup{
g: new(big.Int).SetInt64(2),
p: p,
pMinus1: new(big.Int).Sub(p, bigOne),
}
kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{
kexAlgoMap[InsecureKeyExchangeDH14SHA1] = &dhGroup{
g: group14.g, p: group14.p, pMinus1: group14.pMinus1,
hashFunc: crypto.SHA1,
}
kexAlgoMap[kexAlgoDH14SHA256] = &dhGroup{
kexAlgoMap[KeyExchangeDH14SHA256] = &dhGroup{
g: group14.g, p: group14.p, pMinus1: group14.pMinus1,
hashFunc: crypto.SHA256,
}
// This is the group called diffie-hellman-group16-sha512 in RFC
// 8268 and Oakley Group 16 in RFC 3526.
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF", 16)
p, _ = new(big.Int).SetString(oakleyGroup16, 16)
kexAlgoMap[kexAlgoDH16SHA512] = &dhGroup{
kexAlgoMap[KeyExchangeDH16SHA512] = &dhGroup{
g: new(big.Int).SetInt64(2),
p: p,
pMinus1: new(big.Int).Sub(p, bigOne),
hashFunc: crypto.SHA512,
}
kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()}
kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()}
kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()}
kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{}
kexAlgoMap[kexAlgoCurve25519SHA256LibSSH] = &curve25519sha256{}
kexAlgoMap[kexAlgoDHGEXSHA1] = &dhGEXSHA{hashFunc: crypto.SHA1}
kexAlgoMap[kexAlgoDHGEXSHA256] = &dhGEXSHA{hashFunc: crypto.SHA256}
kexAlgoMap[KeyExchangeCurve25519] = &curve25519sha256{}
kexAlgoMap[keyExchangeCurve25519LibSSH] = &curve25519sha256{}
kexAlgoMap[InsecureKeyExchangeDHGEXSHA1] = &dhGEXSHA{hashFunc: crypto.SHA1}
kexAlgoMap[KeyExchangeDHGEXSHA256] = &dhGEXSHA{hashFunc: crypto.SHA256}
}
// curve25519sha256 implements the curve25519-sha256 (formerly known as
@@ -464,15 +470,17 @@ func (kp *curve25519KeyPair) generate(rand io.Reader) error {
if _, err := io.ReadFull(rand, kp.priv[:]); err != nil {
return err
}
curve25519.ScalarBaseMult(&kp.pub, &kp.priv)
p, err := curve25519.X25519(kp.priv[:], curve25519.Basepoint)
if err != nil {
return fmt.Errorf("curve25519: %w", err)
}
if len(p) != 32 {
return fmt.Errorf("curve25519: internal error: X25519 returned %d bytes, expected 32", len(p))
}
copy(kp.pub[:], p)
return nil
}
// curve25519Zeros is just an array of 32 zero bytes so that we have something
// convenient to compare against in order to reject curve25519 points with the
// wrong order.
var curve25519Zeros [32]byte
func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) {
var kp curve25519KeyPair
if err := kp.generate(rand); err != nil {
@@ -495,11 +503,9 @@ func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handsh
return nil, errors.New("ssh: peer's curve25519 public value has wrong length")
}
var servPub, secret [32]byte
copy(servPub[:], reply.EphemeralPubKey)
curve25519.ScalarMult(&secret, &kp.priv, &servPub)
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 {
return nil, errors.New("ssh: peer's curve25519 public value has wrong order")
secret, err := curve25519.X25519(kp.priv[:], reply.EphemeralPubKey)
if err != nil {
return nil, fmt.Errorf("ssh: peer's curve25519 public value is not valid: %w", err)
}
h := crypto.SHA256.New()
@@ -541,11 +547,9 @@ func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handsh
return nil, err
}
var clientPub, secret [32]byte
copy(clientPub[:], kexInit.ClientPubKey)
curve25519.ScalarMult(&secret, &kp.priv, &clientPub)
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 {
return nil, errors.New("ssh: peer's curve25519 public value has wrong order")
secret, err := curve25519.X25519(kp.priv[:], kexInit.ClientPubKey)
if err != nil {
return nil, fmt.Errorf("ssh: peer's curve25519 public value is not valid: %w", err)
}
hostKeyBytes := priv.PublicKey().Marshal()
@@ -601,9 +605,9 @@ const (
func (gex *dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
// Send GexRequest
kexDHGexRequest := kexDHGexRequestMsg{
MinBits: dhGroupExchangeMinimumBits,
PreferedBits: dhGroupExchangePreferredBits,
MaxBits: dhGroupExchangeMaximumBits,
MinBits: dhGroupExchangeMinimumBits,
PreferredBits: dhGroupExchangePreferredBits,
MaxBits: dhGroupExchangeMaximumBits,
}
if err := c.writePacket(Marshal(&kexDHGexRequest)); err != nil {
return nil, err
@@ -690,9 +694,7 @@ func (gex *dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshak
}
// Server half implementation of the Diffie Hellman Key Exchange with SHA1 and SHA256.
//
// This is a minimal implementation to satisfy the automated tests.
func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
func (gex *dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
// Receive GexRequest
packet, err := c.readPacket()
if err != nil {
@@ -702,13 +704,32 @@ func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshake
if err = Unmarshal(packet, &kexDHGexRequest); err != nil {
return
}
// We check that the request received is valid and that the MaxBits
// requested are at least equal to our supported minimum. This is the same
// check done in OpenSSH:
// https://github.com/openssh/openssh-portable/blob/80a2f64b/kexgexs.c#L94
//
// Furthermore, we also check that the required MinBits are less than or
// equal to 4096 because we can use up to Oakley Group 16.
if kexDHGexRequest.MaxBits < kexDHGexRequest.MinBits || kexDHGexRequest.PreferredBits < kexDHGexRequest.MinBits ||
kexDHGexRequest.MaxBits < kexDHGexRequest.PreferredBits || kexDHGexRequest.MaxBits < dhGroupExchangeMinimumBits ||
kexDHGexRequest.MinBits > 4096 {
return nil, fmt.Errorf("ssh: DH GEX request out of range, min: %d, max: %d, preferred: %d", kexDHGexRequest.MinBits,
kexDHGexRequest.MaxBits, kexDHGexRequest.PreferredBits)
}
var p *big.Int
// We hardcode sending Oakley Group 14 (2048 bits), Oakley Group 15 (3072
// bits) or Oakley Group 16 (4096 bits), based on the requested max size.
if kexDHGexRequest.MaxBits < 3072 {
p, _ = new(big.Int).SetString(oakleyGroup14, 16)
} else if kexDHGexRequest.MaxBits < 4096 {
p, _ = new(big.Int).SetString(oakleyGroup15, 16)
} else {
p, _ = new(big.Int).SetString(oakleyGroup16, 16)
}
// Send GexGroup
// This is the group called diffie-hellman-group14-sha1 in RFC
// 4253 and Oakley Group 14 in RFC 3526.
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
g := big.NewInt(2)
msg := &kexDHGexGroupMsg{
P: p,
G: g,
@@ -746,9 +767,9 @@ func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshake
h := gex.hashFunc.New()
magics.write(h)
writeString(h, hostKeyBytes)
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits))
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits))
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits))
binary.Write(h, binary.BigEndian, kexDHGexRequest.MinBits)
binary.Write(h, binary.BigEndian, kexDHGexRequest.PreferredBits)
binary.Write(h, binary.BigEndian, kexDHGexRequest.MaxBits)
writeInt(h, p)
writeInt(h, g)
writeInt(h, kexDHGexInit.X)
+71 -26
View File
@@ -27,6 +27,7 @@ import (
"fmt"
"io"
"math/big"
"slices"
"strings"
"golang.org/x/crypto/ssh/internal/bcrypt_pbkdf"
@@ -36,14 +37,19 @@ import (
// ClientConfig.HostKeyAlgorithms, Signature.Format, or as AlgorithmSigner
// arguments.
const (
KeyAlgoRSA = "ssh-rsa"
KeyAlgoDSA = "ssh-dss"
KeyAlgoECDSA256 = "ecdsa-sha2-nistp256"
KeyAlgoSKECDSA256 = "sk-ecdsa-sha2-nistp256@openssh.com"
KeyAlgoECDSA384 = "ecdsa-sha2-nistp384"
KeyAlgoECDSA521 = "ecdsa-sha2-nistp521"
KeyAlgoED25519 = "ssh-ed25519"
KeyAlgoSKED25519 = "sk-ssh-ed25519@openssh.com"
KeyAlgoRSA = "ssh-rsa"
// Deprecated: DSA is only supported at insecure key sizes, and was removed
// from major implementations.
KeyAlgoDSA = InsecureKeyAlgoDSA
// Deprecated: DSA is only supported at insecure key sizes, and was removed
// from major implementations.
InsecureKeyAlgoDSA = "ssh-dss"
KeyAlgoECDSA256 = "ecdsa-sha2-nistp256"
KeyAlgoSKECDSA256 = "sk-ecdsa-sha2-nistp256@openssh.com"
KeyAlgoECDSA384 = "ecdsa-sha2-nistp384"
KeyAlgoECDSA521 = "ecdsa-sha2-nistp521"
KeyAlgoED25519 = "ssh-ed25519"
KeyAlgoSKED25519 = "sk-ssh-ed25519@openssh.com"
// KeyAlgoRSASHA256 and KeyAlgoRSASHA512 are only public key algorithms, not
// public key formats, so they can't appear as a PublicKey.Type. The
@@ -67,7 +73,7 @@ func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err err
switch algo {
case KeyAlgoRSA:
return parseRSA(in)
case KeyAlgoDSA:
case InsecureKeyAlgoDSA:
return parseDSA(in)
case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521:
return parseECDSA(in)
@@ -77,13 +83,18 @@ func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err err
return parseED25519(in)
case KeyAlgoSKED25519:
return parseSKEd25519(in)
case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01:
case CertAlgoRSAv01, InsecureCertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01:
cert, err := parseCert(in, certKeyAlgoNames[algo])
if err != nil {
return nil, nil, err
}
return cert, nil, nil
}
if keyFormat := keyFormatForAlgorithm(algo); keyFormat != "" {
return nil, nil, fmt.Errorf("ssh: signature algorithm %q isn't a key format; key is malformed and should be re-encoded with type %q",
algo, keyFormat)
}
return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", algo)
}
@@ -186,9 +197,10 @@ func ParseKnownHosts(in []byte) (marker string, hosts []string, pubKey PublicKey
return "", nil, nil, "", nil, io.EOF
}
// ParseAuthorizedKey parses a public key from an authorized_keys
// file used in OpenSSH according to the sshd(8) manual page.
// ParseAuthorizedKey parses a public key from an authorized_keys file used in
// OpenSSH according to the sshd(8) manual page. Invalid lines are ignored.
func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) {
var lastErr error
for len(in) > 0 {
end := bytes.IndexByte(in, '\n')
if end != -1 {
@@ -217,6 +229,8 @@ func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []str
if out, comment, err = parseAuthorizedKey(in[i:]); err == nil {
return out, comment, options, rest, nil
} else {
lastErr = err
}
// No key type recognised. Maybe there's an options field at
@@ -259,16 +273,22 @@ func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []str
if out, comment, err = parseAuthorizedKey(in[i:]); err == nil {
options = candidateOptions
return out, comment, options, rest, nil
} else {
lastErr = err
}
in = rest
continue
}
if lastErr != nil {
return nil, "", nil, nil, fmt.Errorf("ssh: no key found; last parsing error for ignored line: %w", lastErr)
}
return nil, "", nil, nil, errors.New("ssh: no key found")
}
// ParsePublicKey parses an SSH public key formatted for use in
// ParsePublicKey parses an SSH public key or certificate formatted for use in
// the SSH wire protocol according to RFC 4253, section 6.6.
func ParsePublicKey(in []byte) (out PublicKey, err error) {
algo, in, ok := parseString(in)
@@ -390,11 +410,11 @@ func NewSignerWithAlgorithms(signer AlgorithmSigner, algorithms []string) (Multi
}
for _, algo := range algorithms {
if !contains(supportedAlgos, algo) {
if !slices.Contains(supportedAlgos, algo) {
return nil, fmt.Errorf("ssh: algorithm %q is not supported for key type %q",
algo, signer.PublicKey().Type())
}
if !contains(signerAlgos, algo) {
if !slices.Contains(signerAlgos, algo) {
return nil, fmt.Errorf("ssh: algorithm %q is restricted for the provided signer", algo)
}
}
@@ -481,10 +501,13 @@ func (r *rsaPublicKey) Marshal() []byte {
func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error {
supportedAlgos := algorithmsForKeyFormat(r.Type())
if !contains(supportedAlgos, sig.Format) {
if !slices.Contains(supportedAlgos, sig.Format) {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type())
}
hash := hashFuncs[sig.Format]
hash, err := hashFunc(sig.Format)
if err != nil {
return err
}
h := hash.New()
h.Write(data)
digest := h.Sum(nil)
@@ -601,7 +624,11 @@ func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
}
h := hashFuncs[sig.Format].New()
hash, err := hashFunc(sig.Format)
if err != nil {
return err
}
h := hash.New()
h.Write(data)
digest := h.Sum(nil)
@@ -646,7 +673,11 @@ func (k *dsaPrivateKey) SignWithAlgorithm(rand io.Reader, data []byte, algorithm
return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm)
}
h := hashFuncs[k.PublicKey().Type()].New()
hash, err := hashFunc(k.PublicKey().Type())
if err != nil {
return nil, err
}
h := hash.New()
h.Write(data)
digest := h.Sum(nil)
r, s, err := dsa.Sign(rand, k.PrivateKey, digest)
@@ -796,8 +827,11 @@ func (k *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
}
h := hashFuncs[sig.Format].New()
hash, err := hashFunc(sig.Format)
if err != nil {
return err
}
h := hash.New()
h.Write(data)
digest := h.Sum(nil)
@@ -900,8 +934,11 @@ func (k *skECDSAPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
}
h := hashFuncs[sig.Format].New()
hash, err := hashFunc(sig.Format)
if err != nil {
return err
}
h := hash.New()
h.Write([]byte(k.application))
appDigest := h.Sum(nil)
@@ -1004,7 +1041,11 @@ func (k *skEd25519PublicKey) Verify(data []byte, sig *Signature) error {
return fmt.Errorf("invalid size %d for Ed25519 public key", l)
}
h := hashFuncs[sig.Format].New()
hash, err := hashFunc(sig.Format)
if err != nil {
return err
}
h := hash.New()
h.Write([]byte(k.application))
appDigest := h.Sum(nil)
@@ -1107,11 +1148,14 @@ func (s *wrappedSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm
algorithm = s.pubKey.Type()
}
if !contains(s.Algorithms(), algorithm) {
if !slices.Contains(s.Algorithms(), algorithm) {
return nil, fmt.Errorf("ssh: unsupported signature algorithm %q for key format %q", algorithm, s.pubKey.Type())
}
hashFunc := hashFuncs[algorithm]
hashFunc, err := hashFunc(algorithm)
if err != nil {
return nil, err
}
var digest []byte
if hashFunc != 0 {
h := hashFunc.New()
@@ -1446,6 +1490,7 @@ type openSSHEncryptedPrivateKey struct {
NumKeys uint32
PubKey []byte
PrivKeyBlock []byte
Rest []byte `ssh:"rest"`
}
type openSSHPrivateKey struct {
+29 -13
View File
@@ -7,11 +7,13 @@ package ssh
// Message authentication support
import (
"crypto/fips140"
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"hash"
"slices"
)
type macMode struct {
@@ -46,23 +48,37 @@ func (t truncatingMAC) Size() int {
func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() }
var macModes = map[string]*macMode{
"hmac-sha2-512-etm@openssh.com": {64, true, func(key []byte) hash.Hash {
// macModes defines the supported MACs. MACs not included are not supported
// and will not be negotiated, even if explicitly configured. When FIPS mode is
// enabled, only FIPS-approved algorithms are included.
var macModes = map[string]*macMode{}
func init() {
macModes[HMACSHA512ETM] = &macMode{64, true, func(key []byte) hash.Hash {
return hmac.New(sha512.New, key)
}},
"hmac-sha2-256-etm@openssh.com": {32, true, func(key []byte) hash.Hash {
}}
macModes[HMACSHA256ETM] = &macMode{32, true, func(key []byte) hash.Hash {
return hmac.New(sha256.New, key)
}},
"hmac-sha2-512": {64, false, func(key []byte) hash.Hash {
}}
macModes[HMACSHA512] = &macMode{64, false, func(key []byte) hash.Hash {
return hmac.New(sha512.New, key)
}},
"hmac-sha2-256": {32, false, func(key []byte) hash.Hash {
}}
macModes[HMACSHA256] = &macMode{32, false, func(key []byte) hash.Hash {
return hmac.New(sha256.New, key)
}},
"hmac-sha1": {20, false, func(key []byte) hash.Hash {
}}
if fips140.Enabled() {
defaultMACs = slices.DeleteFunc(defaultMACs, func(algo string) bool {
_, ok := macModes[algo]
return !ok
})
return
}
macModes[HMACSHA1] = &macMode{20, false, func(key []byte) hash.Hash {
return hmac.New(sha1.New, key)
}},
"hmac-sha1-96": {20, false, func(key []byte) hash.Hash {
}}
macModes[InsecureHMACSHA196] = &macMode{20, false, func(key []byte) hash.Hash {
return truncatingMAC{12, hmac.New(sha1.New, key)}
}},
}}
}
+4 -4
View File
@@ -122,9 +122,9 @@ type kexDHGexReplyMsg struct {
const msgKexDHGexRequest = 34
type kexDHGexRequestMsg struct {
MinBits uint32 `sshtype:"34"`
PreferedBits uint32
MaxBits uint32
MinBits uint32 `sshtype:"34"`
PreferredBits uint32
MaxBits uint32
}
// See RFC 4253, section 10.
@@ -792,7 +792,7 @@ func marshalString(to []byte, s []byte) []byte {
return to[len(s):]
}
var bigIntType = reflect.TypeOf((*big.Int)(nil))
var bigIntType = reflect.TypeFor[*big.Int]()
// Decode a packet into its corresponding message.
func decode(packet []byte) (interface{}, error) {
-19
View File
@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.24
package ssh
import (
@@ -13,27 +11,10 @@ import (
"errors"
"fmt"
"io"
"runtime"
"slices"
"golang.org/x/crypto/curve25519"
)
const (
kexAlgoMLKEM768xCurve25519SHA256 = "mlkem768x25519-sha256"
)
func init() {
// After Go 1.24rc1 mlkem swapped the order of return values of Encapsulate.
// See #70950.
if runtime.Version() == "go1.24rc1" {
return
}
supportedKexAlgos = slices.Insert(supportedKexAlgos, 0, kexAlgoMLKEM768xCurve25519SHA256)
preferredKexAlgos = slices.Insert(preferredKexAlgos, 0, kexAlgoMLKEM768xCurve25519SHA256)
kexAlgoMap[kexAlgoMLKEM768xCurve25519SHA256] = &mlkem768WithCurve25519sha256{}
}
// mlkem768WithCurve25519sha256 implements the hybrid ML-KEM768 with
// curve25519-sha256 key exchange method, as described by
// draft-kampanakis-curdle-ssh-pq-ke-05 section 2.3.3.
+34 -12
View File
@@ -10,6 +10,7 @@ import (
"fmt"
"io"
"net"
"slices"
"strings"
)
@@ -43,6 +44,9 @@ type Permissions struct {
// pass data from the authentication callbacks to the server
// application layer.
Extensions map[string]string
// ExtraData allows to store user defined data.
ExtraData map[any]any
}
type GSSAPIWithMICConfig struct {
@@ -126,6 +130,21 @@ type ServerConfig struct {
// Permissions.Extensions entry.
PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
// VerifiedPublicKeyCallback, if non-nil, is called after a client
// successfully confirms having control over a key that was previously
// approved by PublicKeyCallback. The permissions object passed to the
// callback is the one returned by PublicKeyCallback for the given public
// key and its ownership is transferred to the callback. The returned
// Permissions object can be the same object, optionally modified, or a
// completely new object. If VerifiedPublicKeyCallback is non-nil,
// PublicKeyCallback is not allowed to return a PartialSuccessError, which
// can instead be returned by VerifiedPublicKeyCallback.
//
// VerifiedPublicKeyCallback does not affect which authentication methods
// are included in the list of methods that can be attempted by the client.
VerifiedPublicKeyCallback func(conn ConnMetadata, key PublicKey, permissions *Permissions,
signatureAlgorithm string) (*Permissions, error)
// KeyboardInteractiveCallback, if non-nil, is called when
// keyboard-interactive authentication is selected (RFC
// 4256). The client object's Challenge function should be
@@ -243,22 +262,15 @@ func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewCha
fullConf.MaxAuthTries = 6
}
if len(fullConf.PublicKeyAuthAlgorithms) == 0 {
fullConf.PublicKeyAuthAlgorithms = supportedPubKeyAuthAlgos
fullConf.PublicKeyAuthAlgorithms = defaultPubKeyAuthAlgos
} else {
for _, algo := range fullConf.PublicKeyAuthAlgorithms {
if !contains(supportedPubKeyAuthAlgos, algo) {
if !slices.Contains(SupportedAlgorithms().PublicKeyAuths, algo) && !slices.Contains(InsecureAlgorithms().PublicKeyAuths, algo) {
c.Close()
return nil, nil, nil, fmt.Errorf("ssh: unsupported public key authentication algorithm %s", algo)
}
}
}
// Check if the config contains any unsupported key exchanges
for _, kex := range fullConf.KeyExchanges {
if _, ok := serverForbiddenKexAlgos[kex]; ok {
c.Close()
return nil, nil, nil, fmt.Errorf("ssh: unsupported key exchange %s for server", kex)
}
}
s := &connection{
sshConn: sshConn{conn: c},
@@ -315,6 +327,7 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
// We just did the key change, so the session ID is established.
s.sessionID = s.transport.getSessionID()
s.algorithms = s.transport.getAlgorithms()
var packet []byte
if packet, err = s.transport.readPacket(); err != nil {
@@ -637,7 +650,7 @@ userAuthLoop:
return nil, parseError(msgUserAuthRequest)
}
algo := string(algoBytes)
if !contains(config.PublicKeyAuthAlgorithms, underlyingAlgo(algo)) {
if !slices.Contains(config.PublicKeyAuthAlgorithms, underlyingAlgo(algo)) {
authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo)
break
}
@@ -658,6 +671,9 @@ userAuthLoop:
candidate.pubKeyData = pubKeyData
candidate.perms, candidate.result = authConfig.PublicKeyCallback(s, pubKey)
_, isPartialSuccessError := candidate.result.(*PartialSuccessError)
if isPartialSuccessError && config.VerifiedPublicKeyCallback != nil {
return nil, errors.New("ssh: invalid library usage: PublicKeyCallback must not return partial success when VerifiedPublicKeyCallback is defined")
}
if (candidate.result == nil || isPartialSuccessError) &&
candidate.perms != nil &&
@@ -701,7 +717,7 @@ userAuthLoop:
// ssh-rsa-cert-v01@openssh.com algorithm with ssh-rsa public
// key type. The algorithm and public key type must be
// consistent: both must be certificate algorithms, or neither.
if !contains(algorithmsForKeyFormat(pubKey.Type()), algo) {
if !slices.Contains(algorithmsForKeyFormat(pubKey.Type()), algo) {
authErr = fmt.Errorf("ssh: public key type %q not compatible with selected algorithm %q",
pubKey.Type(), algo)
break
@@ -711,7 +727,7 @@ userAuthLoop:
// algorithm name that corresponds to algo with
// sig.Format. This is usually the same, but
// for certs, the names differ.
if !contains(config.PublicKeyAuthAlgorithms, sig.Format) {
if !slices.Contains(config.PublicKeyAuthAlgorithms, sig.Format) {
authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format)
break
}
@@ -728,6 +744,12 @@ userAuthLoop:
authErr = candidate.result
perms = candidate.perms
if authErr == nil && config.VerifiedPublicKeyCallback != nil {
// Only call VerifiedPublicKeyCallback after the key has been accepted
// and successfully verified. If authErr is non-nil, the key is not
// considered verified and the callback must not run.
perms, authErr = config.VerifiedPublicKeyCallback(s, pubKey, perms, algo)
}
}
case "gssapi-with-mic":
if authConfig.GSSAPIWithMICConfig == nil {
+7 -1
View File
@@ -106,6 +106,13 @@ func parseGSSAPIPayload(payload []byte) (*userAuthRequestGSSAPI, error) {
if !ok {
return nil, errors.New("parse uint32 failed")
}
// Each ASN.1 encoded OID must have a minimum
// of 2 bytes; 64 maximum mechanisms is an
// arbitrary, but reasonable ceiling.
const maxMechs = 64
if n > maxMechs || int(n)*2 > len(rest) {
return nil, errors.New("invalid mechanism count")
}
s := &userAuthRequestGSSAPI{
N: n,
OIDS: make([]asn1.ObjectIdentifier, n),
@@ -122,7 +129,6 @@ func parseGSSAPIPayload(payload []byte) (*userAuthRequestGSSAPI, error) {
if rest, err = asn1.Unmarshal(desiredMech, &s.OIDS[i]); err != nil {
return nil, err
}
}
return s, nil
}
+2 -2
View File
@@ -44,7 +44,7 @@ func (c *Client) ListenUnix(socketPath string) (net.Listener, error) {
if !ok {
return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer")
}
ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"})
ch := c.forwards.add("unix", socketPath)
return &unixListener{socketPath, c, ch}, nil
}
@@ -96,7 +96,7 @@ func (l *unixListener) Accept() (net.Conn, error) {
// Close closes the listener.
func (l *unixListener) Close() error {
// this also closes the listener.
l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"})
l.conn.forwards.remove("unix", l.socketPath)
m := streamLocalChannelForwardMsg{
l.socketPath,
}
+80 -44
View File
@@ -11,6 +11,7 @@ import (
"io"
"math/rand"
"net"
"net/netip"
"strconv"
"strings"
"sync"
@@ -22,14 +23,21 @@ import (
// the returned net.Listener. The listener must be serviced, or the
// SSH connection may hang.
// N must be "tcp", "tcp4", "tcp6", or "unix".
//
// If the address is a hostname, it is sent to the remote peer as-is, without
// being resolved locally, and the Listener Addr method will return a zero IP.
func (c *Client) Listen(n, addr string) (net.Listener, error) {
switch n {
case "tcp", "tcp4", "tcp6":
laddr, err := net.ResolveTCPAddr(n, addr)
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
return c.ListenTCP(laddr)
port, err := strconv.ParseInt(portStr, 10, 32)
if err != nil {
return nil, err
}
return c.listenTCPInternal(host, int(port))
case "unix":
return c.ListenUnix(addr)
default:
@@ -102,15 +110,24 @@ func (c *Client) handleForwards() {
// ListenTCP requests the remote peer open a listening socket
// on laddr. Incoming connections will be available by calling
// Accept on the returned net.Listener.
//
// ListenTCP accepts an IP address, to provide a hostname use [Client.Listen]
// with "tcp", "tcp4", or "tcp6" network instead.
func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
c.handleForwardsOnce.Do(c.handleForwards)
if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
return c.autoPortListenWorkaround(laddr)
}
return c.listenTCPInternal(laddr.IP.String(), laddr.Port)
}
func (c *Client) listenTCPInternal(host string, port int) (net.Listener, error) {
c.handleForwardsOnce.Do(c.handleForwards)
m := channelForwardMsg{
laddr.IP.String(),
uint32(laddr.Port),
host,
uint32(port),
}
// send message
ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m))
@@ -123,20 +140,33 @@ func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
// If the original port was 0, then the remote side will
// supply a real port number in the response.
if laddr.Port == 0 {
if port == 0 {
var p struct {
Port uint32
}
if err := Unmarshal(resp, &p); err != nil {
return nil, err
}
laddr.Port = int(p.Port)
port = int(p.Port)
}
// Construct a local address placeholder for the remote listener. If the
// original host is an IP address, preserve it so that Listener.Addr()
// reports the same IP. If the host is a hostname or cannot be parsed as an
// IP, fall back to IPv4zero. The port field is always set, even if the
// original port was 0, because in that case the remote server will assign
// one, allowing callers to determine which port was selected.
ip := net.IPv4zero
if parsed, err := netip.ParseAddr(host); err == nil {
ip = net.IP(parsed.AsSlice())
}
laddr := &net.TCPAddr{
IP: ip,
Port: port,
}
addr := net.JoinHostPort(host, strconv.FormatInt(int64(port), 10))
ch := c.forwards.add("tcp", addr)
// Register this forward, using the port number we obtained.
ch := c.forwards.add(laddr)
return &tcpListener{laddr, c, ch}, nil
return &tcpListener{laddr, addr, c, ch}, nil
}
// forwardList stores a mapping between remote
@@ -149,8 +179,9 @@ type forwardList struct {
// forwardEntry represents an established mapping of a laddr on a
// remote ssh server to a channel connected to a tcpListener.
type forwardEntry struct {
laddr net.Addr
c chan forward
addr string // host:port or socket path
network string // tcp or unix
c chan forward
}
// forward represents an incoming forwarded tcpip connection. The
@@ -161,12 +192,13 @@ type forward struct {
raddr net.Addr // the raddr of the incoming connection
}
func (l *forwardList) add(addr net.Addr) chan forward {
func (l *forwardList) add(n, addr string) chan forward {
l.Lock()
defer l.Unlock()
f := forwardEntry{
laddr: addr,
c: make(chan forward, 1),
addr: addr,
network: n,
c: make(chan forward, 1),
}
l.entries = append(l.entries, f)
return f.c
@@ -185,19 +217,20 @@ func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
if port == 0 || port > 65535 {
return nil, fmt.Errorf("ssh: port number out of range: %d", port)
}
ip := net.ParseIP(string(addr))
if ip == nil {
ip, err := netip.ParseAddr(addr)
if err != nil {
return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr)
}
return &net.TCPAddr{IP: ip, Port: int(port)}, nil
return &net.TCPAddr{IP: net.IP(ip.AsSlice()), Port: int(port)}, nil
}
func (l *forwardList) handleChannels(in <-chan NewChannel) {
for ch := range in {
var (
laddr net.Addr
raddr net.Addr
err error
addr string
network string
raddr net.Addr
err error
)
switch channelType := ch.ChannelType(); channelType {
case "forwarded-tcpip":
@@ -207,40 +240,34 @@ func (l *forwardList) handleChannels(in <-chan NewChannel) {
continue
}
// RFC 4254 section 7.2 specifies that incoming
// addresses should list the address, in string
// format. It is implied that this should be an IP
// address, as it would be impossible to connect to it
// otherwise.
laddr, err = parseTCPAddr(payload.Addr, payload.Port)
if err != nil {
ch.Reject(ConnectionFailed, err.Error())
continue
}
// RFC 4254 section 7.2 specifies that incoming addresses should
// list the address that was connected, in string format. It is the
// same address used in the tcpip-forward request. The originator
// address is an IP address instead.
addr = net.JoinHostPort(payload.Addr, strconv.FormatUint(uint64(payload.Port), 10))
raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
if err != nil {
ch.Reject(ConnectionFailed, err.Error())
continue
}
network = "tcp"
case "forwarded-streamlocal@openssh.com":
var payload forwardedStreamLocalPayload
if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
continue
}
laddr = &net.UnixAddr{
Name: payload.SocketPath,
Net: "unix",
}
addr = payload.SocketPath
raddr = &net.UnixAddr{
Name: "@",
Net: "unix",
}
network = "unix"
default:
panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
}
if ok := l.forward(laddr, raddr, ch); !ok {
if ok := l.forward(network, addr, raddr, ch); !ok {
// Section 7.2, implementations MUST reject spurious incoming
// connections.
ch.Reject(Prohibited, "no forward for address")
@@ -252,11 +279,11 @@ func (l *forwardList) handleChannels(in <-chan NewChannel) {
// remove removes the forward entry, and the channel feeding its
// listener.
func (l *forwardList) remove(addr net.Addr) {
func (l *forwardList) remove(n, addr string) {
l.Lock()
defer l.Unlock()
for i, f := range l.entries {
if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
if n == f.network && addr == f.addr {
l.entries = append(l.entries[:i], l.entries[i+1:]...)
close(f.c)
return
@@ -274,11 +301,11 @@ func (l *forwardList) closeAll() {
l.entries = nil
}
func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
func (l *forwardList) forward(n, addr string, raddr net.Addr, ch NewChannel) bool {
l.Lock()
defer l.Unlock()
for _, f := range l.entries {
if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
if n == f.network && addr == f.addr {
f.c <- forward{newCh: ch, raddr: raddr}
return true
}
@@ -288,6 +315,7 @@ func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
type tcpListener struct {
laddr *net.TCPAddr
addr string
conn *Client
in <-chan forward
@@ -314,13 +342,21 @@ func (l *tcpListener) Accept() (net.Conn, error) {
// Close closes the listener.
func (l *tcpListener) Close() error {
host, port, err := net.SplitHostPort(l.addr)
if err != nil {
return err
}
rport, err := strconv.ParseUint(port, 10, 32)
if err != nil {
return err
}
m := channelForwardMsg{
l.laddr.IP.String(),
uint32(l.laddr.Port),
host,
uint32(rport),
}
// this also closes the listener.
l.conn.forwards.remove(l.laddr)
l.conn.forwards.remove("tcp", l.addr)
ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
if err == nil && !ok {
err = errors.New("ssh: cancel-tcpip-forward failed")
+8 -11
View File
@@ -8,6 +8,7 @@ import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"log"
)
@@ -16,13 +17,6 @@ import (
// wire. No message decoding is done, to minimize the impact on timing.
const debugTransport = false
const (
gcm128CipherID = "aes128-gcm@openssh.com"
gcm256CipherID = "aes256-gcm@openssh.com"
aes128cbcID = "aes128-cbc"
tripledescbcID = "3des-cbc"
)
// packetConn represents a transport that implements packet based
// operations.
type packetConn interface {
@@ -92,14 +86,14 @@ func (t *transport) setInitialKEXDone() {
// prepareKeyChange sets up key material for a keychange. The key changes in
// both directions are triggered by reading and writing a msgNewKey packet
// respectively.
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult)
func (t *transport) prepareKeyChange(algs *NegotiatedAlgorithms, kexResult *kexResult) error {
ciph, err := newPacketCipher(t.reader.dir, algs.Read, kexResult)
if err != nil {
return err
}
t.reader.pendingKeyChange <- ciph
ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult)
ciph, err = newPacketCipher(t.writer.dir, algs.Write, kexResult)
if err != nil {
return err
}
@@ -259,8 +253,11 @@ var (
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys).
func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) {
func newPacketCipher(d direction, algs DirectionAlgorithms, kex *kexResult) (packetCipher, error) {
cipherMode := cipherModes[algs.Cipher]
if cipherMode == nil {
return nil, fmt.Errorf("ssh: unsupported cipher %v", algs.Cipher)
}
iv := make([]byte, cipherMode.ivSize)
key := make([]byte, cipherMode.keySize)
+1 -1
View File
@@ -33,7 +33,7 @@ type printer struct {
}
// printf prints to the buffer.
func (p *printer) printf(format string, args ...interface{}) {
func (p *printer) printf(format string, args ...any) {
fmt.Fprintf(p, format, args...)
}
+2 -2
View File
@@ -94,7 +94,7 @@ func (x *FileSyntax) Span() (start, end Position) {
// line, the new line is added at the end of the block containing hint,
// extracting hint into a new block if it is not yet in one.
//
// If the hint is non-nil buts its first token does not match,
// If the hint is non-nil but its first token does not match,
// the new line is added after the block containing hint
// (or hint itself, if not in a block).
//
@@ -600,7 +600,7 @@ func (in *input) readToken() {
// Checked all punctuation. Must be identifier token.
if c := in.peekRune(); !isIdent(c) {
in.Error(fmt.Sprintf("unexpected input character %#q", c))
in.Error(fmt.Sprintf("unexpected input character %#q", rune(c)))
}
// Scan over identifier.
+101 -33
View File
@@ -20,10 +20,11 @@
package modfile
import (
"cmp"
"errors"
"fmt"
"path/filepath"
"sort"
"slices"
"strconv"
"strings"
"unicode"
@@ -44,6 +45,7 @@ type File struct {
Replace []*Replace
Retract []*Retract
Tool []*Tool
Ignore []*Ignore
Syntax *FileSyntax
}
@@ -100,6 +102,12 @@ type Tool struct {
Syntax *Line
}
// An Ignore is a single ignore statement.
type Ignore struct {
Path string
Syntax *Line
}
// A VersionInterval represents a range of versions with upper and lower bounds.
// Intervals are closed: both bounds are included. When Low is equal to High,
// the interval may refer to a single version ('v1.2.3') or an interval
@@ -304,7 +312,7 @@ func parseToFile(file string, data []byte, fix VersionFixer, strict bool) (parse
})
}
continue
case "module", "godebug", "require", "exclude", "replace", "retract", "tool":
case "module", "godebug", "require", "exclude", "replace", "retract", "tool", "ignore":
for _, l := range x.Line {
f.add(&errs, x, l, x.Token[0], l.Token, fix, strict)
}
@@ -337,7 +345,7 @@ func (f *File) add(errs *ErrorList, block *LineBlock, line *Line, verb string, a
// and simply ignore those statements.
if !strict {
switch verb {
case "go", "module", "retract", "require":
case "go", "module", "retract", "require", "ignore":
// want these even for dependency go.mods
default:
return
@@ -360,7 +368,7 @@ func (f *File) add(errs *ErrorList, block *LineBlock, line *Line, verb string, a
Err: err,
})
}
errorf := func(format string, args ...interface{}) {
errorf := func(format string, args ...any) {
wrapError(fmt.Errorf(format, args...))
}
@@ -531,6 +539,21 @@ func (f *File) add(errs *ErrorList, block *LineBlock, line *Line, verb string, a
Path: s,
Syntax: line,
})
case "ignore":
if len(args) != 1 {
errorf("ignore directive expects exactly one argument")
return
}
s, err := parseString(&args[0])
if err != nil {
errorf("invalid quoted string: %v", err)
return
}
f.Ignore = append(f.Ignore, &Ignore{
Path: s,
Syntax: line,
})
}
}
@@ -551,7 +574,7 @@ func parseReplace(filename string, line *Line, verb string, args []string, fix V
Err: err,
}
}
errorf := func(format string, args ...interface{}) *Error {
errorf := func(format string, args ...any) *Error {
return wrapError(fmt.Errorf(format, args...))
}
@@ -662,7 +685,7 @@ func (f *WorkFile) add(errs *ErrorList, line *Line, verb string, args []string,
Err: err,
})
}
errorf := func(format string, args ...interface{}) {
errorf := func(format string, args ...any) {
wrapError(fmt.Errorf(format, args...))
}
@@ -1571,7 +1594,7 @@ func (f *File) AddRetract(vi VersionInterval, rationale string) error {
r.Syntax = f.Syntax.addLine(nil, "retract", "[", AutoQuote(vi.Low), ",", AutoQuote(vi.High), "]")
}
if rationale != "" {
for _, line := range strings.Split(rationale, "\n") {
for line := range strings.SplitSeq(rationale, "\n") {
com := Comment{Token: "// " + line}
r.Syntax.Comment().Before = append(r.Syntax.Comment().Before, com)
}
@@ -1619,6 +1642,36 @@ func (f *File) DropTool(path string) error {
return nil
}
// AddIgnore adds a new ignore directive with the given path.
// It does nothing if the ignore line already exists.
func (f *File) AddIgnore(path string) error {
for _, t := range f.Ignore {
if t.Path == path {
return nil
}
}
f.Ignore = append(f.Ignore, &Ignore{
Path: path,
Syntax: f.Syntax.addLine(nil, "ignore", path),
})
f.SortBlocks()
return nil
}
// DropIgnore removes a ignore directive with the given path.
// It does nothing if no such ignore directive exists.
func (f *File) DropIgnore(path string) error {
for _, t := range f.Ignore {
if t.Path == path {
t.Syntax.markRemoved()
*t = Ignore{}
}
}
return nil
}
func (f *File) SortBlocks() {
f.removeDups() // otherwise sorting is unsafe
@@ -1633,15 +1686,13 @@ func (f *File) SortBlocks() {
if !ok {
continue
}
less := lineLess
less := compareLine
if block.Token[0] == "exclude" && useSemanticSortForExclude {
less = lineExcludeLess
less = compareLineExclude
} else if block.Token[0] == "retract" {
less = lineRetractLess
less = compareLineRetract
}
sort.SliceStable(block.Line, func(i, j int) bool {
return less(block.Line[i], block.Line[j])
})
slices.SortStableFunc(block.Line, less)
}
}
@@ -1657,10 +1708,10 @@ func (f *File) SortBlocks() {
// retract directives are not de-duplicated since comments are
// meaningful, and versions may be retracted multiple times.
func (f *File) removeDups() {
removeDups(f.Syntax, &f.Exclude, &f.Replace, &f.Tool)
removeDups(f.Syntax, &f.Exclude, &f.Replace, &f.Tool, &f.Ignore)
}
func removeDups(syntax *FileSyntax, exclude *[]*Exclude, replace *[]*Replace, tool *[]*Tool) {
func removeDups(syntax *FileSyntax, exclude *[]*Exclude, replace *[]*Replace, tool *[]*Tool, ignore *[]*Ignore) {
kill := make(map[*Line]bool)
// Remove duplicate excludes.
@@ -1719,6 +1770,24 @@ func removeDups(syntax *FileSyntax, exclude *[]*Exclude, replace *[]*Replace, to
*tool = newTool
}
if ignore != nil {
haveIgnore := make(map[string]bool)
for _, i := range *ignore {
if haveIgnore[i.Path] {
kill[i.Syntax] = true
continue
}
haveIgnore[i.Path] = true
}
var newIgnore []*Ignore
for _, i := range *ignore {
if !kill[i.Syntax] {
newIgnore = append(newIgnore, i)
}
}
*ignore = newIgnore
}
// Duplicate require and retract directives are not removed.
// Drop killed statements from the syntax tree.
@@ -1746,39 +1815,38 @@ func removeDups(syntax *FileSyntax, exclude *[]*Exclude, replace *[]*Replace, to
syntax.Stmt = stmts
}
// lineLess returns whether li should be sorted before lj. It sorts
// lexicographically without assigning any special meaning to tokens.
func lineLess(li, lj *Line) bool {
// compareLine compares li and lj. It sorts lexicographically without assigning
// any special meaning to tokens.
func compareLine(li, lj *Line) int {
for k := 0; k < len(li.Token) && k < len(lj.Token); k++ {
if li.Token[k] != lj.Token[k] {
return li.Token[k] < lj.Token[k]
return cmp.Compare(li.Token[k], lj.Token[k])
}
}
return len(li.Token) < len(lj.Token)
return cmp.Compare(len(li.Token), len(lj.Token))
}
// lineExcludeLess reports whether li should be sorted before lj for lines in
// an "exclude" block.
func lineExcludeLess(li, lj *Line) bool {
// compareLineExclude compares li and lj for lines in an "exclude" block.
func compareLineExclude(li, lj *Line) int {
if len(li.Token) != 2 || len(lj.Token) != 2 {
// Not a known exclude specification.
// Fall back to sorting lexicographically.
return lineLess(li, lj)
return compareLine(li, lj)
}
// An exclude specification has two tokens: ModulePath and Version.
// Compare module path by string order and version by semver rules.
if pi, pj := li.Token[0], lj.Token[0]; pi != pj {
return pi < pj
return cmp.Compare(pi, pj)
}
return semver.Compare(li.Token[1], lj.Token[1]) < 0
return semver.Compare(li.Token[1], lj.Token[1])
}
// lineRetractLess returns whether li should be sorted before lj for lines in
// a "retract" block. It treats each line as a version interval. Single versions
// are compared as if they were intervals with the same low and high version.
// compareLineRetract compares li and lj for lines in a "retract" block.
// It treats each line as a version interval. Single versions are compared as
// if they were intervals with the same low and high version.
// Intervals are sorted in descending order, first by low version, then by
// high version, using semver.Compare.
func lineRetractLess(li, lj *Line) bool {
// high version, using [semver.Compare].
func compareLineRetract(li, lj *Line) int {
interval := func(l *Line) VersionInterval {
if len(l.Token) == 1 {
return VersionInterval{Low: l.Token[0], High: l.Token[0]}
@@ -1792,9 +1860,9 @@ func lineRetractLess(li, lj *Line) bool {
vii := interval(li)
vij := interval(lj)
if cmp := semver.Compare(vii.Low, vij.Low); cmp != 0 {
return cmp > 0
return -cmp
}
return semver.Compare(vii.High, vij.High) > 0
return -semver.Compare(vii.High, vij.High)
}
// checkCanonicalVersion returns a non-nil error if vers is not a canonical
+3 -5
View File
@@ -6,7 +6,7 @@ package modfile
import (
"fmt"
"sort"
"slices"
"strings"
)
@@ -315,9 +315,7 @@ func (f *WorkFile) SortBlocks() {
if !ok {
continue
}
sort.SliceStable(block.Line, func(i, j int) bool {
return lineLess(block.Line[i], block.Line[j])
})
slices.SortStableFunc(block.Line, compareLine)
}
}
@@ -331,5 +329,5 @@ func (f *WorkFile) SortBlocks() {
// retract directives are not de-duplicated since comments are
// meaningful, and versions may be retracted multiple times.
func (f *WorkFile) removeDups() {
removeDups(f.Syntax, nil, &f.Replace, nil)
removeDups(f.Syntax, nil, &f.Replace, nil, nil)
}
+12 -13
View File
@@ -96,10 +96,11 @@ package module
// Changes to the semantics in this file require approval from rsc.
import (
"cmp"
"errors"
"fmt"
"path"
"sort"
"slices"
"strings"
"unicode"
"unicode/utf8"
@@ -260,7 +261,7 @@ func modPathOK(r rune) bool {
// importPathOK reports whether r can appear in a package import path element.
//
// Import paths are intermediate between module paths and file paths: we allow
// Import paths are intermediate between module paths and file paths: we
// disallow characters that would be confusing or ambiguous as arguments to
// 'go get' (such as '@' and ' ' ), but allow certain characters that are
// otherwise-unambiguous on the command line and historically used for some
@@ -657,17 +658,15 @@ func CanonicalVersion(v string) string {
// optionally followed by a tie-breaking suffix introduced by a slash character,
// like in "v0.0.1/go.mod".
func Sort(list []Version) {
sort.Slice(list, func(i, j int) bool {
mi := list[i]
mj := list[j]
if mi.Path != mj.Path {
return mi.Path < mj.Path
slices.SortFunc(list, func(i, j Version) int {
if i.Path != j.Path {
return strings.Compare(i.Path, j.Path)
}
// To help go.sum formatting, allow version/file.
// Compare semver prefix by semver rules,
// file by string order.
vi := mi.Version
vj := mj.Version
vi := i.Version
vj := j.Version
var fi, fj string
if k := strings.Index(vi, "/"); k >= 0 {
vi, fi = vi[:k], vi[k:]
@@ -676,9 +675,9 @@ func Sort(list []Version) {
vj, fj = vj[:k], vj[k:]
}
if vi != vj {
return semver.Compare(vi, vj) < 0
return semver.Compare(vi, vj)
}
return fi < fj
return cmp.Compare(fi, fj)
})
}
@@ -803,8 +802,8 @@ func MatchPrefixPatterns(globs, target string) bool {
for globs != "" {
// Extract next non-empty glob in comma-separated list.
var glob string
if i := strings.Index(globs, ","); i >= 0 {
glob, globs = globs[:i], globs[i+1:]
if before, after, ok := strings.Cut(globs, ","); ok {
glob, globs = before, after
} else {
glob, globs = globs, ""
}
+20 -14
View File
@@ -22,7 +22,10 @@
// as shorthands for vMAJOR.0.0 and vMAJOR.MINOR.0.
package semver
import "sort"
import (
"slices"
"strings"
)
// parsed returns the parsed form of a semantic version string.
type parsed struct {
@@ -42,8 +45,8 @@ func IsValid(v string) bool {
// Canonical returns the canonical formatting of the semantic version v.
// It fills in any missing .MINOR or .PATCH and discards build metadata.
// Two semantic versions compare equal only if their canonical formattings
// are identical strings.
// Two semantic versions compare equal only if their canonical formatting
// is an identical string.
// The canonical invalid semantic version is the empty string.
func Canonical(v string) string {
p, ok := parse(v)
@@ -154,19 +157,22 @@ func Max(v, w string) string {
// ByVersion implements [sort.Interface] for sorting semantic version strings.
type ByVersion []string
func (vs ByVersion) Len() int { return len(vs) }
func (vs ByVersion) Swap(i, j int) { vs[i], vs[j] = vs[j], vs[i] }
func (vs ByVersion) Less(i, j int) bool {
cmp := Compare(vs[i], vs[j])
if cmp != 0 {
return cmp < 0
}
return vs[i] < vs[j]
func (vs ByVersion) Len() int { return len(vs) }
func (vs ByVersion) Swap(i, j int) { vs[i], vs[j] = vs[j], vs[i] }
func (vs ByVersion) Less(i, j int) bool { return compareVersion(vs[i], vs[j]) < 0 }
// Sort sorts a list of semantic version strings using [Compare] and falls back
// to use [strings.Compare] if both versions are considered equal.
func Sort(list []string) {
slices.SortFunc(list, compareVersion)
}
// Sort sorts a list of semantic version strings using [ByVersion].
func Sort(list []string) {
sort.Sort(ByVersion(list))
func compareVersion(a, b string) int {
cmp := Compare(a, b)
if cmp != 0 {
return cmp
}
return strings.Compare(a, b)
}
func parse(v string) (p parsed, ok bool) {
+22 -48
View File
@@ -2,44 +2,9 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package context defines the Context type, which carries deadlines,
// cancellation signals, and other request-scoped values across API boundaries
// and between processes.
// As of Go 1.7 this package is available in the standard library under the
// name [context], and migrating to it can be done automatically with [go fix].
// Package context has been superseded by the standard library [context] package.
//
// Incoming requests to a server should create a [Context], and outgoing
// calls to servers should accept a Context. The chain of function
// calls between them must propagate the Context, optionally replacing
// it with a derived Context created using [WithCancel], [WithDeadline],
// [WithTimeout], or [WithValue].
//
// Programs that use Contexts should follow these rules to keep interfaces
// consistent across packages and enable static analysis tools to check context
// propagation:
//
// Do not store Contexts inside a struct type; instead, pass a Context
// explicitly to each function that needs it. This is discussed further in
// https://go.dev/blog/context-and-structs. The Context should be the first
// parameter, typically named ctx:
//
// func DoSomething(ctx context.Context, arg Arg) error {
// // ... use ctx ...
// }
//
// Do not pass a nil [Context], even if a function permits it. Pass [context.TODO]
// if you are unsure about which Context to use.
//
// Use context Values only for request-scoped data that transits processes and
// APIs, not for passing optional parameters to functions.
//
// The same Context may be passed to functions running in different goroutines;
// Contexts are safe for simultaneous use by multiple goroutines.
//
// See https://go.dev/blog/context for example code for a server that uses
// Contexts.
//
// [go fix]: https://go.dev/cmd/go#hdr-Update_packages_to_use_new_APIs
// Deprecated: Use the standard library context package instead.
package context
import (
@@ -51,36 +16,37 @@ import (
// API boundaries.
//
// Context's methods may be called by multiple goroutines simultaneously.
//
//go:fix inline
type Context = context.Context
// Canceled is the error returned by [Context.Err] when the context is canceled
// for some reason other than its deadline passing.
//
//go:fix inline
var Canceled = context.Canceled
// DeadlineExceeded is the error returned by [Context.Err] when the context is canceled
// due to its deadline passing.
//
//go:fix inline
var DeadlineExceeded = context.DeadlineExceeded
// Background returns a non-nil, empty Context. It is never canceled, has no
// values, and has no deadline. It is typically used by the main function,
// initialization, and tests, and as the top-level Context for incoming
// requests.
func Background() Context {
return background
}
//
//go:fix inline
func Background() Context { return context.Background() }
// TODO returns a non-nil, empty Context. Code should use context.TODO when
// it's unclear which Context to use or it is not yet available (because the
// surrounding function has not yet been extended to accept a Context
// parameter).
func TODO() Context {
return todo
}
var (
background = context.Background()
todo = context.TODO()
)
//
//go:fix inline
func TODO() Context { return context.TODO() }
// A CancelFunc tells an operation to abandon its work.
// A CancelFunc does not wait for the work to stop.
@@ -95,6 +61,8 @@ type CancelFunc = context.CancelFunc
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this [Context] complete.
//
//go:fix inline
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
return context.WithCancel(parent)
}
@@ -108,6 +76,8 @@ func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this [Context] complete.
//
//go:fix inline
func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) {
return context.WithDeadline(parent, d)
}
@@ -122,6 +92,8 @@ func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) {
// defer cancel() // releases resources if slowOperation completes before timeout elapses
// return slowOperation(ctx)
// }
//
//go:fix inline
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
return context.WithTimeout(parent, timeout)
}
@@ -139,6 +111,8 @@ func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
// interface{}, context keys often have concrete type
// struct{}. Alternatively, exported context key variables' static
// type should be a pointer or interface.
//
//go:fix inline
func WithValue(parent Context, key, val interface{}) Context {
return context.WithValue(parent, key, val)
}
+20
View File
@@ -0,0 +1,20 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.27
package http2
import "net/http"
// Support for go.dev/issue/75500 is added in Go 1.27. In case anyone uses
// x/net with versions before Go 1.27, we return true here so that their write
// scheduler will still be the round-robin write scheduler rather than the RFC
// 9218 write scheduler. That way, older users of Go will not see a sudden
// change of behavior just from importing x/net.
//
// TODO(nsh): remove this file after x/net go.mod is at Go 1.27.
func clientPriorityDisabled(_ *http.Server) bool {
return true
}
+13
View File
@@ -0,0 +1,13 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.27
package http2
import "net/http"
func clientPriorityDisabled(s *http.Server) bool {
return s.DisableClientPriority
}
+55 -8
View File
@@ -27,6 +27,7 @@ import (
// - If the resulting value is zero or out of range, use a default.
type http2Config struct {
MaxConcurrentStreams uint32
StrictMaxConcurrentRequests bool
MaxDecoderHeaderTableSize uint32
MaxEncoderHeaderTableSize uint32
MaxReadFrameSize uint32
@@ -55,7 +56,7 @@ func configFromServer(h1 *http.Server, h2 *Server) http2Config {
PermitProhibitedCipherSuites: h2.PermitProhibitedCipherSuites,
CountError: h2.CountError,
}
fillNetHTTPServerConfig(&conf, h1)
fillNetHTTPConfig(&conf, h1.HTTP2)
setConfigDefaults(&conf, true)
return conf
}
@@ -64,12 +65,13 @@ func configFromServer(h1 *http.Server, h2 *Server) http2Config {
// (the net/http Transport).
func configFromTransport(h2 *Transport) http2Config {
conf := http2Config{
MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize,
MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize,
MaxReadFrameSize: h2.MaxReadFrameSize,
SendPingTimeout: h2.ReadIdleTimeout,
PingTimeout: h2.PingTimeout,
WriteByteTimeout: h2.WriteByteTimeout,
StrictMaxConcurrentRequests: h2.StrictMaxConcurrentStreams,
MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize,
MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize,
MaxReadFrameSize: h2.MaxReadFrameSize,
SendPingTimeout: h2.ReadIdleTimeout,
PingTimeout: h2.PingTimeout,
WriteByteTimeout: h2.WriteByteTimeout,
}
// Unlike most config fields, where out-of-range values revert to the default,
@@ -81,7 +83,7 @@ func configFromTransport(h2 *Transport) http2Config {
}
if h2.t1 != nil {
fillNetHTTPTransportConfig(&conf, h2.t1)
fillNetHTTPConfig(&conf, h2.t1.HTTP2)
}
setConfigDefaults(&conf, false)
return conf
@@ -120,3 +122,48 @@ func adjustHTTP1MaxHeaderSize(n int64) int64 {
const typicalHeaders = 10 // conservative
return n + typicalHeaders*perFieldOverhead
}
func fillNetHTTPConfig(conf *http2Config, h2 *http.HTTP2Config) {
if h2 == nil {
return
}
if h2.MaxConcurrentStreams != 0 {
conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams)
}
if http2ConfigStrictMaxConcurrentRequests(h2) {
conf.StrictMaxConcurrentRequests = true
}
if h2.MaxEncoderHeaderTableSize != 0 {
conf.MaxEncoderHeaderTableSize = uint32(h2.MaxEncoderHeaderTableSize)
}
if h2.MaxDecoderHeaderTableSize != 0 {
conf.MaxDecoderHeaderTableSize = uint32(h2.MaxDecoderHeaderTableSize)
}
if h2.MaxConcurrentStreams != 0 {
conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams)
}
if h2.MaxReadFrameSize != 0 {
conf.MaxReadFrameSize = uint32(h2.MaxReadFrameSize)
}
if h2.MaxReceiveBufferPerConnection != 0 {
conf.MaxUploadBufferPerConnection = int32(h2.MaxReceiveBufferPerConnection)
}
if h2.MaxReceiveBufferPerStream != 0 {
conf.MaxUploadBufferPerStream = int32(h2.MaxReceiveBufferPerStream)
}
if h2.SendPingTimeout != 0 {
conf.SendPingTimeout = h2.SendPingTimeout
}
if h2.PingTimeout != 0 {
conf.PingTimeout = h2.PingTimeout
}
if h2.WriteByteTimeout != 0 {
conf.WriteByteTimeout = h2.WriteByteTimeout
}
if h2.PermitProhibitedCipherSuites {
conf.PermitProhibitedCipherSuites = true
}
if h2.CountError != nil {
conf.CountError = h2.CountError
}
}
-61
View File
@@ -1,61 +0,0 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.24
package http2
import "net/http"
// fillNetHTTPServerConfig sets fields in conf from srv.HTTP2.
func fillNetHTTPServerConfig(conf *http2Config, srv *http.Server) {
fillNetHTTPConfig(conf, srv.HTTP2)
}
// fillNetHTTPTransportConfig sets fields in conf from tr.HTTP2.
func fillNetHTTPTransportConfig(conf *http2Config, tr *http.Transport) {
fillNetHTTPConfig(conf, tr.HTTP2)
}
func fillNetHTTPConfig(conf *http2Config, h2 *http.HTTP2Config) {
if h2 == nil {
return
}
if h2.MaxConcurrentStreams != 0 {
conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams)
}
if h2.MaxEncoderHeaderTableSize != 0 {
conf.MaxEncoderHeaderTableSize = uint32(h2.MaxEncoderHeaderTableSize)
}
if h2.MaxDecoderHeaderTableSize != 0 {
conf.MaxDecoderHeaderTableSize = uint32(h2.MaxDecoderHeaderTableSize)
}
if h2.MaxConcurrentStreams != 0 {
conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams)
}
if h2.MaxReadFrameSize != 0 {
conf.MaxReadFrameSize = uint32(h2.MaxReadFrameSize)
}
if h2.MaxReceiveBufferPerConnection != 0 {
conf.MaxUploadBufferPerConnection = int32(h2.MaxReceiveBufferPerConnection)
}
if h2.MaxReceiveBufferPerStream != 0 {
conf.MaxUploadBufferPerStream = int32(h2.MaxReceiveBufferPerStream)
}
if h2.SendPingTimeout != 0 {
conf.SendPingTimeout = h2.SendPingTimeout
}
if h2.PingTimeout != 0 {
conf.PingTimeout = h2.PingTimeout
}
if h2.WriteByteTimeout != 0 {
conf.WriteByteTimeout = h2.WriteByteTimeout
}
if h2.PermitProhibitedCipherSuites {
conf.PermitProhibitedCipherSuites = true
}
if h2.CountError != nil {
conf.CountError = h2.CountError
}
}
+15
View File
@@ -0,0 +1,15 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.26
package http2
import (
"net/http"
)
func http2ConfigStrictMaxConcurrentRequests(h2 *http.HTTP2Config) bool {
return false
}
+15
View File
@@ -0,0 +1,15 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.26
package http2
import (
"net/http"
)
func http2ConfigStrictMaxConcurrentRequests(h2 *http.HTTP2Config) bool {
return h2.StrictMaxConcurrentRequests
}
-16
View File
@@ -1,16 +0,0 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.24
package http2
import "net/http"
// Pre-Go 1.24 fallback.
// The Server.HTTP2 and Transport.HTTP2 config fields were added in Go 1.24.
func fillNetHTTPServerConfig(conf *http2Config, srv *http.Server) {}
func fillNetHTTPTransportConfig(conf *http2Config, tr *http.Transport) {}
+232 -63
View File
@@ -11,11 +11,13 @@ import (
"fmt"
"io"
"log"
"slices"
"strings"
"sync"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/internal/httpsfv"
)
const frameHeaderLen = 9
@@ -23,40 +25,43 @@ const frameHeaderLen = 9
var padZeros = make([]byte, 255) // zeros for padding
// A FrameType is a registered frame type as defined in
// https://httpwg.org/specs/rfc7540.html#rfc.section.11.2
// https://httpwg.org/specs/rfc7540.html#rfc.section.11.2 and other future
// RFCs.
type FrameType uint8
const (
FrameData FrameType = 0x0
FrameHeaders FrameType = 0x1
FramePriority FrameType = 0x2
FrameRSTStream FrameType = 0x3
FrameSettings FrameType = 0x4
FramePushPromise FrameType = 0x5
FramePing FrameType = 0x6
FrameGoAway FrameType = 0x7
FrameWindowUpdate FrameType = 0x8
FrameContinuation FrameType = 0x9
FrameData FrameType = 0x0
FrameHeaders FrameType = 0x1
FramePriority FrameType = 0x2
FrameRSTStream FrameType = 0x3
FrameSettings FrameType = 0x4
FramePushPromise FrameType = 0x5
FramePing FrameType = 0x6
FrameGoAway FrameType = 0x7
FrameWindowUpdate FrameType = 0x8
FrameContinuation FrameType = 0x9
FramePriorityUpdate FrameType = 0x10
)
var frameName = map[FrameType]string{
FrameData: "DATA",
FrameHeaders: "HEADERS",
FramePriority: "PRIORITY",
FrameRSTStream: "RST_STREAM",
FrameSettings: "SETTINGS",
FramePushPromise: "PUSH_PROMISE",
FramePing: "PING",
FrameGoAway: "GOAWAY",
FrameWindowUpdate: "WINDOW_UPDATE",
FrameContinuation: "CONTINUATION",
var frameNames = [...]string{
FrameData: "DATA",
FrameHeaders: "HEADERS",
FramePriority: "PRIORITY",
FrameRSTStream: "RST_STREAM",
FrameSettings: "SETTINGS",
FramePushPromise: "PUSH_PROMISE",
FramePing: "PING",
FrameGoAway: "GOAWAY",
FrameWindowUpdate: "WINDOW_UPDATE",
FrameContinuation: "CONTINUATION",
FramePriorityUpdate: "PRIORITY_UPDATE",
}
func (t FrameType) String() string {
if s, ok := frameName[t]; ok {
return s
if int(t) < len(frameNames) {
return frameNames[t]
}
return fmt.Sprintf("UNKNOWN_FRAME_TYPE_%d", uint8(t))
return fmt.Sprintf("UNKNOWN_FRAME_TYPE_%d", t)
}
// Flags is a bitmask of HTTP/2 flags.
@@ -124,22 +129,25 @@ var flagName = map[FrameType]map[Flags]string{
// might be 0).
type frameParser func(fc *frameCache, fh FrameHeader, countError func(string), payload []byte) (Frame, error)
var frameParsers = map[FrameType]frameParser{
FrameData: parseDataFrame,
FrameHeaders: parseHeadersFrame,
FramePriority: parsePriorityFrame,
FrameRSTStream: parseRSTStreamFrame,
FrameSettings: parseSettingsFrame,
FramePushPromise: parsePushPromise,
FramePing: parsePingFrame,
FrameGoAway: parseGoAwayFrame,
FrameWindowUpdate: parseWindowUpdateFrame,
FrameContinuation: parseContinuationFrame,
var frameParsers = [...]frameParser{
FrameData: parseDataFrame,
FrameHeaders: parseHeadersFrame,
FramePriority: parsePriorityFrame,
FrameRSTStream: parseRSTStreamFrame,
FrameSettings: parseSettingsFrame,
FramePushPromise: parsePushPromise,
FramePing: parsePingFrame,
FrameGoAway: parseGoAwayFrame,
FrameWindowUpdate: parseWindowUpdateFrame,
FrameContinuation: parseContinuationFrame,
FramePriorityUpdate: parsePriorityUpdateFrame,
}
func typeFrameParser(t FrameType) frameParser {
if f := frameParsers[t]; f != nil {
return f
if int(t) < len(frameParsers) {
if f := frameParsers[t]; f != nil {
return f
}
}
return parseUnknownFrame
}
@@ -280,6 +288,8 @@ type Framer struct {
// lastHeaderStream is non-zero if the last frame was an
// unfinished HEADERS/CONTINUATION.
lastHeaderStream uint32
// lastFrameType holds the type of the last frame for verifying frame order.
lastFrameType FrameType
maxReadSize uint32
headerBuf [frameHeaderLen]byte
@@ -347,7 +357,7 @@ func (fr *Framer) maxHeaderListSize() uint32 {
func (f *Framer) startWrite(ftype FrameType, flags Flags, streamID uint32) {
// Write the FrameHeader.
f.wbuf = append(f.wbuf[:0],
0, // 3 bytes of length, filled in in endWrite
0, // 3 bytes of length, filled in endWrite
0,
0,
byte(ftype),
@@ -488,30 +498,41 @@ func terminalReadFrameError(err error) bool {
return err != nil
}
// ReadFrame reads a single frame. The returned Frame is only valid
// until the next call to ReadFrame.
// ReadFrameHeader reads the header of the next frame.
// It reads the 9-byte fixed frame header, and does not read any portion of the
// frame payload. The caller is responsible for consuming the payload, either
// with ReadFrameForHeader or directly from the Framer's io.Reader.
//
// If the frame is larger than previously set with SetMaxReadFrameSize, the
// returned error is ErrFrameTooLarge. Other errors may be of type
// ConnectionError, StreamError, or anything else from the underlying
// reader.
// If the frame is larger than previously set with SetMaxReadFrameSize, it
// returns the frame header and ErrFrameTooLarge.
//
// If ReadFrame returns an error and a non-nil Frame, the Frame's StreamID
// indicates the stream responsible for the error.
func (fr *Framer) ReadFrame() (Frame, error) {
// If the returned FrameHeader.StreamID is non-zero, it indicates the stream
// responsible for the error.
func (fr *Framer) ReadFrameHeader() (FrameHeader, error) {
fr.errDetail = nil
if fr.lastFrame != nil {
fr.lastFrame.invalidate()
}
fh, err := readFrameHeader(fr.headerBuf[:], fr.r)
if err != nil {
return nil, err
return fh, err
}
if fh.Length > fr.maxReadSize {
if fh == invalidHTTP1LookingFrameHeader() {
return nil, fmt.Errorf("http2: failed reading the frame payload: %w, note that the frame header looked like an HTTP/1.1 header", err)
return fh, fmt.Errorf("http2: failed reading the frame payload: %w, note that the frame header looked like an HTTP/1.1 header", ErrFrameTooLarge)
}
return nil, ErrFrameTooLarge
return fh, ErrFrameTooLarge
}
if err := fr.checkFrameOrder(fh); err != nil {
return fh, err
}
return fh, nil
}
// ReadFrameForHeader reads the payload for the frame with the given FrameHeader.
//
// It behaves identically to ReadFrame, other than not checking the maximum
// frame size.
func (fr *Framer) ReadFrameForHeader(fh FrameHeader) (Frame, error) {
if fr.lastFrame != nil {
fr.lastFrame.invalidate()
}
payload := fr.getReadBuf(fh.Length)
if _, err := io.ReadFull(fr.r, payload); err != nil {
@@ -527,9 +548,7 @@ func (fr *Framer) ReadFrame() (Frame, error) {
}
return nil, err
}
if err := fr.checkFrameOrder(f); err != nil {
return nil, err
}
fr.lastFrame = f
if fr.logReads {
fr.debugReadLoggerf("http2: Framer %p: read %v", fr, summarizeFrame(f))
}
@@ -539,6 +558,24 @@ func (fr *Framer) ReadFrame() (Frame, error) {
return f, nil
}
// ReadFrame reads a single frame. The returned Frame is only valid
// until the next call to ReadFrame or ReadFrameBodyForHeader.
//
// If the frame is larger than previously set with SetMaxReadFrameSize, the
// returned error is ErrFrameTooLarge. Other errors may be of type
// ConnectionError, StreamError, or anything else from the underlying
// reader.
//
// If ReadFrame returns an error and a non-nil Frame, the Frame's StreamID
// indicates the stream responsible for the error.
func (fr *Framer) ReadFrame() (Frame, error) {
fh, err := fr.ReadFrameHeader()
if err != nil {
return nil, err
}
return fr.ReadFrameForHeader(fh)
}
// connError returns ConnectionError(code) but first
// stashes away a public reason to the caller can optionally relay it
// to the peer before hanging up on them. This might help others debug
@@ -551,20 +588,19 @@ func (fr *Framer) connError(code ErrCode, reason string) error {
// checkFrameOrder reports an error if f is an invalid frame to return
// next from ReadFrame. Mostly it checks whether HEADERS and
// CONTINUATION frames are contiguous.
func (fr *Framer) checkFrameOrder(f Frame) error {
last := fr.lastFrame
fr.lastFrame = f
func (fr *Framer) checkFrameOrder(fh FrameHeader) error {
lastType := fr.lastFrameType
fr.lastFrameType = fh.Type
if fr.AllowIllegalReads {
return nil
}
fh := f.Header()
if fr.lastHeaderStream != 0 {
if fh.Type != FrameContinuation {
return fr.connError(ErrCodeProtocol,
fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d",
fh.Type, fh.StreamID,
last.Header().Type, fr.lastHeaderStream))
lastType, fr.lastHeaderStream))
}
if fh.StreamID != fr.lastHeaderStream {
return fr.connError(ErrCodeProtocol,
@@ -1152,7 +1188,41 @@ type PriorityFrame struct {
PriorityParam
}
// PriorityParam are the stream prioritzation parameters.
// defaultRFC9218Priority determines what priority we should use as the default
// value.
//
// According to RFC 9218, by default, streams should be given an urgency of 3
// and should be non-incremental. However, making streams non-incremental by
// default would be a huge change to our historical behavior where we would
// round-robin writes across streams. When streams are non-incremental, we
// would process streams of the same urgency one-by-one to completion instead.
//
// To avoid such a sudden change which might break some HTTP/2 users, this
// function allows the caller to specify whether they can actually use the
// default value as specified in RFC 9218. If not, this function will return a
// priority value where streams are incremental by default instead: effectively
// a round-robin between stream of the same urgency.
//
// As an example, a server might not be able to use the RFC 9218 default value
// when it's not sure that the client it is serving is aware of RFC 9218.
func defaultRFC9218Priority(canUseDefault bool) PriorityParam {
if canUseDefault {
return PriorityParam{
urgency: 3,
incremental: 0,
}
}
return PriorityParam{
urgency: 3,
incremental: 1,
}
}
// Note that HTTP/2 has had two different prioritization schemes, and
// PriorityParam struct below is a superset of both schemes. The exported
// symbols are from RFC 7540 and the non-exported ones are from RFC 9218.
// PriorityParam are the stream prioritization parameters.
type PriorityParam struct {
// StreamDep is a 31-bit stream identifier for the
// stream that this stream depends on. Zero means no
@@ -1167,6 +1237,20 @@ type PriorityParam struct {
// the spec, "Add one to the value to obtain a weight between
// 1 and 256."
Weight uint8
// "The urgency (u) parameter value is Integer (see Section 3.3.1 of
// [STRUCTURED-FIELDS]), between 0 and 7 inclusive, in descending order of
// priority. The default is 3."
urgency uint8
// "The incremental (i) parameter value is Boolean (see Section 3.3.6 of
// [STRUCTURED-FIELDS]). It indicates if an HTTP response can be processed
// incrementally, i.e., provide some meaningful output as chunks of the
// response arrive."
//
// We use uint8 (i.e. 0 is false, 1 is true) instead of bool so we can
// avoid unnecessary type conversions and because either type takes 1 byte.
incremental uint8
}
func (p PriorityParam) IsZero() bool {
@@ -1215,6 +1299,74 @@ func (f *Framer) WritePriority(streamID uint32, p PriorityParam) error {
return f.endWrite()
}
// PriorityUpdateFrame is a PRIORITY_UPDATE frame as described in
// https://www.rfc-editor.org/rfc/rfc9218.html#name-the-priority_update-frame.
type PriorityUpdateFrame struct {
FrameHeader
Priority string
PrioritizedStreamID uint32
}
func parseRFC9218Priority(s string, canUseDefault bool) (p PriorityParam, ok bool) {
p = defaultRFC9218Priority(canUseDefault)
ok = httpsfv.ParseDictionary(s, func(key, val, _ string) {
switch key {
case "u":
if u, ok := httpsfv.ParseInteger(val); ok && u >= 0 && u <= 7 {
p.urgency = uint8(u)
}
case "i":
if i, ok := httpsfv.ParseBoolean(val); ok {
if i {
p.incremental = 1
} else {
p.incremental = 0
}
}
}
})
if !ok {
return defaultRFC9218Priority(canUseDefault), ok
}
return p, true
}
func parsePriorityUpdateFrame(_ *frameCache, fh FrameHeader, countError func(string), payload []byte) (Frame, error) {
if fh.StreamID != 0 {
countError("frame_priority_update_non_zero_stream")
return nil, connError{ErrCodeProtocol, "PRIORITY_UPDATE frame with non-zero stream ID"}
}
if len(payload) < 4 {
countError("frame_priority_update_bad_length")
return nil, connError{ErrCodeFrameSize, fmt.Sprintf("PRIORITY_UPDATE frame payload size was %d; want at least 4", len(payload))}
}
v := binary.BigEndian.Uint32(payload[:4])
streamID := v & 0x7fffffff // mask off high bit
if streamID == 0 {
countError("frame_priority_update_prioritizing_zero_stream")
return nil, connError{ErrCodeProtocol, "PRIORITY_UPDATE frame with prioritized stream ID of zero"}
}
return &PriorityUpdateFrame{
FrameHeader: fh,
PrioritizedStreamID: streamID,
Priority: string(payload[4:]),
}, nil
}
// WritePriorityUpdate writes a PRIORITY_UPDATE frame.
//
// It will perform exactly one Write to the underlying Writer.
// It is the caller's responsibility to not call other Write methods concurrently.
func (f *Framer) WritePriorityUpdate(streamID uint32, priority string) error {
if !validStreamID(streamID) && !f.AllowIllegalWrites {
return errStreamID
}
f.startWrite(FramePriorityUpdate, 0, 0)
f.writeUint32(streamID)
f.writeBytes([]byte(priority))
return f.endWrite()
}
// A RSTStreamFrame allows for abnormal termination of a stream.
// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.4
type RSTStreamFrame struct {
@@ -1496,6 +1648,23 @@ func (mh *MetaHeadersFrame) PseudoFields() []hpack.HeaderField {
return mh.Fields
}
func (mh *MetaHeadersFrame) rfc9218Priority(priorityAware bool) (p PriorityParam, priorityAwareAfter, hasIntermediary bool) {
var s string
for _, field := range mh.Fields {
if field.Name == "priority" {
s = field.Value
priorityAware = true
}
if slices.Contains([]string{"via", "forwarded", "x-forwarded-for"}, field.Name) {
hasIntermediary = true
}
}
// No need to check for ok. parseRFC9218Priority will return a default
// value if there is no priority field or if the field cannot be parsed.
p, _ = parseRFC9218Priority(s, priorityAware && !hasIntermediary)
return p, priorityAware, hasIntermediary
}
func (mh *MetaHeadersFrame) checkPseudos() error {
var isRequest, isResponse bool
pf := mh.PseudoFields()
+14 -3
View File
@@ -15,21 +15,32 @@ import (
"runtime"
"strconv"
"sync"
"sync/atomic"
)
var DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1"
// Setting DebugGoroutines to false during a test to disable goroutine debugging
// results in race detector complaints when a test leaves goroutines running before
// returning. Tests shouldn't do this, of course, but when they do it generally shows
// up as infrequent, hard-to-debug flakes. (See #66519.)
//
// Disable goroutine debugging during individual tests with an atomic bool.
// (Note that it's safe to enable/disable debugging mid-test, so the actual race condition
// here is harmless.)
var disableDebugGoroutines atomic.Bool
type goroutineLock uint64
func newGoroutineLock() goroutineLock {
if !DebugGoroutines {
if !DebugGoroutines || disableDebugGoroutines.Load() {
return 0
}
return goroutineLock(curGoroutineID())
}
func (g goroutineLock) check() {
if !DebugGoroutines {
if !DebugGoroutines || disableDebugGoroutines.Load() {
return
}
if curGoroutineID() != uint64(g) {
@@ -38,7 +49,7 @@ func (g goroutineLock) check() {
}
func (g goroutineLock) checkNotOn() {
if !DebugGoroutines {
if !DebugGoroutines || disableDebugGoroutines.Load() {
return
}
if curGoroutineID() == uint64(g) {
+10 -3
View File
@@ -6,6 +6,7 @@ package hpack
import (
"fmt"
"strings"
)
// headerFieldTable implements a list of HeaderFields.
@@ -54,10 +55,16 @@ func (t *headerFieldTable) len() int {
// addEntry adds a new entry.
func (t *headerFieldTable) addEntry(f HeaderField) {
// Prevent f from escaping to the heap.
f2 := HeaderField{
Name: strings.Clone(f.Name),
Value: strings.Clone(f.Value),
Sensitive: f.Sensitive,
}
id := uint64(t.len()) + t.evictCount + 1
t.byName[f.Name] = id
t.byNameValue[pairNameValue{f.Name, f.Value}] = id
t.ents = append(t.ents, f)
t.byName[f2.Name] = id
t.byNameValue[pairNameValue{f2.Name, f2.Value}] = id
t.ents = append(t.ents, f2)
}
// evictOldest evicts the n oldest entries in the table.
+18 -35
View File
@@ -4,20 +4,21 @@
// Package http2 implements the HTTP/2 protocol.
//
// This package is low-level and intended to be used directly by very
// few people. Most users will use it indirectly through the automatic
// use by the net/http package (from Go 1.6 and later).
// For use in earlier Go versions see ConfigureServer. (Transport support
// requires Go 1.6 or later)
// Almost no users should need to import this package directly.
// The net/http package supports HTTP/2 natively.
//
// See https://http2.github.io/ for more information on HTTP/2.
// To enable or disable HTTP/2 support in net/http clients and servers, see
// [http.Transport.Protocols] and [http.Server.Protocols].
//
// See https://http2.golang.org/ for a test server running this code.
// To configure HTTP/2 parameters, see
// [http.Transport.HTTP2] and [http.Server.HTTP2].
//
// To create HTTP/1 or HTTP/2 connections, see
// [http.Transport.NewClientConn].
package http2 // import "golang.org/x/net/http2"
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
@@ -37,7 +38,6 @@ var (
VerboseLogs bool
logFrameWrites bool
logFrameReads bool
inTests bool
// Enabling extended CONNECT by causes browsers to attempt to use
// WebSockets-over-HTTP/2. This results in problems when the server's websocket
@@ -173,6 +173,7 @@ const (
SettingMaxFrameSize SettingID = 0x5
SettingMaxHeaderListSize SettingID = 0x6
SettingEnableConnectProtocol SettingID = 0x8
SettingNoRFC7540Priorities SettingID = 0x9
)
var settingName = map[SettingID]string{
@@ -183,6 +184,7 @@ var settingName = map[SettingID]string{
SettingMaxFrameSize: "MAX_FRAME_SIZE",
SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE",
SettingEnableConnectProtocol: "ENABLE_CONNECT_PROTOCOL",
SettingNoRFC7540Priorities: "NO_RFC7540_PRIORITIES",
}
func (s SettingID) String() string {
@@ -257,15 +259,13 @@ func (cw closeWaiter) Wait() {
// idle memory usage with many connections.
type bufferedWriter struct {
_ incomparable
group synctestGroupInterface // immutable
conn net.Conn // immutable
bw *bufio.Writer // non-nil when data is buffered
byteTimeout time.Duration // immutable, WriteByteTimeout
conn net.Conn // immutable
bw *bufio.Writer // non-nil when data is buffered
byteTimeout time.Duration // immutable, WriteByteTimeout
}
func newBufferedWriter(group synctestGroupInterface, conn net.Conn, timeout time.Duration) *bufferedWriter {
func newBufferedWriter(conn net.Conn, timeout time.Duration) *bufferedWriter {
return &bufferedWriter{
group: group,
conn: conn,
byteTimeout: timeout,
}
@@ -316,24 +316,18 @@ func (w *bufferedWriter) Flush() error {
type bufferedWriterTimeoutWriter bufferedWriter
func (w *bufferedWriterTimeoutWriter) Write(p []byte) (n int, err error) {
return writeWithByteTimeout(w.group, w.conn, w.byteTimeout, p)
return writeWithByteTimeout(w.conn, w.byteTimeout, p)
}
// writeWithByteTimeout writes to conn.
// If more than timeout passes without any bytes being written to the connection,
// the write fails.
func writeWithByteTimeout(group synctestGroupInterface, conn net.Conn, timeout time.Duration, p []byte) (n int, err error) {
func writeWithByteTimeout(conn net.Conn, timeout time.Duration, p []byte) (n int, err error) {
if timeout <= 0 {
return conn.Write(p)
}
for {
var now time.Time
if group == nil {
now = time.Now()
} else {
now = group.Now()
}
conn.SetWriteDeadline(now.Add(timeout))
conn.SetWriteDeadline(time.Now().Add(timeout))
nn, err := conn.Write(p[n:])
n += nn
if n == len(p) || nn == 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
@@ -419,14 +413,3 @@ func (s *sorter) SortStrings(ss []string) {
// makes that struct also non-comparable, and generally doesn't add
// any size (as long as it's first).
type incomparable [0]func()
// synctestGroupInterface is the methods of synctestGroup used by Server and Transport.
// It's defined as an interface here to let us keep synctestGroup entirely test-only
// and not a part of non-test builds.
type synctestGroupInterface interface {
Join()
Now() time.Time
NewTimer(d time.Duration) timer
AfterFunc(d time.Duration, f func()) timer
ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc)
}
+147 -84
View File
@@ -164,6 +164,8 @@ type Server struct {
// NewWriteScheduler constructs a write scheduler for a connection.
// If nil, a default scheduler is chosen.
//
// Deprecated: User-provided write schedulers are deprecated.
NewWriteScheduler func() WriteScheduler
// CountError, if non-nil, is called on HTTP/2 server errors.
@@ -176,44 +178,15 @@ type Server struct {
// so that we don't embed a Mutex in this struct, which will make the
// struct non-copyable, which might break some callers.
state *serverInternalState
// Synchronization group used for testing.
// Outside of tests, this is nil.
group synctestGroupInterface
}
func (s *Server) markNewGoroutine() {
if s.group != nil {
s.group.Join()
}
}
func (s *Server) now() time.Time {
if s.group != nil {
return s.group.Now()
}
return time.Now()
}
// newTimer creates a new time.Timer, or a synthetic timer in tests.
func (s *Server) newTimer(d time.Duration) timer {
if s.group != nil {
return s.group.NewTimer(d)
}
return timeTimer{time.NewTimer(d)}
}
// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
func (s *Server) afterFunc(d time.Duration, f func()) timer {
if s.group != nil {
return s.group.AfterFunc(d, f)
}
return timeTimer{time.AfterFunc(d, f)}
}
type serverInternalState struct {
mu sync.Mutex
activeConns map[*serverConn]struct{}
// Pool of error channels. This is per-Server rather than global
// because channels can't be reused across synctest bubbles.
errChanPool sync.Pool
}
func (s *serverInternalState) registerConn(sc *serverConn) {
@@ -245,6 +218,27 @@ func (s *serverInternalState) startGracefulShutdown() {
s.mu.Unlock()
}
// Global error channel pool used for uninitialized Servers.
// We use a per-Server pool when possible to avoid using channels across synctest bubbles.
var errChanPool = sync.Pool{
New: func() any { return make(chan error, 1) },
}
func (s *serverInternalState) getErrChan() chan error {
if s == nil {
return errChanPool.Get().(chan error) // Server used without calling ConfigureServer
}
return s.errChanPool.Get().(chan error)
}
func (s *serverInternalState) putErrChan(ch chan error) {
if s == nil {
errChanPool.Put(ch) // Server used without calling ConfigureServer
return
}
s.errChanPool.Put(ch)
}
// ConfigureServer adds HTTP/2 support to a net/http Server.
//
// The configuration conf may be nil.
@@ -257,7 +251,10 @@ func ConfigureServer(s *http.Server, conf *Server) error {
if conf == nil {
conf = new(Server)
}
conf.state = &serverInternalState{activeConns: make(map[*serverConn]struct{})}
conf.state = &serverInternalState{
activeConns: make(map[*serverConn]struct{}),
errChanPool: sync.Pool{New: func() any { return make(chan error, 1) }},
}
if h1, h2 := s, conf; h2.IdleTimeout == 0 {
if h1.IdleTimeout != 0 {
h2.IdleTimeout = h1.IdleTimeout
@@ -423,6 +420,9 @@ func (o *ServeConnOpts) handler() http.Handler {
//
// The opts parameter is optional. If nil, default values are used.
func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
if opts == nil {
opts = &ServeConnOpts{}
}
s.serveConn(c, opts, nil)
}
@@ -438,7 +438,7 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon
conn: c,
baseCtx: baseCtx,
remoteAddrStr: c.RemoteAddr().String(),
bw: newBufferedWriter(s.group, c, conf.WriteByteTimeout),
bw: newBufferedWriter(c, conf.WriteByteTimeout),
handler: opts.handler(),
streams: make(map[uint32]*stream),
readFrameCh: make(chan readFrameResult),
@@ -474,10 +474,13 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon
sc.conn.SetWriteDeadline(time.Time{})
}
if s.NewWriteScheduler != nil {
switch {
case s.NewWriteScheduler != nil:
sc.writeSched = s.NewWriteScheduler()
} else {
case clientPriorityDisabled(http1srv):
sc.writeSched = newRoundRobinWriteScheduler()
default:
sc.writeSched = newPriorityWriteSchedulerRFC9218()
}
// These start at the RFC-specified defaults. If there is a higher
@@ -638,11 +641,11 @@ type serverConn struct {
pingSent bool
sentPingData [8]byte
goAwayCode ErrCode
shutdownTimer timer // nil until used
idleTimer timer // nil if unused
shutdownTimer *time.Timer // nil until used
idleTimer *time.Timer // nil if unused
readIdleTimeout time.Duration
pingTimeout time.Duration
readIdleTimer timer // nil if unused
readIdleTimer *time.Timer // nil if unused
// Owned by the writeFrameAsync goroutine:
headerWriteBuf bytes.Buffer
@@ -650,6 +653,23 @@ type serverConn struct {
// Used by startGracefulShutdown.
shutdownOnce sync.Once
// Used for RFC 9218 prioritization.
hasIntermediary bool // connection is done via an intermediary / proxy
priorityAware bool // the client has sent priority signal, meaning that it is aware of it.
}
func (sc *serverConn) writeSchedIgnoresRFC7540() bool {
switch sc.writeSched.(type) {
case *priorityWriteSchedulerRFC9218:
return true
case *randomWriteScheduler:
return true
case *roundRobinWriteScheduler:
return true
default:
return false
}
}
func (sc *serverConn) maxHeaderListSize() uint32 {
@@ -687,12 +707,12 @@ type stream struct {
flow outflow // limits writing from Handler to client
inflow inflow // what the client is allowed to POST/etc to us
state streamState
resetQueued bool // RST_STREAM queued for write; set by sc.resetStream
gotTrailerHeader bool // HEADER frame for trailers was seen
wroteHeaders bool // whether we wrote headers (not status 100)
readDeadline timer // nil if unused
writeDeadline timer // nil if unused
closeErr error // set before cw is closed
resetQueued bool // RST_STREAM queued for write; set by sc.resetStream
gotTrailerHeader bool // HEADER frame for trailers was seen
wroteHeaders bool // whether we wrote headers (not status 100)
readDeadline *time.Timer // nil if unused
writeDeadline *time.Timer // nil if unused
closeErr error // set before cw is closed
trailer http.Header // accumulated trailers
reqTrailer http.Header // handler's Request.Trailer
@@ -848,7 +868,6 @@ type readFrameResult struct {
// consumer is done with the frame.
// It's run on its own goroutine.
func (sc *serverConn) readFrames() {
sc.srv.markNewGoroutine()
gate := make(chan struct{})
gateDone := func() { gate <- struct{}{} }
for {
@@ -881,7 +900,6 @@ type frameWriteResult struct {
// At most one goroutine can be running writeFrameAsync at a time per
// serverConn.
func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest, wd *writeData) {
sc.srv.markNewGoroutine()
var err error
if wd == nil {
err = wr.write.writeFrame(sc)
@@ -942,6 +960,9 @@ func (sc *serverConn) serve(conf http2Config) {
if !disableExtendedConnectProtocol {
settings = append(settings, Setting{SettingEnableConnectProtocol, 1})
}
if sc.writeSchedIgnoresRFC7540() {
settings = append(settings, Setting{SettingNoRFC7540Priorities, 1})
}
sc.writeFrame(FrameWriteRequest{
write: settings,
})
@@ -965,22 +986,22 @@ func (sc *serverConn) serve(conf http2Config) {
sc.setConnState(http.StateIdle)
if sc.srv.IdleTimeout > 0 {
sc.idleTimer = sc.srv.afterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
defer sc.idleTimer.Stop()
}
if conf.SendPingTimeout > 0 {
sc.readIdleTimeout = conf.SendPingTimeout
sc.readIdleTimer = sc.srv.afterFunc(conf.SendPingTimeout, sc.onReadIdleTimer)
sc.readIdleTimer = time.AfterFunc(conf.SendPingTimeout, sc.onReadIdleTimer)
defer sc.readIdleTimer.Stop()
}
go sc.readFrames() // closed by defer sc.conn.Close above
settingsTimer := sc.srv.afterFunc(firstSettingsTimeout, sc.onSettingsTimer)
settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer)
defer settingsTimer.Stop()
lastFrameTime := sc.srv.now()
lastFrameTime := time.Now()
loopNum := 0
for {
loopNum++
@@ -994,7 +1015,7 @@ func (sc *serverConn) serve(conf http2Config) {
case res := <-sc.wroteFrameCh:
sc.wroteFrame(res)
case res := <-sc.readFrameCh:
lastFrameTime = sc.srv.now()
lastFrameTime = time.Now()
// Process any written frames before reading new frames from the client since a
// written frame could have triggered a new stream to be started.
if sc.writingFrameAsync {
@@ -1077,7 +1098,7 @@ func (sc *serverConn) handlePingTimer(lastFrameReadTime time.Time) {
}
pingAt := lastFrameReadTime.Add(sc.readIdleTimeout)
now := sc.srv.now()
now := time.Now()
if pingAt.After(now) {
// We received frames since arming the ping timer.
// Reset it for the next possible timeout.
@@ -1141,10 +1162,10 @@ func (sc *serverConn) readPreface() error {
errc <- nil
}
}()
timer := sc.srv.newTimer(prefaceTimeout) // TODO: configurable on *Server?
timer := time.NewTimer(prefaceTimeout) // TODO: configurable on *Server?
defer timer.Stop()
select {
case <-timer.C():
case <-timer.C:
return errPrefaceTimeout
case err := <-errc:
if err == nil {
@@ -1156,10 +1177,6 @@ func (sc *serverConn) readPreface() error {
}
}
var errChanPool = sync.Pool{
New: func() interface{} { return make(chan error, 1) },
}
var writeDataPool = sync.Pool{
New: func() interface{} { return new(writeData) },
}
@@ -1167,7 +1184,7 @@ var writeDataPool = sync.Pool{
// writeDataFromHandler writes DATA response frames from a handler on
// the given stream.
func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStream bool) error {
ch := errChanPool.Get().(chan error)
ch := sc.srv.state.getErrChan()
writeArg := writeDataPool.Get().(*writeData)
*writeArg = writeData{stream.id, data, endStream}
err := sc.writeFrameFromHandler(FrameWriteRequest{
@@ -1199,7 +1216,7 @@ func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStrea
return errStreamClosed
}
}
errChanPool.Put(ch)
sc.srv.state.putErrChan(ch)
if frameWriteDone {
writeDataPool.Put(writeArg)
}
@@ -1513,7 +1530,7 @@ func (sc *serverConn) goAway(code ErrCode) {
func (sc *serverConn) shutDownIn(d time.Duration) {
sc.serveG.check()
sc.shutdownTimer = sc.srv.afterFunc(d, sc.onShutdownTimer)
sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer)
}
func (sc *serverConn) resetStream(se StreamError) {
@@ -1631,6 +1648,8 @@ func (sc *serverConn) processFrame(f Frame) error {
// A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE
// frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
return sc.countError("push_promise", ConnectionError(ErrCodeProtocol))
case *PriorityUpdateFrame:
return sc.processPriorityUpdate(f)
default:
sc.vlogf("http2: server ignoring frame: %v", f.Header())
return nil
@@ -1811,6 +1830,10 @@ func (sc *serverConn) processSetting(s Setting) error {
case SettingEnableConnectProtocol:
// Receipt of this parameter by a server does not
// have any impact
case SettingNoRFC7540Priorities:
if s.Val > 1 {
return ConnectionError(ErrCodeProtocol)
}
default:
// Unknown setting: "An endpoint that receives a SETTINGS
// frame with any unknown or unsupported identifier MUST
@@ -2081,13 +2104,33 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
if f.StreamEnded() {
initialState = stateHalfClosedRemote
}
st := sc.newStream(id, 0, initialState)
// We are handling two special cases here:
// 1. When a request is sent via an intermediary, we force priority to be
// u=3,i. This is essentially a round-robin behavior, and is done to ensure
// fairness between, for example, multiple clients using the same proxy.
// 2. Until a client has shown that it is aware of RFC 9218, we make its
// streams non-incremental by default. This is done to preserve the
// historical behavior of handling streams in a round-robin manner, rather
// than one-by-one to completion.
initialPriority := defaultRFC9218Priority(sc.priorityAware && !sc.hasIntermediary)
if _, ok := sc.writeSched.(*priorityWriteSchedulerRFC9218); ok && !sc.hasIntermediary {
headerPriority, priorityAware, hasIntermediary := f.rfc9218Priority(sc.priorityAware)
initialPriority = headerPriority
sc.hasIntermediary = hasIntermediary
if priorityAware {
sc.priorityAware = true
}
}
st := sc.newStream(id, 0, initialState, initialPriority)
if f.HasPriority() {
if err := sc.checkPriority(f.StreamID, f.Priority); err != nil {
return err
}
sc.writeSched.AdjustStream(st.id, f.Priority)
if !sc.writeSchedIgnoresRFC7540() {
sc.writeSched.AdjustStream(st.id, f.Priority)
}
}
rw, req, err := sc.newWriterAndRequest(st, f)
@@ -2118,7 +2161,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// (in Go 1.8), though. That's a more sane option anyway.
if sc.hs.ReadTimeout > 0 {
sc.conn.SetReadDeadline(time.Time{})
st.readDeadline = sc.srv.afterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
}
return sc.scheduleHandler(id, rw, req, handler)
@@ -2128,7 +2171,7 @@ func (sc *serverConn) upgradeRequest(req *http.Request) {
sc.serveG.check()
id := uint32(1)
sc.maxClientStreamID = id
st := sc.newStream(id, 0, stateHalfClosedRemote)
st := sc.newStream(id, 0, stateHalfClosedRemote, defaultRFC9218Priority(sc.priorityAware && !sc.hasIntermediary))
st.reqTrailer = req.Trailer
if st.reqTrailer != nil {
st.trailer = make(http.Header)
@@ -2193,11 +2236,32 @@ func (sc *serverConn) processPriority(f *PriorityFrame) error {
if err := sc.checkPriority(f.StreamID, f.PriorityParam); err != nil {
return err
}
// We need to avoid calling AdjustStream when using the RFC 9218 write
// scheduler. Otherwise, incremental's zero value in PriorityParam will
// unexpectedly make all streams non-incremental. This causes us to process
// streams one-by-one to completion rather than doing it in a round-robin
// manner (the historical behavior), which might be unexpected to users.
if sc.writeSchedIgnoresRFC7540() {
return nil
}
sc.writeSched.AdjustStream(f.StreamID, f.PriorityParam)
return nil
}
func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream {
func (sc *serverConn) processPriorityUpdate(f *PriorityUpdateFrame) error {
sc.priorityAware = true
if _, ok := sc.writeSched.(*priorityWriteSchedulerRFC9218); !ok {
return nil
}
p, ok := parseRFC9218Priority(f.Priority, sc.priorityAware)
if !ok {
return sc.countError("unparsable_priority_update", streamError(f.PrioritizedStreamID, ErrCodeProtocol))
}
sc.writeSched.AdjustStream(f.PrioritizedStreamID, p)
return nil
}
func (sc *serverConn) newStream(id, pusherID uint32, state streamState, priority PriorityParam) *stream {
sc.serveG.check()
if id == 0 {
panic("internal error: cannot create stream with id 0")
@@ -2216,11 +2280,11 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
st.flow.add(sc.initialStreamSendWindowSize)
st.inflow.init(sc.initialStreamRecvWindowSize)
if sc.hs.WriteTimeout > 0 {
st.writeDeadline = sc.srv.afterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
}
sc.streams[id] = st
sc.writeSched.OpenStream(st.id, OpenStreamOptions{PusherID: pusherID})
sc.writeSched.OpenStream(st.id, OpenStreamOptions{PusherID: pusherID, priority: priority})
if st.isPushed() {
sc.curPushedStreams++
} else {
@@ -2405,7 +2469,6 @@ func (sc *serverConn) handlerDone() {
// Run on its own goroutine.
func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) {
sc.srv.markNewGoroutine()
defer sc.sendServeMsg(handlerDoneMsg)
didPanic := true
defer func() {
@@ -2454,7 +2517,7 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) erro
// waiting for this frame to be written, so an http.Flush mid-handler
// writes out the correct value of keys, before a handler later potentially
// mutates it.
errc = errChanPool.Get().(chan error)
errc = sc.srv.state.getErrChan()
}
if err := sc.writeFrameFromHandler(FrameWriteRequest{
write: headerData,
@@ -2466,7 +2529,7 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) erro
if errc != nil {
select {
case err := <-errc:
errChanPool.Put(errc)
sc.srv.state.putErrChan(errc)
return err
case <-sc.doneServing:
return errClientDisconnected
@@ -2573,7 +2636,7 @@ func (b *requestBody) Read(p []byte) (n int, err error) {
if err == io.EOF {
b.sawEOF = true
}
if b.conn == nil && inTests {
if b.conn == nil {
return
}
b.conn.noteBodyReadFromHandler(b.stream, n, err)
@@ -2702,7 +2765,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
var date string
if _, ok := rws.snapHeader["Date"]; !ok {
// TODO(bradfitz): be faster here, like net/http? measure.
date = rws.conn.srv.now().UTC().Format(http.TimeFormat)
date = time.Now().UTC().Format(http.TimeFormat)
}
for _, v := range rws.snapHeader["Trailer"] {
@@ -2824,7 +2887,7 @@ func (rws *responseWriterState) promoteUndeclaredTrailers() {
func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
st := w.rws.stream
if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) {
if !deadline.IsZero() && deadline.Before(time.Now()) {
// If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail.
st.onReadTimeout()
@@ -2840,9 +2903,9 @@ func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
if deadline.IsZero() {
st.readDeadline = nil
} else if st.readDeadline == nil {
st.readDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onReadTimeout)
st.readDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onReadTimeout)
} else {
st.readDeadline.Reset(deadline.Sub(sc.srv.now()))
st.readDeadline.Reset(deadline.Sub(time.Now()))
}
})
return nil
@@ -2850,7 +2913,7 @@ func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
st := w.rws.stream
if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) {
if !deadline.IsZero() && deadline.Before(time.Now()) {
// If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail.
st.onWriteTimeout()
@@ -2866,9 +2929,9 @@ func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
if deadline.IsZero() {
st.writeDeadline = nil
} else if st.writeDeadline == nil {
st.writeDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onWriteTimeout)
st.writeDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onWriteTimeout)
} else {
st.writeDeadline.Reset(deadline.Sub(sc.srv.now()))
st.writeDeadline.Reset(deadline.Sub(time.Now()))
}
})
return nil
@@ -3147,7 +3210,7 @@ func (w *responseWriter) Push(target string, opts *http.PushOptions) error {
method: opts.Method,
url: u,
header: cloneHeader(opts.Header),
done: errChanPool.Get().(chan error),
done: sc.srv.state.getErrChan(),
}
select {
@@ -3164,7 +3227,7 @@ func (w *responseWriter) Push(target string, opts *http.PushOptions) error {
case <-st.cw:
return errStreamClosed
case err := <-msg.done:
errChanPool.Put(msg.done)
sc.srv.state.putErrChan(msg.done)
return err
}
}
@@ -3227,7 +3290,7 @@ func (sc *serverConn) startPush(msg *startPushRequest) {
// transition to "half closed (remote)" after sending the initial HEADERS, but
// we start in "half closed (remote)" for simplicity.
// See further comments at the definition of stateHalfClosedRemote.
promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote)
promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote, defaultRFC9218Priority(sc.priorityAware && !sc.hasIntermediary))
rw, req, err := sc.newWriterAndRequestNoBody(promised, httpcommon.ServerRequestParam{
Method: msg.method,
Scheme: msg.url.Scheme,
-20
View File
@@ -1,20 +0,0 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import "time"
// A timer is a time.Timer, as an interface which can be replaced in tests.
type timer = interface {
C() <-chan time.Time
Reset(d time.Duration) bool
Stop() bool
}
// timeTimer adapts a time.Timer to the timer interface.
type timeTimer struct {
*time.Timer
}
func (t timeTimer) C() <-chan time.Time { return t.Timer.C }
+266 -120
View File
@@ -9,6 +9,7 @@ package http2
import (
"bufio"
"bytes"
"compress/flate"
"compress/gzip"
"context"
"crypto/rand"
@@ -193,50 +194,6 @@ type Transport struct {
type transportTestHooks struct {
newclientconn func(*ClientConn)
group synctestGroupInterface
}
func (t *Transport) markNewGoroutine() {
if t != nil && t.transportTestHooks != nil {
t.transportTestHooks.group.Join()
}
}
func (t *Transport) now() time.Time {
if t != nil && t.transportTestHooks != nil {
return t.transportTestHooks.group.Now()
}
return time.Now()
}
func (t *Transport) timeSince(when time.Time) time.Duration {
if t != nil && t.transportTestHooks != nil {
return t.now().Sub(when)
}
return time.Since(when)
}
// newTimer creates a new time.Timer, or a synthetic timer in tests.
func (t *Transport) newTimer(d time.Duration) timer {
if t.transportTestHooks != nil {
return t.transportTestHooks.group.NewTimer(d)
}
return timeTimer{time.NewTimer(d)}
}
// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
func (t *Transport) afterFunc(d time.Duration, f func()) timer {
if t.transportTestHooks != nil {
return t.transportTestHooks.group.AfterFunc(d, f)
}
return timeTimer{time.AfterFunc(d, f)}
}
func (t *Transport) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
if t.transportTestHooks != nil {
return t.transportTestHooks.group.ContextWithTimeout(ctx, d)
}
return context.WithTimeout(ctx, d)
}
func (t *Transport) maxHeaderListSize() uint32 {
@@ -366,7 +323,7 @@ type ClientConn struct {
readerErr error // set before readerDone is closed
idleTimeout time.Duration // or 0 for never
idleTimer timer
idleTimer *time.Timer
mu sync.Mutex // guards following
cond *sync.Cond // hold mu; broadcast on flow/closed changes
@@ -399,6 +356,7 @@ type ClientConn struct {
readIdleTimeout time.Duration
pingTimeout time.Duration
extendedConnectAllowed bool
strictMaxConcurrentStreams bool
// rstStreamPingsBlocked works around an unfortunate gRPC behavior.
// gRPC strictly limits the number of PING frames that it will receive.
@@ -418,11 +376,24 @@ type ClientConn struct {
// completely unresponsive connection.
pendingResets int
// readBeforeStreamID is the smallest stream ID that has not been followed by
// a frame read from the peer. We use this to determine when a request may
// have been sent to a completely unresponsive connection:
// If the request ID is less than readBeforeStreamID, then we have had some
// indication of life on the connection since sending the request.
readBeforeStreamID uint32
// reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests.
// Write to reqHeaderMu to lock it, read from it to unlock.
// Lock reqmu BEFORE mu or wmu.
reqHeaderMu chan struct{}
// internalStateHook reports state changes back to the net/http.ClientConn.
// Note that this is different from the user state hook registered by
// net/http.ClientConn.SetStateHook: The internal hook calls ClientConn,
// which calls the user hook.
internalStateHook func()
// wmu is held while writing.
// Acquire BEFORE mu when holding both, to avoid blocking mu on network writes.
// Only acquire both at the same time when changing peer settings.
@@ -534,14 +505,12 @@ func (cs *clientStream) closeReqBodyLocked() {
cs.reqBodyClosed = make(chan struct{})
reqBodyClosed := cs.reqBodyClosed
go func() {
cs.cc.t.markNewGoroutine()
cs.reqBody.Close()
close(reqBodyClosed)
}()
}
type stickyErrWriter struct {
group synctestGroupInterface
conn net.Conn
timeout time.Duration
err *error
@@ -551,7 +520,7 @@ func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
if *sew.err != nil {
return 0, *sew.err
}
n, err = writeWithByteTimeout(sew.group, sew.conn, sew.timeout, p)
n, err = writeWithByteTimeout(sew.conn, sew.timeout, p)
*sew.err = err
return n, err
}
@@ -650,9 +619,9 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
backoff := float64(uint(1) << (uint(retry) - 1))
backoff += backoff * (0.1 * mathrand.Float64())
d := time.Second * time.Duration(backoff)
tm := t.newTimer(d)
tm := time.NewTimer(d)
select {
case <-tm.C():
case <-tm.C:
t.vlogf("RoundTrip retrying after failure: %v", roundTripErr)
continue
case <-req.Context().Done():
@@ -699,6 +668,7 @@ var (
errClientConnUnusable = errors.New("http2: client conn not usable")
errClientConnNotEstablished = errors.New("http2: client conn could not be established")
errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
errClientConnForceClosed = errors.New("http2: client connection force closed via ClientConn.Close")
)
// shouldRetryRequest is called by RoundTrip when a request fails to get
@@ -742,19 +712,12 @@ func canRetryError(err error) bool {
return true
}
if se, ok := err.(StreamError); ok {
if se.Code == ErrCodeProtocol && se.Cause == errFromPeer {
// See golang/go#47635, golang/go#42777
return true
}
return se.Code == ErrCodeRefusedStream
}
return false
}
func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) {
if t.transportTestHooks != nil {
return t.newClientConn(nil, singleUse)
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
@@ -763,7 +726,7 @@ func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse b
if err != nil {
return nil, err
}
return t.newClientConn(tconn, singleUse)
return t.newClientConn(tconn, singleUse, nil)
}
func (t *Transport) newTLSConfig(host string) *tls.Config {
@@ -815,10 +778,10 @@ func (t *Transport) expectContinueTimeout() time.Duration {
}
func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
return t.newClientConn(c, t.disableKeepAlives())
return t.newClientConn(c, t.disableKeepAlives(), nil)
}
func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) {
func (t *Transport) newClientConn(c net.Conn, singleUse bool, internalStateHook func()) (*ClientConn, error) {
conf := configFromTransport(t)
cc := &ClientConn{
t: t,
@@ -829,7 +792,8 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
initialWindowSize: 65535, // spec default
initialStreamRecvWindowSize: conf.MaxUploadBufferPerStream,
maxConcurrentStreams: initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings.
peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead.
strictMaxConcurrentStreams: conf.StrictMaxConcurrentRequests,
peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead.
streams: make(map[uint32]*clientStream),
singleUse: singleUse,
seenSettingsChan: make(chan struct{}),
@@ -838,14 +802,12 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
pingTimeout: conf.PingTimeout,
pings: make(map[[8]byte]chan struct{}),
reqHeaderMu: make(chan struct{}, 1),
lastActive: t.now(),
lastActive: time.Now(),
internalStateHook: internalStateHook,
}
var group synctestGroupInterface
if t.transportTestHooks != nil {
t.markNewGoroutine()
t.transportTestHooks.newclientconn(cc)
c = cc.tconn
group = t.group
}
if VerboseLogs {
t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
@@ -857,7 +819,6 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
// TODO: adjust this writer size to account for frame size +
// MTU + crypto/tls record padding.
cc.bw = bufio.NewWriter(stickyErrWriter{
group: group,
conn: c,
timeout: conf.WriteByteTimeout,
err: &cc.werr,
@@ -906,7 +867,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
// Start the idle timer after the connection is fully initialized.
if d := t.idleConnTimeout(); d != 0 {
cc.idleTimeout = d
cc.idleTimer = t.afterFunc(d, cc.onIdleTimeout)
cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout)
}
go cc.readLoop()
@@ -917,7 +878,7 @@ func (cc *ClientConn) healthCheck() {
pingTimeout := cc.pingTimeout
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
// trigger the healthCheck again if there is no frame received.
ctx, cancel := cc.t.contextWithTimeout(context.Background(), pingTimeout)
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
cc.vlogf("http2: Transport sending health check")
err := cc.Ping(ctx)
@@ -1067,7 +1028,7 @@ func (cc *ClientConn) idleStateLocked() (st clientConnIdleState) {
return
}
var maxConcurrentOkay bool
if cc.t.StrictMaxConcurrentStreams {
if cc.strictMaxConcurrentStreams {
// We'll tell the caller we can take a new request to
// prevent the caller from dialing a new TCP
// connection, but then we'll block later before
@@ -1083,10 +1044,7 @@ func (cc *ClientConn) idleStateLocked() (st clientConnIdleState) {
maxConcurrentOkay = cc.currentRequestCountLocked() < int(cc.maxConcurrentStreams)
}
st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay &&
!cc.doNotReuse &&
int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 &&
!cc.tooIdleLocked()
st.canTakeNewRequest = maxConcurrentOkay && cc.isUsableLocked()
// If this connection has never been used for a request and is closed,
// then let it take a request (which will fail).
@@ -1102,6 +1060,31 @@ func (cc *ClientConn) idleStateLocked() (st clientConnIdleState) {
return
}
func (cc *ClientConn) isUsableLocked() bool {
return cc.goAway == nil &&
!cc.closed &&
!cc.closing &&
!cc.doNotReuse &&
int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 &&
!cc.tooIdleLocked()
}
// canReserveLocked reports whether a net/http.ClientConn can reserve a slot on this conn.
//
// This follows slightly different rules than clientConnIdleState.canTakeNewRequest.
// We only permit reservations up to the conn's concurrency limit.
// This differs from ClientConn.ReserveNewRequest, which permits reservations
// past the limit when StrictMaxConcurrentStreams is set.
func (cc *ClientConn) canReserveLocked() bool {
if cc.currentRequestCountLocked() >= int(cc.maxConcurrentStreams) {
return false
}
if !cc.isUsableLocked() {
return false
}
return true
}
// currentRequestCountLocked reports the number of concurrency slots currently in use,
// including active streams, reserved slots, and reset streams waiting for acknowledgement.
func (cc *ClientConn) currentRequestCountLocked() int {
@@ -1113,6 +1096,14 @@ func (cc *ClientConn) canTakeNewRequestLocked() bool {
return st.canTakeNewRequest
}
// availableLocked reports the number of concurrency slots available.
func (cc *ClientConn) availableLocked() int {
if !cc.canTakeNewRequestLocked() {
return 0
}
return max(0, int(cc.maxConcurrentStreams)-cc.currentRequestCountLocked())
}
// tooIdleLocked reports whether this connection has been been sitting idle
// for too much wall time.
func (cc *ClientConn) tooIdleLocked() bool {
@@ -1120,7 +1111,7 @@ func (cc *ClientConn) tooIdleLocked() bool {
// times are compared based on their wall time. We don't want
// to reuse a connection that's been sitting idle during
// VM/laptop suspend if monotonic time was also frozen.
return cc.idleTimeout != 0 && !cc.lastIdle.IsZero() && cc.t.timeSince(cc.lastIdle.Round(0)) > cc.idleTimeout
return cc.idleTimeout != 0 && !cc.lastIdle.IsZero() && time.Since(cc.lastIdle.Round(0)) > cc.idleTimeout
}
// onIdleTimeout is called from a time.AfterFunc goroutine. It will
@@ -1137,6 +1128,7 @@ func (cc *ClientConn) closeConn() {
t := time.AfterFunc(250*time.Millisecond, cc.forceCloseConn)
defer t.Stop()
cc.tconn.Close()
cc.maybeCallStateHook()
}
// A tls.Conn.Close can hang for a long time if the peer is unresponsive.
@@ -1186,7 +1178,6 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
done := make(chan struct{})
cancelled := false // guarded by cc.mu
go func() {
cc.t.markNewGoroutine()
cc.mu.Lock()
defer cc.mu.Unlock()
for {
@@ -1257,8 +1248,7 @@ func (cc *ClientConn) closeForError(err error) {
//
// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
func (cc *ClientConn) Close() error {
err := errors.New("http2: client connection force closed via ClientConn.Close")
cc.closeForError(err)
cc.closeForError(errClientConnForceClosed)
return nil
}
@@ -1427,7 +1417,6 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream))
//
// It sends the request and performs post-request cleanup (closing Request.Body, etc.).
func (cs *clientStream) doRequest(req *http.Request, streamf func(*clientStream)) {
cs.cc.t.markNewGoroutine()
err := cs.writeRequest(req, streamf)
cs.cleanupWriteRequest(err)
}
@@ -1558,9 +1547,9 @@ func (cs *clientStream) writeRequest(req *http.Request, streamf func(*clientStre
var respHeaderTimer <-chan time.Time
var respHeaderRecv chan struct{}
if d := cc.responseHeaderTimeout(); d != 0 {
timer := cc.t.newTimer(d)
timer := time.NewTimer(d)
defer timer.Stop()
respHeaderTimer = timer.C()
respHeaderTimer = timer.C
respHeaderRecv = cs.respHeaderRecv
}
// Wait until the peer half-closes its end of the stream,
@@ -1665,6 +1654,8 @@ func (cs *clientStream) cleanupWriteRequest(err error) {
}
bodyClosed := cs.reqBodyClosed
closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil
// Have we read any frames from the connection since sending this request?
readSinceStream := cc.readBeforeStreamID > cs.ID
cc.mu.Unlock()
if mustCloseBody {
cs.reqBody.Close()
@@ -1696,8 +1687,10 @@ func (cs *clientStream) cleanupWriteRequest(err error) {
//
// This could be due to the server becoming unresponsive.
// To avoid sending too many requests on a dead connection,
// we let the request continue to consume a concurrency slot
// until we can confirm the server is still responding.
// if we haven't read any frames from the connection since
// sending this request, we let it continue to consume
// a concurrency slot until we can confirm the server is
// still responding.
// We do this by sending a PING frame along with the RST_STREAM
// (unless a ping is already in flight).
//
@@ -1708,7 +1701,7 @@ func (cs *clientStream) cleanupWriteRequest(err error) {
// because it's short lived and will probably be closed before
// we get the ping response.
ping := false
if !closeOnIdle {
if !closeOnIdle && !readSinceStream {
cc.mu.Lock()
// rstStreamPingsBlocked works around a gRPC behavior:
// see comment on the field for details.
@@ -1742,6 +1735,7 @@ func (cs *clientStream) cleanupWriteRequest(err error) {
}
close(cs.donec)
cc.maybeCallStateHook()
}
// awaitOpenSlotForStreamLocked waits until len(streams) < maxConcurrentStreams.
@@ -1753,7 +1747,7 @@ func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error {
// Return a fatal error which aborts the retry loop.
return errClientConnNotEstablished
}
cc.lastActive = cc.t.now()
cc.lastActive = time.Now()
if cc.closed || !cc.canTakeNewRequestLocked() {
return errClientConnUnusable
}
@@ -2092,10 +2086,10 @@ func (cc *ClientConn) forgetStreamID(id uint32) {
if len(cc.streams) != slen-1 {
panic("forgetting unknown stream id")
}
cc.lastActive = cc.t.now()
cc.lastActive = time.Now()
if len(cc.streams) == 0 && cc.idleTimer != nil {
cc.idleTimer.Reset(cc.idleTimeout)
cc.lastIdle = cc.t.now()
cc.lastIdle = time.Now()
}
// Wake up writeRequestBody via clientStream.awaitFlowControl and
// wake up RoundTrip if there is a pending request.
@@ -2121,7 +2115,6 @@ type clientConnReadLoop struct {
// readLoop runs in its own goroutine and reads and dispatches frames.
func (cc *ClientConn) readLoop() {
cc.t.markNewGoroutine()
rl := &clientConnReadLoop{cc: cc}
defer rl.cleanup()
cc.readerErr = rl.run()
@@ -2188,9 +2181,9 @@ func (rl *clientConnReadLoop) cleanup() {
if cc.idleTimeout > 0 && unusedWaitTime > cc.idleTimeout {
unusedWaitTime = cc.idleTimeout
}
idleTime := cc.t.now().Sub(cc.lastActive)
idleTime := time.Now().Sub(cc.lastActive)
if atomic.LoadUint32(&cc.atomicReused) == 0 && idleTime < unusedWaitTime && !cc.closedOnIdle {
cc.idleTimer = cc.t.afterFunc(unusedWaitTime-idleTime, func() {
cc.idleTimer = time.AfterFunc(unusedWaitTime-idleTime, func() {
cc.t.connPool().MarkDead(cc)
})
} else {
@@ -2250,9 +2243,9 @@ func (rl *clientConnReadLoop) run() error {
cc := rl.cc
gotSettings := false
readIdleTimeout := cc.readIdleTimeout
var t timer
var t *time.Timer
if readIdleTimeout != 0 {
t = cc.t.afterFunc(readIdleTimeout, cc.healthCheck)
t = time.AfterFunc(readIdleTimeout, cc.healthCheck)
}
for {
f, err := cc.fr.ReadFrame()
@@ -2779,6 +2772,11 @@ func (rl *clientConnReadLoop) endStreamError(cs *clientStream, err error) {
cs.abortStream(err)
}
func (rl *clientConnReadLoop) endStreamErrorLocked(cs *clientStream, err error) {
cs.readAborted = true
cs.abortStreamLocked(err)
}
// Constants passed to streamByID for documentation purposes.
const (
headerOrDataFrame = true
@@ -2795,6 +2793,7 @@ func (rl *clientConnReadLoop) streamByID(id uint32, headerOrData bool) *clientSt
// See comment on ClientConn.rstStreamPingsBlocked for details.
rl.cc.rstStreamPingsBlocked = false
}
rl.cc.readBeforeStreamID = rl.cc.nextStreamID
cs := rl.cc.streams[id]
if cs != nil && !cs.readAborted {
return cs
@@ -2845,6 +2844,7 @@ func (rl *clientConnReadLoop) processSettings(f *SettingsFrame) error {
func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error {
cc := rl.cc
defer cc.maybeCallStateHook()
cc.mu.Lock()
defer cc.mu.Unlock()
@@ -2858,6 +2858,9 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error {
var seenMaxConcurrentStreams bool
err := f.ForeachSetting(func(s Setting) error {
if err := s.Valid(); err != nil {
return err
}
switch s.ID {
case SettingMaxFrameSize:
cc.maxFrameSize = s.Val
@@ -2889,9 +2892,6 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error {
cc.henc.SetMaxDynamicTableSize(s.Val)
cc.peerMaxHeaderTableSize = s.Val
case SettingEnableConnectProtocol:
if err := s.Valid(); err != nil {
return err
}
// If the peer wants to send us SETTINGS_ENABLE_CONNECT_PROTOCOL,
// we require that it do so in the first SETTINGS frame.
//
@@ -2944,7 +2944,7 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error {
if !fl.add(int32(f.Increment)) {
// For stream, the sender sends RST_STREAM with an error code of FLOW_CONTROL_ERROR
if cs != nil {
rl.endStreamError(cs, StreamError{
rl.endStreamErrorLocked(cs, StreamError{
StreamID: f.StreamID,
Code: ErrCodeFlowControl,
})
@@ -2998,7 +2998,6 @@ func (cc *ClientConn) Ping(ctx context.Context) error {
var pingError error
errc := make(chan struct{})
go func() {
cc.t.markNewGoroutine()
cc.wmu.Lock()
defer cc.wmu.Unlock()
if pingError = cc.fr.WritePing(false, p); pingError != nil {
@@ -3026,6 +3025,7 @@ func (cc *ClientConn) Ping(ctx context.Context) error {
func (rl *clientConnReadLoop) processPing(f *PingFrame) error {
if f.IsAck() {
cc := rl.cc
defer cc.maybeCallStateHook()
cc.mu.Lock()
defer cc.mu.Unlock()
// If ack, notify listener if any
@@ -3128,41 +3128,104 @@ type erringRoundTripper struct{ err error }
func (rt erringRoundTripper) RoundTripErr() error { return rt.err }
func (rt erringRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { return nil, rt.err }
var errConcurrentReadOnResBody = errors.New("http2: concurrent read on response body")
// gzipReader wraps a response body so it can lazily
// call gzip.NewReader on the first call to Read
// get gzip.Reader from the pool on the first call to Read.
// After Close is called it puts gzip.Reader to the pool immediately
// if there is no Read in progress or later when Read completes.
type gzipReader struct {
_ incomparable
body io.ReadCloser // underlying Response.Body
zr *gzip.Reader // lazily-initialized gzip reader
zerr error // sticky error
mu sync.Mutex // guards zr and zerr
zr *gzip.Reader // stores gzip reader from the pool between reads
zerr error // sticky gzip reader init error or sentinel value to detect concurrent read and read after close
}
type eofReader struct{}
func (eofReader) Read([]byte) (int, error) { return 0, io.EOF }
func (eofReader) ReadByte() (byte, error) { return 0, io.EOF }
var gzipPool = sync.Pool{New: func() any { return new(gzip.Reader) }}
// gzipPoolGet gets a gzip.Reader from the pool and resets it to read from r.
func gzipPoolGet(r io.Reader) (*gzip.Reader, error) {
zr := gzipPool.Get().(*gzip.Reader)
if err := zr.Reset(r); err != nil {
gzipPoolPut(zr)
return nil, err
}
return zr, nil
}
// gzipPoolPut puts a gzip.Reader back into the pool.
func gzipPoolPut(zr *gzip.Reader) {
// Reset will allocate bufio.Reader if we pass it anything
// other than a flate.Reader, so ensure that it's getting one.
var r flate.Reader = eofReader{}
zr.Reset(r)
gzipPool.Put(zr)
}
// acquire returns a gzip.Reader for reading response body.
// The reader must be released after use.
func (gz *gzipReader) acquire() (*gzip.Reader, error) {
gz.mu.Lock()
defer gz.mu.Unlock()
if gz.zerr != nil {
return nil, gz.zerr
}
if gz.zr == nil {
gz.zr, gz.zerr = gzipPoolGet(gz.body)
if gz.zerr != nil {
return nil, gz.zerr
}
}
ret := gz.zr
gz.zr, gz.zerr = nil, errConcurrentReadOnResBody
return ret, nil
}
// release returns the gzip.Reader to the pool if Close was called during Read.
func (gz *gzipReader) release(zr *gzip.Reader) {
gz.mu.Lock()
defer gz.mu.Unlock()
if gz.zerr == errConcurrentReadOnResBody {
gz.zr, gz.zerr = zr, nil
} else { // fs.ErrClosed
gzipPoolPut(zr)
}
}
// close returns the gzip.Reader to the pool immediately or
// signals release to do so after Read completes.
func (gz *gzipReader) close() {
gz.mu.Lock()
defer gz.mu.Unlock()
if gz.zerr == nil && gz.zr != nil {
gzipPoolPut(gz.zr)
gz.zr = nil
}
gz.zerr = fs.ErrClosed
}
func (gz *gzipReader) Read(p []byte) (n int, err error) {
if gz.zerr != nil {
return 0, gz.zerr
zr, err := gz.acquire()
if err != nil {
return 0, err
}
if gz.zr == nil {
gz.zr, err = gzip.NewReader(gz.body)
if err != nil {
gz.zerr = err
return 0, err
}
}
return gz.zr.Read(p)
defer gz.release(zr)
return zr.Read(p)
}
func (gz *gzipReader) Close() error {
if err := gz.body.Close(); err != nil {
return err
}
gz.zerr = fs.ErrClosed
return nil
gz.close()
return gz.body.Close()
}
type errorReader struct{ err error }
func (r errorReader) Read(p []byte) (int, error) { return 0, r.err }
// isConnectionCloseRequest reports whether req should use its own
// connection for a single request and then close the connection.
func isConnectionCloseRequest(req *http.Request) bool {
@@ -3182,9 +3245,13 @@ func registerHTTPSProtocol(t *http.Transport, rt noDialH2RoundTripper) (err erro
}
// noDialH2RoundTripper is a RoundTripper which only tries to complete the request
// if there's already has a cached connection to the host.
// if there's already a cached connection to the host.
// (The field is exported so it can be accessed via reflect from net/http; tested
// by TestNoDialH2RoundTripperType)
//
// A noDialH2RoundTripper is registered with http1.Transport.RegisterProtocol,
// and the http1.Transport can use type assertions to call non-RoundTrip methods on it.
// This lets us expose, for example, NewClientConn to net/http.
type noDialH2RoundTripper struct{ *Transport }
func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -3195,6 +3262,85 @@ func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, err
return res, err
}
func (rt noDialH2RoundTripper) NewClientConn(conn net.Conn, internalStateHook func()) (http.RoundTripper, error) {
tr := rt.Transport
cc, err := tr.newClientConn(conn, tr.disableKeepAlives(), internalStateHook)
if err != nil {
return nil, err
}
// RoundTrip should block when the conn is at its concurrency limit,
// not return an error. Setting strictMaxConcurrentStreams enables this.
cc.strictMaxConcurrentStreams = true
return netHTTPClientConn{cc}, nil
}
// netHTTPClientConn wraps ClientConn and implements the interface net/http expects from
// the RoundTripper returned by NewClientConn.
type netHTTPClientConn struct {
cc *ClientConn
}
func (cc netHTTPClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
return cc.cc.RoundTrip(req)
}
func (cc netHTTPClientConn) Close() error {
return cc.cc.Close()
}
func (cc netHTTPClientConn) Err() error {
cc.cc.mu.Lock()
defer cc.cc.mu.Unlock()
if cc.cc.closed {
return errors.New("connection closed")
}
return nil
}
func (cc netHTTPClientConn) Reserve() error {
defer cc.cc.maybeCallStateHook()
cc.cc.mu.Lock()
defer cc.cc.mu.Unlock()
if !cc.cc.canReserveLocked() {
return errors.New("connection is unavailable")
}
cc.cc.streamsReserved++
return nil
}
func (cc netHTTPClientConn) Release() {
defer cc.cc.maybeCallStateHook()
cc.cc.mu.Lock()
defer cc.cc.mu.Unlock()
// We don't complain if streamsReserved is 0.
//
// This is consistent with RoundTrip: both Release and RoundTrip will
// consume a reservation iff one exists.
if cc.cc.streamsReserved > 0 {
cc.cc.streamsReserved--
}
}
func (cc netHTTPClientConn) Available() int {
cc.cc.mu.Lock()
defer cc.cc.mu.Unlock()
return cc.cc.availableLocked()
}
func (cc netHTTPClientConn) InFlight() int {
cc.cc.mu.Lock()
defer cc.cc.mu.Unlock()
return cc.cc.currentRequestCountLocked()
}
func (cc *ClientConn) maybeCallStateHook() {
if cc.internalStateHook != nil {
cc.internalStateHook()
}
}
func (t *Transport) idleConnTimeout() time.Duration {
// to keep things backwards compatible, we use non-zero values of
// IdleConnTimeout, followed by using the IdleConnTimeout on the underlying
@@ -3228,7 +3374,7 @@ func traceGotConn(req *http.Request, cc *ClientConn, reused bool) {
cc.mu.Lock()
ci.WasIdle = len(cc.streams) == 0 && reused
if ci.WasIdle && !cc.lastActive.IsZero() {
ci.IdleTime = cc.t.timeSince(cc.lastActive)
ci.IdleTime = time.Since(cc.lastActive)
}
cc.mu.Unlock()
+58 -15
View File
@@ -8,6 +8,8 @@ import "fmt"
// WriteScheduler is the interface implemented by HTTP/2 write schedulers.
// Methods are never called concurrently.
//
// Deprecated: User-provided write schedulers are deprecated.
type WriteScheduler interface {
// OpenStream opens a new stream in the write scheduler.
// It is illegal to call this with streamID=0 or with a streamID that is
@@ -38,13 +40,19 @@ type WriteScheduler interface {
}
// OpenStreamOptions specifies extra options for WriteScheduler.OpenStream.
//
// Deprecated: User-provided write schedulers are deprecated.
type OpenStreamOptions struct {
// PusherID is zero if the stream was initiated by the client. Otherwise,
// PusherID names the stream that pushed the newly opened stream.
PusherID uint32
// priority is used to set the priority of the newly opened stream.
priority PriorityParam
}
// FrameWriteRequest is a request to write a frame.
//
// Deprecated: User-provided write schedulers are deprecated.
type FrameWriteRequest struct {
// write is the interface value that does the writing, once the
// WriteScheduler has selected this frame to write. The write
@@ -183,45 +191,75 @@ func (wr *FrameWriteRequest) replyToWriter(err error) {
}
// writeQueue is used by implementations of WriteScheduler.
//
// Each writeQueue contains a queue of FrameWriteRequests, meant to store all
// FrameWriteRequests associated with a given stream. This is implemented as a
// two-stage queue: currQueue[currPos:] and nextQueue. Removing an item is done
// by incrementing currPos of currQueue. Adding an item is done by appending it
// to the nextQueue. If currQueue is empty when trying to remove an item, we
// can swap currQueue and nextQueue to remedy the situation.
// This two-stage queue is analogous to the use of two lists in Okasaki's
// purely functional queue but without the overhead of reversing the list when
// swapping stages.
//
// writeQueue also contains prev and next, this can be used by implementations
// of WriteScheduler to construct data structures that represent the order of
// writing between different streams (e.g. circular linked list).
type writeQueue struct {
s []FrameWriteRequest
currQueue []FrameWriteRequest
nextQueue []FrameWriteRequest
currPos int
prev, next *writeQueue
}
func (q *writeQueue) empty() bool { return len(q.s) == 0 }
func (q *writeQueue) empty() bool {
return (len(q.currQueue) - q.currPos + len(q.nextQueue)) == 0
}
func (q *writeQueue) push(wr FrameWriteRequest) {
q.s = append(q.s, wr)
q.nextQueue = append(q.nextQueue, wr)
}
func (q *writeQueue) shift() FrameWriteRequest {
if len(q.s) == 0 {
if q.empty() {
panic("invalid use of queue")
}
wr := q.s[0]
// TODO: less copy-happy queue.
copy(q.s, q.s[1:])
q.s[len(q.s)-1] = FrameWriteRequest{}
q.s = q.s[:len(q.s)-1]
if q.currPos >= len(q.currQueue) {
q.currQueue, q.currPos, q.nextQueue = q.nextQueue, 0, q.currQueue[:0]
}
wr := q.currQueue[q.currPos]
q.currQueue[q.currPos] = FrameWriteRequest{}
q.currPos++
return wr
}
func (q *writeQueue) peek() *FrameWriteRequest {
if q.currPos < len(q.currQueue) {
return &q.currQueue[q.currPos]
}
if len(q.nextQueue) > 0 {
return &q.nextQueue[0]
}
return nil
}
// consume consumes up to n bytes from q.s[0]. If the frame is
// entirely consumed, it is removed from the queue. If the frame
// is partially consumed, the frame is kept with the consumed
// bytes removed. Returns true iff any bytes were consumed.
func (q *writeQueue) consume(n int32) (FrameWriteRequest, bool) {
if len(q.s) == 0 {
if q.empty() {
return FrameWriteRequest{}, false
}
consumed, rest, numresult := q.s[0].Consume(n)
consumed, rest, numresult := q.peek().Consume(n)
switch numresult {
case 0:
return FrameWriteRequest{}, false
case 1:
q.shift()
case 2:
q.s[0] = rest
*q.peek() = rest
}
return consumed, true
}
@@ -230,10 +268,15 @@ type writeQueuePool []*writeQueue
// put inserts an unused writeQueue into the pool.
func (p *writeQueuePool) put(q *writeQueue) {
for i := range q.s {
q.s[i] = FrameWriteRequest{}
for i := range q.currQueue {
q.currQueue[i] = FrameWriteRequest{}
}
q.s = q.s[:0]
for i := range q.nextQueue {
q.nextQueue[i] = FrameWriteRequest{}
}
q.currQueue = q.currQueue[:0]
q.nextQueue = q.nextQueue[:0]
q.currPos = 0
*p = append(*p, q)
}
@@ -11,9 +11,11 @@ import (
)
// RFC 7540, Section 5.3.5: the default weight is 16.
const priorityDefaultWeight = 15 // 16 = 15 + 1
const priorityDefaultWeightRFC7540 = 15 // 16 = 15 + 1
// PriorityWriteSchedulerConfig configures a priorityWriteScheduler.
//
// Deprecated: User-provided write schedulers are deprecated.
type PriorityWriteSchedulerConfig struct {
// MaxClosedNodesInTree controls the maximum number of closed streams to
// retain in the priority tree. Setting this to zero saves a small amount
@@ -55,7 +57,14 @@ type PriorityWriteSchedulerConfig struct {
// NewPriorityWriteScheduler constructs a WriteScheduler that schedules
// frames by following HTTP/2 priorities as described in RFC 7540 Section 5.3.
// If cfg is nil, default options are used.
//
// Deprecated: The RFC 7540 write scheduler has known bugs and performance issues,
// and RFC 7540 prioritization was deprecated in RFC 9113.
func NewPriorityWriteScheduler(cfg *PriorityWriteSchedulerConfig) WriteScheduler {
return newPriorityWriteSchedulerRFC7540(cfg)
}
func newPriorityWriteSchedulerRFC7540(cfg *PriorityWriteSchedulerConfig) WriteScheduler {
if cfg == nil {
// For justification of these defaults, see:
// https://docs.google.com/document/d/1oLhNg1skaWD4_DtaoCxdSRN5erEXrH-KnLrMwEpOtFY
@@ -66,8 +75,8 @@ func NewPriorityWriteScheduler(cfg *PriorityWriteSchedulerConfig) WriteScheduler
}
}
ws := &priorityWriteScheduler{
nodes: make(map[uint32]*priorityNode),
ws := &priorityWriteSchedulerRFC7540{
nodes: make(map[uint32]*priorityNodeRFC7540),
maxClosedNodesInTree: cfg.MaxClosedNodesInTree,
maxIdleNodesInTree: cfg.MaxIdleNodesInTree,
enableWriteThrottle: cfg.ThrottleOutOfOrderWrites,
@@ -81,32 +90,32 @@ func NewPriorityWriteScheduler(cfg *PriorityWriteSchedulerConfig) WriteScheduler
return ws
}
type priorityNodeState int
type priorityNodeStateRFC7540 int
const (
priorityNodeOpen priorityNodeState = iota
priorityNodeClosed
priorityNodeIdle
priorityNodeOpenRFC7540 priorityNodeStateRFC7540 = iota
priorityNodeClosedRFC7540
priorityNodeIdleRFC7540
)
// priorityNode is a node in an HTTP/2 priority tree.
// priorityNodeRFC7540 is a node in an HTTP/2 priority tree.
// Each node is associated with a single stream ID.
// See RFC 7540, Section 5.3.
type priorityNode struct {
q writeQueue // queue of pending frames to write
id uint32 // id of the stream, or 0 for the root of the tree
weight uint8 // the actual weight is weight+1, so the value is in [1,256]
state priorityNodeState // open | closed | idle
bytes int64 // number of bytes written by this node, or 0 if closed
subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree
type priorityNodeRFC7540 struct {
q writeQueue // queue of pending frames to write
id uint32 // id of the stream, or 0 for the root of the tree
weight uint8 // the actual weight is weight+1, so the value is in [1,256]
state priorityNodeStateRFC7540 // open | closed | idle
bytes int64 // number of bytes written by this node, or 0 if closed
subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree
// These links form the priority tree.
parent *priorityNode
kids *priorityNode // start of the kids list
prev, next *priorityNode // doubly-linked list of siblings
parent *priorityNodeRFC7540
kids *priorityNodeRFC7540 // start of the kids list
prev, next *priorityNodeRFC7540 // doubly-linked list of siblings
}
func (n *priorityNode) setParent(parent *priorityNode) {
func (n *priorityNodeRFC7540) setParent(parent *priorityNodeRFC7540) {
if n == parent {
panic("setParent to self")
}
@@ -141,7 +150,7 @@ func (n *priorityNode) setParent(parent *priorityNode) {
}
}
func (n *priorityNode) addBytes(b int64) {
func (n *priorityNodeRFC7540) addBytes(b int64) {
n.bytes += b
for ; n != nil; n = n.parent {
n.subtreeBytes += b
@@ -154,7 +163,7 @@ func (n *priorityNode) addBytes(b int64) {
//
// f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true
// if any ancestor p of n is still open (ignoring the root node).
func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f func(*priorityNode, bool) bool) bool {
func (n *priorityNodeRFC7540) walkReadyInOrder(openParent bool, tmp *[]*priorityNodeRFC7540, f func(*priorityNodeRFC7540, bool) bool) bool {
if !n.q.empty() && f(n, openParent) {
return true
}
@@ -165,7 +174,7 @@ func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f
// Don't consider the root "open" when updating openParent since
// we can't send data frames on the root stream (only control frames).
if n.id != 0 {
openParent = openParent || (n.state == priorityNodeOpen)
openParent = openParent || (n.state == priorityNodeOpenRFC7540)
}
// Common case: only one kid or all kids have the same weight.
@@ -195,7 +204,7 @@ func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f
*tmp = append(*tmp, n.kids)
n.kids.setParent(nil)
}
sort.Sort(sortPriorityNodeSiblings(*tmp))
sort.Sort(sortPriorityNodeSiblingsRFC7540(*tmp))
for i := len(*tmp) - 1; i >= 0; i-- {
(*tmp)[i].setParent(n) // setParent inserts at the head of n.kids
}
@@ -207,15 +216,15 @@ func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f
return false
}
type sortPriorityNodeSiblings []*priorityNode
type sortPriorityNodeSiblingsRFC7540 []*priorityNodeRFC7540
func (z sortPriorityNodeSiblings) Len() int { return len(z) }
func (z sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] }
func (z sortPriorityNodeSiblings) Less(i, k int) bool {
func (z sortPriorityNodeSiblingsRFC7540) Len() int { return len(z) }
func (z sortPriorityNodeSiblingsRFC7540) Swap(i, k int) { z[i], z[k] = z[k], z[i] }
func (z sortPriorityNodeSiblingsRFC7540) Less(i, k int) bool {
// Prefer the subtree that has sent fewer bytes relative to its weight.
// See sections 5.3.2 and 5.3.4.
wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes)
wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes)
wi, bi := float64(z[i].weight)+1, float64(z[i].subtreeBytes)
wk, bk := float64(z[k].weight)+1, float64(z[k].subtreeBytes)
if bi == 0 && bk == 0 {
return wi >= wk
}
@@ -225,13 +234,13 @@ func (z sortPriorityNodeSiblings) Less(i, k int) bool {
return bi/bk <= wi/wk
}
type priorityWriteScheduler struct {
type priorityWriteSchedulerRFC7540 struct {
// root is the root of the priority tree, where root.id = 0.
// The root queues control frames that are not associated with any stream.
root priorityNode
root priorityNodeRFC7540
// nodes maps stream ids to priority tree nodes.
nodes map[uint32]*priorityNode
nodes map[uint32]*priorityNodeRFC7540
// maxID is the maximum stream id in nodes.
maxID uint32
@@ -239,7 +248,7 @@ type priorityWriteScheduler struct {
// lists of nodes that have been closed or are idle, but are kept in
// the tree for improved prioritization. When the lengths exceed either
// maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded.
closedNodes, idleNodes []*priorityNode
closedNodes, idleNodes []*priorityNodeRFC7540
// From the config.
maxClosedNodesInTree int
@@ -248,19 +257,19 @@ type priorityWriteScheduler struct {
enableWriteThrottle bool
// tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations.
tmp []*priorityNode
tmp []*priorityNodeRFC7540
// pool of empty queues for reuse.
queuePool writeQueuePool
}
func (ws *priorityWriteScheduler) OpenStream(streamID uint32, options OpenStreamOptions) {
func (ws *priorityWriteSchedulerRFC7540) OpenStream(streamID uint32, options OpenStreamOptions) {
// The stream may be currently idle but cannot be opened or closed.
if curr := ws.nodes[streamID]; curr != nil {
if curr.state != priorityNodeIdle {
if curr.state != priorityNodeIdleRFC7540 {
panic(fmt.Sprintf("stream %d already opened", streamID))
}
curr.state = priorityNodeOpen
curr.state = priorityNodeOpenRFC7540
return
}
@@ -272,11 +281,11 @@ func (ws *priorityWriteScheduler) OpenStream(streamID uint32, options OpenStream
if parent == nil {
parent = &ws.root
}
n := &priorityNode{
n := &priorityNodeRFC7540{
q: *ws.queuePool.get(),
id: streamID,
weight: priorityDefaultWeight,
state: priorityNodeOpen,
weight: priorityDefaultWeightRFC7540,
state: priorityNodeOpenRFC7540,
}
n.setParent(parent)
ws.nodes[streamID] = n
@@ -285,24 +294,23 @@ func (ws *priorityWriteScheduler) OpenStream(streamID uint32, options OpenStream
}
}
func (ws *priorityWriteScheduler) CloseStream(streamID uint32) {
func (ws *priorityWriteSchedulerRFC7540) CloseStream(streamID uint32) {
if streamID == 0 {
panic("violation of WriteScheduler interface: cannot close stream 0")
}
if ws.nodes[streamID] == nil {
panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID))
}
if ws.nodes[streamID].state != priorityNodeOpen {
if ws.nodes[streamID].state != priorityNodeOpenRFC7540 {
panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID))
}
n := ws.nodes[streamID]
n.state = priorityNodeClosed
n.state = priorityNodeClosedRFC7540
n.addBytes(-n.bytes)
q := n.q
ws.queuePool.put(&q)
n.q.s = nil
if ws.maxClosedNodesInTree > 0 {
ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n)
} else {
@@ -310,7 +318,7 @@ func (ws *priorityWriteScheduler) CloseStream(streamID uint32) {
}
}
func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority PriorityParam) {
func (ws *priorityWriteSchedulerRFC7540) AdjustStream(streamID uint32, priority PriorityParam) {
if streamID == 0 {
panic("adjustPriority on root")
}
@@ -324,11 +332,11 @@ func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority Priorit
return
}
ws.maxID = streamID
n = &priorityNode{
n = &priorityNodeRFC7540{
q: *ws.queuePool.get(),
id: streamID,
weight: priorityDefaultWeight,
state: priorityNodeIdle,
weight: priorityDefaultWeightRFC7540,
state: priorityNodeIdleRFC7540,
}
n.setParent(&ws.root)
ws.nodes[streamID] = n
@@ -340,7 +348,7 @@ func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority Priorit
parent := ws.nodes[priority.StreamDep]
if parent == nil {
n.setParent(&ws.root)
n.weight = priorityDefaultWeight
n.weight = priorityDefaultWeightRFC7540
return
}
@@ -381,8 +389,8 @@ func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority Priorit
n.weight = priority.Weight
}
func (ws *priorityWriteScheduler) Push(wr FrameWriteRequest) {
var n *priorityNode
func (ws *priorityWriteSchedulerRFC7540) Push(wr FrameWriteRequest) {
var n *priorityNodeRFC7540
if wr.isControl() {
n = &ws.root
} else {
@@ -401,8 +409,8 @@ func (ws *priorityWriteScheduler) Push(wr FrameWriteRequest) {
n.q.push(wr)
}
func (ws *priorityWriteScheduler) Pop() (wr FrameWriteRequest, ok bool) {
ws.root.walkReadyInOrder(false, &ws.tmp, func(n *priorityNode, openParent bool) bool {
func (ws *priorityWriteSchedulerRFC7540) Pop() (wr FrameWriteRequest, ok bool) {
ws.root.walkReadyInOrder(false, &ws.tmp, func(n *priorityNodeRFC7540, openParent bool) bool {
limit := int32(math.MaxInt32)
if openParent {
limit = ws.writeThrottleLimit
@@ -428,7 +436,7 @@ func (ws *priorityWriteScheduler) Pop() (wr FrameWriteRequest, ok bool) {
return wr, ok
}
func (ws *priorityWriteScheduler) addClosedOrIdleNode(list *[]*priorityNode, maxSize int, n *priorityNode) {
func (ws *priorityWriteSchedulerRFC7540) addClosedOrIdleNode(list *[]*priorityNodeRFC7540, maxSize int, n *priorityNodeRFC7540) {
if maxSize == 0 {
return
}
@@ -442,7 +450,7 @@ func (ws *priorityWriteScheduler) addClosedOrIdleNode(list *[]*priorityNode, max
*list = append(*list, n)
}
func (ws *priorityWriteScheduler) removeNode(n *priorityNode) {
func (ws *priorityWriteSchedulerRFC7540) removeNode(n *priorityNodeRFC7540) {
for n.kids != nil {
n.kids.setParent(n.parent)
}
+224
View File
@@ -0,0 +1,224 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"fmt"
"math"
)
type streamMetadata struct {
location *writeQueue
priority PriorityParam
}
type priorityWriteSchedulerRFC9218 struct {
// control contains control frames (SETTINGS, PING, etc.).
control writeQueue
// heads contain the head of a circular list of streams.
// We put these heads within a nested array that represents urgency and
// incremental, as defined in
// https://www.rfc-editor.org/rfc/rfc9218.html#name-priority-parameters.
// 8 represents u=0 up to u=7, and 2 represents i=false and i=true.
heads [8][2]*writeQueue
// streams contains a mapping between each stream ID and their metadata, so
// we can quickly locate them when needing to, for example, adjust their
// priority.
streams map[uint32]streamMetadata
// queuePool are empty queues for reuse.
queuePool writeQueuePool
// prioritizeIncremental is used to determine whether we should prioritize
// incremental streams or not, when urgency is the same in a given Pop()
// call.
prioritizeIncremental bool
// priorityUpdateBuf is used to buffer the most recent PRIORITY_UPDATE we
// receive per https://www.rfc-editor.org/rfc/rfc9218.html#name-the-priority_update-frame.
priorityUpdateBuf struct {
// streamID being 0 means that the buffer is empty. This is a safe
// assumption as PRIORITY_UPDATE for stream 0 is a PROTOCOL_ERROR.
streamID uint32
priority PriorityParam
}
}
func newPriorityWriteSchedulerRFC9218() WriteScheduler {
ws := &priorityWriteSchedulerRFC9218{
streams: make(map[uint32]streamMetadata),
}
return ws
}
func (ws *priorityWriteSchedulerRFC9218) OpenStream(streamID uint32, opt OpenStreamOptions) {
if ws.streams[streamID].location != nil {
panic(fmt.Errorf("stream %d already opened", streamID))
}
if streamID == ws.priorityUpdateBuf.streamID {
ws.priorityUpdateBuf.streamID = 0
opt.priority = ws.priorityUpdateBuf.priority
}
q := ws.queuePool.get()
ws.streams[streamID] = streamMetadata{
location: q,
priority: opt.priority,
}
u, i := opt.priority.urgency, opt.priority.incremental
if ws.heads[u][i] == nil {
ws.heads[u][i] = q
q.next = q
q.prev = q
} else {
// Queues are stored in a ring.
// Insert the new stream before ws.head, putting it at the end of the list.
q.prev = ws.heads[u][i].prev
q.next = ws.heads[u][i]
q.prev.next = q
q.next.prev = q
}
}
func (ws *priorityWriteSchedulerRFC9218) CloseStream(streamID uint32) {
metadata := ws.streams[streamID]
q, u, i := metadata.location, metadata.priority.urgency, metadata.priority.incremental
if q == nil {
return
}
if q.next == q {
// This was the only open stream.
ws.heads[u][i] = nil
} else {
q.prev.next = q.next
q.next.prev = q.prev
if ws.heads[u][i] == q {
ws.heads[u][i] = q.next
}
}
delete(ws.streams, streamID)
ws.queuePool.put(q)
}
func (ws *priorityWriteSchedulerRFC9218) AdjustStream(streamID uint32, priority PriorityParam) {
metadata := ws.streams[streamID]
q, u, i := metadata.location, metadata.priority.urgency, metadata.priority.incremental
if q == nil {
ws.priorityUpdateBuf.streamID = streamID
ws.priorityUpdateBuf.priority = priority
return
}
// Remove stream from current location.
if q.next == q {
// This was the only open stream.
ws.heads[u][i] = nil
} else {
q.prev.next = q.next
q.next.prev = q.prev
if ws.heads[u][i] == q {
ws.heads[u][i] = q.next
}
}
// Insert stream to the new queue.
u, i = priority.urgency, priority.incremental
if ws.heads[u][i] == nil {
ws.heads[u][i] = q
q.next = q
q.prev = q
} else {
// Queues are stored in a ring.
// Insert the new stream before ws.head, putting it at the end of the list.
q.prev = ws.heads[u][i].prev
q.next = ws.heads[u][i]
q.prev.next = q
q.next.prev = q
}
// Update the metadata.
ws.streams[streamID] = streamMetadata{
location: q,
priority: priority,
}
}
func (ws *priorityWriteSchedulerRFC9218) Push(wr FrameWriteRequest) {
if wr.isControl() {
ws.control.push(wr)
return
}
q := ws.streams[wr.StreamID()].location
if q == nil {
// This is a closed stream.
// wr should not be a HEADERS or DATA frame.
// We push the request onto the control queue.
if wr.DataSize() > 0 {
panic("add DATA on non-open stream")
}
ws.control.push(wr)
return
}
q.push(wr)
}
func (ws *priorityWriteSchedulerRFC9218) Pop() (FrameWriteRequest, bool) {
// Control and RST_STREAM frames first.
if !ws.control.empty() {
return ws.control.shift(), true
}
// On the next Pop(), we want to prioritize incremental if we prioritized
// non-incremental request of the same urgency this time. Vice-versa.
// i.e. when there are incremental and non-incremental requests at the same
// priority, we give 50% of our bandwidth to the incremental ones in
// aggregate and 50% to the first non-incremental one (since
// non-incremental streams do not use round-robin writes).
ws.prioritizeIncremental = !ws.prioritizeIncremental
// Always prioritize lowest u (i.e. highest urgency level).
for u := range ws.heads {
for i := range ws.heads[u] {
// When we want to prioritize incremental, we try to pop i=true
// first before i=false when u is the same.
if ws.prioritizeIncremental {
i = (i + 1) % 2
}
q := ws.heads[u][i]
if q == nil {
continue
}
for {
if wr, ok := q.consume(math.MaxInt32); ok {
if i == 1 {
// For incremental streams, we update head to q.next so
// we can round-robin between multiple streams that can
// immediately benefit from partial writes.
ws.heads[u][i] = q.next
} else {
// For non-incremental streams, we try to finish one to
// completion rather than doing round-robin. However,
// we update head here so that if q.consume() is !ok
// (e.g. the stream has no more frame to consume), head
// is updated to the next q that has frames to consume
// on future iterations. This way, we do not prioritize
// writing to unavailable stream on next Pop() calls,
// preventing head-of-line blocking.
ws.heads[u][i] = q
}
return wr, true
}
q = q.next
if q == ws.heads[u][i] {
break
}
}
}
}
return FrameWriteRequest{}, false
}
+2
View File
@@ -10,6 +10,8 @@ import "math"
// priorities. Control frames like SETTINGS and PING are written before DATA
// frames, but if no control frames are queued and multiple streams have queued
// HEADERS or DATA frames, Pop selects a ready stream arbitrarily.
//
// Deprecated: User-provided write schedulers are deprecated.
func NewRandomWriteScheduler() WriteScheduler {
return &randomWriteScheduler{sq: make(map[uint32]*writeQueue)}
}
+1 -1
View File
@@ -25,7 +25,7 @@ type roundRobinWriteScheduler struct {
}
// newRoundRobinWriteScheduler constructs a new write scheduler.
// The round robin scheduler priorizes control frames
// The round robin scheduler prioritizes control frames
// like SETTINGS and PING over DATA frames.
// When there are no control frames to send, it performs a round-robin
// selection from the ready streams.
+2 -3
View File
@@ -9,7 +9,6 @@ import (
"net"
"runtime"
"golang.org/x/net/internal/socket"
"golang.org/x/net/ipv4"
)
@@ -49,12 +48,12 @@ func ParseIPv4Header(b []byte) (*ipv4.Header, error) {
}
switch runtime.GOOS {
case "darwin", "ios":
h.TotalLen = int(socket.NativeEndian.Uint16(b[2:4]))
h.TotalLen = int(binary.NativeEndian.Uint16(b[2:4]))
case "freebsd":
if freebsdVersion >= 1000000 {
h.TotalLen = int(binary.BigEndian.Uint16(b[2:4]))
} else {
h.TotalLen = int(socket.NativeEndian.Uint16(b[2:4]))
h.TotalLen = int(binary.NativeEndian.Uint16(b[2:4]))
}
default:
h.TotalLen = int(binary.BigEndian.Uint16(b[2:4]))
+2 -2
View File
@@ -51,7 +51,7 @@ type EncodeHeadersParam struct {
DefaultUserAgent string
}
// EncodeHeadersParam is the result of EncodeHeaders.
// EncodeHeadersResult is the result of EncodeHeaders.
type EncodeHeadersResult struct {
HasBody bool
HasTrailers bool
@@ -399,7 +399,7 @@ type ServerRequestResult struct {
// If the request should be rejected, this is a short string suitable for passing
// to the http2 package's CountError function.
// It might be a bit odd to return errors this way rather than returing an error,
// It might be a bit odd to return errors this way rather than returning an error,
// but this ensures we don't forget to include a CountError reason.
InvalidReason string
}
+665
View File
@@ -0,0 +1,665 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package httpsfv provides functionality for dealing with HTTP Structured
// Field Values.
package httpsfv
import (
"slices"
"strconv"
"strings"
"time"
"unicode/utf8"
)
func isLCAlpha(b byte) bool {
return (b >= 'a' && b <= 'z')
}
func isAlpha(b byte) bool {
return isLCAlpha(b) || (b >= 'A' && b <= 'Z')
}
func isDigit(b byte) bool {
return b >= '0' && b <= '9'
}
func isVChar(b byte) bool {
return b >= 0x21 && b <= 0x7e
}
func isSP(b byte) bool {
return b == 0x20
}
func isTChar(b byte) bool {
if isAlpha(b) || isDigit(b) {
return true
}
return slices.Contains([]byte{'!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~'}, b)
}
func countLeftWhitespace(s string) int {
i := 0
for _, ch := range []byte(s) {
if ch != ' ' && ch != '\t' {
break
}
i++
}
return i
}
// https://www.rfc-editor.org/rfc/rfc4648#section-8.
func decOctetHex(ch1, ch2 byte) (ch byte, ok bool) {
decBase16 := func(in byte) (out byte, ok bool) {
if !isDigit(in) && !(in >= 'a' && in <= 'f') {
return 0, false
}
if isDigit(in) {
return in - '0', true
}
return in - 'a' + 10, true
}
if ch1, ok = decBase16(ch1); !ok {
return 0, ok
}
if ch2, ok = decBase16(ch2); !ok {
return 0, ok
}
return ch1<<4 | ch2, true
}
// ParseList parses a list from a given HTTP Structured Field Values.
//
// Given an HTTP SFV string that represents a list, it will call the given
// function using each of the members and parameters contained in the list.
// This allows the caller to extract information out of the list.
//
// This function will return once it encounters the end of the string, or
// something that is not a list. If it cannot consume the entire given
// string, the ok value returned will be false.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-list.
func ParseList(s string, f func(member, param string)) (ok bool) {
for len(s) != 0 {
var member, param string
if len(s) != 0 && s[0] == '(' {
if member, s, ok = consumeBareInnerList(s, nil); !ok {
return ok
}
} else {
if member, s, ok = consumeBareItem(s); !ok {
return ok
}
}
if param, s, ok = consumeParameter(s, nil); !ok {
return ok
}
if f != nil {
f(member, param)
}
s = s[countLeftWhitespace(s):]
if len(s) == 0 {
break
}
if s[0] != ',' {
return false
}
s = s[1:]
s = s[countLeftWhitespace(s):]
if len(s) == 0 {
return false
}
}
return true
}
// consumeBareInnerList consumes an inner list
// (https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-inner-list),
// except for the inner list's top-most parameter.
// For example, given `(a;b c;d);e`, it will consume only `(a;b c;d)`.
func consumeBareInnerList(s string, f func(bareItem, param string)) (consumed, rest string, ok bool) {
if len(s) == 0 || s[0] != '(' {
return "", s, false
}
rest = s[1:]
for len(rest) != 0 {
var bareItem, param string
rest = rest[countLeftWhitespace(rest):]
if len(rest) != 0 && rest[0] == ')' {
rest = rest[1:]
break
}
if bareItem, rest, ok = consumeBareItem(rest); !ok {
return "", s, ok
}
if param, rest, ok = consumeParameter(rest, nil); !ok {
return "", s, ok
}
if len(rest) == 0 || (rest[0] != ')' && !isSP(rest[0])) {
return "", s, false
}
if f != nil {
f(bareItem, param)
}
}
return s[:len(s)-len(rest)], rest, true
}
// ParseBareInnerList parses a bare inner list from a given HTTP Structured
// Field Values.
//
// We define a bare inner list as an inner list
// (https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-inner-list),
// without the top-most parameter of the inner list. For example, given the
// inner list `(a;b c;d);e`, the bare inner list would be `(a;b c;d)`.
//
// Given an HTTP SFV string that represents a bare inner list, it will call the
// given function using each of the bare item and parameter within the bare
// inner list. This allows the caller to extract information out of the bare
// inner list.
//
// This function will return once it encounters the end of the bare inner list,
// or something that is not a bare inner list. If it cannot consume the entire
// given string, the ok value returned will be false.
func ParseBareInnerList(s string, f func(bareItem, param string)) (ok bool) {
_, rest, ok := consumeBareInnerList(s, f)
return rest == "" && ok
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-item.
func consumeItem(s string, f func(bareItem, param string)) (consumed, rest string, ok bool) {
var bareItem, param string
if bareItem, rest, ok = consumeBareItem(s); !ok {
return "", s, ok
}
if param, rest, ok = consumeParameter(rest, nil); !ok {
return "", s, ok
}
if f != nil {
f(bareItem, param)
}
return s[:len(s)-len(rest)], rest, true
}
// ParseItem parses an item from a given HTTP Structured Field Values.
//
// Given an HTTP SFV string that represents an item, it will call the given
// function once, with the bare item and the parameter of the item. This allows
// the caller to extract information out of the item.
//
// This function will return once it encounters the end of the string, or
// something that is not an item. If it cannot consume the entire given
// string, the ok value returned will be false.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-item.
func ParseItem(s string, f func(bareItem, param string)) (ok bool) {
_, rest, ok := consumeItem(s, f)
return rest == "" && ok
}
// ParseDictionary parses a dictionary from a given HTTP Structured Field
// Values.
//
// Given an HTTP SFV string that represents a dictionary, it will call the
// given function using each of the keys, values, and parameters contained in
// the dictionary. This allows the caller to extract information out of the
// dictionary.
//
// This function will return once it encounters the end of the string, or
// something that is not a dictionary. If it cannot consume the entire given
// string, the ok value returned will be false.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-dictionary.
func ParseDictionary(s string, f func(key, val, param string)) (ok bool) {
for len(s) != 0 {
var key, val, param string
val = "?1" // Default value for empty val is boolean true.
if key, s, ok = consumeKey(s); !ok {
return ok
}
if len(s) != 0 && s[0] == '=' {
s = s[1:]
if len(s) != 0 && s[0] == '(' {
if val, s, ok = consumeBareInnerList(s, nil); !ok {
return ok
}
} else {
if val, s, ok = consumeBareItem(s); !ok {
return ok
}
}
}
if param, s, ok = consumeParameter(s, nil); !ok {
return ok
}
if f != nil {
f(key, val, param)
}
s = s[countLeftWhitespace(s):]
if len(s) == 0 {
break
}
if s[0] == ',' {
s = s[1:]
}
s = s[countLeftWhitespace(s):]
if len(s) == 0 {
return false
}
}
return true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#parse-param.
func consumeParameter(s string, f func(key, val string)) (consumed, rest string, ok bool) {
rest = s
for len(rest) != 0 {
var key, val string
val = "?1" // Default value for empty val is boolean true.
if rest[0] != ';' {
break
}
rest = rest[1:]
rest = rest[countLeftWhitespace(rest):]
key, rest, ok = consumeKey(rest)
if !ok {
return "", s, ok
}
if len(rest) != 0 && rest[0] == '=' {
rest = rest[1:]
val, rest, ok = consumeBareItem(rest)
if !ok {
return "", s, ok
}
}
if f != nil {
f(key, val)
}
}
return s[:len(s)-len(rest)], rest, true
}
// ParseParameter parses a parameter from a given HTTP Structured Field Values.
//
// Given an HTTP SFV string that represents a parameter, it will call the given
// function using each of the keys and values contained in the parameter. This
// allows the caller to extract information out of the parameter.
//
// This function will return once it encounters the end of the string, or
// something that is not a parameter. If it cannot consume the entire given
// string, the ok value returned will be false.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#parse-param.
func ParseParameter(s string, f func(key, val string)) (ok bool) {
_, rest, ok := consumeParameter(s, f)
return rest == "" && ok
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-key.
func consumeKey(s string) (consumed, rest string, ok bool) {
if len(s) == 0 || (!isLCAlpha(s[0]) && s[0] != '*') {
return "", s, false
}
i := 0
for _, ch := range []byte(s) {
if !isLCAlpha(ch) && !isDigit(ch) && !slices.Contains([]byte("_-.*"), ch) {
break
}
i++
}
return s[:i], s[i:], true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-integer-or-decim.
func consumeIntegerOrDecimal(s string) (consumed, rest string, ok bool) {
var i, signOffset, periodIndex int
var isDecimal bool
if i < len(s) && s[i] == '-' {
i++
signOffset++
}
if i >= len(s) {
return "", s, false
}
if !isDigit(s[i]) {
return "", s, false
}
for i < len(s) {
ch := s[i]
if isDigit(ch) {
i++
continue
}
if !isDecimal && ch == '.' {
if i-signOffset > 12 {
return "", s, false
}
periodIndex = i
isDecimal = true
i++
continue
}
break
}
if !isDecimal && i-signOffset > 15 {
return "", s, false
}
if isDecimal {
if i-signOffset > 16 {
return "", s, false
}
if s[i-1] == '.' {
return "", s, false
}
if i-periodIndex-1 > 3 {
return "", s, false
}
}
return s[:i], s[i:], true
}
// ParseInteger parses an integer from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid integer. It returns the
// parsed integer and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-integer-or-decim.
func ParseInteger(s string) (parsed int64, ok bool) {
if _, rest, ok := consumeIntegerOrDecimal(s); !ok || rest != "" {
return 0, false
}
if n, err := strconv.ParseInt(s, 10, 64); err == nil {
return n, true
}
return 0, false
}
// ParseDecimal parses a decimal from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid decimal. It returns the
// parsed decimal and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-integer-or-decim.
func ParseDecimal(s string) (parsed float64, ok bool) {
if _, rest, ok := consumeIntegerOrDecimal(s); !ok || rest != "" {
return 0, false
}
if !strings.Contains(s, ".") {
return 0, false
}
if n, err := strconv.ParseFloat(s, 64); err == nil {
return n, true
}
return 0, false
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-string.
func consumeString(s string) (consumed, rest string, ok bool) {
if len(s) == 0 || s[0] != '"' {
return "", s, false
}
for i := 1; i < len(s); i++ {
switch ch := s[i]; ch {
case '\\':
if i+1 >= len(s) {
return "", s, false
}
i++
if ch = s[i]; ch != '"' && ch != '\\' {
return "", s, false
}
case '"':
return s[:i+1], s[i+1:], true
default:
if !isVChar(ch) && !isSP(ch) {
return "", s, false
}
}
}
return "", s, false
}
// ParseString parses a Go string from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid string. It returns the
// parsed string and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-string.
func ParseString(s string) (parsed string, ok bool) {
if _, rest, ok := consumeString(s); !ok || rest != "" {
return "", false
}
return s[1 : len(s)-1], true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-token
func consumeToken(s string) (consumed, rest string, ok bool) {
if len(s) == 0 || (!isAlpha(s[0]) && s[0] != '*') {
return "", s, false
}
i := 0
for _, ch := range []byte(s) {
if !isTChar(ch) && !slices.Contains([]byte(":/"), ch) {
break
}
i++
}
return s[:i], s[i:], true
}
// ParseToken parses a token from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid token. It returns the
// parsed token and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-token
func ParseToken(s string) (parsed string, ok bool) {
if _, rest, ok := consumeToken(s); !ok || rest != "" {
return "", false
}
return s, true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-byte-sequence.
func consumeByteSequence(s string) (consumed, rest string, ok bool) {
if len(s) == 0 || s[0] != ':' {
return "", s, false
}
for i := 1; i < len(s); i++ {
if ch := s[i]; ch == ':' {
return s[:i+1], s[i+1:], true
}
if ch := s[i]; !isAlpha(ch) && !isDigit(ch) && !slices.Contains([]byte("+/="), ch) {
return "", s, false
}
}
return "", s, false
}
// ParseByteSequence parses a byte sequence from a given HTTP Structured Field
// Values.
//
// The entire HTTP SFV string must consist of a valid byte sequence. It returns
// the parsed byte sequence and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-byte-sequence.
func ParseByteSequence(s string) (parsed []byte, ok bool) {
if _, rest, ok := consumeByteSequence(s); !ok || rest != "" {
return nil, false
}
return []byte(s[1 : len(s)-1]), true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-boolean.
func consumeBoolean(s string) (consumed, rest string, ok bool) {
if len(s) >= 2 && (s[:2] == "?0" || s[:2] == "?1") {
return s[:2], s[2:], true
}
return "", s, false
}
// ParseBoolean parses a boolean from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid boolean. It returns the
// parsed boolean and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-boolean.
func ParseBoolean(s string) (parsed bool, ok bool) {
if _, rest, ok := consumeBoolean(s); !ok || rest != "" {
return false, false
}
return s == "?1", true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-date.
func consumeDate(s string) (consumed, rest string, ok bool) {
if len(s) == 0 || s[0] != '@' {
return "", s, false
}
if _, rest, ok = consumeIntegerOrDecimal(s[1:]); !ok {
return "", s, ok
}
consumed = s[:len(s)-len(rest)]
if slices.Contains([]byte(consumed), '.') {
return "", s, false
}
return consumed, rest, ok
}
// ParseDate parses a date from a given HTTP Structured Field Values.
//
// The entire HTTP SFV string must consist of a valid date. It returns the
// parsed date and an ok boolean value, indicating success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-date.
func ParseDate(s string) (parsed time.Time, ok bool) {
if _, rest, ok := consumeDate(s); !ok || rest != "" {
return time.Time{}, false
}
if n, ok := ParseInteger(s[1:]); !ok {
return time.Time{}, false
} else {
return time.Unix(n, 0), true
}
}
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-display-string.
func consumeDisplayString(s string) (consumed, rest string, ok bool) {
// To prevent excessive allocation, especially when input is large, we
// maintain a buffer of 4 bytes to keep track of the last rune we
// encounter. This way, we can validate that the display string conforms to
// UTF-8 without actually building the whole string.
var lastRune [4]byte
var runeLen int
isPartOfValidRune := func(ch byte) bool {
lastRune[runeLen] = ch
runeLen++
if utf8.FullRune(lastRune[:runeLen]) {
r, s := utf8.DecodeRune(lastRune[:runeLen])
if r == utf8.RuneError {
return false
}
copy(lastRune[:], lastRune[s:runeLen])
runeLen -= s
return true
}
return runeLen <= 4
}
if len(s) <= 1 || s[:2] != `%"` {
return "", s, false
}
i := 2
for i < len(s) {
ch := s[i]
if !isVChar(ch) && !isSP(ch) {
return "", s, false
}
switch ch {
case '"':
if runeLen > 0 {
return "", s, false
}
return s[:i+1], s[i+1:], true
case '%':
if i+2 >= len(s) {
return "", s, false
}
if ch, ok = decOctetHex(s[i+1], s[i+2]); !ok {
return "", s, ok
}
if ok = isPartOfValidRune(ch); !ok {
return "", s, ok
}
i += 3
default:
if ok = isPartOfValidRune(ch); !ok {
return "", s, ok
}
i++
}
}
return "", s, false
}
// ParseDisplayString parses a display string from a given HTTP Structured
// Field Values.
//
// The entire HTTP SFV string must consist of a valid display string. It
// returns the parsed display string and an ok boolean value, indicating
// success or not.
//
// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-display-string.
func ParseDisplayString(s string) (parsed string, ok bool) {
if _, rest, ok := consumeDisplayString(s); !ok || rest != "" {
return "", false
}
// consumeDisplayString() already validates that we have a valid display
// string. Therefore, we can just construct the display string, without
// validating it again.
s = s[2 : len(s)-1]
var b strings.Builder
for i := 0; i < len(s); {
if s[i] == '%' {
decoded, _ := decOctetHex(s[i+1], s[i+2])
b.WriteByte(decoded)
i += 3
continue
}
b.WriteByte(s[i])
i++
}
return b.String(), true
}
// https://www.rfc-editor.org/rfc/rfc9651.html#parse-bare-item.
func consumeBareItem(s string) (consumed, rest string, ok bool) {
if len(s) == 0 {
return "", s, false
}
ch := s[0]
switch {
case ch == '-' || isDigit(ch):
return consumeIntegerOrDecimal(s)
case ch == '"':
return consumeString(s)
case ch == '*' || isAlpha(ch):
return consumeToken(s)
case ch == ':':
return consumeByteSequence(s)
case ch == '?':
return consumeBoolean(s)
case ch == '@':
return consumeDate(s)
case ch == '%':
return consumeDisplayString(s)
default:
return "", s, false
}
}
+1 -1
View File
@@ -2,6 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin && go1.12
//go:build darwin
// This exists solely so we can linkname in symbols from syscall.
+5 -2
View File
@@ -6,7 +6,10 @@
package socket
import "unsafe"
import (
"encoding/binary"
"unsafe"
)
func (h *msghdr) pack(vs []iovec, bs [][]byte, oob []byte, sa []byte) {
for i := range vs {
@@ -31,5 +34,5 @@ func (h *msghdr) controllen() int {
}
func (h *msghdr) flags() int {
return int(NativeEndian.Uint32(h.Pad_cgo_2[:]))
return int(binary.NativeEndian.Uint32(h.Pad_cgo_2[:]))
}
+3 -2
View File
@@ -7,6 +7,7 @@
package socket // import "golang.org/x/net/internal/socket"
import (
"encoding/binary"
"errors"
"net"
"runtime"
@@ -58,7 +59,7 @@ func (o *Option) GetInt(c *Conn) (int, error) {
if o.Len == 1 {
return int(b[0]), nil
}
return int(NativeEndian.Uint32(b[:4])), nil
return int(binary.NativeEndian.Uint32(b[:4])), nil
}
// Set writes the option and value to the kernel.
@@ -84,7 +85,7 @@ func (o *Option) SetInt(c *Conn, v int) error {
b = []byte{byte(v)}
} else {
var bb [4]byte
NativeEndian.PutUint32(bb[:o.Len], uint32(v))
binary.NativeEndian.PutUint32(bb[:o.Len], uint32(v))
b = bb[:4]
}
return o.set(c, b)
-23
View File
@@ -1,23 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package socket
import (
"encoding/binary"
"unsafe"
)
// NativeEndian is the machine native endian implementation of ByteOrder.
var NativeEndian binary.ByteOrder
func init() {
i := uint32(1)
b := (*[4]byte)(unsafe.Pointer(&i))
if b[0] == 1 {
NativeEndian = binary.LittleEndian
} else {
NativeEndian = binary.BigEndian
}
}
+5 -5
View File
@@ -36,7 +36,7 @@ func marshalSockaddr(ip net.IP, port int, zone string, b []byte) int {
if ip4 := ip.To4(); ip4 != nil {
switch runtime.GOOS {
case "android", "illumos", "linux", "solaris", "windows":
NativeEndian.PutUint16(b[:2], uint16(sysAF_INET))
binary.NativeEndian.PutUint16(b[:2], uint16(sysAF_INET))
default:
b[0] = sizeofSockaddrInet4
b[1] = sysAF_INET
@@ -48,7 +48,7 @@ func marshalSockaddr(ip net.IP, port int, zone string, b []byte) int {
if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil {
switch runtime.GOOS {
case "android", "illumos", "linux", "solaris", "windows":
NativeEndian.PutUint16(b[:2], uint16(sysAF_INET6))
binary.NativeEndian.PutUint16(b[:2], uint16(sysAF_INET6))
default:
b[0] = sizeofSockaddrInet6
b[1] = sysAF_INET6
@@ -56,7 +56,7 @@ func marshalSockaddr(ip net.IP, port int, zone string, b []byte) int {
binary.BigEndian.PutUint16(b[2:4], uint16(port))
copy(b[8:24], ip6)
if zone != "" {
NativeEndian.PutUint32(b[24:28], uint32(zoneCache.index(zone)))
binary.NativeEndian.PutUint32(b[24:28], uint32(zoneCache.index(zone)))
}
return sizeofSockaddrInet6
}
@@ -70,7 +70,7 @@ func parseInetAddr(b []byte, network string) (net.Addr, error) {
var af int
switch runtime.GOOS {
case "android", "illumos", "linux", "solaris", "windows":
af = int(NativeEndian.Uint16(b[:2]))
af = int(binary.NativeEndian.Uint16(b[:2]))
default:
af = int(b[1])
}
@@ -89,7 +89,7 @@ func parseInetAddr(b []byte, network string) (net.Addr, error) {
}
ip = make(net.IP, net.IPv6len)
copy(ip, b[8:24])
if id := int(NativeEndian.Uint32(b[24:28])); id > 0 {
if id := int(binary.NativeEndian.Uint32(b[24:28])); id > 0 {
zone = zoneCache.name(id)
}
}
+1 -1
View File
@@ -297,7 +297,7 @@ func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter,
b = append(b, up.Username...)
b = append(b, byte(len(up.Password)))
b = append(b, up.Password...)
// TODO(mikio): handle IO deadlines and cancelation if
// TODO(mikio): handle IO deadlines and cancellation if
// necessary
if _, err := rw.Write(b); err != nil {
return err
+8 -10
View File
@@ -9,8 +9,6 @@ import (
"fmt"
"net"
"runtime"
"golang.org/x/net/internal/socket"
)
const (
@@ -68,12 +66,12 @@ func (h *Header) Marshal() ([]byte, error) {
flagsAndFragOff := (h.FragOff & 0x1fff) | int(h.Flags<<13)
switch runtime.GOOS {
case "darwin", "ios", "dragonfly", "netbsd":
socket.NativeEndian.PutUint16(b[2:4], uint16(h.TotalLen))
socket.NativeEndian.PutUint16(b[6:8], uint16(flagsAndFragOff))
binary.NativeEndian.PutUint16(b[2:4], uint16(h.TotalLen))
binary.NativeEndian.PutUint16(b[6:8], uint16(flagsAndFragOff))
case "freebsd":
if freebsdVersion < 1100000 {
socket.NativeEndian.PutUint16(b[2:4], uint16(h.TotalLen))
socket.NativeEndian.PutUint16(b[6:8], uint16(flagsAndFragOff))
binary.NativeEndian.PutUint16(b[2:4], uint16(h.TotalLen))
binary.NativeEndian.PutUint16(b[6:8], uint16(flagsAndFragOff))
} else {
binary.BigEndian.PutUint16(b[2:4], uint16(h.TotalLen))
binary.BigEndian.PutUint16(b[6:8], uint16(flagsAndFragOff))
@@ -127,15 +125,15 @@ func (h *Header) Parse(b []byte) error {
h.Dst = net.IPv4(b[16], b[17], b[18], b[19])
switch runtime.GOOS {
case "darwin", "ios", "dragonfly", "netbsd":
h.TotalLen = int(socket.NativeEndian.Uint16(b[2:4])) + hdrlen
h.FragOff = int(socket.NativeEndian.Uint16(b[6:8]))
h.TotalLen = int(binary.NativeEndian.Uint16(b[2:4])) + hdrlen
h.FragOff = int(binary.NativeEndian.Uint16(b[6:8]))
case "freebsd":
if freebsdVersion < 1100000 {
h.TotalLen = int(socket.NativeEndian.Uint16(b[2:4]))
h.TotalLen = int(binary.NativeEndian.Uint16(b[2:4]))
if freebsdVersion < 1000000 {
h.TotalLen += hdrlen
}
h.FragOff = int(socket.NativeEndian.Uint16(b[6:8]))
h.FragOff = int(binary.NativeEndian.Uint16(b[6:8]))
} else {
h.TotalLen = int(binary.BigEndian.Uint16(b[2:4]))
h.FragOff = int(binary.BigEndian.Uint16(b[6:8]))
+2 -1
View File
@@ -7,6 +7,7 @@
package ipv6
import (
"encoding/binary"
"unsafe"
"golang.org/x/net/internal/iana"
@@ -19,7 +20,7 @@ func marshal2292HopLimit(b []byte, cm *ControlMessage) []byte {
m := socket.ControlMessage(b)
m.MarshalHeader(iana.ProtocolIPv6, unix.IPV6_2292HOPLIMIT, 4)
if cm != nil {
socket.NativeEndian.PutUint32(m.Data(4), uint32(cm.HopLimit))
binary.NativeEndian.PutUint32(m.Data(4), uint32(cm.HopLimit))
}
return m.Next(4)
}
+5 -4
View File
@@ -7,6 +7,7 @@
package ipv6
import (
"encoding/binary"
"net"
"unsafe"
@@ -20,26 +21,26 @@ func marshalTrafficClass(b []byte, cm *ControlMessage) []byte {
m := socket.ControlMessage(b)
m.MarshalHeader(iana.ProtocolIPv6, unix.IPV6_TCLASS, 4)
if cm != nil {
socket.NativeEndian.PutUint32(m.Data(4), uint32(cm.TrafficClass))
binary.NativeEndian.PutUint32(m.Data(4), uint32(cm.TrafficClass))
}
return m.Next(4)
}
func parseTrafficClass(cm *ControlMessage, b []byte) {
cm.TrafficClass = int(socket.NativeEndian.Uint32(b[:4]))
cm.TrafficClass = int(binary.NativeEndian.Uint32(b[:4]))
}
func marshalHopLimit(b []byte, cm *ControlMessage) []byte {
m := socket.ControlMessage(b)
m.MarshalHeader(iana.ProtocolIPv6, unix.IPV6_HOPLIMIT, 4)
if cm != nil {
socket.NativeEndian.PutUint32(m.Data(4), uint32(cm.HopLimit))
binary.NativeEndian.PutUint32(m.Data(4), uint32(cm.HopLimit))
}
return m.Next(4)
}
func parseHopLimit(cm *ControlMessage, b []byte) {
cm.HopLimit = int(socket.NativeEndian.Uint32(b[:4]))
cm.HopLimit = int(binary.NativeEndian.Uint32(b[:4]))
}
func marshalPacketInfo(b []byte, cm *ControlMessage) []byte {
+3 -3
View File
@@ -142,7 +142,7 @@ func testPingPong(t *testing.T, c1, c2 net.Conn) {
}
// testRacyRead tests that it is safe to mutate the input Read buffer
// immediately after cancelation has occurred.
// immediately after cancellation has occurred.
func testRacyRead(t *testing.T, c1, c2 net.Conn) {
go chunkedCopy(c2, rand.New(rand.NewSource(0)))
@@ -170,7 +170,7 @@ func testRacyRead(t *testing.T, c1, c2 net.Conn) {
}
// testRacyWrite tests that it is safe to mutate the input Write buffer
// immediately after cancelation has occurred.
// immediately after cancellation has occurred.
func testRacyWrite(t *testing.T, c1, c2 net.Conn) {
go chunkedCopy(io.Discard, c2)
@@ -318,7 +318,7 @@ func testCloseTimeout(t *testing.T, c1, c2 net.Conn) {
defer wg.Wait()
wg.Add(3)
// Test for cancelation upon connection closure.
// Test for cancellation upon connection closure.
c1.SetDeadline(neverTimeout)
go func() {
defer wg.Done()
+2 -2
View File
@@ -58,8 +58,8 @@ func RenderEvents(w http.ResponseWriter, req *http.Request, sensitive bool) {
Buckets: buckets,
}
data.Families = make([]string, 0, len(families))
famMu.RLock()
data.Families = make([]string, 0, len(families))
for name := range families {
data.Families = append(data.Families, name)
}
@@ -508,7 +508,7 @@ const eventsHTML = `
<tr class="first">
<td class="when">{{$el.When}}</td>
<td class="elapsed">{{$el.ElapsedTime}}</td>
<td>{{$el.Title}}
<td>{{$el.Title}}</td>
</tr>
{{if $.Expanded}}
<tr>
+1
View File
@@ -440,6 +440,7 @@ func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (er
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != 101 {
return ErrBadStatus
}
+35 -92
View File
@@ -3,7 +3,7 @@
// license that can be found in the LICENSE file.
// Package errgroup provides synchronization, error propagation, and Context
// cancelation for groups of goroutines working on subtasks of a common task.
// cancellation for groups of goroutines working on subtasks of a common task.
//
// [errgroup.Group] is related to [sync.WaitGroup] but adds handling of tasks
// returning errors.
@@ -12,8 +12,6 @@ package errgroup
import (
"context"
"fmt"
"runtime"
"runtime/debug"
"sync"
)
@@ -33,10 +31,6 @@ type Group struct {
errOnce sync.Once
err error
mu sync.Mutex
panicValue any // = PanicError | PanicValue; non-nil if some Group.Go goroutine panicked.
abnormal bool // some Group.Go goroutine terminated abnormally (panic or goexit).
}
func (g *Group) done() {
@@ -56,80 +50,47 @@ func WithContext(ctx context.Context) (*Group, context.Context) {
return &Group{cancel: cancel}, ctx
}
// Wait blocks until all function calls from the Go method have returned
// normally, then returns the first non-nil error (if any) from them.
//
// If any of the calls panics, Wait panics with a [PanicValue];
// and if any of them calls [runtime.Goexit], Wait calls runtime.Goexit.
// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
func (g *Group) Wait() error {
g.wg.Wait()
if g.cancel != nil {
g.cancel(g.err)
}
if g.panicValue != nil {
panic(g.panicValue)
}
if g.abnormal {
runtime.Goexit()
}
return g.err
}
// Go calls the given function in a new goroutine.
// The first call to Go must happen before a Wait.
// It blocks until the new goroutine can be added without the number of
// active goroutines in the group exceeding the configured limit.
//
// The first call to Go must happen before a Wait.
// It blocks until the new goroutine can be added without the number of
// goroutines in the group exceeding the configured limit.
//
// The first goroutine in the group that returns a non-nil error, panics, or
// invokes [runtime.Goexit] will cancel the associated Context, if any.
// The first goroutine in the group that returns a non-nil error will
// cancel the associated Context, if any. The error will be returned
// by Wait.
func (g *Group) Go(f func() error) {
if g.sem != nil {
g.sem <- token{}
}
g.add(f)
}
func (g *Group) add(f func() error) {
g.wg.Add(1)
go func() {
defer g.done()
normalReturn := false
defer func() {
if normalReturn {
return
}
v := recover()
g.mu.Lock()
defer g.mu.Unlock()
if !g.abnormal {
if g.cancel != nil {
g.cancel(g.err)
}
g.abnormal = true
}
if v != nil && g.panicValue == nil {
switch v := v.(type) {
case error:
g.panicValue = PanicError{
Recovered: v,
Stack: debug.Stack(),
}
default:
g.panicValue = PanicValue{
Recovered: v,
Stack: debug.Stack(),
}
}
}
}()
err := f()
normalReturn = true
if err != nil {
// It is tempting to propagate panics from f()
// up to the goroutine that calls Wait, but
// it creates more problems than it solves:
// - it delays panics arbitrarily,
// making bugs harder to detect;
// - it turns f's panic stack into a mere value,
// hiding it from crash-monitoring tools;
// - it risks deadlocks that hide the panic entirely,
// if f's panic leaves the program in a state
// that prevents the Wait call from being reached.
// See #53757, #74275, #74304, #74306.
if err := f(); err != nil {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
@@ -154,7 +115,19 @@ func (g *Group) TryGo(f func() error) bool {
}
}
g.add(f)
g.wg.Add(1)
go func() {
defer g.done()
if err := f(); err != nil {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel(g.err)
}
})
}
}()
return true
}
@@ -171,38 +144,8 @@ func (g *Group) SetLimit(n int) {
g.sem = nil
return
}
if len(g.sem) != 0 {
panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem)))
if active := len(g.sem); active != 0 {
panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", active))
}
g.sem = make(chan token, n)
}
// PanicError wraps an error recovered from an unhandled panic
// when calling a function passed to Go or TryGo.
type PanicError struct {
Recovered error
Stack []byte // result of call to [debug.Stack]
}
func (p PanicError) Error() string {
// A Go Error method conventionally does not include a stack dump, so omit it
// here. (Callers who care can extract it from the Stack field.)
return fmt.Sprintf("recovered from errgroup.Group: %v", p.Recovered)
}
func (p PanicError) Unwrap() error { return p.Recovered }
// PanicValue wraps a value that does not implement the error interface,
// recovered from an unhandled panic when calling a function passed to Go or
// TryGo.
type PanicValue struct {
Recovered any
Stack []byte // result of call to [debug.Stack]
}
func (p PanicValue) String() string {
if len(p.Stack) > 0 {
return fmt.Sprintf("recovered from errgroup.Group: %v\n%s", p.Recovered, p.Stack)
}
return fmt.Sprintf("recovered from errgroup.Group: %v", p.Recovered)
}
+12
View File
@@ -0,0 +1,12 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin && arm64 && gc
#include "textflag.h"
TEXT libc_sysctlbyname_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_sysctlbyname(SB)
GLOBL ·libc_sysctlbyname_trampoline_addr(SB), RODATA, $8
DATA ·libc_sysctlbyname_trampoline_addr(SB)/8, $libc_sysctlbyname_trampoline<>(SB)
+3 -6
View File
@@ -44,14 +44,11 @@ func initOptions() {
}
func archInit() {
switch runtime.GOOS {
case "freebsd":
if runtime.GOOS == "freebsd" {
readARM64Registers()
case "linux", "netbsd", "openbsd", "windows":
} else {
// Most platforms don't seem to allow directly reading these registers.
doinit()
default:
// Many platforms don't seem to allow reading these registers.
setMinimalFeatures()
}
}
+67
View File
@@ -0,0 +1,67 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin && arm64 && gc
package cpu
func doinit() {
setMinimalFeatures()
// The feature flags are explained in [Instruction Set Detection].
// There are some differences between MacOS versions:
//
// MacOS 11 and 12 do not have "hw.optional" sysctl values for some of the features.
//
// MacOS 13 changed some of the naming conventions to align with ARM Architecture Reference Manual.
// For example "hw.optional.armv8_2_sha512" became "hw.optional.arm.FEAT_SHA512".
// It currently checks both to stay compatible with MacOS 11 and 12.
// The old names also work with MacOS 13, however it's not clear whether
// they will continue working with future OS releases.
//
// Once MacOS 12 is no longer supported the old names can be removed.
//
// [Instruction Set Detection]: https://developer.apple.com/documentation/kernel/1387446-sysctlbyname/determining_instruction_set_characteristics
// Encryption, hashing and checksum capabilities
// For the following flags there are no MacOS 11 sysctl flags.
ARM64.HasAES = true || darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_AES\x00"))
ARM64.HasPMULL = true || darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_PMULL\x00"))
ARM64.HasSHA1 = true || darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_SHA1\x00"))
ARM64.HasSHA2 = true || darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_SHA256\x00"))
ARM64.HasSHA3 = darwinSysctlEnabled([]byte("hw.optional.armv8_2_sha3\x00")) || darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_SHA3\x00"))
ARM64.HasSHA512 = darwinSysctlEnabled([]byte("hw.optional.armv8_2_sha512\x00")) || darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_SHA512\x00"))
ARM64.HasCRC32 = darwinSysctlEnabled([]byte("hw.optional.armv8_crc32\x00"))
// Atomic and memory ordering
ARM64.HasATOMICS = darwinSysctlEnabled([]byte("hw.optional.armv8_1_atomics\x00")) || darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_LSE\x00"))
ARM64.HasLRCPC = darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_LRCPC\x00"))
// SIMD and floating point capabilities
ARM64.HasFPHP = darwinSysctlEnabled([]byte("hw.optional.neon_fp16\x00")) || darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_FP16\x00"))
ARM64.HasASIMDHP = darwinSysctlEnabled([]byte("hw.optional.neon_hpfp\x00")) || darwinSysctlEnabled([]byte("hw.optional.AdvSIMD_HPFPCvt\x00"))
ARM64.HasASIMDRDM = darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_RDM\x00"))
ARM64.HasASIMDDP = darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_DotProd\x00"))
ARM64.HasASIMDFHM = darwinSysctlEnabled([]byte("hw.optional.armv8_2_fhm\x00")) || darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_FHM\x00"))
ARM64.HasI8MM = darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_I8MM\x00"))
ARM64.HasJSCVT = darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_JSCVT\x00"))
ARM64.HasFCMA = darwinSysctlEnabled([]byte("hw.optional.armv8_3_compnum\x00")) || darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_FCMA\x00"))
// Miscellaneous
ARM64.HasDCPOP = darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_DPB\x00"))
ARM64.HasEVTSTRM = darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_ECV\x00"))
ARM64.HasDIT = darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_DIT\x00"))
// Not supported, but added for completeness
ARM64.HasCPUID = false
ARM64.HasSM3 = false // darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_SM3\x00"))
ARM64.HasSM4 = false // darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_SM4\x00"))
ARM64.HasSVE = false // darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_SVE\x00"))
ARM64.HasSVE2 = false // darwinSysctlEnabled([]byte("hw.optional.arm.FEAT_SVE2\x00"))
}
+31
View File
@@ -0,0 +1,31 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin && arm64 && !gc
package cpu
import "runtime"
func doinit() {
setMinimalFeatures()
ARM64.HasASIMD = true
ARM64.HasFP = true
// Go already assumes these to be available because they were on the M1
// and these are supported on all Apple arm64 chips.
ARM64.HasAES = true
ARM64.HasPMULL = true
ARM64.HasSHA1 = true
ARM64.HasSHA2 = true
if runtime.GOOS != "ios" {
// Apple A7 processors do not support these, however
// M-series SoCs are at least armv8.4-a
ARM64.HasCRC32 = true // armv8.1
ARM64.HasATOMICS = true // armv8.2
ARM64.HasJSCVT = true // armv8.3, if HasFP
}
}
+1
View File
@@ -9,3 +9,4 @@ package cpu
func getisar0() uint64 { return 0 }
func getisar1() uint64 { return 0 }
func getpfr0() uint64 { return 0 }
func getzfr0() uint64 { return 0 }
+4 -2
View File
@@ -2,8 +2,10 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !linux && !netbsd && !openbsd && !windows && arm64
//go:build !darwin && !linux && !netbsd && !openbsd && arm64
package cpu
func doinit() {}
func doinit() {
setMinimalFeatures()
}
-42
View File
@@ -1,42 +0,0 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
import (
"golang.org/x/sys/windows"
)
func doinit() {
// set HasASIMD and HasFP to true as per
// https://learn.microsoft.com/en-us/cpp/build/arm64-windows-abi-conventions?view=msvc-170#base-requirements
//
// The ARM64 version of Windows always presupposes that it's running on an ARMv8 or later architecture.
// Both floating-point and NEON support are presumed to be present in hardware.
//
ARM64.HasASIMD = true
ARM64.HasFP = true
if windows.IsProcessorFeaturePresent(windows.PF_ARM_V8_CRYPTO_INSTRUCTIONS_AVAILABLE) {
ARM64.HasAES = true
ARM64.HasPMULL = true
ARM64.HasSHA1 = true
ARM64.HasSHA2 = true
}
ARM64.HasSHA3 = windows.IsProcessorFeaturePresent(windows.PF_ARM_SHA3_INSTRUCTIONS_AVAILABLE)
ARM64.HasCRC32 = windows.IsProcessorFeaturePresent(windows.PF_ARM_V8_CRC32_INSTRUCTIONS_AVAILABLE)
ARM64.HasSHA512 = windows.IsProcessorFeaturePresent(windows.PF_ARM_SHA512_INSTRUCTIONS_AVAILABLE)
ARM64.HasATOMICS = windows.IsProcessorFeaturePresent(windows.PF_ARM_V81_ATOMIC_INSTRUCTIONS_AVAILABLE)
if windows.IsProcessorFeaturePresent(windows.PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) {
ARM64.HasASIMDDP = true
ARM64.HasASIMDRDM = true
}
if windows.IsProcessorFeaturePresent(windows.PF_ARM_V83_LRCPC_INSTRUCTIONS_AVAILABLE) {
ARM64.HasLRCPC = true
ARM64.HasSM3 = true
}
ARM64.HasSVE = windows.IsProcessorFeaturePresent(windows.PF_ARM_SVE_INSTRUCTIONS_AVAILABLE)
ARM64.HasSVE2 = windows.IsProcessorFeaturePresent(windows.PF_ARM_SVE2_INSTRUCTIONS_AVAILABLE)
ARM64.HasJSCVT = windows.IsProcessorFeaturePresent(windows.PF_ARM_V83_JSCVT_INSTRUCTIONS_AVAILABLE)
}
+54
View File
@@ -0,0 +1,54 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Minimal copy from internal/cpu and runtime to make sysctl calls.
//go:build darwin && arm64 && gc
package cpu
import (
"syscall"
"unsafe"
)
type Errno = syscall.Errno
// adapted from internal/cpu/cpu_arm64_darwin.go
func darwinSysctlEnabled(name []byte) bool {
out := int32(0)
nout := unsafe.Sizeof(out)
if ret := sysctlbyname(&name[0], (*byte)(unsafe.Pointer(&out)), &nout, nil, 0); ret != nil {
return false
}
return out > 0
}
//go:cgo_import_dynamic libc_sysctl sysctl "/usr/lib/libSystem.B.dylib"
var libc_sysctlbyname_trampoline_addr uintptr
// adapted from runtime/sys_darwin.go in the pattern of sysctl() above, as defined in x/sys/unix
func sysctlbyname(name *byte, old *byte, oldlen *uintptr, new *byte, newlen uintptr) error {
if _, _, err := syscall_syscall6(
libc_sysctlbyname_trampoline_addr,
uintptr(unsafe.Pointer(name)),
uintptr(unsafe.Pointer(old)),
uintptr(unsafe.Pointer(oldlen)),
uintptr(unsafe.Pointer(new)),
uintptr(newlen),
0,
); err != 0 {
return err
}
return nil
}
//go:cgo_import_dynamic libc_sysctlbyname sysctlbyname "/usr/lib/libSystem.B.dylib"
// Implemented in the runtime package (runtime/sys_darwin.go)
func syscall_syscall6(fn, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err Errno)
//go:linkname syscall_syscall6 syscall.syscall6
+1 -7
View File
@@ -19,13 +19,7 @@ import (
// A Note is a string describing a process note.
// It implements the os.Signal interface.
type Note string
func (n Note) Signal() {}
func (n Note) String() string {
return string(n)
}
type Note = syscall.Note
var (
Stdin = 0
+125 -104
View File
@@ -593,110 +593,115 @@ const (
)
const (
NDA_UNSPEC = 0x0
NDA_DST = 0x1
NDA_LLADDR = 0x2
NDA_CACHEINFO = 0x3
NDA_PROBES = 0x4
NDA_VLAN = 0x5
NDA_PORT = 0x6
NDA_VNI = 0x7
NDA_IFINDEX = 0x8
NDA_MASTER = 0x9
NDA_LINK_NETNSID = 0xa
NDA_SRC_VNI = 0xb
NTF_USE = 0x1
NTF_SELF = 0x2
NTF_MASTER = 0x4
NTF_PROXY = 0x8
NTF_EXT_LEARNED = 0x10
NTF_OFFLOADED = 0x20
NTF_ROUTER = 0x80
NUD_INCOMPLETE = 0x1
NUD_REACHABLE = 0x2
NUD_STALE = 0x4
NUD_DELAY = 0x8
NUD_PROBE = 0x10
NUD_FAILED = 0x20
NUD_NOARP = 0x40
NUD_PERMANENT = 0x80
NUD_NONE = 0x0
IFA_UNSPEC = 0x0
IFA_ADDRESS = 0x1
IFA_LOCAL = 0x2
IFA_LABEL = 0x3
IFA_BROADCAST = 0x4
IFA_ANYCAST = 0x5
IFA_CACHEINFO = 0x6
IFA_MULTICAST = 0x7
IFA_FLAGS = 0x8
IFA_RT_PRIORITY = 0x9
IFA_TARGET_NETNSID = 0xa
IFAL_LABEL = 0x2
IFAL_ADDRESS = 0x1
RT_SCOPE_UNIVERSE = 0x0
RT_SCOPE_SITE = 0xc8
RT_SCOPE_LINK = 0xfd
RT_SCOPE_HOST = 0xfe
RT_SCOPE_NOWHERE = 0xff
RT_TABLE_UNSPEC = 0x0
RT_TABLE_COMPAT = 0xfc
RT_TABLE_DEFAULT = 0xfd
RT_TABLE_MAIN = 0xfe
RT_TABLE_LOCAL = 0xff
RT_TABLE_MAX = 0xffffffff
RTA_UNSPEC = 0x0
RTA_DST = 0x1
RTA_SRC = 0x2
RTA_IIF = 0x3
RTA_OIF = 0x4
RTA_GATEWAY = 0x5
RTA_PRIORITY = 0x6
RTA_PREFSRC = 0x7
RTA_METRICS = 0x8
RTA_MULTIPATH = 0x9
RTA_FLOW = 0xb
RTA_CACHEINFO = 0xc
RTA_TABLE = 0xf
RTA_MARK = 0x10
RTA_MFC_STATS = 0x11
RTA_VIA = 0x12
RTA_NEWDST = 0x13
RTA_PREF = 0x14
RTA_ENCAP_TYPE = 0x15
RTA_ENCAP = 0x16
RTA_EXPIRES = 0x17
RTA_PAD = 0x18
RTA_UID = 0x19
RTA_TTL_PROPAGATE = 0x1a
RTA_IP_PROTO = 0x1b
RTA_SPORT = 0x1c
RTA_DPORT = 0x1d
RTN_UNSPEC = 0x0
RTN_UNICAST = 0x1
RTN_LOCAL = 0x2
RTN_BROADCAST = 0x3
RTN_ANYCAST = 0x4
RTN_MULTICAST = 0x5
RTN_BLACKHOLE = 0x6
RTN_UNREACHABLE = 0x7
RTN_PROHIBIT = 0x8
RTN_THROW = 0x9
RTN_NAT = 0xa
RTN_XRESOLVE = 0xb
SizeofNlMsghdr = 0x10
SizeofNlMsgerr = 0x14
SizeofRtGenmsg = 0x1
SizeofNlAttr = 0x4
SizeofRtAttr = 0x4
SizeofIfInfomsg = 0x10
SizeofIfAddrmsg = 0x8
SizeofIfAddrlblmsg = 0xc
SizeofIfaCacheinfo = 0x10
SizeofRtMsg = 0xc
SizeofRtNexthop = 0x8
SizeofNdUseroptmsg = 0x10
SizeofNdMsg = 0xc
NDA_UNSPEC = 0x0
NDA_DST = 0x1
NDA_LLADDR = 0x2
NDA_CACHEINFO = 0x3
NDA_PROBES = 0x4
NDA_VLAN = 0x5
NDA_PORT = 0x6
NDA_VNI = 0x7
NDA_IFINDEX = 0x8
NDA_MASTER = 0x9
NDA_LINK_NETNSID = 0xa
NDA_SRC_VNI = 0xb
NTF_USE = 0x1
NTF_SELF = 0x2
NTF_MASTER = 0x4
NTF_PROXY = 0x8
NTF_EXT_LEARNED = 0x10
NTF_OFFLOADED = 0x20
NTF_ROUTER = 0x80
NUD_INCOMPLETE = 0x1
NUD_REACHABLE = 0x2
NUD_STALE = 0x4
NUD_DELAY = 0x8
NUD_PROBE = 0x10
NUD_FAILED = 0x20
NUD_NOARP = 0x40
NUD_PERMANENT = 0x80
NUD_NONE = 0x0
IFA_UNSPEC = 0x0
IFA_ADDRESS = 0x1
IFA_LOCAL = 0x2
IFA_LABEL = 0x3
IFA_BROADCAST = 0x4
IFA_ANYCAST = 0x5
IFA_CACHEINFO = 0x6
IFA_MULTICAST = 0x7
IFA_FLAGS = 0x8
IFA_RT_PRIORITY = 0x9
IFA_TARGET_NETNSID = 0xa
IFAL_LABEL = 0x2
IFAL_ADDRESS = 0x1
RT_SCOPE_UNIVERSE = 0x0
RT_SCOPE_SITE = 0xc8
RT_SCOPE_LINK = 0xfd
RT_SCOPE_HOST = 0xfe
RT_SCOPE_NOWHERE = 0xff
RT_TABLE_UNSPEC = 0x0
RT_TABLE_COMPAT = 0xfc
RT_TABLE_DEFAULT = 0xfd
RT_TABLE_MAIN = 0xfe
RT_TABLE_LOCAL = 0xff
RT_TABLE_MAX = 0xffffffff
RTA_UNSPEC = 0x0
RTA_DST = 0x1
RTA_SRC = 0x2
RTA_IIF = 0x3
RTA_OIF = 0x4
RTA_GATEWAY = 0x5
RTA_PRIORITY = 0x6
RTA_PREFSRC = 0x7
RTA_METRICS = 0x8
RTA_MULTIPATH = 0x9
RTA_FLOW = 0xb
RTA_CACHEINFO = 0xc
RTA_TABLE = 0xf
RTA_MARK = 0x10
RTA_MFC_STATS = 0x11
RTA_VIA = 0x12
RTA_NEWDST = 0x13
RTA_PREF = 0x14
RTA_ENCAP_TYPE = 0x15
RTA_ENCAP = 0x16
RTA_EXPIRES = 0x17
RTA_PAD = 0x18
RTA_UID = 0x19
RTA_TTL_PROPAGATE = 0x1a
RTA_IP_PROTO = 0x1b
RTA_SPORT = 0x1c
RTA_DPORT = 0x1d
RTN_UNSPEC = 0x0
RTN_UNICAST = 0x1
RTN_LOCAL = 0x2
RTN_BROADCAST = 0x3
RTN_ANYCAST = 0x4
RTN_MULTICAST = 0x5
RTN_BLACKHOLE = 0x6
RTN_UNREACHABLE = 0x7
RTN_PROHIBIT = 0x8
RTN_THROW = 0x9
RTN_NAT = 0xa
RTN_XRESOLVE = 0xb
PREFIX_UNSPEC = 0x0
PREFIX_ADDRESS = 0x1
PREFIX_CACHEINFO = 0x2
SizeofNlMsghdr = 0x10
SizeofNlMsgerr = 0x14
SizeofRtGenmsg = 0x1
SizeofNlAttr = 0x4
SizeofRtAttr = 0x4
SizeofIfInfomsg = 0x10
SizeofPrefixmsg = 0xc
SizeofPrefixCacheinfo = 0x8
SizeofIfAddrmsg = 0x8
SizeofIfAddrlblmsg = 0xc
SizeofIfaCacheinfo = 0x10
SizeofRtMsg = 0xc
SizeofRtNexthop = 0x8
SizeofNdUseroptmsg = 0x10
SizeofNdMsg = 0xc
)
type NlMsghdr struct {
@@ -735,6 +740,22 @@ type IfInfomsg struct {
Change uint32
}
type Prefixmsg struct {
Family uint8
Pad1 uint8
Pad2 uint16
Ifindex int32
Type uint8
Len uint8
Flags uint8
Pad3 uint8
}
type PrefixCacheinfo struct {
Preferred_time uint32
Valid_time uint32
}
type IfAddrmsg struct {
Family uint8
Prefixlen uint8
+1
View File
@@ -8,5 +8,6 @@ package windows
import "syscall"
type Signal = syscall.Signal
type Errno = syscall.Errno
type SysProcAttr = syscall.SysProcAttr
+1 -36
View File
@@ -163,42 +163,7 @@ func (p *Proc) Addr() uintptr {
// (according to the semantics of the specific function being called) before consulting
// the error. The error will be guaranteed to contain windows.Errno.
func (p *Proc) Call(a ...uintptr) (r1, r2 uintptr, lastErr error) {
switch len(a) {
case 0:
return syscall.Syscall(p.Addr(), uintptr(len(a)), 0, 0, 0)
case 1:
return syscall.Syscall(p.Addr(), uintptr(len(a)), a[0], 0, 0)
case 2:
return syscall.Syscall(p.Addr(), uintptr(len(a)), a[0], a[1], 0)
case 3:
return syscall.Syscall(p.Addr(), uintptr(len(a)), a[0], a[1], a[2])
case 4:
return syscall.Syscall6(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], 0, 0)
case 5:
return syscall.Syscall6(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], 0)
case 6:
return syscall.Syscall6(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5])
case 7:
return syscall.Syscall9(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], 0, 0)
case 8:
return syscall.Syscall9(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], 0)
case 9:
return syscall.Syscall9(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8])
case 10:
return syscall.Syscall12(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], 0, 0)
case 11:
return syscall.Syscall12(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], 0)
case 12:
return syscall.Syscall12(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11])
case 13:
return syscall.Syscall15(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], 0, 0)
case 14:
return syscall.Syscall15(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], 0)
case 15:
return syscall.Syscall15(p.Addr(), uintptr(len(a)), a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14])
default:
panic("Call " + p.Name + " with too many arguments " + itoa(len(a)) + ".")
}
return syscall.SyscallN(p.Addr(), a...)
}
// A LazyDLL implements access to a single DLL.
+14 -1
View File
@@ -198,7 +198,20 @@ type KeyInfo struct {
// ModTime returns the key's last write time.
func (ki *KeyInfo) ModTime() time.Time {
return time.Unix(0, ki.lastWriteTime.Nanoseconds())
lastHigh, lastLow := ki.lastWriteTime.HighDateTime, ki.lastWriteTime.LowDateTime
// 100-nanosecond intervals since January 1, 1601
hsec := uint64(lastHigh)<<32 + uint64(lastLow)
// Convert _before_ gauging; the nanosecond difference between Epoch (00:00:00
// UTC, January 1, 1970) and Filetime's zero offset (January 1, 1601) is out
// of bounds for int64: -11644473600*1e7*1e2 < math.MinInt64
sec := int64(hsec/1e7) - 11644473600
nsec := int64(hsec%1e7) * 100
return time.Unix(sec, nsec)
}
// modTimeZero reports whether the key's last write time is zero.
func (ki *KeyInfo) modTimeZero() bool {
return ki.lastWriteTime.LowDateTime == 0 && ki.lastWriteTime.HighDateTime == 0
}
// Stat retrieves information about the open key k.
+5 -1
View File
@@ -1438,13 +1438,17 @@ func GetSecurityInfo(handle Handle, objectType SE_OBJECT_TYPE, securityInformati
}
// GetNamedSecurityInfo queries the security information for a given named object and returns the self-relative security
// descriptor result on the Go heap.
// descriptor result on the Go heap. The security descriptor might be nil, even when err is nil, if the object exists
// but has no security descriptor.
func GetNamedSecurityInfo(objectName string, objectType SE_OBJECT_TYPE, securityInformation SECURITY_INFORMATION) (sd *SECURITY_DESCRIPTOR, err error) {
var winHeapSD *SECURITY_DESCRIPTOR
err = getNamedSecurityInfo(objectName, objectType, securityInformation, nil, nil, nil, nil, &winHeapSD)
if err != nil {
return
}
if winHeapSD == nil {
return nil, nil
}
defer LocalFree(Handle(unsafe.Pointer(winHeapSD)))
return winHeapSD.copySelfRelativeSecurityDescriptor(), nil
}
-14
View File
@@ -1490,20 +1490,6 @@ func Getgid() (gid int) { return -1 }
func Getegid() (egid int) { return -1 }
func Getgroups() (gids []int, err error) { return nil, syscall.EWINDOWS }
type Signal int
func (s Signal) Signal() {}
func (s Signal) String() string {
if 0 <= s && int(s) < len(signals) {
str := signals[s]
if str != "" {
return str
}
}
return "signal " + itoa(int(s))
}
func LoadCreateSymbolicLink() error {
return procCreateSymbolicLinkW.Find()
}
+3 -1
View File
@@ -20,12 +20,14 @@ func isTerminal(fd int) bool {
return err == nil
}
// This is intended to be used on a console input handle.
// See https://learn.microsoft.com/en-us/windows/console/setconsolemode
func makeRaw(fd int) (*State, error) {
var st uint32
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
return nil, err
}
raw := st &^ (windows.ENABLE_ECHO_INPUT | windows.ENABLE_PROCESSED_INPUT | windows.ENABLE_LINE_INPUT | windows.ENABLE_PROCESSED_OUTPUT)
raw := st &^ (windows.ENABLE_ECHO_INPUT | windows.ENABLE_PROCESSED_INPUT | windows.ENABLE_LINE_INPUT)
raw |= windows.ENABLE_VIRTUAL_TERMINAL_INPUT
if err := windows.SetConsoleMode(windows.Handle(fd), raw); err != nil {
return nil, err
+37 -6
View File
@@ -146,6 +146,7 @@ const (
keyCtrlD = 4
keyCtrlU = 21
keyEnter = '\r'
keyLF = '\n'
keyEscape = 27
keyBackspace = 127
keyUnknown = 0xd800 /* UTF-16 surrogate area */ + iota
@@ -159,7 +160,9 @@ const (
keyEnd
keyDeleteWord
keyDeleteLine
keyDelete
keyClearScreen
keyTranspose
keyPasteStart
keyPasteEnd
)
@@ -193,6 +196,8 @@ func bytesToKey(b []byte, pasteActive bool) (rune, []byte) {
return keyDeleteLine, b[1:]
case 12: // ^L
return keyClearScreen, b[1:]
case 20: // ^T
return keyTranspose, b[1:]
case 23: // ^W
return keyDeleteWord, b[1:]
case 14: // ^N
@@ -227,6 +232,10 @@ func bytesToKey(b []byte, pasteActive bool) (rune, []byte) {
}
}
if !pasteActive && len(b) >= 4 && b[0] == keyEscape && b[1] == '[' && b[2] == '3' && b[3] == '~' {
return keyDelete, b[4:]
}
if !pasteActive && len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
switch b[5] {
case 'C':
@@ -412,7 +421,7 @@ func (t *Terminal) eraseNPreviousChars(n int) {
}
}
// countToLeftWord returns then number of characters from the cursor to the
// countToLeftWord returns the number of characters from the cursor to the
// start of the previous word.
func (t *Terminal) countToLeftWord() int {
if t.pos == 0 {
@@ -437,7 +446,7 @@ func (t *Terminal) countToLeftWord() int {
return t.pos - pos
}
// countToRightWord returns then number of characters from the cursor to the
// countToRightWord returns the number of characters from the cursor to the
// start of the next word.
func (t *Terminal) countToRightWord() int {
pos := t.pos
@@ -477,7 +486,7 @@ func visualLength(runes []rune) int {
return length
}
// histroryAt unlocks the terminal and relocks it while calling History.At.
// historyAt unlocks the terminal and relocks it while calling History.At.
func (t *Terminal) historyAt(idx int) (string, bool) {
t.lock.Unlock() // Unlock to avoid deadlock if History methods use the output writer.
defer t.lock.Lock() // panic in At (or Len) protection.
@@ -497,7 +506,7 @@ func (t *Terminal) historyAdd(entry string) {
// handleKey processes the given key and, optionally, returns a line of text
// that the user has entered.
func (t *Terminal) handleKey(key rune) (line string, ok bool) {
if t.pasteActive && key != keyEnter {
if t.pasteActive && key != keyEnter && key != keyLF {
t.addKeyToLine(key)
return
}
@@ -567,7 +576,7 @@ func (t *Terminal) handleKey(key rune) (line string, ok bool) {
t.setLine(runes, len(runes))
}
}
case keyEnter:
case keyEnter, keyLF:
t.moveCursorToPos(len(t.line))
t.queue([]rune("\r\n"))
line = string(t.line)
@@ -589,7 +598,7 @@ func (t *Terminal) handleKey(key rune) (line string, ok bool) {
}
t.line = t.line[:t.pos]
t.moveCursorToPos(t.pos)
case keyCtrlD:
case keyCtrlD, keyDelete:
// Erase the character under the current position.
// The EOF case when the line is empty is handled in
// readLine().
@@ -599,6 +608,24 @@ func (t *Terminal) handleKey(key rune) (line string, ok bool) {
}
case keyCtrlU:
t.eraseNPreviousChars(t.pos)
case keyTranspose:
// This transposes the two characters around the cursor and advances the cursor. Best-effort.
if len(t.line) < 2 || t.pos < 1 {
return
}
swap := t.pos
if swap == len(t.line) {
swap-- // special: at end of line, swap previous two chars
}
t.line[swap-1], t.line[swap] = t.line[swap], t.line[swap-1]
if t.pos < len(t.line) {
t.pos++
}
if t.echo {
t.moveCursorToPos(swap - 1)
t.writeLine(t.line[swap-1:])
t.moveCursorToPos(t.pos)
}
case keyClearScreen:
// Erases the screen and moves the cursor to the home position.
t.queue([]rune("\x1b[2J\x1b[H"))
@@ -812,6 +839,10 @@ func (t *Terminal) readLine() (line string, err error) {
if !t.pasteActive {
lineIsPasted = false
}
// If we have CR, consume LF if present (CRLF sequence) to avoid returning an extra empty line.
if key == keyEnter && len(rest) > 0 && rest[0] == keyLF {
rest = rest[1:]
}
line, lineOk = t.handleKey(key)
}
if len(rest) > 0 {
-2255
View File
File diff suppressed because it is too large Load Diff
-2316
View File
File diff suppressed because it is too large Load Diff
-2359
View File
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -1,6 +1,6 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
//go:build go1.21
//go:build !go1.27
package cases
File diff suppressed because it is too large Load Diff
-2215
View File
File diff suppressed because it is too large Load Diff
+4
View File
@@ -334,3 +334,7 @@ func (t *Transformer) advanceString(s string) (n int, ok bool) {
}
return n, true
}
func (t *Transformer) isFinal() bool {
return t.state == ruleLTRFinal || t.state == ruleRTLFinal || t.state == ruleInitial
}
-11
View File
@@ -1,11 +0,0 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.10
package bidirule
func (t *Transformer) isFinal() bool {
return t.state == ruleLTRFinal || t.state == ruleRTLFinal || t.state == ruleInitial
}
-14
View File
@@ -1,14 +0,0 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.10
package bidirule
func (t *Transformer) isFinal() bool {
if !t.isRTL() {
return true
}
return t.state == ruleLTRFinal || t.state == ruleRTLFinal || t.state == ruleInitial
}
+2 -9
View File
@@ -427,13 +427,6 @@ type isolatingRunSequence struct {
func (i *isolatingRunSequence) Len() int { return len(i.indexes) }
func maxLevel(a, b level) level {
if a > b {
return a
}
return b
}
// Rule X10, second bullet: Determine the start-of-sequence (sos) and end-of-sequence (eos) types,
// either L or R, for each isolating run sequence.
func (p *paragraph) isolatingRunSequence(indexes []int) *isolatingRunSequence {
@@ -474,8 +467,8 @@ func (p *paragraph) isolatingRunSequence(indexes []int) *isolatingRunSequence {
indexes: indexes,
types: types,
level: level,
sos: typeForLevel(maxLevel(prevLevel, level)),
eos: typeForLevel(maxLevel(succLevel, level)),
sos: typeForLevel(max(prevLevel, level)),
eos: typeForLevel(max(succLevel, level)),
}
}

Some files were not shown because too many files have changed in this diff Show More