diff --git a/conn.go b/conn.go index ac1e56f1..411d69fb 100644 --- a/conn.go +++ b/conn.go @@ -799,6 +799,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} err = encodeTimestampArray(wbuf, arguments[i], TimestampTzOid) case OidOid: err = encodeOid(wbuf, arguments[i]) + case JsonOid: + err = encodeJson(wbuf, arguments[i]) + case JsonbOid: + err = encodeJson(wbuf, arguments[i]) default: return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } diff --git a/query.go b/query.go index d5d0b636..59ffbb3b 100644 --- a/query.go +++ b/query.go @@ -284,7 +284,14 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { rows.Fatal(err) } default: - rows.Fatal(fmt.Errorf("Scan cannot decode into %T", d)) + switch vr.Type().DataType { + case JsonOid: + decodeJson(vr, &d) + case JsonbOid: + decodeJson(vr, &d) + default: + rows.Fatal(fmt.Errorf("Scan cannot decode into %T", d)) + } } if vr.Err() != nil { @@ -360,6 +367,14 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, decodeTimestamp(vr)) case InetOid, CidrOid: values = append(values, decodeInet(vr)) + case JsonOid: + var d interface{} + decodeJson(vr, &d) + values = append(values, d) + case JsonbOid: + var d interface{} + decodeJson(vr, &d) + values = append(values, d) default: rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types")) } diff --git a/values.go b/values.go index e1ec36c6..0213a236 100644 --- a/values.go +++ b/values.go @@ -2,6 +2,7 @@ package pgx import ( "bytes" + "encoding/json" "fmt" "math" "net" @@ -19,6 +20,7 @@ const ( Int4Oid = 23 TextOid = 25 OidOid = 26 + JsonOid = 114 CidrOid = 650 Float4Oid = 700 Float8Oid = 701 @@ -37,6 +39,7 @@ const ( TimestampArrayOid = 1115 TimestampTzOid = 1184 TimestampTzArrayOid = 1185 + JsonbOid = 3802 ) // PostgreSQL format codes @@ -995,6 +998,28 @@ func encodeBytea(w *WriteBuf, value interface{}) error { return nil } +func decodeJson(vr *ValueReader, d interface{}) error { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != JsonOid && vr.Type().DataType != JsonbOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into json", vr.Type().DataType))) + } + + bytes := vr.ReadBytes(vr.Len()) + return json.Unmarshal(bytes, d) +} + +func encodeJson(w *WriteBuf, value interface{}) error { + s, err := json.Marshal(value) + if err != nil { + fmt.Errorf("Failed to encode json from type: %T", value) + } + + return encodeText(w, s) +} + func decodeDate(vr *ValueReader) time.Time { var zeroTime time.Time diff --git a/values_test.go b/values_test.go index 81646bb6..e8617417 100644 --- a/values_test.go +++ b/values_test.go @@ -65,6 +65,46 @@ func TestTimestampTzTranscode(t *testing.T) { } } +func TestJsonTranscode(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + m := map[string]string{ + "key": "value", + } + var outputJson map[string]string + + err := conn.QueryRow("select $1::json", m).Scan(&outputJson) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if m["key"] != outputJson["key"] { + t.Errorf("Did not transcode json successfully: %v is not %v", outputJson["key"], m["key"]) + } +} + +func TestJsonbTranscode(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + m := map[string]string{ + "key": "value", + } + var outputJson map[string]string + + err := conn.QueryRow("select $1::jsonb", m).Scan(&outputJson) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if m["key"] != outputJson["key"] { + t.Errorf("Did not transcode jsonb successfully: %v is not %v", outputJson["key"], m["key"]) + } +} + func mustParseCIDR(t *testing.T, s string) net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil {