From fff5b9759bf5cfed2d821c48e6258040fa1bcc73 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 4 Sep 2015 13:40:59 -0500 Subject: [PATCH] Add tests for other types of JSON objects --- query.go | 115 ++++++++++++++++++++++++----------------------- values.go | 6 ++- values_test.go | 120 +++++++++++++++++++++++++++++++++++++++---------- 3 files changed, 159 insertions(+), 82 deletions(-) diff --git a/query.go b/query.go index 821648c9..93527e20 100644 --- a/query.go +++ b/query.go @@ -220,78 +220,79 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { for _, d := range dest { vr, _ := rows.nextColumn() - switch d := d.(type) { - case *bool: - *d = decodeBool(vr) - case *[]byte: + + // Check for []byte first as we allow sidestepping the decoding process and retrieving the raw bytes + if b, ok := d.(*[]byte); ok { // If it actually is a bytea then pass it through decodeBytea (so it can be decoded if it is in text format) // Otherwise read the bytes directly regardless of what the actual type is. if vr.Type().DataType == ByteaOid { - *d = decodeBytea(vr) + *b = decodeBytea(vr) } else { if vr.Len() != -1 { - *d = vr.ReadBytes(vr.Len()) + *b = vr.ReadBytes(vr.Len()) } else { - *d = nil + *b = nil } } - case *int64: - *d = decodeInt8(vr) - case *int16: - *d = decodeInt2(vr) - case *int32: - *d = decodeInt4(vr) - case *Oid: - *d = decodeOid(vr) - case *string: - *d = decodeText(vr) - case *float32: - *d = decodeFloat4(vr) - case *float64: - *d = decodeFloat8(vr) - case *[]bool: - *d = decodeBoolArray(vr) - case *[]int16: - *d = decodeInt2Array(vr) - case *[]int32: - *d = decodeInt4Array(vr) - case *[]int64: - *d = decodeInt8Array(vr) - case *[]float32: - *d = decodeFloat4Array(vr) - case *[]float64: - *d = decodeFloat8Array(vr) - case *[]string: - *d = decodeTextArray(vr) - case *[]time.Time: - *d = decodeTimestampArray(vr) - case *time.Time: - switch vr.Type().DataType { - case DateOid: - *d = decodeDate(vr) - case TimestampTzOid: - *d = decodeTimestampTz(vr) - case TimestampOid: - *d = decodeTimestamp(vr) - default: - rows.Fatal(fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType)) - } - case *net.IPNet: - *d = decodeInet(vr) - case Scanner: - err = d.Scan(vr) + } else if s, ok := d.(Scanner); ok { + err = s.Scan(vr) if err != nil { rows.Fatal(err) } - default: - switch vr.Type().DataType { - case JsonOid, JsonbOid: - decodeJson(vr, &d) + } else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid { + decodeJson(vr, &d) + } else { + switch d := d.(type) { + case *bool: + *d = decodeBool(vr) + case *int64: + *d = decodeInt8(vr) + case *int16: + *d = decodeInt2(vr) + case *int32: + *d = decodeInt4(vr) + case *Oid: + *d = decodeOid(vr) + case *string: + *d = decodeText(vr) + case *float32: + *d = decodeFloat4(vr) + case *float64: + *d = decodeFloat8(vr) + case *[]bool: + *d = decodeBoolArray(vr) + case *[]int16: + *d = decodeInt2Array(vr) + case *[]int32: + *d = decodeInt4Array(vr) + case *[]int64: + *d = decodeInt8Array(vr) + case *[]float32: + *d = decodeFloat4Array(vr) + case *[]float64: + *d = decodeFloat8Array(vr) + case *[]string: + *d = decodeTextArray(vr) + case *[]time.Time: + *d = decodeTimestampArray(vr) + case *time.Time: + switch vr.Type().DataType { + case DateOid: + *d = decodeDate(vr) + case TimestampTzOid: + *d = decodeTimestampTz(vr) + case TimestampOid: + *d = decodeTimestamp(vr) + default: + rows.Fatal(fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType)) + } + case *net.IPNet: + *d = decodeInet(vr) default: rows.Fatal(fmt.Errorf("Scan cannot decode into %T", d)) } - } + } if vr.Err() != nil { rows.Fatal(vr.Err()) } diff --git a/values.go b/values.go index 0213a236..6633ab8b 100644 --- a/values.go +++ b/values.go @@ -1008,7 +1008,11 @@ func decodeJson(vr *ValueReader, d interface{}) error { } bytes := vr.ReadBytes(vr.Len()) - return json.Unmarshal(bytes, d) + err := json.Unmarshal(bytes, d) + if err != nil { + vr.Fatal(err) + } + return err } func encodeJson(w *WriteBuf, value interface{}) error { diff --git a/values_test.go b/values_test.go index 6530d928..bdfa7cc4 100644 --- a/values_test.go +++ b/values_test.go @@ -1,6 +1,7 @@ package pgx_test import ( + "encoding/json" "github.com/jackc/pgx" "net" "reflect" @@ -77,33 +78,104 @@ func TestJsonAndJsonbTranscode(t *testing.T) { } typename := conn.PgTypes[oid].Name - // Test single level objects with map[string]string - inStringMap := map[string]string{"key": "value"} - var outStringMap map[string]string - err := conn.QueryRow("select $1::"+typename, inStringMap).Scan(&outStringMap) - if err != nil { - t.Errorf("%s: QueryRow Scan failed: %v", typename, err) - } + testJsonSingleLevelStringMap(t, conn, typename) + testJsonNestedMap(t, conn, typename) + testJsonStringArray(t, conn, typename) + testJsonInt64Array(t, conn, typename) + testJsonInt16ArrayFailureDueToOverflow(t, conn, typename) + testJsonStruct(t, conn, typename) + } +} - if !reflect.DeepEqual(inStringMap, outStringMap) { - t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, inStringMap, outStringMap) - } +func testJsonSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) { + input := map[string]string{"key": "value"} + var output map[string]string + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err != nil { + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + return + } - // Test nested objects with map[string]interface{} - inNestedMap := map[string]interface{}{ - "name": "Uncanny", - "stats": map[string]interface{}{"hp": 107, "maxhp": 150}, - "inventory": []string{"phone", "key"}, - } - var outNestedMap map[string]interface{} - err = conn.QueryRow("select $1::"+typename, inNestedMap).Scan(&outNestedMap) - if err != nil { - t.Errorf("%s: QueryRow Scan failed: %v", typename, err) - } + if !reflect.DeepEqual(input, output) { + t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, input, output) + return + } +} - if !reflect.DeepEqual(inStringMap, outStringMap) { - t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, inStringMap, outStringMap) - } +func testJsonNestedMap(t *testing.T, conn *pgx.Conn, typename string) { + input := map[string]interface{}{ + "name": "Uncanny", + "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, + "inventory": []interface{}{"phone", "key"}, + } + var output map[string]interface{} + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err != nil { + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + return + } + + if !reflect.DeepEqual(input, output) { + t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, input, output) + return + } +} + +func testJsonStringArray(t *testing.T, conn *pgx.Conn, typename string) { + input := []string{"foo", "bar", "baz"} + var output []string + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err != nil { + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + } + + if !reflect.DeepEqual(input, output) { + t.Errorf("%s: Did not transcode []string successfully: %v is not %v", typename, input, output) + } +} + +func testJsonInt64Array(t *testing.T, conn *pgx.Conn, typename string) { + input := []int64{1, 2, 234432} + var output []int64 + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err != nil { + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + } + + if !reflect.DeepEqual(input, output) { + t.Errorf("%s: Did not transcode []int64 successfully: %v is not %v", typename, input, output) + } +} + +func testJsonInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string) { + input := []int{1, 2, 234432} + var output []int16 + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if _, ok := err.(*json.UnmarshalTypeError); !ok { + t.Errorf("%s: Expected *json.UnmarkalTypeError, but got %v", typename, err) + } +} + +func testJsonStruct(t *testing.T, conn *pgx.Conn, typename string) { + type person struct { + Name string `json:"name"` + Age int `json:"age"` + } + + input := person{ + Name: "John", + Age: 42, + } + + var output person + + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err != nil { + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + } + + if !reflect.DeepEqual(input, output) { + t.Errorf("%s: Did not transcode struct successfully: %v is not %v", typename, input, output) } }