mirror of https://github.com/jackc/pgx.git
Add tests for other types of JSON objects
parent
9d200733b9
commit
fff5b9759b
115
query.go
115
query.go
|
@ -220,78 +220,79 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
|||
|
||||
for _, d := range dest {
|
||||
vr, _ := rows.nextColumn()
|
||||
switch d := d.(type) {
|
||||
case *bool:
|
||||
*d = decodeBool(vr)
|
||||
case *[]byte:
|
||||
|
||||
// Check for []byte first as we allow sidestepping the decoding process and retrieving the raw bytes
|
||||
if b, ok := d.(*[]byte); ok {
|
||||
// If it actually is a bytea then pass it through decodeBytea (so it can be decoded if it is in text format)
|
||||
// Otherwise read the bytes directly regardless of what the actual type is.
|
||||
if vr.Type().DataType == ByteaOid {
|
||||
*d = decodeBytea(vr)
|
||||
*b = decodeBytea(vr)
|
||||
} else {
|
||||
if vr.Len() != -1 {
|
||||
*d = vr.ReadBytes(vr.Len())
|
||||
*b = vr.ReadBytes(vr.Len())
|
||||
} else {
|
||||
*d = nil
|
||||
*b = nil
|
||||
}
|
||||
}
|
||||
case *int64:
|
||||
*d = decodeInt8(vr)
|
||||
case *int16:
|
||||
*d = decodeInt2(vr)
|
||||
case *int32:
|
||||
*d = decodeInt4(vr)
|
||||
case *Oid:
|
||||
*d = decodeOid(vr)
|
||||
case *string:
|
||||
*d = decodeText(vr)
|
||||
case *float32:
|
||||
*d = decodeFloat4(vr)
|
||||
case *float64:
|
||||
*d = decodeFloat8(vr)
|
||||
case *[]bool:
|
||||
*d = decodeBoolArray(vr)
|
||||
case *[]int16:
|
||||
*d = decodeInt2Array(vr)
|
||||
case *[]int32:
|
||||
*d = decodeInt4Array(vr)
|
||||
case *[]int64:
|
||||
*d = decodeInt8Array(vr)
|
||||
case *[]float32:
|
||||
*d = decodeFloat4Array(vr)
|
||||
case *[]float64:
|
||||
*d = decodeFloat8Array(vr)
|
||||
case *[]string:
|
||||
*d = decodeTextArray(vr)
|
||||
case *[]time.Time:
|
||||
*d = decodeTimestampArray(vr)
|
||||
case *time.Time:
|
||||
switch vr.Type().DataType {
|
||||
case DateOid:
|
||||
*d = decodeDate(vr)
|
||||
case TimestampTzOid:
|
||||
*d = decodeTimestampTz(vr)
|
||||
case TimestampOid:
|
||||
*d = decodeTimestamp(vr)
|
||||
default:
|
||||
rows.Fatal(fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType))
|
||||
}
|
||||
case *net.IPNet:
|
||||
*d = decodeInet(vr)
|
||||
case Scanner:
|
||||
err = d.Scan(vr)
|
||||
} else if s, ok := d.(Scanner); ok {
|
||||
err = s.Scan(vr)
|
||||
if err != nil {
|
||||
rows.Fatal(err)
|
||||
}
|
||||
default:
|
||||
switch vr.Type().DataType {
|
||||
case JsonOid, JsonbOid:
|
||||
decodeJson(vr, &d)
|
||||
} else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid {
|
||||
decodeJson(vr, &d)
|
||||
} else {
|
||||
switch d := d.(type) {
|
||||
case *bool:
|
||||
*d = decodeBool(vr)
|
||||
case *int64:
|
||||
*d = decodeInt8(vr)
|
||||
case *int16:
|
||||
*d = decodeInt2(vr)
|
||||
case *int32:
|
||||
*d = decodeInt4(vr)
|
||||
case *Oid:
|
||||
*d = decodeOid(vr)
|
||||
case *string:
|
||||
*d = decodeText(vr)
|
||||
case *float32:
|
||||
*d = decodeFloat4(vr)
|
||||
case *float64:
|
||||
*d = decodeFloat8(vr)
|
||||
case *[]bool:
|
||||
*d = decodeBoolArray(vr)
|
||||
case *[]int16:
|
||||
*d = decodeInt2Array(vr)
|
||||
case *[]int32:
|
||||
*d = decodeInt4Array(vr)
|
||||
case *[]int64:
|
||||
*d = decodeInt8Array(vr)
|
||||
case *[]float32:
|
||||
*d = decodeFloat4Array(vr)
|
||||
case *[]float64:
|
||||
*d = decodeFloat8Array(vr)
|
||||
case *[]string:
|
||||
*d = decodeTextArray(vr)
|
||||
case *[]time.Time:
|
||||
*d = decodeTimestampArray(vr)
|
||||
case *time.Time:
|
||||
switch vr.Type().DataType {
|
||||
case DateOid:
|
||||
*d = decodeDate(vr)
|
||||
case TimestampTzOid:
|
||||
*d = decodeTimestampTz(vr)
|
||||
case TimestampOid:
|
||||
*d = decodeTimestamp(vr)
|
||||
default:
|
||||
rows.Fatal(fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType))
|
||||
}
|
||||
case *net.IPNet:
|
||||
*d = decodeInet(vr)
|
||||
default:
|
||||
rows.Fatal(fmt.Errorf("Scan cannot decode into %T", d))
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
if vr.Err() != nil {
|
||||
rows.Fatal(vr.Err())
|
||||
}
|
||||
|
|
|
@ -1008,7 +1008,11 @@ func decodeJson(vr *ValueReader, d interface{}) error {
|
|||
}
|
||||
|
||||
bytes := vr.ReadBytes(vr.Len())
|
||||
return json.Unmarshal(bytes, d)
|
||||
err := json.Unmarshal(bytes, d)
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func encodeJson(w *WriteBuf, value interface{}) error {
|
||||
|
|
120
values_test.go
120
values_test.go
|
@ -1,6 +1,7 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/jackc/pgx"
|
||||
"net"
|
||||
"reflect"
|
||||
|
@ -77,33 +78,104 @@ func TestJsonAndJsonbTranscode(t *testing.T) {
|
|||
}
|
||||
typename := conn.PgTypes[oid].Name
|
||||
|
||||
// Test single level objects with map[string]string
|
||||
inStringMap := map[string]string{"key": "value"}
|
||||
var outStringMap map[string]string
|
||||
err := conn.QueryRow("select $1::"+typename, inStringMap).Scan(&outStringMap)
|
||||
if err != nil {
|
||||
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
|
||||
}
|
||||
testJsonSingleLevelStringMap(t, conn, typename)
|
||||
testJsonNestedMap(t, conn, typename)
|
||||
testJsonStringArray(t, conn, typename)
|
||||
testJsonInt64Array(t, conn, typename)
|
||||
testJsonInt16ArrayFailureDueToOverflow(t, conn, typename)
|
||||
testJsonStruct(t, conn, typename)
|
||||
}
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(inStringMap, outStringMap) {
|
||||
t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, inStringMap, outStringMap)
|
||||
}
|
||||
func testJsonSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) {
|
||||
input := map[string]string{"key": "value"}
|
||||
var output map[string]string
|
||||
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
|
||||
if err != nil {
|
||||
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Test nested objects with map[string]interface{}
|
||||
inNestedMap := map[string]interface{}{
|
||||
"name": "Uncanny",
|
||||
"stats": map[string]interface{}{"hp": 107, "maxhp": 150},
|
||||
"inventory": []string{"phone", "key"},
|
||||
}
|
||||
var outNestedMap map[string]interface{}
|
||||
err = conn.QueryRow("select $1::"+typename, inNestedMap).Scan(&outNestedMap)
|
||||
if err != nil {
|
||||
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
|
||||
}
|
||||
if !reflect.DeepEqual(input, output) {
|
||||
t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, input, output)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(inStringMap, outStringMap) {
|
||||
t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, inStringMap, outStringMap)
|
||||
}
|
||||
func testJsonNestedMap(t *testing.T, conn *pgx.Conn, typename string) {
|
||||
input := map[string]interface{}{
|
||||
"name": "Uncanny",
|
||||
"stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)},
|
||||
"inventory": []interface{}{"phone", "key"},
|
||||
}
|
||||
var output map[string]interface{}
|
||||
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
|
||||
if err != nil {
|
||||
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(input, output) {
|
||||
t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, input, output)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func testJsonStringArray(t *testing.T, conn *pgx.Conn, typename string) {
|
||||
input := []string{"foo", "bar", "baz"}
|
||||
var output []string
|
||||
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
|
||||
if err != nil {
|
||||
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(input, output) {
|
||||
t.Errorf("%s: Did not transcode []string successfully: %v is not %v", typename, input, output)
|
||||
}
|
||||
}
|
||||
|
||||
func testJsonInt64Array(t *testing.T, conn *pgx.Conn, typename string) {
|
||||
input := []int64{1, 2, 234432}
|
||||
var output []int64
|
||||
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
|
||||
if err != nil {
|
||||
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(input, output) {
|
||||
t.Errorf("%s: Did not transcode []int64 successfully: %v is not %v", typename, input, output)
|
||||
}
|
||||
}
|
||||
|
||||
func testJsonInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string) {
|
||||
input := []int{1, 2, 234432}
|
||||
var output []int16
|
||||
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
|
||||
if _, ok := err.(*json.UnmarshalTypeError); !ok {
|
||||
t.Errorf("%s: Expected *json.UnmarkalTypeError, but got %v", typename, err)
|
||||
}
|
||||
}
|
||||
|
||||
func testJsonStruct(t *testing.T, conn *pgx.Conn, typename string) {
|
||||
type person struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
}
|
||||
|
||||
input := person{
|
||||
Name: "John",
|
||||
Age: 42,
|
||||
}
|
||||
|
||||
var output person
|
||||
|
||||
err := conn.QueryRow("select $1::"+typename, input).Scan(&output)
|
||||
if err != nil {
|
||||
t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(input, output) {
|
||||
t.Errorf("%s: Did not transcode struct successfully: %v is not %v", typename, input, output)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue