mirror of
https://github.com/absmach/supermq.git
synced 2026-06-23 06:30:22 +00:00
NOISSUE - Update dependencies (#1176)
* Update dependencies Signed-off-by: Dušan Borovčanin <dusan.borovcanin@mainflux.com> * Fix mProxy version Signed-off-by: dusanb <borovcanindusan1@gmail.com.com> Co-authored-by: dusanb <borovcanindusan1@gmail.com.com>
This commit is contained in:
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/mainflux/mainflux/authn/postgres"
|
||||
dockertest "gopkg.in/ory/dockertest.v3"
|
||||
dockertest "github.com/ory/dockertest/v3"
|
||||
)
|
||||
|
||||
const wrong string = "wrong-value"
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/mainflux/mainflux/bootstrap/postgres"
|
||||
"github.com/mainflux/mainflux/logger"
|
||||
dockertest "gopkg.in/ory/dockertest.v3"
|
||||
dockertest "github.com/ory/dockertest/v3"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/go-redis/redis"
|
||||
dockertest "gopkg.in/ory/dockertest.v3"
|
||||
dockertest "github.com/ory/dockertest/v3"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -6,36 +6,38 @@ require (
|
||||
github.com/BurntSushi/toml v0.3.1
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible
|
||||
github.com/docker/docker v1.13.1
|
||||
github.com/dustin/go-coap v0.0.0-20170214053734-ddcc80675fa4
|
||||
github.com/dustin/go-coap v0.0.0-20190908170653-752e0f79981e
|
||||
github.com/eclipse/paho.mqtt.golang v1.2.0
|
||||
github.com/fatih/color v1.7.0
|
||||
github.com/go-kit/kit v0.9.0
|
||||
github.com/go-redis/redis v6.15.0+incompatible
|
||||
github.com/fatih/color v1.9.0
|
||||
github.com/go-kit/kit v0.10.0
|
||||
github.com/go-redis/redis v6.15.7+incompatible
|
||||
github.com/go-zoo/bone v1.3.0
|
||||
github.com/gocql/gocql v0.0.0-20181106112037-68ae1e384be4
|
||||
github.com/gofrs/uuid v3.2.0+incompatible
|
||||
github.com/gocql/gocql v0.0.0-20200511135441-57b003a04490
|
||||
github.com/gofrs/uuid v3.3.0+incompatible
|
||||
github.com/gogo/protobuf v1.3.1
|
||||
github.com/golang/protobuf v1.4.0
|
||||
github.com/golang/protobuf v1.4.1
|
||||
github.com/gopcua/opcua v0.1.6
|
||||
github.com/hokaccha/go-prettyjson v0.0.0-20180920040306-f579f869bbfe
|
||||
github.com/influxdata/influxdb v1.6.4
|
||||
github.com/hokaccha/go-prettyjson v0.0.0-20190818114111-108c894c2c0e
|
||||
github.com/influxdata/influxdb v1.8.0
|
||||
github.com/jmoiron/sqlx v1.2.1-0.20190319043955-cdf62fdf55f6
|
||||
github.com/lib/pq v1.0.0
|
||||
github.com/mainflux/mproxy v0.1.8
|
||||
github.com/lib/pq v1.5.2
|
||||
github.com/mainflux/mproxy v0.2.0
|
||||
github.com/mainflux/senml v1.0.1
|
||||
github.com/nats-io/nats.go v1.9.1
|
||||
github.com/nats-io/nats.go v1.10.0
|
||||
github.com/opentracing/opentracing-go v1.1.0
|
||||
github.com/ory/dockertest/v3 v3.6.0
|
||||
github.com/pelletier/go-toml v1.7.0
|
||||
github.com/prometheus/client_golang v1.5.1
|
||||
github.com/rubenv/sql-migrate v0.0.0-20181106121204-ba2c6a7295c5
|
||||
github.com/spf13/cobra v0.0.5
|
||||
github.com/spf13/viper v1.5.0
|
||||
github.com/prometheus/client_golang v1.6.0
|
||||
github.com/rubenv/sql-migrate v0.0.0-20200429072036-ae26b214fa43
|
||||
github.com/spf13/cobra v1.0.0
|
||||
github.com/spf13/viper v1.7.0
|
||||
github.com/stretchr/testify v1.5.1
|
||||
github.com/uber/jaeger-client-go v2.22.1+incompatible
|
||||
go.mongodb.org/mongo-driver v1.1.3
|
||||
golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4
|
||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b
|
||||
gonum.org/v1/gonum v0.0.0-20190808205415-ced62fe5104b
|
||||
google.golang.org/grpc v1.27.1
|
||||
gopkg.in/ory/dockertest.v3 v3.3.5
|
||||
github.com/uber/jaeger-client-go v2.23.1+incompatible
|
||||
go.mongodb.org/mongo-driver v1.3.3
|
||||
golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37
|
||||
golang.org/x/net v0.0.0-20200513185701-a91f0712d120
|
||||
golang.org/x/sys v0.0.0-20200513112337-417ce2331b5c // indirect
|
||||
gonum.org/v1/gonum v0.7.0
|
||||
google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884 // indirect
|
||||
google.golang.org/grpc v1.29.1
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/gocql/gocql"
|
||||
log "github.com/mainflux/mainflux/logger"
|
||||
"github.com/mainflux/mainflux/writers/cassandra"
|
||||
dockertest "gopkg.in/ory/dockertest.v3"
|
||||
dockertest "github.com/ory/dockertest/v3"
|
||||
)
|
||||
|
||||
var logger, _ = log.New(os.Stdout, log.Info.String())
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
influxdb "github.com/influxdata/influxdb/client/v2"
|
||||
dockertest "gopkg.in/ory/dockertest.v3"
|
||||
dockertest "github.com/ory/dockertest/v3"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
|
||||
dockertest "gopkg.in/ory/dockertest.v3"
|
||||
dockertest "github.com/ory/dockertest/v3"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/mainflux/mainflux/logger"
|
||||
"github.com/mainflux/mainflux/readers/postgres"
|
||||
dockertest "gopkg.in/ory/dockertest.v3"
|
||||
dockertest "github.com/ory/dockertest/v3"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
+2
-2
@@ -2,9 +2,9 @@
|
||||
NPROC=$(nproc)
|
||||
GO_VERSION=1.13
|
||||
PROTOC_VERSION=3.11.4
|
||||
PROTOC_GEN_VERSION=v1.3.3
|
||||
PROTOC_GEN_VERSION=v1.4.1
|
||||
PROTOC_GOFAST_VERSION=v1.3.1
|
||||
GRPC_VERSION=v1.27.1
|
||||
GRPC_VERSION=v1.29.1
|
||||
|
||||
function version_gt() { test "$(printf '%s\n' "$@" | sort -V | head -n 1)" != "$1"; }
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/mainflux/mainflux/logger"
|
||||
"github.com/mainflux/mainflux/things/postgres"
|
||||
dockertest "gopkg.in/ory/dockertest.v3"
|
||||
dockertest "github.com/ory/dockertest/v3"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/go-redis/redis"
|
||||
dockertest "gopkg.in/ory/dockertest.v3"
|
||||
dockertest "github.com/ory/dockertest/v3"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
|
||||
dockertest "gopkg.in/ory/dockertest.v3"
|
||||
dockertest "github.com/ory/dockertest/v3"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/mainflux/mainflux/users/postgres"
|
||||
dockertest "gopkg.in/ory/dockertest.v3"
|
||||
dockertest "github.com/ory/dockertest/v3"
|
||||
)
|
||||
|
||||
const wrong string = "wrong-value"
|
||||
|
||||
+137
-137
@@ -1,137 +1,137 @@
|
||||
package winio
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type fileFullEaInformation struct {
|
||||
NextEntryOffset uint32
|
||||
Flags uint8
|
||||
NameLength uint8
|
||||
ValueLength uint16
|
||||
}
|
||||
|
||||
var (
|
||||
fileFullEaInformationSize = binary.Size(&fileFullEaInformation{})
|
||||
|
||||
errInvalidEaBuffer = errors.New("invalid extended attribute buffer")
|
||||
errEaNameTooLarge = errors.New("extended attribute name too large")
|
||||
errEaValueTooLarge = errors.New("extended attribute value too large")
|
||||
)
|
||||
|
||||
// ExtendedAttribute represents a single Windows EA.
|
||||
type ExtendedAttribute struct {
|
||||
Name string
|
||||
Value []byte
|
||||
Flags uint8
|
||||
}
|
||||
|
||||
func parseEa(b []byte) (ea ExtendedAttribute, nb []byte, err error) {
|
||||
var info fileFullEaInformation
|
||||
err = binary.Read(bytes.NewReader(b), binary.LittleEndian, &info)
|
||||
if err != nil {
|
||||
err = errInvalidEaBuffer
|
||||
return
|
||||
}
|
||||
|
||||
nameOffset := fileFullEaInformationSize
|
||||
nameLen := int(info.NameLength)
|
||||
valueOffset := nameOffset + int(info.NameLength) + 1
|
||||
valueLen := int(info.ValueLength)
|
||||
nextOffset := int(info.NextEntryOffset)
|
||||
if valueLen+valueOffset > len(b) || nextOffset < 0 || nextOffset > len(b) {
|
||||
err = errInvalidEaBuffer
|
||||
return
|
||||
}
|
||||
|
||||
ea.Name = string(b[nameOffset : nameOffset+nameLen])
|
||||
ea.Value = b[valueOffset : valueOffset+valueLen]
|
||||
ea.Flags = info.Flags
|
||||
if info.NextEntryOffset != 0 {
|
||||
nb = b[info.NextEntryOffset:]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeExtendedAttributes decodes a list of EAs from a FILE_FULL_EA_INFORMATION
|
||||
// buffer retrieved from BackupRead, ZwQueryEaFile, etc.
|
||||
func DecodeExtendedAttributes(b []byte) (eas []ExtendedAttribute, err error) {
|
||||
for len(b) != 0 {
|
||||
ea, nb, err := parseEa(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
eas = append(eas, ea)
|
||||
b = nb
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func writeEa(buf *bytes.Buffer, ea *ExtendedAttribute, last bool) error {
|
||||
if int(uint8(len(ea.Name))) != len(ea.Name) {
|
||||
return errEaNameTooLarge
|
||||
}
|
||||
if int(uint16(len(ea.Value))) != len(ea.Value) {
|
||||
return errEaValueTooLarge
|
||||
}
|
||||
entrySize := uint32(fileFullEaInformationSize + len(ea.Name) + 1 + len(ea.Value))
|
||||
withPadding := (entrySize + 3) &^ 3
|
||||
nextOffset := uint32(0)
|
||||
if !last {
|
||||
nextOffset = withPadding
|
||||
}
|
||||
info := fileFullEaInformation{
|
||||
NextEntryOffset: nextOffset,
|
||||
Flags: ea.Flags,
|
||||
NameLength: uint8(len(ea.Name)),
|
||||
ValueLength: uint16(len(ea.Value)),
|
||||
}
|
||||
|
||||
err := binary.Write(buf, binary.LittleEndian, &info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = buf.Write([]byte(ea.Name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = buf.WriteByte(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = buf.Write(ea.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = buf.Write([]byte{0, 0, 0}[0 : withPadding-entrySize])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodeExtendedAttributes encodes a list of EAs into a FILE_FULL_EA_INFORMATION
|
||||
// buffer for use with BackupWrite, ZwSetEaFile, etc.
|
||||
func EncodeExtendedAttributes(eas []ExtendedAttribute) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
for i := range eas {
|
||||
last := false
|
||||
if i == len(eas)-1 {
|
||||
last = true
|
||||
}
|
||||
|
||||
err := writeEa(&buf, &eas[i], last)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
package winio
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type fileFullEaInformation struct {
|
||||
NextEntryOffset uint32
|
||||
Flags uint8
|
||||
NameLength uint8
|
||||
ValueLength uint16
|
||||
}
|
||||
|
||||
var (
|
||||
fileFullEaInformationSize = binary.Size(&fileFullEaInformation{})
|
||||
|
||||
errInvalidEaBuffer = errors.New("invalid extended attribute buffer")
|
||||
errEaNameTooLarge = errors.New("extended attribute name too large")
|
||||
errEaValueTooLarge = errors.New("extended attribute value too large")
|
||||
)
|
||||
|
||||
// ExtendedAttribute represents a single Windows EA.
|
||||
type ExtendedAttribute struct {
|
||||
Name string
|
||||
Value []byte
|
||||
Flags uint8
|
||||
}
|
||||
|
||||
func parseEa(b []byte) (ea ExtendedAttribute, nb []byte, err error) {
|
||||
var info fileFullEaInformation
|
||||
err = binary.Read(bytes.NewReader(b), binary.LittleEndian, &info)
|
||||
if err != nil {
|
||||
err = errInvalidEaBuffer
|
||||
return
|
||||
}
|
||||
|
||||
nameOffset := fileFullEaInformationSize
|
||||
nameLen := int(info.NameLength)
|
||||
valueOffset := nameOffset + int(info.NameLength) + 1
|
||||
valueLen := int(info.ValueLength)
|
||||
nextOffset := int(info.NextEntryOffset)
|
||||
if valueLen+valueOffset > len(b) || nextOffset < 0 || nextOffset > len(b) {
|
||||
err = errInvalidEaBuffer
|
||||
return
|
||||
}
|
||||
|
||||
ea.Name = string(b[nameOffset : nameOffset+nameLen])
|
||||
ea.Value = b[valueOffset : valueOffset+valueLen]
|
||||
ea.Flags = info.Flags
|
||||
if info.NextEntryOffset != 0 {
|
||||
nb = b[info.NextEntryOffset:]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeExtendedAttributes decodes a list of EAs from a FILE_FULL_EA_INFORMATION
|
||||
// buffer retrieved from BackupRead, ZwQueryEaFile, etc.
|
||||
func DecodeExtendedAttributes(b []byte) (eas []ExtendedAttribute, err error) {
|
||||
for len(b) != 0 {
|
||||
ea, nb, err := parseEa(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
eas = append(eas, ea)
|
||||
b = nb
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func writeEa(buf *bytes.Buffer, ea *ExtendedAttribute, last bool) error {
|
||||
if int(uint8(len(ea.Name))) != len(ea.Name) {
|
||||
return errEaNameTooLarge
|
||||
}
|
||||
if int(uint16(len(ea.Value))) != len(ea.Value) {
|
||||
return errEaValueTooLarge
|
||||
}
|
||||
entrySize := uint32(fileFullEaInformationSize + len(ea.Name) + 1 + len(ea.Value))
|
||||
withPadding := (entrySize + 3) &^ 3
|
||||
nextOffset := uint32(0)
|
||||
if !last {
|
||||
nextOffset = withPadding
|
||||
}
|
||||
info := fileFullEaInformation{
|
||||
NextEntryOffset: nextOffset,
|
||||
Flags: ea.Flags,
|
||||
NameLength: uint8(len(ea.Name)),
|
||||
ValueLength: uint16(len(ea.Value)),
|
||||
}
|
||||
|
||||
err := binary.Write(buf, binary.LittleEndian, &info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = buf.Write([]byte(ea.Name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = buf.WriteByte(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = buf.Write(ea.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = buf.Write([]byte{0, 0, 0}[0 : withPadding-entrySize])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodeExtendedAttributes encodes a list of EAs into a FILE_FULL_EA_INFORMATION
|
||||
// buffer for use with BackupWrite, ZwSetEaFile, etc.
|
||||
func EncodeExtendedAttributes(eas []ExtendedAttribute) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
for i := range eas {
|
||||
last := false
|
||||
if i == len(eas)-1 {
|
||||
last = true
|
||||
}
|
||||
|
||||
err := writeEa(&buf, &eas[i], last)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
+17
-1
@@ -16,6 +16,7 @@ import (
|
||||
//sys createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) = CreateIoCompletionPort
|
||||
//sys getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
|
||||
//sys setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
|
||||
//sys wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
|
||||
|
||||
type atomicBool int32
|
||||
|
||||
@@ -79,6 +80,7 @@ type win32File struct {
|
||||
wg sync.WaitGroup
|
||||
wgLock sync.RWMutex
|
||||
closing atomicBool
|
||||
socket bool
|
||||
readDeadline deadlineHandler
|
||||
writeDeadline deadlineHandler
|
||||
}
|
||||
@@ -109,7 +111,13 @@ func makeWin32File(h syscall.Handle) (*win32File, error) {
|
||||
}
|
||||
|
||||
func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) {
|
||||
return makeWin32File(h)
|
||||
// If we return the result of makeWin32File directly, it can result in an
|
||||
// interface-wrapped nil, rather than a nil interface value.
|
||||
f, err := makeWin32File(h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// closeHandle closes the resources associated with a Win32 handle
|
||||
@@ -190,6 +198,10 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
|
||||
if f.closing.isSet() {
|
||||
err = ErrFileClosed
|
||||
}
|
||||
} else if err != nil && f.socket {
|
||||
// err is from Win32. Query the overlapped structure to get the winsock error.
|
||||
var bytes, flags uint32
|
||||
err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
|
||||
}
|
||||
case <-timeout:
|
||||
cancelIoEx(f.handle, &c.o)
|
||||
@@ -265,6 +277,10 @@ func (f *win32File) Flush() error {
|
||||
return syscall.FlushFileBuffers(f.handle)
|
||||
}
|
||||
|
||||
func (f *win32File) Fd() uintptr {
|
||||
return uintptr(f.handle)
|
||||
}
|
||||
|
||||
func (d *deadlineHandler) set(deadline time.Time) error {
|
||||
d.setLock.Lock()
|
||||
defer d.setLock.Unlock()
|
||||
|
||||
+2
-1
@@ -20,7 +20,8 @@ const (
|
||||
// FileBasicInfo contains file access time and file attributes information.
|
||||
type FileBasicInfo struct {
|
||||
CreationTime, LastAccessTime, LastWriteTime, ChangeTime syscall.Filetime
|
||||
FileAttributes uintptr // includes padding
|
||||
FileAttributes uint32
|
||||
pad uint32 // padding
|
||||
}
|
||||
|
||||
// GetFileBasicInfo retrieves times and attributes for a file.
|
||||
|
||||
+9
@@ -0,0 +1,9 @@
|
||||
module github.com/Microsoft/go-winio
|
||||
|
||||
go 1.12
|
||||
|
||||
require (
|
||||
github.com/pkg/errors v0.8.1
|
||||
github.com/sirupsen/logrus v1.4.1
|
||||
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b
|
||||
)
|
||||
+16
@@ -0,0 +1,16 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/sirupsen/logrus v1.4.1 h1:GL2rEmy6nsikmW0r8opw9JIRScdMF5hA8cOYLH7In1k=
|
||||
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b h1:ag/x1USPSsqHud38I9BAC88qdNLDHHtQ4mlgQIZPPNA=
|
||||
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
+305
@@ -0,0 +1,305 @@
|
||||
package winio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/Microsoft/go-winio/pkg/guid"
|
||||
)
|
||||
|
||||
//sys bind(s syscall.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
|
||||
|
||||
const (
|
||||
afHvSock = 34 // AF_HYPERV
|
||||
|
||||
socketError = ^uintptr(0)
|
||||
)
|
||||
|
||||
// An HvsockAddr is an address for a AF_HYPERV socket.
|
||||
type HvsockAddr struct {
|
||||
VMID guid.GUID
|
||||
ServiceID guid.GUID
|
||||
}
|
||||
|
||||
type rawHvsockAddr struct {
|
||||
Family uint16
|
||||
_ uint16
|
||||
VMID guid.GUID
|
||||
ServiceID guid.GUID
|
||||
}
|
||||
|
||||
// Network returns the address's network name, "hvsock".
|
||||
func (addr *HvsockAddr) Network() string {
|
||||
return "hvsock"
|
||||
}
|
||||
|
||||
func (addr *HvsockAddr) String() string {
|
||||
return fmt.Sprintf("%s:%s", &addr.VMID, &addr.ServiceID)
|
||||
}
|
||||
|
||||
// VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port.
|
||||
func VsockServiceID(port uint32) guid.GUID {
|
||||
g, _ := guid.FromString("00000000-facb-11e6-bd58-64006a7986d3")
|
||||
g.Data1 = port
|
||||
return g
|
||||
}
|
||||
|
||||
func (addr *HvsockAddr) raw() rawHvsockAddr {
|
||||
return rawHvsockAddr{
|
||||
Family: afHvSock,
|
||||
VMID: addr.VMID,
|
||||
ServiceID: addr.ServiceID,
|
||||
}
|
||||
}
|
||||
|
||||
func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) {
|
||||
addr.VMID = raw.VMID
|
||||
addr.ServiceID = raw.ServiceID
|
||||
}
|
||||
|
||||
// HvsockListener is a socket listener for the AF_HYPERV address family.
|
||||
type HvsockListener struct {
|
||||
sock *win32File
|
||||
addr HvsockAddr
|
||||
}
|
||||
|
||||
// HvsockConn is a connected socket of the AF_HYPERV address family.
|
||||
type HvsockConn struct {
|
||||
sock *win32File
|
||||
local, remote HvsockAddr
|
||||
}
|
||||
|
||||
func newHvSocket() (*win32File, error) {
|
||||
fd, err := syscall.Socket(afHvSock, syscall.SOCK_STREAM, 1)
|
||||
if err != nil {
|
||||
return nil, os.NewSyscallError("socket", err)
|
||||
}
|
||||
f, err := makeWin32File(fd)
|
||||
if err != nil {
|
||||
syscall.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
f.socket = true
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// ListenHvsock listens for connections on the specified hvsock address.
|
||||
func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) {
|
||||
l := &HvsockListener{addr: *addr}
|
||||
sock, err := newHvSocket()
|
||||
if err != nil {
|
||||
return nil, l.opErr("listen", err)
|
||||
}
|
||||
sa := addr.raw()
|
||||
err = bind(sock.handle, unsafe.Pointer(&sa), int32(unsafe.Sizeof(sa)))
|
||||
if err != nil {
|
||||
return nil, l.opErr("listen", os.NewSyscallError("socket", err))
|
||||
}
|
||||
err = syscall.Listen(sock.handle, 16)
|
||||
if err != nil {
|
||||
return nil, l.opErr("listen", os.NewSyscallError("listen", err))
|
||||
}
|
||||
return &HvsockListener{sock: sock, addr: *addr}, nil
|
||||
}
|
||||
|
||||
func (l *HvsockListener) opErr(op string, err error) error {
|
||||
return &net.OpError{Op: op, Net: "hvsock", Addr: &l.addr, Err: err}
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
func (l *HvsockListener) Addr() net.Addr {
|
||||
return &l.addr
|
||||
}
|
||||
|
||||
// Accept waits for the next connection and returns it.
|
||||
func (l *HvsockListener) Accept() (_ net.Conn, err error) {
|
||||
sock, err := newHvSocket()
|
||||
if err != nil {
|
||||
return nil, l.opErr("accept", err)
|
||||
}
|
||||
defer func() {
|
||||
if sock != nil {
|
||||
sock.Close()
|
||||
}
|
||||
}()
|
||||
c, err := l.sock.prepareIo()
|
||||
if err != nil {
|
||||
return nil, l.opErr("accept", err)
|
||||
}
|
||||
defer l.sock.wg.Done()
|
||||
|
||||
// AcceptEx, per documentation, requires an extra 16 bytes per address.
|
||||
const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{}))
|
||||
var addrbuf [addrlen * 2]byte
|
||||
|
||||
var bytes uint32
|
||||
err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0, addrlen, addrlen, &bytes, &c.o)
|
||||
_, err = l.sock.asyncIo(c, nil, bytes, err)
|
||||
if err != nil {
|
||||
return nil, l.opErr("accept", os.NewSyscallError("acceptex", err))
|
||||
}
|
||||
conn := &HvsockConn{
|
||||
sock: sock,
|
||||
}
|
||||
conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0])))
|
||||
conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen])))
|
||||
sock = nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Close closes the listener, causing any pending Accept calls to fail.
|
||||
func (l *HvsockListener) Close() error {
|
||||
return l.sock.Close()
|
||||
}
|
||||
|
||||
/* Need to finish ConnectEx handling
|
||||
func DialHvsock(ctx context.Context, addr *HvsockAddr) (*HvsockConn, error) {
|
||||
sock, err := newHvSocket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if sock != nil {
|
||||
sock.Close()
|
||||
}
|
||||
}()
|
||||
c, err := sock.prepareIo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer sock.wg.Done()
|
||||
var bytes uint32
|
||||
err = windows.ConnectEx(windows.Handle(sock.handle), sa, nil, 0, &bytes, &c.o)
|
||||
_, err = sock.asyncIo(ctx, c, nil, bytes, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn := &HvsockConn{
|
||||
sock: sock,
|
||||
remote: *addr,
|
||||
}
|
||||
sock = nil
|
||||
return conn, nil
|
||||
}
|
||||
*/
|
||||
|
||||
func (conn *HvsockConn) opErr(op string, err error) error {
|
||||
return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err}
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) Read(b []byte) (int, error) {
|
||||
c, err := conn.sock.prepareIo()
|
||||
if err != nil {
|
||||
return 0, conn.opErr("read", err)
|
||||
}
|
||||
defer conn.sock.wg.Done()
|
||||
buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
|
||||
var flags, bytes uint32
|
||||
err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
|
||||
n, err := conn.sock.asyncIo(c, &conn.sock.readDeadline, bytes, err)
|
||||
if err != nil {
|
||||
if _, ok := err.(syscall.Errno); ok {
|
||||
err = os.NewSyscallError("wsarecv", err)
|
||||
}
|
||||
return 0, conn.opErr("read", err)
|
||||
} else if n == 0 {
|
||||
err = io.EOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) Write(b []byte) (int, error) {
|
||||
t := 0
|
||||
for len(b) != 0 {
|
||||
n, err := conn.write(b)
|
||||
if err != nil {
|
||||
return t + n, err
|
||||
}
|
||||
t += n
|
||||
b = b[n:]
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) write(b []byte) (int, error) {
|
||||
c, err := conn.sock.prepareIo()
|
||||
if err != nil {
|
||||
return 0, conn.opErr("write", err)
|
||||
}
|
||||
defer conn.sock.wg.Done()
|
||||
buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
|
||||
var bytes uint32
|
||||
err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
|
||||
n, err := conn.sock.asyncIo(c, &conn.sock.writeDeadline, bytes, err)
|
||||
if err != nil {
|
||||
if _, ok := err.(syscall.Errno); ok {
|
||||
err = os.NewSyscallError("wsasend", err)
|
||||
}
|
||||
return 0, conn.opErr("write", err)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Close closes the socket connection, failing any pending read or write calls.
|
||||
func (conn *HvsockConn) Close() error {
|
||||
return conn.sock.Close()
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) shutdown(how int) error {
|
||||
err := syscall.Shutdown(conn.sock.handle, syscall.SHUT_RD)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("shutdown", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseRead shuts down the read end of the socket.
|
||||
func (conn *HvsockConn) CloseRead() error {
|
||||
err := conn.shutdown(syscall.SHUT_RD)
|
||||
if err != nil {
|
||||
return conn.opErr("close", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseWrite shuts down the write end of the socket, notifying the other endpoint that
|
||||
// no more data will be written.
|
||||
func (conn *HvsockConn) CloseWrite() error {
|
||||
err := conn.shutdown(syscall.SHUT_WR)
|
||||
if err != nil {
|
||||
return conn.opErr("close", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address of the connection.
|
||||
func (conn *HvsockConn) LocalAddr() net.Addr {
|
||||
return &conn.local
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address of the connection.
|
||||
func (conn *HvsockConn) RemoteAddr() net.Addr {
|
||||
return &conn.remote
|
||||
}
|
||||
|
||||
// SetDeadline implements the net.Conn SetDeadline method.
|
||||
func (conn *HvsockConn) SetDeadline(t time.Time) error {
|
||||
conn.SetReadDeadline(t)
|
||||
conn.SetWriteDeadline(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline implements the net.Conn SetReadDeadline method.
|
||||
func (conn *HvsockConn) SetReadDeadline(t time.Time) error {
|
||||
return conn.sock.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
|
||||
func (conn *HvsockConn) SetWriteDeadline(t time.Time) error {
|
||||
return conn.sock.SetWriteDeadline(t)
|
||||
}
|
||||
+176
-90
@@ -3,10 +3,13 @@
|
||||
package winio
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
@@ -15,10 +18,51 @@ import (
|
||||
//sys connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) = ConnectNamedPipe
|
||||
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateNamedPipeW
|
||||
//sys createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateFileW
|
||||
//sys waitNamedPipe(name string, timeout uint32) (err error) = WaitNamedPipeW
|
||||
//sys getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
|
||||
//sys getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
|
||||
//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc
|
||||
//sys ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile
|
||||
//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
|
||||
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U
|
||||
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl
|
||||
|
||||
type ioStatusBlock struct {
|
||||
Status, Information uintptr
|
||||
}
|
||||
|
||||
type objectAttributes struct {
|
||||
Length uintptr
|
||||
RootDirectory uintptr
|
||||
ObjectName *unicodeString
|
||||
Attributes uintptr
|
||||
SecurityDescriptor *securityDescriptor
|
||||
SecurityQoS uintptr
|
||||
}
|
||||
|
||||
type unicodeString struct {
|
||||
Length uint16
|
||||
MaximumLength uint16
|
||||
Buffer uintptr
|
||||
}
|
||||
|
||||
type securityDescriptor struct {
|
||||
Revision byte
|
||||
Sbz1 byte
|
||||
Control uint16
|
||||
Owner uintptr
|
||||
Group uintptr
|
||||
Sacl uintptr
|
||||
Dacl uintptr
|
||||
}
|
||||
|
||||
type ntstatus int32
|
||||
|
||||
func (status ntstatus) Err() error {
|
||||
if status >= 0 {
|
||||
return nil
|
||||
}
|
||||
return rtlNtStatusToDosError(status)
|
||||
}
|
||||
|
||||
const (
|
||||
cERROR_PIPE_BUSY = syscall.Errno(231)
|
||||
@@ -26,21 +70,20 @@ const (
|
||||
cERROR_PIPE_CONNECTED = syscall.Errno(535)
|
||||
cERROR_SEM_TIMEOUT = syscall.Errno(121)
|
||||
|
||||
cPIPE_ACCESS_DUPLEX = 0x3
|
||||
cFILE_FLAG_FIRST_PIPE_INSTANCE = 0x80000
|
||||
cSECURITY_SQOS_PRESENT = 0x100000
|
||||
cSECURITY_ANONYMOUS = 0
|
||||
|
||||
cPIPE_REJECT_REMOTE_CLIENTS = 0x8
|
||||
|
||||
cPIPE_UNLIMITED_INSTANCES = 255
|
||||
|
||||
cNMPWAIT_USE_DEFAULT_WAIT = 0
|
||||
cNMPWAIT_NOWAIT = 1
|
||||
cSECURITY_SQOS_PRESENT = 0x100000
|
||||
cSECURITY_ANONYMOUS = 0
|
||||
|
||||
cPIPE_TYPE_MESSAGE = 4
|
||||
|
||||
cPIPE_READMODE_MESSAGE = 2
|
||||
|
||||
cFILE_OPEN = 1
|
||||
cFILE_CREATE = 2
|
||||
|
||||
cFILE_PIPE_MESSAGE_TYPE = 1
|
||||
cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2
|
||||
|
||||
cSE_DACL_PRESENT = 4
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -121,6 +164,11 @@ func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
|
||||
// zero-byte message, ensure that all future Read() calls
|
||||
// also return EOF.
|
||||
f.readEOF = true
|
||||
} else if err == syscall.ERROR_MORE_DATA {
|
||||
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
|
||||
// and the message still has more bytes. Treat this as a success, since
|
||||
// this package presents all named pipes as byte streams.
|
||||
err = nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
@@ -133,40 +181,53 @@ func (s pipeAddress) String() string {
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
|
||||
func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return syscall.Handle(0), ctx.Err()
|
||||
default:
|
||||
h, err := createFile(*path, syscall.GENERIC_READ|syscall.GENERIC_WRITE, 0, nil, syscall.OPEN_EXISTING, syscall.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
|
||||
if err == nil {
|
||||
return h, nil
|
||||
}
|
||||
if err != cERROR_PIPE_BUSY {
|
||||
return h, &os.PathError{Err: err, Op: "open", Path: *path}
|
||||
}
|
||||
// Wait 10 msec and try again. This is a rather simplistic
|
||||
// view, as we always try each 10 milliseconds.
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DialPipe connects to a named pipe by path, timing out if the connection
|
||||
// takes longer than the specified duration. If timeout is nil, then the timeout
|
||||
// is the default timeout established by the pipe server.
|
||||
// takes longer than the specified duration. If timeout is nil, then we use
|
||||
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
|
||||
func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
|
||||
var absTimeout time.Time
|
||||
if timeout != nil {
|
||||
absTimeout = time.Now().Add(*timeout)
|
||||
} else {
|
||||
absTimeout = time.Now().Add(time.Second * 2)
|
||||
}
|
||||
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
|
||||
conn, err := DialPipeContext(ctx, path)
|
||||
if err == context.DeadlineExceeded {
|
||||
return nil, ErrTimeout
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
|
||||
// cancellation or timeout.
|
||||
func DialPipeContext(ctx context.Context, path string) (net.Conn, error) {
|
||||
var err error
|
||||
var h syscall.Handle
|
||||
for {
|
||||
h, err = createFile(path, syscall.GENERIC_READ|syscall.GENERIC_WRITE, 0, nil, syscall.OPEN_EXISTING, syscall.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
|
||||
if err != cERROR_PIPE_BUSY {
|
||||
break
|
||||
}
|
||||
now := time.Now()
|
||||
var ms uint32
|
||||
if absTimeout.IsZero() {
|
||||
ms = cNMPWAIT_USE_DEFAULT_WAIT
|
||||
} else if now.After(absTimeout) {
|
||||
ms = cNMPWAIT_NOWAIT
|
||||
} else {
|
||||
ms = uint32(absTimeout.Sub(now).Nanoseconds() / 1000 / 1000)
|
||||
}
|
||||
err = waitNamedPipe(path, ms)
|
||||
if err != nil {
|
||||
if err == cERROR_SEM_TIMEOUT {
|
||||
return nil, ErrTimeout
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
h, err = tryDialPipe(ctx, &path)
|
||||
if err != nil {
|
||||
return nil, &os.PathError{Op: "open", Path: path, Err: err}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var flags uint32
|
||||
@@ -175,16 +236,6 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var state uint32
|
||||
err = getNamedPipeHandleState(h, &state, nil, nil, nil, nil, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if state&cPIPE_READMODE_MESSAGE != 0 {
|
||||
return nil, &os.PathError{Op: "open", Path: path, Err: errors.New("message readmode pipes not supported")}
|
||||
}
|
||||
|
||||
f, err := makeWin32File(h)
|
||||
if err != nil {
|
||||
syscall.Close(h)
|
||||
@@ -207,43 +258,87 @@ type acceptResponse struct {
|
||||
}
|
||||
|
||||
type win32PipeListener struct {
|
||||
firstHandle syscall.Handle
|
||||
path string
|
||||
securityDescriptor []byte
|
||||
config PipeConfig
|
||||
acceptCh chan (chan acceptResponse)
|
||||
closeCh chan int
|
||||
doneCh chan int
|
||||
firstHandle syscall.Handle
|
||||
path string
|
||||
config PipeConfig
|
||||
acceptCh chan (chan acceptResponse)
|
||||
closeCh chan int
|
||||
doneCh chan int
|
||||
}
|
||||
|
||||
func makeServerPipeHandle(path string, securityDescriptor []byte, c *PipeConfig, first bool) (syscall.Handle, error) {
|
||||
var flags uint32 = cPIPE_ACCESS_DUPLEX | syscall.FILE_FLAG_OVERLAPPED
|
||||
if first {
|
||||
flags |= cFILE_FLAG_FIRST_PIPE_INSTANCE
|
||||
}
|
||||
|
||||
var mode uint32 = cPIPE_REJECT_REMOTE_CLIENTS
|
||||
if c.MessageMode {
|
||||
mode |= cPIPE_TYPE_MESSAGE
|
||||
}
|
||||
|
||||
sa := &syscall.SecurityAttributes{}
|
||||
sa.Length = uint32(unsafe.Sizeof(*sa))
|
||||
if securityDescriptor != nil {
|
||||
len := uint32(len(securityDescriptor))
|
||||
sa.SecurityDescriptor = localAlloc(0, len)
|
||||
defer localFree(sa.SecurityDescriptor)
|
||||
copy((*[0xffff]byte)(unsafe.Pointer(sa.SecurityDescriptor))[:], securityDescriptor)
|
||||
}
|
||||
h, err := createNamedPipe(path, flags, mode, cPIPE_UNLIMITED_INSTANCES, uint32(c.OutputBufferSize), uint32(c.InputBufferSize), 0, sa)
|
||||
func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (syscall.Handle, error) {
|
||||
path16, err := syscall.UTF16FromString(path)
|
||||
if err != nil {
|
||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||
}
|
||||
|
||||
var oa objectAttributes
|
||||
oa.Length = unsafe.Sizeof(oa)
|
||||
|
||||
var ntPath unicodeString
|
||||
if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil {
|
||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||
}
|
||||
defer localFree(ntPath.Buffer)
|
||||
oa.ObjectName = &ntPath
|
||||
|
||||
// The security descriptor is only needed for the first pipe.
|
||||
if first {
|
||||
if sd != nil {
|
||||
len := uint32(len(sd))
|
||||
sdb := localAlloc(0, len)
|
||||
defer localFree(sdb)
|
||||
copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd)
|
||||
oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb))
|
||||
} else {
|
||||
// Construct the default named pipe security descriptor.
|
||||
var dacl uintptr
|
||||
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
|
||||
return 0, fmt.Errorf("getting default named pipe ACL: %s", err)
|
||||
}
|
||||
defer localFree(dacl)
|
||||
|
||||
sdb := &securityDescriptor{
|
||||
Revision: 1,
|
||||
Control: cSE_DACL_PRESENT,
|
||||
Dacl: dacl,
|
||||
}
|
||||
oa.SecurityDescriptor = sdb
|
||||
}
|
||||
}
|
||||
|
||||
typ := uint32(cFILE_PIPE_REJECT_REMOTE_CLIENTS)
|
||||
if c.MessageMode {
|
||||
typ |= cFILE_PIPE_MESSAGE_TYPE
|
||||
}
|
||||
|
||||
disposition := uint32(cFILE_OPEN)
|
||||
access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE | syscall.SYNCHRONIZE)
|
||||
if first {
|
||||
disposition = cFILE_CREATE
|
||||
// By not asking for read or write access, the named pipe file system
|
||||
// will put this pipe into an initially disconnected state, blocking
|
||||
// client connections until the next call with first == false.
|
||||
access = syscall.SYNCHRONIZE
|
||||
}
|
||||
|
||||
timeout := int64(-50 * 10000) // 50ms
|
||||
|
||||
var (
|
||||
h syscall.Handle
|
||||
iosb ioStatusBlock
|
||||
)
|
||||
err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err()
|
||||
if err != nil {
|
||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||
}
|
||||
|
||||
runtime.KeepAlive(ntPath)
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
|
||||
h, err := makeServerPipeHandle(l.path, l.securityDescriptor, &l.config, false)
|
||||
h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -354,22 +449,13 @@ func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Immediately open and then close a client handle so that the named pipe is
|
||||
// created but not currently accepting connections.
|
||||
h2, err := createFile(path, 0, 0, nil, syscall.OPEN_EXISTING, cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
|
||||
if err != nil {
|
||||
syscall.Close(h)
|
||||
return nil, err
|
||||
}
|
||||
syscall.Close(h2)
|
||||
l := &win32PipeListener{
|
||||
firstHandle: h,
|
||||
path: path,
|
||||
securityDescriptor: sd,
|
||||
config: *c,
|
||||
acceptCh: make(chan (chan acceptResponse)),
|
||||
closeCh: make(chan int),
|
||||
doneCh: make(chan int),
|
||||
firstHandle: h,
|
||||
path: path,
|
||||
config: *c,
|
||||
acceptCh: make(chan (chan acceptResponse)),
|
||||
closeCh: make(chan int),
|
||||
doneCh: make(chan int),
|
||||
}
|
||||
go l.listenerRoutine()
|
||||
return l, nil
|
||||
|
||||
+235
@@ -0,0 +1,235 @@
|
||||
// Package guid provides a GUID type. The backing structure for a GUID is
|
||||
// identical to that used by the golang.org/x/sys/windows GUID type.
|
||||
// There are two main binary encodings used for a GUID, the big-endian encoding,
|
||||
// and the Windows (mixed-endian) encoding. See here for details:
|
||||
// https://en.wikipedia.org/wiki/Universally_unique_identifier#Encoding
|
||||
package guid
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// Variant specifies which GUID variant (or "type") of the GUID. It determines
|
||||
// how the entirety of the rest of the GUID is interpreted.
|
||||
type Variant uint8
|
||||
|
||||
// The variants specified by RFC 4122.
|
||||
const (
|
||||
// VariantUnknown specifies a GUID variant which does not conform to one of
|
||||
// the variant encodings specified in RFC 4122.
|
||||
VariantUnknown Variant = iota
|
||||
VariantNCS
|
||||
VariantRFC4122
|
||||
VariantMicrosoft
|
||||
VariantFuture
|
||||
)
|
||||
|
||||
// Version specifies how the bits in the GUID were generated. For instance, a
|
||||
// version 4 GUID is randomly generated, and a version 5 is generated from the
|
||||
// hash of an input string.
|
||||
type Version uint8
|
||||
|
||||
var _ = (encoding.TextMarshaler)(GUID{})
|
||||
var _ = (encoding.TextUnmarshaler)(&GUID{})
|
||||
|
||||
// GUID represents a GUID/UUID. It has the same structure as
|
||||
// golang.org/x/sys/windows.GUID so that it can be used with functions expecting
|
||||
// that type. It is defined as its own type so that stringification and
|
||||
// marshaling can be supported. The representation matches that used by native
|
||||
// Windows code.
|
||||
type GUID windows.GUID
|
||||
|
||||
// NewV4 returns a new version 4 (pseudorandom) GUID, as defined by RFC 4122.
|
||||
func NewV4() (GUID, error) {
|
||||
var b [16]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
return GUID{}, err
|
||||
}
|
||||
|
||||
g := FromArray(b)
|
||||
g.setVersion(4) // Version 4 means randomly generated.
|
||||
g.setVariant(VariantRFC4122)
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// NewV5 returns a new version 5 (generated from a string via SHA-1 hashing)
|
||||
// GUID, as defined by RFC 4122. The RFC is unclear on the encoding of the name,
|
||||
// and the sample code treats it as a series of bytes, so we do the same here.
|
||||
//
|
||||
// Some implementations, such as those found on Windows, treat the name as a
|
||||
// big-endian UTF16 stream of bytes. If that is desired, the string can be
|
||||
// encoded as such before being passed to this function.
|
||||
func NewV5(namespace GUID, name []byte) (GUID, error) {
|
||||
b := sha1.New()
|
||||
namespaceBytes := namespace.ToArray()
|
||||
b.Write(namespaceBytes[:])
|
||||
b.Write(name)
|
||||
|
||||
a := [16]byte{}
|
||||
copy(a[:], b.Sum(nil))
|
||||
|
||||
g := FromArray(a)
|
||||
g.setVersion(5) // Version 5 means generated from a string.
|
||||
g.setVariant(VariantRFC4122)
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func fromArray(b [16]byte, order binary.ByteOrder) GUID {
|
||||
var g GUID
|
||||
g.Data1 = order.Uint32(b[0:4])
|
||||
g.Data2 = order.Uint16(b[4:6])
|
||||
g.Data3 = order.Uint16(b[6:8])
|
||||
copy(g.Data4[:], b[8:16])
|
||||
return g
|
||||
}
|
||||
|
||||
func (g GUID) toArray(order binary.ByteOrder) [16]byte {
|
||||
b := [16]byte{}
|
||||
order.PutUint32(b[0:4], g.Data1)
|
||||
order.PutUint16(b[4:6], g.Data2)
|
||||
order.PutUint16(b[6:8], g.Data3)
|
||||
copy(b[8:16], g.Data4[:])
|
||||
return b
|
||||
}
|
||||
|
||||
// FromArray constructs a GUID from a big-endian encoding array of 16 bytes.
|
||||
func FromArray(b [16]byte) GUID {
|
||||
return fromArray(b, binary.BigEndian)
|
||||
}
|
||||
|
||||
// ToArray returns an array of 16 bytes representing the GUID in big-endian
|
||||
// encoding.
|
||||
func (g GUID) ToArray() [16]byte {
|
||||
return g.toArray(binary.BigEndian)
|
||||
}
|
||||
|
||||
// FromWindowsArray constructs a GUID from a Windows encoding array of bytes.
|
||||
func FromWindowsArray(b [16]byte) GUID {
|
||||
return fromArray(b, binary.LittleEndian)
|
||||
}
|
||||
|
||||
// ToWindowsArray returns an array of 16 bytes representing the GUID in Windows
|
||||
// encoding.
|
||||
func (g GUID) ToWindowsArray() [16]byte {
|
||||
return g.toArray(binary.LittleEndian)
|
||||
}
|
||||
|
||||
func (g GUID) String() string {
|
||||
return fmt.Sprintf(
|
||||
"%08x-%04x-%04x-%04x-%012x",
|
||||
g.Data1,
|
||||
g.Data2,
|
||||
g.Data3,
|
||||
g.Data4[:2],
|
||||
g.Data4[2:])
|
||||
}
|
||||
|
||||
// FromString parses a string containing a GUID and returns the GUID. The only
|
||||
// format currently supported is the `xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx`
|
||||
// format.
|
||||
func FromString(s string) (GUID, error) {
|
||||
if len(s) != 36 {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
|
||||
var g GUID
|
||||
|
||||
data1, err := strconv.ParseUint(s[0:8], 16, 32)
|
||||
if err != nil {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
g.Data1 = uint32(data1)
|
||||
|
||||
data2, err := strconv.ParseUint(s[9:13], 16, 16)
|
||||
if err != nil {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
g.Data2 = uint16(data2)
|
||||
|
||||
data3, err := strconv.ParseUint(s[14:18], 16, 16)
|
||||
if err != nil {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
g.Data3 = uint16(data3)
|
||||
|
||||
for i, x := range []int{19, 21, 24, 26, 28, 30, 32, 34} {
|
||||
v, err := strconv.ParseUint(s[x:x+2], 16, 8)
|
||||
if err != nil {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
g.Data4[i] = uint8(v)
|
||||
}
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func (g *GUID) setVariant(v Variant) {
|
||||
d := g.Data4[0]
|
||||
switch v {
|
||||
case VariantNCS:
|
||||
d = (d & 0x7f)
|
||||
case VariantRFC4122:
|
||||
d = (d & 0x3f) | 0x80
|
||||
case VariantMicrosoft:
|
||||
d = (d & 0x1f) | 0xc0
|
||||
case VariantFuture:
|
||||
d = (d & 0x0f) | 0xe0
|
||||
case VariantUnknown:
|
||||
fallthrough
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid variant: %d", v))
|
||||
}
|
||||
g.Data4[0] = d
|
||||
}
|
||||
|
||||
// Variant returns the GUID variant, as defined in RFC 4122.
|
||||
func (g GUID) Variant() Variant {
|
||||
b := g.Data4[0]
|
||||
if b&0x80 == 0 {
|
||||
return VariantNCS
|
||||
} else if b&0xc0 == 0x80 {
|
||||
return VariantRFC4122
|
||||
} else if b&0xe0 == 0xc0 {
|
||||
return VariantMicrosoft
|
||||
} else if b&0xe0 == 0xe0 {
|
||||
return VariantFuture
|
||||
}
|
||||
return VariantUnknown
|
||||
}
|
||||
|
||||
func (g *GUID) setVersion(v Version) {
|
||||
g.Data3 = (g.Data3 & 0x0fff) | (uint16(v) << 12)
|
||||
}
|
||||
|
||||
// Version returns the GUID version, as defined in RFC 4122.
|
||||
func (g GUID) Version() Version {
|
||||
return Version((g.Data3 & 0xF000) >> 12)
|
||||
}
|
||||
|
||||
// MarshalText returns the textual representation of the GUID.
|
||||
func (g GUID) MarshalText() ([]byte, error) {
|
||||
return []byte(g.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText takes the textual representation of a GUID, and unmarhals it
|
||||
// into this GUID.
|
||||
func (g *GUID) UnmarshalText(text []byte) error {
|
||||
g2, err := FromString(string(text))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*g = g2
|
||||
return nil
|
||||
}
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
package winio
|
||||
|
||||
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go file.go pipe.go sd.go fileinfo.go privilege.go backup.go
|
||||
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go file.go pipe.go sd.go fileinfo.go privilege.go backup.go hvsock.go
|
||||
|
||||
+65
-23
@@ -1,4 +1,4 @@
|
||||
// MACHINE GENERATED BY 'go generate' COMMAND; DO NOT EDIT
|
||||
// Code generated by 'go generate'; DO NOT EDIT.
|
||||
|
||||
package winio
|
||||
|
||||
@@ -38,19 +38,25 @@ func errnoErr(e syscall.Errno) error {
|
||||
|
||||
var (
|
||||
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||
modntdll = windows.NewLazySystemDLL("ntdll.dll")
|
||||
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||
|
||||
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
|
||||
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
|
||||
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
|
||||
procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
|
||||
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
|
||||
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
|
||||
procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
|
||||
procCreateFileW = modkernel32.NewProc("CreateFileW")
|
||||
procWaitNamedPipeW = modkernel32.NewProc("WaitNamedPipeW")
|
||||
procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
|
||||
procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
|
||||
procLocalAlloc = modkernel32.NewProc("LocalAlloc")
|
||||
procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
|
||||
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
|
||||
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
|
||||
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
|
||||
procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW")
|
||||
procConvertSidToStringSidW = modadvapi32.NewProc("ConvertSidToStringSidW")
|
||||
procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
|
||||
@@ -69,6 +75,7 @@ var (
|
||||
procLookupPrivilegeDisplayNameW = modadvapi32.NewProc("LookupPrivilegeDisplayNameW")
|
||||
procBackupRead = modkernel32.NewProc("BackupRead")
|
||||
procBackupWrite = modkernel32.NewProc("BackupWrite")
|
||||
procbind = modws2_32.NewProc("bind")
|
||||
)
|
||||
|
||||
func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
|
||||
@@ -120,6 +127,24 @@ func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err erro
|
||||
return
|
||||
}
|
||||
|
||||
func wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
|
||||
var _p0 uint32
|
||||
if wait {
|
||||
_p0 = 1
|
||||
} else {
|
||||
_p0 = 0
|
||||
}
|
||||
r1, _, e1 := syscall.Syscall6(procWSAGetOverlappedResult.Addr(), 5, uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)), 0)
|
||||
if r1 == 0 {
|
||||
if e1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
} else {
|
||||
err = syscall.EINVAL
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
|
||||
if r1 == 0 {
|
||||
@@ -176,27 +201,6 @@ func _createFile(name *uint16, access uint32, mode uint32, sa *syscall.SecurityA
|
||||
return
|
||||
}
|
||||
|
||||
func waitNamedPipe(name string, timeout uint32) (err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _waitNamedPipe(_p0, timeout)
|
||||
}
|
||||
|
||||
func _waitNamedPipe(name *uint16, timeout uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procWaitNamedPipeW.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(timeout), 0)
|
||||
if r1 == 0 {
|
||||
if e1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
} else {
|
||||
err = syscall.EINVAL
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
|
||||
if r1 == 0 {
|
||||
@@ -227,6 +231,32 @@ func localAlloc(uFlags uint32, length uint32) (ptr uintptr) {
|
||||
return
|
||||
}
|
||||
|
||||
func ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) {
|
||||
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
|
||||
status = ntstatus(r0)
|
||||
return
|
||||
}
|
||||
|
||||
func rtlNtStatusToDosError(status ntstatus) (winerr error) {
|
||||
r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0)
|
||||
if r0 != 0 {
|
||||
winerr = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) {
|
||||
r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0)
|
||||
status = ntstatus(r0)
|
||||
return
|
||||
}
|
||||
|
||||
func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) {
|
||||
r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0)
|
||||
status = ntstatus(r0)
|
||||
return
|
||||
}
|
||||
|
||||
func lookupAccountName(systemName *uint16, accountName string, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(accountName)
|
||||
@@ -518,3 +548,15 @@ func backupWrite(h syscall.Handle, b []byte, bytesWritten *uint32, abort bool, p
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func bind(s syscall.Handle, name unsafe.Pointer, namelen int32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procbind.Addr(), 3, uintptr(s), uintptr(name), uintptr(namelen))
|
||||
if r1 == socketError {
|
||||
if e1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
} else {
|
||||
err = syscall.EINVAL
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
Generated
Generated
Vendored
+1
-1
@@ -1,6 +1,6 @@
|
||||
language: go
|
||||
go:
|
||||
- 1.3.3
|
||||
- 1.7
|
||||
- 1.x
|
||||
- tip
|
||||
before_install:
|
||||
Generated
Vendored
Generated
Vendored
+1
-1
@@ -24,7 +24,7 @@ See https://godoc.org/github.com/cenkalti/backoff#pkg-examples
|
||||
[coveralls]: https://coveralls.io/github/cenkalti/backoff?branch=master
|
||||
[coveralls image]: https://coveralls.io/repos/github/cenkalti/backoff/badge.svg?branch=master
|
||||
|
||||
[google-http-java-client]: https://github.com/google/google-http-java-client
|
||||
[google-http-java-client]: https://github.com/google/google-http-java-client/blob/da1aa993e90285ec18579f1553339b00e19b3ab5/google-http-client/src/main/java/com/google/api/client/util/ExponentialBackOff.java
|
||||
[exponential backoff wiki]: http://en.wikipedia.org/wiki/Exponential_backoff
|
||||
|
||||
[advanced example]: https://godoc.org/github.com/cenkalti/backoff#example_
|
||||
Generated
Vendored
Generated
Vendored
+7
-4
@@ -1,9 +1,8 @@
|
||||
package backoff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// BackOffContext is a backoff policy that stops retrying after the context
|
||||
@@ -52,9 +51,13 @@ func (b *backOffContext) Context() context.Context {
|
||||
|
||||
func (b *backOffContext) NextBackOff() time.Duration {
|
||||
select {
|
||||
case <-b.Context().Done():
|
||||
case <-b.ctx.Done():
|
||||
return Stop
|
||||
default:
|
||||
return b.BackOff.NextBackOff()
|
||||
}
|
||||
next := b.BackOff.NextBackOff()
|
||||
if deadline, ok := b.ctx.Deadline(); ok && deadline.Sub(time.Now()) < next {
|
||||
return Stop
|
||||
}
|
||||
return next
|
||||
}
|
||||
Generated
Vendored
+1
-6
@@ -63,7 +63,6 @@ type ExponentialBackOff struct {
|
||||
|
||||
currentInterval time.Duration
|
||||
startTime time.Time
|
||||
random *rand.Rand
|
||||
}
|
||||
|
||||
// Clock is an interface that returns current time for BackOff.
|
||||
@@ -89,7 +88,6 @@ func NewExponentialBackOff() *ExponentialBackOff {
|
||||
MaxInterval: DefaultMaxInterval,
|
||||
MaxElapsedTime: DefaultMaxElapsedTime,
|
||||
Clock: SystemClock,
|
||||
random: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
b.Reset()
|
||||
return b
|
||||
@@ -118,10 +116,7 @@ func (b *ExponentialBackOff) NextBackOff() time.Duration {
|
||||
return Stop
|
||||
}
|
||||
defer b.incrementCurrentInterval()
|
||||
if b.random == nil {
|
||||
b.random = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
}
|
||||
return getRandomValueFromInterval(b.RandomizationFactor, b.random.Float64(), b.currentInterval)
|
||||
return getRandomValueFromInterval(b.RandomizationFactor, rand.Float64(), b.currentInterval)
|
||||
}
|
||||
|
||||
// GetElapsedTime returns the elapsed time since an ExponentialBackOff instance
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
module github.com/cenkalti/backoff/v3
|
||||
|
||||
go 1.12
|
||||
Generated
Vendored
+8
-4
@@ -15,7 +15,6 @@ type Notify func(error, time.Duration)
|
||||
|
||||
// Retry the operation o until it does not return error or BackOff stops.
|
||||
// o is guaranteed to be run at least once.
|
||||
// It is the caller's responsibility to reset b after Retry returns.
|
||||
//
|
||||
// If o returns a *PermanentError, the operation is not retried, and the
|
||||
// wrapped error is returned.
|
||||
@@ -29,6 +28,7 @@ func Retry(o Operation, b BackOff) error { return RetryNotify(o, b, nil) }
|
||||
func RetryNotify(operation Operation, b BackOff, notify Notify) error {
|
||||
var err error
|
||||
var next time.Duration
|
||||
var t *time.Timer
|
||||
|
||||
cb := ensureContext(b)
|
||||
|
||||
@@ -42,7 +42,7 @@ func RetryNotify(operation Operation, b BackOff, notify Notify) error {
|
||||
return permanent.Err
|
||||
}
|
||||
|
||||
if next = b.NextBackOff(); next == Stop {
|
||||
if next = cb.NextBackOff(); next == Stop {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -50,11 +50,15 @@ func RetryNotify(operation Operation, b BackOff, notify Notify) error {
|
||||
notify(err, next)
|
||||
}
|
||||
|
||||
t := time.NewTimer(next)
|
||||
if t == nil {
|
||||
t = time.NewTimer(next)
|
||||
defer t.Stop()
|
||||
} else {
|
||||
t.Reset(next)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-cb.Context().Done():
|
||||
t.Stop()
|
||||
return err
|
||||
case <-t.C:
|
||||
}
|
||||
Generated
Vendored
-2
@@ -1,7 +1,6 @@
|
||||
package backoff
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -34,7 +33,6 @@ func NewTicker(b BackOff) *Ticker {
|
||||
}
|
||||
t.b.Reset()
|
||||
go t.run()
|
||||
runtime.SetFinalizer(t, (*Ticker).Stop)
|
||||
return t
|
||||
}
|
||||
|
||||
Generated
Vendored
+4
-15
@@ -1,6 +1,7 @@
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
https://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
@@ -175,28 +176,16 @@
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
Copyright The containerd Authors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
https://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
|
||||
+16
@@ -1,3 +1,19 @@
|
||||
/*
|
||||
Copyright The containerd Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package pathdriver
|
||||
|
||||
import (
|
||||
|
||||
+1
-1
@@ -113,7 +113,7 @@ func SplitProtoPort(rawPort string) (string, string) {
|
||||
}
|
||||
|
||||
func validateProto(proto string) bool {
|
||||
for _, availableProto := range []string{"tcp", "udp"} {
|
||||
for _, availableProto := range []string{"tcp", "udp", "sctp"} {
|
||||
if availableProto == proto {
|
||||
return true
|
||||
}
|
||||
|
||||
+1
-1
@@ -27,7 +27,7 @@
|
||||
|
||||
[people.akihirosuda]
|
||||
Name = "Akihiro Suda"
|
||||
Email = "suda.akihiro@lab.ntt.co.jp"
|
||||
Email = "akihiro.suda.cz@hco.ntt.co.jp"
|
||||
GitHub = "AkihiroSuda"
|
||||
|
||||
[people.dnephin]
|
||||
|
||||
+1
-1
@@ -1,7 +1,7 @@
|
||||
dependencies:
|
||||
post:
|
||||
# install golint
|
||||
- go get github.com/golang/lint/golint
|
||||
- go get golang.org/x/lint/golint
|
||||
|
||||
test:
|
||||
pre:
|
||||
|
||||
+1
-1
@@ -18,7 +18,7 @@ func HumanDuration(d time.Duration) string {
|
||||
return fmt.Sprintf("%d seconds", seconds)
|
||||
} else if minutes := int(d.Minutes()); minutes == 1 {
|
||||
return "About a minute"
|
||||
} else if minutes < 46 {
|
||||
} else if minutes < 60 {
|
||||
return fmt.Sprintf("%d minutes", minutes)
|
||||
} else if hours := int(d.Hours() + 0.5); hours == 1 {
|
||||
return "About an hour"
|
||||
|
||||
+7
-2
@@ -96,8 +96,13 @@ func ParseUlimit(val string) (*Ulimit, error) {
|
||||
return nil, fmt.Errorf("too many limit value arguments - %s, can only have up to two, `soft[:hard]`", parts[1])
|
||||
}
|
||||
|
||||
if soft > *hard {
|
||||
return nil, fmt.Errorf("ulimit soft limit must be less than or equal to hard limit: %d > %d", soft, *hard)
|
||||
if *hard != -1 {
|
||||
if soft == -1 {
|
||||
return nil, fmt.Errorf("ulimit soft limit must be less than or equal to hard limit: soft: -1 (unlimited), hard: %d", *hard)
|
||||
}
|
||||
if soft > *hard {
|
||||
return nil, fmt.Errorf("ulimit soft limit must be less than or equal to hard limit: %d > %d", soft, *hard)
|
||||
}
|
||||
}
|
||||
|
||||
return &Ulimit{Name: parts[0], Soft: soft, Hard: *hard}, nil
|
||||
|
||||
+1
-1
@@ -210,7 +210,7 @@ var optionDefs = [256]optionDef{
|
||||
}
|
||||
|
||||
// MediaType specifies the content type of a message.
|
||||
type MediaType byte
|
||||
type MediaType uint16
|
||||
|
||||
// Content types.
|
||||
const (
|
||||
|
||||
-5
@@ -1,5 +0,0 @@
|
||||
language: go
|
||||
go:
|
||||
- 1.8.x
|
||||
- tip
|
||||
|
||||
-27
@@ -1,27 +0,0 @@
|
||||
# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'.
|
||||
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/mattn/go-colorable"
|
||||
packages = ["."]
|
||||
revision = "167de6bfdfba052fa6b2d3664c8f5272e23c9072"
|
||||
version = "v0.0.9"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/mattn/go-isatty"
|
||||
packages = ["."]
|
||||
revision = "0360b2af4f38e8d38c7fce2a9f4e702702d73a39"
|
||||
version = "v0.0.3"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "golang.org/x/sys"
|
||||
packages = ["unix"]
|
||||
revision = "37707fdb30a5b38865cfb95e5aab41707daec7fd"
|
||||
|
||||
[solve-meta]
|
||||
analyzer-name = "dep"
|
||||
analyzer-version = 1
|
||||
inputs-digest = "e8a50671c3cb93ea935bf210b1cd20702876b9d9226129be581ef646d1565cdc"
|
||||
solver-name = "gps-cdcl"
|
||||
solver-version = 1
|
||||
-30
@@ -1,30 +0,0 @@
|
||||
|
||||
# Gopkg.toml example
|
||||
#
|
||||
# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md
|
||||
# for detailed Gopkg.toml documentation.
|
||||
#
|
||||
# required = ["github.com/user/thing/cmd/thing"]
|
||||
# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"]
|
||||
#
|
||||
# [[constraint]]
|
||||
# name = "github.com/user/project"
|
||||
# version = "1.0.0"
|
||||
#
|
||||
# [[constraint]]
|
||||
# name = "github.com/user/project2"
|
||||
# branch = "dev"
|
||||
# source = "github.com/myfork/project2"
|
||||
#
|
||||
# [[override]]
|
||||
# name = "github.com/x/y"
|
||||
# version = "2.4.0"
|
||||
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/mattn/go-colorable"
|
||||
version = "0.0.9"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/mattn/go-isatty"
|
||||
version = "0.0.3"
|
||||
+7
-4
@@ -1,6 +1,12 @@
|
||||
# Color [](https://godoc.org/github.com/fatih/color) [](https://travis-ci.org/fatih/color)
|
||||
# Archived project. No maintenance.
|
||||
|
||||
This project is not maintained anymore and is archived. Feel free to fork and
|
||||
make your own changes if needed. For more detail read my blog post: [Taking an indefinite sabbatical from my projects](https://arslan.io/2018/10/09/taking-an-indefinite-sabbatical-from-my-projects/)
|
||||
|
||||
Thanks to everyone for their valuable feedback and contributions.
|
||||
|
||||
|
||||
# Color [](https://godoc.org/github.com/fatih/color)
|
||||
|
||||
Color lets you use colorized outputs in terms of [ANSI Escape
|
||||
Codes](http://en.wikipedia.org/wiki/ANSI_escape_code#Colors) in Go (Golang). It
|
||||
@@ -17,9 +23,6 @@ suits you.
|
||||
go get github.com/fatih/color
|
||||
```
|
||||
|
||||
Note that the `vendor` folder is here for stability. Remove the folder if you
|
||||
already have the dependencies in your GOPATH.
|
||||
|
||||
## Examples
|
||||
|
||||
### Standard colors
|
||||
|
||||
+8
@@ -0,0 +1,8 @@
|
||||
module github.com/fatih/color
|
||||
|
||||
go 1.13
|
||||
|
||||
require (
|
||||
github.com/mattn/go-colorable v0.1.4
|
||||
github.com/mattn/go-isatty v0.0.11
|
||||
)
|
||||
+8
@@ -0,0 +1,8 @@
|
||||
github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA=
|
||||
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
|
||||
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||
github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM=
|
||||
github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE=
|
||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 h1:YyJpGZS1sBuBCzLAR1VEpK193GlqGZbnPFnPV/5Rsb4=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
+3
-1
@@ -31,7 +31,9 @@ func (l *jsonLogger) Log(keyvals ...interface{}) error {
|
||||
}
|
||||
merge(m, k, v)
|
||||
}
|
||||
return json.NewEncoder(l.Writer).Encode(m)
|
||||
enc := json.NewEncoder(l.Writer)
|
||||
enc.SetEscapeHTML(false)
|
||||
return enc.Encode(m)
|
||||
}
|
||||
|
||||
func merge(dst map[string]interface{}, k, v interface{}) {
|
||||
|
||||
+1
-1
@@ -51,7 +51,7 @@ type Gauge struct {
|
||||
lvs lv.LabelValues
|
||||
}
|
||||
|
||||
// NewGaugeFrom construts and registers a Prometheus GaugeVec,
|
||||
// NewGaugeFrom constructs and registers a Prometheus GaugeVec,
|
||||
// and returns a usable Gauge object.
|
||||
func NewGaugeFrom(opts prometheus.GaugeOpts, labelNames []string) *Gauge {
|
||||
gv := prometheus.NewGaugeVec(opts, labelNames)
|
||||
|
||||
+11
@@ -26,3 +26,14 @@ func NewLogErrorHandler(logger log.Logger) *LogErrorHandler {
|
||||
func (h *LogErrorHandler) Handle(ctx context.Context, err error) {
|
||||
h.logger.Log("err", err)
|
||||
}
|
||||
|
||||
// The ErrorHandlerFunc type is an adapter to allow the use of
|
||||
// ordinary function as ErrorHandler. If f is a function
|
||||
// with the appropriate signature, ErrorHandlerFunc(f) is a
|
||||
// ErrorHandler that calls f.
|
||||
type ErrorHandlerFunc func(ctx context.Context, err error)
|
||||
|
||||
// Handle calls f(ctx, err).
|
||||
func (f ErrorHandlerFunc) Handle(ctx context.Context, err error) {
|
||||
f(ctx, err)
|
||||
}
|
||||
|
||||
+1
-4
@@ -1,4 +1 @@
|
||||
_testdata/
|
||||
_testdata2/
|
||||
logfmt-fuzz.zip
|
||||
logfmt.test.exe
|
||||
.vscode/
|
||||
|
||||
+2
@@ -6,6 +6,8 @@ go:
|
||||
- "1.9.x"
|
||||
- "1.10.x"
|
||||
- "1.11.x"
|
||||
- "1.12.x"
|
||||
- "1.13.x"
|
||||
- "tip"
|
||||
|
||||
before_install:
|
||||
|
||||
+7
@@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.5.0] - 2020-01-03
|
||||
|
||||
### Changed
|
||||
- Remove the dependency on github.com/kr/logfmt by [@ChrisHines]
|
||||
- Move fuzz code to github.com/go-logfmt/fuzzlogfmt by [@ChrisHines]
|
||||
|
||||
## [0.4.0] - 2018-11-21
|
||||
|
||||
### Added
|
||||
@@ -30,6 +36,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Decoder by [@ChrisHines]
|
||||
- MarshalKeyvals by [@ChrisHines]
|
||||
|
||||
[0.5.0]: https://github.com/go-logfmt/logfmt/compare/v0.4.0...v0.5.0
|
||||
[0.4.0]: https://github.com/go-logfmt/logfmt/compare/v0.3.0...v0.4.0
|
||||
[0.3.0]: https://github.com/go-logfmt/logfmt/compare/v0.2.0...v0.3.0
|
||||
[0.2.0]: https://github.com/go-logfmt/logfmt/compare/v0.1.0...v0.2.0
|
||||
|
||||
+3
-3
@@ -79,7 +79,7 @@ key:
|
||||
dec.pos += p
|
||||
if dec.pos > start {
|
||||
dec.key = line[start:dec.pos]
|
||||
if multibyte && bytes.IndexRune(dec.key, utf8.RuneError) != -1 {
|
||||
if multibyte && bytes.ContainsRune(dec.key, utf8.RuneError) {
|
||||
dec.syntaxError(invalidKeyError)
|
||||
return false
|
||||
}
|
||||
@@ -97,7 +97,7 @@ key:
|
||||
dec.pos += p
|
||||
if dec.pos > start {
|
||||
dec.key = line[start:dec.pos]
|
||||
if multibyte && bytes.IndexRune(dec.key, utf8.RuneError) != -1 {
|
||||
if multibyte && bytes.ContainsRune(dec.key, utf8.RuneError) {
|
||||
dec.syntaxError(invalidKeyError)
|
||||
return false
|
||||
}
|
||||
@@ -110,7 +110,7 @@ key:
|
||||
dec.pos = len(line)
|
||||
if dec.pos > start {
|
||||
dec.key = line[start:dec.pos]
|
||||
if multibyte && bytes.IndexRune(dec.key, utf8.RuneError) != -1 {
|
||||
if multibyte && bytes.ContainsRune(dec.key, utf8.RuneError) {
|
||||
dec.syntaxError(invalidKeyError)
|
||||
return false
|
||||
}
|
||||
|
||||
-126
@@ -1,126 +0,0 @@
|
||||
// +build gofuzz
|
||||
|
||||
package logfmt
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
|
||||
kr "github.com/kr/logfmt"
|
||||
)
|
||||
|
||||
// Fuzz checks reserialized data matches
|
||||
func Fuzz(data []byte) int {
|
||||
parsed, err := parse(data)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
var w1 bytes.Buffer
|
||||
if err = write(parsed, &w1); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
parsed, err = parse(w1.Bytes())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var w2 bytes.Buffer
|
||||
if err = write(parsed, &w2); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if !bytes.Equal(w1.Bytes(), w2.Bytes()) {
|
||||
panic(fmt.Sprintf("reserialized data does not match:\n%q\n%q\n", w1.Bytes(), w2.Bytes()))
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// FuzzVsKR checks go-logfmt/logfmt against kr/logfmt
|
||||
func FuzzVsKR(data []byte) int {
|
||||
parsed, err := parse(data)
|
||||
parsedKR, errKR := parseKR(data)
|
||||
|
||||
// github.com/go-logfmt/logfmt is a stricter parser. It returns errors for
|
||||
// more inputs than github.com/kr/logfmt. Ignore any inputs that have a
|
||||
// stict error.
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Fail if the more forgiving parser finds an error not found by the
|
||||
// stricter parser.
|
||||
if errKR != nil {
|
||||
panic(fmt.Sprintf("unmatched error: %v", errKR))
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(parsed, parsedKR) {
|
||||
panic(fmt.Sprintf("parsers disagree:\n%+v\n%+v\n", parsed, parsedKR))
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
type kv struct {
|
||||
k, v []byte
|
||||
}
|
||||
|
||||
func parse(data []byte) ([][]kv, error) {
|
||||
var got [][]kv
|
||||
dec := NewDecoder(bytes.NewReader(data))
|
||||
for dec.ScanRecord() {
|
||||
var kvs []kv
|
||||
for dec.ScanKeyval() {
|
||||
kvs = append(kvs, kv{dec.Key(), dec.Value()})
|
||||
}
|
||||
got = append(got, kvs)
|
||||
}
|
||||
return got, dec.Err()
|
||||
}
|
||||
|
||||
func parseKR(data []byte) ([][]kv, error) {
|
||||
var (
|
||||
s = bufio.NewScanner(bytes.NewReader(data))
|
||||
err error
|
||||
h saveHandler
|
||||
got [][]kv
|
||||
)
|
||||
for err == nil && s.Scan() {
|
||||
h.kvs = nil
|
||||
err = kr.Unmarshal(s.Bytes(), &h)
|
||||
got = append(got, h.kvs)
|
||||
}
|
||||
if err == nil {
|
||||
err = s.Err()
|
||||
}
|
||||
return got, err
|
||||
}
|
||||
|
||||
type saveHandler struct {
|
||||
kvs []kv
|
||||
}
|
||||
|
||||
func (h *saveHandler) HandleLogfmt(key, val []byte) error {
|
||||
if len(key) == 0 {
|
||||
key = nil
|
||||
}
|
||||
if len(val) == 0 {
|
||||
val = nil
|
||||
}
|
||||
h.kvs = append(h.kvs, kv{key, val})
|
||||
return nil
|
||||
}
|
||||
|
||||
func write(recs [][]kv, w io.Writer) error {
|
||||
enc := NewEncoder(w)
|
||||
for _, rec := range recs {
|
||||
for _, f := range rec {
|
||||
if err := enc.EncodeKeyval(f.k, f.v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := enc.EndRecord(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
module github.com/go-logfmt/logfmt
|
||||
|
||||
require github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515
|
||||
go 1.13
|
||||
|
||||
-2
@@ -1,2 +0,0 @@
|
||||
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 h1:T+h1c/A9Gawja4Y9mFVWj2vyii2bbUNDw3kt9VxK2EY=
|
||||
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
|
||||
+1
@@ -8,6 +8,7 @@ go:
|
||||
- 1.9.x
|
||||
- 1.10.x
|
||||
- 1.11.x
|
||||
- 1.12.x
|
||||
- tip
|
||||
|
||||
matrix:
|
||||
|
||||
+1
-1
@@ -15,7 +15,7 @@ bench: testdeps
|
||||
|
||||
testdata/redis:
|
||||
mkdir -p $@
|
||||
wget -qO- https://github.com/antirez/redis/archive/unstable.tar.gz | tar xvz --strip-components=1 -C $@
|
||||
wget -qO- https://github.com/antirez/redis/archive/5.0.tar.gz | tar xvz --strip-components=1 -C $@
|
||||
|
||||
testdata/redis/src/redis-server: testdata/redis
|
||||
sed -i.bak 's/libjemalloc.a/libjemalloc.a -lrt/g' $</src/Makefile
|
||||
|
||||
+2
-2
@@ -9,7 +9,7 @@ Supports:
|
||||
- Redis 3 commands except QUIT, MONITOR, SLOWLOG and SYNC.
|
||||
- Automatic connection pooling with [circuit breaker](https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern) support.
|
||||
- [Pub/Sub](https://godoc.org/github.com/go-redis/redis#PubSub).
|
||||
- [Transactions](https://godoc.org/github.com/go-redis/redis#Multi).
|
||||
- [Transactions](https://godoc.org/github.com/go-redis/redis#example-Client-TxPipeline).
|
||||
- [Pipeline](https://godoc.org/github.com/go-redis/redis#example-Client-Pipeline) and [TxPipeline](https://godoc.org/github.com/go-redis/redis#example-Client-TxPipeline).
|
||||
- [Scripting](https://godoc.org/github.com/go-redis/redis#Script).
|
||||
- [Timeouts](https://godoc.org/github.com/go-redis/redis#Options).
|
||||
@@ -143,4 +143,4 @@ BenchmarkRedisClusterPing-4 100000 11535 ns/op 117 B/op
|
||||
|
||||
- [Golang PostgreSQL ORM](https://github.com/go-pg/pg)
|
||||
- [Golang msgpack](https://github.com/vmihailenco/msgpack)
|
||||
- [Golang message task queue](https://github.com/go-msgqueue/msgqueue)
|
||||
- [Golang message task queue](https://github.com/vmihailenco/taskq)
|
||||
|
||||
+66
-59
@@ -17,7 +17,6 @@ import (
|
||||
"github.com/go-redis/redis/internal/hashtag"
|
||||
"github.com/go-redis/redis/internal/pool"
|
||||
"github.com/go-redis/redis/internal/proto"
|
||||
"github.com/go-redis/redis/internal/singleflight"
|
||||
)
|
||||
|
||||
var errClusterNoNodes = fmt.Errorf("redis: cluster has no nodes")
|
||||
@@ -243,8 +242,6 @@ type clusterNodes struct {
|
||||
clusterAddrs []string
|
||||
closed bool
|
||||
|
||||
nodeCreateGroup singleflight.Group
|
||||
|
||||
_generation uint32 // atomic
|
||||
}
|
||||
|
||||
@@ -347,11 +344,6 @@ func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) {
|
||||
return node, nil
|
||||
}
|
||||
|
||||
v, err := c.nodeCreateGroup.Do(addr, func() (interface{}, error) {
|
||||
node := newClusterNode(c.opt, addr)
|
||||
return node, nil
|
||||
})
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
@@ -361,15 +353,13 @@ func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) {
|
||||
|
||||
node, ok := c.allNodes[addr]
|
||||
if ok {
|
||||
_ = v.(*clusterNode).Close()
|
||||
return node, err
|
||||
}
|
||||
node = v.(*clusterNode)
|
||||
|
||||
node = newClusterNode(c.opt, addr)
|
||||
|
||||
c.allAddrs = appendIfNotExists(c.allAddrs, addr)
|
||||
if err == nil {
|
||||
c.clusterAddrs = append(c.clusterAddrs, addr)
|
||||
}
|
||||
c.clusterAddrs = append(c.clusterAddrs, addr)
|
||||
c.allNodes[addr] = node
|
||||
|
||||
return node, err
|
||||
@@ -713,12 +703,12 @@ func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient {
|
||||
if ctx == nil {
|
||||
panic("nil context")
|
||||
}
|
||||
c2 := c.copy()
|
||||
c2 := c.clone()
|
||||
c2.ctx = ctx
|
||||
return c2
|
||||
}
|
||||
|
||||
func (c *ClusterClient) copy() *ClusterClient {
|
||||
func (c *ClusterClient) clone() *ClusterClient {
|
||||
cp := *c
|
||||
cp.init()
|
||||
return &cp
|
||||
@@ -782,6 +772,11 @@ func cmdSlot(cmd Cmder, pos int) int {
|
||||
}
|
||||
|
||||
func (c *ClusterClient) cmdSlot(cmd Cmder) int {
|
||||
args := cmd.Args()
|
||||
if args[0] == "cluster" && args[1] == "getkeysinslot" {
|
||||
return args[2].(int)
|
||||
}
|
||||
|
||||
cmdInfo := c.cmdInfo(cmd.Name())
|
||||
return cmdSlot(cmd, cmdFirstKeyPos(cmd, cmdInfo))
|
||||
}
|
||||
@@ -793,7 +788,7 @@ func (c *ClusterClient) cmdSlotAndNode(cmd Cmder) (int, *clusterNode, error) {
|
||||
}
|
||||
|
||||
cmdInfo := c.cmdInfo(cmd.Name())
|
||||
slot := cmdSlot(cmd, cmdFirstKeyPos(cmd, cmdInfo))
|
||||
slot := c.cmdSlot(cmd)
|
||||
|
||||
if c.opt.ReadOnly && cmdInfo != nil && cmdInfo.ReadOnly {
|
||||
if c.opt.RouteByLatency {
|
||||
@@ -854,15 +849,12 @@ func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error {
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if internal.IsRetryableError(err, true) {
|
||||
if err != Nil {
|
||||
c.state.LazyReload()
|
||||
continue
|
||||
}
|
||||
|
||||
moved, ask, addr := internal.IsMovedError(err)
|
||||
if moved || ask {
|
||||
c.state.LazyReload()
|
||||
node, err = c.nodes.GetOrCreate(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -870,7 +862,7 @@ func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error {
|
||||
continue
|
||||
}
|
||||
|
||||
if err == pool.ErrClosed {
|
||||
if err == pool.ErrClosed || internal.IsReadOnlyError(err) {
|
||||
node, err = c.slotMasterNode(slot)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -878,6 +870,10 @@ func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error {
|
||||
continue
|
||||
}
|
||||
|
||||
if internal.IsRetryableError(err, true) {
|
||||
continue
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -942,6 +938,9 @@ func (c *ClusterClient) defaultProcess(cmd Cmder) error {
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if err != Nil {
|
||||
c.state.LazyReload()
|
||||
}
|
||||
|
||||
// If slave is loading - pick another node.
|
||||
if c.opt.ReadOnly && internal.IsLoadingError(err) {
|
||||
@@ -950,9 +949,23 @@ func (c *ClusterClient) defaultProcess(cmd Cmder) error {
|
||||
continue
|
||||
}
|
||||
|
||||
if internal.IsRetryableError(err, true) {
|
||||
c.state.LazyReload()
|
||||
var moved bool
|
||||
var addr string
|
||||
moved, ask, addr = internal.IsMovedError(err)
|
||||
if moved || ask {
|
||||
node, err = c.nodes.GetOrCreate(addr)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err == pool.ErrClosed || internal.IsReadOnlyError(err) {
|
||||
node = nil
|
||||
continue
|
||||
}
|
||||
|
||||
if internal.IsRetryableError(err, true) {
|
||||
// First retry the same node.
|
||||
if attempt == 0 {
|
||||
continue
|
||||
@@ -966,24 +979,6 @@ func (c *ClusterClient) defaultProcess(cmd Cmder) error {
|
||||
continue
|
||||
}
|
||||
|
||||
var moved bool
|
||||
var addr string
|
||||
moved, ask, addr = internal.IsMovedError(err)
|
||||
if moved || ask {
|
||||
c.state.LazyReload()
|
||||
|
||||
node, err = c.nodes.GetOrCreate(addr)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err == pool.ErrClosed {
|
||||
node = nil
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
@@ -1203,6 +1198,7 @@ func (c *ClusterClient) WrapProcessPipeline(
|
||||
fn func(oldProcess func([]Cmder) error) func([]Cmder) error,
|
||||
) {
|
||||
c.processPipeline = fn(c.processPipeline)
|
||||
c.processTxPipeline = fn(c.processTxPipeline)
|
||||
}
|
||||
|
||||
func (c *ClusterClient) defaultProcessPipeline(cmds []Cmder) error {
|
||||
@@ -1532,40 +1528,51 @@ func (c *ClusterClient) txPipelineReadQueued(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClusterClient) pubSub(channels []string) *PubSub {
|
||||
func (c *ClusterClient) pubSub() *PubSub {
|
||||
var node *clusterNode
|
||||
pubsub := &PubSub{
|
||||
opt: c.opt.clientOptions(),
|
||||
|
||||
newConn: func(channels []string) (*pool.Conn, error) {
|
||||
if node == nil {
|
||||
var slot int
|
||||
if len(channels) > 0 {
|
||||
slot = hashtag.Slot(channels[0])
|
||||
} else {
|
||||
slot = -1
|
||||
}
|
||||
|
||||
masterNode, err := c.slotMasterNode(slot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
node = masterNode
|
||||
if node != nil {
|
||||
panic("node != nil")
|
||||
}
|
||||
return node.Client.newConn()
|
||||
|
||||
var err error
|
||||
if len(channels) > 0 {
|
||||
slot := hashtag.Slot(channels[0])
|
||||
node, err = c.slotMasterNode(slot)
|
||||
} else {
|
||||
node, err = c.nodes.Random()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cn, err := node.Client.newConn()
|
||||
if err != nil {
|
||||
node = nil
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cn, nil
|
||||
},
|
||||
closeConn: func(cn *pool.Conn) error {
|
||||
return node.Client.connPool.CloseConn(cn)
|
||||
err := node.Client.connPool.CloseConn(cn)
|
||||
node = nil
|
||||
return err
|
||||
},
|
||||
}
|
||||
pubsub.init()
|
||||
|
||||
return pubsub
|
||||
}
|
||||
|
||||
// Subscribe subscribes the client to the specified channels.
|
||||
// Channels can be omitted to create empty subscription.
|
||||
func (c *ClusterClient) Subscribe(channels ...string) *PubSub {
|
||||
pubsub := c.pubSub(channels)
|
||||
pubsub := c.pubSub()
|
||||
if len(channels) > 0 {
|
||||
_ = pubsub.Subscribe(channels...)
|
||||
}
|
||||
@@ -1575,7 +1582,7 @@ func (c *ClusterClient) Subscribe(channels ...string) *PubSub {
|
||||
// PSubscribe subscribes the client to the given patterns.
|
||||
// Patterns can be omitted to create empty subscription.
|
||||
func (c *ClusterClient) PSubscribe(channels ...string) *PubSub {
|
||||
pubsub := c.pubSub(channels)
|
||||
pubsub := c.pubSub()
|
||||
if len(channels) > 0 {
|
||||
_ = pubsub.PSubscribe(channels...)
|
||||
}
|
||||
|
||||
+42
-6
@@ -218,6 +218,25 @@ func (cmd *Cmd) Uint64() (uint64, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (cmd *Cmd) Float32() (float32, error) {
|
||||
if cmd.err != nil {
|
||||
return 0, cmd.err
|
||||
}
|
||||
switch val := cmd.val.(type) {
|
||||
case int64:
|
||||
return float32(val), nil
|
||||
case string:
|
||||
f, err := strconv.ParseFloat(val, 32)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return float32(f), nil
|
||||
default:
|
||||
err := fmt.Errorf("redis: unexpected type=%T for Float32", val)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
func (cmd *Cmd) Float64() (float64, error) {
|
||||
if cmd.err != nil {
|
||||
return 0, cmd.err
|
||||
@@ -585,6 +604,17 @@ func (cmd *StringCmd) Uint64() (uint64, error) {
|
||||
return strconv.ParseUint(cmd.Val(), 10, 64)
|
||||
}
|
||||
|
||||
func (cmd *StringCmd) Float32() (float32, error) {
|
||||
if cmd.err != nil {
|
||||
return 0, cmd.err
|
||||
}
|
||||
f, err := strconv.ParseFloat(cmd.Val(), 32)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return float32(f), nil
|
||||
}
|
||||
|
||||
func (cmd *StringCmd) Float64() (float64, error) {
|
||||
if cmd.err != nil {
|
||||
return 0, cmd.err
|
||||
@@ -687,12 +717,12 @@ func (cmd *StringSliceCmd) readReply(rd *proto.Reader) error {
|
||||
func stringSliceParser(rd *proto.Reader, n int64) (interface{}, error) {
|
||||
ss := make([]string, 0, n)
|
||||
for i := int64(0); i < n; i++ {
|
||||
s, err := rd.ReadString()
|
||||
if err == Nil {
|
||||
switch s, err := rd.ReadString(); {
|
||||
case err == Nil:
|
||||
ss = append(ss, "")
|
||||
} else if err != nil {
|
||||
case err != nil:
|
||||
return nil, err
|
||||
} else {
|
||||
default:
|
||||
ss = append(ss, s)
|
||||
}
|
||||
}
|
||||
@@ -969,14 +999,20 @@ func xMessageSliceParser(rd *proto.Reader, n int64) (interface{}, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var values map[string]interface{}
|
||||
|
||||
v, err := rd.ReadArrayReply(stringInterfaceMapParser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if err != proto.Nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
values = v.(map[string]interface{})
|
||||
}
|
||||
|
||||
msgs = append(msgs, XMessage{
|
||||
ID: id,
|
||||
Values: v.(map[string]interface{}),
|
||||
Values: values,
|
||||
})
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
+12
-4
@@ -270,6 +270,7 @@ type Cmdable interface {
|
||||
ClusterResetHard() *StatusCmd
|
||||
ClusterInfo() *StringCmd
|
||||
ClusterKeySlot(key string) *IntCmd
|
||||
ClusterGetKeysInSlot(slot int, count int) *StringSliceCmd
|
||||
ClusterCountFailureReports(nodeID string) *IntCmd
|
||||
ClusterCountKeysInSlot(slot int) *IntCmd
|
||||
ClusterDelSlots(slots ...int) *StatusCmd
|
||||
@@ -1452,10 +1453,11 @@ func (c *cmdable) XGroupDelConsumer(stream, group, consumer string) *IntCmd {
|
||||
type XReadGroupArgs struct {
|
||||
Group string
|
||||
Consumer string
|
||||
Streams []string
|
||||
Count int64
|
||||
Block time.Duration
|
||||
NoAck bool
|
||||
// List of streams and ids.
|
||||
Streams []string
|
||||
Count int64
|
||||
Block time.Duration
|
||||
NoAck bool
|
||||
}
|
||||
|
||||
func (c *cmdable) XReadGroup(a *XReadGroupArgs) *XStreamSliceCmd {
|
||||
@@ -2402,6 +2404,12 @@ func (c *cmdable) ClusterKeySlot(key string) *IntCmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c *cmdable) ClusterGetKeysInSlot(slot int, count int) *StringSliceCmd {
|
||||
cmd := NewStringSliceCmd("cluster", "getkeysinslot", slot, count)
|
||||
c.process(cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c *cmdable) ClusterCountFailureReports(nodeID string) *IntCmd {
|
||||
cmd := NewIntCmd("cluster", "count-failure-reports", nodeID)
|
||||
c.process(cmd)
|
||||
|
||||
+6
-1
@@ -47,7 +47,8 @@ func IsBadConn(err error, allowTimeout bool) bool {
|
||||
return false
|
||||
}
|
||||
if IsRedisError(err) {
|
||||
return strings.HasPrefix(err.Error(), "READONLY ")
|
||||
// #790
|
||||
return IsReadOnlyError(err)
|
||||
}
|
||||
if allowTimeout {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
@@ -82,3 +83,7 @@ func IsMovedError(err error) (moved bool, ask bool, addr string) {
|
||||
func IsLoadingError(err error) bool {
|
||||
return strings.HasPrefix(err.Error(), "LOADING ")
|
||||
}
|
||||
|
||||
func IsReadOnlyError(err error) bool {
|
||||
return strings.HasPrefix(err.Error(), "READONLY ")
|
||||
}
|
||||
|
||||
+6
-4
@@ -17,14 +17,16 @@ type Conn struct {
|
||||
rdLocked bool
|
||||
wr *proto.Writer
|
||||
|
||||
InitedAt time.Time
|
||||
pooled bool
|
||||
usedAt atomic.Value
|
||||
Inited bool
|
||||
pooled bool
|
||||
createdAt time.Time
|
||||
usedAt atomic.Value
|
||||
}
|
||||
|
||||
func NewConn(netConn net.Conn) *Conn {
|
||||
cn := &Conn{
|
||||
netConn: netConn,
|
||||
netConn: netConn,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
cn.rd = proto.NewReader(netConn)
|
||||
cn.wr = proto.NewWriter(netConn)
|
||||
|
||||
+4
-4
@@ -38,7 +38,7 @@ type Pooler interface {
|
||||
|
||||
Get() (*Conn, error)
|
||||
Put(*Conn)
|
||||
Remove(*Conn)
|
||||
Remove(*Conn, error)
|
||||
|
||||
Len() int
|
||||
IdleLen() int
|
||||
@@ -289,7 +289,7 @@ func (p *ConnPool) popIdle() *Conn {
|
||||
|
||||
func (p *ConnPool) Put(cn *Conn) {
|
||||
if !cn.pooled {
|
||||
p.Remove(cn)
|
||||
p.Remove(cn, nil)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -300,7 +300,7 @@ func (p *ConnPool) Put(cn *Conn) {
|
||||
p.freeTurn()
|
||||
}
|
||||
|
||||
func (p *ConnPool) Remove(cn *Conn) {
|
||||
func (p *ConnPool) Remove(cn *Conn, reason error) {
|
||||
p.removeConn(cn)
|
||||
p.freeTurn()
|
||||
_ = p.closeConn(cn)
|
||||
@@ -468,7 +468,7 @@ func (p *ConnPool) isStaleConn(cn *Conn) bool {
|
||||
if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout {
|
||||
return true
|
||||
}
|
||||
if p.opt.MaxConnAge > 0 && now.Sub(cn.InitedAt) >= p.opt.MaxConnAge {
|
||||
if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
+167
-17
@@ -1,53 +1,203 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
const (
|
||||
stateDefault = 0
|
||||
stateInited = 1
|
||||
stateClosed = 2
|
||||
)
|
||||
|
||||
type BadConnError struct {
|
||||
wrapped error
|
||||
}
|
||||
|
||||
var _ error = (*BadConnError)(nil)
|
||||
|
||||
func (e BadConnError) Error() string {
|
||||
return "pg: Conn is in a bad state"
|
||||
}
|
||||
|
||||
func (e BadConnError) Unwrap() error {
|
||||
return e.wrapped
|
||||
}
|
||||
|
||||
type SingleConnPool struct {
|
||||
cn *Conn
|
||||
pool Pooler
|
||||
level int32 // atomic
|
||||
|
||||
state uint32 // atomic
|
||||
ch chan *Conn
|
||||
|
||||
_badConnError atomic.Value
|
||||
}
|
||||
|
||||
var _ Pooler = (*SingleConnPool)(nil)
|
||||
|
||||
func NewSingleConnPool(cn *Conn) *SingleConnPool {
|
||||
return &SingleConnPool{
|
||||
cn: cn,
|
||||
func NewSingleConnPool(pool Pooler) *SingleConnPool {
|
||||
p, ok := pool.(*SingleConnPool)
|
||||
if !ok {
|
||||
p = &SingleConnPool{
|
||||
pool: pool,
|
||||
ch: make(chan *Conn, 1),
|
||||
}
|
||||
}
|
||||
atomic.AddInt32(&p.level, 1)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) SetConn(cn *Conn) {
|
||||
if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
|
||||
p.ch <- cn
|
||||
} else {
|
||||
panic("not reached")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) NewConn() (*Conn, error) {
|
||||
panic("not implemented")
|
||||
return p.pool.NewConn()
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) CloseConn(*Conn) error {
|
||||
panic("not implemented")
|
||||
func (p *SingleConnPool) CloseConn(cn *Conn) error {
|
||||
return p.pool.CloseConn(cn)
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Get() (*Conn, error) {
|
||||
return p.cn, nil
|
||||
// In worst case this races with Close which is not a very common operation.
|
||||
for i := 0; i < 1000; i++ {
|
||||
switch atomic.LoadUint32(&p.state) {
|
||||
case stateDefault:
|
||||
cn, err := p.pool.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
|
||||
return cn, nil
|
||||
}
|
||||
p.pool.Remove(cn, ErrClosed)
|
||||
case stateInited:
|
||||
if err := p.badConnError(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cn, ok := <-p.ch
|
||||
if !ok {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
return cn, nil
|
||||
case stateClosed:
|
||||
return nil, ErrClosed
|
||||
default:
|
||||
panic("not reached")
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("pg: SingleConnPool.Get: infinite loop")
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Put(cn *Conn) {
|
||||
if p.cn != cn {
|
||||
panic("p.cn != cn")
|
||||
defer func() {
|
||||
if recover() != nil {
|
||||
p.freeConn(cn)
|
||||
}
|
||||
}()
|
||||
p.ch <- cn
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) freeConn(cn *Conn) {
|
||||
if err := p.badConnError(); err != nil {
|
||||
p.pool.Remove(cn, err)
|
||||
} else {
|
||||
p.pool.Put(cn)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Remove(cn *Conn) {
|
||||
if p.cn != cn {
|
||||
panic("p.cn != cn")
|
||||
}
|
||||
func (p *SingleConnPool) Remove(cn *Conn, reason error) {
|
||||
defer func() {
|
||||
if recover() != nil {
|
||||
p.pool.Remove(cn, ErrClosed)
|
||||
}
|
||||
}()
|
||||
p._badConnError.Store(BadConnError{wrapped: reason})
|
||||
p.ch <- cn
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Len() int {
|
||||
return 1
|
||||
switch atomic.LoadUint32(&p.state) {
|
||||
case stateDefault:
|
||||
return 0
|
||||
case stateInited:
|
||||
return 1
|
||||
case stateClosed:
|
||||
return 0
|
||||
default:
|
||||
panic("not reached")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) IdleLen() int {
|
||||
return 0
|
||||
return len(p.ch)
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Stats() *Stats {
|
||||
return nil
|
||||
return &Stats{}
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Close() error {
|
||||
level := atomic.AddInt32(&p.level, -1)
|
||||
if level > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
state := atomic.LoadUint32(&p.state)
|
||||
if state == stateClosed {
|
||||
return ErrClosed
|
||||
}
|
||||
if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) {
|
||||
close(p.ch)
|
||||
cn, ok := <-p.ch
|
||||
if ok {
|
||||
p.freeConn(cn)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("pg: SingleConnPool.Close: infinite loop")
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Reset() error {
|
||||
if p.badConnError() == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case cn, ok := <-p.ch:
|
||||
if !ok {
|
||||
return ErrClosed
|
||||
}
|
||||
p.pool.Remove(cn, ErrClosed)
|
||||
p._badConnError.Store(BadConnError{wrapped: nil})
|
||||
default:
|
||||
return fmt.Errorf("pg: SingleConnPool does not have a Conn")
|
||||
}
|
||||
|
||||
if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) {
|
||||
state := atomic.LoadUint32(&p.state)
|
||||
return fmt.Errorf("pg: invalid SingleConnPool state: %d", state)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) badConnError() error {
|
||||
if v := p._badConnError.Load(); v != nil {
|
||||
err := v.(BadConnError)
|
||||
if err.wrapped != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
+5
-5
@@ -55,13 +55,13 @@ func (p *StickyConnPool) putUpstream() {
|
||||
|
||||
func (p *StickyConnPool) Put(cn *Conn) {}
|
||||
|
||||
func (p *StickyConnPool) removeUpstream() {
|
||||
p.pool.Remove(p.cn)
|
||||
func (p *StickyConnPool) removeUpstream(reason error) {
|
||||
p.pool.Remove(p.cn, reason)
|
||||
p.cn = nil
|
||||
}
|
||||
|
||||
func (p *StickyConnPool) Remove(cn *Conn) {
|
||||
p.removeUpstream()
|
||||
func (p *StickyConnPool) Remove(cn *Conn, reason error) {
|
||||
p.removeUpstream(reason)
|
||||
}
|
||||
|
||||
func (p *StickyConnPool) Len() int {
|
||||
@@ -101,7 +101,7 @@ func (p *StickyConnPool) Close() error {
|
||||
if p.reusable {
|
||||
p.putUpstream()
|
||||
} else {
|
||||
p.removeUpstream()
|
||||
p.removeUpstream(ErrClosed)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
-64
@@ -1,64 +0,0 @@
|
||||
/*
|
||||
Copyright 2013 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package singleflight provides a duplicate function call suppression
|
||||
// mechanism.
|
||||
package singleflight
|
||||
|
||||
import "sync"
|
||||
|
||||
// call is an in-flight or completed Do call
|
||||
type call struct {
|
||||
wg sync.WaitGroup
|
||||
val interface{}
|
||||
err error
|
||||
}
|
||||
|
||||
// Group represents a class of work and forms a namespace in which
|
||||
// units of work can be executed with duplicate suppression.
|
||||
type Group struct {
|
||||
mu sync.Mutex // protects m
|
||||
m map[string]*call // lazily initialized
|
||||
}
|
||||
|
||||
// Do executes and returns the results of the given function, making
|
||||
// sure that only one execution is in-flight for a given key at a
|
||||
// time. If a duplicate comes in, the duplicate caller waits for the
|
||||
// original to complete and receives the same results.
|
||||
func (g *Group) Do(key string, fn func() (interface{}, error)) (interface{}, error) {
|
||||
g.mu.Lock()
|
||||
if g.m == nil {
|
||||
g.m = make(map[string]*call)
|
||||
}
|
||||
if c, ok := g.m[key]; ok {
|
||||
g.mu.Unlock()
|
||||
c.wg.Wait()
|
||||
return c.val, c.err
|
||||
}
|
||||
c := new(call)
|
||||
c.wg.Add(1)
|
||||
g.m[key] = c
|
||||
g.mu.Unlock()
|
||||
|
||||
c.val, c.err = fn()
|
||||
c.wg.Done()
|
||||
|
||||
g.mu.Lock()
|
||||
delete(g.m, key)
|
||||
g.mu.Unlock()
|
||||
|
||||
return c.val, c.err
|
||||
}
|
||||
+10
@@ -27,3 +27,13 @@ func isLower(s string) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func Unwrap(err error) error {
|
||||
u, ok := err.(interface {
|
||||
Unwrap() error
|
||||
})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return u.Unwrap()
|
||||
}
|
||||
|
||||
+3
@@ -101,6 +101,9 @@ func (opt *Options) init() {
|
||||
if opt.Network == "" {
|
||||
opt.Network = "tcp"
|
||||
}
|
||||
if opt.Addr == "" {
|
||||
opt.Addr = "localhost:6379"
|
||||
}
|
||||
if opt.Dialer == nil {
|
||||
opt.Dialer = func() (net.Conn, error) {
|
||||
netDialer := &net.Dialer{
|
||||
|
||||
+20
@@ -8,8 +8,22 @@ import (
|
||||
|
||||
type pipelineExecer func([]Cmder) error
|
||||
|
||||
// Pipeliner is an mechanism to realise Redis Pipeline technique.
|
||||
//
|
||||
// Pipelining is a technique to extremely speed up processing by packing
|
||||
// operations to batches, send them at once to Redis and read a replies in a
|
||||
// singe step.
|
||||
// See https://redis.io/topics/pipelining
|
||||
//
|
||||
// Pay attention, that Pipeline is not a transaction, so you can get unexpected
|
||||
// results in case of big pipelines and small read/write timeouts.
|
||||
// Redis client has retransmission logic in case of timeouts, pipeline
|
||||
// can be retransmitted and commands can be executed more then once.
|
||||
// To avoid this: it is good idea to use reasonable bigger read/write timeouts
|
||||
// depends of your batch size and/or use TxPipeline.
|
||||
type Pipeliner interface {
|
||||
StatefulCmdable
|
||||
Do(args ...interface{}) *Cmd
|
||||
Process(cmd Cmder) error
|
||||
Close() error
|
||||
Discard() error
|
||||
@@ -31,6 +45,12 @@ type Pipeline struct {
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (c *Pipeline) Do(args ...interface{}) *Cmd {
|
||||
cmd := NewCmd(args...)
|
||||
_ = c.Process(cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// Process queues the cmd for later execution.
|
||||
func (c *Pipeline) Process(cmd Cmder) error {
|
||||
c.mu.Lock()
|
||||
|
||||
+51
-11
@@ -3,6 +3,7 @@ package redis
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -13,7 +14,7 @@ import (
|
||||
|
||||
var errPingTimeout = errors.New("redis: ping timeout")
|
||||
|
||||
// PubSub implements Pub/Sub commands bas described in
|
||||
// PubSub implements Pub/Sub commands as described in
|
||||
// http://redis.io/topics/pubsub. Message receiving is NOT safe
|
||||
// for concurrent use by multiple goroutines.
|
||||
//
|
||||
@@ -29,8 +30,9 @@ type PubSub struct {
|
||||
cn *pool.Conn
|
||||
channels map[string]struct{}
|
||||
patterns map[string]struct{}
|
||||
closed bool
|
||||
exit chan struct{}
|
||||
|
||||
closed bool
|
||||
exit chan struct{}
|
||||
|
||||
cmd *Cmd
|
||||
|
||||
@@ -39,6 +41,12 @@ type PubSub struct {
|
||||
ping chan struct{}
|
||||
}
|
||||
|
||||
func (c *PubSub) String() string {
|
||||
channels := mapKeys(c.channels)
|
||||
channels = append(channels, mapKeys(c.patterns)...)
|
||||
return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", "))
|
||||
}
|
||||
|
||||
func (c *PubSub) init() {
|
||||
c.exit = make(chan struct{})
|
||||
}
|
||||
@@ -389,16 +397,39 @@ func (c *PubSub) ReceiveMessage() (*Message, error) {
|
||||
// It periodically sends Ping messages to test connection health.
|
||||
// The channel is closed with PubSub. Receive* APIs can not be used
|
||||
// after channel is created.
|
||||
//
|
||||
// If the Go channel is full for 30 seconds the message is dropped.
|
||||
func (c *PubSub) Channel() <-chan *Message {
|
||||
c.chOnce.Do(c.initChannel)
|
||||
return c.channel(100)
|
||||
}
|
||||
|
||||
// ChannelSize is like Channel, but creates a Go channel
|
||||
// with specified buffer size.
|
||||
func (c *PubSub) ChannelSize(size int) <-chan *Message {
|
||||
return c.channel(size)
|
||||
}
|
||||
|
||||
func (c *PubSub) channel(size int) <-chan *Message {
|
||||
c.chOnce.Do(func() {
|
||||
c.initChannel(size)
|
||||
})
|
||||
if cap(c.ch) != size {
|
||||
err := fmt.Errorf("redis: PubSub.Channel is called with different buffer size")
|
||||
panic(err)
|
||||
}
|
||||
return c.ch
|
||||
}
|
||||
|
||||
func (c *PubSub) initChannel() {
|
||||
c.ch = make(chan *Message, 100)
|
||||
c.ping = make(chan struct{}, 10)
|
||||
func (c *PubSub) initChannel(size int) {
|
||||
const timeout = 30 * time.Second
|
||||
|
||||
c.ch = make(chan *Message, size)
|
||||
c.ping = make(chan struct{}, 1)
|
||||
|
||||
go func() {
|
||||
timer := time.NewTimer(timeout)
|
||||
timer.Stop()
|
||||
|
||||
var errCount int
|
||||
for {
|
||||
msg, err := c.Receive()
|
||||
@@ -413,6 +444,7 @@ func (c *PubSub) initChannel() {
|
||||
errCount++
|
||||
continue
|
||||
}
|
||||
|
||||
errCount = 0
|
||||
|
||||
// Any message is as good as a ping.
|
||||
@@ -427,16 +459,24 @@ func (c *PubSub) initChannel() {
|
||||
case *Pong:
|
||||
// Ignore.
|
||||
case *Message:
|
||||
c.ch <- msg
|
||||
timer.Reset(timeout)
|
||||
select {
|
||||
case c.ch <- msg:
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
case <-timer.C:
|
||||
internal.Logf(
|
||||
"redis: %s channel is full for %s (message is dropped)",
|
||||
c, timeout)
|
||||
}
|
||||
default:
|
||||
internal.Logf("redis: unknown message: %T", msg)
|
||||
internal.Logf("redis: unknown message type: %T", msg)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
const timeout = 5 * time.Second
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
timer.Stop()
|
||||
|
||||
|
||||
+20
-17
@@ -51,11 +51,10 @@ func (c *baseClient) newConn() (*pool.Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cn.InitedAt.IsZero() {
|
||||
if err := c.initConn(cn); err != nil {
|
||||
_ = c.connPool.CloseConn(cn)
|
||||
return nil, err
|
||||
}
|
||||
err = c.initConn(cn)
|
||||
if err != nil {
|
||||
_ = c.connPool.CloseConn(cn)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cn, nil
|
||||
@@ -85,12 +84,13 @@ func (c *baseClient) _getConn() (*pool.Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cn.InitedAt.IsZero() {
|
||||
err := c.initConn(cn)
|
||||
if err != nil {
|
||||
c.connPool.Remove(cn)
|
||||
err = c.initConn(cn)
|
||||
if err != nil {
|
||||
c.connPool.Remove(cn, err)
|
||||
if err := internal.Unwrap(err); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cn, nil
|
||||
@@ -102,7 +102,7 @@ func (c *baseClient) releaseConn(cn *pool.Conn, err error) {
|
||||
}
|
||||
|
||||
if internal.IsBadConn(err, false) {
|
||||
c.connPool.Remove(cn)
|
||||
c.connPool.Remove(cn, err)
|
||||
} else {
|
||||
c.connPool.Put(cn)
|
||||
}
|
||||
@@ -116,12 +116,15 @@ func (c *baseClient) releaseConnStrict(cn *pool.Conn, err error) {
|
||||
if err == nil || internal.IsRedisError(err) {
|
||||
c.connPool.Put(cn)
|
||||
} else {
|
||||
c.connPool.Remove(cn)
|
||||
c.connPool.Remove(cn, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *baseClient) initConn(cn *pool.Conn) error {
|
||||
cn.InitedAt = time.Now()
|
||||
if cn.Inited {
|
||||
return nil
|
||||
}
|
||||
cn.Inited = true
|
||||
|
||||
if c.opt.Password == "" &&
|
||||
c.opt.DB == 0 &&
|
||||
@@ -201,9 +204,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = cn.WithReader(c.cmdTimeout(cmd), func(rd *proto.Reader) error {
|
||||
return cmd.readReply(rd)
|
||||
})
|
||||
err = cn.WithReader(c.cmdTimeout(cmd), cmd.readReply)
|
||||
c.releaseConn(cn, err)
|
||||
if err != nil && internal.IsRetryableError(err, cmd.readTimeout() == nil) {
|
||||
continue
|
||||
@@ -237,7 +238,7 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
|
||||
func (c *baseClient) Close() error {
|
||||
var firstErr error
|
||||
if c.onClose != nil {
|
||||
if err := c.onClose(); err != nil && firstErr == nil {
|
||||
if err := c.onClose(); err != nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
@@ -543,10 +544,12 @@ type Conn struct {
|
||||
}
|
||||
|
||||
func newConn(opt *Options, cn *pool.Conn) *Conn {
|
||||
connPool := pool.NewSingleConnPool(nil)
|
||||
connPool.SetConn(cn)
|
||||
c := Conn{
|
||||
baseClient: baseClient{
|
||||
opt: opt,
|
||||
connPool: pool.NewSingleConnPool(cn),
|
||||
connPool: connPool,
|
||||
},
|
||||
}
|
||||
c.baseClient.init()
|
||||
|
||||
+50
-6
@@ -273,9 +273,13 @@ func (c *ringShards) Heartbeat(frequency time.Duration) {
|
||||
|
||||
// rebalance removes dead shards from the Ring.
|
||||
func (c *ringShards) rebalance() {
|
||||
c.mu.RLock()
|
||||
shards := c.shards
|
||||
c.mu.RUnlock()
|
||||
|
||||
hash := newConsistentHash(c.opt)
|
||||
var shardsNum int
|
||||
for name, shard := range c.shards {
|
||||
for name, shard := range shards {
|
||||
if shard.IsUp() {
|
||||
hash.Add(name)
|
||||
shardsNum++
|
||||
@@ -319,12 +323,12 @@ func (c *ringShards) Close() error {
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// Ring is a Redis client that uses constistent hashing to distribute
|
||||
// Ring is a Redis client that uses consistent hashing to distribute
|
||||
// keys across multiple Redis servers (shards). It's safe for
|
||||
// concurrent use by multiple goroutines.
|
||||
//
|
||||
// Ring monitors the state of each shard and removes dead shards from
|
||||
// the ring. When shard comes online it is added back to the ring. This
|
||||
// the ring. When a shard comes online it is added back to the ring. This
|
||||
// gives you maximum availability and partition tolerance, but no
|
||||
// consistency between different shards or even clients. Each client
|
||||
// uses shards that are available to the client and does not do any
|
||||
@@ -357,7 +361,8 @@ func NewRing(opt *RingOptions) *Ring {
|
||||
|
||||
ring.process = ring.defaultProcess
|
||||
ring.processPipeline = ring.defaultProcessPipeline
|
||||
ring.cmdable.setProcessor(ring.Process)
|
||||
|
||||
ring.init()
|
||||
|
||||
for name, addr := range opt.Addrs {
|
||||
clopt := opt.clientOptions()
|
||||
@@ -370,6 +375,10 @@ func NewRing(opt *RingOptions) *Ring {
|
||||
return ring
|
||||
}
|
||||
|
||||
func (c *Ring) init() {
|
||||
c.cmdable.setProcessor(c.Process)
|
||||
}
|
||||
|
||||
func (c *Ring) Context() context.Context {
|
||||
if c.ctx != nil {
|
||||
return c.ctx
|
||||
@@ -381,13 +390,15 @@ func (c *Ring) WithContext(ctx context.Context) *Ring {
|
||||
if ctx == nil {
|
||||
panic("nil context")
|
||||
}
|
||||
c2 := c.copy()
|
||||
c2 := c.clone()
|
||||
c2.ctx = ctx
|
||||
return c2
|
||||
}
|
||||
|
||||
func (c *Ring) copy() *Ring {
|
||||
func (c *Ring) clone() *Ring {
|
||||
cp := *c
|
||||
cp.init()
|
||||
|
||||
return &cp
|
||||
}
|
||||
|
||||
@@ -653,6 +664,39 @@ func (c *Ring) Close() error {
|
||||
return c.shards.Close()
|
||||
}
|
||||
|
||||
func (c *Ring) Watch(fn func(*Tx) error, keys ...string) error {
|
||||
if len(keys) == 0 {
|
||||
return fmt.Errorf("redis: Watch requires at least one key")
|
||||
}
|
||||
|
||||
var shards []*ringShard
|
||||
for _, key := range keys {
|
||||
if key != "" {
|
||||
shard, err := c.shards.GetByKey(hashtag.Key(key))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
shards = append(shards, shard)
|
||||
}
|
||||
}
|
||||
|
||||
if len(shards) == 0 {
|
||||
return fmt.Errorf("redis: Watch requires at least one shard")
|
||||
}
|
||||
|
||||
if len(shards) > 1 {
|
||||
for _, shard := range shards[1:] {
|
||||
if shard.Client != shards[0].Client {
|
||||
err := fmt.Errorf("redis: Watch requires all keys to be in the same shard")
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return shards[0].Client.Watch(fn, keys...)
|
||||
}
|
||||
|
||||
func newConsistentHash(opt *RingOptions) *consistenthash.Map {
|
||||
return consistenthash.New(opt.HashReplicas, consistenthash.Hash(opt.Hash))
|
||||
}
|
||||
|
||||
+46
-12
@@ -90,9 +90,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
|
||||
opt: opt,
|
||||
connPool: failover.Pool(),
|
||||
|
||||
onClose: func() error {
|
||||
return failover.Close()
|
||||
},
|
||||
onClose: failover.Close,
|
||||
},
|
||||
}
|
||||
c.baseClient.init()
|
||||
@@ -164,6 +162,39 @@ func (c *SentinelClient) Sentinels(name string) *SliceCmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
// Failover forces a failover as if the master was not reachable, and without
|
||||
// asking for agreement to other Sentinels.
|
||||
func (c *SentinelClient) Failover(name string) *StatusCmd {
|
||||
cmd := NewStatusCmd("sentinel", "failover", name)
|
||||
c.Process(cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// Reset resets all the masters with matching name. The pattern argument is a
|
||||
// glob-style pattern. The reset process clears any previous state in a master
|
||||
// (including a failover in progress), and removes every slave and sentinel
|
||||
// already discovered and associated with the master.
|
||||
func (c *SentinelClient) Reset(pattern string) *IntCmd {
|
||||
cmd := NewIntCmd("sentinel", "reset", pattern)
|
||||
c.Process(cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// FlushConfig forces Sentinel to rewrite its configuration on disk, including
|
||||
// the current Sentinel state.
|
||||
func (c *SentinelClient) FlushConfig() *StatusCmd {
|
||||
cmd := NewStatusCmd("sentinel", "flushconfig")
|
||||
c.Process(cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// Master shows the state and info of the specified master.
|
||||
func (c *SentinelClient) Master(name string) *StringStringMapCmd {
|
||||
cmd := NewStringStringMapCmd("sentinel", "master", name)
|
||||
c.Process(cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
type sentinelFailover struct {
|
||||
sentinelAddrs []string
|
||||
|
||||
@@ -214,7 +245,9 @@ func (c *sentinelFailover) MasterAddr() (string, error) {
|
||||
}
|
||||
|
||||
func (c *sentinelFailover) masterAddr() (string, error) {
|
||||
c.mu.RLock()
|
||||
addr := c.getMasterAddr()
|
||||
c.mu.RUnlock()
|
||||
if addr != "" {
|
||||
return addr, nil
|
||||
}
|
||||
@@ -222,6 +255,15 @@ func (c *sentinelFailover) masterAddr() (string, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
addr = c.getMasterAddr()
|
||||
if addr != "" {
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
if c.sentinel != nil {
|
||||
c.closeSentinel()
|
||||
}
|
||||
|
||||
for i, sentinelAddr := range c.sentinelAddrs {
|
||||
sentinel := NewSentinelClient(&Options{
|
||||
Addr: sentinelAddr,
|
||||
@@ -260,9 +302,7 @@ func (c *sentinelFailover) masterAddr() (string, error) {
|
||||
}
|
||||
|
||||
func (c *sentinelFailover) getMasterAddr() string {
|
||||
c.mu.RLock()
|
||||
sentinel := c.sentinel
|
||||
c.mu.RUnlock()
|
||||
|
||||
if sentinel == nil {
|
||||
return ""
|
||||
@@ -272,11 +312,6 @@ func (c *sentinelFailover) getMasterAddr() string {
|
||||
if err != nil {
|
||||
internal.Logf("sentinel: GetMasterAddrByName name=%q failed: %s",
|
||||
c.masterName, err)
|
||||
c.mu.Lock()
|
||||
if c.sentinel == sentinel {
|
||||
c.closeSentinel()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -358,8 +393,7 @@ func (c *sentinelFailover) listen(pubsub *PubSub) {
|
||||
break
|
||||
}
|
||||
|
||||
switch msg.Channel {
|
||||
case "+switch-master":
|
||||
if msg.Channel == "+switch-master" {
|
||||
parts := strings.Split(msg.Payload, " ")
|
||||
if parts[0] != c.masterName {
|
||||
internal.Logf("sentinel: ignore addr for master=%q", parts[0])
|
||||
|
||||
+2
-2
@@ -29,10 +29,10 @@ func (c *Client) newTx() *Tx {
|
||||
return &tx
|
||||
}
|
||||
|
||||
// Watch prepares a transcaction and marks the keys to be watched
|
||||
// Watch prepares a transaction and marks the keys to be watched
|
||||
// for conditional execution if there are any keys.
|
||||
//
|
||||
// The transaction is automatically closed when the fn exits.
|
||||
// The transaction is automatically closed when fn exits.
|
||||
func (c *Client) Watch(fn func(*Tx) error, keys ...string) error {
|
||||
tx := c.newTx()
|
||||
if len(keys) > 0 {
|
||||
|
||||
+1
@@ -155,6 +155,7 @@ type UniversalClient interface {
|
||||
Watch(fn func(*Tx) error, keys ...string) error
|
||||
Process(cmd Cmder) error
|
||||
WrapProcess(fn func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error)
|
||||
WrapProcessPipeline(fn func(oldProcess func([]Cmder) error) func([]Cmder) error)
|
||||
Subscribe(channels ...string) *PubSub
|
||||
PSubscribe(channels ...string) *PubSub
|
||||
Close() error
|
||||
|
||||
+9
-7
@@ -13,24 +13,26 @@ matrix:
|
||||
|
||||
branches:
|
||||
only:
|
||||
- master
|
||||
- master
|
||||
|
||||
env:
|
||||
global:
|
||||
- GOMAXPROCS=2
|
||||
matrix:
|
||||
- CASS=2.2.13
|
||||
- CASS=2.1.21
|
||||
AUTH=true
|
||||
- CASS=2.2.13
|
||||
- CASS=2.2.14
|
||||
AUTH=true
|
||||
- CASS=2.2.14
|
||||
AUTH=false
|
||||
- CASS=3.0.17
|
||||
- CASS=3.0.18
|
||||
AUTH=false
|
||||
- CASS=3.11.3
|
||||
- CASS=3.11.4
|
||||
AUTH=false
|
||||
|
||||
go:
|
||||
- "1.10"
|
||||
- "1.11"
|
||||
- 1.12.x
|
||||
- 1.13.x
|
||||
|
||||
install:
|
||||
- ./install_test_deps.sh $TRAVIS_REPO_SLUG
|
||||
|
||||
+7
@@ -108,3 +108,10 @@ Luke Hines <lukehines@protonmail.com>
|
||||
Jacob Greenleaf <jacob@jacobgreenleaf.com>
|
||||
Alex Lourie <alex@instaclustr.com>; <djay.il@gmail.com>
|
||||
Marco Cadetg <cadetg@gmail.com>
|
||||
Karl Matthias <karl@matthias.org>
|
||||
Thomas Meson <zllak@hycik.org>
|
||||
Martin Sucha <martin.sucha@kiwi.com>; <git@mm.ms47.eu>
|
||||
Pavel Buchinchik <p.buchinchik@gmail.com>
|
||||
Rintaro Okamura <rintaro.okamura@gmail.com>
|
||||
Yura Sokolov <y.sokolov@joom.com>; <funny.falcon@gmail.com>
|
||||
Jorge Bay <jorgebg@apache.org>
|
||||
|
||||
+18
-2
@@ -19,8 +19,8 @@ The following matrix shows the versions of Go and Cassandra that are tested with
|
||||
|
||||
Go/Cassandra | 2.1.x | 2.2.x | 3.x.x
|
||||
-------------| -------| ------| ---------
|
||||
1.10 | yes | yes | yes
|
||||
1.11 | yes | yes | yes
|
||||
1.12 | yes | yes | yes
|
||||
1.13 | yes | yes | yes
|
||||
|
||||
Gocql has been tested in production against many different versions of Cassandra. Due to limits in our CI setup we only test against the latest 3 major releases, which coincide with the official support from the Apache project.
|
||||
|
||||
@@ -166,6 +166,22 @@ func main() {
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
Authentication
|
||||
-------
|
||||
|
||||
```go
|
||||
cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
||||
cluster.Authenticator = gocql.PasswordAuthenticator{
|
||||
Username: "user",
|
||||
Password: "password"
|
||||
}
|
||||
cluster.Keyspace = "example"
|
||||
cluster.Consistency = gocql.Quorum
|
||||
session, _ := cluster.CreateSession()
|
||||
defer session.Close()
|
||||
```
|
||||
|
||||
Data Binding
|
||||
------------
|
||||
|
||||
|
||||
+26
-16
@@ -5,6 +5,7 @@
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
@@ -45,22 +46,23 @@ type ClusterConfig struct {
|
||||
// highest supported protocol for the cluster. In clusters with nodes of different
|
||||
// versions the protocol selected is not defined (ie, it can be any of the supported in the cluster)
|
||||
ProtoVersion int
|
||||
Timeout time.Duration // connection timeout (default: 600ms)
|
||||
ConnectTimeout time.Duration // initial connection timeout, used during initial dial to server (default: 600ms)
|
||||
Port int // port (default: 9042)
|
||||
Keyspace string // initial keyspace (optional)
|
||||
NumConns int // number of connections per host (default: 2)
|
||||
Consistency Consistency // default consistency level (default: Quorum)
|
||||
Compressor Compressor // compression algorithm (default: nil)
|
||||
Authenticator Authenticator // authenticator (default: nil)
|
||||
RetryPolicy RetryPolicy // Default retry policy to use for queries (default: 0)
|
||||
ConvictionPolicy ConvictionPolicy // Decide whether to mark host as down based on the error and host info (default: SimpleConvictionPolicy)
|
||||
ReconnectionPolicy ReconnectionPolicy // Default reconnection policy to use for reconnecting before trying to mark host as down (default: see below)
|
||||
SocketKeepalive time.Duration // The keepalive period to use, enabled if > 0 (default: 0)
|
||||
MaxPreparedStmts int // Sets the maximum cache size for prepared statements globally for gocql (default: 1000)
|
||||
MaxRoutingKeyInfo int // Sets the maximum cache size for query info about statements for each session (default: 1000)
|
||||
PageSize int // Default page size to use for created sessions (default: 5000)
|
||||
SerialConsistency SerialConsistency // Sets the consistency for the serial part of queries, values can be either SERIAL or LOCAL_SERIAL (default: unset)
|
||||
Timeout time.Duration // connection timeout (default: 600ms)
|
||||
ConnectTimeout time.Duration // initial connection timeout, used during initial dial to server (default: 600ms)
|
||||
Port int // port (default: 9042)
|
||||
Keyspace string // initial keyspace (optional)
|
||||
NumConns int // number of connections per host (default: 2)
|
||||
Consistency Consistency // default consistency level (default: Quorum)
|
||||
Compressor Compressor // compression algorithm (default: nil)
|
||||
Authenticator Authenticator // authenticator (default: nil)
|
||||
AuthProvider func(h *HostInfo) (Authenticator, error) // an authenticator factory. Can be used to create alternative authenticators (default: nil)
|
||||
RetryPolicy RetryPolicy // Default retry policy to use for queries (default: 0)
|
||||
ConvictionPolicy ConvictionPolicy // Decide whether to mark host as down based on the error and host info (default: SimpleConvictionPolicy)
|
||||
ReconnectionPolicy ReconnectionPolicy // Default reconnection policy to use for reconnecting before trying to mark host as down (default: see below)
|
||||
SocketKeepalive time.Duration // The keepalive period to use, enabled if > 0 (default: 0)
|
||||
MaxPreparedStmts int // Sets the maximum cache size for prepared statements globally for gocql (default: 1000)
|
||||
MaxRoutingKeyInfo int // Sets the maximum cache size for query info about statements for each session (default: 1000)
|
||||
PageSize int // Default page size to use for created sessions (default: 5000)
|
||||
SerialConsistency SerialConsistency // Sets the consistency for the serial part of queries, values can be either SERIAL or LOCAL_SERIAL (default: unset)
|
||||
SslOpts *SslOptions
|
||||
DefaultTimestamp bool // Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server. (default: true, only enabled for protocol 3 and above)
|
||||
// PoolConfig configures the underlying connection pool, allowing the
|
||||
@@ -143,10 +145,18 @@ type ClusterConfig struct {
|
||||
// (default: 200 microseconds)
|
||||
WriteCoalesceWaitTime time.Duration
|
||||
|
||||
// Dialer will be used to establish all connections created for this Cluster.
|
||||
// If not provided, a default dialer configured with ConnectTimeout will be used.
|
||||
Dialer Dialer
|
||||
|
||||
// internal config for testing
|
||||
disableControlConn bool
|
||||
}
|
||||
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// NewCluster generates a new config for the default cluster implementation.
|
||||
//
|
||||
// The supplied hosts are used to initially connect to the cluster then the rest of
|
||||
|
||||
+273
-178
@@ -28,6 +28,9 @@ var (
|
||||
"org.apache.cassandra.auth.PasswordAuthenticator",
|
||||
"com.instaclustr.cassandra.auth.SharedSecretAuthenticator",
|
||||
"com.datastax.bdp.cassandra.auth.DseAuthenticator",
|
||||
"io.aiven.cassandra.auth.AivenAuthenticator",
|
||||
"com.ericsson.bss.cassandra.ecaudit.auth.AuditPasswordAuthenticator",
|
||||
"com.amazon.helenus.auth.HelenusAuthenticator",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -96,10 +99,14 @@ type ConnConfig struct {
|
||||
CQLVersion string
|
||||
Timeout time.Duration
|
||||
ConnectTimeout time.Duration
|
||||
Dialer Dialer
|
||||
Compressor Compressor
|
||||
Authenticator Authenticator
|
||||
AuthProvider func(h *HostInfo) (Authenticator, error)
|
||||
Keepalive time.Duration
|
||||
tlsConfig *tls.Config
|
||||
|
||||
tlsConfig *tls.Config
|
||||
disableCoalesce bool
|
||||
}
|
||||
|
||||
type ConnErrorHandler interface {
|
||||
@@ -135,7 +142,7 @@ type Conn struct {
|
||||
headerBuf [maxFrameHeaderSize]byte
|
||||
|
||||
streams *streams.IDGenerator
|
||||
mu sync.RWMutex
|
||||
mu sync.Mutex
|
||||
calls map[int]*callReq
|
||||
|
||||
errorHandler ConnErrorHandler
|
||||
@@ -150,50 +157,78 @@ type Conn struct {
|
||||
session *Session
|
||||
|
||||
closed int32
|
||||
quit chan struct{}
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
timeouts int64
|
||||
}
|
||||
|
||||
// Connect establishes a connection to a Cassandra node.
|
||||
func (s *Session) dial(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
|
||||
// connect establishes a connection to a Cassandra node using session's connection config.
|
||||
func (s *Session) connect(ctx context.Context, host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
|
||||
return s.dial(ctx, host, s.connCfg, errorHandler)
|
||||
}
|
||||
|
||||
// dial establishes a connection to a Cassandra node and notifies the session's connectObserver.
|
||||
func (s *Session) dial(ctx context.Context, host *HostInfo, connConfig *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
|
||||
var obs ObservedConnect
|
||||
if s.connectObserver != nil {
|
||||
obs.Host = host
|
||||
obs.Start = time.Now()
|
||||
}
|
||||
|
||||
conn, err := s.dialWithoutObserver(ctx, host, connConfig, errorHandler)
|
||||
|
||||
if s.connectObserver != nil {
|
||||
obs.End = time.Now()
|
||||
obs.Err = err
|
||||
s.connectObserver.ObserveConnect(obs)
|
||||
}
|
||||
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// dialWithoutObserver establishes connection to a Cassandra node.
|
||||
//
|
||||
// dialWithoutObserver does not notify the connection observer, so you most probably want to call dial() instead.
|
||||
func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
|
||||
ip := host.ConnectAddress()
|
||||
port := host.port
|
||||
|
||||
// TODO(zariel): remove these
|
||||
if len(ip) == 0 || ip.IsUnspecified() {
|
||||
if !validIpAddr(ip) {
|
||||
panic(fmt.Sprintf("host missing connect ip address: %v", ip))
|
||||
} else if port == 0 {
|
||||
panic(fmt.Sprintf("host missing port: %v", port))
|
||||
}
|
||||
|
||||
var (
|
||||
err error
|
||||
conn net.Conn
|
||||
)
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: cfg.ConnectTimeout,
|
||||
}
|
||||
if cfg.Keepalive > 0 {
|
||||
dialer.KeepAlive = cfg.Keepalive
|
||||
dialer := cfg.Dialer
|
||||
if dialer == nil {
|
||||
d := &net.Dialer{
|
||||
Timeout: cfg.ConnectTimeout,
|
||||
}
|
||||
if cfg.Keepalive > 0 {
|
||||
d.KeepAlive = cfg.Keepalive
|
||||
}
|
||||
dialer = d
|
||||
}
|
||||
|
||||
// TODO(zariel): handle ipv6 zone
|
||||
addr := (&net.TCPAddr{IP: ip, Port: port}).String()
|
||||
|
||||
if cfg.tlsConfig != nil {
|
||||
// the TLS config is safe to be reused by connections but it must not
|
||||
// be modified after being used.
|
||||
conn, err = tls.DialWithDialer(dialer, "tcp", addr, cfg.tlsConfig)
|
||||
} else {
|
||||
conn, err = dialer.Dial("tcp", addr)
|
||||
}
|
||||
|
||||
conn, err := dialer.DialContext(ctx, "tcp", host.HostnameAndPort())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.tlsConfig != nil {
|
||||
// the TLS config is safe to be reused by connections but it must not
|
||||
// be modified after being used.
|
||||
tconn := tls.Client(conn, cfg.tlsConfig)
|
||||
if err := tconn.Handshake(); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
conn = tconn
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
c := &Conn{
|
||||
conn: conn,
|
||||
r: bufio.NewReader(conn),
|
||||
@@ -203,8 +238,6 @@ func (s *Session) dial(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHa
|
||||
addr: conn.RemoteAddr().String(),
|
||||
errorHandler: errorHandler,
|
||||
compressor: cfg.Compressor,
|
||||
auth: cfg.Authenticator,
|
||||
quit: make(chan struct{}),
|
||||
session: s,
|
||||
streams: streams.New(cfg.ProtoVersion),
|
||||
host: host,
|
||||
@@ -213,40 +246,51 @@ func (s *Session) dial(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHa
|
||||
w: conn,
|
||||
timeout: cfg.Timeout,
|
||||
},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
var (
|
||||
ctx context.Context
|
||||
cancel func()
|
||||
)
|
||||
if cfg.ConnectTimeout > 0 {
|
||||
ctx, cancel = context.WithTimeout(context.TODO(), cfg.ConnectTimeout)
|
||||
} else {
|
||||
ctx, cancel = context.WithCancel(context.TODO())
|
||||
if err := c.init(ctx); err != nil {
|
||||
cancel()
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Conn) init(ctx context.Context) error {
|
||||
if c.session.cfg.AuthProvider != nil {
|
||||
var err error
|
||||
c.auth, err = c.cfg.AuthProvider(c.host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
c.auth = c.cfg.Authenticator
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
startup := &startupCoordinator{
|
||||
frameTicker: make(chan struct{}),
|
||||
conn: c,
|
||||
}
|
||||
|
||||
c.timeout = cfg.ConnectTimeout
|
||||
c.timeout = c.cfg.ConnectTimeout
|
||||
if err := startup.setupConn(ctx); err != nil {
|
||||
c.close()
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
c.timeout = cfg.Timeout
|
||||
c.timeout = c.cfg.Timeout
|
||||
|
||||
// dont coalesce startup frames
|
||||
if s.cfg.WriteCoalesceWaitTime > 0 {
|
||||
c.w = newWriteCoalescer(c.w, s.cfg.WriteCoalesceWaitTime, c.quit)
|
||||
if c.session.cfg.WriteCoalesceWaitTime > 0 && !c.cfg.disableCoalesce {
|
||||
c.w = newWriteCoalescer(c.conn, c.timeout, c.session.cfg.WriteCoalesceWaitTime, ctx.Done())
|
||||
}
|
||||
|
||||
go c.serve()
|
||||
go c.serve(ctx)
|
||||
go c.heartBeat(ctx)
|
||||
|
||||
return c, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) Write(p []byte) (n int, err error) {
|
||||
@@ -282,10 +326,18 @@ type startupCoordinator struct {
|
||||
}
|
||||
|
||||
func (s *startupCoordinator) setupConn(ctx context.Context) error {
|
||||
var cancel context.CancelFunc
|
||||
if s.conn.timeout > 0 {
|
||||
ctx, cancel = context.WithTimeout(ctx, s.conn.timeout)
|
||||
} else {
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
startupErr := make(chan error)
|
||||
go func() {
|
||||
for range s.frameTicker {
|
||||
err := s.conn.recv()
|
||||
err := s.conn.recv(ctx)
|
||||
if err != nil {
|
||||
select {
|
||||
case startupErr <- err:
|
||||
@@ -432,7 +484,7 @@ func (c *Conn) closeWithError(err error) {
|
||||
// we should attempt to deliver the error back to the caller if it
|
||||
// exists
|
||||
if err != nil {
|
||||
c.mu.RLock()
|
||||
c.mu.Lock()
|
||||
for _, req := range c.calls {
|
||||
// we need to send the error to all waiting queries, put the state
|
||||
// of this conn into not active so that it can not execute any queries.
|
||||
@@ -441,11 +493,11 @@ func (c *Conn) closeWithError(err error) {
|
||||
case <-req.timeout:
|
||||
}
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// if error was nil then unblock the quit channel
|
||||
close(c.quit)
|
||||
c.cancel()
|
||||
cerr := c.close()
|
||||
|
||||
if err != nil {
|
||||
@@ -467,10 +519,10 @@ func (c *Conn) Close() {
|
||||
// Serve starts the stream multiplexer for this connection, which is required
|
||||
// to execute any queries. This method runs as long as the connection is
|
||||
// open and is therefore usually called in a separate goroutine.
|
||||
func (c *Conn) serve() {
|
||||
func (c *Conn) serve(ctx context.Context) {
|
||||
var err error
|
||||
for err == nil {
|
||||
err = c.recv()
|
||||
err = c.recv(ctx)
|
||||
}
|
||||
|
||||
c.closeWithError(err)
|
||||
@@ -495,7 +547,54 @@ func (p *protocolError) Error() string {
|
||||
return fmt.Sprintf("gocql: received unexpected frame on stream %d: %v", p.frame.Header().stream, p.frame)
|
||||
}
|
||||
|
||||
func (c *Conn) recv() error {
|
||||
func (c *Conn) heartBeat(ctx context.Context) {
|
||||
sleepTime := 1 * time.Second
|
||||
timer := time.NewTimer(sleepTime)
|
||||
defer timer.Stop()
|
||||
|
||||
var failures int
|
||||
|
||||
for {
|
||||
if failures > 5 {
|
||||
c.closeWithError(fmt.Errorf("gocql: heartbeat failed"))
|
||||
return
|
||||
}
|
||||
|
||||
timer.Reset(sleepTime)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
framer, err := c.exec(context.Background(), &writeOptionsFrame{}, nil)
|
||||
if err != nil {
|
||||
failures++
|
||||
continue
|
||||
}
|
||||
|
||||
resp, err := framer.parseFrame()
|
||||
if err != nil {
|
||||
// invalid frame
|
||||
failures++
|
||||
continue
|
||||
}
|
||||
|
||||
switch resp.(type) {
|
||||
case *supportedFrame:
|
||||
// Everything ok
|
||||
sleepTime = 5 * time.Second
|
||||
failures = 0
|
||||
case error:
|
||||
// TODO: should we do something here?
|
||||
default:
|
||||
panic(fmt.Sprintf("gocql: unknown frame in response to options: %T", resp))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) recv(ctx context.Context) error {
|
||||
// not safe for concurrent reads
|
||||
|
||||
// read a full header, ignore timeouts, as this is being ran in a loop
|
||||
@@ -521,6 +620,7 @@ func (c *Conn) recv() error {
|
||||
Length: int32(head.length),
|
||||
Start: headStartTime,
|
||||
End: headEndTime,
|
||||
Host: c.host,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -552,12 +652,15 @@ func (c *Conn) recv() error {
|
||||
}
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
c.mu.Lock()
|
||||
call, ok := c.calls[head.stream]
|
||||
c.mu.RUnlock()
|
||||
delete(c.calls, head.stream)
|
||||
c.mu.Unlock()
|
||||
if call == nil || call.framer == nil || !ok {
|
||||
Logger.Printf("gocql: received response for stream which has no handler: header=%v\n", head)
|
||||
return c.discardFrame(head)
|
||||
} else if head.stream != call.streamID {
|
||||
panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.stream))
|
||||
}
|
||||
|
||||
err = call.framer.readFrame(&head)
|
||||
@@ -574,30 +677,19 @@ func (c *Conn) recv() error {
|
||||
select {
|
||||
case call.resp <- err:
|
||||
case <-call.timeout:
|
||||
c.releaseStream(head.stream)
|
||||
case <-c.quit:
|
||||
c.releaseStream(call)
|
||||
case <-ctx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) releaseStream(stream int) {
|
||||
c.mu.Lock()
|
||||
call := c.calls[stream]
|
||||
if call != nil && stream != call.streamID {
|
||||
panic(fmt.Sprintf("attempt to release streamID with invalid stream: %d -> %+v\n", stream, call))
|
||||
} else if call == nil {
|
||||
panic(fmt.Sprintf("releasing a stream not in use: %d", stream))
|
||||
}
|
||||
delete(c.calls, stream)
|
||||
c.mu.Unlock()
|
||||
|
||||
func (c *Conn) releaseStream(call *callReq) {
|
||||
if call.timer != nil {
|
||||
call.timer.Stop()
|
||||
}
|
||||
|
||||
streamPool.Put(call)
|
||||
c.streams.Clear(stream)
|
||||
c.streams.Clear(call.streamID)
|
||||
}
|
||||
|
||||
func (c *Conn) handleTimeout() {
|
||||
@@ -606,16 +698,6 @@ func (c *Conn) handleTimeout() {
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
streamPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &callReq{
|
||||
resp: make(chan error),
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
type callReq struct {
|
||||
// could use a waitgroup but this allows us to do timeouts on the read/send
|
||||
resp chan error
|
||||
@@ -641,19 +723,20 @@ func (c *deadlineWriter) Write(p []byte) (int, error) {
|
||||
return c.w.Write(p)
|
||||
}
|
||||
|
||||
func newWriteCoalescer(w io.Writer, d time.Duration, quit <-chan struct{}) *writeCoalescer {
|
||||
func newWriteCoalescer(conn net.Conn, timeout time.Duration, d time.Duration, quit <-chan struct{}) *writeCoalescer {
|
||||
wc := &writeCoalescer{
|
||||
writeCh: make(chan struct{}), // TODO: could this be sync?
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
w: w,
|
||||
c: conn,
|
||||
quit: quit,
|
||||
timeout: timeout,
|
||||
}
|
||||
go wc.writeFlusher(d)
|
||||
return wc
|
||||
}
|
||||
|
||||
type writeCoalescer struct {
|
||||
w io.Writer
|
||||
c net.Conn
|
||||
|
||||
quit <-chan struct{}
|
||||
writeCh chan struct{}
|
||||
@@ -662,6 +745,7 @@ type writeCoalescer struct {
|
||||
// cond waits for the buffer to be flushed
|
||||
cond *sync.Cond
|
||||
buffers net.Buffers
|
||||
timeout time.Duration
|
||||
|
||||
// result of the write
|
||||
err error
|
||||
@@ -673,10 +757,14 @@ func (w *writeCoalescer) flushLocked() {
|
||||
return
|
||||
}
|
||||
|
||||
if w.timeout > 0 {
|
||||
w.c.SetWriteDeadline(time.Now().Add(w.timeout))
|
||||
}
|
||||
|
||||
// Given we are going to do a fanout n is useless and according to
|
||||
// the docs WriteTo should return 0 and err or bytes written and
|
||||
// no error.
|
||||
_, w.err = w.buffers.WriteTo(w.w)
|
||||
_, w.err = w.buffers.WriteTo(w.c)
|
||||
if w.err != nil {
|
||||
w.buffers = nil
|
||||
}
|
||||
@@ -766,10 +854,12 @@ func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*frame
|
||||
// resp is basically a waiting semaphore protecting the framer
|
||||
framer := newFramer(c, c, c.compressor, c.version)
|
||||
|
||||
call := streamPool.Get().(*callReq)
|
||||
call.framer = framer
|
||||
call.timeout = make(chan struct{})
|
||||
call.streamID = stream
|
||||
call := &callReq{
|
||||
framer: framer,
|
||||
timeout: make(chan struct{}),
|
||||
streamID: stream,
|
||||
resp: make(chan error),
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
existingCall := c.calls[stream]
|
||||
@@ -833,7 +923,7 @@ func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*frame
|
||||
// this is because the request is still outstanding and we have
|
||||
// been handed another error from another stream which caused the
|
||||
// connection to close.
|
||||
c.releaseStream(stream)
|
||||
c.releaseStream(call)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -844,7 +934,7 @@ func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*frame
|
||||
case <-ctxDone:
|
||||
close(call.timeout)
|
||||
return nil, ctx.Err()
|
||||
case <-c.quit:
|
||||
case <-c.ctx.Done():
|
||||
return nil, ErrConnectionClosed
|
||||
}
|
||||
|
||||
@@ -854,7 +944,7 @@ func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*frame
|
||||
//
|
||||
// Ensure that the stream is not released if there are potentially outstanding
|
||||
// requests on the stream to prevent nil pointer dereferences in recv().
|
||||
defer c.releaseStream(stream)
|
||||
defer c.releaseStream(call)
|
||||
|
||||
if v := framer.header.version.version(); v != c.version {
|
||||
return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
|
||||
@@ -870,8 +960,8 @@ type preparedStatment struct {
|
||||
}
|
||||
|
||||
type inflightPrepare struct {
|
||||
wg sync.WaitGroup
|
||||
err error
|
||||
done chan struct{}
|
||||
err error
|
||||
|
||||
preparedStatment *preparedStatment
|
||||
}
|
||||
@@ -879,69 +969,76 @@ type inflightPrepare struct {
|
||||
func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) {
|
||||
stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt)
|
||||
flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare {
|
||||
flight := new(inflightPrepare)
|
||||
flight.wg.Add(1)
|
||||
flight := &inflightPrepare{
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
lru.Add(stmtCacheKey, flight)
|
||||
return flight
|
||||
})
|
||||
|
||||
if ok {
|
||||
flight.wg.Wait()
|
||||
if !ok {
|
||||
go func() {
|
||||
defer close(flight.done)
|
||||
|
||||
prep := &writePrepareFrame{
|
||||
statement: stmt,
|
||||
}
|
||||
if c.version > protoVersion4 {
|
||||
prep.keyspace = c.currentKeyspace
|
||||
}
|
||||
|
||||
// we won the race to do the load, if our context is canceled we shouldnt
|
||||
// stop the load as other callers are waiting for it but this caller should get
|
||||
// their context cancelled error.
|
||||
framer, err := c.exec(c.ctx, prep, tracer)
|
||||
if err != nil {
|
||||
flight.err = err
|
||||
c.session.stmtsLRU.remove(stmtCacheKey)
|
||||
return
|
||||
}
|
||||
|
||||
frame, err := framer.parseFrame()
|
||||
if err != nil {
|
||||
flight.err = err
|
||||
c.session.stmtsLRU.remove(stmtCacheKey)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated
|
||||
// everytime we need to parse a frame.
|
||||
if len(framer.traceID) > 0 && tracer != nil {
|
||||
tracer.Trace(framer.traceID)
|
||||
}
|
||||
|
||||
switch x := frame.(type) {
|
||||
case *resultPreparedFrame:
|
||||
flight.preparedStatment = &preparedStatment{
|
||||
// defensively copy as we will recycle the underlying buffer after we
|
||||
// return.
|
||||
id: copyBytes(x.preparedID),
|
||||
// the type info's should _not_ have a reference to the framers read buffer,
|
||||
// therefore we can just copy them directly.
|
||||
request: x.reqMeta,
|
||||
response: x.respMeta,
|
||||
}
|
||||
case error:
|
||||
flight.err = x
|
||||
default:
|
||||
flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
|
||||
}
|
||||
|
||||
if flight.err != nil {
|
||||
c.session.stmtsLRU.remove(stmtCacheKey)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-flight.done:
|
||||
return flight.preparedStatment, flight.err
|
||||
}
|
||||
|
||||
prep := &writePrepareFrame{
|
||||
statement: stmt,
|
||||
}
|
||||
if c.version > protoVersion4 {
|
||||
prep.keyspace = c.currentKeyspace
|
||||
}
|
||||
|
||||
framer, err := c.exec(ctx, prep, tracer)
|
||||
if err != nil {
|
||||
flight.err = err
|
||||
flight.wg.Done()
|
||||
c.session.stmtsLRU.remove(stmtCacheKey)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
frame, err := framer.parseFrame()
|
||||
if err != nil {
|
||||
flight.err = err
|
||||
flight.wg.Done()
|
||||
c.session.stmtsLRU.remove(stmtCacheKey)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated
|
||||
// everytime we need to parse a frame.
|
||||
if len(framer.traceID) > 0 && tracer != nil {
|
||||
tracer.Trace(framer.traceID)
|
||||
}
|
||||
|
||||
switch x := frame.(type) {
|
||||
case *resultPreparedFrame:
|
||||
flight.preparedStatment = &preparedStatment{
|
||||
// defensively copy as we will recycle the underlying buffer after we
|
||||
// return.
|
||||
id: copyBytes(x.preparedID),
|
||||
// the type info's should _not_ have a reference to the framers read buffer,
|
||||
// therefore we can just copy them directly.
|
||||
request: x.reqMeta,
|
||||
response: x.respMeta,
|
||||
}
|
||||
case error:
|
||||
flight.err = x
|
||||
default:
|
||||
flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
|
||||
}
|
||||
flight.wg.Done()
|
||||
|
||||
if flight.err != nil {
|
||||
c.session.stmtsLRU.remove(stmtCacheKey)
|
||||
}
|
||||
|
||||
return flight.preparedStatment, flight.err
|
||||
}
|
||||
|
||||
func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error {
|
||||
@@ -989,7 +1086,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
|
||||
info *preparedStatment
|
||||
)
|
||||
|
||||
if qry.shouldPrepare() {
|
||||
if !qry.skipPrepare && qry.shouldPrepare() {
|
||||
// Prepare all DML queries. Other queries can not be prepared.
|
||||
var err error
|
||||
info, err = c.prepareStatement(ctx, qry.stmt, qry.trace)
|
||||
@@ -997,11 +1094,8 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
|
||||
return &Iter{err: err}
|
||||
}
|
||||
|
||||
var values []interface{}
|
||||
|
||||
if qry.binding == nil {
|
||||
values = qry.values
|
||||
} else {
|
||||
values := qry.values
|
||||
if qry.binding != nil {
|
||||
values, err = qry.binding(&QueryInfo{
|
||||
Id: info.id,
|
||||
Args: info.request.columns,
|
||||
@@ -1105,11 +1199,8 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
|
||||
return iter
|
||||
case *RequestErrUnprepared:
|
||||
stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, qry.stmt)
|
||||
if c.session.stmtsLRU.remove(stmtCacheKey) {
|
||||
return c.executeQuery(ctx, qry)
|
||||
}
|
||||
|
||||
return &Iter{err: x, framer: framer}
|
||||
c.session.stmtsLRU.evictPreparedID(stmtCacheKey, x.StatementId)
|
||||
return c.executeQuery(ctx, qry)
|
||||
case error:
|
||||
return &Iter{err: x, framer: framer}
|
||||
default:
|
||||
@@ -1141,9 +1232,9 @@ func (c *Conn) AvailableStreams() int {
|
||||
|
||||
func (c *Conn) UseKeyspace(keyspace string) error {
|
||||
q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
|
||||
q.params.consistency = Any
|
||||
q.params.consistency = c.session.cons
|
||||
|
||||
framer, err := c.exec(context.Background(), q, nil)
|
||||
framer, err := c.exec(c.ctx, q, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1249,14 +1340,9 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
|
||||
stmt, found := stmts[string(x.StatementId)]
|
||||
if found {
|
||||
key := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt)
|
||||
c.session.stmtsLRU.remove(key)
|
||||
}
|
||||
|
||||
if found {
|
||||
return c.executeBatch(ctx, batch)
|
||||
} else {
|
||||
return &Iter{err: x, framer: framer}
|
||||
c.session.stmtsLRU.evictPreparedID(key, x.StatementId)
|
||||
}
|
||||
return c.executeBatch(ctx, batch)
|
||||
case *resultRowsFrame:
|
||||
iter := &Iter{
|
||||
meta: x.meta,
|
||||
@@ -1275,16 +1361,19 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
|
||||
func (c *Conn) query(ctx context.Context, statement string, values ...interface{}) (iter *Iter) {
|
||||
q := c.session.Query(statement, values...).Consistency(One)
|
||||
q.trace = nil
|
||||
q.skipPrepare = true
|
||||
q.disableSkipMetadata = true
|
||||
return c.executeQuery(ctx, q)
|
||||
}
|
||||
|
||||
func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) {
|
||||
const (
|
||||
peerSchemas = "SELECT schema_version, peer FROM system.peers"
|
||||
peerSchemas = "SELECT * FROM system.peers"
|
||||
localSchemas = "SELECT schema_version FROM system.local WHERE key='local'"
|
||||
)
|
||||
|
||||
var versions map[string]struct{}
|
||||
var schemaVersion string
|
||||
|
||||
endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement)
|
||||
for time.Now().Before(endDeadline) {
|
||||
@@ -1292,16 +1381,22 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) {
|
||||
|
||||
versions = make(map[string]struct{})
|
||||
|
||||
var schemaVersion string
|
||||
var peer string
|
||||
for iter.Scan(&schemaVersion, &peer) {
|
||||
if schemaVersion == "" {
|
||||
Logger.Printf("skipping peer entry with empty schema_version: peer=%q", peer)
|
||||
rows, err := iter.SliceMap()
|
||||
if err != nil {
|
||||
goto cont
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
host, err := c.session.hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port})
|
||||
if err != nil {
|
||||
goto cont
|
||||
}
|
||||
if !isValidPeer(host) || host.schemaVersion == "" {
|
||||
Logger.Printf("invalid peer or peer with empty schema_version: peer=%q", host)
|
||||
continue
|
||||
}
|
||||
|
||||
versions[schemaVersion] = struct{}{}
|
||||
schemaVersion = ""
|
||||
versions[host.schemaVersion] = struct{}{}
|
||||
}
|
||||
|
||||
if err = iter.Close(); err != nil {
|
||||
@@ -1352,7 +1447,7 @@ func (c *Conn) localHostInfo(ctx context.Context) (*HostInfo, error) {
|
||||
port := c.conn.RemoteAddr().(*net.TCPAddr).Port
|
||||
|
||||
// TODO(zariel): avoid doing this here
|
||||
host, err := c.session.hostInfoFromMap(row, port)
|
||||
host, err := c.session.hostInfoFromMap(row, &HostInfo{connectAddress: c.host.connectAddress, port: port})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
+12
-9
@@ -90,14 +90,17 @@ func connConfig(cfg *ClusterConfig) (*ConnConfig, error) {
|
||||
}
|
||||
|
||||
return &ConnConfig{
|
||||
ProtoVersion: cfg.ProtoVersion,
|
||||
CQLVersion: cfg.CQLVersion,
|
||||
Timeout: cfg.Timeout,
|
||||
ConnectTimeout: cfg.ConnectTimeout,
|
||||
Compressor: cfg.Compressor,
|
||||
Authenticator: cfg.Authenticator,
|
||||
Keepalive: cfg.SocketKeepalive,
|
||||
tlsConfig: tlsConfig,
|
||||
ProtoVersion: cfg.ProtoVersion,
|
||||
CQLVersion: cfg.CQLVersion,
|
||||
Timeout: cfg.Timeout,
|
||||
ConnectTimeout: cfg.ConnectTimeout,
|
||||
Dialer: cfg.Dialer,
|
||||
Compressor: cfg.Compressor,
|
||||
Authenticator: cfg.Authenticator,
|
||||
AuthProvider: cfg.AuthProvider,
|
||||
Keepalive: cfg.SocketKeepalive,
|
||||
tlsConfig: tlsConfig,
|
||||
disableCoalesce: tlsConfig != nil, // write coalescing doesn't work with framing on top of TCP like in TLS.
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -504,7 +507,7 @@ func (pool *hostConnPool) connect() (err error) {
|
||||
var conn *Conn
|
||||
reconnectionPolicy := pool.session.cfg.ReconnectionPolicy
|
||||
for i := 0; i < reconnectionPolicy.GetMaxRetries(); i++ {
|
||||
conn, err = pool.session.connect(pool.host, pool)
|
||||
conn, err = pool.session.connect(pool.session.ctx, pool.host, pool)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
+19
-13
@@ -116,7 +116,7 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
|
||||
|
||||
// Check if host is a literal IP address
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
hosts = append(hosts, &HostInfo{connectAddress: ip, port: port})
|
||||
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
@@ -142,21 +142,21 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
hosts = append(hosts, &HostInfo{connectAddress: ip, port: port})
|
||||
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
|
||||
}
|
||||
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
func shuffleHosts(hosts []*HostInfo) []*HostInfo {
|
||||
mutRandr.Lock()
|
||||
perm := randr.Perm(len(hosts))
|
||||
mutRandr.Unlock()
|
||||
shuffled := make([]*HostInfo, len(hosts))
|
||||
copy(shuffled, hosts)
|
||||
|
||||
for i, host := range hosts {
|
||||
shuffled[perm[i]] = host
|
||||
}
|
||||
mutRandr.Lock()
|
||||
randr.Shuffle(len(hosts), func(i, j int) {
|
||||
shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
|
||||
})
|
||||
mutRandr.Unlock()
|
||||
|
||||
return shuffled
|
||||
}
|
||||
@@ -166,10 +166,13 @@ func (c *controlConn) shuffleDial(endpoints []*HostInfo) (*Conn, error) {
|
||||
// node.
|
||||
shuffled := shuffleHosts(endpoints)
|
||||
|
||||
cfg := *c.session.connCfg
|
||||
cfg.disableCoalesce = true
|
||||
|
||||
var err error
|
||||
for _, host := range shuffled {
|
||||
var conn *Conn
|
||||
conn, err = c.session.connect(host, c)
|
||||
conn, err = c.session.dial(c.session.ctx, host, &cfg, c)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
@@ -218,7 +221,7 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
|
||||
var err error
|
||||
for _, host := range hosts {
|
||||
var conn *Conn
|
||||
conn, err = c.session.dial(host, &connCfg, handler)
|
||||
conn, err = c.session.dial(c.session.ctx, host, &connCfg, handler)
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
@@ -340,7 +343,7 @@ func (c *controlConn) reconnect(refreshring bool) {
|
||||
var newConn *Conn
|
||||
if host != nil {
|
||||
// try to connect to the old host
|
||||
conn, err := c.session.connect(host, c)
|
||||
conn, err := c.session.connect(c.session.ctx, host, c)
|
||||
if err != nil {
|
||||
// host is dead
|
||||
// TODO: this is replicated in a few places
|
||||
@@ -362,7 +365,7 @@ func (c *controlConn) reconnect(refreshring bool) {
|
||||
}
|
||||
|
||||
var err error
|
||||
newConn, err = c.session.connect(host, c)
|
||||
newConn, err = c.session.connect(c.session.ctx, host, c)
|
||||
if err != nil {
|
||||
// TODO: add log handler for things like this
|
||||
return
|
||||
@@ -386,7 +389,10 @@ func (c *controlConn) HandleError(conn *Conn, err error, closed bool) {
|
||||
}
|
||||
|
||||
oldConn := c.getConn()
|
||||
if oldConn.conn != conn {
|
||||
|
||||
// If connection has long gone, and not been attempted for awhile,
|
||||
// it's possible to have oldConn as nil here (#1297).
|
||||
if oldConn != nil && oldConn.conn != conn {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
+3
@@ -361,6 +361,9 @@ type ObservedFrameHeader struct {
|
||||
Start time.Time
|
||||
// EndHeader is the time we finished reading the frame header off the network connection.
|
||||
End time.Time
|
||||
|
||||
// Host is Host of the connection the frame header was read from.
|
||||
Host *HostInfo
|
||||
}
|
||||
|
||||
func (f ObservedFrameHeader) String() string {
|
||||
|
||||
+6
@@ -1,7 +1,13 @@
|
||||
module github.com/gocql/gocql
|
||||
|
||||
require (
|
||||
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 // indirect
|
||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect
|
||||
github.com/golang/snappy v0.0.0-20170215233205-553a64147049
|
||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed
|
||||
github.com/kr/pretty v0.1.0 // indirect
|
||||
github.com/stretchr/testify v1.3.0 // indirect
|
||||
gopkg.in/inf.v0 v0.9.1
|
||||
)
|
||||
|
||||
go 1.13
|
||||
|
||||
-3
@@ -1,3 +0,0 @@
|
||||
github.com/golang/snappy v0.0.0-20170215233205-553a64147049 h1:K9KHZbXKpGydfDN0aZrsoHpLJlZsBrGMFWbgLDGnPZk=
|
||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
|
||||
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
|
||||
+22
@@ -0,0 +1,22 @@
|
||||
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
|
||||
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
|
||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
|
||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/golang/snappy v0.0.0-20170215233205-553a64147049 h1:K9KHZbXKpGydfDN0aZrsoHpLJlZsBrGMFWbgLDGnPZk=
|
||||
github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
|
||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
|
||||
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
|
||||
+14
@@ -26,6 +26,8 @@ func goType(t TypeInfo) reflect.Type {
|
||||
return reflect.TypeOf(*new(string))
|
||||
case TypeBigInt, TypeCounter:
|
||||
return reflect.TypeOf(*new(int64))
|
||||
case TypeTime:
|
||||
return reflect.TypeOf(*new(time.Duration))
|
||||
case TypeTimestamp:
|
||||
return reflect.TypeOf(*new(time.Time))
|
||||
case TypeBlob:
|
||||
@@ -83,14 +85,24 @@ func getCassandraBaseType(name string) Type {
|
||||
return TypeBoolean
|
||||
case "counter":
|
||||
return TypeCounter
|
||||
case "date":
|
||||
return TypeDate
|
||||
case "decimal":
|
||||
return TypeDecimal
|
||||
case "double":
|
||||
return TypeDouble
|
||||
case "duration":
|
||||
return TypeDuration
|
||||
case "float":
|
||||
return TypeFloat
|
||||
case "int":
|
||||
return TypeInt
|
||||
case "smallint":
|
||||
return TypeSmallInt
|
||||
case "tinyint":
|
||||
return TypeTinyInt
|
||||
case "time":
|
||||
return TypeTime
|
||||
case "timestamp":
|
||||
return TypeTimestamp
|
||||
case "uuid":
|
||||
@@ -229,6 +241,8 @@ func getApacheCassandraType(class string) Type {
|
||||
return TypeSmallInt
|
||||
case "ByteType":
|
||||
return TypeTinyInt
|
||||
case "TimeType":
|
||||
return TypeTime
|
||||
case "DateType", "TimestampType":
|
||||
return TypeTimestamp
|
||||
case "UUIDType", "LexicalUUIDType":
|
||||
|
||||
+31
-14
@@ -89,6 +89,10 @@ func (c cassVersion) Before(major, minor, patch int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c cassVersion) AtLeast(major, minor, patch int) bool {
|
||||
return !c.Before(major, minor, patch)
|
||||
}
|
||||
|
||||
func (c cassVersion) String() string {
|
||||
return fmt.Sprintf("v%d.%d.%d", c.Major, c.Minor, c.Patch)
|
||||
}
|
||||
@@ -106,6 +110,7 @@ type HostInfo struct {
|
||||
// TODO(zariel): reduce locking maybe, not all values will change, but to ensure
|
||||
// that we are thread safe use a mutex to access all fields.
|
||||
mu sync.RWMutex
|
||||
hostname string
|
||||
peer net.IP
|
||||
broadcastAddress net.IP
|
||||
listenAddress net.IP
|
||||
@@ -123,6 +128,7 @@ type HostInfo struct {
|
||||
clusterName string
|
||||
version cassVersion
|
||||
state nodeState
|
||||
schemaVersion string
|
||||
tokens []string
|
||||
}
|
||||
|
||||
@@ -222,8 +228,9 @@ func (h *HostInfo) PreferredIP() net.IP {
|
||||
|
||||
func (h *HostInfo) DataCenter() string {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.dataCenter
|
||||
dc := h.dataCenter
|
||||
h.mu.RUnlock()
|
||||
return dc
|
||||
}
|
||||
|
||||
func (h *HostInfo) setDataCenter(dataCenter string) *HostInfo {
|
||||
@@ -235,8 +242,9 @@ func (h *HostInfo) setDataCenter(dataCenter string) *HostInfo {
|
||||
|
||||
func (h *HostInfo) Rack() string {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.rack
|
||||
rack := h.rack
|
||||
h.mu.RUnlock()
|
||||
return rack
|
||||
}
|
||||
|
||||
func (h *HostInfo) setRack(rack string) *HostInfo {
|
||||
@@ -407,15 +415,22 @@ func (h *HostInfo) IsUp() bool {
|
||||
return h != nil && h.State() == NodeUp
|
||||
}
|
||||
|
||||
func (h *HostInfo) HostnameAndPort() string {
|
||||
if h.hostname == "" {
|
||||
h.hostname = h.ConnectAddress().String()
|
||||
}
|
||||
return net.JoinHostPort(h.hostname, strconv.Itoa(h.port))
|
||||
}
|
||||
|
||||
func (h *HostInfo) String() string {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
connectAddr, source := h.connectAddressLocked()
|
||||
return fmt.Sprintf("[HostInfo connectAddress=%q peer=%q rpc_address=%q broadcast_address=%q "+
|
||||
return fmt.Sprintf("[HostInfo hostname=%q connectAddress=%q peer=%q rpc_address=%q broadcast_address=%q "+
|
||||
"preferred_ip=%q connect_addr=%q connect_addr_source=%q "+
|
||||
"port=%d data_centre=%q rack=%q host_id=%q version=%q state=%s num_tokens=%d]",
|
||||
h.connectAddress, h.peer, h.rpcAddress, h.broadcastAddress, h.preferredIP,
|
||||
h.hostname, h.connectAddress, h.peer, h.rpcAddress, h.broadcastAddress, h.preferredIP,
|
||||
connectAddr, source,
|
||||
h.port, h.dataCenter, h.rack, h.hostId, h.version, h.state, len(h.tokens))
|
||||
}
|
||||
@@ -446,15 +461,11 @@ func checkSystemSchema(control *controlConn) (bool, error) {
|
||||
|
||||
// Given a map that represents a row from either system.local or system.peers
|
||||
// return as much information as we can in *HostInfo
|
||||
func (s *Session) hostInfoFromMap(row map[string]interface{}, port int) (*HostInfo, error) {
|
||||
func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*HostInfo, error) {
|
||||
const assertErrorMsg = "Assertion failed for %s"
|
||||
var ok bool
|
||||
|
||||
// Default to our connected port if the cluster doesn't have port information
|
||||
host := HostInfo{
|
||||
port: port,
|
||||
}
|
||||
|
||||
for key, value := range row {
|
||||
switch key {
|
||||
case "data_center":
|
||||
@@ -539,6 +550,12 @@ func (s *Session) hostInfoFromMap(row map[string]interface{}, port int) (*HostIn
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "dse_version")
|
||||
}
|
||||
case "schema_version":
|
||||
schemaVersion, ok := value.(UUID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "schema_version")
|
||||
}
|
||||
host.schemaVersion = schemaVersion.String()
|
||||
}
|
||||
// TODO(thrawn01): Add 'port'? once CASSANDRA-7544 is complete
|
||||
// Not sure what the port field will be called until the JIRA issue is complete
|
||||
@@ -548,7 +565,7 @@ func (s *Session) hostInfoFromMap(row map[string]interface{}, port int) (*HostIn
|
||||
host.connectAddress = ip
|
||||
host.port = port
|
||||
|
||||
return &host, nil
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// Ask the control node for host info on all it's known peers
|
||||
@@ -571,7 +588,7 @@ func (r *ringDescriber) getClusterPeerInfo() ([]*HostInfo, error) {
|
||||
|
||||
for _, row := range rows {
|
||||
// extract all available info about the peer
|
||||
host, err := r.session.hostInfoFromMap(row, r.session.cfg.Port)
|
||||
host, err := r.session.hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !isValidPeer(host) {
|
||||
@@ -633,7 +650,7 @@ func (r *ringDescriber) getHostInfo(ip net.IP, port int) (*HostInfo, error) {
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
h, err := r.session.hostInfoFromMap(row, port)
|
||||
h, err := r.session.hostInfoFromMap(row, &HostInfo{port: port})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
+6
-1
@@ -75,10 +75,15 @@ function run_tests() {
|
||||
else
|
||||
sleep 1s
|
||||
go test -tags "cassandra gocql_debug" -timeout=5m -race $args
|
||||
|
||||
ccm clear
|
||||
ccm start --wait-for-binary-proto
|
||||
sleep 1s
|
||||
|
||||
go test -tags "integration gocql_debug" -timeout=5m -race $args
|
||||
|
||||
ccm clear
|
||||
ccm start
|
||||
ccm start --wait-for-binary-proto
|
||||
sleep 1s
|
||||
|
||||
go test -tags "ccm gocql_debug" -timeout=5m -race $args
|
||||
|
||||
+3
-3
@@ -1,11 +1,11 @@
|
||||
// +build appengine
|
||||
// +build appengine s390x
|
||||
|
||||
package murmur
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
func getBlock(data []byte, n int) (int64, int64) {
|
||||
k1 := binary.LittleEndian.Int64(data[n*16:])
|
||||
k2 := binary.LittleEndian.Int64(data[(n*16)+8:])
|
||||
k1 := int64(binary.LittleEndian.Uint64(data[n*16:]))
|
||||
k2 := int64(binary.LittleEndian.Uint64(data[(n*16)+8:]))
|
||||
return k1, k2
|
||||
}
|
||||
|
||||
+1
@@ -1,4 +1,5 @@
|
||||
// +build !appengine
|
||||
// +build !s390x
|
||||
|
||||
package murmur
|
||||
|
||||
|
||||
+173
-55
@@ -82,7 +82,9 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
|
||||
return marshalDouble(info, value)
|
||||
case TypeDecimal:
|
||||
return marshalDecimal(info, value)
|
||||
case TypeTimestamp, TypeTime:
|
||||
case TypeTime:
|
||||
return marshalTime(info, value)
|
||||
case TypeTimestamp:
|
||||
return marshalTimestamp(info, value)
|
||||
case TypeList, TypeSet:
|
||||
return marshalList(info, value)
|
||||
@@ -146,7 +148,9 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
|
||||
return unmarshalDouble(info, data, value)
|
||||
case TypeDecimal:
|
||||
return unmarshalDecimal(info, data, value)
|
||||
case TypeTimestamp, TypeTime:
|
||||
case TypeTime:
|
||||
return unmarshalTime(info, data, value)
|
||||
case TypeTimestamp:
|
||||
return unmarshalTimestamp(info, data, value)
|
||||
case TypeList, TypeSet:
|
||||
return unmarshalList(info, data, value)
|
||||
@@ -709,9 +713,6 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
|
||||
return nil
|
||||
case *uint:
|
||||
unitVal := uint64(int64Val)
|
||||
if ^uint(0) == math.MaxUint32 && unitVal > math.MaxUint32 {
|
||||
return unmarshalErrorf("unmarshal int: value %d out of range for %T", unitVal, *v)
|
||||
}
|
||||
switch info.Type() {
|
||||
case TypeInt:
|
||||
*v = uint(unitVal) & 0xFFFFFFFF
|
||||
@@ -720,6 +721,9 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
|
||||
case TypeTinyInt:
|
||||
*v = uint(unitVal) & 0xFF
|
||||
default:
|
||||
if ^uint(0) == math.MaxUint32 && (int64Val < 0 || int64Val > math.MaxUint32) {
|
||||
return unmarshalErrorf("unmarshal int: value %d out of range for %T", unitVal, *v)
|
||||
}
|
||||
*v = uint(unitVal)
|
||||
}
|
||||
return nil
|
||||
@@ -745,15 +749,17 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
|
||||
*v = int32(int64Val)
|
||||
return nil
|
||||
case *uint32:
|
||||
if int64Val > math.MaxUint32 {
|
||||
return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
|
||||
}
|
||||
switch info.Type() {
|
||||
case TypeInt:
|
||||
*v = uint32(int64Val) & 0xFFFFFFFF
|
||||
case TypeSmallInt:
|
||||
*v = uint32(int64Val) & 0xFFFF
|
||||
case TypeTinyInt:
|
||||
*v = uint32(int64Val) & 0xFF
|
||||
default:
|
||||
if int64Val < 0 || int64Val > math.MaxUint32 {
|
||||
return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
|
||||
}
|
||||
*v = uint32(int64Val) & 0xFFFFFFFF
|
||||
}
|
||||
return nil
|
||||
@@ -764,13 +770,15 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
|
||||
*v = int16(int64Val)
|
||||
return nil
|
||||
case *uint16:
|
||||
if int64Val > math.MaxUint16 {
|
||||
return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
|
||||
}
|
||||
switch info.Type() {
|
||||
case TypeSmallInt:
|
||||
*v = uint16(int64Val) & 0xFFFF
|
||||
case TypeTinyInt:
|
||||
*v = uint16(int64Val) & 0xFF
|
||||
default:
|
||||
if int64Val < 0 || int64Val > math.MaxUint16 {
|
||||
return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
|
||||
}
|
||||
*v = uint16(int64Val) & 0xFFFF
|
||||
}
|
||||
return nil
|
||||
@@ -781,7 +789,7 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
|
||||
*v = int8(int64Val)
|
||||
return nil
|
||||
case *uint8:
|
||||
if int64Val > math.MaxUint8 {
|
||||
if info.Type() != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) {
|
||||
return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v)
|
||||
}
|
||||
*v = uint8(int64Val) & 0xFF
|
||||
@@ -829,34 +837,69 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
|
||||
rv.SetInt(int64Val)
|
||||
return nil
|
||||
case reflect.Uint:
|
||||
if int64Val < 0 || (^uint(0) == math.MaxUint32 && int64Val > math.MaxUint32) {
|
||||
return unmarshalErrorf("unmarshal int: value %d out of range", int64Val)
|
||||
unitVal := uint64(int64Val)
|
||||
switch info.Type() {
|
||||
case TypeInt:
|
||||
rv.SetUint(unitVal & 0xFFFFFFFF)
|
||||
case TypeSmallInt:
|
||||
rv.SetUint(unitVal & 0xFFFF)
|
||||
case TypeTinyInt:
|
||||
rv.SetUint(unitVal & 0xFF)
|
||||
default:
|
||||
if ^uint(0) == math.MaxUint32 && (int64Val < 0 || int64Val > math.MaxUint32) {
|
||||
return unmarshalErrorf("unmarshal int: value %d out of range for %s", unitVal, rv.Type())
|
||||
}
|
||||
rv.SetUint(unitVal)
|
||||
}
|
||||
rv.SetUint(uint64(int64Val))
|
||||
return nil
|
||||
case reflect.Uint64:
|
||||
if int64Val < 0 {
|
||||
return unmarshalErrorf("unmarshal int: value %d out of range", int64Val)
|
||||
unitVal := uint64(int64Val)
|
||||
switch info.Type() {
|
||||
case TypeInt:
|
||||
rv.SetUint(unitVal & 0xFFFFFFFF)
|
||||
case TypeSmallInt:
|
||||
rv.SetUint(unitVal & 0xFFFF)
|
||||
case TypeTinyInt:
|
||||
rv.SetUint(unitVal & 0xFF)
|
||||
default:
|
||||
rv.SetUint(unitVal)
|
||||
}
|
||||
rv.SetUint(uint64(int64Val))
|
||||
return nil
|
||||
case reflect.Uint32:
|
||||
if int64Val < 0 || int64Val > math.MaxUint32 {
|
||||
return unmarshalErrorf("unmarshal int: value %d out of range", int64Val)
|
||||
unitVal := uint64(int64Val)
|
||||
switch info.Type() {
|
||||
case TypeInt:
|
||||
rv.SetUint(unitVal & 0xFFFFFFFF)
|
||||
case TypeSmallInt:
|
||||
rv.SetUint(unitVal & 0xFFFF)
|
||||
case TypeTinyInt:
|
||||
rv.SetUint(unitVal & 0xFF)
|
||||
default:
|
||||
if int64Val < 0 || int64Val > math.MaxUint32 {
|
||||
return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type())
|
||||
}
|
||||
rv.SetUint(unitVal & 0xFFFFFFFF)
|
||||
}
|
||||
rv.SetUint(uint64(int64Val))
|
||||
return nil
|
||||
case reflect.Uint16:
|
||||
if int64Val < 0 || int64Val > math.MaxUint16 {
|
||||
return unmarshalErrorf("unmarshal int: value %d out of range", int64Val)
|
||||
unitVal := uint64(int64Val)
|
||||
switch info.Type() {
|
||||
case TypeSmallInt:
|
||||
rv.SetUint(unitVal & 0xFFFF)
|
||||
case TypeTinyInt:
|
||||
rv.SetUint(unitVal & 0xFF)
|
||||
default:
|
||||
if int64Val < 0 || int64Val > math.MaxUint16 {
|
||||
return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type())
|
||||
}
|
||||
rv.SetUint(unitVal & 0xFFFF)
|
||||
}
|
||||
rv.SetUint(uint64(int64Val))
|
||||
return nil
|
||||
case reflect.Uint8:
|
||||
if int64Val < 0 || int64Val > math.MaxUint8 {
|
||||
return unmarshalErrorf("unmarshal int: value %d out of range", int64Val)
|
||||
if info.Type() != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) {
|
||||
return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type())
|
||||
}
|
||||
rv.SetUint(uint64(int64Val))
|
||||
rv.SetUint(uint64(int64Val) & 0xff)
|
||||
return nil
|
||||
}
|
||||
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
|
||||
@@ -1090,7 +1133,7 @@ func encBigInt2C(n *big.Int) []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) {
|
||||
func marshalTime(info TypeInfo, value interface{}) ([]byte, error) {
|
||||
switch v := value.(type) {
|
||||
case Marshaler:
|
||||
return v.MarshalCQL(info)
|
||||
@@ -1098,12 +1141,6 @@ func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) {
|
||||
return nil, nil
|
||||
case int64:
|
||||
return encBigInt(v), nil
|
||||
case time.Time:
|
||||
if v.IsZero() {
|
||||
return []byte{}, nil
|
||||
}
|
||||
x := int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6)
|
||||
return encBigInt(x), nil
|
||||
case time.Duration:
|
||||
return encBigInt(v.Nanoseconds()), nil
|
||||
}
|
||||
@@ -1120,6 +1157,59 @@ func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) {
|
||||
return nil, marshalErrorf("can not marshal %T into %s", value, info)
|
||||
}
|
||||
|
||||
func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) {
|
||||
switch v := value.(type) {
|
||||
case Marshaler:
|
||||
return v.MarshalCQL(info)
|
||||
case unsetColumn:
|
||||
return nil, nil
|
||||
case int64:
|
||||
return encBigInt(v), nil
|
||||
case time.Time:
|
||||
if v.IsZero() {
|
||||
return []byte{}, nil
|
||||
}
|
||||
x := int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6)
|
||||
return encBigInt(x), nil
|
||||
}
|
||||
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(value)
|
||||
switch rv.Type().Kind() {
|
||||
case reflect.Int64:
|
||||
return encBigInt(rv.Int()), nil
|
||||
}
|
||||
return nil, marshalErrorf("can not marshal %T into %s", value, info)
|
||||
}
|
||||
|
||||
func unmarshalTime(info TypeInfo, data []byte, value interface{}) error {
|
||||
switch v := value.(type) {
|
||||
case Unmarshaler:
|
||||
return v.UnmarshalCQL(info, data)
|
||||
case *int64:
|
||||
*v = decBigInt(data)
|
||||
return nil
|
||||
case *time.Duration:
|
||||
*v = time.Duration(decBigInt(data))
|
||||
return nil
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(value)
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
|
||||
}
|
||||
rv = rv.Elem()
|
||||
switch rv.Type().Kind() {
|
||||
case reflect.Int64:
|
||||
rv.SetInt(decBigInt(data))
|
||||
return nil
|
||||
}
|
||||
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
|
||||
}
|
||||
|
||||
func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error {
|
||||
switch v := value.(type) {
|
||||
case Unmarshaler:
|
||||
@@ -1137,8 +1227,6 @@ func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error {
|
||||
nsec := (x - sec*1000) * 1000000
|
||||
*v = time.Unix(sec, nsec).In(time.UTC)
|
||||
return nil
|
||||
case *time.Duration:
|
||||
*v = time.Duration(decBigInt(data))
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(value)
|
||||
@@ -1212,6 +1300,16 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error {
|
||||
timestamp := (int64(current) - int64(origin)) * 86400000
|
||||
*v = time.Unix(0, timestamp*int64(time.Millisecond)).In(time.UTC)
|
||||
return nil
|
||||
case *string:
|
||||
if len(data) == 0 {
|
||||
*v = ""
|
||||
return nil
|
||||
}
|
||||
var origin uint32 = 1 << 31
|
||||
var current uint32 = binary.BigEndian.Uint32(data)
|
||||
timestamp := (int64(current) - int64(origin)) * 86400000
|
||||
*v = time.Unix(0, timestamp*int64(time.Millisecond)).In(time.UTC).Format("2006-01-02")
|
||||
return nil
|
||||
}
|
||||
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
|
||||
}
|
||||
@@ -1400,11 +1498,17 @@ func marshalList(info TypeInfo, value interface{}) ([]byte, error) {
|
||||
return nil, marshalErrorf("can not marshal %T into %s", value, info)
|
||||
}
|
||||
|
||||
func readCollectionSize(info CollectionType, data []byte) (size, read int) {
|
||||
func readCollectionSize(info CollectionType, data []byte) (size, read int, err error) {
|
||||
if info.proto > protoVersion2 {
|
||||
if len(data) < 4 {
|
||||
return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof")
|
||||
}
|
||||
size = int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
|
||||
read = 4
|
||||
} else {
|
||||
if len(data) < 2 {
|
||||
return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof")
|
||||
}
|
||||
size = int(data[0])<<8 | int(data[1])
|
||||
read = 2
|
||||
}
|
||||
@@ -1437,10 +1541,10 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error {
|
||||
rv.Set(reflect.Zero(t))
|
||||
return nil
|
||||
}
|
||||
if len(data) < 2 {
|
||||
return unmarshalErrorf("unmarshal list: unexpected eof")
|
||||
n, p, err := readCollectionSize(listInfo, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n, p := readCollectionSize(listInfo, data)
|
||||
data = data[p:]
|
||||
if k == reflect.Array {
|
||||
if rv.Len() != n {
|
||||
@@ -1450,11 +1554,14 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error {
|
||||
rv.Set(reflect.MakeSlice(t, n, n))
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
if len(data) < 2 {
|
||||
m, p, err := readCollectionSize(listInfo, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data = data[p:]
|
||||
if len(data) < m {
|
||||
return unmarshalErrorf("unmarshal list: unexpected eof")
|
||||
}
|
||||
m, p := readCollectionSize(listInfo, data)
|
||||
data = data[p:]
|
||||
if err := Unmarshal(listInfo.Elem, data[:m], rv.Index(i).Addr().Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1478,15 +1585,16 @@ func marshalMap(info TypeInfo, value interface{}) ([]byte, error) {
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(value)
|
||||
if rv.IsNil() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
t := rv.Type()
|
||||
if t.Kind() != reflect.Map {
|
||||
return nil, marshalErrorf("can not marshal %T into %s", value, info)
|
||||
}
|
||||
|
||||
if rv.IsNil() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
n := rv.Len()
|
||||
|
||||
@@ -1537,25 +1645,34 @@ func unmarshalMap(info TypeInfo, data []byte, value interface{}) error {
|
||||
return nil
|
||||
}
|
||||
rv.Set(reflect.MakeMap(t))
|
||||
if len(data) < 2 {
|
||||
return unmarshalErrorf("unmarshal map: unexpected eof")
|
||||
n, p, err := readCollectionSize(mapInfo, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n, p := readCollectionSize(mapInfo, data)
|
||||
data = data[p:]
|
||||
for i := 0; i < n; i++ {
|
||||
if len(data) < 2 {
|
||||
return unmarshalErrorf("unmarshal list: unexpected eof")
|
||||
m, p, err := readCollectionSize(mapInfo, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m, p := readCollectionSize(mapInfo, data)
|
||||
data = data[p:]
|
||||
if len(data) < m {
|
||||
return unmarshalErrorf("unmarshal map: unexpected eof")
|
||||
}
|
||||
key := reflect.New(t.Key())
|
||||
if err := Unmarshal(mapInfo.Key, data[:m], key.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
data = data[m:]
|
||||
|
||||
m, p = readCollectionSize(mapInfo, data)
|
||||
m, p, err = readCollectionSize(mapInfo, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data = data[p:]
|
||||
if len(data) < m {
|
||||
return unmarshalErrorf("unmarshal map: unexpected eof")
|
||||
}
|
||||
val := reflect.New(t.Elem())
|
||||
if err := Unmarshal(mapInfo.Elem, data[:m], val.Interface()); err != nil {
|
||||
return err
|
||||
@@ -1806,8 +1923,9 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
|
||||
for i, elem := range tuple.Elems {
|
||||
// each element inside data is a [bytes]
|
||||
var p []byte
|
||||
p, data = readBytes(data)
|
||||
|
||||
if len(data) > 4 {
|
||||
p, data = readBytes(data)
|
||||
}
|
||||
err := Unmarshal(elem, p, v[i])
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
+66
-3
@@ -22,6 +22,7 @@ type KeyspaceMetadata struct {
|
||||
Tables map[string]*TableMetadata
|
||||
Functions map[string]*FunctionMetadata
|
||||
Aggregates map[string]*AggregateMetadata
|
||||
Views map[string]*ViewMetadata
|
||||
}
|
||||
|
||||
// schema metadata for a table (a.k.a. column family)
|
||||
@@ -81,6 +82,14 @@ type AggregateMetadata struct {
|
||||
finalFunc string
|
||||
}
|
||||
|
||||
// ViewMetadata holds the metadata for views.
|
||||
type ViewMetadata struct {
|
||||
Keyspace string
|
||||
Name string
|
||||
FieldNames []string
|
||||
FieldTypes []TypeInfo
|
||||
}
|
||||
|
||||
// the ordering of the column with regard to its comparator
|
||||
type ColumnOrder bool
|
||||
|
||||
@@ -233,9 +242,13 @@ func (s *schemaDescriber) refreshSchema(keyspaceName string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
views, err := getViewsMetadata(s.session, keyspaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// organize the schema data
|
||||
compileMetadata(s.session.cfg.ProtoVersion, keyspace, tables, columns, functions, aggregates)
|
||||
compileMetadata(s.session.cfg.ProtoVersion, keyspace, tables, columns, functions, aggregates, views)
|
||||
|
||||
// update the cache
|
||||
s.cache[keyspaceName] = keyspace
|
||||
@@ -255,6 +268,7 @@ func compileMetadata(
|
||||
columns []ColumnMetadata,
|
||||
functions []FunctionMetadata,
|
||||
aggregates []AggregateMetadata,
|
||||
views []ViewMetadata,
|
||||
) {
|
||||
keyspace.Tables = make(map[string]*TableMetadata)
|
||||
for i := range tables {
|
||||
@@ -272,6 +286,10 @@ func compileMetadata(
|
||||
aggregate.StateFunc = *keyspace.Functions[aggregate.stateFunc]
|
||||
keyspace.Aggregates[aggregate.Name] = &aggregate
|
||||
}
|
||||
keyspace.Views = make(map[string]*ViewMetadata, len(views))
|
||||
for i := range views {
|
||||
keyspace.Views[views[i].Name] = &views[i]
|
||||
}
|
||||
|
||||
// add columns from the schema data
|
||||
for i := range columns {
|
||||
@@ -849,11 +867,56 @@ func getTypeInfo(t string) TypeInfo {
|
||||
return getCassandraType(t)
|
||||
}
|
||||
|
||||
func getFunctionsMetadata(session *Session, keyspaceName string) ([]FunctionMetadata, error) {
|
||||
func getViewsMetadata(session *Session, keyspaceName string) ([]ViewMetadata, error) {
|
||||
if session.cfg.ProtoVersion == protoVersion1 {
|
||||
return nil, nil
|
||||
}
|
||||
var tableName string
|
||||
if session.useSystemSchema {
|
||||
tableName = "system_schema.types"
|
||||
} else {
|
||||
tableName = "system.schema_usertypes"
|
||||
}
|
||||
stmt := fmt.Sprintf(`
|
||||
SELECT
|
||||
type_name,
|
||||
field_names,
|
||||
field_types
|
||||
FROM %s
|
||||
WHERE keyspace_name = ?`, tableName)
|
||||
|
||||
var views []ViewMetadata
|
||||
|
||||
rows := session.control.query(stmt, keyspaceName).Scanner()
|
||||
for rows.Next() {
|
||||
view := ViewMetadata{Keyspace: keyspaceName}
|
||||
var argumentTypes []string
|
||||
err := rows.Scan(&view.Name,
|
||||
&view.FieldNames,
|
||||
&argumentTypes,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
view.FieldTypes = make([]TypeInfo, len(argumentTypes))
|
||||
for i, argumentType := range argumentTypes {
|
||||
view.FieldTypes[i] = getTypeInfo(argumentType)
|
||||
}
|
||||
views = append(views, view)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return views, nil
|
||||
}
|
||||
|
||||
func getFunctionsMetadata(session *Session, keyspaceName string) ([]FunctionMetadata, error) {
|
||||
if session.cfg.ProtoVersion == protoVersion1 || !session.hasAggregatesAndFunctions {
|
||||
return nil, nil
|
||||
}
|
||||
var tableName string
|
||||
if session.useSystemSchema {
|
||||
tableName = "system_schema.functions"
|
||||
} else {
|
||||
@@ -905,7 +968,7 @@ func getFunctionsMetadata(session *Session, keyspaceName string) ([]FunctionMeta
|
||||
}
|
||||
|
||||
func getAggregatesMetadata(session *Session, keyspaceName string) ([]AggregateMetadata, error) {
|
||||
if session.cfg.ProtoVersion == protoVersion1 {
|
||||
if session.cfg.ProtoVersion == protoVersion1 || !session.hasAggregatesAndFunctions {
|
||||
return nil, nil
|
||||
}
|
||||
var tableName string
|
||||
|
||||
+216
-129
@@ -333,9 +333,8 @@ func RoundRobinHostPolicy() HostSelectionPolicy {
|
||||
}
|
||||
|
||||
type roundRobinHostPolicy struct {
|
||||
hosts cowHostList
|
||||
pos uint32
|
||||
mu sync.RWMutex
|
||||
hosts cowHostList
|
||||
lastUsedHostIdx uint64
|
||||
}
|
||||
|
||||
func (r *roundRobinHostPolicy) IsLocal(*HostInfo) bool { return true }
|
||||
@@ -344,25 +343,8 @@ func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {}
|
||||
func (r *roundRobinHostPolicy) Init(*Session) {}
|
||||
|
||||
func (r *roundRobinHostPolicy) Pick(qry ExecutableQuery) NextHost {
|
||||
// i is used to limit the number of attempts to find a host
|
||||
// to the number of hosts known to this policy
|
||||
var i int
|
||||
return func() SelectedHost {
|
||||
hosts := r.hosts.get()
|
||||
if len(hosts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// always increment pos to evenly distribute traffic in case of
|
||||
// failures
|
||||
pos := atomic.AddUint32(&r.pos, 1) - 1
|
||||
if i >= len(hosts) {
|
||||
return nil
|
||||
}
|
||||
host := hosts[(pos)%uint32(len(hosts))]
|
||||
i++
|
||||
return (*selectedHost)(host)
|
||||
}
|
||||
nextStartOffset := atomic.AddUint64(&r.lastUsedHostIdx, 1)
|
||||
return roundRobbin(int(nextStartOffset), r.hosts.get())
|
||||
}
|
||||
|
||||
func (r *roundRobinHostPolicy) AddHost(host *HostInfo) {
|
||||
@@ -387,6 +369,18 @@ func ShuffleReplicas() func(*tokenAwareHostPolicy) {
|
||||
}
|
||||
}
|
||||
|
||||
// NonLocalReplicasFallback enables fallback to replicas that are not considered local.
|
||||
//
|
||||
// TokenAwareHostPolicy used with DCAwareHostPolicy fallback first selects replicas by partition key in local DC, then
|
||||
// falls back to other nodes in the local DC. Enabling NonLocalReplicasFallback causes TokenAwareHostPolicy
|
||||
// to first select replicas by partition key in local DC, then replicas by partition key in remote DCs and fall back
|
||||
// to other nodes in local DC.
|
||||
func NonLocalReplicasFallback() func(policy *tokenAwareHostPolicy) {
|
||||
return func(t *tokenAwareHostPolicy) {
|
||||
t.nonLocalReplicasFallback = true
|
||||
}
|
||||
}
|
||||
|
||||
// TokenAwareHostPolicy is a token aware host selection policy, where hosts are
|
||||
// selected based on the partition key, so queries are sent to the host which
|
||||
// owns the partition. Fallback is used when routing information is not available.
|
||||
@@ -398,25 +392,35 @@ func TokenAwareHostPolicy(fallback HostSelectionPolicy, opts ...func(*tokenAware
|
||||
return p
|
||||
}
|
||||
|
||||
type keyspaceMeta struct {
|
||||
replicas map[string]map[token][]*HostInfo
|
||||
// clusterMeta holds metadata about cluster topology.
|
||||
// It is used inside atomic.Value and shallow copies are used when replacing it,
|
||||
// so fields should not be modified in-place. Instead, to modify a field a copy of the field should be made
|
||||
// and the pointer in clusterMeta updated to point to the new value.
|
||||
type clusterMeta struct {
|
||||
// replicas is map[keyspace]map[token]hosts
|
||||
replicas map[string]tokenRingReplicas
|
||||
tokenRing *tokenRing
|
||||
}
|
||||
|
||||
type tokenAwareHostPolicy struct {
|
||||
fallback HostSelectionPolicy
|
||||
getKeyspaceMetadata func(keyspace string) (*KeyspaceMetadata, error)
|
||||
getKeyspaceName func() string
|
||||
|
||||
shuffleReplicas bool
|
||||
nonLocalReplicasFallback bool
|
||||
|
||||
// mu protects writes to hosts, partitioner, metadata.
|
||||
// reads can be unlocked as long as they are not used for updating state later.
|
||||
mu sync.Mutex
|
||||
hosts cowHostList
|
||||
mu sync.RWMutex
|
||||
partitioner string
|
||||
fallback HostSelectionPolicy
|
||||
session *Session
|
||||
|
||||
tokenRing atomic.Value // *tokenRing
|
||||
keyspaces atomic.Value // *keyspaceMeta
|
||||
|
||||
shuffleReplicas bool
|
||||
metadata atomic.Value // *clusterMeta
|
||||
}
|
||||
|
||||
func (t *tokenAwareHostPolicy) Init(s *Session) {
|
||||
t.session = s
|
||||
t.getKeyspaceMetadata = s.KeyspaceMetadata
|
||||
t.getKeyspaceName = func() string { return s.cfg.Keyspace }
|
||||
}
|
||||
|
||||
func (t *tokenAwareHostPolicy) IsLocal(host *HostInfo) bool {
|
||||
@@ -424,34 +428,36 @@ func (t *tokenAwareHostPolicy) IsLocal(host *HostInfo) bool {
|
||||
}
|
||||
|
||||
func (t *tokenAwareHostPolicy) KeyspaceChanged(update KeyspaceUpdateEvent) {
|
||||
meta, _ := t.keyspaces.Load().(*keyspaceMeta)
|
||||
var size = 1
|
||||
if meta != nil {
|
||||
size = len(meta.replicas)
|
||||
}
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
meta := t.getMetadataForUpdate()
|
||||
t.updateReplicas(meta, update.Keyspace)
|
||||
t.metadata.Store(meta)
|
||||
}
|
||||
|
||||
newMeta := &keyspaceMeta{
|
||||
replicas: make(map[string]map[token][]*HostInfo, size),
|
||||
}
|
||||
// updateReplicas updates replicas in clusterMeta.
|
||||
// It must be called with t.mu mutex locked.
|
||||
// meta must not be nil and it's replicas field will be updated.
|
||||
func (t *tokenAwareHostPolicy) updateReplicas(meta *clusterMeta, keyspace string) {
|
||||
newReplicas := make(map[string]tokenRingReplicas, len(meta.replicas))
|
||||
|
||||
ks, err := t.session.KeyspaceMetadata(update.Keyspace)
|
||||
ks, err := t.getKeyspaceMetadata(keyspace)
|
||||
if err == nil {
|
||||
strat := getStrategy(ks)
|
||||
tr := t.tokenRing.Load().(*tokenRing)
|
||||
if tr != nil {
|
||||
newMeta.replicas[update.Keyspace] = strat.replicaMap(t.hosts.get(), tr.tokens)
|
||||
}
|
||||
}
|
||||
|
||||
if meta != nil {
|
||||
for ks, replicas := range meta.replicas {
|
||||
if ks != update.Keyspace {
|
||||
newMeta.replicas[ks] = replicas
|
||||
if strat != nil {
|
||||
if meta != nil && meta.tokenRing != nil {
|
||||
newReplicas[keyspace] = strat.replicaMap(meta.tokenRing)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.keyspaces.Store(newMeta)
|
||||
for ks, replicas := range meta.replicas {
|
||||
if ks != keyspace {
|
||||
newReplicas[ks] = replicas
|
||||
}
|
||||
}
|
||||
|
||||
meta.replicas = newReplicas
|
||||
}
|
||||
|
||||
func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) {
|
||||
@@ -461,50 +467,96 @@ func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) {
|
||||
if t.partitioner != partitioner {
|
||||
t.fallback.SetPartitioner(partitioner)
|
||||
t.partitioner = partitioner
|
||||
|
||||
t.resetTokenRing(partitioner)
|
||||
meta := t.getMetadataForUpdate()
|
||||
meta.resetTokenRing(t.partitioner, t.hosts.get())
|
||||
t.updateReplicas(meta, t.getKeyspaceName())
|
||||
t.metadata.Store(meta)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) {
|
||||
t.hosts.add(host)
|
||||
t.fallback.AddHost(host)
|
||||
t.mu.Lock()
|
||||
if t.hosts.add(host) {
|
||||
meta := t.getMetadataForUpdate()
|
||||
meta.resetTokenRing(t.partitioner, t.hosts.get())
|
||||
t.updateReplicas(meta, t.getKeyspaceName())
|
||||
t.metadata.Store(meta)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
t.mu.RLock()
|
||||
partitioner := t.partitioner
|
||||
t.mu.RUnlock()
|
||||
t.resetTokenRing(partitioner)
|
||||
t.fallback.AddHost(host)
|
||||
}
|
||||
|
||||
func (t *tokenAwareHostPolicy) AddHosts(hosts []*HostInfo) {
|
||||
t.mu.Lock()
|
||||
|
||||
for _, host := range hosts {
|
||||
t.hosts.add(host)
|
||||
}
|
||||
|
||||
meta := t.getMetadataForUpdate()
|
||||
meta.resetTokenRing(t.partitioner, t.hosts.get())
|
||||
t.updateReplicas(meta, t.getKeyspaceName())
|
||||
t.metadata.Store(meta)
|
||||
|
||||
t.mu.Unlock()
|
||||
|
||||
for _, host := range hosts {
|
||||
t.fallback.AddHost(host)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tokenAwareHostPolicy) RemoveHost(host *HostInfo) {
|
||||
t.hosts.remove(host.ConnectAddress())
|
||||
t.fallback.RemoveHost(host)
|
||||
t.mu.Lock()
|
||||
if t.hosts.remove(host.ConnectAddress()) {
|
||||
meta := t.getMetadataForUpdate()
|
||||
meta.resetTokenRing(t.partitioner, t.hosts.get())
|
||||
t.updateReplicas(meta, t.getKeyspaceName())
|
||||
t.metadata.Store(meta)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
t.mu.RLock()
|
||||
partitioner := t.partitioner
|
||||
t.mu.RUnlock()
|
||||
t.resetTokenRing(partitioner)
|
||||
t.fallback.RemoveHost(host)
|
||||
}
|
||||
|
||||
func (t *tokenAwareHostPolicy) HostUp(host *HostInfo) {
|
||||
// TODO: need to avoid doing all the work on AddHost on hostup/down
|
||||
// because it now expensive to calculate the replica map for each
|
||||
// token
|
||||
t.AddHost(host)
|
||||
t.fallback.HostUp(host)
|
||||
}
|
||||
|
||||
func (t *tokenAwareHostPolicy) HostDown(host *HostInfo) {
|
||||
t.RemoveHost(host)
|
||||
t.fallback.HostDown(host)
|
||||
}
|
||||
|
||||
func (t *tokenAwareHostPolicy) resetTokenRing(partitioner string) {
|
||||
// getMetadataReadOnly returns current cluster metadata.
|
||||
// Metadata uses copy on write, so the returned value should be only used for reading.
|
||||
// To obtain a copy that could be updated, use getMetadataForUpdate instead.
|
||||
func (t *tokenAwareHostPolicy) getMetadataReadOnly() *clusterMeta {
|
||||
meta, _ := t.metadata.Load().(*clusterMeta)
|
||||
return meta
|
||||
}
|
||||
|
||||
// getMetadataForUpdate returns clusterMeta suitable for updating.
|
||||
// It is a SHALLOW copy of current metadata in case it was already set or new empty clusterMeta otherwise.
|
||||
// This function should be called with t.mu mutex locked and the mutex should not be released before
|
||||
// storing the new metadata.
|
||||
func (t *tokenAwareHostPolicy) getMetadataForUpdate() *clusterMeta {
|
||||
metaReadOnly := t.getMetadataReadOnly()
|
||||
meta := new(clusterMeta)
|
||||
if metaReadOnly != nil {
|
||||
*meta = *metaReadOnly
|
||||
}
|
||||
return meta
|
||||
}
|
||||
|
||||
// resetTokenRing creates a new tokenRing.
|
||||
// It must be called with t.mu locked.
|
||||
func (m *clusterMeta) resetTokenRing(partitioner string, hosts []*HostInfo) {
|
||||
if partitioner == "" {
|
||||
// partitioner not yet set
|
||||
return
|
||||
}
|
||||
|
||||
// create a new token ring
|
||||
hosts := t.hosts.get()
|
||||
tokenRing, err := newTokenRing(partitioner, hosts)
|
||||
if err != nil {
|
||||
Logger.Printf("Unable to update the token ring due to error: %s", err)
|
||||
@@ -512,16 +564,7 @@ func (t *tokenAwareHostPolicy) resetTokenRing(partitioner string) {
|
||||
}
|
||||
|
||||
// replace the token ring
|
||||
t.tokenRing.Store(tokenRing)
|
||||
}
|
||||
|
||||
func (t *tokenAwareHostPolicy) getReplicas(keyspace string, token token) ([]*HostInfo, bool) {
|
||||
meta, _ := t.keyspaces.Load().(*keyspaceMeta)
|
||||
if meta == nil {
|
||||
return nil, false
|
||||
}
|
||||
tokens, ok := meta.replicas[keyspace][token]
|
||||
return tokens, ok
|
||||
m.tokenRing = tokenRing
|
||||
}
|
||||
|
||||
func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
|
||||
@@ -536,28 +579,29 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
|
||||
return t.fallback.Pick(qry)
|
||||
}
|
||||
|
||||
tr, _ := t.tokenRing.Load().(*tokenRing)
|
||||
if tr == nil {
|
||||
meta := t.getMetadataReadOnly()
|
||||
if meta == nil || meta.tokenRing == nil {
|
||||
return t.fallback.Pick(qry)
|
||||
}
|
||||
|
||||
token := tr.partitioner.Hash(routingKey)
|
||||
primaryEndpoint := tr.GetHostForToken(token)
|
||||
token := meta.tokenRing.partitioner.Hash(routingKey)
|
||||
ht := meta.replicas[qry.Keyspace()].replicasFor(token)
|
||||
|
||||
if primaryEndpoint == nil || token == nil {
|
||||
return t.fallback.Pick(qry)
|
||||
}
|
||||
|
||||
replicas, ok := t.getReplicas(qry.Keyspace(), token)
|
||||
if !ok {
|
||||
replicas = []*HostInfo{primaryEndpoint}
|
||||
} else if t.shuffleReplicas {
|
||||
replicas = shuffleHosts(replicas)
|
||||
var replicas []*HostInfo
|
||||
if ht == nil {
|
||||
host, _ := meta.tokenRing.GetHostForToken(token)
|
||||
replicas = []*HostInfo{host}
|
||||
} else {
|
||||
replicas = ht.hosts
|
||||
if t.shuffleReplicas {
|
||||
replicas = shuffleHosts(replicas)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
fallbackIter NextHost
|
||||
i int
|
||||
i, j int
|
||||
remote []*HostInfo
|
||||
)
|
||||
|
||||
used := make(map[*HostInfo]bool, len(replicas))
|
||||
@@ -566,12 +610,29 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
|
||||
h := replicas[i]
|
||||
i++
|
||||
|
||||
if h.IsUp() && t.fallback.IsLocal(h) {
|
||||
if !t.fallback.IsLocal(h) {
|
||||
remote = append(remote, h)
|
||||
continue
|
||||
}
|
||||
|
||||
if h.IsUp() {
|
||||
used[h] = true
|
||||
return (*selectedHost)(h)
|
||||
}
|
||||
}
|
||||
|
||||
if t.nonLocalReplicasFallback {
|
||||
for j < len(remote) {
|
||||
h := remote[j]
|
||||
j++
|
||||
|
||||
if h.IsUp() {
|
||||
used[h] = true
|
||||
return (*selectedHost)(h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fallbackIter == nil {
|
||||
// fallback
|
||||
fallbackIter = t.fallback.Pick(qry)
|
||||
@@ -580,9 +641,11 @@ func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost {
|
||||
// filter the token aware selected hosts from the fallback hosts
|
||||
for fallbackHost := fallbackIter(); fallbackHost != nil; fallbackHost = fallbackIter() {
|
||||
if !used[fallbackHost.Info()] {
|
||||
used[fallbackHost.Info()] = true
|
||||
return fallbackHost
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -730,11 +793,10 @@ func (host selectedHostPoolHost) Mark(err error) {
|
||||
}
|
||||
|
||||
type dcAwareRR struct {
|
||||
local string
|
||||
pos uint32
|
||||
mu sync.RWMutex
|
||||
localHosts cowHostList
|
||||
remoteHosts cowHostList
|
||||
local string
|
||||
localHosts cowHostList
|
||||
remoteHosts cowHostList
|
||||
lastUsedHostIdx uint64
|
||||
}
|
||||
|
||||
// DCAwareRoundRobinPolicy is a host selection policies which will prioritize and
|
||||
@@ -753,7 +815,7 @@ func (d *dcAwareRR) IsLocal(host *HostInfo) bool {
|
||||
}
|
||||
|
||||
func (d *dcAwareRR) AddHost(host *HostInfo) {
|
||||
if host.DataCenter() == d.local {
|
||||
if d.IsLocal(host) {
|
||||
d.localHosts.add(host)
|
||||
} else {
|
||||
d.remoteHosts.add(host)
|
||||
@@ -761,7 +823,7 @@ func (d *dcAwareRR) AddHost(host *HostInfo) {
|
||||
}
|
||||
|
||||
func (d *dcAwareRR) RemoveHost(host *HostInfo) {
|
||||
if host.DataCenter() == d.local {
|
||||
if d.IsLocal(host) {
|
||||
d.localHosts.remove(host.ConnectAddress())
|
||||
} else {
|
||||
d.remoteHosts.remove(host.ConnectAddress())
|
||||
@@ -771,33 +833,53 @@ func (d *dcAwareRR) RemoveHost(host *HostInfo) {
|
||||
func (d *dcAwareRR) HostUp(host *HostInfo) { d.AddHost(host) }
|
||||
func (d *dcAwareRR) HostDown(host *HostInfo) { d.RemoveHost(host) }
|
||||
|
||||
func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost {
|
||||
var i int
|
||||
return func() SelectedHost {
|
||||
var hosts []*HostInfo
|
||||
localHosts := d.localHosts.get()
|
||||
remoteHosts := d.remoteHosts.get()
|
||||
if len(localHosts) != 0 {
|
||||
hosts = localHosts
|
||||
} else {
|
||||
hosts = remoteHosts
|
||||
}
|
||||
if len(hosts) == 0 {
|
||||
return nil
|
||||
}
|
||||
// This function is supposed to be called in a fashion
|
||||
// roundRobbin(offset, hostsPriority1, hostsPriority2, hostsPriority3 ... )
|
||||
//
|
||||
// E.g. for DC-naive strategy:
|
||||
// roundRobbin(offset, allHosts)
|
||||
//
|
||||
// For tiered and DC-aware strategy:
|
||||
// roundRobbin(offset, localHosts, remoteHosts)
|
||||
func roundRobbin(shift int, hosts ...[]*HostInfo) NextHost {
|
||||
currentLayer := 0
|
||||
currentlyObserved := 0
|
||||
|
||||
// always increment pos to evenly distribute traffic in case of
|
||||
// failures
|
||||
pos := atomic.AddUint32(&d.pos, 1) - 1
|
||||
if i >= len(localHosts)+len(remoteHosts) {
|
||||
return nil
|
||||
return func() SelectedHost {
|
||||
|
||||
// iterate over layers
|
||||
for {
|
||||
if currentLayer == len(hosts) {
|
||||
return nil
|
||||
}
|
||||
|
||||
currentLayerSize := len(hosts[currentLayer])
|
||||
|
||||
// iterate over hosts within a layer
|
||||
for {
|
||||
currentlyObserved++
|
||||
if currentlyObserved > currentLayerSize {
|
||||
currentLayer++
|
||||
currentlyObserved = 0
|
||||
break
|
||||
}
|
||||
|
||||
h := hosts[currentLayer][(shift+currentlyObserved)%currentLayerSize]
|
||||
|
||||
if h.IsUp() {
|
||||
return (*selectedHost)(h)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
host := hosts[(pos)%uint32(len(hosts))]
|
||||
i++
|
||||
return (*selectedHost)(host)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost {
|
||||
nextStartOffset := atomic.AddUint64(&d.lastUsedHostIdx, 1)
|
||||
return roundRobbin(int(nextStartOffset), d.localHosts.get(), d.remoteHosts.get())
|
||||
}
|
||||
|
||||
// ConvictionPolicy interface is used by gocql to determine if a host should be
|
||||
// marked as DOWN based on the error and host info
|
||||
type ConvictionPolicy interface {
|
||||
@@ -850,10 +932,15 @@ func (c *ConstantReconnectionPolicy) GetMaxRetries() int {
|
||||
type ExponentialReconnectionPolicy struct {
|
||||
MaxRetries int
|
||||
InitialInterval time.Duration
|
||||
MaxInterval time.Duration
|
||||
}
|
||||
|
||||
func (e *ExponentialReconnectionPolicy) GetInterval(currentRetry int) time.Duration {
|
||||
return getExponentialTime(e.InitialInterval, math.MaxInt16*time.Second, e.GetMaxRetries())
|
||||
max := e.MaxInterval
|
||||
if max < e.InitialInterval {
|
||||
max = math.MaxInt16 * time.Second
|
||||
}
|
||||
return getExponentialTime(e.InitialInterval, max, currentRetry)
|
||||
}
|
||||
|
||||
func (e *ExponentialReconnectionPolicy) GetMaxRetries() int {
|
||||
|
||||
+26
-1
@@ -1,6 +1,7 @@
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/gocql/gocql/internal/lru"
|
||||
"sync"
|
||||
)
|
||||
@@ -59,6 +60,30 @@ func (p *preparedLRU) execIfMissing(key string, fn func(lru *lru.Cache) *infligh
|
||||
}
|
||||
|
||||
func (p *preparedLRU) keyFor(addr, keyspace, statement string) string {
|
||||
// TODO: maybe use []byte for keys?
|
||||
// TODO: we should just use a struct for the key in the map
|
||||
return addr + keyspace + statement
|
||||
}
|
||||
|
||||
func (p *preparedLRU) evictPreparedID(key string, id []byte) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
val, ok := p.lru.Get(key)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
ifp, ok := val.(*inflightPrepare)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ifp.done:
|
||||
if bytes.Equal(id, ifp.preparedStatment.id) {
|
||||
p.lru.Remove(key)
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user