mirror of https://github.com/jackc/pgx.git
parent
a9199847a8
commit
f7b6b3f077
2
conn.go
2
conn.go
|
@ -917,7 +917,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
wbuf.WriteInt16(TextFormatCode)
|
||||
default:
|
||||
switch oid {
|
||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid, RecordOid:
|
||||
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid, RecordOid, JsonOid, JsonbOid:
|
||||
wbuf.WriteInt16(BinaryFormatCode)
|
||||
default:
|
||||
wbuf.WriteInt16(TextFormatCode)
|
||||
|
|
|
@ -119,6 +119,61 @@ func TestConnCopyToLarge(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyToJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} {
|
||||
if _, ok := conn.PgTypes[oid]; !ok {
|
||||
return // No JSON/JSONB type -- must be running against old PostgreSQL
|
||||
}
|
||||
}
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a json,
|
||||
b jsonb
|
||||
)`)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
{map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}},
|
||||
{nil, nil},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyTo: %v", err)
|
||||
}
|
||||
if copyCount != len(inputRows) {
|
||||
t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyToFailServerSideMidway(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
8
query.go
8
query.go
|
@ -298,13 +298,17 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
|||
if err != nil {
|
||||
rows.Fatal(scanArgError{col: i, err: err})
|
||||
}
|
||||
} else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid {
|
||||
} else if vr.Type().DataType == JsonOid {
|
||||
// Because the argument passed to decodeJSON will escape the heap.
|
||||
// This allows d to be stack allocated and only copied to the heap when
|
||||
// we actually are decoding JSON. This saves one memory allocation per
|
||||
// row.
|
||||
d2 := d
|
||||
decodeJSON(vr, &d2)
|
||||
} else if vr.Type().DataType == JsonbOid {
|
||||
// Same trick as above for getting stack allocation
|
||||
d2 := d
|
||||
decodeJSONB(vr, &d2)
|
||||
} else {
|
||||
if err := Decode(vr, d); err != nil {
|
||||
rows.Fatal(scanArgError{col: i, err: err})
|
||||
|
@ -393,7 +397,7 @@ func (rows *Rows) Values() ([]interface{}, error) {
|
|||
values = append(values, d)
|
||||
case JsonbOid:
|
||||
var d interface{}
|
||||
decodeJSON(vr, &d)
|
||||
decodeJSONB(vr, &d)
|
||||
values = append(values, d)
|
||||
default:
|
||||
rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types"))
|
||||
|
|
55
values.go
55
values.go
|
@ -91,23 +91,25 @@ func init() {
|
|||
"bool": BinaryFormatCode,
|
||||
"bytea": BinaryFormatCode,
|
||||
"char": BinaryFormatCode,
|
||||
"cid": BinaryFormatCode,
|
||||
"cidr": BinaryFormatCode,
|
||||
"date": BinaryFormatCode,
|
||||
"float4": BinaryFormatCode,
|
||||
"float8": BinaryFormatCode,
|
||||
"json": BinaryFormatCode,
|
||||
"jsonb": BinaryFormatCode,
|
||||
"inet": BinaryFormatCode,
|
||||
"int2": BinaryFormatCode,
|
||||
"int4": BinaryFormatCode,
|
||||
"int8": BinaryFormatCode,
|
||||
"oid": BinaryFormatCode,
|
||||
"tid": BinaryFormatCode,
|
||||
"xid": BinaryFormatCode,
|
||||
"cid": BinaryFormatCode,
|
||||
"record": BinaryFormatCode,
|
||||
"text": BinaryFormatCode,
|
||||
"tid": BinaryFormatCode,
|
||||
"timestamp": BinaryFormatCode,
|
||||
"timestamptz": BinaryFormatCode,
|
||||
"varchar": BinaryFormatCode,
|
||||
"xid": BinaryFormatCode,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -889,9 +891,12 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error {
|
|||
return Encode(wbuf, oid, arg)
|
||||
}
|
||||
|
||||
if oid == JsonOid || oid == JsonbOid {
|
||||
if oid == JsonOid {
|
||||
return encodeJSON(wbuf, oid, arg)
|
||||
}
|
||||
if oid == JsonbOid {
|
||||
return encodeJSONB(wbuf, oid, arg)
|
||||
}
|
||||
|
||||
switch arg := arg.(type) {
|
||||
case []string:
|
||||
|
@ -1914,7 +1919,7 @@ func decodeJSON(vr *ValueReader, d interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().DataType != JsonOid && vr.Type().DataType != JsonbOid {
|
||||
if vr.Type().DataType != JsonOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into json", vr.Type().DataType)))
|
||||
}
|
||||
|
||||
|
@ -1927,7 +1932,7 @@ func decodeJSON(vr *ValueReader, d interface{}) error {
|
|||
}
|
||||
|
||||
func encodeJSON(w *WriteBuf, oid Oid, value interface{}) error {
|
||||
if oid != JsonOid && oid != JsonbOid {
|
||||
if oid != JsonOid {
|
||||
return fmt.Errorf("cannot encode JSON into oid %v", oid)
|
||||
}
|
||||
|
||||
|
@ -1942,6 +1947,44 @@ func encodeJSON(w *WriteBuf, oid Oid, value interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func decodeJSONB(vr *ValueReader, d interface{}) error {
|
||||
if vr.Len() == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().DataType != JsonbOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into jsonb", vr.Type().DataType)))
|
||||
}
|
||||
|
||||
bytes := vr.ReadBytes(vr.Len())
|
||||
if bytes[0] != 1 {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown jsonb format byte: %x", bytes[0])))
|
||||
}
|
||||
|
||||
err := json.Unmarshal(bytes[1:], d)
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func encodeJSONB(w *WriteBuf, oid Oid, value interface{}) error {
|
||||
if oid != JsonbOid {
|
||||
return fmt.Errorf("cannot encode JSON into oid %v", oid)
|
||||
}
|
||||
|
||||
s, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to encode json from type: %T", value)
|
||||
}
|
||||
|
||||
w.WriteInt32(int32(len(s) + 1))
|
||||
w.WriteByte(1) // JSONB format header
|
||||
w.WriteBytes(s)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeDate(vr *ValueReader) time.Time {
|
||||
var zeroTime time.Time
|
||||
|
||||
|
|
Loading…
Reference in New Issue