From 5f7d01778eaf02b0c0ef9871b934952bbf9afed5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 10 Aug 2016 16:27:44 -0500 Subject: [PATCH] Add CopyTo to support PostgreSQL copy protocol fixes #102 --- CHANGELOG.md | 1 + README.md | 1 + bench_test.go | 331 ++++++++++++++++++++++++++++++++++++++++++ conn_pool.go | 11 ++ copy_to.go | 241 +++++++++++++++++++++++++++++++ copy_to_test.go | 373 ++++++++++++++++++++++++++++++++++++++++++++++++ doc.go | 20 +++ messages.go | 4 + tx.go | 9 ++ 9 files changed, 991 insertions(+) create mode 100644 copy_to.go create mode 100644 copy_to_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d185f2b..26a1590d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ ## Features +* Add CopyTo * Add PrepareEx * Add basic record to []interface{} decoding * Encode and decode between all Go and PostgreSQL integer types with bounds checking diff --git a/README.md b/README.md index c90bf966..607b38cd 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ Pgx supports many additional features beyond what is available through database/ * Transaction isolation level control * Full TLS connection control * Binary format support for custom types (can be much faster) +* Copy protocol support for faster bulk data loads * Logging support * Configurable connection pool with after connect hooks to do arbitrary connection setup * PostgreSQL array to Go slice mapping for integers, floats, and strings diff --git a/bench_test.go b/bench_test.go index eb9c0595..1ea92cc4 100644 --- a/bench_test.go +++ b/bench_test.go @@ -1,6 +1,9 @@ package pgx_test import ( + "bytes" + "fmt" + "strings" "testing" "time" @@ -432,3 +435,331 @@ func BenchmarkLog15Discard(b *testing.B) { logger.Debug("benchmark", "i", i, "b.N", b.N) } } + +const benchmarkWriteTableCreateSQL = `drop table if exists t; + +create table t( + varchar_1 varchar not null, + varchar_2 varchar not null, + varchar_null_1 varchar, + date_1 date not null, + date_null_1 date, + int4_1 int4 not null, + int4_2 int4 not null, + int4_null_1 int4, + tstz_1 timestamptz not null, + tstz_2 timestamptz, + bool_1 bool not null, + bool_2 bool not null, + bool_3 bool not null +); +` + +const benchmarkWriteTableInsertSQL = `insert into t( + varchar_1, + varchar_2, + varchar_null_1, + date_1, + date_null_1, + int4_1, + int4_2, + int4_null_1, + tstz_1, + tstz_2, + bool_1, + bool_2, + bool_3 +) values ( + $1::varchar, + $2::varchar, + $3::varchar, + $4::date, + $5::date, + $6::int4, + $7::int4, + $8::int4, + $9::timestamptz, + $10::timestamptz, + $11::bool, + $12::bool, + $13::bool +)` + +type benchmarkWriteTableCopyToSrc struct { + count int + idx int + row []interface{} +} + +func (s *benchmarkWriteTableCopyToSrc) Next() bool { + s.idx++ + return s.idx < s.count +} + +func (s *benchmarkWriteTableCopyToSrc) Values() ([]interface{}, error) { + return s.row, nil +} + +func (s *benchmarkWriteTableCopyToSrc) Err() error { + return nil +} + +func newBenchmarkWriteTableCopyToSrc(count int) pgx.CopyToSource { + return &benchmarkWriteTableCopyToSrc{ + count: count, + row: []interface{}{ + "varchar_1", + "varchar_2", + pgx.NullString{}, + time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), + pgx.NullTime{}, + 1, + 2, + pgx.NullInt32{}, + time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local), + true, + false, + true, + }, + } +} + +func benchmarkWriteNRowsViaInsert(b *testing.B, n int) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + mustExec(b, conn, benchmarkWriteTableCreateSQL) + _, err := conn.Prepare("insert_t", benchmarkWriteTableInsertSQL) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + src := newBenchmarkWriteTableCopyToSrc(n) + + tx, err := conn.Begin() + if err != nil { + b.Fatal(err) + } + + for src.Next() { + values, _ := src.Values() + if _, err = tx.Exec("insert_t", values...); err != nil { + b.Fatalf("Exec unexpectedly failed with: %v", err) + } + } + + err = tx.Commit() + if err != nil { + b.Fatal(err) + } + } +} + +// note this function is only used for benchmarks -- it doesn't escape tableName +// or columnNames +func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc pgx.CopyToSource) (int, error) { + maxRowsPerInsert := 65535 / len(columnNames) + rowsThisInsert := 0 + rowCount := 0 + + sqlBuf := &bytes.Buffer{} + args := make(pgx.QueryArgs, 0) + + resetQuery := func() { + sqlBuf.Reset() + fmt.Fprintf(sqlBuf, "insert into %s(%s) values", tableName, strings.Join(columnNames, ", ")) + + args = args[0:0] + + rowsThisInsert = 0 + } + resetQuery() + + tx, err := conn.Begin() + if err != nil { + return 0, err + } + defer tx.Rollback() + + for rowSrc.Next() { + if rowsThisInsert > 0 { + sqlBuf.WriteByte(',') + } + + sqlBuf.WriteByte('(') + + values, err := rowSrc.Values() + if err != nil { + return 0, err + } + + for i, val := range values { + if i > 0 { + sqlBuf.WriteByte(',') + } + sqlBuf.WriteString(args.Append(val)) + } + + sqlBuf.WriteByte(')') + + rowsThisInsert++ + + if rowsThisInsert == maxRowsPerInsert { + _, err := tx.Exec(sqlBuf.String(), args...) + if err != nil { + return 0, err + } + + rowCount += rowsThisInsert + resetQuery() + } + } + + if rowsThisInsert > 0 { + _, err := tx.Exec(sqlBuf.String(), args...) + if err != nil { + return 0, err + } + + rowCount += rowsThisInsert + } + + if err := tx.Commit(); err != nil { + return 0, nil + } + + return rowCount, nil + +} + +func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + mustExec(b, conn, benchmarkWriteTableCreateSQL) + _, err := conn.Prepare("insert_t", benchmarkWriteTableInsertSQL) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + src := newBenchmarkWriteTableCopyToSrc(n) + + _, err := multiInsert(conn, "t", + []string{"varchar_1", + "varchar_2", + "varchar_null_1", + "date_1", + "date_null_1", + "int4_1", + "int4_2", + "int4_null_1", + "tstz_1", + "tstz_2", + "bool_1", + "bool_2", + "bool_3"}, + src) + if err != nil { + b.Fatal(err) + } + } +} + +func benchmarkWriteNRowsViaCopy(b *testing.B, n int) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + mustExec(b, conn, benchmarkWriteTableCreateSQL) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + src := newBenchmarkWriteTableCopyToSrc(n) + + _, err := conn.CopyTo("t", + []string{"varchar_1", + "varchar_2", + "varchar_null_1", + "date_1", + "date_null_1", + "int4_1", + "int4_2", + "int4_null_1", + "tstz_1", + "tstz_2", + "bool_1", + "bool_2", + "bool_3"}, + src) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkWrite5RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 5) +} + +func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 5) +} + +func BenchmarkWrite5RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 5) +} + +func BenchmarkWrite10RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 10) +} + +func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 10) +} + +func BenchmarkWrite10RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 10) +} + +func BenchmarkWrite100RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 100) +} + +func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 100) +} + +func BenchmarkWrite100RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 100) +} + +func BenchmarkWrite1000RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 1000) +} + +func BenchmarkWrite1000RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 1000) +} + +func BenchmarkWrite1000RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 1000) +} + +func BenchmarkWrite10000RowsViaInsert(b *testing.B) { + benchmarkWriteNRowsViaInsert(b, 10000) +} + +func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) { + benchmarkWriteNRowsViaMultiInsert(b, 10000) +} + +func BenchmarkWrite10000RowsViaCopy(b *testing.B) { + benchmarkWriteNRowsViaCopy(b, 10000) +} diff --git a/conn_pool.go b/conn_pool.go index 9e468cbb..fdd54114 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -496,3 +496,14 @@ func (p *ConnPool) BeginIso(iso string) (*Tx, error) { return tx, nil } } + +// CopyTo acquires a connection, delegates the call to that connection, and releases the connection +func (p *ConnPool) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { + c, err := p.Acquire() + if err != nil { + return 0, err + } + defer p.Release(c) + + return c.CopyTo(tableName, columnNames, rowSrc) +} diff --git a/copy_to.go b/copy_to.go new file mode 100644 index 00000000..91292bb0 --- /dev/null +++ b/copy_to.go @@ -0,0 +1,241 @@ +package pgx + +import ( + "bytes" + "fmt" +) + +// CopyToRows returns a CopyToSource interface over the provided rows slice +// making it usable by *Conn.CopyTo. +func CopyToRows(rows [][]interface{}) CopyToSource { + return ©ToRows{rows: rows, idx: -1} +} + +type copyToRows struct { + rows [][]interface{} + idx int +} + +func (ctr *copyToRows) Next() bool { + ctr.idx++ + return ctr.idx < len(ctr.rows) +} + +func (ctr *copyToRows) Values() ([]interface{}, error) { + return ctr.rows[ctr.idx], nil +} + +func (ctr *copyToRows) Err() error { + return nil +} + +// CopyToSource is the interface used by *Conn.CopyTo as the source for copy data. +type CopyToSource interface { + // Next returns true if there is another row and makes the next row data + // available to Values(). When there are no more rows available or an error + // has occurred it returns false. + Next() bool + + // Values returns the values for the current row. + Values() ([]interface{}, error) + + // Err returns any error that has been encountered by the CopyToSource. If + // this is not nil *Conn.CopyTo will abort the copy. + Err() error +} + +type copyTo struct { + conn *Conn + tableName string + columnNames []string + rowSrc CopyToSource + readerErrChan chan error +} + +func (ct *copyTo) readUntilReadyForQuery() { + for { + t, r, err := ct.conn.rxMsg() + if err != nil { + ct.readerErrChan <- err + close(ct.readerErrChan) + return + } + + switch t { + case readyForQuery: + ct.conn.rxReadyForQuery(r) + close(ct.readerErrChan) + return + case commandComplete: + case errorResponse: + ct.readerErrChan <- ct.conn.rxErrorResponse(r) + default: + err = ct.conn.processContextFreeMsg(t, r) + if err != nil { + ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r) + } + } + } +} + +func (ct *copyTo) waitForReaderDone() error { + var err error + for err = range ct.readerErrChan { + } + return err +} + +func (ct *copyTo) run() (int, error) { + quotedTableName := quoteIdentifier(ct.tableName) + buf := &bytes.Buffer{} + for i, cn := range ct.columnNames { + if i != 0 { + buf.WriteString(", ") + } + buf.WriteString(quoteIdentifier(cn)) + } + quotedColumnNames := buf.String() + + ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) + if err != nil { + return 0, err + } + + err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) + if err != nil { + return 0, err + } + + err = ct.conn.readUntilCopyInResponse() + if err != nil { + return 0, err + } + + go ct.readUntilReadyForQuery() + defer ct.waitForReaderDone() + + wbuf := newWriteBuf(ct.conn, copyData) + + wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000")) + wbuf.WriteInt32(0) + wbuf.WriteInt32(0) + + var sentCount int + + for ct.rowSrc.Next() { + select { + case err = <-ct.readerErrChan: + return 0, err + default: + } + + if len(wbuf.buf) > 65536 { + wbuf.closeMsg() + _, err = ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return 0, err + } + + // Directly manipulate wbuf to reset to reuse the same buffer + wbuf.buf = wbuf.buf[0:5] + wbuf.buf[0] = copyData + wbuf.sizeIdx = 1 + } + + sentCount++ + + values, err := ct.rowSrc.Values() + if err != nil { + ct.cancelCopyIn() + return 0, err + } + if len(values) != len(ct.columnNames) { + ct.cancelCopyIn() + return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) + } + + wbuf.WriteInt16(int16(len(ct.columnNames))) + for i, val := range values { + err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val) + if err != nil { + ct.cancelCopyIn() + return 0, err + } + + } + } + + if ct.rowSrc.Err() != nil { + ct.cancelCopyIn() + return 0, ct.rowSrc.Err() + } + + wbuf.WriteInt16(-1) // terminate the copy stream + + wbuf.startMsg(copyDone) + wbuf.closeMsg() + _, err = ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return 0, err + } + + err = ct.waitForReaderDone() + if err != nil { + return 0, err + } + return sentCount, nil +} + +func (c *Conn) readUntilCopyInResponse() error { + for { + var t byte + var r *msgReader + t, r, err := c.rxMsg() + if err != nil { + return err + } + + switch t { + case copyInResponse: + return nil + default: + err = c.processContextFreeMsg(t, r) + if err != nil { + return err + } + } + } +} + +func (ct *copyTo) cancelCopyIn() error { + wbuf := newWriteBuf(ct.conn, copyFail) + wbuf.WriteCString("client error: abort") + wbuf.closeMsg() + _, err := ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return err + } + + return nil +} + +// CopyTo uses the PostgreSQL copy protocol to perform bulk data insertion. +// It returns the number of rows copied and an error. +// +// CopyTo requires all values use the binary format. Almost all types +// implemented by pgx use the binary format by default. Types implementing +// Encoder can only be used if they encode to the binary format. +func (c *Conn) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { + ct := ©To{ + conn: c, + tableName: tableName, + columnNames: columnNames, + rowSrc: rowSrc, + readerErrChan: make(chan error), + } + + return ct.run() +} diff --git a/copy_to_test.go b/copy_to_test.go new file mode 100644 index 00000000..d810c4fb --- /dev/null +++ b/copy_to_test.go @@ -0,0 +1,373 @@ +package pgx_test + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/jackc/pgx" +) + +func TestConnCopyToSmall(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz + )`) + + inputRows := [][]interface{}{ + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, + {nil, nil, nil, nil, nil, nil, nil}, + } + + copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyToRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyTo: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyToLarge(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz, + h bytea + )`) + + inputRows := [][]interface{}{} + + for i := 0; i < 10000; i++ { + inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) + } + + copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyToRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyTo: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal") + } + + ensureConnValid(t, conn) +} + +func TestConnCopyToFailServerSideMidway(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int4, + b varchar not null + )`) + + inputRows := [][]interface{}{ + {int32(1), "abc"}, + {int32(2), nil}, // this row should trigger a failure + {int32(3), "def"}, + } + + copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows)) + if err == nil { + t.Errorf("Expected CopyTo return error, but it did not") + } + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err) + } + if copyCount != 0 { + t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type failSource struct { + count int +} + +func (fs *failSource) Next() bool { + time.Sleep(time.Millisecond * 100) + fs.count++ + return fs.count < 100 +} + +func (fs *failSource) Values() ([]interface{}, error) { + if fs.count == 3 { + return []interface{}{nil}, nil + } + return []interface{}{make([]byte, 100000)}, nil +} + +func (fs *failSource) Err() error { + return nil +} + +func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + startTime := time.Now() + + copyCount, err := conn.CopyTo("foo", []string{"a"}, &failSource{}) + if err == nil { + t.Errorf("Expected CopyTo return error, but it did not") + } + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err) + } + if copyCount != 0 { + t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) + } + + endTime := time.Now() + copyTime := endTime.Sub(startTime) + if copyTime > time.Second { + t.Errorf("Failing CopyTo shouldn't have taken so long: %v", copyTime) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type clientFailSource struct { + count int + err error +} + +func (cfs *clientFailSource) Next() bool { + cfs.count++ + return cfs.count < 100 +} + +func (cfs *clientFailSource) Values() ([]interface{}, error) { + if cfs.count == 3 { + cfs.err = fmt.Errorf("client error") + return nil, cfs.err + } + return []interface{}{make([]byte, 100000)}, nil +} + +func (cfs *clientFailSource) Err() error { + return cfs.err +} + +func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFailSource{}) + if err == nil { + t.Errorf("Expected CopyTo return error, but it did not") + } + if copyCount != 0 { + t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type clientFinalErrSource struct { + count int +} + +func (cfs *clientFinalErrSource) Next() bool { + cfs.count++ + return cfs.count < 5 +} + +func (cfs *clientFinalErrSource) Values() ([]interface{}, error) { + return []interface{}{make([]byte, 100000)}, nil +} + +func (cfs *clientFinalErrSource) Err() error { + return fmt.Errorf("final error") +} + +func TestConnCopyToCopyToSourceErrorEnd(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFinalErrSource{}) + if err == nil { + t.Errorf("Expected CopyTo return error, but it did not") + } + if copyCount != 0 { + t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} diff --git a/doc.go b/doc.go index 0fd3d2f6..bf624c22 100644 --- a/doc.go +++ b/doc.go @@ -104,6 +104,26 @@ creates a transaction with a specified isolation level. return err } +Copy Protocol + +Use CopyTo to efficiently insert multiple rows at a time using the PostgreSQL +copy protocol. CopyTo accepts a CopyToSource interface. If the data is already +in a [][]interface{} use CopyToRows to wrap it in a CopyToSource interface. Or +implement CopyToSource to avoid buffering the entire data set in memory. + + rows := [][]interface{}{ + {"John", "Smith", int32(36)}, + {"Jane", "Doe", int32(29)}, + } + + copyCount, err := conn.CopyTo( + "people", + []string{"first_name", "last_name", "age"}, + pgx.CopyToRows(rows), + ) + +CopyTo can be faster than an insert with as few as 5 rows. + Listen and Notify pgx can listen to the PostgreSQL notification system with the diff --git a/messages.go b/messages.go index 1fbd9cbc..7f04f1f2 100644 --- a/messages.go +++ b/messages.go @@ -25,6 +25,10 @@ const ( noData = 'n' closeComplete = '3' flush = 'H' + copyInResponse = 'G' + copyData = 'd' + copyFail = 'f' + copyDone = 'c' ) type startupMessage struct { diff --git a/tx.go b/tx.go index e5c90c23..36f99c28 100644 --- a/tx.go +++ b/tx.go @@ -158,6 +158,15 @@ func (tx *Tx) QueryRow(sql string, args ...interface{}) *Row { return (*Row)(rows) } +// CopyTo delegates to the underlying *Conn +func (tx *Tx) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { + if tx.status != TxStatusInProgress { + return 0, ErrTxClosed + } + + return tx.conn.CopyTo(tableName, columnNames, rowSrc) +} + // Conn returns the *Conn this transaction is using. func (tx *Tx) Conn() *Conn { return tx.conn