mirror of https://github.com/jackc/pgx.git
Merge remote-tracking branch 'pgproto3/master' into v5-dev
Pull in pgproto3 changes and update for pgx v5non-blocking
commit
a92f1df1df
|
@ -0,0 +1,59 @@
|
||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthenticationGSS struct{}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) Backend() {}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeGSS {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, 4)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeGSS)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data []byte
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationGSS",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Type string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,68 @@
|
||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthenticationGSSContinue struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) Backend() {}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeGSSCont {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
a.Data = src[4:]
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
|
||||||
|
dst = append(dst, a.Data...)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data []byte
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationGSSContinue",
|
||||||
|
Data: a.Data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Type string
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
a.Data = msg.Data
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -147,6 +147,8 @@ func (b *Backend) Receive() (FrontendMessage, error) {
|
||||||
msg = &SASLResponse{}
|
msg = &SASLResponse{}
|
||||||
case AuthTypeSASLFinal:
|
case AuthTypeSASLFinal:
|
||||||
msg = &SASLResponse{}
|
msg = &SASLResponse{}
|
||||||
|
case AuthTypeGSS, AuthTypeGSSCont:
|
||||||
|
msg = &GSSResponse{}
|
||||||
case AuthTypeCleartextPassword, AuthTypeMD5Password:
|
case AuthTypeCleartextPassword, AuthTypeMD5Password:
|
||||||
fallthrough
|
fallthrough
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -48,7 +48,7 @@ func (src *CopyBothResponse) Encode(dst []byte) []byte {
|
||||||
dst = append(dst, 'W')
|
dst = append(dst, 'W')
|
||||||
sp := len(dst)
|
sp := len(dst)
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
dst = append(dst, src.OverallFormat)
|
||||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||||
for _, fc := range src.ColumnFormatCodes {
|
for _, fc := range src.ColumnFormatCodes {
|
||||||
dst = pgio.AppendUint16(dst, fc)
|
dst = pgio.AppendUint16(dst, fc)
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
package pgproto3_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgproto3"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEncodeDecode(t *testing.T) {
|
||||||
|
srcBytes := []byte{'W', 0x00, 0x00, 0x00, 0x0b, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01}
|
||||||
|
dstResp := pgproto3.CopyBothResponse{}
|
||||||
|
err := dstResp.Decode(srcBytes[5:])
|
||||||
|
assert.NoError(t, err, "No errors on decode")
|
||||||
|
dstBytes := []byte{}
|
||||||
|
dstBytes = dstResp.Encode(dstBytes)
|
||||||
|
assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match")
|
||||||
|
}
|
|
@ -16,6 +16,8 @@ type Frontend struct {
|
||||||
authenticationOk AuthenticationOk
|
authenticationOk AuthenticationOk
|
||||||
authenticationCleartextPassword AuthenticationCleartextPassword
|
authenticationCleartextPassword AuthenticationCleartextPassword
|
||||||
authenticationMD5Password AuthenticationMD5Password
|
authenticationMD5Password AuthenticationMD5Password
|
||||||
|
authenticationGSS AuthenticationGSS
|
||||||
|
authenticationGSSContinue AuthenticationGSSContinue
|
||||||
authenticationSASL AuthenticationSASL
|
authenticationSASL AuthenticationSASL
|
||||||
authenticationSASLContinue AuthenticationSASLContinue
|
authenticationSASLContinue AuthenticationSASLContinue
|
||||||
authenticationSASLFinal AuthenticationSASLFinal
|
authenticationSASLFinal AuthenticationSASLFinal
|
||||||
|
@ -179,9 +181,9 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er
|
||||||
case AuthTypeSCMCreds:
|
case AuthTypeSCMCreds:
|
||||||
return nil, errors.New("AuthTypeSCMCreds is unimplemented")
|
return nil, errors.New("AuthTypeSCMCreds is unimplemented")
|
||||||
case AuthTypeGSS:
|
case AuthTypeGSS:
|
||||||
return nil, errors.New("AuthTypeGSS is unimplemented")
|
return &f.authenticationGSS, nil
|
||||||
case AuthTypeGSSCont:
|
case AuthTypeGSSCont:
|
||||||
return nil, errors.New("AuthTypeGSSCont is unimplemented")
|
return &f.authenticationGSSContinue, nil
|
||||||
case AuthTypeSSPI:
|
case AuthTypeSSPI:
|
||||||
return nil, errors.New("AuthTypeSSPI is unimplemented")
|
return nil, errors.New("AuthTypeSSPI is unimplemented")
|
||||||
case AuthTypeSASL:
|
case AuthTypeSASL:
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GSSResponse struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (g *GSSResponse) Frontend() {}
|
||||||
|
|
||||||
|
func (g *GSSResponse) Decode(data []byte) error {
|
||||||
|
g.Data = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GSSResponse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'p')
|
||||||
|
dst = pgio.AppendInt32(dst, int32(4+len(g.Data)))
|
||||||
|
dst = append(dst, g.Data...)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (g *GSSResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data []byte
|
||||||
|
}{
|
||||||
|
Type: "GSSResponse",
|
||||||
|
Data: g.Data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (g *GSSResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
var msg struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
g.Data = msg.Data
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -37,6 +37,32 @@ func TestJSONUnmarshalAuthenticationSASL(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalAuthenticationGSS(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"AuthenticationGSS"}`)
|
||||||
|
want := AuthenticationGSS{}
|
||||||
|
|
||||||
|
var got AuthenticationGSS
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled AuthenticationGSS struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalAuthenticationGSSContinue(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"AuthenticationGSSContinue","Data":[1,2,3,4]}`)
|
||||||
|
want := AuthenticationGSSContinue{Data: []byte{1, 2, 3, 4}}
|
||||||
|
|
||||||
|
var got AuthenticationGSSContinue
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled AuthenticationGSSContinue struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) {
|
func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) {
|
||||||
data := []byte(`{"Type":"AuthenticationSASLContinue", "Data":"1"}`)
|
data := []byte(`{"Type":"AuthenticationSASLContinue", "Data":"1"}`)
|
||||||
want := AuthenticationSASLContinue{
|
want := AuthenticationSASLContinue{
|
||||||
|
@ -551,6 +577,19 @@ func TestAuthenticationMD5Password(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJSONUnmarshalGSSResponse(t *testing.T) {
|
||||||
|
data := []byte(`{"Type":"GSSResponse","Data":[10,20,30,40]}`)
|
||||||
|
want := GSSResponse{Data: []byte{10, 20, 30, 40}}
|
||||||
|
|
||||||
|
var got GSSResponse
|
||||||
|
if err := json.Unmarshal(data, &got); err != nil {
|
||||||
|
t.Errorf("cannot JSON unmarshal %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Error("unmarshaled GSSResponse struct doesn't match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestErrorResponse(t *testing.T) {
|
func TestErrorResponse(t *testing.T) {
|
||||||
data := []byte(`{"Type":"ErrorResponse","UnknownFields":{"112":"foo"},"Code": "Fail","Position":1,"Message":"this is an error"}`)
|
data := []byte(`{"Type":"ErrorResponse","UnknownFields":{"112":"foo"},"Code": "Fail","Position":1,"Message":"this is an error"}`)
|
||||||
want := ErrorResponse{
|
want := ErrorResponse{
|
||||||
|
|
Loading…
Reference in New Issue