diff --git a/authentication_cleartext_password.go b/authentication_cleartext_password.go index dd82c7a7..1b87a718 100644 --- a/authentication_cleartext_password.go +++ b/authentication_cleartext_password.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -37,3 +38,12 @@ func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte { dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword) return dst } + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationCleartextPassword) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "AuthenticationCleartextPassword", + }) +} diff --git a/authentication_md5_password.go b/authentication_md5_password.go index d505d264..95795b31 100644 --- a/authentication_md5_password.go +++ b/authentication_md5_password.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -41,3 +42,33 @@ func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { dst = append(dst, src.Salt[:]...) return dst } + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationMD5Password) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Salt [4]byte + }{ + Type: "AuthenticationMD5Password", + Salt: src.Salt, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Salt [4]byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Salt = msg.Salt + return nil +} diff --git a/authentication_ok.go b/authentication_ok.go index 7b13c6e0..ad69b907 100644 --- a/authentication_ok.go +++ b/authentication_ok.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -37,3 +38,12 @@ func (src *AuthenticationOk) Encode(dst []byte) []byte { dst = pgio.AppendUint32(dst, AuthTypeOk) return dst } + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationOk) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "AuthenticationOK", + }) +} diff --git a/authentication_sasl.go b/authentication_sasl.go index c57ae32d..d2b09750 100644 --- a/authentication_sasl.go +++ b/authentication_sasl.go @@ -3,6 +3,7 @@ package pgproto3 import ( "bytes" "encoding/binary" + "encoding/json" "errors" "github.com/jackc/pgio" @@ -58,3 +59,14 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte { return dst } + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationSASL) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + AuthMechanisms []string + }{ + Type: "AuthenticationSASL", + AuthMechanisms: src.AuthMechanisms, + }) +} diff --git a/authentication_sasl_continue.go b/authentication_sasl_continue.go index 62a16c76..d258065f 100644 --- a/authentication_sasl_continue.go +++ b/authentication_sasl_continue.go @@ -48,6 +48,17 @@ func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationSASLContinue) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "AuthenticationSASLContinue", + Data: string(src.Data), + }) +} + // UnmarshalJSON implements encoding/json.Unmarshaler. func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error { // Ignore null, like in the main JSON package. diff --git a/authentication_sasl_final.go b/authentication_sasl_final.go index de5e454a..6a681d73 100644 --- a/authentication_sasl_final.go +++ b/authentication_sasl_final.go @@ -48,6 +48,17 @@ func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte { return dst } +// MarshalJSON implements encoding/json.Unmarshaler. +func (src AuthenticationSASLFinal) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "AuthenticationSASLFinal", + Data: string(src.Data), + }) +} + // UnmarshalJSON implements encoding/json.Unmarshaler. func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error { // Ignore null, like in the main JSON package. diff --git a/error_response.go b/error_response.go index 4eb0a196..ec51e019 100644 --- a/error_response.go +++ b/error_response.go @@ -3,6 +3,7 @@ package pgproto3 import ( "bytes" "encoding/binary" + "encoding/json" "strconv" ) @@ -225,3 +226,109 @@ func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { return buf.Bytes() } + +// MarshalJSON implements encoding/json.Marshaler. +func (src ErrorResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string + }{ + Type: "ErrorResponse", + Severity: src.Severity, + SeverityUnlocalized: src.SeverityUnlocalized, + Code: src.Code, + Message: src.Message, + Detail: src.Detail, + Hint: src.Hint, + Position: src.Position, + InternalPosition: src.InternalPosition, + InternalQuery: src.InternalQuery, + Where: src.Where, + SchemaName: src.SchemaName, + TableName: src.TableName, + ColumnName: src.ColumnName, + DataTypeName: src.DataTypeName, + ConstraintName: src.ConstraintName, + File: src.File, + Line: src.Line, + Routine: src.Routine, + UnknownFields: src.UnknownFields, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *ErrorResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Severity = msg.Severity + dst.SeverityUnlocalized = msg.SeverityUnlocalized + dst.Code = msg.Code + dst.Message = msg.Message + dst.Detail = msg.Detail + dst.Hint = msg.Hint + dst.Position = msg.Position + dst.InternalPosition = msg.InternalPosition + dst.InternalQuery = msg.InternalQuery + dst.Where = msg.Where + dst.SchemaName = msg.SchemaName + dst.TableName = msg.TableName + dst.ColumnName = msg.ColumnName + dst.DataTypeName = msg.DataTypeName + dst.ConstraintName = msg.ConstraintName + dst.File = msg.File + dst.Line = msg.Line + dst.Routine = msg.Routine + + dst.UnknownFields = msg.UnknownFields + + return nil +} diff --git a/json_test.go b/json_test.go index c73807ab..eab26252 100644 --- a/json_test.go +++ b/json_test.go @@ -23,9 +23,9 @@ func TestJSONUnmarshalAuthenticationMD5Password(t *testing.T) { } func TestJSONUnmarshalAuthenticationSASL(t *testing.T) { - data := []byte(`{"Type":"AuthenticationSASL", "AuthMechanisms":[]}`) + data := []byte(`{"Type":"AuthenticationSASL","AuthMechanisms":["SCRAM-SHA-256"]}`) want := AuthenticationSASL{ - AuthMechanisms: []string{}, + []string{"SCRAM-SHA-256"}, } var got AuthenticationSASL @@ -38,9 +38,9 @@ func TestJSONUnmarshalAuthenticationSASL(t *testing.T) { } func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) { - data := []byte(`{"Type":"AuthenticationSASLContinue"}`) + data := []byte(`{"Type":"AuthenticationSASLContinue", "Data":"1"}`) want := AuthenticationSASLContinue{ - Data: []byte{}, + Data: []byte{'1'}, } var got AuthenticationSASLContinue @@ -53,9 +53,9 @@ func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) { } func TestJSONUnmarshalAuthenticationSASLFinal(t *testing.T) { - data := []byte(`{"Type":"AuthenticationSASLFinal"}`) + data := []byte(`{"Type":"AuthenticationSASLFinal", "Data":"1"}`) want := AuthenticationSASLFinal{ - Data: []byte{}, + Data: []byte{'1'}, } var got AuthenticationSASLFinal @@ -463,8 +463,11 @@ func TestJSONUnmarshalQuery(t *testing.T) { } func TestJSONUnmarshalSASLInitialResponse(t *testing.T) { - data := []byte(`{"Type":"SASLInitialResponse"}`) - want := SASLInitialResponse{} + data := []byte(`{"Type":"SASLInitialResponse", "AuthMechanism":"SCRAM-SHA-256", "Data": "6D"}`) + want := SASLInitialResponse{ + AuthMechanism: "SCRAM-SHA-256", + Data: []byte{109}, + } var got SASLInitialResponse if err := json.Unmarshal(data, &got); err != nil { @@ -506,3 +509,64 @@ func TestJSONUnmarshalStartupMessage(t *testing.T) { t.Error("unmarshaled StartupMessage struct doesn't match expected value") } } + +func TestAuthenticationOK(t *testing.T) { + data := []byte(`{"Type":"AuthenticationOK"}`) + want := AuthenticationOk{} + + var got AuthenticationOk + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationOK struct doesn't match expected value") + } +} + +func TestAuthenticationCleartextPassword(t *testing.T) { + data := []byte(`{"Type":"AuthenticationCleartextPassword"}`) + want := AuthenticationCleartextPassword{} + + var got AuthenticationCleartextPassword + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationCleartextPassword struct doesn't match expected value") + } +} + +func TestAuthenticationMD5Password(t *testing.T) { + data := []byte(`{"Type":"AuthenticationMD5Password","Salt":[1,2,3,4]}`) + want := AuthenticationMD5Password{ + Salt: [4]byte{1, 2, 3, 4}, + } + + var got AuthenticationMD5Password + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationMD5Password struct doesn't match expected value") + } +} + +func TestErrorResponse(t *testing.T) { + data := []byte(`{"Type":"ErrorResponse","UnknownFields":{"112":"foo"},"Code": "Fail","Position":1,"Message":"this is an error"}`) + want := ErrorResponse{ + UnknownFields: map[byte]string{ + 'p': "foo", + }, + Code: "Fail", + Position: 1, + Message: "this is an error", + } + + var got ErrorResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled ErrorResponse struct doesn't match expected value") + } +}