mirror of https://github.com/jackc/pgx.git
Merge remote-tracking branch 'pgproto3/master' into v5-dev
Pull in pgproto3 changes and update for pgx v5non-blocking
commit
a92f1df1df
|
@ -0,0 +1,59 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type AuthenticationGSS struct{}
|
||||
|
||||
func (a *AuthenticationGSS) Backend() {}
|
||||
|
||||
func (a *AuthenticationGSS) AuthenticationResponse() {}
|
||||
|
||||
func (a *AuthenticationGSS) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return errors.New("authentication message too short")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeGSS {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 4)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeGSS)
|
||||
return dst
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data []byte
|
||||
}{
|
||||
Type: "AuthenticationGSS",
|
||||
})
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Type string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type AuthenticationGSSContinue struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) Backend() {}
|
||||
|
||||
func (a *AuthenticationGSSContinue) AuthenticationResponse() {}
|
||||
|
||||
func (a *AuthenticationGSSContinue) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return errors.New("authentication message too short")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeGSSCont {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
|
||||
a.Data = src[4:]
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
|
||||
dst = append(dst, a.Data...)
|
||||
return dst
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data []byte
|
||||
}{
|
||||
Type: "AuthenticationGSSContinue",
|
||||
Data: a.Data,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Type string
|
||||
Data []byte
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.Data = msg.Data
|
||||
return nil
|
||||
}
|
|
@ -147,6 +147,8 @@ func (b *Backend) Receive() (FrontendMessage, error) {
|
|||
msg = &SASLResponse{}
|
||||
case AuthTypeSASLFinal:
|
||||
msg = &SASLResponse{}
|
||||
case AuthTypeGSS, AuthTypeGSSCont:
|
||||
msg = &GSSResponse{}
|
||||
case AuthTypeCleartextPassword, AuthTypeMD5Password:
|
||||
fallthrough
|
||||
default:
|
||||
|
|
|
@ -48,7 +48,7 @@ func (src *CopyBothResponse) Encode(dst []byte) []byte {
|
|||
dst = append(dst, 'W')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.OverallFormat)
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
package pgproto3_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEncodeDecode(t *testing.T) {
|
||||
srcBytes := []byte{'W', 0x00, 0x00, 0x00, 0x0b, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01}
|
||||
dstResp := pgproto3.CopyBothResponse{}
|
||||
err := dstResp.Decode(srcBytes[5:])
|
||||
assert.NoError(t, err, "No errors on decode")
|
||||
dstBytes := []byte{}
|
||||
dstBytes = dstResp.Encode(dstBytes)
|
||||
assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match")
|
||||
}
|
|
@ -16,6 +16,8 @@ type Frontend struct {
|
|||
authenticationOk AuthenticationOk
|
||||
authenticationCleartextPassword AuthenticationCleartextPassword
|
||||
authenticationMD5Password AuthenticationMD5Password
|
||||
authenticationGSS AuthenticationGSS
|
||||
authenticationGSSContinue AuthenticationGSSContinue
|
||||
authenticationSASL AuthenticationSASL
|
||||
authenticationSASLContinue AuthenticationSASLContinue
|
||||
authenticationSASLFinal AuthenticationSASLFinal
|
||||
|
@ -179,9 +181,9 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er
|
|||
case AuthTypeSCMCreds:
|
||||
return nil, errors.New("AuthTypeSCMCreds is unimplemented")
|
||||
case AuthTypeGSS:
|
||||
return nil, errors.New("AuthTypeGSS is unimplemented")
|
||||
return &f.authenticationGSS, nil
|
||||
case AuthTypeGSSCont:
|
||||
return nil, errors.New("AuthTypeGSSCont is unimplemented")
|
||||
return &f.authenticationGSSContinue, nil
|
||||
case AuthTypeSSPI:
|
||||
return nil, errors.New("AuthTypeSSPI is unimplemented")
|
||||
case AuthTypeSASL:
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type GSSResponse struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (g *GSSResponse) Frontend() {}
|
||||
|
||||
func (g *GSSResponse) Decode(data []byte) error {
|
||||
g.Data = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *GSSResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'p')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(g.Data)))
|
||||
dst = append(dst, g.Data...)
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (g *GSSResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data []byte
|
||||
}{
|
||||
Type: "GSSResponse",
|
||||
Data: g.Data,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (g *GSSResponse) UnmarshalJSON(data []byte) error {
|
||||
var msg struct {
|
||||
Data []byte
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
g.Data = msg.Data
|
||||
return nil
|
||||
}
|
|
@ -37,6 +37,32 @@ func TestJSONUnmarshalAuthenticationSASL(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestJSONUnmarshalAuthenticationGSS(t *testing.T) {
|
||||
data := []byte(`{"Type":"AuthenticationGSS"}`)
|
||||
want := AuthenticationGSS{}
|
||||
|
||||
var got AuthenticationGSS
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Errorf("cannot JSON unmarshal %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Error("unmarshaled AuthenticationGSS struct doesn't match expected value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONUnmarshalAuthenticationGSSContinue(t *testing.T) {
|
||||
data := []byte(`{"Type":"AuthenticationGSSContinue","Data":[1,2,3,4]}`)
|
||||
want := AuthenticationGSSContinue{Data: []byte{1, 2, 3, 4}}
|
||||
|
||||
var got AuthenticationGSSContinue
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Errorf("cannot JSON unmarshal %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Error("unmarshaled AuthenticationGSSContinue struct doesn't match expected value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) {
|
||||
data := []byte(`{"Type":"AuthenticationSASLContinue", "Data":"1"}`)
|
||||
want := AuthenticationSASLContinue{
|
||||
|
@ -551,6 +577,19 @@ func TestAuthenticationMD5Password(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestJSONUnmarshalGSSResponse(t *testing.T) {
|
||||
data := []byte(`{"Type":"GSSResponse","Data":[10,20,30,40]}`)
|
||||
want := GSSResponse{Data: []byte{10, 20, 30, 40}}
|
||||
|
||||
var got GSSResponse
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Errorf("cannot JSON unmarshal %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Error("unmarshaled GSSResponse struct doesn't match expected value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorResponse(t *testing.T) {
|
||||
data := []byte(`{"Type":"ErrorResponse","UnknownFields":{"112":"foo"},"Code": "Fail","Position":1,"Message":"this is an error"}`)
|
||||
want := ErrorResponse{
|
||||
|
|
Loading…
Reference in New Issue