From 7e43eca3d3c7f71bfbdfa88cd8c05191400e2701 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 8 Aug 2016 16:31:01 -0500 Subject: [PATCH 01/75] Remove one allocation per pool query --- conn_pool.go | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/conn_pool.go b/conn_pool.go index 775fb091..9e468cbb 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -30,6 +30,8 @@ type ConnPool struct { pgTypes map[Oid]PgType pgsql_af_inet *byte pgsql_af_inet6 *byte + txAfterClose func(tx *Tx) + rowsAfterClose func(rows *Rows) } type ConnPoolStat struct { @@ -68,6 +70,14 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { p.logLevel = LogLevelNone } + p.txAfterClose = func(tx *Tx) { + p.Release(tx.Conn()) + } + + p.rowsAfterClose = func(rows *Rows) { + p.Release(rows.Conn()) + } + p.allConnections = make([]*Conn, 0, p.maxConnections) p.availableConnections = make([]*Conn, 0, p.maxConnections) p.preparedStatements = make(map[string]*PreparedStatement) @@ -486,11 +496,3 @@ func (p *ConnPool) BeginIso(iso string) (*Tx, error) { return tx, nil } } - -func (p *ConnPool) txAfterClose(tx *Tx) { - p.Release(tx.Conn()) -} - -func (p *ConnPool) rowsAfterClose(rows *Rows) { - p.Release(rows.Conn()) -} From bb73d8427902891bbad7b949b9c60b32949d935f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 8 Aug 2016 17:01:01 -0500 Subject: [PATCH 02/75] Remove unnecessary buf from msgReader Replace with bufio.Reader.Peek for short sizes --- msg_reader.go | 57 +++++++++++++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/msg_reader.go b/msg_reader.go index fd74a63b..069094cd 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -10,7 +10,6 @@ import ( // msgReader is a helper that reads values from a PostgreSQL message. type msgReader struct { reader *bufio.Reader - buf [128]byte msgBytesRemaining int32 err error log func(lvl int, msg string, ctx ...interface{}) @@ -47,10 +46,15 @@ func (r *msgReader) rxMsg() (byte, error) { } } - b := r.buf[0:5] - _, err := io.ReadFull(r.reader, b) + b, err := r.reader.Peek(5) + if err != nil { + r.fatal(err) + return 0, err + } + msgType := b[0] r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b[1:])) - 4 - return b[0], err + r.reader.Discard(5) + return msgType, nil } func (r *msgReader) readByte() byte { @@ -88,8 +92,7 @@ func (r *msgReader) readInt16() int16 { return 0 } - b := r.buf[0:2] - _, err := io.ReadFull(r.reader, b) + b, err := r.reader.Peek(2) if err != nil { r.fatal(err) return 0 @@ -97,6 +100,8 @@ func (r *msgReader) readInt16() int16 { n := int16(binary.BigEndian.Uint16(b)) + r.reader.Discard(2) + if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) } @@ -115,8 +120,7 @@ func (r *msgReader) readInt32() int32 { return 0 } - b := r.buf[0:4] - _, err := io.ReadFull(r.reader, b) + b, err := r.reader.Peek(4) if err != nil { r.fatal(err) return 0 @@ -124,6 +128,8 @@ func (r *msgReader) readInt32() int32 { n := int32(binary.BigEndian.Uint32(b)) + r.reader.Discard(4) + if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) } @@ -142,8 +148,7 @@ func (r *msgReader) readInt64() int64 { return 0 } - b := r.buf[0:8] - _, err := io.ReadFull(r.reader, b) + b, err := r.reader.Peek(8) if err != nil { r.fatal(err) return 0 @@ -151,6 +156,8 @@ func (r *msgReader) readInt64() int64 { n := int64(binary.BigEndian.Uint64(b)) + r.reader.Discard(8) + if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining) } @@ -190,32 +197,34 @@ func (r *msgReader) readCString() string { } // readString reads count bytes and returns as string -func (r *msgReader) readString(count int32) string { +func (r *msgReader) readString(countI32 int32) string { if r.err != nil { return "" } - r.msgBytesRemaining -= count + r.msgBytesRemaining -= countI32 if r.msgBytesRemaining < 0 { r.fatal(errors.New("read past end of message")) return "" } - var b []byte - if count <= int32(len(r.buf)) { - b = r.buf[0:int(count)] + count := int(countI32) + var s string + + if r.reader.Buffered() >= count { + buf, _ := r.reader.Peek(count) + s = string(buf) + r.reader.Discard(count) } else { - b = make([]byte, int(count)) + buf := make([]byte, int(count)) + _, err := io.ReadFull(r.reader, buf) + if err != nil { + r.fatal(err) + return "" + } + s = string(buf) } - _, err := io.ReadFull(r.reader, b) - if err != nil { - r.fatal(err) - return "" - } - - s := string(b) - if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) } From 5f7d01778eaf02b0c0ef9871b934952bbf9afed5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 10 Aug 2016 16:27:44 -0500 Subject: [PATCH 03/75] Add CopyTo to support PostgreSQL copy protocol fixes #102 --- CHANGELOG.md | 1 + README.md | 1 + bench_test.go | 331 ++++++++++++++++++++++++++++++++++++++++++ conn_pool.go | 11 ++ copy_to.go | 241 +++++++++++++++++++++++++++++++ copy_to_test.go | 373 ++++++++++++++++++++++++++++++++++++++++++++++++ doc.go | 20 +++ messages.go | 4 + tx.go | 9 ++ 9 files changed, 991 insertions(+) create mode 100644 copy_to.go create mode 100644 copy_to_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d185f2b..26a1590d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ ## Features +* Add CopyTo * Add PrepareEx * Add basic record to []interface{} decoding * Encode and decode between all Go and PostgreSQL integer types with bounds checking diff --git a/README.md b/README.md index c90bf966..607b38cd 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ Pgx supports many additional features beyond what is available through database/ * Transaction isolation level control * Full TLS connection control * Binary format support for custom types (can be much faster) +* Copy protocol support for faster bulk data loads * Logging support * Configurable connection pool with after connect hooks to do arbitrary connection setup * PostgreSQL array to Go slice mapping for integers, floats, and strings diff --git a/bench_test.go b/bench_test.go index eb9c0595..1ea92cc4 100644 --- a/bench_test.go +++ b/bench_test.go @@ -1,6 +1,9 @@ package pgx_test import ( + "bytes" + "fmt" + "strings" "testing" "time" @@ -432,3 +435,331 @@ func BenchmarkLog15Discard(b *testing.B) { logger.Debug("benchmark", "i", i, "b.N", b.N) } } + +const benchmarkWriteTableCreateSQL = `drop table if exists t; + +create table t( + varchar_1 varchar not null, + varchar_2 varchar not null, + varchar_null_1 varchar, + date_1 date not null, + date_null_1 date, + int4_1 int4 not null, + int4_2 int4 not null, + int4_null_1 int4, + tstz_1 timestamptz not null, + tstz_2 timestamptz, + bool_1 bool not null, + bool_2 bool not null, + bool_3 bool not null +); +` + +const benchmarkWriteTableInsertSQL = `insert into t( + varchar_1, + varchar_2, + varchar_null_1, + date_1, + date_null_1, + int4_1, + int4_2, + int4_null_1, + tstz_1, + tstz_2, + bool_1, + bool_2, + bool_3 +) values ( + $1::varchar, + $2::varchar, + $3::varchar, + $4::date, + $5::date, + $6::int4, + $7::int4, + $8::int4, + $9::timestamptz, + $10::timestamptz, + $11::bool, + $12::bool, + $13::bool +)` + +type benchmarkWriteTableCopyToSrc struct { + count int + idx int + row []interface{} +} + +func (s *benchmarkWriteTableCopyToSrc) Next() bool { + s.idx++ + return s.idx < s.count +} + +func (s *benchmarkWriteTableCopyToSrc) Values() ([]interface{}, error) { + return s.row, nil +} + +func (s *benchmarkWriteTableCopyToSrc) Err() error { + return nil +} + +func newBenchmarkWriteTableCopyToSrc(count int) pgx.CopyToSource { + return &benchmarkWriteTableCopyToSrc{ + count: count, + row: []interface{}{ + "varchar_1", + "varchar_2", + pgx.NullString{}, + time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), + pgx.NullTime{}, + 1, + 2, + pgx.NullInt32{}, + time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local), + true, + false, + true, + }, + } +} + +func benchmarkWriteNRowsViaInsert(b *testing.B, n int) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + mustExec(b, conn, benchmarkWriteTableCreateSQL) + _, err := conn.Prepare("insert_t", benchmarkWriteTableInsertSQL) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + src := newBenchmarkWriteTableCopyToSrc(n) + + tx, err := conn.Begin() + if err != nil { + b.Fatal(err) + } + + for src.Next() { + values, _ := src.Values() + if _, err = tx.Exec("insert_t", values...); err != nil { + b.Fatalf("Exec unexpectedly failed with: %v", err) + } + } + + err = tx.Commit() + if err != nil { + b.Fatal(err) + } + } +} + +// note this function is only used for benchmarks -- it doesn't escape tableName +// or columnNames +func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc pgx.CopyToSource) (int, error) { + maxRowsPerInsert := 65535 / len(columnNames) + rowsThisInsert := 0 + rowCount := 0 + + sqlBuf := &bytes.Buffer{} + args := make(pgx.QueryArgs, 0) + + resetQuery := func() { + sqlBuf.Reset() + fmt.Fprintf(sqlBuf, "insert into %s(%s) values", tableName, strings.Join(columnNames, ", ")) + + args = args[0:0] + + rowsThisInsert = 0 + } + resetQuery() + + tx, err := conn.Begin() + if err != nil { + return 0, err + } + defer tx.Rollback() + + for rowSrc.Next() { + if rowsThisInsert > 0 { + sqlBuf.WriteByte(',') + } + + sqlBuf.WriteByte('(') + + values, err := rowSrc.Values() + if err != nil { + return 0, err + } + + for i, val := range values { + if i > 0 { + sqlBuf.WriteByte(',') + } + sqlBuf.WriteString(args.Append(val)) + } + + sqlBuf.WriteByte(')') + + rowsThisInsert++ + + if rowsThisInsert == maxRowsPerInsert { + _, err := tx.Exec(sqlBuf.String(), args...) + if err != nil { + return 0, err + } + + rowCount += rowsThisInsert + resetQuery() + } + } + + if rowsThisInsert > 0 { + _, err := tx.Exec(sqlBuf.String(), args...) + if err != nil { + return 0, err + } + + rowCount += rowsThisInsert + } + + if err := tx.Commit(); err != nil { + return 0, nil + } + + return rowCount, nil + +} + +func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + mustExec(b, conn, benchmarkWriteTableCreateSQL) + _, err := conn.Prepare("insert_t", benchmarkWriteTableInsertSQL) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + src := newBenchmarkWriteTableCopyToSrc(n) + + _, err := multiInsert(conn, "t", + []string{"varchar_1", + "varchar_2", + "varchar_null_1", + "date_1", + "date_null_1", + "int4_1", + "int4_2", + "int4_null_1", + "tstz_1", + "tstz_2", + "bool_1", + "bool_2", + "bool_3"}, + src) + if err != nil { + b.Fatal(err) + } + } +} + +func benchmarkWriteNRowsViaCopy(b *testing.B, n int) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + mustExec(b, conn, benchmarkWriteTableCreateSQL) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + src := newBenchmarkWriteTableCopyToSrc(n) + + _, err := conn.CopyTo("t", + []string{"varchar_1", + "varchar_2", + "varchar_null_1", + "date_1", + "date_null_1", + "int4_1", + "int4_2", + "int4_null_1", + "tstz_1", + "tstz_2", + "bool_1", + "bool_2", + "bool_3"}, + src) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkWrite5RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 5) +} + +func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 5) +} + +func BenchmarkWrite5RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 5) +} + +func BenchmarkWrite10RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 10) +} + +func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 10) +} + +func BenchmarkWrite10RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 10) +} + +func BenchmarkWrite100RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 100) +} + +func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 100) +} + +func BenchmarkWrite100RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 100) +} + +func BenchmarkWrite1000RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 1000) +} + +func BenchmarkWrite1000RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 1000) +} + +func BenchmarkWrite1000RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 1000) +} + +func BenchmarkWrite10000RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 10000) +} + +func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 10000) +} + +func BenchmarkWrite10000RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 10000) +} diff --git a/conn_pool.go b/conn_pool.go index 9e468cbb..fdd54114 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -496,3 +496,14 @@ func (p *ConnPool) BeginIso(iso string) (*Tx, error) { return tx, nil } } + +// CopyTo acquires a connection, delegates the call to that connection, and releases the connection +func (p *ConnPool) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { + c, err := p.Acquire() + if err != nil { + return 0, err + } + defer p.Release(c) + + return c.CopyTo(tableName, columnNames, rowSrc) +} diff --git a/copy_to.go b/copy_to.go new file mode 100644 index 00000000..91292bb0 --- /dev/null +++ b/copy_to.go @@ -0,0 +1,241 @@ +package pgx + +import ( + "bytes" + "fmt" +) + +// CopyToRows returns a CopyToSource interface over the provided rows slice +// making it usable by *Conn.CopyTo. +func CopyToRows(rows [][]interface{}) CopyToSource { + return ©ToRows{rows: rows, idx: -1} +} + +type copyToRows struct { + rows [][]interface{} + idx int +} + +func (ctr *copyToRows) Next() bool { + ctr.idx++ + return ctr.idx < len(ctr.rows) +} + +func (ctr *copyToRows) Values() ([]interface{}, error) { + return ctr.rows[ctr.idx], nil +} + +func (ctr *copyToRows) Err() error { + return nil +} + +// CopyToSource is the interface used by *Conn.CopyTo as the source for copy data. +type CopyToSource interface { + // Next returns true if there is another row and makes the next row data + // available to Values(). When there are no more rows available or an error + // has occurred it returns false. + Next() bool + + // Values returns the values for the current row. + Values() ([]interface{}, error) + + // Err returns any error that has been encountered by the CopyToSource. If + // this is not nil *Conn.CopyTo will abort the copy. + Err() error +} + +type copyTo struct { + conn *Conn + tableName string + columnNames []string + rowSrc CopyToSource + readerErrChan chan error +} + +func (ct *copyTo) readUntilReadyForQuery() { + for { + t, r, err := ct.conn.rxMsg() + if err != nil { + ct.readerErrChan <- err + close(ct.readerErrChan) + return + } + + switch t { + case readyForQuery: + ct.conn.rxReadyForQuery(r) + close(ct.readerErrChan) + return + case commandComplete: + case errorResponse: + ct.readerErrChan <- ct.conn.rxErrorResponse(r) + default: + err = ct.conn.processContextFreeMsg(t, r) + if err != nil { + ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r) + } + } + } +} + +func (ct *copyTo) waitForReaderDone() error { + var err error + for err = range ct.readerErrChan { + } + return err +} + +func (ct *copyTo) run() (int, error) { + quotedTableName := quoteIdentifier(ct.tableName) + buf := &bytes.Buffer{} + for i, cn := range ct.columnNames { + if i != 0 { + buf.WriteString(", ") + } + buf.WriteString(quoteIdentifier(cn)) + } + quotedColumnNames := buf.String() + + ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) + if err != nil { + return 0, err + } + + err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) + if err != nil { + return 0, err + } + + err = ct.conn.readUntilCopyInResponse() + if err != nil { + return 0, err + } + + go ct.readUntilReadyForQuery() + defer ct.waitForReaderDone() + + wbuf := newWriteBuf(ct.conn, copyData) + + wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000")) + wbuf.WriteInt32(0) + wbuf.WriteInt32(0) + + var sentCount int + + for ct.rowSrc.Next() { + select { + case err = <-ct.readerErrChan: + return 0, err + default: + } + + if len(wbuf.buf) > 65536 { + wbuf.closeMsg() + _, err = ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return 0, err + } + + // Directly manipulate wbuf to reset to reuse the same buffer + wbuf.buf = wbuf.buf[0:5] + wbuf.buf[0] = copyData + wbuf.sizeIdx = 1 + } + + sentCount++ + + values, err := ct.rowSrc.Values() + if err != nil { + ct.cancelCopyIn() + return 0, err + } + if len(values) != len(ct.columnNames) { + ct.cancelCopyIn() + return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) + } + + wbuf.WriteInt16(int16(len(ct.columnNames))) + for i, val := range values { + err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val) + if err != nil { + ct.cancelCopyIn() + return 0, err + } + + } + } + + if ct.rowSrc.Err() != nil { + ct.cancelCopyIn() + return 0, ct.rowSrc.Err() + } + + wbuf.WriteInt16(-1) // terminate the copy stream + + wbuf.startMsg(copyDone) + wbuf.closeMsg() + _, err = ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return 0, err + } + + err = ct.waitForReaderDone() + if err != nil { + return 0, err + } + return sentCount, nil +} + +func (c *Conn) readUntilCopyInResponse() error { + for { + var t byte + var r *msgReader + t, r, err := c.rxMsg() + if err != nil { + return err + } + + switch t { + case copyInResponse: + return nil + default: + err = c.processContextFreeMsg(t, r) + if err != nil { + return err + } + } + } +} + +func (ct *copyTo) cancelCopyIn() error { + wbuf := newWriteBuf(ct.conn, copyFail) + wbuf.WriteCString("client error: abort") + wbuf.closeMsg() + _, err := ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return err + } + + return nil +} + +// CopyTo uses the PostgreSQL copy protocol to perform bulk data insertion. +// It returns the number of rows copied and an error. +// +// CopyTo requires all values use the binary format. Almost all types +// implemented by pgx use the binary format by default. Types implementing +// Encoder can only be used if they encode to the binary format. +func (c *Conn) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { + ct := ©To{ + conn: c, + tableName: tableName, + columnNames: columnNames, + rowSrc: rowSrc, + readerErrChan: make(chan error), + } + + return ct.run() +} diff --git a/copy_to_test.go b/copy_to_test.go new file mode 100644 index 00000000..d810c4fb --- /dev/null +++ b/copy_to_test.go @@ -0,0 +1,373 @@ +package pgx_test + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/jackc/pgx" +) + +func TestConnCopyToSmall(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz + )`) + + inputRows := [][]interface{}{ + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, + {nil, nil, nil, nil, nil, nil, nil}, + } + + copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyToRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyTo: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyToLarge(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz, + h bytea + )`) + + inputRows := [][]interface{}{} + + for i := 0; i < 10000; i++ { + inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) + } + + copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyToRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyTo: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal") + } + + ensureConnValid(t, conn) +} + +func TestConnCopyToFailServerSideMidway(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int4, + b varchar not null + )`) + + inputRows := [][]interface{}{ + {int32(1), "abc"}, + {int32(2), nil}, // this row should trigger a failure + {int32(3), "def"}, + } + + copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows)) + if err == nil { + t.Errorf("Expected CopyTo return error, but it did not") + } + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err) + } + if copyCount != 0 { + t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type failSource struct { + count int +} + +func (fs *failSource) Next() bool { + time.Sleep(time.Millisecond * 100) + fs.count++ + return fs.count < 100 +} + +func (fs *failSource) Values() ([]interface{}, error) { + if fs.count == 3 { + return []interface{}{nil}, nil + } + return []interface{}{make([]byte, 100000)}, nil +} + +func (fs *failSource) Err() error { + return nil +} + +func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + startTime := time.Now() + + copyCount, err := conn.CopyTo("foo", []string{"a"}, &failSource{}) + if err == nil { + t.Errorf("Expected CopyTo return error, but it did not") + } + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err) + } + if copyCount != 0 { + t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) + } + + endTime := time.Now() + copyTime := endTime.Sub(startTime) + if copyTime > time.Second { + t.Errorf("Failing CopyTo shouldn't have taken so long: %v", copyTime) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type clientFailSource struct { + count int + err error +} + +func (cfs *clientFailSource) Next() bool { + cfs.count++ + return cfs.count < 100 +} + +func (cfs *clientFailSource) Values() ([]interface{}, error) { + if cfs.count == 3 { + cfs.err = fmt.Errorf("client error") + return nil, cfs.err + } + return []interface{}{make([]byte, 100000)}, nil +} + +func (cfs *clientFailSource) Err() error { + return cfs.err +} + +func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFailSource{}) + if err == nil { + t.Errorf("Expected CopyTo return error, but it did not") + } + if copyCount != 0 { + t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type clientFinalErrSource struct { + count int +} + +func (cfs *clientFinalErrSource) Next() bool { + cfs.count++ + return cfs.count < 5 +} + +func (cfs *clientFinalErrSource) Values() ([]interface{}, error) { + return []interface{}{make([]byte, 100000)}, nil +} + +func (cfs *clientFinalErrSource) Err() error { + return fmt.Errorf("final error") +} + +func TestConnCopyToCopyToSourceErrorEnd(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFinalErrSource{}) + if err == nil { + t.Errorf("Expected CopyTo return error, but it did not") + } + if copyCount != 0 { + t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} diff --git a/doc.go b/doc.go index 0fd3d2f6..bf624c22 100644 --- a/doc.go +++ b/doc.go @@ -104,6 +104,26 @@ creates a transaction with a specified isolation level. return err } +Copy Protocol + +Use CopyTo to efficiently insert multiple rows at a time using the PostgreSQL +copy protocol. CopyTo accepts a CopyToSource interface. If the data is already +in a [][]interface{} use CopyToRows to wrap it in a CopyToSource interface. Or +implement CopyToSource to avoid buffering the entire data set in memory. + + rows := [][]interface{}{ + {"John", "Smith", int32(36)}, + {"Jane", "Doe", int32(29)}, + } + + copyCount, err := conn.CopyTo( + "people", + []string{"first_name", "last_name", "age"}, + pgx.CopyToRows(rows), + ) + +CopyTo can be faster than an insert with as few as 5 rows. + Listen and Notify pgx can listen to the PostgreSQL notification system with the diff --git a/messages.go b/messages.go index 1fbd9cbc..7f04f1f2 100644 --- a/messages.go +++ b/messages.go @@ -25,6 +25,10 @@ const ( noData = 'n' closeComplete = '3' flush = 'H' + copyInResponse = 'G' + copyData = 'd' + copyFail = 'f' + copyDone = 'c' ) type startupMessage struct { diff --git a/tx.go b/tx.go index e5c90c23..36f99c28 100644 --- a/tx.go +++ b/tx.go @@ -158,6 +158,15 @@ func (tx *Tx) QueryRow(sql string, args ...interface{}) *Row { return (*Row)(rows) } +// CopyTo delegates to the underlying *Conn +func (tx *Tx) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { + if tx.status != TxStatusInProgress { + return 0, ErrTxClosed + } + + return tx.conn.CopyTo(tableName, columnNames, rowSrc) +} + // Conn returns the *Conn this transaction is using. func (tx *Tx) Conn() *Conn { return tx.conn From cfb0304ab0cf553a522d6d8a1374a5316b1c694a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 10 Aug 2016 16:28:29 -0500 Subject: [PATCH 04/75] Fix typos --- CHANGELOG.md | 2 +- doc.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 26a1590d..41c78c21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ * Encode and decode between all Go and PostgreSQL integer types with bounds checking * Decode inet/cidr to net.IP * Encode/decode [][]byte to/from bytea[] -* Encode/decode named types whoses underlying types are string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64 +* Encode/decode named types whose underlying types are string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64 ## Performance diff --git a/doc.go b/doc.go index bf624c22..65ca385c 100644 --- a/doc.go +++ b/doc.go @@ -50,7 +50,7 @@ pgx also implements QueryRow in the same style as database/sql. return err } -Use exec to execute a query that does not return a result set. +Use Exec to execute a query that does not return a result set. commandTag, err := conn.Exec("delete from widgets where id=$1", 42) if err != nil { From 7a2738f9f231b1a71c66d5c346d2286953df4166 Mon Sep 17 00:00:00 2001 From: Mostafa Hajizadeh Date: Thu, 11 Aug 2016 08:06:58 +0430 Subject: [PATCH 05/75] Fix minor documentation mistake: s/slice/null/ --- doc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc.go b/doc.go index 65ca385c..c202f861 100644 --- a/doc.go +++ b/doc.go @@ -156,7 +156,7 @@ Array Mapping pgx maps between int16, int32, int64, float32, float64, and string Go slices and the equivalent PostgreSQL array type. Go slices of native types do not -support nulls, so if a PostgreSQL array that contains a slice is read into a +support nulls, so if a PostgreSQL array that contains a null is read into a native Go slice an error will occur. Hstore Mapping From 9ce81d7ab7131c25219ca0bd1c09e2d1de8a930e Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sun, 21 Aug 2016 14:22:32 -0400 Subject: [PATCH 06/75] Updates test instructions in README Lets the user know about extra packages that need to be installed for the tests to run, and that connection_settings_test.go.example has been renamed to conn_config_test.go.example. --- README.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 607b38cd..ea404819 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,13 @@ skip tests for connection types that are not configured. ### Normal Test Environment -To setup the normal test environment run the following SQL: +To setup the normal test environment, first install these dependencies: + + go get github.com/jackc/fake + go get github.com/shopspring/decimal + go get gopkg.in/inconshreveable/log15.v2 + +Then run the following SQL: create user pgx_md5 password 'secret'; create database pgx_test; @@ -66,7 +72,7 @@ Connect to database pgx_test and run: create extension hstore; -Next open connection_settings_test.go.example and make a copy without the +Next open conn_config_test.go.example and make a copy without the .example. If your PostgreSQL server is accepting connections on 127.0.0.1, then you are done. From 32862d9bf85d25b6f415e55656a940d7a926c6aa Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 26 Aug 2016 07:38:12 -0500 Subject: [PATCH 07/75] Update versioning policy --- README.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index ea404819..cf244b04 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,8 @@ If you are developing on Windows with TCP connections: ## Version Policy -pgx follows semantic versioning for the documented public API. ```master``` -branch tracks the latest stable branch (```v2```). Consider using ```import -"gopkg.in/jackc/pgx.v2"``` to lock to the ```v2``` branch or use a vendoring -tool such as [godep](https://github.com/tools/godep). +pgx follows semantic versioning for the documented public API on stable releases. Branch ```v2``` is the latest stable release. ```master``` can contain new features or behavior that will change or be removed before being merged to the stable ```v2``` branch (in practice, this occurs very rarely). + +Consider using a vendoring +tool such as [godep](https://github.com/tools/godep) or importing pgx via ```import +"gopkg.in/jackc/pgx.v2"``` to lock to the ```v2``` branch. From 2508faa9ced29f4bbb3a39483e4b5ac2627a7c48 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 26 Aug 2016 07:39:15 -0500 Subject: [PATCH 08/75] Release 2.9.0 --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 41c78c21..ffcc9594 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -# Unreleased +# 2.9.0 (August 26, 2016) ## Fixes From 9f6b99e3321fc72a0eac4542098030669cd2b7a2 Mon Sep 17 00:00:00 2001 From: Martin Hamrle Date: Tue, 30 Aug 2016 19:59:16 +0200 Subject: [PATCH 09/75] Cleanups Cleanups suggested by gometalinter tools. --- conn.go | 36 ++++++++++++++++---------------- conn_config_test.go.example | 2 -- conn_pool.go | 41 ++++++++++++++++++------------------- conn_pool_test.go | 4 ++-- conn_test.go | 2 +- fastpath.go | 2 -- hstore.go | 1 - messages.go | 10 ++++----- msg_reader.go | 4 ++-- query_test.go | 5 +++-- sql.go | 4 ++-- sql_test.go | 5 +++-- values.go | 36 +++++++++++++++----------------- values_test.go | 22 ++++++++++---------- 14 files changed, 84 insertions(+), 90 deletions(-) diff --git a/conn.go b/conn.go index c2519003..c928ed98 100644 --- a/conn.go +++ b/conn.go @@ -63,8 +63,8 @@ type Conn struct { logLevel int mr msgReader fp *fastpath - pgsql_af_inet *byte - pgsql_af_inet6 *byte + pgsqlAfInet *byte + pgsqlAfInet6 *byte busy bool poolResetCount int preallocatedRows []Rows @@ -145,7 +145,7 @@ func Connect(config ConnConfig) (c *Conn, err error) { return connect(config, nil, nil, nil) } -func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsql_af_inet *byte, pgsql_af_inet6 *byte) (c *Conn, err error) { +func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsqlAfInet *byte, pgsqlAfInet6 *byte) (c *Conn, err error) { c = new(Conn) c.config = config @@ -157,13 +157,13 @@ func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsql_af_inet *byte, pgs } } - if pgsql_af_inet != nil { - c.pgsql_af_inet = new(byte) - *c.pgsql_af_inet = *pgsql_af_inet + if pgsqlAfInet != nil { + c.pgsqlAfInet = new(byte) + *c.pgsqlAfInet = *pgsqlAfInet } - if pgsql_af_inet6 != nil { - c.pgsql_af_inet6 = new(byte) - *c.pgsql_af_inet6 = *pgsql_af_inet6 + if pgsqlAfInet6 != nil { + c.pgsqlAfInet6 = new(byte) + *c.pgsqlAfInet6 = *pgsqlAfInet6 } if c.config.LogLevel != 0 { @@ -315,7 +315,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } - if c.pgsql_af_inet == nil || c.pgsql_af_inet6 == nil { + if c.pgsqlAfInet == nil || c.pgsqlAfInet6 == nil { err = c.loadInetConstants() if err != nil { return err @@ -372,8 +372,8 @@ func (c *Conn) loadInetConstants() error { return err } - c.pgsql_af_inet = &ipv4[0] - c.pgsql_af_inet6 = &ipv6[0] + c.pgsqlAfInet = &ipv4[0] + c.pgsqlAfInet6 = &ipv6[0] return nil } @@ -446,7 +446,7 @@ func ParseURI(uri string) (ConnConfig, error) { return cp, nil } -var dsn_regexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) +var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) // ParseDSN parses a database DSN (data source name) into a ConnConfig // @@ -462,7 +462,7 @@ var dsn_regexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) func ParseDSN(s string) (ConnConfig, error) { var cp ConnConfig - m := dsn_regexp.FindAllStringSubmatch(s, -1) + m := dsnRegexp.FindAllStringSubmatch(s, -1) var sslmode string @@ -477,11 +477,11 @@ func ParseDSN(s string) (ConnConfig, error) { case "host": cp.Host = b[2] case "port": - if p, err := strconv.ParseUint(b[2], 10, 16); err != nil { + p, err := strconv.ParseUint(b[2], 10, 16) + if err != nil { return cp, err - } else { - cp.Port = uint16(p) } + cp.Port = uint16(p) case "dbname": cp.Database = b[2] case "sslmode": @@ -627,7 +627,7 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared if opts != nil { if len(opts.ParameterOids) > 65535 { - return nil, errors.New(fmt.Sprintf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids))) + return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids)) } wbuf.WriteInt16(int16(len(opts.ParameterOids))) for _, oid := range opts.ParameterOids { diff --git a/conn_config_test.go.example b/conn_config_test.go.example index 358e0247..0b80d490 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -11,7 +11,6 @@ var tcpConnConfig *pgx.ConnConfig = nil var unixSocketConnConfig *pgx.ConnConfig = nil var md5ConnConfig *pgx.ConnConfig = nil var plainPasswordConnConfig *pgx.ConnConfig = nil -var noPasswordConnConfig *pgx.ConnConfig = nil var invalidUserConnConfig *pgx.ConnConfig = nil var tlsConnConfig *pgx.ConnConfig = nil var customDialerConnConfig *pgx.ConnConfig = nil @@ -20,7 +19,6 @@ var customDialerConnConfig *pgx.ConnConfig = nil // var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"} // var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // var plainPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"} -// var noPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_none", Database: "pgx_test"} // var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"} // var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}} // var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/conn_pool.go b/conn_pool.go index fdd54114..6fbe143a 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -28,8 +28,8 @@ type ConnPool struct { preparedStatements map[string]*PreparedStatement acquireTimeout time.Duration pgTypes map[Oid]PgType - pgsql_af_inet *byte - pgsql_af_inet6 *byte + pgsqlAfInet *byte + pgsqlAfInet6 *byte txAfterClose func(tx *Tx) rowsAfterClose func(rows *Rows) } @@ -148,26 +148,25 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { // Create a new connection. // Careful here: createConnectionUnlocked() removes the current lock, // creates a connection and then locks it back. - if c, err := p.createConnectionUnlocked(); err == nil { - c.poolResetCount = p.resetCount - p.allConnections = append(p.allConnections, c) - return c, nil - } else { + c, err := p.createConnectionUnlocked() + if err != nil { return nil, err } - } else { - // All connections are in use and we cannot create more - if p.logLevel >= LogLevelWarn { - p.logger.Warn("All connections in pool are busy - waiting...") - } + c.poolResetCount = p.resetCount + p.allConnections = append(p.allConnections, c) + return c, nil + } + // All connections are in use and we cannot create more + if p.logLevel >= LogLevelWarn { + p.logger.Warn("All connections in pool are busy - waiting...") + } - // Wait until there is an available connection OR room to create a new connection - for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections { - if p.deadlinePassed(deadline) { - return nil, errors.New("Timeout: All connections in pool are busy") - } - p.cond.Wait() + // Wait until there is an available connection OR room to create a new connection + for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections { + if p.deadlinePassed(deadline) { + return nil, errors.New("Timeout: All connections in pool are busy") } + p.cond.Wait() } // Stop the timer so that we do not spawn it on every acquire call. @@ -282,7 +281,7 @@ func (p *ConnPool) Stat() (s ConnPoolStat) { } func (p *ConnPool) createConnection() (*Conn, error) { - c, err := connect(p.config, p.pgTypes, p.pgsql_af_inet, p.pgsql_af_inet6) + c, err := connect(p.config, p.pgTypes, p.pgsqlAfInet, p.pgsqlAfInet6) if err != nil { return nil, err } @@ -318,8 +317,8 @@ func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { // all the known statements for the new connection. func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) { p.pgTypes = c.PgTypes - p.pgsql_af_inet = c.pgsql_af_inet - p.pgsql_af_inet6 = c.pgsql_af_inet6 + p.pgsqlAfInet = c.pgsqlAfInet + p.pgsqlAfInet6 = c.pgsqlAfInet6 if p.afterConnect != nil { err := p.afterConnect(c) diff --git a/conn_pool_test.go b/conn_pool_test.go index 9aa31758..e3ae0036 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -40,7 +40,7 @@ func releaseAllConnections(pool *pgx.ConnPool, connections []*pgx.Conn) { func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error) { startTime := time.Now() c, err := pool.Acquire() - return c, time.Now().Sub(startTime), err + return c, time.Since(startTime), err } func TestNewConnPool(t *testing.T) { @@ -215,7 +215,7 @@ func TestPoolNonBlockingConnections(t *testing.T) { // Prior to createConnectionUnlocked() use the test took // maxConnections * openTimeout seconds to complete. // With createConnectionUnlocked() it takes ~ 1 * openTimeout seconds. - timeTaken := time.Now().Sub(startedAt) + timeTaken := time.Since(startedAt) if timeTaken > openTimeout+1*time.Second { t.Fatalf("Expected all Acquire() to run in parallel and take about %v, instead it took '%v'", openTimeout, timeTaken) } diff --git a/conn_test.go b/conn_test.go index 181a3ed2..9ed073ce 100644 --- a/conn_test.go +++ b/conn_test.go @@ -914,7 +914,7 @@ func TestPrepareQueryManyParameters(t *testing.T) { args := make([]interface{}, 0, tt.count) for j := 0; j < tt.count; j++ { params = append(params, fmt.Sprintf("($%d::text)", j+1)) - args = append(args, strconv.FormatInt(int64(j), 10)) + args = append(args, strconv.Itoa(j)) } sql := "values" + strings.Join(params, ", ") diff --git a/fastpath.go b/fastpath.go index 8814e559..19b98784 100644 --- a/fastpath.go +++ b/fastpath.go @@ -4,8 +4,6 @@ import ( "encoding/binary" ) -type fastpathArg []byte - func newFastpath(cn *Conn) *fastpath { return &fastpath{cn: cn, fns: make(map[string]Oid)} } diff --git a/hstore.go b/hstore.go index a5d40cce..0ab9f779 100644 --- a/hstore.go +++ b/hstore.go @@ -15,7 +15,6 @@ const ( hsVal hsNul hsNext - hsEnd ) type hstoreParser struct { diff --git a/messages.go b/messages.go index 7f04f1f2..db0258de 100644 --- a/messages.go +++ b/messages.go @@ -39,10 +39,10 @@ func newStartupMessage() *startupMessage { return &startupMessage{map[string]string{}} } -func (self *startupMessage) Bytes() (buf []byte) { +func (s *startupMessage) Bytes() (buf []byte) { buf = make([]byte, 8, 128) binary.BigEndian.PutUint32(buf[4:8], uint32(protocolVersionNumber)) - for key, value := range self.options { + for key, value := range s.options { buf = append(buf, key...) buf = append(buf, 0) buf = append(buf, value...) @@ -89,8 +89,8 @@ type PgError struct { Routine string } -func (self PgError) Error() string { - return self.Severity + ": " + self.Message + " (SQLSTATE " + self.Code + ")" +func (pe PgError) Error() string { + return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } func newWriteBuf(c *Conn, t byte) *WriteBuf { @@ -99,7 +99,7 @@ func newWriteBuf(c *Conn, t byte) *WriteBuf { return &c.writeBuf } -// WrifeBuf is used build messages to send to the PostgreSQL server. It is used +// WriteBuf is used build messages to send to the PostgreSQL server. It is used // by the Encoder interface when implementing custom encoders. type WriteBuf struct { buf []byte diff --git a/msg_reader.go b/msg_reader.go index 069094cd..c8869bdd 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -62,7 +62,7 @@ func (r *msgReader) readByte() byte { return 0 } - r.msgBytesRemaining -= 1 + r.msgBytesRemaining-- if r.msgBytesRemaining < 0 { r.fatal(errors.New("read past end of message")) return 0 @@ -216,7 +216,7 @@ func (r *msgReader) readString(countI32 int32) string { s = string(buf) r.reader.Discard(count) } else { - buf := make([]byte, int(count)) + buf := make([]byte, count) _, err := io.ReadFull(r.reader, buf) if err != nil { r.fatal(err) diff --git a/query_test.go b/query_test.go index 2cf8b3cd..06a18ffe 100644 --- a/query_test.go +++ b/query_test.go @@ -3,11 +3,12 @@ package pgx_test import ( "bytes" "database/sql" - "github.com/jackc/pgx" "strings" "testing" "time" + "github.com/jackc/pgx" + "github.com/shopspring/decimal" ) @@ -784,7 +785,7 @@ func TestQueryRowCoreByteSlice(t *testing.T) { t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) } - if bytes.Compare(actual, tt.expected) != 0 { + if !bytes.Equal(actual, tt.expected) { t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql) } diff --git a/sql.go b/sql.go index 9445c263..7ee0f2a0 100644 --- a/sql.go +++ b/sql.go @@ -14,7 +14,7 @@ func init() { placeholders = make([]string, 64) for i := 1; i < 64; i++ { - placeholders[i] = "$" + strconv.FormatInt(int64(i), 10) + placeholders[i] = "$" + strconv.Itoa(i) } } @@ -25,5 +25,5 @@ func (qa *QueryArgs) Append(v interface{}) string { if len(*qa) < len(placeholders) { return placeholders[len(*qa)] } - return "$" + strconv.FormatInt(int64(len(*qa)), 10) + return "$" + strconv.Itoa(len(*qa)) } diff --git a/sql_test.go b/sql_test.go index eafd92fa..dd036035 100644 --- a/sql_test.go +++ b/sql_test.go @@ -1,16 +1,17 @@ package pgx_test import ( - "github.com/jackc/pgx" "strconv" "testing" + + "github.com/jackc/pgx" ) func TestQueryArgs(t *testing.T) { var qa pgx.QueryArgs for i := 1; i < 512; i++ { - expectedPlaceholder := "$" + strconv.FormatInt(int64(i), 10) + expectedPlaceholder := "$" + strconv.Itoa(i) placeholder := qa.Append(i) if placeholder != expectedPlaceholder { t.Errorf(`Expected qa.Append to return "%s", but it returned "%s"`, expectedPlaceholder, placeholder) diff --git a/values.go b/values.go index b6e0a84b..ee43813b 100644 --- a/values.go +++ b/values.go @@ -225,28 +225,28 @@ type NullString struct { Valid bool // Valid is true if String is not NULL } -func (s *NullString) Scan(vr *ValueReader) error { +func (n *NullString) Scan(vr *ValueReader) error { // Not checking oid as so we can scan anything into into a NullString - may revisit this decision later if vr.Len() == -1 { - s.String, s.Valid = "", false + n.String, n.Valid = "", false return nil } - s.Valid = true - s.String = decodeText(vr) + n.Valid = true + n.String = decodeText(vr) return vr.Err() } func (n NullString) FormatCode() int16 { return TextFormatCode } -func (s NullString) Encode(w *WriteBuf, oid Oid) error { - if !s.Valid { +func (n NullString) Encode(w *WriteBuf, oid Oid) error { + if !n.Valid { w.WriteInt32(-1) return nil } - return encodeString(w, oid, s.String) + return encodeString(w, oid, n.String) } // NullInt16 represents an smallint that may be null. NullInt16 implements the @@ -621,10 +621,9 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { if refVal.IsNil() { wbuf.WriteInt32(-1) return nil - } else { - arg = refVal.Elem().Interface() - return Encode(wbuf, oid, arg) } + arg = refVal.Elem().Interface() + return Encode(wbuf, oid, arg) } if oid == JsonOid || oid == JsonbOid { @@ -892,14 +891,13 @@ func Decode(vr *ValueReader, d interface{}) error { el.Set(reflect.Zero(el.Type())) } return nil - } else { - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - d = el.Interface() - return Decode(vr, d) } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + d = el.Interface() + return Decode(vr, d) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: n := decodeInt(vr) if el.OverflowInt(n) { @@ -1645,10 +1643,10 @@ func encodeIPNet(w *WriteBuf, oid Oid, value net.IPNet) error { switch len(value.IP) { case net.IPv4len: size = 8 - family = *w.conn.pgsql_af_inet + family = *w.conn.pgsqlAfInet case net.IPv6len: size = 20 - family = *w.conn.pgsql_af_inet6 + family = *w.conn.pgsqlAfInet6 default: return fmt.Errorf("Unexpected IP length: %v", len(value.IP)) } diff --git a/values_test.go b/values_test.go index 063598d9..7a690055 100644 --- a/values_test.go +++ b/values_test.go @@ -630,7 +630,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::bool[]", []bool{true, false, true}, &[]bool{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]bool))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]bool))) { t.Errorf("failed to encode bool[]") } }, @@ -638,7 +638,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]int16))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]int16))) { t.Errorf("failed to encode smallint[]") } }, @@ -646,7 +646,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]uint16))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]uint16))) { t.Errorf("failed to encode smallint[]") } }, @@ -654,7 +654,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::int[]", []int32{2, 4, 484}, &[]int32{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]int32))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]int32))) { t.Errorf("failed to encode int[]") } }, @@ -662,7 +662,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]uint32))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]uint32))) { t.Errorf("failed to encode int[]") } }, @@ -670,7 +670,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]int64))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]int64))) { t.Errorf("failed to encode bigint[]") } }, @@ -678,7 +678,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]uint64))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]uint64))) { t.Errorf("failed to encode bigint[]") } }, @@ -686,7 +686,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]string))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]string))) { t.Errorf("failed to encode text[]") } }, @@ -694,7 +694,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) { t.Errorf("failed to encode time.Time[] to timestamp[]") } }, @@ -702,7 +702,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) { t.Errorf("failed to encode time.Time[] to timestamptz[]") } }, @@ -718,7 +718,7 @@ func TestArrayDecoding(t *testing.T) { for i := range queryBytesSliceSlice { qb := queryBytesSliceSlice[i] sb := scanBytesSliceSlice[i] - if bytes.Compare(qb, sb) != 0 { + if !bytes.Equal(qb, sb) { t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb) } } From 7dbfd4bf4bec3fff9d54e87757a9d73272dd84ca Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Thu, 1 Sep 2016 22:55:18 -0400 Subject: [PATCH 10/75] Switches oid to uint32 --- messages.go | 2 +- msg_reader.go | 28 ++++++++++++++++++++++++++++ value_reader.go | 16 +++++++++++++++- values_test.go | 33 +++++++++++++++++++++++++++++++++ 4 files changed, 77 insertions(+), 2 deletions(-) diff --git a/messages.go b/messages.go index db0258de..7e5c3b54 100644 --- a/messages.go +++ b/messages.go @@ -53,7 +53,7 @@ func (s *startupMessage) Bytes() (buf []byte) { return buf } -type Oid int32 +type Oid uint32 type FieldDescription struct { Name string diff --git a/msg_reader.go b/msg_reader.go index c8869bdd..2bcd2d51 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -137,6 +137,34 @@ func (r *msgReader) readInt32() int32 { return n } +func (r *msgReader) readUint32() uint32 { + if r.err != nil { + return 0 + } + + r.msgBytesRemaining -= 4 + if r.msgBytesRemaining < 0 { + r.fatal(errors.New("read past end of message")) + return 0 + } + + b, err := r.reader.Peek(4) + if err != nil { + r.fatal(err) + return 0 + } + + n := uint32(binary.BigEndian.Uint32(b)) + + r.reader.Discard(4) + + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + } + + return n +} + func (r *msgReader) readInt64() int64 { if r.err != nil { return 0 diff --git a/value_reader.go b/value_reader.go index 4936b887..6e552ea8 100644 --- a/value_reader.go +++ b/value_reader.go @@ -74,6 +74,20 @@ func (r *ValueReader) ReadInt32() int32 { return r.mr.readInt32() } +func (r *ValueReader) ReadUint32() uint32 { + if r.err != nil { + return 0 + } + + r.valueBytesRemaining -= 4 + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return 0 + } + + return r.mr.readUint32() +} + func (r *ValueReader) ReadInt64() int64 { if r.err != nil { return 0 @@ -89,7 +103,7 @@ func (r *ValueReader) ReadInt64() int64 { } func (r *ValueReader) ReadOid() Oid { - return Oid(r.ReadInt32()) + return Oid(r.ReadUint32()) } // ReadString reads count bytes and returns as string diff --git a/values_test.go b/values_test.go index 7a690055..3e650b61 100644 --- a/values_test.go +++ b/values_test.go @@ -551,6 +551,39 @@ func TestInetCidrTranscodeWithJustIP(t *testing.T) { } } +func TestOid(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + sql string + value pgx.Oid + }{ + {"select $1::oid", 0}, + {"select $1::oid", 1}, + {"select $1::oid", 4294967295}, + } + + for i, tt := range tests { + expected := tt.value + var actual pgx.Oid + + err := conn.QueryRow(tt.sql, expected).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, expected) + continue + } + + if actual != expected { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, expected, actual, tt.sql) + } + + ensureConnValid(t, conn) + } +} + func TestNullX(t *testing.T) { t.Parallel() From 1061b1f978ed615ebb05f9b800cc7bcaa0e600c1 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 3 Sep 2016 18:04:55 -0400 Subject: [PATCH 11/75] Adds Xid type --- messages.go | 6 ++++ values.go | 90 ++++++++++++++++++++++++++++++++++++++++++++++++++ values_test.go | 4 +++ 3 files changed, 100 insertions(+) diff --git a/messages.go b/messages.go index 7e5c3b54..053a4c13 100644 --- a/messages.go +++ b/messages.go @@ -138,6 +138,12 @@ func (wb *WriteBuf) WriteInt32(n int32) { wb.buf = append(wb.buf, b...) } +func (wb *WriteBuf) WriteUint32(n uint32) { + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, uint32(n)) + wb.buf = append(wb.buf, b...) +} + func (wb *WriteBuf) WriteInt64(n int64) { b := make([]byte, 8) binary.BigEndian.PutUint64(b, uint64(n)) diff --git a/values.go b/values.go index ee43813b..d2c56b92 100644 --- a/values.go +++ b/values.go @@ -22,6 +22,7 @@ const ( Int4Oid = 23 TextOid = 25 OidOid = 26 + XidOid = 28 JsonOid = 114 CidrOid = 650 CidrArrayOid = 651 @@ -93,6 +94,7 @@ func init() { "int4": BinaryFormatCode, "int8": BinaryFormatCode, "oid": BinaryFormatCode, + "xid": BinaryFormatCode, "record": BinaryFormatCode, "text": BinaryFormatCode, "timestamp": BinaryFormatCode, @@ -327,6 +329,47 @@ func (n NullInt32) Encode(w *WriteBuf, oid Oid) error { return encodeInt32(w, oid, n.Int32) } +type Xid uint32 + +// NullXid represents a transaction ID (Xid) that may be null. NullXid implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullXid struct { + Xid Xid + Valid bool // Valid is true if Int32 is not NULL +} + +func (n *NullXid) Scan(vr *ValueReader) error { + if vr.Type().DataType != XidOid { + return SerializationError(fmt.Sprintf("NullXid.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Xid, n.Valid = 0, false + return nil + } + n.Valid = true + n.Xid = decodeXid(vr) + return vr.Err() +} + +func (n NullXid) FormatCode() int16 { return BinaryFormatCode } + +func (n NullXid) Encode(w *WriteBuf, oid Oid) error { + if oid != XidOid { + return SerializationError(fmt.Sprintf("NullXid.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeXid(w, oid, n.Xid) +} + // NullInt64 represents an bigint that may be null. NullInt64 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. @@ -691,6 +734,8 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { return encodeIPNetSlice(wbuf, oid, arg) case Oid: return encodeOid(wbuf, oid, arg) + case Xid: + return encodeXid(wbuf, oid, arg) default: if strippedArg, ok := stripNamedType(&refVal); ok { return Encode(wbuf, oid, strippedArg) @@ -815,6 +860,8 @@ func Decode(vr *ValueReader, d interface{}) error { *v = uint64(n) case *Oid: *v = decodeOid(vr) + case *Xid: + *v = decodeXid(vr) case *string: *v = decodeText(vr) case *float32: @@ -1339,6 +1386,49 @@ func encodeOid(w *WriteBuf, oid Oid, value Oid) error { return nil } +func decodeXid(vr *ValueReader) Xid { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into Xid")) + return Xid(0) + } + + if vr.Type().DataType != XidOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Xid", vr.Type().DataType))) + return Xid(0) + } + + // Unlikely Xid will ever go over the wire as text format, but who knows? + switch vr.Type().FormatCode { + case TextFormatCode: + s := vr.ReadString(vr.Len()) + n, err := strconv.ParseUint(s, 10, 32) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) + } + return Xid(n) + case BinaryFormatCode: + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) + return Xid(0) + } + return Xid(vr.ReadUint32()) + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return Xid(0) + } +} + +func encodeXid(w *WriteBuf, oid Oid, value Xid) error { + if oid != XidOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Xid", oid) + } + + w.WriteInt32(4) + w.WriteUint32(uint32(value)) + + return nil +} + func decodeFloat4(vr *ValueReader) float32 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into float32")) diff --git a/values_test.go b/values_test.go index 3e650b61..c7474a1c 100644 --- a/values_test.go +++ b/values_test.go @@ -594,6 +594,7 @@ func TestNullX(t *testing.T) { s pgx.NullString i16 pgx.NullInt16 i32 pgx.NullInt32 + xid pgx.NullXid i64 pgx.NullInt64 f32 pgx.NullFloat32 f64 pgx.NullFloat64 @@ -615,6 +616,9 @@ func TestNullX(t *testing.T) { {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}}, + {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 1, Valid: true}}}, + {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: false}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 0, Valid: false}}}, + {"select $1::xid", []interface{}{pgx.NullXid{Xid: 4294967295, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 4294967295, Valid: true}}}, {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}}, {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}}, {"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 1.23, Valid: true}}}, From 99bfc154f0ff1da92d35ea4d7a07c45d219e61f3 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 3 Sep 2016 18:19:33 -0400 Subject: [PATCH 12/75] Makes Oid casting consistent Also fixes uint32 encoding in a few places. --- messages.go | 6 ++++++ values.go | 10 +++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/messages.go b/messages.go index 7e5c3b54..c158d5d2 100644 --- a/messages.go +++ b/messages.go @@ -138,6 +138,12 @@ func (wb *WriteBuf) WriteInt32(n int32) { wb.buf = append(wb.buf, b...) } +func (wb *WriteBuf) WriteUint32(n uint32) { + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, n) + wb.buf = append(wb.buf, b...) +} + func (wb *WriteBuf) WriteInt64(n int64) { b := make([]byte, 8) binary.BigEndian.PutUint64(b, uint64(n)) diff --git a/values.go b/values.go index ee43813b..a1d2b515 100644 --- a/values.go +++ b/values.go @@ -1299,19 +1299,19 @@ func decodeInt4(vr *ValueReader) int32 { func decodeOid(vr *ValueReader) Oid { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into Oid")) - return 0 + return Oid(0) } if vr.Type().DataType != OidOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Oid", vr.Type().DataType))) - return 0 + return Oid(0) } // Oid needs to decode text format because it is used in loadPgTypes switch vr.Type().FormatCode { case TextFormatCode: s := vr.ReadString(vr.Len()) - n, err := strconv.ParseInt(s, 10, 32) + n, err := strconv.ParseUint(s, 10, 32) if err != nil { vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) } @@ -1319,7 +1319,7 @@ func decodeOid(vr *ValueReader) Oid { case BinaryFormatCode: if vr.Len() != 4 { vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) - return 0 + return Oid(0) } return Oid(vr.ReadInt32()) default: @@ -1334,7 +1334,7 @@ func encodeOid(w *WriteBuf, oid Oid, value Oid) error { } w.WriteInt32(4) - w.WriteInt32(int32(value)) + w.WriteUint32(uint32(value)) return nil } From 074bcd7139309f003b76859e91d8c8dd1576d740 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 3 Sep 2016 18:30:36 -0400 Subject: [PATCH 13/75] Adds docs for Oid type. --- messages.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/messages.go b/messages.go index c158d5d2..fec34dbb 100644 --- a/messages.go +++ b/messages.go @@ -53,6 +53,10 @@ func (s *startupMessage) Bytes() (buf []byte) { return buf } +// Oid (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html, +// used internally by PostgreSQL as a primary key for various system tables. It is currently implemented +// as an unsigned four-byte integer. Its definition can be found in src/include/postgres_ext.h +// in the PostgreSQL sources. type Oid uint32 type FieldDescription struct { From 0c7277fe15359549b236292644f4ad7eff753544 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 5 Sep 2016 08:08:39 -0500 Subject: [PATCH 14/75] Update changelog --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ffcc9594..a9d70ee0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +# Unreleased + +## Fixes + +* Oid underlying type changed to uint32, previously it was incorrectly int32 + # 2.9.0 (August 26, 2016) ## Fixes From 7adabc9b93697c46b98c18c30f043910fb41ffba Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Mon, 5 Sep 2016 10:59:24 -0400 Subject: [PATCH 15/75] Improves documentation of Xid type --- values.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/values.go b/values.go index e3486bed..80bd3ed8 100644 --- a/values.go +++ b/values.go @@ -329,9 +329,23 @@ func (n NullInt32) Encode(w *WriteBuf, oid Oid) error { return encodeInt32(w, oid, n.Int32) } +// Xid is PostgreSQL's Transaction ID type. +// +// In later versions of PostgreSQL, it is the type used for the backend_xid +// and backend_xmin columns of the pg_stat_activity system view. +// +// Also, when one does +// +// select xmin, xmax, * from some_table; +// +// it is the data type of the xmin and xmax hidden system columns. +// +// It is currently implemented as an unsigned for byte integer. +// Its definition can be found in src/include/postgres_ext.h as TransactionId +// in the PostgreSQL sources. type Xid uint32 -// NullXid represents a transaction ID (Xid) that may be null. NullXid implements the +// NullXid represents a Transaction ID (Xid) that may be null. NullXid implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. // From 60ab3403abb7c4f3e99cf995ba92af476f4910e0 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Mon, 5 Sep 2016 12:15:34 -0400 Subject: [PATCH 16/75] Adds Cid/NullCid type --- values.go | 101 +++++++++++++++++++++++++++++++++++++++++++++++++ values_test.go | 4 ++ 2 files changed, 105 insertions(+) diff --git a/values.go b/values.go index 80bd3ed8..de5bcfd5 100644 --- a/values.go +++ b/values.go @@ -23,6 +23,7 @@ const ( TextOid = 25 OidOid = 26 XidOid = 28 + CidOid = 29 JsonOid = 114 CidrOid = 650 CidrArrayOid = 651 @@ -95,6 +96,7 @@ func init() { "int8": BinaryFormatCode, "oid": BinaryFormatCode, "xid": BinaryFormatCode, + "cid": BinaryFormatCode, "record": BinaryFormatCode, "text": BinaryFormatCode, "timestamp": BinaryFormatCode, @@ -384,6 +386,58 @@ func (n NullXid) Encode(w *WriteBuf, oid Oid) error { return encodeXid(w, oid, n.Xid) } +// Cid is PostgreSQL's Command Identifier type. +// +// When one does +// +// select cmin, cmax, * from some_table; +// +// it is the data type of the cmin and cmax hidden system columns. +// +// It is currently implemented as an unsigned for byte integer. +// Its definition can be found in src/include/c.h as CommandId +// in the PostgreSQL sources. +type Cid uint32 + +// NullCid represents a Command Identifier (Cid) that may be null. NullCid implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullCid struct { + Cid Cid + Valid bool // Valid is true if Int32 is not NULL +} + +func (n *NullCid) Scan(vr *ValueReader) error { + if vr.Type().DataType != CidOid { + return SerializationError(fmt.Sprintf("NullCid.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Cid, n.Valid = 0, false + return nil + } + n.Valid = true + n.Cid = decodeCid(vr) + return vr.Err() +} + +func (n NullCid) FormatCode() int16 { return BinaryFormatCode } + +func (n NullCid) Encode(w *WriteBuf, oid Oid) error { + if oid != CidOid { + return SerializationError(fmt.Sprintf("NullCid.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeCid(w, oid, n.Cid) +} + // NullInt64 represents an bigint that may be null. NullInt64 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. @@ -750,6 +804,8 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { return encodeOid(wbuf, oid, arg) case Xid: return encodeXid(wbuf, oid, arg) + case Cid: + return encodeCid(wbuf, oid, arg) default: if strippedArg, ok := stripNamedType(&refVal); ok { return Encode(wbuf, oid, strippedArg) @@ -876,6 +932,8 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeOid(vr) case *Xid: *v = decodeXid(vr) + case *Cid: + *v = decodeCid(vr) case *string: *v = decodeText(vr) case *float32: @@ -1443,6 +1501,49 @@ func encodeXid(w *WriteBuf, oid Oid, value Xid) error { return nil } +func decodeCid(vr *ValueReader) Cid { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into Cid")) + return Cid(0) + } + + if vr.Type().DataType != CidOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Cid", vr.Type().DataType))) + return Cid(0) + } + + // Unlikely Cid will ever go over the wire as text format, but who knows? + switch vr.Type().FormatCode { + case TextFormatCode: + s := vr.ReadString(vr.Len()) + n, err := strconv.ParseUint(s, 10, 32) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) + } + return Cid(n) + case BinaryFormatCode: + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) + return Cid(0) + } + return Cid(vr.ReadUint32()) + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return Cid(0) + } +} + +func encodeCid(w *WriteBuf, oid Oid, value Cid) error { + if oid != CidOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Cid", oid) + } + + w.WriteInt32(4) + w.WriteUint32(uint32(value)) + + return nil +} + func decodeFloat4(vr *ValueReader) float32 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into float32")) diff --git a/values_test.go b/values_test.go index c7474a1c..2325b6f1 100644 --- a/values_test.go +++ b/values_test.go @@ -595,6 +595,7 @@ func TestNullX(t *testing.T) { i16 pgx.NullInt16 i32 pgx.NullInt32 xid pgx.NullXid + cid pgx.NullCid i64 pgx.NullInt64 f32 pgx.NullFloat32 f64 pgx.NullFloat64 @@ -619,6 +620,9 @@ func TestNullX(t *testing.T) { {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 1, Valid: true}}}, {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: false}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 0, Valid: false}}}, {"select $1::xid", []interface{}{pgx.NullXid{Xid: 4294967295, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 4294967295, Valid: true}}}, + {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}}, + {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: false}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 0, Valid: false}}}, + {"select $1::cid", []interface{}{pgx.NullCid{Cid: 4294967295, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 4294967295, Valid: true}}}, {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}}, {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}}, {"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 1.23, Valid: true}}}, From 57b3037e96fc06e32354c48666b1986a06abf034 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 10 Sep 2016 14:49:39 -0400 Subject: [PATCH 17/75] Adds tid oid --- values.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/values.go b/values.go index de5bcfd5..04c66317 100644 --- a/values.go +++ b/values.go @@ -22,6 +22,7 @@ const ( Int4Oid = 23 TextOid = 25 OidOid = 26 + TidOid = 27 XidOid = 28 CidOid = 29 JsonOid = 114 @@ -95,6 +96,7 @@ func init() { "int4": BinaryFormatCode, "int8": BinaryFormatCode, "oid": BinaryFormatCode, + "tid": BinaryFormatCode, "xid": BinaryFormatCode, "cid": BinaryFormatCode, "record": BinaryFormatCode, From cba72d47c5f74ca539a66f2a37affb8878579d23 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 10 Sep 2016 19:44:33 -0400 Subject: [PATCH 18/75] Fixes typo --- values.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/values.go b/values.go index 80bd3ed8..dcc378d1 100644 --- a/values.go +++ b/values.go @@ -340,7 +340,7 @@ func (n NullInt32) Encode(w *WriteBuf, oid Oid) error { // // it is the data type of the xmin and xmax hidden system columns. // -// It is currently implemented as an unsigned for byte integer. +// It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/postgres_ext.h as TransactionId // in the PostgreSQL sources. type Xid uint32 From f49b92d5a89033429c07a3fcfe74fd76426c2a13 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 10 Sep 2016 19:47:51 -0400 Subject: [PATCH 19/75] Fixes typo --- values.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/values.go b/values.go index de5bcfd5..1bf224d5 100644 --- a/values.go +++ b/values.go @@ -342,7 +342,7 @@ func (n NullInt32) Encode(w *WriteBuf, oid Oid) error { // // it is the data type of the xmin and xmax hidden system columns. // -// It is currently implemented as an unsigned for byte integer. +// It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/postgres_ext.h as TransactionId // in the PostgreSQL sources. type Xid uint32 @@ -394,7 +394,7 @@ func (n NullXid) Encode(w *WriteBuf, oid Oid) error { // // it is the data type of the cmin and cmax hidden system columns. // -// It is currently implemented as an unsigned for byte integer. +// It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/c.h as CommandId // in the PostgreSQL sources. type Cid uint32 From 55bd3a9134fc2af077863b62696beab9df67ea06 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 10 Sep 2016 20:37:02 -0400 Subject: [PATCH 20/75] Adds binary tid decode --- msg_reader.go | 28 ++++++++++++++ values.go | 104 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 130 insertions(+), 2 deletions(-) diff --git a/msg_reader.go b/msg_reader.go index 2bcd2d51..59617b73 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -137,6 +137,34 @@ func (r *msgReader) readInt32() int32 { return n } +func (r *msgReader) readUint16() uint16 { + if r.err != nil { + return 0 + } + + r.msgBytesRemaining -= 2 + if r.msgBytesRemaining < 0 { + r.fatal(errors.New("read past end of message")) + return 0 + } + + b, err := r.reader.Peek(2) + if err != nil { + r.fatal(err) + return 0 + } + + n := uint16(binary.BigEndian.Uint16(b)) + + r.reader.Discard(2) + + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + } + + return n +} + func (r *msgReader) readUint32() uint32 { if r.err != nil { return 0 diff --git a/values.go b/values.go index 04c66317..51adfa1c 100644 --- a/values.go +++ b/values.go @@ -344,7 +344,7 @@ func (n NullInt32) Encode(w *WriteBuf, oid Oid) error { // // it is the data type of the xmin and xmax hidden system columns. // -// It is currently implemented as an unsigned for byte integer. +// It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/postgres_ext.h as TransactionId // in the PostgreSQL sources. type Xid uint32 @@ -396,7 +396,7 @@ func (n NullXid) Encode(w *WriteBuf, oid Oid) error { // // it is the data type of the cmin and cmax hidden system columns. // -// It is currently implemented as an unsigned for byte integer. +// It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/c.h as CommandId // in the PostgreSQL sources. type Cid uint32 @@ -440,6 +440,61 @@ func (n NullCid) Encode(w *WriteBuf, oid Oid) error { return encodeCid(w, oid, n.Cid) } +// Tid is PostgreSQL's Tuple Identifier type. +// +// When one does +// +// select ctid, * from some_table; +// +// it is the data type of the ctid hidden system column. +// +// It is currently implemented as a pair unsigned two byte integers. +// Its conversion functions can be found in src/backend/utils/adt/tid.c +// in the PostgreSQL sources. +type Tid struct { + BlockNumber uint16 + OffsetNumber uint16 +} + +// NullTid represents a Tuple Identifier (Tid) that may be null. NullTid implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullTid struct { + Tid Tid + Valid bool // Valid is true if Int32 is not NULL +} + +func (n *NullTid) Scan(vr *ValueReader) error { + if vr.Type().DataType != TidOid { + return SerializationError(fmt.Sprintf("NullTid.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Tid, n.Valid = 0, false + return nil + } + n.Valid = true + n.Tid = decodeTid(vr) + return vr.Err() +} + +func (n NullTid) FormatCode() int16 { return BinaryFormatCode } + +func (n NullTid) Encode(w *WriteBuf, oid Oid) error { + if oid != TidOid { + return SerializationError(fmt.Sprintf("NullTid.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeTid(w, oid, n.Tid) +} + // NullInt64 represents an bigint that may be null. NullInt64 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. @@ -934,6 +989,8 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeOid(vr) case *Xid: *v = decodeXid(vr) + case *Tid: + *v = decodeTid(vr) case *Cid: *v = decodeCid(vr) case *string: @@ -1546,6 +1603,49 @@ func encodeCid(w *WriteBuf, oid Oid, value Cid) error { return nil } +func decodeTid(vr *ValueReader) Tid { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into Tid")) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + + if vr.Type().DataType != TidOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Tid", vr.Type().DataType))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + + // Unlikely Tid will ever go over the wire as text format, but who knows? + switch vr.Type().FormatCode { + case TextFormatCode: // XXX: not done yet src/backend/utils/adt/tid.c for hints; s already contains the string, so we just have to parse out (uint16,uint16) + s := vr.ReadString(vr.Len()) + n, err := strconv.ParseUint(s, 10, 32) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) + } + return Tid(n) + case BinaryFormatCode: + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + return Tid{BlockNumber: vr.ReadUint16(), OffsetNumber: vr.ReadUint16()} + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } +} + +func encodeTid(w *WriteBuf, oid Oid, value Tid) error { + if oid != TidOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Tid", oid) + } + + w.WriteInt32(4) + w.WriteUint32(uint32(value)) + + return nil +} + func decodeFloat4(vr *ValueReader) float32 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into float32")) From 72084ad1b558ecc8f2b84efdddc4b49c5b77e6c4 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sun, 11 Sep 2016 10:02:27 -0400 Subject: [PATCH 21/75] Gets Tid parsing working --- messages.go | 6 ++++++ values.go | 28 +++++++++++++++++++++++----- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/messages.go b/messages.go index fec34dbb..8eb4b8ef 100644 --- a/messages.go +++ b/messages.go @@ -136,6 +136,12 @@ func (wb *WriteBuf) WriteInt16(n int16) { wb.buf = append(wb.buf, b...) } +func (wb *WriteBuf) WriteUint16(n uint16) { + b := make([]byte, 2) + binary.BigEndian.PutUint16(b, n) + wb.buf = append(wb.buf, b...) +} + func (wb *WriteBuf) WriteInt32(n int32) { b := make([]byte, 4) binary.BigEndian.PutUint32(b, uint32(n)) diff --git a/values.go b/values.go index 51adfa1c..69fdaae0 100644 --- a/values.go +++ b/values.go @@ -8,6 +8,7 @@ import ( "math" "net" "reflect" + "regexp" "strconv" "strings" "time" @@ -1603,6 +1604,10 @@ func encodeCid(w *WriteBuf, oid Oid, value Cid) error { return nil } +// Note that we do not match negative numbers, because neither the +// BlockNumber nor OffsetNumber of a Tid can be negative. +var tidRegexp *regexp.Regexp = regexp.MustCompile(`^\((\d*),(\d*)\)$`) + func decodeTid(vr *ValueReader) Tid { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into Tid")) @@ -1616,13 +1621,25 @@ func decodeTid(vr *ValueReader) Tid { // Unlikely Tid will ever go over the wire as text format, but who knows? switch vr.Type().FormatCode { - case TextFormatCode: // XXX: not done yet src/backend/utils/adt/tid.c for hints; s already contains the string, so we just have to parse out (uint16,uint16) + case TextFormatCode: s := vr.ReadString(vr.Len()) - n, err := strconv.ParseUint(s, 10, 32) - if err != nil { + + match := tidRegexp.FindStringSubmatch(s) + if match == nil { vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) + return Tid{BlockNumber: 0, OffsetNumber: 0} } - return Tid(n) + + blockNumber, err := strconv.ParseUint(s, 10, 16) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid BlockNumber part of a Tid: %v", s))) + } + + offsetNumber, err := strconv.ParseUint(s, 10, 16) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid offsetNumber part of a Tid: %v", s))) + } + return Tid{BlockNumber: blockNumber, OffsetNumber: offsetNumber} case BinaryFormatCode: if vr.Len() != 4 { vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) @@ -1641,7 +1658,8 @@ func encodeTid(w *WriteBuf, oid Oid, value Tid) error { } w.WriteInt32(4) - w.WriteUint32(uint32(value)) + w.WriteUint16(uint16(value.BlockNumber)) + w.WriteUint16(uint16(value.OffsetNumber)) return nil } From 00bd3062e0d4f0c55df05876a43313525c1a2c48 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sun, 11 Sep 2016 17:33:34 -0400 Subject: [PATCH 22/75] Figures out tid binary wire formatting --- value_reader.go | 14 ++++++++++++++ values.go | 16 ++++++++-------- values_test.go | 4 ++++ 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/value_reader.go b/value_reader.go index 6e552ea8..a4897543 100644 --- a/value_reader.go +++ b/value_reader.go @@ -60,6 +60,20 @@ func (r *ValueReader) ReadInt16() int16 { return r.mr.readInt16() } +func (r *ValueReader) ReadUint16() uint16 { + if r.err != nil { + return 0 + } + + r.valueBytesRemaining -= 2 + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return 0 + } + + return r.mr.readUint16() +} + func (r *ValueReader) ReadInt32() int32 { if r.err != nil { return 0 diff --git a/values.go b/values.go index 69fdaae0..ae7b7cf7 100644 --- a/values.go +++ b/values.go @@ -453,7 +453,7 @@ func (n NullCid) Encode(w *WriteBuf, oid Oid) error { // Its conversion functions can be found in src/backend/utils/adt/tid.c // in the PostgreSQL sources. type Tid struct { - BlockNumber uint16 + BlockNumber uint32 OffsetNumber uint16 } @@ -473,7 +473,7 @@ func (n *NullTid) Scan(vr *ValueReader) error { } if vr.Len() == -1 { - n.Tid, n.Valid = 0, false + n.Tid, n.Valid = Tid{BlockNumber: 0, OffsetNumber: 0}, false return nil } n.Valid = true @@ -1639,13 +1639,13 @@ func decodeTid(vr *ValueReader) Tid { if err != nil { vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid offsetNumber part of a Tid: %v", s))) } - return Tid{BlockNumber: blockNumber, OffsetNumber: offsetNumber} + return Tid{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber)} case BinaryFormatCode: - if vr.Len() != 4 { + if vr.Len() != 6 { vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) return Tid{BlockNumber: 0, OffsetNumber: 0} } - return Tid{BlockNumber: vr.ReadUint16(), OffsetNumber: vr.ReadUint16()} + return Tid{BlockNumber: vr.ReadUint32(), OffsetNumber: vr.ReadUint16()} default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return Tid{BlockNumber: 0, OffsetNumber: 0} @@ -1657,9 +1657,9 @@ func encodeTid(w *WriteBuf, oid Oid, value Tid) error { return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Tid", oid) } - w.WriteInt32(4) - w.WriteUint16(uint16(value.BlockNumber)) - w.WriteUint16(uint16(value.OffsetNumber)) + w.WriteInt32(6) + w.WriteUint32(value.BlockNumber) + w.WriteUint16(value.OffsetNumber) return nil } diff --git a/values_test.go b/values_test.go index 2325b6f1..cea70b9c 100644 --- a/values_test.go +++ b/values_test.go @@ -596,6 +596,7 @@ func TestNullX(t *testing.T) { i32 pgx.NullInt32 xid pgx.NullXid cid pgx.NullCid + tid pgx.NullTid i64 pgx.NullInt64 f32 pgx.NullFloat32 f64 pgx.NullFloat64 @@ -623,6 +624,9 @@ func TestNullX(t *testing.T) { {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: false}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 0, Valid: false}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 4294967295, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 4294967295, Valid: true}}}, + {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}}, + {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: false}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 0, OffsetNumber: 0}, Valid: false}}}, + {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}}, {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}}, {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}}, {"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 1.23, Valid: true}}}, From 818dcbf2b664379261d49ce1569844fc680988d8 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 17 Sep 2016 23:11:59 -0400 Subject: [PATCH 23/75] Adds "char" type --- values.go | 84 +++++++++++++++++++++++++++++++++++++++++++++++++- values_test.go | 4 +++ 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/values.go b/values.go index ae7b7cf7..ffc1ceab 100644 --- a/values.go +++ b/values.go @@ -18,6 +18,7 @@ import ( const ( BoolOid = 16 ByteaOid = 17 + CharOid = 18 Int8Oid = 20 Int2Oid = 21 Int4Oid = 23 @@ -88,6 +89,7 @@ func init() { "_varchar": BinaryFormatCode, "bool": BinaryFormatCode, "bytea": BinaryFormatCode, + "char": BinaryFormatCode, "cidr": BinaryFormatCode, "date": BinaryFormatCode, "float4": BinaryFormatCode, @@ -256,7 +258,53 @@ func (n NullString) Encode(w *WriteBuf, oid Oid) error { return encodeString(w, oid, n.String) } -// NullInt16 represents an smallint that may be null. NullInt16 implements the +// The pgx.Char type is for PostgreSQL's special 8-bit-only +// "char" type more akin to the C language's char type, or Go's byte type. +// (Note that the name in PostgreSQL itself is "char" and not char.) +// It gets used a lot +// in PostgreSQL's system tables to hold a single ASCII character value. +type Char byte + +// NullChar represents a pgx.Char that may be null. NullChar implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan for prepared and unprepared queries. +// +// If Valid is false then the value is NULL. +type NullChar struct { + Char Char + Valid bool // Valid is true if Char is not NULL +} + +func (n *NullChar) Scan(vr *ValueReader) error { + if vr.Type().DataType != CharOid { + return SerializationError(fmt.Sprintf("NullChar.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Char, n.Valid = 0, false + return nil + } + n.Valid = true + n.Char = decodeChar(vr) + return vr.Err() +} + +func (n NullChar) FormatCode() int16 { return BinaryFormatCode } + +func (n NullChar) Encode(w *WriteBuf, oid Oid) error { + if oid != CharOid { + return SerializationError(fmt.Sprintf("NullChar.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeChar(w, oid, n.Char) +} + +// NullInt16 represents a smallint that may be null. NullInt16 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan for prepared and unprepared queries. // @@ -810,6 +858,8 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { return encodeInt(wbuf, oid, arg) case uint: return encodeUInt(wbuf, oid, arg) + case Char: + return encodeChar(wbuf, oid, arg) case int8: return encodeInt8(wbuf, oid, arg) case uint8: @@ -986,6 +1036,8 @@ func Decode(vr *ValueReader, d interface{}) error { return fmt.Errorf("%d is less than zero for uint64", n) } *v = uint64(n) + case *Char: + *v = decodeChar(vr) case *Oid: *v = decodeOid(vr) case *Xid: @@ -1185,6 +1237,30 @@ func decodeInt8(vr *ValueReader) int64 { return vr.ReadInt64() } +func decodeChar(vr *ValueReader) Char { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into char")) + return Char(0) + } + + if vr.Type().DataType != CharOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into char", vr.Type().DataType))) + return Char(0) + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return Char(0) + } + + if vr.Len() != 1 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a char: %d", vr.Len()))) + return Char(0) + } + + return Char(vr.ReadByte()) +} + func decodeInt2(vr *ValueReader) int16 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into int16")) @@ -1270,6 +1346,12 @@ func encodeUInt(w *WriteBuf, oid Oid, value uint) error { return nil } +func encodeChar(w *WriteBuf, oid Oid, value Char) error { + w.WriteInt32(1) + w.WriteByte(byte(value)) + return nil +} + func encodeInt8(w *WriteBuf, oid Oid, value int8) error { switch oid { case Int2Oid: diff --git a/values_test.go b/values_test.go index cea70b9c..10ef5de4 100644 --- a/values_test.go +++ b/values_test.go @@ -594,6 +594,7 @@ func TestNullX(t *testing.T) { s pgx.NullString i16 pgx.NullInt16 i32 pgx.NullInt32 + c pgx.NullChar xid pgx.NullXid cid pgx.NullCid tid pgx.NullTid @@ -621,6 +622,9 @@ func TestNullX(t *testing.T) { {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 1, Valid: true}}}, {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: false}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 0, Valid: false}}}, {"select $1::xid", []interface{}{pgx.NullXid{Xid: 4294967295, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 4294967295, Valid: true}}}, + {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 1, Valid: true}}}, + {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: false}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 0, Valid: false}}}, + {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: false}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 0, Valid: false}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 4294967295, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 4294967295, Valid: true}}}, From cd21fd9035b2dfddebbceb4cd30e47cfd2313bba Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 19 Sep 2016 07:47:16 -0500 Subject: [PATCH 24/75] Fix missing documentation fixes #177 --- values.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/values.go b/values.go index 1bf224d5..4d542bf1 100644 --- a/values.go +++ b/values.go @@ -66,7 +66,8 @@ const minInt = -maxInt - 1 // or binary). In theory the Scanner interface should be the one to determine // the format of the returned values. However, the query has already been // executed by the time Scan is called so it has no chance to set the format. -// So for types that should be returned in binary th +// So for types that should always be returned in binary the format should be +// set here. var DefaultTypeFormats map[string]int16 func init() { From 94203a55ada6ef85fc0e4d4b11badcdc5c48c115 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Mon, 19 Sep 2016 20:40:13 -0400 Subject: [PATCH 25/75] Adds same comment fix about binary settings as on master --- values.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/values.go b/values.go index ffc1ceab..b7112b9b 100644 --- a/values.go +++ b/values.go @@ -69,7 +69,8 @@ const minInt = -maxInt - 1 // or binary). In theory the Scanner interface should be the one to determine // the format of the returned values. However, the query has already been // executed by the time Scan is called so it has no chance to set the format. -// So for types that should be returned in binary th +// So for types that should always be returned in binary the format should be +// set here. var DefaultTypeFormats map[string]int16 func init() { From 88ac6ff200bb6410851a71244a2118f83b7c8ed3 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Mon, 19 Sep 2016 20:43:03 -0400 Subject: [PATCH 26/75] Reformats "char" comment a bit --- values.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/values.go b/values.go index b7112b9b..859d79a1 100644 --- a/values.go +++ b/values.go @@ -261,9 +261,9 @@ func (n NullString) Encode(w *WriteBuf, oid Oid) error { // The pgx.Char type is for PostgreSQL's special 8-bit-only // "char" type more akin to the C language's char type, or Go's byte type. -// (Note that the name in PostgreSQL itself is "char" and not char.) -// It gets used a lot -// in PostgreSQL's system tables to hold a single ASCII character value. +// (Note that the name in PostgreSQL itself is "char", in double-quotes, +// and not char.) It gets used a lot in PostgreSQL's system tables to hold +// a single ASCII character value. type Char byte // NullChar represents a pgx.Char that may be null. NullChar implements the From 256cbf0010ae83c21449f2600d1ecb29e7694496 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Mon, 19 Sep 2016 20:48:31 -0400 Subject: [PATCH 27/75] Adds example column to pgx.Char doc --- values.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/values.go b/values.go index 859d79a1..f8991f88 100644 --- a/values.go +++ b/values.go @@ -263,7 +263,7 @@ func (n NullString) Encode(w *WriteBuf, oid Oid) error { // "char" type more akin to the C language's char type, or Go's byte type. // (Note that the name in PostgreSQL itself is "char", in double-quotes, // and not char.) It gets used a lot in PostgreSQL's system tables to hold -// a single ASCII character value. +// a single ASCII character value (eg pg_class.relkind). type Char byte // NullChar represents a pgx.Char that may be null. NullChar implements the From cc1ad69c32966191a25c586c399ed4983b2b12df Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Tue, 20 Sep 2016 21:11:30 -0400 Subject: [PATCH 28/75] Adds NullOid type Oids are rarely null, but they can be: on the right hand side of a left join, for instance. This commit takes moves the Oid type def from messages.go to values.go, so it can live along side the other types. It removes the special case for testing Oid and now leverages the TestNullX test instead. --- messages.go | 6 ------ values.go | 51 +++++++++++++++++++++++++++++++++++++++++++++++--- values_test.go | 37 ++++-------------------------------- 3 files changed, 52 insertions(+), 42 deletions(-) diff --git a/messages.go b/messages.go index 8eb4b8ef..317ba273 100644 --- a/messages.go +++ b/messages.go @@ -53,12 +53,6 @@ func (s *startupMessage) Bytes() (buf []byte) { return buf } -// Oid (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html, -// used internally by PostgreSQL as a primary key for various system tables. It is currently implemented -// as an unsigned four-byte integer. Its definition can be found in src/include/postgres_ext.h -// in the PostgreSQL sources. -type Oid uint32 - type FieldDescription struct { Name string Table Oid diff --git a/values.go b/values.go index f8991f88..545f1031 100644 --- a/values.go +++ b/values.go @@ -383,6 +383,51 @@ func (n NullInt32) Encode(w *WriteBuf, oid Oid) error { return encodeInt32(w, oid, n.Int32) } +// Oid (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html, +// used internally by PostgreSQL as a primary key for various system tables. It is currently implemented +// as an unsigned four-byte integer. Its definition can be found in src/include/postgres_ext.h +// in the PostgreSQL sources. +type Oid uint32 + +// NullOid represents a Command Identifier (Oid) that may be null. NullOid implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullOid struct { + Oid Oid + Valid bool // Valid is true if Oid is not NULL +} + +func (n *NullOid) Scan(vr *ValueReader) error { + if vr.Type().DataType != OidOid { + return SerializationError(fmt.Sprintf("NullOid.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Oid, n.Valid = 0, false + return nil + } + n.Valid = true + n.Oid = decodeOid(vr) + return vr.Err() +} + +func (n NullOid) FormatCode() int16 { return BinaryFormatCode } + +func (n NullOid) Encode(w *WriteBuf, oid Oid) error { + if oid != OidOid { + return SerializationError(fmt.Sprintf("NullOid.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeOid(w, oid, n.Oid) +} + // Xid is PostgreSQL's Transaction ID type. // // In later versions of PostgreSQL, it is the type used for the backend_xid @@ -406,7 +451,7 @@ type Xid uint32 // If Valid is false then the value is NULL. type NullXid struct { Xid Xid - Valid bool // Valid is true if Int32 is not NULL + Valid bool // Valid is true if Xid is not NULL } func (n *NullXid) Scan(vr *ValueReader) error { @@ -458,7 +503,7 @@ type Cid uint32 // If Valid is false then the value is NULL. type NullCid struct { Cid Cid - Valid bool // Valid is true if Int32 is not NULL + Valid bool // Valid is true if Cid is not NULL } func (n *NullCid) Scan(vr *ValueReader) error { @@ -513,7 +558,7 @@ type Tid struct { // If Valid is false then the value is NULL. type NullTid struct { Tid Tid - Valid bool // Valid is true if Int32 is not NULL + Valid bool // Valid is true if Tid is not NULL } func (n *NullTid) Scan(vr *ValueReader) error { diff --git a/values_test.go b/values_test.go index 10ef5de4..7de4e0c2 100644 --- a/values_test.go +++ b/values_test.go @@ -551,39 +551,6 @@ func TestInetCidrTranscodeWithJustIP(t *testing.T) { } } -func TestOid(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - tests := []struct { - sql string - value pgx.Oid - }{ - {"select $1::oid", 0}, - {"select $1::oid", 1}, - {"select $1::oid", 4294967295}, - } - - for i, tt := range tests { - expected := tt.value - var actual pgx.Oid - - err := conn.QueryRow(tt.sql, expected).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, expected) - continue - } - - if actual != expected { - t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, expected, actual, tt.sql) - } - - ensureConnValid(t, conn) - } -} - func TestNullX(t *testing.T) { t.Parallel() @@ -595,6 +562,7 @@ func TestNullX(t *testing.T) { i16 pgx.NullInt16 i32 pgx.NullInt32 c pgx.NullChar + oid pgx.NullOid xid pgx.NullXid cid pgx.NullCid tid pgx.NullTid @@ -619,6 +587,9 @@ func TestNullX(t *testing.T) { {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}}, + {"select $1::oid", []interface{}{pgx.NullOid{Oid: 1, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOid{Oid: 1, Valid: true}}}, + {"select $1::oid", []interface{}{pgx.NullOid{Oid: 1, Valid: false}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOid{Oid: 0, Valid: false}}}, + {"select $1::oid", []interface{}{pgx.NullOid{Oid: 4294967295, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOid{Oid: 4294967295, Valid: true}}}, {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 1, Valid: true}}}, {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: false}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 0, Valid: false}}}, {"select $1::xid", []interface{}{pgx.NullXid{Xid: 4294967295, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 4294967295, Valid: true}}}, From f5b269d56a0e0da79e9affe11a960c2380393fa5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 22 Sep 2016 07:37:18 -0500 Subject: [PATCH 29/75] Update Go versions on Travis --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 314e26e8..b63c864c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,8 @@ language: go go: - - 1.6.2 - - 1.5.2 + - 1.7.1 + - 1.6.3 - tip # Derived from https://github.com/lib/pq/blob/master/.travis.yml From 79acbeac0efc90dda7356cf5d3f586de8f7b223e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 22 Sep 2016 07:47:28 -0500 Subject: [PATCH 30/75] Tweak test for better travis compat. --- conn_pool_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/conn_pool_test.go b/conn_pool_test.go index e3ae0036..2163f515 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -366,12 +366,12 @@ func TestPoolAcquireAndReleaseCycleAutoConnect(t *testing.T) { if err != nil { t.Fatalf("Unable to Acquire: %v", err) } - rows, _ := c.Query("select 1") + rows, _ := c.Query("select 1, pg_sleep(0.02)") rows.Close() pool.Release(c) } - for i := 0; i < 1000; i++ { + for i := 0; i < 10; i++ { doSomething() } @@ -381,7 +381,7 @@ func TestPoolAcquireAndReleaseCycleAutoConnect(t *testing.T) { } var wg sync.WaitGroup - for i := 0; i < 1000; i++ { + for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() From 2ccec0026e6b7d34625c59d1f0d1d43befe4d58d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 22 Sep 2016 07:51:47 -0500 Subject: [PATCH 31/75] go get test dependencies on travis --- .travis.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.travis.yml b/.travis.yml index b63c864c..ca2306a1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -37,6 +37,9 @@ before_script: - psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" +install: + - go get -u github.com/shopspring/decimal + script: - go test -v -race -short ./... From dd7d777682e135e19355c45afdfa4a609759dc06 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 22 Sep 2016 08:32:54 -0500 Subject: [PATCH 32/75] Fetch another dependency for travis --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index ca2306a1..5bf5a167 100644 --- a/.travis.yml +++ b/.travis.yml @@ -39,6 +39,7 @@ before_script: install: - go get -u github.com/shopspring/decimal + - go get -u gopkg.in/inconshreveable/log15.v2 script: - go test -v -race -short ./... From b1a77cfa3116f319ef00fb785e0776a1fa0d4fb4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 22 Sep 2016 08:38:16 -0500 Subject: [PATCH 33/75] And another dependency for travis --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 5bf5a167..537dca8d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -40,6 +40,7 @@ before_script: install: - go get -u github.com/shopspring/decimal - go get -u gopkg.in/inconshreveable/log15.v2 + - go get -u github.com/jackc/fake script: - go test -v -race -short ./... From bcfb1f4d7cbc45ee96a42ec1efbb4d1d8cc5e0d3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 22 Sep 2016 08:56:46 -0500 Subject: [PATCH 34/75] Use pointer for decimal.Decimal Fix breakage caused by 54efccb61ffe0a31b6a1908bdc8a35f491da01ea in https://github.com/shopspring/decimal --- query_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/query_test.go b/query_test.go index 06a18ffe..457bc1fb 100644 --- a/query_test.go +++ b/query_test.go @@ -1282,7 +1282,7 @@ func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) { } var num decimal.Decimal - err = conn.QueryRow("select $1::decimal", expected).Scan(&num) + err = conn.QueryRow("select $1::decimal", &expected).Scan(&num) if err != nil { t.Fatalf("Scan failed: %v", err) } From 383c12177e7736901470195a5da96ec52e6c5002 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Sep 2016 12:49:51 -0500 Subject: [PATCH 35/75] Add mapping information for core types. refs #183 --- doc.go | 134 ++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 84 insertions(+), 50 deletions(-) diff --git a/doc.go b/doc.go index c202f861..980c5a74 100644 --- a/doc.go +++ b/doc.go @@ -81,63 +81,39 @@ releasing connections when you do not need that level of control. return err } -Transactions +Base Type Mapping -Transactions are started by calling Begin or BeginIso. The BeginIso variant -creates a transaction with a specified isolation level. +pgx maps between all common base types directly between Go and PostgreSQL. In +particular: - tx, err := conn.Begin() - if err != nil { - return err - } - // Rollback is safe to call even if the tx is already closed, so if - // the tx commits successfully, this is a no-op - defer tx.Rollback() + Go PostgreSQL + ----------------------- + string varchar + text - _, err = tx.Exec("insert into foo(id) values (1)") - if err != nil { - return err - } + // Integers are automatically be converted to any other integer type if + // it can be done without overflow or underflow. + int8 + int16 smallint + int32 int + int64 bigint + int + uint8 + uint16 + uint32 + uint64 + uint - err = tx.Commit() - if err != nil { - return err - } + // Floats are strict and do not automatically convert like integers. + float32 float4 + float64 float8 -Copy Protocol + time.Time date + timestamp + timestamptz -Use CopyTo to efficiently insert multiple rows at a time using the PostgreSQL -copy protocol. CopyTo accepts a CopyToSource interface. If the data is already -in a [][]interface{} use CopyToRows to wrap it in a CopyToSource interface. Or -implement CopyToSource to avoid buffering the entire data set in memory. + []byte bytea - rows := [][]interface{}{ - {"John", "Smith", int32(36)}, - {"Jane", "Doe", int32(29)}, - } - - copyCount, err := conn.CopyTo( - "people", - []string{"first_name", "last_name", "age"}, - pgx.CopyToRows(rows), - ) - -CopyTo can be faster than an insert with as few as 5 rows. - -Listen and Notify - -pgx can listen to the PostgreSQL notification system with the -WaitForNotification function. It takes a maximum time to wait for a -notification. - - err := conn.Listen("channelname") - if err != nil { - return nil - } - - if notification, err := conn.WaitForNotification(time.Second); err != nil { - // do something with notification - } Null Mapping @@ -212,6 +188,64 @@ the raw bytes returned by PostgreSQL. This can be especially useful for reading varchar, text, json, and jsonb values directly into a []byte and avoiding the type conversion from string. +Transactions + +Transactions are started by calling Begin or BeginIso. The BeginIso variant +creates a transaction with a specified isolation level. + + tx, err := conn.Begin() + if err != nil { + return err + } + // Rollback is safe to call even if the tx is already closed, so if + // the tx commits successfully, this is a no-op + defer tx.Rollback() + + _, err = tx.Exec("insert into foo(id) values (1)") + if err != nil { + return err + } + + err = tx.Commit() + if err != nil { + return err + } + +Copy Protocol + +Use CopyTo to efficiently insert multiple rows at a time using the PostgreSQL +copy protocol. CopyTo accepts a CopyToSource interface. If the data is already +in a [][]interface{} use CopyToRows to wrap it in a CopyToSource interface. Or +implement CopyToSource to avoid buffering the entire data set in memory. + + rows := [][]interface{}{ + {"John", "Smith", int32(36)}, + {"Jane", "Doe", int32(29)}, + } + + copyCount, err := conn.CopyTo( + "people", + []string{"first_name", "last_name", "age"}, + pgx.CopyToRows(rows), + ) + +CopyTo can be faster than an insert with as few as 5 rows. + +Listen and Notify + +pgx can listen to the PostgreSQL notification system with the +WaitForNotification function. It takes a maximum time to wait for a +notification. + + err := conn.Listen("channelname") + if err != nil { + return nil + } + + if notification, err := conn.WaitForNotification(time.Second); err != nil { + // do something with notification + } + TLS The pgx ConnConfig struct has a TLSConfig field. If this field is From c25e3dd82605c0d12a44019bf5fcb444f71fa0b7 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Thu, 29 Sep 2016 00:25:19 -0400 Subject: [PATCH 36/75] Adds Name/NullName types --- values.go | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++ values_test.go | 28 ++++++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/values.go b/values.go index 545f1031..ac472c51 100644 --- a/values.go +++ b/values.go @@ -19,6 +19,7 @@ const ( BoolOid = 16 ByteaOid = 17 CharOid = 18 + NameOid = 19 Int8Oid = 20 Int2Oid = 21 Int4Oid = 23 @@ -65,6 +66,12 @@ const maxUint = ^uint(0) const maxInt = int(maxUint >> 1) const minInt = -maxInt - 1 +// NameDataLen is the same as PostgreSQL's NAMEDATALEN, defined in +// src/include/pg_config_manual.h. It is how many bytes long identifiers +// are allowed to be, including the trailing '\0' at the end of C strings. +// (Identifieres are table names, column names, function names, etc.) +const NameDataLen = 64 + // DefaultTypeFormats maps type names to their default requested format (text // or binary). In theory the Scanner interface should be the one to determine // the format of the returned values. However, the query has already been @@ -91,6 +98,7 @@ func init() { "bool": BinaryFormatCode, "bytea": BinaryFormatCode, "char": BinaryFormatCode, + "name": BinaryFormatCode, "cidr": BinaryFormatCode, "date": BinaryFormatCode, "float4": BinaryFormatCode, @@ -259,6 +267,55 @@ func (n NullString) Encode(w *WriteBuf, oid Oid) error { return encodeString(w, oid, n.String) } +// The pgx.Name type is for PostgreSQL's special 63-byte +// name data type, used for identifiers like table names. +type Name string + +// LengthOK is a convenience method that returns false if a name is longer +// than PostgreSQL will allow. PostgreSQL identifiers are allowed +// to be 63 bytes long (NAMEDATALEN in the PostgreSQL source code +// is defined as 64 bytes long, but the 64th char is the '\0' C +// string terminator.) +func (n Name) LengthOK() bool { + return len(string(n)) < NameDataLen +} + +// NullName represents a pgx.Name that may be null. NullName implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan for prepared and unprepared queries. +// +// If Valid is false then the value is NULL. +type NullName struct { + Name Name + Valid bool // Valid is true if Char is not NULL +} + +func (n *NullName) Scan(vr *ValueReader) error { + if vr.Type().DataType != NameOid { + return SerializationError(fmt.Sprintf("NullName.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Name, n.Valid = "", false + return nil + } + + n.Valid = true + n.Name = Name(decodeText(vr)) + return vr.Err() +} + +func (n NullName) FormatCode() int16 { return TextFormatCode } + +func (n NullName) Encode(w *WriteBuf, oid Oid) error { + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeString(w, oid, string(n.Name)) +} + // The pgx.Char type is for PostgreSQL's special 8-bit-only // "char" type more akin to the C language's char type, or Go's byte type. // (Note that the name in PostgreSQL itself is "char", in double-quotes, @@ -906,6 +963,10 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { return encodeUInt(wbuf, oid, arg) case Char: return encodeChar(wbuf, oid, arg) + case Name: + // The name data type goes over the wire using the same format as string, + // so just cast to string and use encodeString + return encodeString(wbuf, oid, string(arg)) case int8: return encodeInt8(wbuf, oid, arg) case uint8: @@ -1084,6 +1145,9 @@ func Decode(vr *ValueReader, d interface{}) error { *v = uint64(n) case *Char: *v = decodeChar(vr) + case *Name: + // name goes over the wire just like text + *v = Name(decodeText(vr)) case *Oid: *v = decodeOid(vr) case *Xid: diff --git a/values_test.go b/values_test.go index 7de4e0c2..0deb3982 100644 --- a/values_test.go +++ b/values_test.go @@ -551,6 +551,28 @@ func TestInetCidrTranscodeWithJustIP(t *testing.T) { } } +func TestNameLengthOK(t *testing.T) { + tests := []struct { + input pgx.Name + expected bool + }{ + {"", true}, + {"1234", true}, + {"123456789012345678901234567890123456789012345678901234567890123", true}, + {"1234567890123456789012345678901234567890123456789012345678901234", false}, + } + + var actual bool + + for i, tt := range tests { + actual = tt.input.LengthOK() + + if actual != tt.expected { + t.Errorf("%d. Expected %v, got %v (name -> %v)", i, tt.expected, actual, tt.input) + } + } +} + func TestNullX(t *testing.T) { t.Parallel() @@ -562,6 +584,7 @@ func TestNullX(t *testing.T) { i16 pgx.NullInt16 i32 pgx.NullInt32 c pgx.NullChar + n pgx.NullName oid pgx.NullOid xid pgx.NullXid cid pgx.NullCid @@ -596,6 +619,11 @@ func TestNullX(t *testing.T) { {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 1, Valid: true}}}, {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: false}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 0, Valid: false}}}, {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}}, + {"select $1::name", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "foo", Valid: true}}}, + {"select $1::name", []interface{}{pgx.NullString{String: "foo", Valid: false}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "", Valid: false}}}, + // bytes past NameDataLen-1 (63 bytes) get silently truncated by PostgreSQL + {"select $1::name", []interface{}{pgx.NullString{String: "1234567890123456789012345678901234567890123456789012345678901234", Valid: true}}, + []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "123456789012345678901234567890123456789012345678901234567890123", Valid: true}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: false}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 0, Valid: false}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 4294967295, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 4294967295, Valid: true}}}, From ca96431b5e3a3431212e405c37e40cd9bc3989cc Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Thu, 29 Sep 2016 00:36:56 -0400 Subject: [PATCH 37/75] Fixes a documentation typo --- values.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/values.go b/values.go index ac472c51..d0940421 100644 --- a/values.go +++ b/values.go @@ -287,7 +287,7 @@ func (n Name) LengthOK() bool { // If Valid is false then the value is NULL. type NullName struct { Name Name - Valid bool // Valid is true if Char is not NULL + Valid bool // Valid is true if Name is not NULL } func (n *NullName) Scan(vr *ValueReader) error { From f7b6b3f077837daf4c0086b46d3eef6cce5e532a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Oct 2016 10:58:04 -0500 Subject: [PATCH 38/75] Handle json/jsonb in binary to support CopyTo fixes #189 --- conn.go | 2 +- copy_to_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++++ query.go | 8 +++++-- values.go | 55 +++++++++++++++++++++++++++++++++++++++++++------ 4 files changed, 111 insertions(+), 9 deletions(-) diff --git a/conn.go b/conn.go index c928ed98..9f15a999 100644 --- a/conn.go +++ b/conn.go @@ -917,7 +917,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid, RecordOid: + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid, RecordOid, JsonOid, JsonbOid: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) diff --git a/copy_to_test.go b/copy_to_test.go index d810c4fb..95b9af2d 100644 --- a/copy_to_test.go +++ b/copy_to_test.go @@ -119,6 +119,61 @@ func TestConnCopyToLarge(t *testing.T) { ensureConnValid(t, conn) } +func TestConnCopyToJSON(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} { + if _, ok := conn.PgTypes[oid]; !ok { + return // No JSON/JSONB type -- must be running against old PostgreSQL + } + } + + mustExec(t, conn, `create temporary table foo( + a json, + b jsonb + )`) + + inputRows := [][]interface{}{ + {map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}}, + {nil, nil}, + } + + copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyTo: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) +} + func TestConnCopyToFailServerSideMidway(t *testing.T) { t.Parallel() diff --git a/query.go b/query.go index 50c8e290..4e4b8e53 100644 --- a/query.go +++ b/query.go @@ -298,13 +298,17 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } - } else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid { + } else if vr.Type().DataType == JsonOid { // Because the argument passed to decodeJSON will escape the heap. // This allows d to be stack allocated and only copied to the heap when // we actually are decoding JSON. This saves one memory allocation per // row. d2 := d decodeJSON(vr, &d2) + } else if vr.Type().DataType == JsonbOid { + // Same trick as above for getting stack allocation + d2 := d + decodeJSONB(vr, &d2) } else { if err := Decode(vr, d); err != nil { rows.Fatal(scanArgError{col: i, err: err}) @@ -393,7 +397,7 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, d) case JsonbOid: var d interface{} - decodeJSON(vr, &d) + decodeJSONB(vr, &d) values = append(values, d) default: rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types")) diff --git a/values.go b/values.go index 545f1031..31f754ef 100644 --- a/values.go +++ b/values.go @@ -91,23 +91,25 @@ func init() { "bool": BinaryFormatCode, "bytea": BinaryFormatCode, "char": BinaryFormatCode, + "cid": BinaryFormatCode, "cidr": BinaryFormatCode, "date": BinaryFormatCode, "float4": BinaryFormatCode, "float8": BinaryFormatCode, + "json": BinaryFormatCode, + "jsonb": BinaryFormatCode, "inet": BinaryFormatCode, "int2": BinaryFormatCode, "int4": BinaryFormatCode, "int8": BinaryFormatCode, "oid": BinaryFormatCode, - "tid": BinaryFormatCode, - "xid": BinaryFormatCode, - "cid": BinaryFormatCode, "record": BinaryFormatCode, "text": BinaryFormatCode, + "tid": BinaryFormatCode, "timestamp": BinaryFormatCode, "timestamptz": BinaryFormatCode, "varchar": BinaryFormatCode, + "xid": BinaryFormatCode, } } @@ -889,9 +891,12 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { return Encode(wbuf, oid, arg) } - if oid == JsonOid || oid == JsonbOid { + if oid == JsonOid { return encodeJSON(wbuf, oid, arg) } + if oid == JsonbOid { + return encodeJSONB(wbuf, oid, arg) + } switch arg := arg.(type) { case []string: @@ -1914,7 +1919,7 @@ func decodeJSON(vr *ValueReader, d interface{}) error { return nil } - if vr.Type().DataType != JsonOid && vr.Type().DataType != JsonbOid { + if vr.Type().DataType != JsonOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into json", vr.Type().DataType))) } @@ -1927,7 +1932,7 @@ func decodeJSON(vr *ValueReader, d interface{}) error { } func encodeJSON(w *WriteBuf, oid Oid, value interface{}) error { - if oid != JsonOid && oid != JsonbOid { + if oid != JsonOid { return fmt.Errorf("cannot encode JSON into oid %v", oid) } @@ -1942,6 +1947,44 @@ func encodeJSON(w *WriteBuf, oid Oid, value interface{}) error { return nil } +func decodeJSONB(vr *ValueReader, d interface{}) error { + if vr.Len() == -1 { + return nil + } + + if vr.Type().DataType != JsonbOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into jsonb", vr.Type().DataType))) + } + + bytes := vr.ReadBytes(vr.Len()) + if bytes[0] != 1 { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown jsonb format byte: %x", bytes[0]))) + } + + err := json.Unmarshal(bytes[1:], d) + if err != nil { + vr.Fatal(err) + } + return err +} + +func encodeJSONB(w *WriteBuf, oid Oid, value interface{}) error { + if oid != JsonbOid { + return fmt.Errorf("cannot encode JSON into oid %v", oid) + } + + s, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("Failed to encode json from type: %T", value) + } + + w.WriteInt32(int32(len(s) + 1)) + w.WriteByte(1) // JSONB format header + w.WriteBytes(s) + + return nil +} + func decodeDate(vr *ValueReader) time.Time { var zeroTime time.Time From ed2ab0a12942159f0c0b15cc9c2485383db95a54 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Oct 2016 11:06:15 -0500 Subject: [PATCH 39/75] Update changlog --- CHANGELOG.md | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a9d70ee0..2513bb30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,16 @@ ## Fixes -* Oid underlying type changed to uint32, previously it was incorrectly int32 +* Oid underlying type changed to uint32, previously it was incorrectly int32 (Manni Wood) + +## Features + +* Add xid type support (Manni Wood) +* Add cid type support (Manni Wood) +* Add tid type support (Manni Wood) +* Add "char" type support (Manni Wood) +* Add NullOid type (Manni Wood) +* Add json/jsonb binary support to allow use with CopyTo # 2.9.0 (August 26, 2016) From c8575984d842fe3f99205f166e3429dbe37b8a24 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 1 Oct 2016 13:46:48 -0400 Subject: [PATCH 40/75] Removes name length convenience method --- values.go | 28 ++++++++++++---------------- values_test.go | 25 ------------------------- 2 files changed, 12 insertions(+), 41 deletions(-) diff --git a/values.go b/values.go index d0940421..45f3e9b0 100644 --- a/values.go +++ b/values.go @@ -66,12 +66,6 @@ const maxUint = ^uint(0) const maxInt = int(maxUint >> 1) const minInt = -maxInt - 1 -// NameDataLen is the same as PostgreSQL's NAMEDATALEN, defined in -// src/include/pg_config_manual.h. It is how many bytes long identifiers -// are allowed to be, including the trailing '\0' at the end of C strings. -// (Identifieres are table names, column names, function names, etc.) -const NameDataLen = 64 - // DefaultTypeFormats maps type names to their default requested format (text // or binary). In theory the Scanner interface should be the one to determine // the format of the returned values. However, the query has already been @@ -267,19 +261,21 @@ func (n NullString) Encode(w *WriteBuf, oid Oid) error { return encodeString(w, oid, n.String) } -// The pgx.Name type is for PostgreSQL's special 63-byte +// Name is a type used for PostgreSQL's special 63-byte // name data type, used for identifiers like table names. +// The pg_class.relname column is a good example of where the +// name data type is used. +// +// Note that the underlying Go data type of pgx.Name is string, +// so there is no way to enforce the 63-byte length. Inputting +// a longer name into PostgreSQL will result in silent truncation +// to 63 bytes. +// +// Also, if you have custom-compiled PostgreSQL and set +// NAMEDATALEN to a different value, obviously that number of +// bytes applies, rather than the default 63. type Name string -// LengthOK is a convenience method that returns false if a name is longer -// than PostgreSQL will allow. PostgreSQL identifiers are allowed -// to be 63 bytes long (NAMEDATALEN in the PostgreSQL source code -// is defined as 64 bytes long, but the 64th char is the '\0' C -// string terminator.) -func (n Name) LengthOK() bool { - return len(string(n)) < NameDataLen -} - // NullName represents a pgx.Name that may be null. NullName implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan for prepared and unprepared queries. diff --git a/values_test.go b/values_test.go index 0deb3982..6a92ee30 100644 --- a/values_test.go +++ b/values_test.go @@ -551,28 +551,6 @@ func TestInetCidrTranscodeWithJustIP(t *testing.T) { } } -func TestNameLengthOK(t *testing.T) { - tests := []struct { - input pgx.Name - expected bool - }{ - {"", true}, - {"1234", true}, - {"123456789012345678901234567890123456789012345678901234567890123", true}, - {"1234567890123456789012345678901234567890123456789012345678901234", false}, - } - - var actual bool - - for i, tt := range tests { - actual = tt.input.LengthOK() - - if actual != tt.expected { - t.Errorf("%d. Expected %v, got %v (name -> %v)", i, tt.expected, actual, tt.input) - } - } -} - func TestNullX(t *testing.T) { t.Parallel() @@ -621,9 +599,6 @@ func TestNullX(t *testing.T) { {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}}, {"select $1::name", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "foo", Valid: true}}}, {"select $1::name", []interface{}{pgx.NullString{String: "foo", Valid: false}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "", Valid: false}}}, - // bytes past NameDataLen-1 (63 bytes) get silently truncated by PostgreSQL - {"select $1::name", []interface{}{pgx.NullString{String: "1234567890123456789012345678901234567890123456789012345678901234", Valid: true}}, - []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "123456789012345678901234567890123456789012345678901234567890123", Valid: true}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: false}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 0, Valid: false}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 4294967295, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 4294967295, Valid: true}}}, From c14c63d63c936f3879573b4c3a4cdf8160ee5484 Mon Sep 17 00:00:00 2001 From: Nathaniel Waisbrot Date: Mon, 3 Oct 2016 08:45:41 -0400 Subject: [PATCH 41/75] Fix test failure when DB and client are not in the same time zone Explicitly set the time zone to UTC in the database and in the test expectation. Then compare the two times in the client-local time zone. --- values_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/values_test.go b/values_test.go index 7de4e0c2..9c5dbfb2 100644 --- a/values_test.go +++ b/values_test.go @@ -1095,11 +1095,11 @@ func TestRowDecode(t *testing.T) { expected []interface{} }{ { - "select row(1, 'cat', '2015-01-01 08:12:42'::timestamptz)", + "select row(1, 'cat', '2015-01-01 08:12:42-00'::timestamptz)", []interface{}{ int32(1), "cat", - time.Date(2015, 1, 1, 8, 12, 42, 0, time.Local), + time.Date(2015, 1, 1, 8, 12, 42, 0, time.UTC).Local(), }, }, } From f1de186a93b4d1c69a8ac4f27233496d028cef36 Mon Sep 17 00:00:00 2001 From: Alexander Staubo Date: Mon, 3 Oct 2016 15:56:01 -0400 Subject: [PATCH 42/75] Connection pool timeout should return a consistent error value so clients can test for it. --- conn_pool.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/conn_pool.go b/conn_pool.go index 6fbe143a..eac731dc 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -40,6 +40,9 @@ type ConnPoolStat struct { AvailableConnections int // unused live connections } +// ErrAcquireTimeout occurs when an attempt to acquire a connection times out. +var ErrAcquireTimeout = errors.New("timeout acquiring connection from pool") + // NewConnPool creates a new ConnPool. config.ConnConfig is passed through to // Connect directly. func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { @@ -131,7 +134,7 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { // Make sure the deadline (if it is) has not passed yet if p.deadlinePassed(deadline) { - return nil, errors.New("Timeout: Acquire connection timeout") + return nil, ErrAcquireTimeout } // If there is a deadline then start a timeout timer @@ -164,7 +167,7 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { // Wait until there is an available connection OR room to create a new connection for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections { if p.deadlinePassed(deadline) { - return nil, errors.New("Timeout: All connections in pool are busy") + return nil, ErrAcquireTimeout } p.cond.Wait() } From bf1cc4dbac6b3982f9c7ec4bc6aba50602a25ffc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 5 Oct 2016 08:42:30 -0500 Subject: [PATCH 43/75] Fix test for new named ErrAcquireTimeout --- conn_pool_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/conn_pool_test.go b/conn_pool_test.go index 2163f515..71a361a6 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -276,8 +276,8 @@ func TestPoolWithAcquireTimeoutSet(t *testing.T) { // ... then try to consume 1 more. It should fail after a short timeout. _, timeTaken, err := acquireWithTimeTaken(pool) - if err == nil || err.Error() != "Timeout: All connections in pool are busy" { - t.Fatalf("Expected error to be 'Timeout: All connections in pool are busy', instead it was '%v'", err) + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) } if timeTaken < connAllocTimeout { t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", connAllocTimeout, timeTaken) From a62698f8a7e2fd27806443f340ae79c34b2cb3c7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 5 Oct 2016 08:43:39 -0500 Subject: [PATCH 44/75] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2513bb30..4e930b13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ * Add "char" type support (Manni Wood) * Add NullOid type (Manni Wood) * Add json/jsonb binary support to allow use with CopyTo +* Add named error ErrAcquireTimeout (Alexander Staubo) # 2.9.0 (August 26, 2016) From 7dec41fb6d42af32089c03b9681f547164a90a62 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Mon, 10 Oct 2016 20:41:57 -0400 Subject: [PATCH 45/75] Fixes TestNullX test of NullName --- values_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/values_test.go b/values_test.go index fc5a58b3..ab40c65c 100644 --- a/values_test.go +++ b/values_test.go @@ -597,8 +597,8 @@ func TestNullX(t *testing.T) { {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 1, Valid: true}}}, {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: false}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 0, Valid: false}}}, {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}}, - {"select $1::name", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "foo", Valid: true}}}, - {"select $1::name", []interface{}{pgx.NullString{String: "foo", Valid: false}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "", Valid: false}}}, + {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: true}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "foo", Valid: true}}}, + {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: false}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "", Valid: false}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: false}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 0, Valid: false}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 4294967295, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 4294967295, Valid: true}}}, From 3734a92a71d7831b3c6ee7e4d9754a8dea5d9ab4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 21 Oct 2016 14:27:38 -0500 Subject: [PATCH 46/75] Log TLS connection errors as info when fallback available fixes #198 --- conn.go | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/conn.go b/conn.go index 9f15a999..a5ae3a48 100644 --- a/conn.go +++ b/conn.go @@ -209,12 +209,21 @@ func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsqlAfInet *byte, pgsql c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial } + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address)) + } err = c.connect(config, network, address, config.TLSConfig) if err != nil && config.UseFallbackTLS { + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, fmt.Sprintf("Connect with TLSConfig failed, trying FallbackTLSConfig: %v", err)) + } err = c.connect(config, network, address, config.FallbackTLSConfig) } if err != nil { + if c.shouldLog(LogLevelError) { + c.log(LogLevelError, fmt.Sprintf("Connect failed: %v", err)) + } return nil, err } @@ -222,23 +231,14 @@ func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsqlAfInet *byte, pgsql } func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address)) - } c.conn, err = c.config.Dial(network, address) if err != nil { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, fmt.Sprintf("Connection failed: %v", err)) - } return err } defer func() { if c != nil && err != nil { c.conn.Close() c.alive = false - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, err.Error()) - } } }() @@ -253,9 +253,6 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.log(LogLevelDebug, "Starting TLS handshake") } if err := c.startTLS(tlsConfig); err != nil { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, fmt.Sprintf("TLS failed: %v", err)) - } return err } } From f73791c6c9511963b6a37a01caf9650878216c0e Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Thu, 27 Oct 2016 21:33:56 -0400 Subject: [PATCH 47/75] Adds NullAclItem --- values.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ values_test.go | 3 +++ 2 files changed, 53 insertions(+) diff --git a/values.go b/values.go index f2e43473..12004433 100644 --- a/values.go +++ b/values.go @@ -44,6 +44,7 @@ const ( Int8ArrayOid = 1016 Float4ArrayOid = 1021 Float8ArrayOid = 1022 + AclItemOid = 1033 InetArrayOid = 1041 VarcharOid = 1043 DateOid = 1082 @@ -89,6 +90,7 @@ func init() { "_timestamp": BinaryFormatCode, "_timestamptz": BinaryFormatCode, "_varchar": BinaryFormatCode, + "aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin) "bool": BinaryFormatCode, "bytea": BinaryFormatCode, "char": BinaryFormatCode, @@ -263,6 +265,47 @@ func (n NullString) Encode(w *WriteBuf, oid Oid) error { return encodeString(w, oid, n.String) } +// AclItem is used for PostgreSQL's aclitem data type. +type AclItem string + +// NullAclItem represents a pgx.AclItem that may be null. NullAclItem implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan for prepared and unprepared queries. +// +// If Valid is false then the value is NULL. +type NullAclItem struct { + AclItem AclItem + Valid bool // Valid is true if AclItem is not NULL +} + +func (n *NullAclItem) Scan(vr *ValueReader) error { + if vr.Type().DataType != AclItemOid { + return SerializationError(fmt.Sprintf("NullAclItem.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.AclItem, n.Valid = "", false + return nil + } + + n.Valid = true + n.AclItem = AclItem(decodeText(vr)) + return vr.Err() +} + +// Particularly important to return TextFormatCode, seeing as Postgres +// only ever sends aclitem as text, not binary. +func (n NullAclItem) FormatCode() int16 { return TextFormatCode } + +func (n NullAclItem) Encode(w *WriteBuf, oid Oid) error { + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeString(w, oid, string(n.AclItem)) +} + // Name is a type used for PostgreSQL's special 63-byte // name data type, used for identifiers like table names. // The pg_class.relname column is a good example of where the @@ -964,6 +1007,10 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { return encodeUInt(wbuf, oid, arg) case Char: return encodeChar(wbuf, oid, arg) + case AclItem: + // The aclitem data type goes over the wire using the same format as string, + // so just cast to string and use encodeString + return encodeString(wbuf, oid, string(arg)) case Name: // The name data type goes over the wire using the same format as string, // so just cast to string and use encodeString @@ -1146,6 +1193,9 @@ func Decode(vr *ValueReader, d interface{}) error { *v = uint64(n) case *Char: *v = decodeChar(vr) + case *AclItem: + // aclitem goes over the wire just like text + *v = AclItem(decodeText(vr)) case *Name: // name goes over the wire just like text *v = Name(decodeText(vr)) diff --git a/values_test.go b/values_test.go index ab40c65c..c198c57f 100644 --- a/values_test.go +++ b/values_test.go @@ -562,6 +562,7 @@ func TestNullX(t *testing.T) { i16 pgx.NullInt16 i32 pgx.NullInt32 c pgx.NullChar + a pgx.NullAclItem n pgx.NullName oid pgx.NullOid xid pgx.NullXid @@ -599,6 +600,8 @@ func TestNullX(t *testing.T) { {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}}, {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: true}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "foo", Valid: true}}}, {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: false}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "", Valid: false}}}, + {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}}, + {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: false}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "", Valid: false}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: false}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 0, Valid: false}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 4294967295, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 4294967295, Valid: true}}}, From df033d499fa22b4140283b39e2e636ca8ba5c2a8 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Thu, 27 Oct 2016 21:57:46 -0400 Subject: [PATCH 48/75] Adds a tricky user to test This allows us to test aclitem encoding with tricky SQL identifiers. The user actually has to exist, or the aclitem will be incorrect. --- .travis.yml | 1 + README.md | 1 + values.go | 13 ++++++++++++- values_test.go | 1 + 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 537dca8d..60f734ad 100644 --- a/.travis.yml +++ b/.travis.yml @@ -36,6 +36,7 @@ before_script: - psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" + - psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" install: - go get -u github.com/shopspring/decimal diff --git a/README.md b/README.md index cf244b04..b78fbc5c 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,7 @@ To setup the normal test environment, first install these dependencies: Then run the following SQL: create user pgx_md5 password 'secret'; + create user " tricky, ' } "" \ test user " superuser password 'secret'; create database pgx_test; Connect to database pgx_test and run: diff --git a/values.go b/values.go index 12004433..6cb6e429 100644 --- a/values.go +++ b/values.go @@ -265,7 +265,18 @@ func (n NullString) Encode(w *WriteBuf, oid Oid) error { return encodeString(w, oid, n.String) } -// AclItem is used for PostgreSQL's aclitem data type. +// AclItem is used for PostgreSQL's aclitem data type. A sample aclitem +// might look like this: +// +// postgres=arwdDxt/postgres +// +// Note, however, that because the user/role name part of an aclitem is +// an identifier, it follows all the usual formatting rules for SQL +// identifiers: if it contains spaces and other special characters, +// it should appear in double-quotes: +// +// postgres=arwdDxt/"role with spaces" +// type AclItem string // NullAclItem represents a pgx.AclItem that may be null. NullAclItem implements the diff --git a/values_test.go b/values_test.go index c198c57f..7f0571d1 100644 --- a/values_test.go +++ b/values_test.go @@ -600,6 +600,7 @@ func TestNullX(t *testing.T) { {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}}, {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: true}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "foo", Valid: true}}}, {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: false}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "", Valid: false}}}, + {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}}, {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}}, {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: false}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "", Valid: false}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}}, From c1177f292e01fbfa9ca3b4187788df49e2f8657b Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Thu, 27 Oct 2016 22:02:12 -0400 Subject: [PATCH 49/75] Adds note on why tricky test user has to actually exist --- .travis.yml | 2 ++ values_test.go | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 60f734ad..b120b33a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -29,6 +29,8 @@ env: - PGVERSION=9.2 - PGVERSION=9.1 +# The tricky test user, below, has to actually exist so that it can be used in a test +# of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. before_script: - mv conn_config_test.go.travis conn_config_test.go - psql -U postgres -c 'create database pgx_test' diff --git a/values_test.go b/values_test.go index 7f0571d1..8b85ceef 100644 --- a/values_test.go +++ b/values_test.go @@ -600,9 +600,10 @@ func TestNullX(t *testing.T) { {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}}, {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: true}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "foo", Valid: true}}}, {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: false}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "", Valid: false}}}, - {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}}, {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}}, {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: false}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "", Valid: false}}}, + // A tricky (and valid) aclitem can still be used, especially with Go's useful backticks + {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: false}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 0, Valid: false}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 4294967295, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 4294967295, Valid: true}}}, From 8c76baabbd6d3a2c03dc2791a05e04a9465a0cc3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 28 Oct 2016 15:56:11 -0500 Subject: [PATCH 50/75] Add changelog note of jsonb []byte change fixes #200 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e930b13..bedf106b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ * Add json/jsonb binary support to allow use with CopyTo * Add named error ErrAcquireTimeout (Alexander Staubo) +## Compatibility + +* jsonb now defaults to binary format. This means passing a []byte to a jsonb column will no longer work. + # 2.9.0 (August 26, 2016) ## Fixes From 84439a13cb4bdb9911455577f500595d0aef0f9a Mon Sep 17 00:00:00 2001 From: ferhat elmas Date: Wed, 9 Nov 2016 00:50:33 +0100 Subject: [PATCH 51/75] Simplify map composite literals as gofmt -s handles --- conn.go | 2 +- hstore_test.go | 26 +++++++++++++------------- stdlib/sql.go | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/conn.go b/conn.go index a5ae3a48..14d89897 100644 --- a/conn.go +++ b/conn.go @@ -427,7 +427,7 @@ func ParseURI(uri string) (ConnConfig, error) { } ignoreKeys := map[string]struct{}{ - "sslmode": struct{}{}, + "sslmode": {}, } cp.RuntimeParams = make(map[string]string) diff --git a/hstore_test.go b/hstore_test.go index dba5206b..c948f0cd 100644 --- a/hstore_test.go +++ b/hstore_test.go @@ -85,23 +85,23 @@ func TestNullHstoreTranscode(t *testing.T) { {pgx.NullHstore{}, "null"}, {pgx.NullHstore{Valid: true}, "empty"}, {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "bar", Valid: true}}, + Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}}, Valid: true}, "single key/value"}, {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "bar", Valid: true}, "baz": pgx.NullString{String: "quz", Valid: true}}, + Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}, "baz": {String: "quz", Valid: true}}, Valid: true}, "multiple key/values"}, {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"NULL": pgx.NullString{String: "bar", Valid: true}}, + Hstore: map[string]pgx.NullString{"NULL": {String: "bar", Valid: true}}, Valid: true}, `string "NULL" key`}, {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "NULL", Valid: true}}, + Hstore: map[string]pgx.NullString{"foo": {String: "NULL", Valid: true}}, Valid: true}, `string "NULL" value`}, {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "", Valid: false}}, + Hstore: map[string]pgx.NullString{"foo": {String: "", Valid: false}}, Valid: true}, `NULL value`}, } @@ -120,36 +120,36 @@ func TestNullHstoreTranscode(t *testing.T) { } for _, sst := range specialStringTests { tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{sst.input + "foo": pgx.NullString{String: "bar", Valid: true}}, + Hstore: map[string]pgx.NullString{sst.input + "foo": {String: "bar", Valid: true}}, Valid: true}, "key with " + sst.description + " at beginning"}) tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo" + sst.input + "foo": pgx.NullString{String: "bar", Valid: true}}, + Hstore: map[string]pgx.NullString{"foo" + sst.input + "foo": {String: "bar", Valid: true}}, Valid: true}, "key with " + sst.description + " in middle"}) tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo" + sst.input: pgx.NullString{String: "bar", Valid: true}}, + Hstore: map[string]pgx.NullString{"foo" + sst.input: {String: "bar", Valid: true}}, Valid: true}, "key with " + sst.description + " at end"}) tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{sst.input: pgx.NullString{String: "bar", Valid: true}}, + Hstore: map[string]pgx.NullString{sst.input: {String: "bar", Valid: true}}, Valid: true}, "key is " + sst.description}) tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: sst.input + "bar", Valid: true}}, + Hstore: map[string]pgx.NullString{"foo": {String: sst.input + "bar", Valid: true}}, Valid: true}, "value with " + sst.description + " at beginning"}) tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "bar" + sst.input + "bar", Valid: true}}, + Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input + "bar", Valid: true}}, Valid: true}, "value with " + sst.description + " in middle"}) tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: "bar" + sst.input, Valid: true}}, + Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input, Valid: true}}, Valid: true}, "value with " + sst.description + " at end"}) tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": pgx.NullString{String: sst.input, Valid: true}}, + Hstore: map[string]pgx.NullString{"foo": {String: sst.input, Valid: true}}, Valid: true}, "value is " + sst.description}) } diff --git a/stdlib/sql.go b/stdlib/sql.go index 5bf2c113..610aefd4 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -230,7 +230,7 @@ func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, er // text format so that pgx.Rows.Values doesn't decode it into a native type // (e.g. []int32) func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) { - for i, _ := range ps.FieldDescriptions { + for i := range ps.FieldDescriptions { intrinsic, _ := databaseSqlOids[ps.FieldDescriptions[i].DataType] if !intrinsic { ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode From c9292c44e604ff434b39af08d0a250c877fc00f6 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 12 Nov 2016 11:42:07 -0500 Subject: [PATCH 52/75] Adds aclitem[] len 1 ability --- values.go | 18 ++++++++++++++++++ values_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/values.go b/values.go index 6cb6e429..2f64f2d5 100644 --- a/values.go +++ b/values.go @@ -45,6 +45,7 @@ const ( Float4ArrayOid = 1021 Float8ArrayOid = 1022 AclItemOid = 1033 + AclItemArrayOid = 1034 InetArrayOid = 1041 VarcharOid = 1043 DateOid = 1082 @@ -77,6 +78,7 @@ var DefaultTypeFormats map[string]int16 func init() { DefaultTypeFormats = map[string]int16{ + "_aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin) "_bool": BinaryFormatCode, "_bytea": BinaryFormatCode, "_cidr": BinaryFormatCode, @@ -981,6 +983,8 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { return Encode(wbuf, oid, v) case string: return encodeString(wbuf, oid, arg) + case []AclItem: + return encodeAclItemSlice(wbuf, oid, arg) case []byte: return encodeByteSlice(wbuf, oid, arg) case [][]byte: @@ -1224,6 +1228,8 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeFloat4(vr) case *float64: *v = decodeFloat8(vr) + case *[]AclItem: + *v = decodeAclItemArray(vr) case *[]bool: *v = decodeBoolArray(vr) case *[]int16: @@ -2993,6 +2999,18 @@ func decodeTextArray(vr *ValueReader) []string { return a } +// XXX: encodeAclItemSlice; using text encoding, not binary +func encodeAclItemSlice(w *WriteBuf, oid Oid, value []AclItem) error { + w.WriteInt32(int32(len("{=r/postgres}"))) + w.WriteBytes([]byte("{=r/postgres}")) + return nil +} + +// XXX: decodeAclItemArray; using text encoding, not binary +func decodeAclItemArray(vr *ValueReader) []AclItem { + return []AclItem{"=r/postgres"} +} + func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error { var elOid Oid switch oid { diff --git a/values_test.go b/values_test.go index 8b85ceef..c2a89d79 100644 --- a/values_test.go +++ b/values_test.go @@ -643,6 +643,42 @@ func TestNullX(t *testing.T) { } } +func TestAclArrayDecoding(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + tests := []struct { + sql string + query interface{} + scan interface{} + assert func(*testing.T, interface{}, interface{}) + }{ + { + "select $1::aclitem[]", + []pgx.AclItem{"=r/postgres"}, + &[]pgx.AclItem{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]pgx.AclItem))) { + t.Errorf("failed to encode aclitem[]") + } + }, + }, + } + for i, tt := range tests { + err := conn.QueryRow(tt.sql, tt.query).Scan(tt.scan) + if err != nil { + t.Errorf(`%d. error reading array: %v`, i, err) + if pgerr, ok := err.(pgx.PgError); ok { + t.Errorf(`%d. error reading array (detail): %s`, i, pgerr.Detail) + } + continue + } + tt.assert(t, tt.query, tt.scan) + ensureConnValid(t, conn) + } +} + func TestArrayDecoding(t *testing.T) { t.Parallel() From a80ef6d35fbe82480d58b1278b8fe55af5bd12e0 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 12 Nov 2016 11:46:07 -0500 Subject: [PATCH 53/75] Actually takes the first arg --- values.go | 5 +++-- values_test.go | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/values.go b/values.go index 2f64f2d5..0723189c 100644 --- a/values.go +++ b/values.go @@ -3001,8 +3001,9 @@ func decodeTextArray(vr *ValueReader) []string { // XXX: encodeAclItemSlice; using text encoding, not binary func encodeAclItemSlice(w *WriteBuf, oid Oid, value []AclItem) error { - w.WriteInt32(int32(len("{=r/postgres}"))) - w.WriteBytes([]byte("{=r/postgres}")) + str := "{" + value[0] + "}" + w.WriteInt32(int32(len(str))) + w.WriteBytes([]byte(str)) return nil } diff --git a/values_test.go b/values_test.go index c2a89d79..92ec01a6 100644 --- a/values_test.go +++ b/values_test.go @@ -643,6 +643,7 @@ func TestNullX(t *testing.T) { } } +// XXX func TestAclArrayDecoding(t *testing.T) { t.Parallel() From 36bdbd7cb105417f883270555f853b2cc17c43b5 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 12 Nov 2016 11:56:04 -0500 Subject: [PATCH 54/75] Parses actual return string ...but only handles aclitem[] size 1 --- values.go | 11 ++++++++++- values_test.go | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/values.go b/values.go index 0723189c..251671bc 100644 --- a/values.go +++ b/values.go @@ -3009,7 +3009,16 @@ func encodeAclItemSlice(w *WriteBuf, oid Oid, value []AclItem) error { // XXX: decodeAclItemArray; using text encoding, not binary func decodeAclItemArray(vr *ValueReader) []AclItem { - return []AclItem{"=r/postgres"} + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into []AclItem")) + return nil + } + + str := vr.ReadString(vr.Len()) + // remove the '{' at the front and the '}' at the end + str = str[1 : len(str)-1] + return []AclItem{AclItem(str)} + // return []AclItem{"=r/postgres"} } func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error { diff --git a/values_test.go b/values_test.go index 92ec01a6..326335e0 100644 --- a/values_test.go +++ b/values_test.go @@ -661,7 +661,7 @@ func TestAclArrayDecoding(t *testing.T) { &[]pgx.AclItem{}, func(t *testing.T, query, scan interface{}) { if !reflect.DeepEqual(query, *(scan.(*[]pgx.AclItem))) { - t.Errorf("failed to encode aclitem[]") + t.Errorf("failed to encode aclitem[]\n EXPECTED: %v\n ACTUAL: %v", query, *(scan.(*[]pgx.AclItem))) } }, }, From 7d7bc873964b2dedd64a24a5769e92e2381d6847 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 12 Nov 2016 12:01:03 -0500 Subject: [PATCH 55/75] Moves sql outside of struct --- values.go | 1 - values_test.go | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/values.go b/values.go index 251671bc..e0d56351 100644 --- a/values.go +++ b/values.go @@ -3018,7 +3018,6 @@ func decodeAclItemArray(vr *ValueReader) []AclItem { // remove the '{' at the front and the '}' at the end str = str[1 : len(str)-1] return []AclItem{AclItem(str)} - // return []AclItem{"=r/postgres"} } func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error { diff --git a/values_test.go b/values_test.go index 326335e0..4c62007c 100644 --- a/values_test.go +++ b/values_test.go @@ -649,14 +649,14 @@ func TestAclArrayDecoding(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) + + sql := "select $1::aclitem[]" tests := []struct { - sql string query interface{} scan interface{} assert func(*testing.T, interface{}, interface{}) }{ { - "select $1::aclitem[]", []pgx.AclItem{"=r/postgres"}, &[]pgx.AclItem{}, func(t *testing.T, query, scan interface{}) { @@ -667,7 +667,7 @@ func TestAclArrayDecoding(t *testing.T) { }, } for i, tt := range tests { - err := conn.QueryRow(tt.sql, tt.query).Scan(tt.scan) + err := conn.QueryRow(sql, tt.query).Scan(tt.scan) if err != nil { t.Errorf(`%d. error reading array: %v`, i, err) if pgerr, ok := err.(pgx.PgError); ok { From d9ab2197539c064100e1ae4a7e0ca8dc36156217 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 12 Nov 2016 12:07:48 -0500 Subject: [PATCH 56/75] Pulls out aclitem[] assert func --- values_test.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/values_test.go b/values_test.go index 4c62007c..46652d4d 100644 --- a/values_test.go +++ b/values_test.go @@ -643,6 +643,12 @@ func TestNullX(t *testing.T) { } } +func assertAclItemSlicesEqual(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]pgx.AclItem))) { + t.Errorf("failed to encode aclitem[]\n EXPECTED: %v\n ACTUAL: %v", query, *(scan.(*[]pgx.AclItem))) + } +} + // XXX func TestAclArrayDecoding(t *testing.T) { t.Parallel() @@ -652,18 +658,12 @@ func TestAclArrayDecoding(t *testing.T) { sql := "select $1::aclitem[]" tests := []struct { - query interface{} - scan interface{} - assert func(*testing.T, interface{}, interface{}) + query interface{} + scan interface{} }{ { []pgx.AclItem{"=r/postgres"}, &[]pgx.AclItem{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]pgx.AclItem))) { - t.Errorf("failed to encode aclitem[]\n EXPECTED: %v\n ACTUAL: %v", query, *(scan.(*[]pgx.AclItem))) - } - }, }, } for i, tt := range tests { @@ -675,7 +675,7 @@ func TestAclArrayDecoding(t *testing.T) { } continue } - tt.assert(t, tt.query, tt.scan) + assertAclItemSlicesEqual(t, tt.query, tt.scan) ensureConnValid(t, conn) } } From 104c01df218e7a68cb25e5a87ddbd08942f3082d Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 12 Nov 2016 12:28:31 -0500 Subject: [PATCH 57/75] Handles aclitem lists of 1+ --- values.go | 20 +++++++++++++++++--- values_test.go | 4 ++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/values.go b/values.go index e0d56351..395b59ce 100644 --- a/values.go +++ b/values.go @@ -3000,8 +3000,15 @@ func decodeTextArray(vr *ValueReader) []string { } // XXX: encodeAclItemSlice; using text encoding, not binary -func encodeAclItemSlice(w *WriteBuf, oid Oid, value []AclItem) error { - str := "{" + value[0] + "}" +func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { + // cast aclitems into strings so we can use strings.Join + strs := make([]string, len(aclitems)) + for i := range strs { + strs[i] = string(aclitems[i]) + } + + str := strings.Join(strs, ",") + str = "{" + str + "}" w.WriteInt32(int32(len(str))) w.WriteBytes([]byte(str)) return nil @@ -3017,7 +3024,14 @@ func decodeAclItemArray(vr *ValueReader) []AclItem { str := vr.ReadString(vr.Len()) // remove the '{' at the front and the '}' at the end str = str[1 : len(str)-1] - return []AclItem{AclItem(str)} + strs := strings.Split(str, ",") + + // cast strings into AclItems before returning + aclitems := make([]AclItem, len(strs)) + for i := range aclitems { + aclitems[i] = AclItem(strs[i]) + } + return aclitems } func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error { diff --git a/values_test.go b/values_test.go index 46652d4d..83244b1b 100644 --- a/values_test.go +++ b/values_test.go @@ -665,6 +665,10 @@ func TestAclArrayDecoding(t *testing.T) { []pgx.AclItem{"=r/postgres"}, &[]pgx.AclItem{}, }, + { + []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres"}, + &[]pgx.AclItem{}, + }, } for i, tt := range tests { err := conn.QueryRow(sql, tt.query).Scan(tt.scan) From 96b652cc95e53d52315dcb33e5a3ea69bc13f19e Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 12 Nov 2016 12:36:55 -0500 Subject: [PATCH 58/75] Makes aclitem test types more specific --- values_test.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/values_test.go b/values_test.go index 83244b1b..662980c2 100644 --- a/values_test.go +++ b/values_test.go @@ -643,9 +643,9 @@ func TestNullX(t *testing.T) { } } -func assertAclItemSlicesEqual(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]pgx.AclItem))) { - t.Errorf("failed to encode aclitem[]\n EXPECTED: %v\n ACTUAL: %v", query, *(scan.(*[]pgx.AclItem))) +func assertAclItemSlicesEqual(t *testing.T, query, scan []pgx.AclItem) { + if !reflect.DeepEqual(query, scan) { + t.Errorf("failed to encode aclitem[]\n EXPECTED: %v\n ACTUAL: %v", query, scan) } } @@ -658,20 +658,20 @@ func TestAclArrayDecoding(t *testing.T) { sql := "select $1::aclitem[]" tests := []struct { - query interface{} - scan interface{} + query []pgx.AclItem + scan []pgx.AclItem }{ { []pgx.AclItem{"=r/postgres"}, - &[]pgx.AclItem{}, + []pgx.AclItem{}, }, { []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres"}, - &[]pgx.AclItem{}, + []pgx.AclItem{}, }, } for i, tt := range tests { - err := conn.QueryRow(sql, tt.query).Scan(tt.scan) + err := conn.QueryRow(sql, tt.query).Scan(&tt.scan) if err != nil { t.Errorf(`%d. error reading array: %v`, i, err) if pgerr, ok := err.(pgx.PgError); ok { From b12a1bb8bccd00071a22e94af41fc0dc11ff6525 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 12 Nov 2016 12:38:30 -0500 Subject: [PATCH 59/75] Removes scan from test struct --- values_test.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/values_test.go b/values_test.go index 662980c2..17a98197 100644 --- a/values_test.go +++ b/values_test.go @@ -657,21 +657,20 @@ func TestAclArrayDecoding(t *testing.T) { defer closeConn(t, conn) sql := "select $1::aclitem[]" + var scan []pgx.AclItem + tests := []struct { query []pgx.AclItem - scan []pgx.AclItem }{ { []pgx.AclItem{"=r/postgres"}, - []pgx.AclItem{}, }, { []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres"}, - []pgx.AclItem{}, }, } for i, tt := range tests { - err := conn.QueryRow(sql, tt.query).Scan(&tt.scan) + err := conn.QueryRow(sql, tt.query).Scan(&scan) if err != nil { t.Errorf(`%d. error reading array: %v`, i, err) if pgerr, ok := err.(pgx.PgError); ok { @@ -679,7 +678,7 @@ func TestAclArrayDecoding(t *testing.T) { } continue } - assertAclItemSlicesEqual(t, tt.query, tt.scan) + assertAclItemSlicesEqual(t, tt.query, scan) ensureConnValid(t, conn) } } From 9b8e3043baf38a70b943e89e22b7b61a77a0c7f7 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sat, 12 Nov 2016 12:46:31 -0500 Subject: [PATCH 60/75] Handles empty aclitems --- values.go | 6 ++++++ values_test.go | 3 +++ 2 files changed, 9 insertions(+) diff --git a/values.go b/values.go index 395b59ce..ff2dfbfb 100644 --- a/values.go +++ b/values.go @@ -3022,6 +3022,12 @@ func decodeAclItemArray(vr *ValueReader) []AclItem { } str := vr.ReadString(vr.Len()) + + // short-circuit empty array + if str == "{}" { + return []AclItem{} + } + // remove the '{' at the front and the '}' at the end str = str[1 : len(str)-1] strs := strings.Split(str, ",") diff --git a/values_test.go b/values_test.go index 17a98197..4cbe40ea 100644 --- a/values_test.go +++ b/values_test.go @@ -662,6 +662,9 @@ func TestAclArrayDecoding(t *testing.T) { tests := []struct { query []pgx.AclItem }{ + { + []pgx.AclItem{}, + }, { []pgx.AclItem{"=r/postgres"}, }, From 4ba4d0097a929af934958aedc052e8302b3a484f Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Sun, 13 Nov 2016 18:08:36 -0500 Subject: [PATCH 61/75] Gets formatting correct for tricky ingoing string ...but broken for outgoing string; must fix next --- values_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/values_test.go b/values_test.go index 4cbe40ea..01e5114b 100644 --- a/values_test.go +++ b/values_test.go @@ -645,7 +645,7 @@ func TestNullX(t *testing.T) { func assertAclItemSlicesEqual(t *testing.T, query, scan []pgx.AclItem) { if !reflect.DeepEqual(query, scan) { - t.Errorf("failed to encode aclitem[]\n EXPECTED: %v\n ACTUAL: %v", query, scan) + t.Errorf("failed to encode aclitem[]\n EXPECTED: %d %v\n ACTUAL: %d %v", len(query), query, len(scan), scan) } } @@ -671,6 +671,9 @@ func TestAclArrayDecoding(t *testing.T) { { []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres"}, }, + { + []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres", `postgres=arwdDxt/\" tricky\, ' \} \"\" \\ test user \"`}, + }, } for i, tt := range tests { err := conn.QueryRow(sql, tt.query).Scan(&scan) From 5712d02e1b58313997f84705de92966505584210 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Tue, 15 Nov 2016 21:53:22 -0500 Subject: [PATCH 62/75] Gets tricky acl parsing working --- values.go | 152 ++++++++++++++++++++++++++++++++++++++++++++++++- values_test.go | 5 +- 2 files changed, 153 insertions(+), 4 deletions(-) diff --git a/values.go b/values.go index ff2dfbfb..fe8f82fa 100644 --- a/values.go +++ b/values.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "io" "math" "net" "reflect" @@ -2999,12 +3000,42 @@ func decodeTextArray(vr *ValueReader) []string { return a } +func EscapeAclItem(acl string) (string, error) { + var buf bytes.Buffer + r := strings.NewReader(acl) + for { + rn, _, err := r.ReadRune() + if err != nil { + if err == io.EOF { + // This error was expected and is OK + return buf.String(), nil + } + // This error was not expected + return "", err + } + if NeedsEscape(rn) { + buf.WriteRune('\\') + } + buf.WriteRune(rn) + } +} + +func NeedsEscape(rn rune) bool { + return rn == '\\' || rn == ',' || rn == '"' || rn == '}' +} + // XXX: encodeAclItemSlice; using text encoding, not binary func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { // cast aclitems into strings so we can use strings.Join strs := make([]string, len(aclitems)) + var escaped string + var err error for i := range strs { - strs[i] = string(aclitems[i]) + escaped, err = EscapeAclItem(string(aclitems[i])) + if err != nil { + return err + } + strs[i] = string(escaped) } str := strings.Join(strs, ",") @@ -3014,6 +3045,121 @@ func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { return nil } +func ParseAclItemArray(arr string) ([]string, error) { + r := strings.NewReader(arr) + // Difficult to guess a performant initial capacity for a slice of + // values, but let's go with 5. + vals := make([]string, 0, 5) + // A single value + vlu := "" + for { + // Grab the first/next/last rune to see if we are dealing with a + // quoted value, an unquoted value, or the end of the string. + rn, _, err := r.ReadRune() + if err != nil { + if err == io.EOF { + // This error was expected and is OK + return vals, nil + } + // This error was not expected + return nil, err + } + + if rn == '"' { + // Discard the opening quote of the quoted value. + vlu, err = ParseQuotedAclItem(r) + } else { + // We have just read the first rune of an unquoted (bare) value; + // put it back so that ParseBareValue can read it. + err := r.UnreadRune() + if err != nil { + // This error was not expected. + return nil, err + } + vlu, err = ParseBareAclItem(r) + } + + if err != nil { + if err == io.EOF { + // This error was expected and is OK. + vals = append(vals, vlu) + return vals, nil + } + // This error was not expected. + return nil, err + } + vals = append(vals, vlu) + } +} + +func ParseBareAclItem(r *strings.Reader) (string, error) { + var buf bytes.Buffer + for { + rn, _, err := r.ReadRune() + if err != nil { + // Return the read value in case the error is a harmless io.EOF. + // (io.EOF marks the end of a bare value at the end of a string) + return buf.String(), err + } + if rn == ',' { + // A comma marks the end of a bare value. + return buf.String(), nil + } else { + buf.WriteRune(rn) + } + } +} + +func ParseQuotedAclItem(r *strings.Reader) (string, error) { + var buf bytes.Buffer + for { + rn, escaped, err := ReadPossiblyEscapedRune(r) + if err != nil { + if err == io.EOF { + // Even when it is the last value, the final rune of + // a quoted value should be the final closing quote, not io.EOF. + return "", fmt.Errorf("unexpected end of quoted value") + } + // Return the read value in case the error is a harmless io.EOF. + return buf.String(), err + } + if !escaped && rn == '"' { + // An unescaped double quote marks the end of a quoted value. + // The next rune should either be a comma or the end of the string. + rn, _, err := r.ReadRune() + if err != nil { + // Return the read value in case the error is a harmless io.EOF. + return buf.String(), err + } + if rn != ',' { + return "", fmt.Errorf("unexpected rune after quoted value") + } + return buf.String(), nil + } + buf.WriteRune(rn) + } +} + +// Returns the next rune from r, unless it is a backslash; +// in that case, it returns the rune after the backslash. The second +// return value tells us whether or not the rune was +// preceeded by a backslash (escaped). +func ReadPossiblyEscapedRune(r *strings.Reader) (rune, bool, error) { + rn, _, err := r.ReadRune() + if err != nil { + return 0, false, err + } + if rn == '\\' { + // Discard the backslash and read the next rune. + rn, _, err = r.ReadRune() + if err != nil { + return 0, false, err + } + return rn, true, nil + } + return rn, false, nil +} + // XXX: decodeAclItemArray; using text encoding, not binary func decodeAclItemArray(vr *ValueReader) []AclItem { if vr.Len() == -1 { @@ -3030,7 +3176,9 @@ func decodeAclItemArray(vr *ValueReader) []AclItem { // remove the '{' at the front and the '}' at the end str = str[1 : len(str)-1] - strs := strings.Split(str, ",") + strs, _ := ParseAclItemArray(str) + // XXX: what do I do with the error here? + // XXX strs := strings.Split(str, ",") // cast strings into AclItems before returning aclitems := make([]AclItem, len(strs)) diff --git a/values_test.go b/values_test.go index 01e5114b..8c5d1032 100644 --- a/values_test.go +++ b/values_test.go @@ -672,13 +672,14 @@ func TestAclArrayDecoding(t *testing.T) { []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres"}, }, { - []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres", `postgres=arwdDxt/\" tricky\, ' \} \"\" \\ test user \"`}, + []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres", `postgres=arwdDxt/" tricky, ' } "" \ test user "`}, }, } for i, tt := range tests { err := conn.QueryRow(sql, tt.query).Scan(&scan) if err != nil { - t.Errorf(`%d. error reading array: %v`, i, err) + // t.Errorf(`%d. error reading array: %v`, i, err) + t.Errorf(`%d. error reading array: %v query: %s`, i, err, tt.query) if pgerr, ok := err.(pgx.PgError); ok { t.Errorf(`%d. error reading array (detail): %s`, i, pgerr.Detail) } From 6ec7e84dbfe77cf613aab7077cc532a0daeda80b Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Tue, 15 Nov 2016 22:05:52 -0500 Subject: [PATCH 63/75] Handles parse error for aclitem[] --- values.go | 8 +++++--- values_test.go | 1 - 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/values.go b/values.go index fe8f82fa..cf3a0266 100644 --- a/values.go +++ b/values.go @@ -3176,9 +3176,11 @@ func decodeAclItemArray(vr *ValueReader) []AclItem { // remove the '{' at the front and the '}' at the end str = str[1 : len(str)-1] - strs, _ := ParseAclItemArray(str) - // XXX: what do I do with the error here? - // XXX strs := strings.Split(str, ",") + strs, err := ParseAclItemArray(str) + if err != nil { + vr.Fatal(ProtocolError(err.Error())) + return nil + } // cast strings into AclItems before returning aclitems := make([]AclItem, len(strs)) diff --git a/values_test.go b/values_test.go index 8c5d1032..bbb22f24 100644 --- a/values_test.go +++ b/values_test.go @@ -649,7 +649,6 @@ func assertAclItemSlicesEqual(t *testing.T, query, scan []pgx.AclItem) { } } -// XXX func TestAclArrayDecoding(t *testing.T) { t.Parallel() From 1ebcbab8a30b72bb9568839794116ccac0121bf1 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Tue, 15 Nov 2016 22:09:55 -0500 Subject: [PATCH 64/75] Removes unneeded XXXs --- values.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/values.go b/values.go index cf3a0266..61aaeb81 100644 --- a/values.go +++ b/values.go @@ -3024,7 +3024,6 @@ func NeedsEscape(rn rune) bool { return rn == '\\' || rn == ',' || rn == '"' || rn == '}' } -// XXX: encodeAclItemSlice; using text encoding, not binary func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { // cast aclitems into strings so we can use strings.Join strs := make([]string, len(aclitems)) @@ -3160,7 +3159,6 @@ func ReadPossiblyEscapedRune(r *strings.Reader) (rune, bool, error) { return rn, false, nil } -// XXX: decodeAclItemArray; using text encoding, not binary func decodeAclItemArray(vr *ValueReader) []AclItem { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into []AclItem")) From 7b3488b088492d493be71487c28209cb7b005feb Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Tue, 15 Nov 2016 22:14:08 -0500 Subject: [PATCH 65/75] Makes parseAclItemArray helpers private --- values.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/values.go b/values.go index 61aaeb81..d8ad8fb3 100644 --- a/values.go +++ b/values.go @@ -3044,7 +3044,7 @@ func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { return nil } -func ParseAclItemArray(arr string) ([]string, error) { +func parseAclItemArray(arr string) ([]string, error) { r := strings.NewReader(arr) // Difficult to guess a performant initial capacity for a slice of // values, but let's go with 5. @@ -3066,7 +3066,7 @@ func ParseAclItemArray(arr string) ([]string, error) { if rn == '"' { // Discard the opening quote of the quoted value. - vlu, err = ParseQuotedAclItem(r) + vlu, err = parseQuotedAclItem(r) } else { // We have just read the first rune of an unquoted (bare) value; // put it back so that ParseBareValue can read it. @@ -3075,7 +3075,7 @@ func ParseAclItemArray(arr string) ([]string, error) { // This error was not expected. return nil, err } - vlu, err = ParseBareAclItem(r) + vlu, err = parseBareAclItem(r) } if err != nil { @@ -3091,7 +3091,7 @@ func ParseAclItemArray(arr string) ([]string, error) { } } -func ParseBareAclItem(r *strings.Reader) (string, error) { +func parseBareAclItem(r *strings.Reader) (string, error) { var buf bytes.Buffer for { rn, _, err := r.ReadRune() @@ -3109,10 +3109,10 @@ func ParseBareAclItem(r *strings.Reader) (string, error) { } } -func ParseQuotedAclItem(r *strings.Reader) (string, error) { +func parseQuotedAclItem(r *strings.Reader) (string, error) { var buf bytes.Buffer for { - rn, escaped, err := ReadPossiblyEscapedRune(r) + rn, escaped, err := readPossiblyEscapedRune(r) if err != nil { if err == io.EOF { // Even when it is the last value, the final rune of @@ -3143,7 +3143,7 @@ func ParseQuotedAclItem(r *strings.Reader) (string, error) { // in that case, it returns the rune after the backslash. The second // return value tells us whether or not the rune was // preceeded by a backslash (escaped). -func ReadPossiblyEscapedRune(r *strings.Reader) (rune, bool, error) { +func readPossiblyEscapedRune(r *strings.Reader) (rune, bool, error) { rn, _, err := r.ReadRune() if err != nil { return 0, false, err @@ -3174,7 +3174,7 @@ func decodeAclItemArray(vr *ValueReader) []AclItem { // remove the '{' at the front and the '}' at the end str = str[1 : len(str)-1] - strs, err := ParseAclItemArray(str) + strs, err := parseAclItemArray(str) if err != nil { vr.Fatal(ProtocolError(err.Error())) return nil From 323e2b3f7817715c392cc848417c485bf5564bfd Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Tue, 15 Nov 2016 22:22:57 -0500 Subject: [PATCH 66/75] Adds aclitem helper func tests --- aclitem_parse_test.go | 126 ++++++++++++++++++++++++++++++++++++++++++ values.go | 4 +- 2 files changed, 128 insertions(+), 2 deletions(-) create mode 100644 aclitem_parse_test.go diff --git a/aclitem_parse_test.go b/aclitem_parse_test.go new file mode 100644 index 00000000..a0e5e858 --- /dev/null +++ b/aclitem_parse_test.go @@ -0,0 +1,126 @@ +package pgx + +import ( + "reflect" + "testing" +) + +func TestEscapeAclItem(t *testing.T) { + tests := []struct { + input string + expected string + }{ + { + "foo", + "foo", + }, + { + `foo, "\}`, + `foo\, \"\\\}`, + }, + } + + for i, tt := range tests { + actual, err := escapeAclItem(tt.input) + + if err != nil { + t.Errorf("%d. Unexpected error %v", i, err) + } + + if actual != tt.expected { + t.Errorf("%d.\nexpected: %s,\nactual: %s", i, tt.expected, actual) + } + } +} + +func TestParseAclItemArray(t *testing.T) { + tests := []struct { + input string + expected []string + errMsg string + }{ + { + "", + []string{}, + "", + }, + { + "one", + []string{"one"}, + "", + }, + { + `"one"`, + []string{"one"}, + "", + }, + { + "one,two,three", + []string{"one", "two", "three"}, + "", + }, + { + `"one","two","three"`, + []string{"one", "two", "three"}, + "", + }, + { + `"one",two,"three"`, + []string{"one", "two", "three"}, + "", + }, + { + `one,two,"three"`, + []string{"one", "two", "three"}, + "", + }, + { + `"one","two",three`, + []string{"one", "two", "three"}, + "", + }, + { + `"one","t w o",three`, + []string{"one", "t w o", "three"}, + "", + }, + { + `"one","t, w o\"\}\\",three`, + []string{"one", `t, w o"}\`, "three"}, + "", + }, + { + `"one","two",three"`, + []string{"one", "two", `three"`}, + "", + }, + { + `"one","two,"three"`, + nil, + "unexpected rune after quoted value", + }, + { + `"one","two","three`, + nil, + "unexpected end of quoted value", + }, + } + + for i, tt := range tests { + actual, err := parseAclItemArray(tt.input) + + if err != nil { + if tt.errMsg == "" { + t.Errorf("%d. Unexpected error %v", i, err) + } else if err.Error() != tt.errMsg { + t.Errorf("%d. Expected error %v did not match actual error %v", i, tt.errMsg, err.Error()) + } + } else if tt.errMsg != "" { + t.Errorf("%d. Expected error not returned: \"%v\"", i, tt.errMsg) + } + + if !reflect.DeepEqual(actual, tt.expected) { + t.Errorf("%d. Expected %v did not match actual %v", i, tt.expected, actual) + } + } +} diff --git a/values.go b/values.go index d8ad8fb3..a209ec80 100644 --- a/values.go +++ b/values.go @@ -3000,7 +3000,7 @@ func decodeTextArray(vr *ValueReader) []string { return a } -func EscapeAclItem(acl string) (string, error) { +func escapeAclItem(acl string) (string, error) { var buf bytes.Buffer r := strings.NewReader(acl) for { @@ -3030,7 +3030,7 @@ func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { var escaped string var err error for i := range strs { - escaped, err = EscapeAclItem(string(aclitems[i])) + escaped, err = escapeAclItem(string(aclitems[i])) if err != nil { return err } From 4b430a254ec0bfc251270e639b1099000711ac54 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Thu, 17 Nov 2016 21:38:00 -0500 Subject: [PATCH 67/75] Improves docs around aclitem[] --- values.go | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/values.go b/values.go index a209ec80..4312a61f 100644 --- a/values.go +++ b/values.go @@ -3000,6 +3000,12 @@ func decodeTextArray(vr *ValueReader) []string { return a } +// escapeAclItem escapes an AclItem before it is added to +// its aclitem[] string representation. The PostgreSQL aclitem +// datatype itself can need escapes because it follows the +// formatting rules of SQL identifiers. Think of this function +// as escaping the escapes, so that PostgreSQL's array parser +// will do the right thing. func escapeAclItem(acl string) (string, error) { var buf bytes.Buffer r := strings.NewReader(acl) @@ -3013,17 +3019,22 @@ func escapeAclItem(acl string) (string, error) { // This error was not expected return "", err } - if NeedsEscape(rn) { + if needsEscape(rn) { buf.WriteRune('\\') } buf.WriteRune(rn) } } -func NeedsEscape(rn rune) bool { +// needsEscape determines whether or not a rune needs escaping +// before being placed in the textual representation of an +// aclitem[] array. +func needsEscape(rn rune) bool { return rn == '\\' || rn == ',' || rn == '"' || rn == '}' } +// encodeAclItemSlice encodes a slice of AclItems in +// their textual represention for PostgreSQL. func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { // cast aclitems into strings so we can use strings.Join strs := make([]string, len(aclitems)) @@ -3044,10 +3055,12 @@ func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { return nil } +// parseAclItemArray parses the textual representation +// of the aclitem[] type. func parseAclItemArray(arr string) ([]string, error) { r := strings.NewReader(arr) // Difficult to guess a performant initial capacity for a slice of - // values, but let's go with 5. + // aclitems, but let's go with 5. vals := make([]string, 0, 5) // A single value vlu := "" From 3906f7c0d084c178ee8155abd9b8e570d5d43f25 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Thu, 17 Nov 2016 21:45:46 -0500 Subject: [PATCH 68/75] Casts aclitem earl to avoid O(2n) --- aclitem_parse_test.go | 24 ++++++++++++------------ values.go | 16 +++++----------- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/aclitem_parse_test.go b/aclitem_parse_test.go index a0e5e858..5c7c748f 100644 --- a/aclitem_parse_test.go +++ b/aclitem_parse_test.go @@ -36,62 +36,62 @@ func TestEscapeAclItem(t *testing.T) { func TestParseAclItemArray(t *testing.T) { tests := []struct { input string - expected []string + expected []AclItem errMsg string }{ { "", - []string{}, + []AclItem{}, "", }, { "one", - []string{"one"}, + []AclItem{"one"}, "", }, { `"one"`, - []string{"one"}, + []AclItem{"one"}, "", }, { "one,two,three", - []string{"one", "two", "three"}, + []AclItem{"one", "two", "three"}, "", }, { `"one","two","three"`, - []string{"one", "two", "three"}, + []AclItem{"one", "two", "three"}, "", }, { `"one",two,"three"`, - []string{"one", "two", "three"}, + []AclItem{"one", "two", "three"}, "", }, { `one,two,"three"`, - []string{"one", "two", "three"}, + []AclItem{"one", "two", "three"}, "", }, { `"one","two",three`, - []string{"one", "two", "three"}, + []AclItem{"one", "two", "three"}, "", }, { `"one","t w o",three`, - []string{"one", "t w o", "three"}, + []AclItem{"one", "t w o", "three"}, "", }, { `"one","t, w o\"\}\\",three`, - []string{"one", `t, w o"}\`, "three"}, + []AclItem{"one", `t, w o"}\`, "three"}, "", }, { `"one","two",three"`, - []string{"one", "two", `three"`}, + []AclItem{"one", "two", `three"`}, "", }, { diff --git a/values.go b/values.go index 4312a61f..490d2f38 100644 --- a/values.go +++ b/values.go @@ -3057,11 +3057,11 @@ func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { // parseAclItemArray parses the textual representation // of the aclitem[] type. -func parseAclItemArray(arr string) ([]string, error) { +func parseAclItemArray(arr string) ([]AclItem, error) { r := strings.NewReader(arr) // Difficult to guess a performant initial capacity for a slice of // aclitems, but let's go with 5. - vals := make([]string, 0, 5) + vals := make([]AclItem, 0, 5) // A single value vlu := "" for { @@ -3094,13 +3094,13 @@ func parseAclItemArray(arr string) ([]string, error) { if err != nil { if err == io.EOF { // This error was expected and is OK. - vals = append(vals, vlu) + vals = append(vals, AclItem(vlu)) return vals, nil } // This error was not expected. return nil, err } - vals = append(vals, vlu) + vals = append(vals, AclItem(vlu)) } } @@ -3192,13 +3192,7 @@ func decodeAclItemArray(vr *ValueReader) []AclItem { vr.Fatal(ProtocolError(err.Error())) return nil } - - // cast strings into AclItems before returning - aclitems := make([]AclItem, len(strs)) - for i := range aclitems { - aclitems[i] = AclItem(strs[i]) - } - return aclitems + return strs } func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error { From bce83fd4ba2f29c5528fc688f49c9df952205c1f Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Thu, 17 Nov 2016 21:59:05 -0500 Subject: [PATCH 69/75] Better names and efficiency --- values.go | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/values.go b/values.go index 490d2f38..c10ef782 100644 --- a/values.go +++ b/values.go @@ -3007,22 +3007,22 @@ func decodeTextArray(vr *ValueReader) []string { // as escaping the escapes, so that PostgreSQL's array parser // will do the right thing. func escapeAclItem(acl string) (string, error) { - var buf bytes.Buffer - r := strings.NewReader(acl) + var escapedAclItem bytes.Buffer + reader := strings.NewReader(acl) for { - rn, _, err := r.ReadRune() + rn, _, err := reader.ReadRune() if err != nil { if err == io.EOF { - // This error was expected and is OK - return buf.String(), nil + // Here, EOF is an expected end state, not an error. + return escapedAclItem.String(), nil } // This error was not expected return "", err } if needsEscape(rn) { - buf.WriteRune('\\') + escapedAclItem.WriteRune('\\') } - buf.WriteRune(rn) + escapedAclItem.WriteRune(rn) } } @@ -3038,18 +3038,21 @@ func needsEscape(rn rune) bool { func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { // cast aclitems into strings so we can use strings.Join strs := make([]string, len(aclitems)) - var escaped string + var escapedAclItem string var err error for i := range strs { - escaped, err = escapeAclItem(string(aclitems[i])) + escapedAclItem, err = escapeAclItem(string(aclitems[i])) if err != nil { return err } - strs[i] = string(escaped) + strs[i] = string(escapedAclItem) } - str := strings.Join(strs, ",") - str = "{" + str + "}" + var buf bytes.Buffer + buf.WriteRune('{') + buf.WriteString(strings.Join(strs, ",")) + buf.WriteRune('}') + str := buf.String() w.WriteInt32(int32(len(str))) w.WriteBytes([]byte(str)) return nil @@ -3070,7 +3073,7 @@ func parseAclItemArray(arr string) ([]AclItem, error) { rn, _, err := r.ReadRune() if err != nil { if err == io.EOF { - // This error was expected and is OK + // Here, EOF is an expected end state, not an error. return vals, nil } // This error was not expected @@ -3093,7 +3096,7 @@ func parseAclItemArray(arr string) ([]AclItem, error) { if err != nil { if err == io.EOF { - // This error was expected and is OK. + // Here, EOF is an expected end state, not an error.. vals = append(vals, AclItem(vlu)) return vals, nil } From 09ee8a9b703ca995aacfd4540915e37b62edb005 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Thu, 17 Nov 2016 22:08:56 -0500 Subject: [PATCH 70/75] Returns AclItem earlier --- values.go | 42 ++++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/values.go b/values.go index c10ef782..58ac118f 100644 --- a/values.go +++ b/values.go @@ -3036,7 +3036,6 @@ func needsEscape(rn rune) bool { // encodeAclItemSlice encodes a slice of AclItems in // their textual represention for PostgreSQL. func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { - // cast aclitems into strings so we can use strings.Join strs := make([]string, len(aclitems)) var escapedAclItem string var err error @@ -3061,20 +3060,20 @@ func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { // parseAclItemArray parses the textual representation // of the aclitem[] type. func parseAclItemArray(arr string) ([]AclItem, error) { - r := strings.NewReader(arr) + reader := strings.NewReader(arr) // Difficult to guess a performant initial capacity for a slice of // aclitems, but let's go with 5. - vals := make([]AclItem, 0, 5) + aclItems := make([]AclItem, 0, 5) // A single value - vlu := "" + aclItem := AclItem("") for { // Grab the first/next/last rune to see if we are dealing with a // quoted value, an unquoted value, or the end of the string. - rn, _, err := r.ReadRune() + rn, _, err := reader.ReadRune() if err != nil { if err == io.EOF { // Here, EOF is an expected end state, not an error. - return vals, nil + return aclItems, nil } // This error was not expected return nil, err @@ -3082,50 +3081,49 @@ func parseAclItemArray(arr string) ([]AclItem, error) { if rn == '"' { // Discard the opening quote of the quoted value. - vlu, err = parseQuotedAclItem(r) + aclItem, err = parseQuotedAclItem(reader) } else { // We have just read the first rune of an unquoted (bare) value; // put it back so that ParseBareValue can read it. - err := r.UnreadRune() + err := reader.UnreadRune() if err != nil { - // This error was not expected. return nil, err } - vlu, err = parseBareAclItem(r) + aclItem, err = parseBareAclItem(reader) } if err != nil { if err == io.EOF { // Here, EOF is an expected end state, not an error.. - vals = append(vals, AclItem(vlu)) - return vals, nil + aclItems = append(aclItems, aclItem) + return aclItems, nil } // This error was not expected. return nil, err } - vals = append(vals, AclItem(vlu)) + aclItems = append(aclItems, aclItem) } } -func parseBareAclItem(r *strings.Reader) (string, error) { +func parseBareAclItem(r *strings.Reader) (AclItem, error) { var buf bytes.Buffer for { rn, _, err := r.ReadRune() if err != nil { // Return the read value in case the error is a harmless io.EOF. // (io.EOF marks the end of a bare value at the end of a string) - return buf.String(), err + return AclItem(buf.String()), err } if rn == ',' { // A comma marks the end of a bare value. - return buf.String(), nil + return AclItem(buf.String()), nil } else { buf.WriteRune(rn) } } } -func parseQuotedAclItem(r *strings.Reader) (string, error) { +func parseQuotedAclItem(r *strings.Reader) (AclItem, error) { var buf bytes.Buffer for { rn, escaped, err := readPossiblyEscapedRune(r) @@ -3133,10 +3131,10 @@ func parseQuotedAclItem(r *strings.Reader) (string, error) { if err == io.EOF { // Even when it is the last value, the final rune of // a quoted value should be the final closing quote, not io.EOF. - return "", fmt.Errorf("unexpected end of quoted value") + return AclItem(""), fmt.Errorf("unexpected end of quoted value") } // Return the read value in case the error is a harmless io.EOF. - return buf.String(), err + return AclItem(buf.String()), err } if !escaped && rn == '"' { // An unescaped double quote marks the end of a quoted value. @@ -3144,12 +3142,12 @@ func parseQuotedAclItem(r *strings.Reader) (string, error) { rn, _, err := r.ReadRune() if err != nil { // Return the read value in case the error is a harmless io.EOF. - return buf.String(), err + return AclItem(buf.String()), err } if rn != ',' { - return "", fmt.Errorf("unexpected rune after quoted value") + return AclItem(""), fmt.Errorf("unexpected rune after quoted value") } - return buf.String(), nil + return AclItem(buf.String()), nil } buf.WriteRune(rn) } From 7bd2e85f31660a6fc200b064c8c8d99b0ba3974f Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Thu, 17 Nov 2016 22:18:09 -0500 Subject: [PATCH 71/75] Improves names and comments --- values.go | 57 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/values.go b/values.go index 58ac118f..77ae0bca 100644 --- a/values.go +++ b/values.go @@ -3105,51 +3105,55 @@ func parseAclItemArray(arr string) ([]AclItem, error) { } } -func parseBareAclItem(r *strings.Reader) (AclItem, error) { - var buf bytes.Buffer +// parseBareAclItem parses a bare (unquoted) aclitem from reader +func parseBareAclItem(reader *strings.Reader) (AclItem, error) { + var aclItem bytes.Buffer for { - rn, _, err := r.ReadRune() + rn, _, err := reader.ReadRune() if err != nil { // Return the read value in case the error is a harmless io.EOF. - // (io.EOF marks the end of a bare value at the end of a string) - return AclItem(buf.String()), err + // (io.EOF marks the end of a bare aclitem at the end of a string) + return AclItem(aclItem.String()), err } if rn == ',' { - // A comma marks the end of a bare value. - return AclItem(buf.String()), nil + // A comma marks the end of a bare aclitem. + return AclItem(aclItem.String()), nil } else { - buf.WriteRune(rn) + aclItem.WriteRune(rn) } } } -func parseQuotedAclItem(r *strings.Reader) (AclItem, error) { - var buf bytes.Buffer +// parseQuotedAclItem parses an aclitem which is in double quotes from reader +func parseQuotedAclItem(reader *strings.Reader) (AclItem, error) { + var aclItem bytes.Buffer for { - rn, escaped, err := readPossiblyEscapedRune(r) + rn, escaped, err := readPossiblyEscapedRune(reader) if err != nil { if err == io.EOF { // Even when it is the last value, the final rune of - // a quoted value should be the final closing quote, not io.EOF. + // a quoted aclitem should be the final closing quote, not io.EOF. return AclItem(""), fmt.Errorf("unexpected end of quoted value") } - // Return the read value in case the error is a harmless io.EOF. - return AclItem(buf.String()), err + // Return the read aclitem in case the error is a harmless io.EOF, + // which will be determined by the caller. + return AclItem(aclItem.String()), err } if !escaped && rn == '"' { // An unescaped double quote marks the end of a quoted value. // The next rune should either be a comma or the end of the string. - rn, _, err := r.ReadRune() + rn, _, err := reader.ReadRune() if err != nil { - // Return the read value in case the error is a harmless io.EOF. - return AclItem(buf.String()), err + // Return the read value in case the error is a harmless io.EOF, + // which will be determined by the caller. + return AclItem(aclItem.String()), err } if rn != ',' { return AclItem(""), fmt.Errorf("unexpected rune after quoted value") } - return AclItem(buf.String()), nil + return AclItem(aclItem.String()), nil } - buf.WriteRune(rn) + aclItem.WriteRune(rn) } } @@ -3157,14 +3161,14 @@ func parseQuotedAclItem(r *strings.Reader) (AclItem, error) { // in that case, it returns the rune after the backslash. The second // return value tells us whether or not the rune was // preceeded by a backslash (escaped). -func readPossiblyEscapedRune(r *strings.Reader) (rune, bool, error) { - rn, _, err := r.ReadRune() +func readPossiblyEscapedRune(reader *strings.Reader) (rune, bool, error) { + rn, _, err := reader.ReadRune() if err != nil { return 0, false, err } if rn == '\\' { // Discard the backslash and read the next rune. - rn, _, err = r.ReadRune() + rn, _, err = reader.ReadRune() if err != nil { return 0, false, err } @@ -3181,19 +3185,20 @@ func decodeAclItemArray(vr *ValueReader) []AclItem { str := vr.ReadString(vr.Len()) - // short-circuit empty array + // Short-circuit empty array. if str == "{}" { return []AclItem{} } - // remove the '{' at the front and the '}' at the end + // Remove the '{' at the front and the '}' at the end, + // so that parseAclItemArray doesn't have to deal with them. str = str[1 : len(str)-1] - strs, err := parseAclItemArray(str) + aclItems, err := parseAclItemArray(str) if err != nil { vr.Fatal(ProtocolError(err.Error())) return nil } - return strs + return aclItems } func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error { From 3beac831cf40e5361fa4edda0b2d25af91e14ff3 Mon Sep 17 00:00:00 2001 From: Manni Wood Date: Thu, 17 Nov 2016 22:25:00 -0500 Subject: [PATCH 72/75] Adds formatting notes --- values.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/values.go b/values.go index 77ae0bca..8a7a49cb 100644 --- a/values.go +++ b/values.go @@ -3058,7 +3058,10 @@ func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { } // parseAclItemArray parses the textual representation -// of the aclitem[] type. +// of the aclitem[] type. The textual representation is chosen because +// Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin). +// See https://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +// for formatting notes. func parseAclItemArray(arr string) ([]AclItem, error) { reader := strings.NewReader(arr) // Difficult to guess a performant initial capacity for a slice of From f8930d614fe24d1cb5fcf1b93b74728f30ecf5d6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 23 Nov 2016 12:00:48 -0600 Subject: [PATCH 73/75] tricky user does not need superuser --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b78fbc5c..ccbd1dc1 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ To setup the normal test environment, first install these dependencies: Then run the following SQL: create user pgx_md5 password 'secret'; - create user " tricky, ' } "" \ test user " superuser password 'secret'; + create user " tricky, ' } "" \ test user " password 'secret'; create database pgx_test; Connect to database pgx_test and run: From c952d48a5c8a3678c5bcdac6f247cc78bb749824 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 23 Nov 2016 12:29:21 -0600 Subject: [PATCH 74/75] Return first err in decodeJSONB fixes #212 --- values.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/values.go b/values.go index 8a7a49cb..8d2bfefd 100644 --- a/values.go +++ b/values.go @@ -2081,12 +2081,16 @@ func decodeJSONB(vr *ValueReader, d interface{}) error { } if vr.Type().DataType != JsonbOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into jsonb", vr.Type().DataType))) + err := ProtocolError(fmt.Sprintf("Cannot decode oid %v into jsonb", vr.Type().DataType)) + vr.Fatal(err) + return err } bytes := vr.ReadBytes(vr.Len()) if bytes[0] != 1 { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown jsonb format byte: %x", bytes[0]))) + err := ProtocolError(fmt.Sprintf("Unknown jsonb format byte: %x", bytes[0])) + vr.Fatal(err) + return err } err := json.Unmarshal(bytes[1:], d) From e96c105b55d469fddef56bfa7db5ecc7ab31d7b2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 23 Nov 2016 12:52:04 -0600 Subject: [PATCH 75/75] decodeJSONB works for text and binary format --- values.go | 13 +++++---- values_test.go | 71 +++++++++++++++++++++++++++----------------------- 2 files changed, 47 insertions(+), 37 deletions(-) diff --git a/values.go b/values.go index 8d2bfefd..b4466b82 100644 --- a/values.go +++ b/values.go @@ -2087,13 +2087,16 @@ func decodeJSONB(vr *ValueReader, d interface{}) error { } bytes := vr.ReadBytes(vr.Len()) - if bytes[0] != 1 { - err := ProtocolError(fmt.Sprintf("Unknown jsonb format byte: %x", bytes[0])) - vr.Fatal(err) - return err + if vr.Type().FormatCode == BinaryFormatCode { + if bytes[0] != 1 { + err := ProtocolError(fmt.Sprintf("Unknown jsonb format byte: %x", bytes[0])) + vr.Fatal(err) + return err + } + bytes = bytes[1:] } - err := json.Unmarshal(bytes[1:], d) + err := json.Unmarshal(bytes, d) if err != nil { vr.Fatal(err) } diff --git a/values_test.go b/values_test.go index bbb22f24..42d5bd3d 100644 --- a/values_test.go +++ b/values_test.go @@ -88,67 +88,74 @@ func TestJsonAndJsonbTranscode(t *testing.T) { if _, ok := conn.PgTypes[oid]; !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } - typename := conn.PgTypes[oid].Name - testJsonString(t, conn, typename) - testJsonStringPointer(t, conn, typename) - testJsonSingleLevelStringMap(t, conn, typename) - testJsonNestedMap(t, conn, typename) - testJsonStringArray(t, conn, typename) - testJsonInt64Array(t, conn, typename) - testJsonInt16ArrayFailureDueToOverflow(t, conn, typename) - testJsonStruct(t, conn, typename) + for _, format := range []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} { + pgtype := conn.PgTypes[oid] + pgtype.DefaultFormat = format + conn.PgTypes[oid] = pgtype + + typename := conn.PgTypes[oid].Name + + testJsonString(t, conn, typename, format) + testJsonStringPointer(t, conn, typename, format) + testJsonSingleLevelStringMap(t, conn, typename, format) + testJsonNestedMap(t, conn, typename, format) + testJsonStringArray(t, conn, typename, format) + testJsonInt64Array(t, conn, typename, format) + testJsonInt16ArrayFailureDueToOverflow(t, conn, typename, format) + testJsonStruct(t, conn, typename, format) + } } } -func testJsonString(t *testing.T, conn *pgx.Conn, typename string) { +func testJsonString(t *testing.T, conn *pgx.Conn, typename string, format int16) { input := `{"key": "value"}` expectedOutput := map[string]string{"key": "value"} var output map[string]string err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) return } if !reflect.DeepEqual(expectedOutput, output) { - t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output) + t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) return } } -func testJsonStringPointer(t *testing.T, conn *pgx.Conn, typename string) { +func testJsonStringPointer(t *testing.T, conn *pgx.Conn, typename string, format int16) { input := `{"key": "value"}` expectedOutput := map[string]string{"key": "value"} var output map[string]string err := conn.QueryRow("select $1::"+typename, &input).Scan(&output) if err != nil { - t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) return } if !reflect.DeepEqual(expectedOutput, output) { - t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output) + t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) return } } -func testJsonSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) { +func testJsonSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { input := map[string]string{"key": "value"} var output map[string]string err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) return } if !reflect.DeepEqual(input, output) { - t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, input, output) + t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, input, output) return } } -func testJsonNestedMap(t *testing.T, conn *pgx.Conn, typename string) { +func testJsonNestedMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { input := map[string]interface{}{ "name": "Uncanny", "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, @@ -157,52 +164,52 @@ func testJsonNestedMap(t *testing.T, conn *pgx.Conn, typename string) { var output map[string]interface{} err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) return } if !reflect.DeepEqual(input, output) { - t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, input, output) + t.Errorf("%s %d: Did not transcode map[string]interface{} successfully: %v is not %v", typename, format, input, output) return } } -func testJsonStringArray(t *testing.T, conn *pgx.Conn, typename string) { +func testJsonStringArray(t *testing.T, conn *pgx.Conn, typename string, format int16) { input := []string{"foo", "bar", "baz"} var output []string err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) } if !reflect.DeepEqual(input, output) { - t.Errorf("%s: Did not transcode []string successfully: %v is not %v", typename, input, output) + t.Errorf("%s %d: Did not transcode []string successfully: %v is not %v", typename, format, input, output) } } -func testJsonInt64Array(t *testing.T, conn *pgx.Conn, typename string) { +func testJsonInt64Array(t *testing.T, conn *pgx.Conn, typename string, format int16) { input := []int64{1, 2, 234432} var output []int64 err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) } if !reflect.DeepEqual(input, output) { - t.Errorf("%s: Did not transcode []int64 successfully: %v is not %v", typename, input, output) + t.Errorf("%s %d: Did not transcode []int64 successfully: %v is not %v", typename, format, input, output) } } -func testJsonInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string) { +func testJsonInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string, format int16) { input := []int{1, 2, 234432} var output []int16 err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" { - t.Errorf("%s: Expected *json.UnmarkalTypeError, but got %v", typename, err) + t.Errorf("%s %d: Expected *json.UnmarkalTypeError, but got %v", typename, format, err) } } -func testJsonStruct(t *testing.T, conn *pgx.Conn, typename string) { +func testJsonStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) { type person struct { Name string `json:"name"` Age int `json:"age"` @@ -217,11 +224,11 @@ func testJsonStruct(t *testing.T, conn *pgx.Conn, typename string) { err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) } if !reflect.DeepEqual(input, output) { - t.Errorf("%s: Did not transcode struct successfully: %v is not %v", typename, input, output) + t.Errorf("%s %d: Did not transcode struct successfully: %v is not %v", typename, format, input, output) } }