diff --git a/conn.go b/conn.go index cf34d267..1007811e 100644 --- a/conn.go +++ b/conn.go @@ -146,6 +146,19 @@ func (ct CommandTag) RowsAffected() int64 { return n } +// Identifier a PostgreSQL identifier or name. Identifiers can be composed of +// multiple parts such as ["schema", "table"] or ["table", "column"]. +type Identifier []string + +// Sanitize returns a sanitized string safe for SQL interpolation. +func (ident Identifier) Sanitize() string { + parts := make([]string, len(ident)) + for i := range ident { + parts[i] = `"` + strings.Replace(ident[i], `"`, `""`, -1) + `"` + } + return strings.Join(parts, ".") +} + // ErrNoRows occurs when rows are expected but none are returned. var ErrNoRows = errors.New("no rows in result set") diff --git a/conn_test.go b/conn_test.go index d863999c..e1c780b8 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1541,3 +1541,40 @@ func TestSetLogLevel(t *testing.T) { t.Fatal("Expected logger to be called, but it wasn't") } } + +func TestIdentifierSanitize(t *testing.T) { + t.Parallel() + + tests := []struct { + ident pgx.Identifier + expected string + }{ + { + ident: pgx.Identifier{`foo`}, + expected: `"foo"`, + }, + { + ident: pgx.Identifier{`select`}, + expected: `"select"`, + }, + { + ident: pgx.Identifier{`foo`, `bar`}, + expected: `"foo"."bar"`, + }, + { + ident: pgx.Identifier{`you should " not do this`}, + expected: `"you should "" not do this"`, + }, + { + ident: pgx.Identifier{`you should " not do this`, `please don't`}, + expected: `"you should "" not do this"."please don't"`, + }, + } + + for i, tt := range tests { + qval := tt.ident.Sanitize() + if qval != tt.expected { + t.Errorf("%d. Expected Sanitize %v to return %v but it was %v", i, tt.ident, tt.expected, qval) + } + } +} diff --git a/copy_from.go b/copy_from.go new file mode 100644 index 00000000..1f8a2306 --- /dev/null +++ b/copy_from.go @@ -0,0 +1,241 @@ +package pgx + +import ( + "bytes" + "fmt" +) + +// CopyFromRows returns a CopyFromSource interface over the provided rows slice +// making it usable by *Conn.CopyFrom. +func CopyFromRows(rows [][]interface{}) CopyFromSource { + return ©FromRows{rows: rows, idx: -1} +} + +type copyFromRows struct { + rows [][]interface{} + idx int +} + +func (ctr *copyFromRows) Next() bool { + ctr.idx++ + return ctr.idx < len(ctr.rows) +} + +func (ctr *copyFromRows) Values() ([]interface{}, error) { + return ctr.rows[ctr.idx], nil +} + +func (ctr *copyFromRows) Err() error { + return nil +} + +// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data. +type CopyFromSource interface { + // Next returns true if there is another row and makes the next row data + // available to Values(). When there are no more rows available or an error + // has occurred it returns false. + Next() bool + + // Values returns the values for the current row. + Values() ([]interface{}, error) + + // Err returns any error that has been encountered by the CopyFromSource. If + // this is not nil *Conn.CopyFrom will abort the copy. + Err() error +} + +type copyFrom struct { + conn *Conn + tableName Identifier + columnNames []string + rowSrc CopyFromSource + readerErrChan chan error +} + +func (ct *copyFrom) readUntilReadyForQuery() { + for { + t, r, err := ct.conn.rxMsg() + if err != nil { + ct.readerErrChan <- err + close(ct.readerErrChan) + return + } + + switch t { + case readyForQuery: + ct.conn.rxReadyForQuery(r) + close(ct.readerErrChan) + return + case commandComplete: + case errorResponse: + ct.readerErrChan <- ct.conn.rxErrorResponse(r) + default: + err = ct.conn.processContextFreeMsg(t, r) + if err != nil { + ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r) + } + } + } +} + +func (ct *copyFrom) waitForReaderDone() error { + var err error + for err = range ct.readerErrChan { + } + return err +} + +func (ct *copyFrom) run() (int, error) { + quotedTableName := ct.tableName.Sanitize() + buf := &bytes.Buffer{} + for i, cn := range ct.columnNames { + if i != 0 { + buf.WriteString(", ") + } + buf.WriteString(quoteIdentifier(cn)) + } + quotedColumnNames := buf.String() + + ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) + if err != nil { + return 0, err + } + + err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) + if err != nil { + return 0, err + } + + err = ct.conn.readUntilCopyInResponse() + if err != nil { + return 0, err + } + + go ct.readUntilReadyForQuery() + defer ct.waitForReaderDone() + + wbuf := newWriteBuf(ct.conn, copyData) + + wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000")) + wbuf.WriteInt32(0) + wbuf.WriteInt32(0) + + var sentCount int + + for ct.rowSrc.Next() { + select { + case err = <-ct.readerErrChan: + return 0, err + default: + } + + if len(wbuf.buf) > 65536 { + wbuf.closeMsg() + _, err = ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return 0, err + } + + // Directly manipulate wbuf to reset to reuse the same buffer + wbuf.buf = wbuf.buf[0:5] + wbuf.buf[0] = copyData + wbuf.sizeIdx = 1 + } + + sentCount++ + + values, err := ct.rowSrc.Values() + if err != nil { + ct.cancelCopyIn() + return 0, err + } + if len(values) != len(ct.columnNames) { + ct.cancelCopyIn() + return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) + } + + wbuf.WriteInt16(int16(len(ct.columnNames))) + for i, val := range values { + err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val) + if err != nil { + ct.cancelCopyIn() + return 0, err + } + + } + } + + if ct.rowSrc.Err() != nil { + ct.cancelCopyIn() + return 0, ct.rowSrc.Err() + } + + wbuf.WriteInt16(-1) // terminate the copy stream + + wbuf.startMsg(copyDone) + wbuf.closeMsg() + _, err = ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return 0, err + } + + err = ct.waitForReaderDone() + if err != nil { + return 0, err + } + return sentCount, nil +} + +func (c *Conn) readUntilCopyInResponse() error { + for { + var t byte + var r *msgReader + t, r, err := c.rxMsg() + if err != nil { + return err + } + + switch t { + case copyInResponse: + return nil + default: + err = c.processContextFreeMsg(t, r) + if err != nil { + return err + } + } + } +} + +func (ct *copyFrom) cancelCopyIn() error { + wbuf := newWriteBuf(ct.conn, copyFail) + wbuf.WriteCString("client error: abort") + wbuf.closeMsg() + _, err := ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return err + } + + return nil +} + +// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. +// It returns the number of rows copied and an error. +// +// CopyFrom requires all values use the binary format. Almost all types +// implemented by pgx use the binary format by default. Types implementing +// Encoder can only be used if they encode to the binary format. +func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { + ct := ©From{ + conn: c, + tableName: tableName, + columnNames: columnNames, + rowSrc: rowSrc, + readerErrChan: make(chan error), + } + + return ct.run() +} diff --git a/copy_from_test.go b/copy_from_test.go new file mode 100644 index 00000000..54da6989 --- /dev/null +++ b/copy_from_test.go @@ -0,0 +1,428 @@ +package pgx_test + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/jackc/pgx" +) + +func TestConnCopyFromSmall(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz + )`) + + inputRows := [][]interface{}{ + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, + {nil, nil, nil, nil, nil, nil, nil}, + } + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyFrom: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyFromLarge(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz, + h bytea + )`) + + inputRows := [][]interface{}{} + + for i := 0; i < 10000; i++ { + inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) + } + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyFrom: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal") + } + + ensureConnValid(t, conn) +} + +func TestConnCopyFromJSON(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} { + if _, ok := conn.PgTypes[oid]; !ok { + return // No JSON/JSONB type -- must be running against old PostgreSQL + } + } + + mustExec(t, conn, `create temporary table foo( + a json, + b jsonb + )`) + + inputRows := [][]interface{}{ + {map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}}, + {nil, nil}, + } + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyFrom: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyFromFailServerSideMidway(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int4, + b varchar not null + )`) + + inputRows := [][]interface{}{ + {int32(1), "abc"}, + {int32(2), nil}, // this row should trigger a failure + {int32(3), "def"}, + } + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) + if err == nil { + t.Errorf("Expected CopyFrom return error, but it did not") + } + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err) + } + if copyCount != 0 { + t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type failSource struct { + count int +} + +func (fs *failSource) Next() bool { + time.Sleep(time.Millisecond * 100) + fs.count++ + return fs.count < 100 +} + +func (fs *failSource) Values() ([]interface{}, error) { + if fs.count == 3 { + return []interface{}{nil}, nil + } + return []interface{}{make([]byte, 100000)}, nil +} + +func (fs *failSource) Err() error { + return nil +} + +func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + startTime := time.Now() + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &failSource{}) + if err == nil { + t.Errorf("Expected CopyFrom return error, but it did not") + } + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err) + } + if copyCount != 0 { + t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) + } + + endTime := time.Now() + copyTime := endTime.Sub(startTime) + if copyTime > time.Second { + t.Errorf("Failing CopyFrom shouldn't have taken so long: %v", copyTime) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type clientFailSource struct { + count int + err error +} + +func (cfs *clientFailSource) Next() bool { + cfs.count++ + return cfs.count < 100 +} + +func (cfs *clientFailSource) Values() ([]interface{}, error) { + if cfs.count == 3 { + cfs.err = fmt.Errorf("client error") + return nil, cfs.err + } + return []interface{}{make([]byte, 100000)}, nil +} + +func (cfs *clientFailSource) Err() error { + return cfs.err +} + +func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{}) + if err == nil { + t.Errorf("Expected CopyFrom return error, but it did not") + } + if copyCount != 0 { + t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type clientFinalErrSource struct { + count int +} + +func (cfs *clientFinalErrSource) Next() bool { + cfs.count++ + return cfs.count < 5 +} + +func (cfs *clientFinalErrSource) Values() ([]interface{}, error) { + return []interface{}{make([]byte, 100000)}, nil +} + +func (cfs *clientFinalErrSource) Err() error { + return fmt.Errorf("final error") +} + +func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{}) + if err == nil { + t.Errorf("Expected CopyFrom return error, but it did not") + } + if copyCount != 0 { + t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} diff --git a/copy_to.go b/copy_to.go index dd70ada3..b6cf16c8 100644 --- a/copy_to.go +++ b/copy_to.go @@ -5,8 +5,8 @@ import ( "fmt" ) -// CopyToRows returns a CopyToSource interface over the provided rows slice -// making it usable by *Conn.CopyTo. +// Deprecated. Use CopyFromRows instead. CopyToRows returns a CopyToSource +// interface over the provided rows slice making it usable by *Conn.CopyTo. func CopyToRows(rows [][]interface{}) CopyToSource { return ©ToRows{rows: rows, idx: -1} } @@ -29,7 +29,8 @@ func (ctr *copyToRows) Err() error { return nil } -// CopyToSource is the interface used by *Conn.CopyTo as the source for copy data. +// Deprecated. Use CopyFromSource instead. CopyToSource is the interface used by +// *Conn.CopyTo as the source for copy data. type CopyToSource interface { // Next returns true if there is another row and makes the next row data // available to Values(). When there are no more rows available or an error @@ -187,27 +188,6 @@ func (ct *copyTo) run() (int, error) { return sentCount, nil } -func (c *Conn) readUntilCopyInResponse() error { - for { - var t byte - var r *msgReader - t, r, err := c.rxMsg() - if err != nil { - return err - } - - switch t { - case copyInResponse: - return nil - default: - err = c.processContextFreeMsg(t, r) - if err != nil { - return err - } - } - } -} - func (ct *copyTo) cancelCopyIn() error { wbuf := newWriteBuf(ct.conn, copyFail) wbuf.WriteCString("client error: abort") @@ -221,8 +201,9 @@ func (ct *copyTo) cancelCopyIn() error { return nil } -// CopyTo uses the PostgreSQL copy protocol to perform bulk data insertion. -// It returns the number of rows copied and an error. +// Deprecated. Use CopyFrom instead. CopyTo uses the PostgreSQL copy protocol to +// perform bulk data insertion. It returns the number of rows copied and an +// error. // // CopyTo requires all values use the binary format. Almost all types // implemented by pgx use the binary format by default. Types implementing diff --git a/copy_to_test.go b/copy_to_test.go index b65ea0f9..afe22ca2 100644 --- a/copy_to_test.go +++ b/copy_to_test.go @@ -1,7 +1,6 @@ package pgx_test import ( - "fmt" "reflect" "testing" "time" @@ -228,27 +227,6 @@ func TestConnCopyToFailServerSideMidway(t *testing.T) { ensureConnValid(t, conn) } -type failSource struct { - count int -} - -func (fs *failSource) Next() bool { - time.Sleep(time.Millisecond * 100) - fs.count++ - return fs.count < 100 -} - -func (fs *failSource) Values() ([]interface{}, error) { - if fs.count == 3 { - return []interface{}{nil}, nil - } - return []interface{}{make([]byte, 100000)}, nil -} - -func (fs *failSource) Err() error { - return nil -} - func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { t.Parallel() @@ -303,28 +281,6 @@ func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { ensureConnValid(t, conn) } -type clientFailSource struct { - count int - err error -} - -func (cfs *clientFailSource) Next() bool { - cfs.count++ - return cfs.count < 100 -} - -func (cfs *clientFailSource) Values() ([]interface{}, error) { - if cfs.count == 3 { - cfs.err = fmt.Errorf("client error") - return nil, cfs.err - } - return []interface{}{make([]byte, 100000)}, nil -} - -func (cfs *clientFailSource) Err() error { - return cfs.err -} - func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) { t.Parallel() @@ -368,23 +324,6 @@ func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) { ensureConnValid(t, conn) } -type clientFinalErrSource struct { - count int -} - -func (cfs *clientFinalErrSource) Next() bool { - cfs.count++ - return cfs.count < 5 -} - -func (cfs *clientFinalErrSource) Values() ([]interface{}, error) { - return []interface{}{make([]byte, 100000)}, nil -} - -func (cfs *clientFinalErrSource) Err() error { - return fmt.Errorf("final error") -} - func TestConnCopyToCopyToSourceErrorEnd(t *testing.T) { t.Parallel()