Handle json/jsonb in binary to support CopyTo

fixes #189
pull/193/head
Jack Christensen 2016-10-01 10:58:04 -05:00
parent a9199847a8
commit f7b6b3f077
4 changed files with 111 additions and 9 deletions

View File

@ -917,7 +917,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
wbuf.WriteInt16(TextFormatCode)
default:
switch oid {
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid, RecordOid:
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid, RecordOid, JsonOid, JsonbOid:
wbuf.WriteInt16(BinaryFormatCode)
default:
wbuf.WriteInt16(TextFormatCode)

View File

@ -119,6 +119,61 @@ func TestConnCopyToLarge(t *testing.T) {
ensureConnValid(t, conn)
}
func TestConnCopyToJSON(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} {
if _, ok := conn.PgTypes[oid]; !ok {
return // No JSON/JSONB type -- must be running against old PostgreSQL
}
}
mustExec(t, conn, `create temporary table foo(
a json,
b jsonb
)`)
inputRows := [][]interface{}{
{map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}},
{nil, nil},
}
copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyTo: %v", err)
}
if copyCount != len(inputRows) {
t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err := conn.Query("select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if !reflect.DeepEqual(inputRows, outputRows) {
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
}
ensureConnValid(t, conn)
}
func TestConnCopyToFailServerSideMidway(t *testing.T) {
t.Parallel()

View File

@ -298,13 +298,17 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
if err != nil {
rows.Fatal(scanArgError{col: i, err: err})
}
} else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid {
} else if vr.Type().DataType == JsonOid {
// Because the argument passed to decodeJSON will escape the heap.
// This allows d to be stack allocated and only copied to the heap when
// we actually are decoding JSON. This saves one memory allocation per
// row.
d2 := d
decodeJSON(vr, &d2)
} else if vr.Type().DataType == JsonbOid {
// Same trick as above for getting stack allocation
d2 := d
decodeJSONB(vr, &d2)
} else {
if err := Decode(vr, d); err != nil {
rows.Fatal(scanArgError{col: i, err: err})
@ -393,7 +397,7 @@ func (rows *Rows) Values() ([]interface{}, error) {
values = append(values, d)
case JsonbOid:
var d interface{}
decodeJSON(vr, &d)
decodeJSONB(vr, &d)
values = append(values, d)
default:
rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types"))

View File

@ -91,23 +91,25 @@ func init() {
"bool": BinaryFormatCode,
"bytea": BinaryFormatCode,
"char": BinaryFormatCode,
"cid": BinaryFormatCode,
"cidr": BinaryFormatCode,
"date": BinaryFormatCode,
"float4": BinaryFormatCode,
"float8": BinaryFormatCode,
"json": BinaryFormatCode,
"jsonb": BinaryFormatCode,
"inet": BinaryFormatCode,
"int2": BinaryFormatCode,
"int4": BinaryFormatCode,
"int8": BinaryFormatCode,
"oid": BinaryFormatCode,
"tid": BinaryFormatCode,
"xid": BinaryFormatCode,
"cid": BinaryFormatCode,
"record": BinaryFormatCode,
"text": BinaryFormatCode,
"tid": BinaryFormatCode,
"timestamp": BinaryFormatCode,
"timestamptz": BinaryFormatCode,
"varchar": BinaryFormatCode,
"xid": BinaryFormatCode,
}
}
@ -889,9 +891,12 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error {
return Encode(wbuf, oid, arg)
}
if oid == JsonOid || oid == JsonbOid {
if oid == JsonOid {
return encodeJSON(wbuf, oid, arg)
}
if oid == JsonbOid {
return encodeJSONB(wbuf, oid, arg)
}
switch arg := arg.(type) {
case []string:
@ -1914,7 +1919,7 @@ func decodeJSON(vr *ValueReader, d interface{}) error {
return nil
}
if vr.Type().DataType != JsonOid && vr.Type().DataType != JsonbOid {
if vr.Type().DataType != JsonOid {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into json", vr.Type().DataType)))
}
@ -1927,7 +1932,7 @@ func decodeJSON(vr *ValueReader, d interface{}) error {
}
func encodeJSON(w *WriteBuf, oid Oid, value interface{}) error {
if oid != JsonOid && oid != JsonbOid {
if oid != JsonOid {
return fmt.Errorf("cannot encode JSON into oid %v", oid)
}
@ -1942,6 +1947,44 @@ func encodeJSON(w *WriteBuf, oid Oid, value interface{}) error {
return nil
}
func decodeJSONB(vr *ValueReader, d interface{}) error {
if vr.Len() == -1 {
return nil
}
if vr.Type().DataType != JsonbOid {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into jsonb", vr.Type().DataType)))
}
bytes := vr.ReadBytes(vr.Len())
if bytes[0] != 1 {
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown jsonb format byte: %x", bytes[0])))
}
err := json.Unmarshal(bytes[1:], d)
if err != nil {
vr.Fatal(err)
}
return err
}
func encodeJSONB(w *WriteBuf, oid Oid, value interface{}) error {
if oid != JsonbOid {
return fmt.Errorf("cannot encode JSON into oid %v", oid)
}
s, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("Failed to encode json from type: %T", value)
}
w.WriteInt32(int32(len(s) + 1))
w.WriteByte(1) // JSONB format header
w.WriteBytes(s)
return nil
}
func decodeDate(vr *ValueReader) time.Time {
var zeroTime time.Time