diff --git a/pgproto3/authentication_gss.go b/pgproto3/authentication_gss.go new file mode 100644 index 00000000..0d234222 --- /dev/null +++ b/pgproto3/authentication_gss.go @@ -0,0 +1,59 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type AuthenticationGSS struct{} + +func (a *AuthenticationGSS) Backend() {} + +func (a *AuthenticationGSS) AuthenticationResponse() {} + +func (a *AuthenticationGSS) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeGSS { + return errors.New("bad auth type") + } + return nil +} + +func (a *AuthenticationGSS) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, 4) + dst = pgio.AppendUint32(dst, AuthTypeGSS) + return dst +} + +func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "AuthenticationGSS", + }) +} + +func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + return nil +} diff --git a/pgproto3/authentication_gss_continue.go b/pgproto3/authentication_gss_continue.go new file mode 100644 index 00000000..63789dc1 --- /dev/null +++ b/pgproto3/authentication_gss_continue.go @@ -0,0 +1,68 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type AuthenticationGSSContinue struct { + Data []byte +} + +func (a *AuthenticationGSSContinue) Backend() {} + +func (a *AuthenticationGSSContinue) AuthenticationResponse() {} + +func (a *AuthenticationGSSContinue) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeGSSCont { + return errors.New("bad auth type") + } + + a.Data = src[4:] + return nil +} + +func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, int32(len(a.Data))+8) + dst = pgio.AppendUint32(dst, AuthTypeGSSCont) + dst = append(dst, a.Data...) + return dst +} + +func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "AuthenticationGSSContinue", + Data: a.Data, + }) +} + +func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Data []byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + a.Data = msg.Data + return nil +} diff --git a/pgproto3/backend.go b/pgproto3/backend.go index c8d2f331..b7db6f76 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -147,6 +147,8 @@ func (b *Backend) Receive() (FrontendMessage, error) { msg = &SASLResponse{} case AuthTypeSASLFinal: msg = &SASLResponse{} + case AuthTypeGSS, AuthTypeGSSCont: + msg = &GSSResponse{} case AuthTypeCleartextPassword, AuthTypeMD5Password: fallthrough default: diff --git a/pgproto3/copy_both_response.go b/pgproto3/copy_both_response.go index 3243dbc1..8840a89e 100644 --- a/pgproto3/copy_both_response.go +++ b/pgproto3/copy_both_response.go @@ -48,7 +48,7 @@ func (src *CopyBothResponse) Encode(dst []byte) []byte { dst = append(dst, 'W') sp := len(dst) dst = pgio.AppendInt32(dst, -1) - + dst = append(dst, src.OverallFormat) dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) diff --git a/pgproto3/copy_both_response_test.go b/pgproto3/copy_both_response_test.go new file mode 100644 index 00000000..4437de1d --- /dev/null +++ b/pgproto3/copy_both_response_test.go @@ -0,0 +1,18 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/assert" +) + +func TestEncodeDecode(t *testing.T) { + srcBytes := []byte{'W', 0x00, 0x00, 0x00, 0x0b, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01} + dstResp := pgproto3.CopyBothResponse{} + err := dstResp.Decode(srcBytes[5:]) + assert.NoError(t, err, "No errors on decode") + dstBytes := []byte{} + dstBytes = dstResp.Encode(dstBytes) + assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match") +} diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index ea6757ad..435275d6 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -16,6 +16,8 @@ type Frontend struct { authenticationOk AuthenticationOk authenticationCleartextPassword AuthenticationCleartextPassword authenticationMD5Password AuthenticationMD5Password + authenticationGSS AuthenticationGSS + authenticationGSSContinue AuthenticationGSSContinue authenticationSASL AuthenticationSASL authenticationSASLContinue AuthenticationSASLContinue authenticationSASLFinal AuthenticationSASLFinal @@ -179,9 +181,9 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er case AuthTypeSCMCreds: return nil, errors.New("AuthTypeSCMCreds is unimplemented") case AuthTypeGSS: - return nil, errors.New("AuthTypeGSS is unimplemented") + return &f.authenticationGSS, nil case AuthTypeGSSCont: - return nil, errors.New("AuthTypeGSSCont is unimplemented") + return &f.authenticationGSSContinue, nil case AuthTypeSSPI: return nil, errors.New("AuthTypeSSPI is unimplemented") case AuthTypeSASL: diff --git a/pgproto3/gss_response.go b/pgproto3/gss_response.go new file mode 100644 index 00000000..64bfbd04 --- /dev/null +++ b/pgproto3/gss_response.go @@ -0,0 +1,49 @@ +package pgproto3 + +import ( + "encoding/json" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type GSSResponse struct { + Data []byte +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (g *GSSResponse) Frontend() {} + +func (g *GSSResponse) Decode(data []byte) error { + g.Data = data + return nil +} + +func (g *GSSResponse) Encode(dst []byte) []byte { + dst = append(dst, 'p') + dst = pgio.AppendInt32(dst, int32(4+len(g.Data))) + dst = append(dst, g.Data...) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (g *GSSResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "GSSResponse", + Data: g.Data, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (g *GSSResponse) UnmarshalJSON(data []byte) error { + var msg struct { + Data []byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + g.Data = msg.Data + return nil +} diff --git a/pgproto3/json_test.go b/pgproto3/json_test.go index eab26252..8fad4f88 100644 --- a/pgproto3/json_test.go +++ b/pgproto3/json_test.go @@ -37,6 +37,32 @@ func TestJSONUnmarshalAuthenticationSASL(t *testing.T) { } } +func TestJSONUnmarshalAuthenticationGSS(t *testing.T) { + data := []byte(`{"Type":"AuthenticationGSS"}`) + want := AuthenticationGSS{} + + var got AuthenticationGSS + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationGSS struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationGSSContinue(t *testing.T) { + data := []byte(`{"Type":"AuthenticationGSSContinue","Data":[1,2,3,4]}`) + want := AuthenticationGSSContinue{Data: []byte{1, 2, 3, 4}} + + var got AuthenticationGSSContinue + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationGSSContinue struct doesn't match expected value") + } +} + func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) { data := []byte(`{"Type":"AuthenticationSASLContinue", "Data":"1"}`) want := AuthenticationSASLContinue{ @@ -551,6 +577,19 @@ func TestAuthenticationMD5Password(t *testing.T) { } } +func TestJSONUnmarshalGSSResponse(t *testing.T) { + data := []byte(`{"Type":"GSSResponse","Data":[10,20,30,40]}`) + want := GSSResponse{Data: []byte{10, 20, 30, 40}} + + var got GSSResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled GSSResponse struct doesn't match expected value") + } +} + func TestErrorResponse(t *testing.T) { data := []byte(`{"Type":"ErrorResponse","UnknownFields":{"112":"foo"},"Code": "Fail","Position":1,"Message":"this is an error"}`) want := ErrorResponse{