Merge remote-tracking branch 'pgproto3/master' into v5-dev

Pull in pgproto3 changes and update for pgx v5
non-blocking
Jack Christensen 2022-04-23 10:43:48 -05:00
commit a92f1df1df
8 changed files with 240 additions and 3 deletions

View File

@ -0,0 +1,59 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
type AuthenticationGSS struct{}
func (a *AuthenticationGSS) Backend() {}
func (a *AuthenticationGSS) AuthenticationResponse() {}
func (a *AuthenticationGSS) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeGSS {
return errors.New("bad auth type")
}
return nil
}
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 4)
dst = pgio.AppendUint32(dst, AuthTypeGSS)
return dst
}
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "AuthenticationGSS",
})
}
func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,68 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
type AuthenticationGSSContinue struct {
Data []byte
}
func (a *AuthenticationGSSContinue) Backend() {}
func (a *AuthenticationGSSContinue) AuthenticationResponse() {}
func (a *AuthenticationGSSContinue) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeGSSCont {
return errors.New("bad auth type")
}
a.Data = src[4:]
return nil
}
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
dst = append(dst, a.Data...)
return dst
}
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "AuthenticationGSSContinue",
Data: a.Data,
})
}
func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
Data []byte
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
a.Data = msg.Data
return nil
}

View File

@ -147,6 +147,8 @@ func (b *Backend) Receive() (FrontendMessage, error) {
msg = &SASLResponse{}
case AuthTypeSASLFinal:
msg = &SASLResponse{}
case AuthTypeGSS, AuthTypeGSSCont:
msg = &GSSResponse{}
case AuthTypeCleartextPassword, AuthTypeMD5Password:
fallthrough
default:

View File

@ -48,7 +48,7 @@ func (src *CopyBothResponse) Encode(dst []byte) []byte {
dst = append(dst, 'W')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.OverallFormat)
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)

View File

@ -0,0 +1,18 @@
package pgproto3_test
import (
"testing"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/stretchr/testify/assert"
)
func TestEncodeDecode(t *testing.T) {
srcBytes := []byte{'W', 0x00, 0x00, 0x00, 0x0b, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01}
dstResp := pgproto3.CopyBothResponse{}
err := dstResp.Decode(srcBytes[5:])
assert.NoError(t, err, "No errors on decode")
dstBytes := []byte{}
dstBytes = dstResp.Encode(dstBytes)
assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match")
}

View File

@ -16,6 +16,8 @@ type Frontend struct {
authenticationOk AuthenticationOk
authenticationCleartextPassword AuthenticationCleartextPassword
authenticationMD5Password AuthenticationMD5Password
authenticationGSS AuthenticationGSS
authenticationGSSContinue AuthenticationGSSContinue
authenticationSASL AuthenticationSASL
authenticationSASLContinue AuthenticationSASLContinue
authenticationSASLFinal AuthenticationSASLFinal
@ -179,9 +181,9 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er
case AuthTypeSCMCreds:
return nil, errors.New("AuthTypeSCMCreds is unimplemented")
case AuthTypeGSS:
return nil, errors.New("AuthTypeGSS is unimplemented")
return &f.authenticationGSS, nil
case AuthTypeGSSCont:
return nil, errors.New("AuthTypeGSSCont is unimplemented")
return &f.authenticationGSSContinue, nil
case AuthTypeSSPI:
return nil, errors.New("AuthTypeSSPI is unimplemented")
case AuthTypeSASL:

49
pgproto3/gss_response.go Normal file
View File

@ -0,0 +1,49 @@
package pgproto3
import (
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type GSSResponse struct {
Data []byte
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (g *GSSResponse) Frontend() {}
func (g *GSSResponse) Decode(data []byte) error {
g.Data = data
return nil
}
func (g *GSSResponse) Encode(dst []byte) []byte {
dst = append(dst, 'p')
dst = pgio.AppendInt32(dst, int32(4+len(g.Data)))
dst = append(dst, g.Data...)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (g *GSSResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "GSSResponse",
Data: g.Data,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (g *GSSResponse) UnmarshalJSON(data []byte) error {
var msg struct {
Data []byte
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
g.Data = msg.Data
return nil
}

View File

@ -37,6 +37,32 @@ func TestJSONUnmarshalAuthenticationSASL(t *testing.T) {
}
}
func TestJSONUnmarshalAuthenticationGSS(t *testing.T) {
data := []byte(`{"Type":"AuthenticationGSS"}`)
want := AuthenticationGSS{}
var got AuthenticationGSS
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled AuthenticationGSS struct doesn't match expected value")
}
}
func TestJSONUnmarshalAuthenticationGSSContinue(t *testing.T) {
data := []byte(`{"Type":"AuthenticationGSSContinue","Data":[1,2,3,4]}`)
want := AuthenticationGSSContinue{Data: []byte{1, 2, 3, 4}}
var got AuthenticationGSSContinue
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled AuthenticationGSSContinue struct doesn't match expected value")
}
}
func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) {
data := []byte(`{"Type":"AuthenticationSASLContinue", "Data":"1"}`)
want := AuthenticationSASLContinue{
@ -551,6 +577,19 @@ func TestAuthenticationMD5Password(t *testing.T) {
}
}
func TestJSONUnmarshalGSSResponse(t *testing.T) {
data := []byte(`{"Type":"GSSResponse","Data":[10,20,30,40]}`)
want := GSSResponse{Data: []byte{10, 20, 30, 40}}
var got GSSResponse
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled GSSResponse struct doesn't match expected value")
}
}
func TestErrorResponse(t *testing.T) {
data := []byte(`{"Type":"ErrorResponse","UnknownFields":{"112":"foo"},"Code": "Fail","Position":1,"Message":"this is an error"}`)
want := ErrorResponse{