Add tests for other types of JSON objects

pull/89/head
Jack Christensen 2015-09-04 13:40:59 -05:00
parent 9d200733b9
commit fff5b9759b
3 changed files with 159 additions and 82 deletions

View File

@ -220,21 +220,31 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
for _, d := range dest { for _, d := range dest {
vr, _ := rows.nextColumn() vr, _ := rows.nextColumn()
switch d := d.(type) {
case *bool: // Check for []byte first as we allow sidestepping the decoding process and retrieving the raw bytes
*d = decodeBool(vr) if b, ok := d.(*[]byte); ok {
case *[]byte:
// If it actually is a bytea then pass it through decodeBytea (so it can be decoded if it is in text format) // If it actually is a bytea then pass it through decodeBytea (so it can be decoded if it is in text format)
// Otherwise read the bytes directly regardless of what the actual type is. // Otherwise read the bytes directly regardless of what the actual type is.
if vr.Type().DataType == ByteaOid { if vr.Type().DataType == ByteaOid {
*d = decodeBytea(vr) *b = decodeBytea(vr)
} else { } else {
if vr.Len() != -1 { if vr.Len() != -1 {
*d = vr.ReadBytes(vr.Len()) *b = vr.ReadBytes(vr.Len())
} else { } else {
*d = nil *b = nil
} }
} }
} else if s, ok := d.(Scanner); ok {
err = s.Scan(vr)
if err != nil {
rows.Fatal(err)
}
} else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid {
decodeJson(vr, &d)
} else {
switch d := d.(type) {
case *bool:
*d = decodeBool(vr)
case *int64: case *int64:
*d = decodeInt8(vr) *d = decodeInt8(vr)
case *int16: case *int16:
@ -278,20 +288,11 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
} }
case *net.IPNet: case *net.IPNet:
*d = decodeInet(vr) *d = decodeInet(vr)
case Scanner:
err = d.Scan(vr)
if err != nil {
rows.Fatal(err)
}
default:
switch vr.Type().DataType {
case JsonOid, JsonbOid:
decodeJson(vr, &d)
default: default:
rows.Fatal(fmt.Errorf("Scan cannot decode into %T", d)) rows.Fatal(fmt.Errorf("Scan cannot decode into %T", d))
} }
}
}
if vr.Err() != nil { if vr.Err() != nil {
rows.Fatal(vr.Err()) rows.Fatal(vr.Err())
} }

View File

@ -1008,7 +1008,11 @@ func decodeJson(vr *ValueReader, d interface{}) error {
} }
bytes := vr.ReadBytes(vr.Len()) bytes := vr.ReadBytes(vr.Len())
return json.Unmarshal(bytes, d) err := json.Unmarshal(bytes, d)
if err != nil {
vr.Fatal(err)
}
return err
} }
func encodeJson(w *WriteBuf, value interface{}) error { func encodeJson(w *WriteBuf, value interface{}) error {

View File

@ -1,6 +1,7 @@
package pgx_test package pgx_test
import ( import (
"encoding/json"
"github.com/jackc/pgx" "github.com/jackc/pgx"
"net" "net"
"reflect" "reflect"
@ -77,33 +78,104 @@ func TestJsonAndJsonbTranscode(t *testing.T) {
} }
typename := conn.PgTypes[oid].Name typename := conn.PgTypes[oid].Name
// Test single level objects with map[string]string testJsonSingleLevelStringMap(t, conn, typename)
inStringMap := map[string]string{"key": "value"} testJsonNestedMap(t, conn, typename)
var outStringMap map[string]string testJsonStringArray(t, conn, typename)
err := conn.QueryRow("select $1::"+typename, inStringMap).Scan(&outStringMap) testJsonInt64Array(t, conn, typename)
testJsonInt16ArrayFailureDueToOverflow(t, conn, typename)
testJsonStruct(t, conn, typename)
}
}
func testJsonSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) {
input := map[string]string{"key": "value"}
var output map[string]string
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
if err != nil { if err != nil {
t.Errorf("%s: QueryRow Scan failed: %v", typename, err) t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
return
} }
if !reflect.DeepEqual(inStringMap, outStringMap) { if !reflect.DeepEqual(input, output) {
t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, inStringMap, outStringMap) t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, input, output)
return
} }
}
// Test nested objects with map[string]interface{} func testJsonNestedMap(t *testing.T, conn *pgx.Conn, typename string) {
inNestedMap := map[string]interface{}{ input := map[string]interface{}{
"name": "Uncanny", "name": "Uncanny",
"stats": map[string]interface{}{"hp": 107, "maxhp": 150}, "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)},
"inventory": []string{"phone", "key"}, "inventory": []interface{}{"phone", "key"},
} }
var outNestedMap map[string]interface{} var output map[string]interface{}
err = conn.QueryRow("select $1::"+typename, inNestedMap).Scan(&outNestedMap) err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
if err != nil {
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
return
}
if !reflect.DeepEqual(input, output) {
t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, input, output)
return
}
}
func testJsonStringArray(t *testing.T, conn *pgx.Conn, typename string) {
input := []string{"foo", "bar", "baz"}
var output []string
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
if err != nil { if err != nil {
t.Errorf("%s: QueryRow Scan failed: %v", typename, err) t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
} }
if !reflect.DeepEqual(inStringMap, outStringMap) { if !reflect.DeepEqual(input, output) {
t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, inStringMap, outStringMap) t.Errorf("%s: Did not transcode []string successfully: %v is not %v", typename, input, output)
} }
}
func testJsonInt64Array(t *testing.T, conn *pgx.Conn, typename string) {
input := []int64{1, 2, 234432}
var output []int64
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
if err != nil {
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
}
if !reflect.DeepEqual(input, output) {
t.Errorf("%s: Did not transcode []int64 successfully: %v is not %v", typename, input, output)
}
}
func testJsonInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string) {
input := []int{1, 2, 234432}
var output []int16
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
if _, ok := err.(*json.UnmarshalTypeError); !ok {
t.Errorf("%s: Expected *json.UnmarkalTypeError, but got %v", typename, err)
}
}
func testJsonStruct(t *testing.T, conn *pgx.Conn, typename string) {
type person struct {
Name string `json:"name"`
Age int `json:"age"`
}
input := person{
Name: "John",
Age: 42,
}
var output person
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
if err != nil {
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
}
if !reflect.DeepEqual(input, output) {
t.Errorf("%s: Did not transcode struct successfully: %v is not %v", typename, input, output)
} }
} }