From f7b6b3f077837daf4c0086b46d3eef6cce5e532a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Oct 2016 10:58:04 -0500 Subject: [PATCH] Handle json/jsonb in binary to support CopyTo fixes #189 --- conn.go | 2 +- copy_to_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++++ query.go | 8 +++++-- values.go | 55 +++++++++++++++++++++++++++++++++++++++++++------ 4 files changed, 111 insertions(+), 9 deletions(-) diff --git a/conn.go b/conn.go index c928ed98..9f15a999 100644 --- a/conn.go +++ b/conn.go @@ -917,7 +917,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid, RecordOid: + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid, RecordOid, JsonOid, JsonbOid: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) diff --git a/copy_to_test.go b/copy_to_test.go index d810c4fb..95b9af2d 100644 --- a/copy_to_test.go +++ b/copy_to_test.go @@ -119,6 +119,61 @@ func TestConnCopyToLarge(t *testing.T) { ensureConnValid(t, conn) } +func TestConnCopyToJSON(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} { + if _, ok := conn.PgTypes[oid]; !ok { + return // No JSON/JSONB type -- must be running against old PostgreSQL + } + } + + mustExec(t, conn, `create temporary table foo( + a json, + b jsonb + )`) + + inputRows := [][]interface{}{ + {map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}}, + {nil, nil}, + } + + copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyTo: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) +} + func TestConnCopyToFailServerSideMidway(t *testing.T) { t.Parallel() diff --git a/query.go b/query.go index 50c8e290..4e4b8e53 100644 --- a/query.go +++ b/query.go @@ -298,13 +298,17 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } - } else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid { + } else if vr.Type().DataType == JsonOid { // Because the argument passed to decodeJSON will escape the heap. // This allows d to be stack allocated and only copied to the heap when // we actually are decoding JSON. This saves one memory allocation per // row. d2 := d decodeJSON(vr, &d2) + } else if vr.Type().DataType == JsonbOid { + // Same trick as above for getting stack allocation + d2 := d + decodeJSONB(vr, &d2) } else { if err := Decode(vr, d); err != nil { rows.Fatal(scanArgError{col: i, err: err}) @@ -393,7 +397,7 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, d) case JsonbOid: var d interface{} - decodeJSON(vr, &d) + decodeJSONB(vr, &d) values = append(values, d) default: rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types")) diff --git a/values.go b/values.go index 545f1031..31f754ef 100644 --- a/values.go +++ b/values.go @@ -91,23 +91,25 @@ func init() { "bool": BinaryFormatCode, "bytea": BinaryFormatCode, "char": BinaryFormatCode, + "cid": BinaryFormatCode, "cidr": BinaryFormatCode, "date": BinaryFormatCode, "float4": BinaryFormatCode, "float8": BinaryFormatCode, + "json": BinaryFormatCode, + "jsonb": BinaryFormatCode, "inet": BinaryFormatCode, "int2": BinaryFormatCode, "int4": BinaryFormatCode, "int8": BinaryFormatCode, "oid": BinaryFormatCode, - "tid": BinaryFormatCode, - "xid": BinaryFormatCode, - "cid": BinaryFormatCode, "record": BinaryFormatCode, "text": BinaryFormatCode, + "tid": BinaryFormatCode, "timestamp": BinaryFormatCode, "timestamptz": BinaryFormatCode, "varchar": BinaryFormatCode, + "xid": BinaryFormatCode, } } @@ -889,9 +891,12 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { return Encode(wbuf, oid, arg) } - if oid == JsonOid || oid == JsonbOid { + if oid == JsonOid { return encodeJSON(wbuf, oid, arg) } + if oid == JsonbOid { + return encodeJSONB(wbuf, oid, arg) + } switch arg := arg.(type) { case []string: @@ -1914,7 +1919,7 @@ func decodeJSON(vr *ValueReader, d interface{}) error { return nil } - if vr.Type().DataType != JsonOid && vr.Type().DataType != JsonbOid { + if vr.Type().DataType != JsonOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into json", vr.Type().DataType))) } @@ -1927,7 +1932,7 @@ func decodeJSON(vr *ValueReader, d interface{}) error { } func encodeJSON(w *WriteBuf, oid Oid, value interface{}) error { - if oid != JsonOid && oid != JsonbOid { + if oid != JsonOid { return fmt.Errorf("cannot encode JSON into oid %v", oid) } @@ -1942,6 +1947,44 @@ func encodeJSON(w *WriteBuf, oid Oid, value interface{}) error { return nil } +func decodeJSONB(vr *ValueReader, d interface{}) error { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != JsonbOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into jsonb", vr.Type().DataType))) + } + + bytes := vr.ReadBytes(vr.Len()) + if bytes[0] != 1 { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown jsonb format byte: %x", bytes[0]))) + } + + err := json.Unmarshal(bytes[1:], d) + if err != nil { + vr.Fatal(err) + } + return err +} + +func encodeJSONB(w *WriteBuf, oid Oid, value interface{}) error { + if oid != JsonbOid { + return fmt.Errorf("cannot encode JSON into oid %v", oid) + } + + s, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("Failed to encode json from type: %T", value) + } + + w.WriteInt32(int32(len(s) + 1)) + w.WriteByte(1) // JSONB format header + w.WriteBytes(s) + + return nil +} + func decodeDate(vr *ValueReader) time.Time { var zeroTime time.Time