mirror of
https://github.com/jackc/pgx.git
synced 2025-08-05 00:00:38 +00:00
Add *Conn.CopyFrom
This replaces *Conn.CopyTo. CopyTo was named incorrectly. In PostgreSQL COPY FROM is the command that copies from the client to the server. In addition, CopyTo does not accept a schema qualified table name. This commit introduces the Identifier type which handles multi-part names and correctly quotes/sanitizes them. The new CopyFrom method uses this Identifier type. Conn.CopyTo is deprecated. refs #243 and #190
This commit is contained in:
parent
ba5f97176a
commit
5eb19bc66a
13
conn.go
13
conn.go
@ -146,6 +146,19 @@ func (ct CommandTag) RowsAffected() int64 {
|
|||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Identifier a PostgreSQL identifier or name. Identifiers can be composed of
|
||||||
|
// multiple parts such as ["schema", "table"] or ["table", "column"].
|
||||||
|
type Identifier []string
|
||||||
|
|
||||||
|
// Sanitize returns a sanitized string safe for SQL interpolation.
|
||||||
|
func (ident Identifier) Sanitize() string {
|
||||||
|
parts := make([]string, len(ident))
|
||||||
|
for i := range ident {
|
||||||
|
parts[i] = `"` + strings.Replace(ident[i], `"`, `""`, -1) + `"`
|
||||||
|
}
|
||||||
|
return strings.Join(parts, ".")
|
||||||
|
}
|
||||||
|
|
||||||
// ErrNoRows occurs when rows are expected but none are returned.
|
// ErrNoRows occurs when rows are expected but none are returned.
|
||||||
var ErrNoRows = errors.New("no rows in result set")
|
var ErrNoRows = errors.New("no rows in result set")
|
||||||
|
|
||||||
|
37
conn_test.go
37
conn_test.go
@ -1541,3 +1541,40 @@ func TestSetLogLevel(t *testing.T) {
|
|||||||
t.Fatal("Expected logger to be called, but it wasn't")
|
t.Fatal("Expected logger to be called, but it wasn't")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIdentifierSanitize(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
ident pgx.Identifier
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
ident: pgx.Identifier{`foo`},
|
||||||
|
expected: `"foo"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ident: pgx.Identifier{`select`},
|
||||||
|
expected: `"select"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ident: pgx.Identifier{`foo`, `bar`},
|
||||||
|
expected: `"foo"."bar"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ident: pgx.Identifier{`you should " not do this`},
|
||||||
|
expected: `"you should "" not do this"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ident: pgx.Identifier{`you should " not do this`, `please don't`},
|
||||||
|
expected: `"you should "" not do this"."please don't"`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
qval := tt.ident.Sanitize()
|
||||||
|
if qval != tt.expected {
|
||||||
|
t.Errorf("%d. Expected Sanitize %v to return %v but it was %v", i, tt.ident, tt.expected, qval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
241
copy_from.go
Normal file
241
copy_from.go
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
package pgx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CopyFromRows returns a CopyFromSource interface over the provided rows slice
|
||||||
|
// making it usable by *Conn.CopyFrom.
|
||||||
|
func CopyFromRows(rows [][]interface{}) CopyFromSource {
|
||||||
|
return ©FromRows{rows: rows, idx: -1}
|
||||||
|
}
|
||||||
|
|
||||||
|
type copyFromRows struct {
|
||||||
|
rows [][]interface{}
|
||||||
|
idx int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctr *copyFromRows) Next() bool {
|
||||||
|
ctr.idx++
|
||||||
|
return ctr.idx < len(ctr.rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctr *copyFromRows) Values() ([]interface{}, error) {
|
||||||
|
return ctr.rows[ctr.idx], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctr *copyFromRows) Err() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data.
|
||||||
|
type CopyFromSource interface {
|
||||||
|
// Next returns true if there is another row and makes the next row data
|
||||||
|
// available to Values(). When there are no more rows available or an error
|
||||||
|
// has occurred it returns false.
|
||||||
|
Next() bool
|
||||||
|
|
||||||
|
// Values returns the values for the current row.
|
||||||
|
Values() ([]interface{}, error)
|
||||||
|
|
||||||
|
// Err returns any error that has been encountered by the CopyFromSource. If
|
||||||
|
// this is not nil *Conn.CopyFrom will abort the copy.
|
||||||
|
Err() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type copyFrom struct {
|
||||||
|
conn *Conn
|
||||||
|
tableName Identifier
|
||||||
|
columnNames []string
|
||||||
|
rowSrc CopyFromSource
|
||||||
|
readerErrChan chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ct *copyFrom) readUntilReadyForQuery() {
|
||||||
|
for {
|
||||||
|
t, r, err := ct.conn.rxMsg()
|
||||||
|
if err != nil {
|
||||||
|
ct.readerErrChan <- err
|
||||||
|
close(ct.readerErrChan)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t {
|
||||||
|
case readyForQuery:
|
||||||
|
ct.conn.rxReadyForQuery(r)
|
||||||
|
close(ct.readerErrChan)
|
||||||
|
return
|
||||||
|
case commandComplete:
|
||||||
|
case errorResponse:
|
||||||
|
ct.readerErrChan <- ct.conn.rxErrorResponse(r)
|
||||||
|
default:
|
||||||
|
err = ct.conn.processContextFreeMsg(t, r)
|
||||||
|
if err != nil {
|
||||||
|
ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ct *copyFrom) waitForReaderDone() error {
|
||||||
|
var err error
|
||||||
|
for err = range ct.readerErrChan {
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ct *copyFrom) run() (int, error) {
|
||||||
|
quotedTableName := ct.tableName.Sanitize()
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
for i, cn := range ct.columnNames {
|
||||||
|
if i != 0 {
|
||||||
|
buf.WriteString(", ")
|
||||||
|
}
|
||||||
|
buf.WriteString(quoteIdentifier(cn))
|
||||||
|
}
|
||||||
|
quotedColumnNames := buf.String()
|
||||||
|
|
||||||
|
ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ct.conn.readUntilCopyInResponse()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
go ct.readUntilReadyForQuery()
|
||||||
|
defer ct.waitForReaderDone()
|
||||||
|
|
||||||
|
wbuf := newWriteBuf(ct.conn, copyData)
|
||||||
|
|
||||||
|
wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000"))
|
||||||
|
wbuf.WriteInt32(0)
|
||||||
|
wbuf.WriteInt32(0)
|
||||||
|
|
||||||
|
var sentCount int
|
||||||
|
|
||||||
|
for ct.rowSrc.Next() {
|
||||||
|
select {
|
||||||
|
case err = <-ct.readerErrChan:
|
||||||
|
return 0, err
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(wbuf.buf) > 65536 {
|
||||||
|
wbuf.closeMsg()
|
||||||
|
_, err = ct.conn.conn.Write(wbuf.buf)
|
||||||
|
if err != nil {
|
||||||
|
ct.conn.die(err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Directly manipulate wbuf to reset to reuse the same buffer
|
||||||
|
wbuf.buf = wbuf.buf[0:5]
|
||||||
|
wbuf.buf[0] = copyData
|
||||||
|
wbuf.sizeIdx = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
sentCount++
|
||||||
|
|
||||||
|
values, err := ct.rowSrc.Values()
|
||||||
|
if err != nil {
|
||||||
|
ct.cancelCopyIn()
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if len(values) != len(ct.columnNames) {
|
||||||
|
ct.cancelCopyIn()
|
||||||
|
return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
|
||||||
|
}
|
||||||
|
|
||||||
|
wbuf.WriteInt16(int16(len(ct.columnNames)))
|
||||||
|
for i, val := range values {
|
||||||
|
err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val)
|
||||||
|
if err != nil {
|
||||||
|
ct.cancelCopyIn()
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ct.rowSrc.Err() != nil {
|
||||||
|
ct.cancelCopyIn()
|
||||||
|
return 0, ct.rowSrc.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
wbuf.WriteInt16(-1) // terminate the copy stream
|
||||||
|
|
||||||
|
wbuf.startMsg(copyDone)
|
||||||
|
wbuf.closeMsg()
|
||||||
|
_, err = ct.conn.conn.Write(wbuf.buf)
|
||||||
|
if err != nil {
|
||||||
|
ct.conn.die(err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ct.waitForReaderDone()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return sentCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) readUntilCopyInResponse() error {
|
||||||
|
for {
|
||||||
|
var t byte
|
||||||
|
var r *msgReader
|
||||||
|
t, r, err := c.rxMsg()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t {
|
||||||
|
case copyInResponse:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
err = c.processContextFreeMsg(t, r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ct *copyFrom) cancelCopyIn() error {
|
||||||
|
wbuf := newWriteBuf(ct.conn, copyFail)
|
||||||
|
wbuf.WriteCString("client error: abort")
|
||||||
|
wbuf.closeMsg()
|
||||||
|
_, err := ct.conn.conn.Write(wbuf.buf)
|
||||||
|
if err != nil {
|
||||||
|
ct.conn.die(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion.
|
||||||
|
// It returns the number of rows copied and an error.
|
||||||
|
//
|
||||||
|
// CopyFrom requires all values use the binary format. Almost all types
|
||||||
|
// implemented by pgx use the binary format by default. Types implementing
|
||||||
|
// Encoder can only be used if they encode to the binary format.
|
||||||
|
func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
|
||||||
|
ct := ©From{
|
||||||
|
conn: c,
|
||||||
|
tableName: tableName,
|
||||||
|
columnNames: columnNames,
|
||||||
|
rowSrc: rowSrc,
|
||||||
|
readerErrChan: make(chan error),
|
||||||
|
}
|
||||||
|
|
||||||
|
return ct.run()
|
||||||
|
}
|
428
copy_from_test.go
Normal file
428
copy_from_test.go
Normal file
@ -0,0 +1,428 @@
|
|||||||
|
package pgx_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConnCopyFromSmall(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
mustExec(t, conn, `create temporary table foo(
|
||||||
|
a int2,
|
||||||
|
b int4,
|
||||||
|
c int8,
|
||||||
|
d varchar,
|
||||||
|
e text,
|
||||||
|
f date,
|
||||||
|
g timestamptz
|
||||||
|
)`)
|
||||||
|
|
||||||
|
inputRows := [][]interface{}{
|
||||||
|
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)},
|
||||||
|
{nil, nil, nil, nil, nil, nil, nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||||
|
}
|
||||||
|
if copyCount != len(inputRows) {
|
||||||
|
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := conn.Query("select * from foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for Query: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputRows [][]interface{}
|
||||||
|
for rows.Next() {
|
||||||
|
row, err := rows.Values()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||||
|
}
|
||||||
|
outputRows = append(outputRows, row)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Err() != nil {
|
||||||
|
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||||
|
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnCopyFromLarge(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
mustExec(t, conn, `create temporary table foo(
|
||||||
|
a int2,
|
||||||
|
b int4,
|
||||||
|
c int8,
|
||||||
|
d varchar,
|
||||||
|
e text,
|
||||||
|
f date,
|
||||||
|
g timestamptz,
|
||||||
|
h bytea
|
||||||
|
)`)
|
||||||
|
|
||||||
|
inputRows := [][]interface{}{}
|
||||||
|
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}})
|
||||||
|
}
|
||||||
|
|
||||||
|
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||||
|
}
|
||||||
|
if copyCount != len(inputRows) {
|
||||||
|
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := conn.Query("select * from foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for Query: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputRows [][]interface{}
|
||||||
|
for rows.Next() {
|
||||||
|
row, err := rows.Values()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||||
|
}
|
||||||
|
outputRows = append(outputRows, row)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Err() != nil {
|
||||||
|
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||||
|
t.Errorf("Input rows and output rows do not equal")
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnCopyFromJSON(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} {
|
||||||
|
if _, ok := conn.PgTypes[oid]; !ok {
|
||||||
|
return // No JSON/JSONB type -- must be running against old PostgreSQL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mustExec(t, conn, `create temporary table foo(
|
||||||
|
a json,
|
||||||
|
b jsonb
|
||||||
|
)`)
|
||||||
|
|
||||||
|
inputRows := [][]interface{}{
|
||||||
|
{map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}},
|
||||||
|
{nil, nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||||
|
}
|
||||||
|
if copyCount != len(inputRows) {
|
||||||
|
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := conn.Query("select * from foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for Query: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputRows [][]interface{}
|
||||||
|
for rows.Next() {
|
||||||
|
row, err := rows.Values()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||||
|
}
|
||||||
|
outputRows = append(outputRows, row)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Err() != nil {
|
||||||
|
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||||
|
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnCopyFromFailServerSideMidway(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
mustExec(t, conn, `create temporary table foo(
|
||||||
|
a int4,
|
||||||
|
b varchar not null
|
||||||
|
)`)
|
||||||
|
|
||||||
|
inputRows := [][]interface{}{
|
||||||
|
{int32(1), "abc"},
|
||||||
|
{int32(2), nil}, // this row should trigger a failure
|
||||||
|
{int32(3), "def"},
|
||||||
|
}
|
||||||
|
|
||||||
|
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected CopyFrom return error, but it did not")
|
||||||
|
}
|
||||||
|
if _, ok := err.(pgx.PgError); !ok {
|
||||||
|
t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err)
|
||||||
|
}
|
||||||
|
if copyCount != 0 {
|
||||||
|
t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := conn.Query("select * from foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for Query: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputRows [][]interface{}
|
||||||
|
for rows.Next() {
|
||||||
|
row, err := rows.Values()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||||
|
}
|
||||||
|
outputRows = append(outputRows, row)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Err() != nil {
|
||||||
|
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(outputRows) != 0 {
|
||||||
|
t.Errorf("Expected 0 rows, but got %v", outputRows)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
type failSource struct {
|
||||||
|
count int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *failSource) Next() bool {
|
||||||
|
time.Sleep(time.Millisecond * 100)
|
||||||
|
fs.count++
|
||||||
|
return fs.count < 100
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *failSource) Values() ([]interface{}, error) {
|
||||||
|
if fs.count == 3 {
|
||||||
|
return []interface{}{nil}, nil
|
||||||
|
}
|
||||||
|
return []interface{}{make([]byte, 100000)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *failSource) Err() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
mustExec(t, conn, `create temporary table foo(
|
||||||
|
a bytea not null
|
||||||
|
)`)
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &failSource{})
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected CopyFrom return error, but it did not")
|
||||||
|
}
|
||||||
|
if _, ok := err.(pgx.PgError); !ok {
|
||||||
|
t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err)
|
||||||
|
}
|
||||||
|
if copyCount != 0 {
|
||||||
|
t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
endTime := time.Now()
|
||||||
|
copyTime := endTime.Sub(startTime)
|
||||||
|
if copyTime > time.Second {
|
||||||
|
t.Errorf("Failing CopyFrom shouldn't have taken so long: %v", copyTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := conn.Query("select * from foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for Query: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputRows [][]interface{}
|
||||||
|
for rows.Next() {
|
||||||
|
row, err := rows.Values()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||||
|
}
|
||||||
|
outputRows = append(outputRows, row)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Err() != nil {
|
||||||
|
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(outputRows) != 0 {
|
||||||
|
t.Errorf("Expected 0 rows, but got %v", outputRows)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientFailSource struct {
|
||||||
|
count int
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfs *clientFailSource) Next() bool {
|
||||||
|
cfs.count++
|
||||||
|
return cfs.count < 100
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfs *clientFailSource) Values() ([]interface{}, error) {
|
||||||
|
if cfs.count == 3 {
|
||||||
|
cfs.err = fmt.Errorf("client error")
|
||||||
|
return nil, cfs.err
|
||||||
|
}
|
||||||
|
return []interface{}{make([]byte, 100000)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfs *clientFailSource) Err() error {
|
||||||
|
return cfs.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
mustExec(t, conn, `create temporary table foo(
|
||||||
|
a bytea not null
|
||||||
|
)`)
|
||||||
|
|
||||||
|
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{})
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected CopyFrom return error, but it did not")
|
||||||
|
}
|
||||||
|
if copyCount != 0 {
|
||||||
|
t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := conn.Query("select * from foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for Query: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputRows [][]interface{}
|
||||||
|
for rows.Next() {
|
||||||
|
row, err := rows.Values()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||||
|
}
|
||||||
|
outputRows = append(outputRows, row)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Err() != nil {
|
||||||
|
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(outputRows) != 0 {
|
||||||
|
t.Errorf("Expected 0 rows, but got %v", outputRows)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientFinalErrSource struct {
|
||||||
|
count int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfs *clientFinalErrSource) Next() bool {
|
||||||
|
cfs.count++
|
||||||
|
return cfs.count < 5
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfs *clientFinalErrSource) Values() ([]interface{}, error) {
|
||||||
|
return []interface{}{make([]byte, 100000)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfs *clientFinalErrSource) Err() error {
|
||||||
|
return fmt.Errorf("final error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
mustExec(t, conn, `create temporary table foo(
|
||||||
|
a bytea not null
|
||||||
|
)`)
|
||||||
|
|
||||||
|
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{})
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected CopyFrom return error, but it did not")
|
||||||
|
}
|
||||||
|
if copyCount != 0 {
|
||||||
|
t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := conn.Query("select * from foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for Query: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputRows [][]interface{}
|
||||||
|
for rows.Next() {
|
||||||
|
row, err := rows.Values()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||||
|
}
|
||||||
|
outputRows = append(outputRows, row)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Err() != nil {
|
||||||
|
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(outputRows) != 0 {
|
||||||
|
t.Errorf("Expected 0 rows, but got %v", outputRows)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
33
copy_to.go
33
copy_to.go
@ -5,8 +5,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CopyToRows returns a CopyToSource interface over the provided rows slice
|
// Deprecated. Use CopyFromRows instead. CopyToRows returns a CopyToSource
|
||||||
// making it usable by *Conn.CopyTo.
|
// interface over the provided rows slice making it usable by *Conn.CopyTo.
|
||||||
func CopyToRows(rows [][]interface{}) CopyToSource {
|
func CopyToRows(rows [][]interface{}) CopyToSource {
|
||||||
return ©ToRows{rows: rows, idx: -1}
|
return ©ToRows{rows: rows, idx: -1}
|
||||||
}
|
}
|
||||||
@ -29,7 +29,8 @@ func (ctr *copyToRows) Err() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CopyToSource is the interface used by *Conn.CopyTo as the source for copy data.
|
// Deprecated. Use CopyFromSource instead. CopyToSource is the interface used by
|
||||||
|
// *Conn.CopyTo as the source for copy data.
|
||||||
type CopyToSource interface {
|
type CopyToSource interface {
|
||||||
// Next returns true if there is another row and makes the next row data
|
// Next returns true if there is another row and makes the next row data
|
||||||
// available to Values(). When there are no more rows available or an error
|
// available to Values(). When there are no more rows available or an error
|
||||||
@ -187,27 +188,6 @@ func (ct *copyTo) run() (int, error) {
|
|||||||
return sentCount, nil
|
return sentCount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) readUntilCopyInResponse() error {
|
|
||||||
for {
|
|
||||||
var t byte
|
|
||||||
var r *msgReader
|
|
||||||
t, r, err := c.rxMsg()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch t {
|
|
||||||
case copyInResponse:
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
err = c.processContextFreeMsg(t, r)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ct *copyTo) cancelCopyIn() error {
|
func (ct *copyTo) cancelCopyIn() error {
|
||||||
wbuf := newWriteBuf(ct.conn, copyFail)
|
wbuf := newWriteBuf(ct.conn, copyFail)
|
||||||
wbuf.WriteCString("client error: abort")
|
wbuf.WriteCString("client error: abort")
|
||||||
@ -221,8 +201,9 @@ func (ct *copyTo) cancelCopyIn() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CopyTo uses the PostgreSQL copy protocol to perform bulk data insertion.
|
// Deprecated. Use CopyFrom instead. CopyTo uses the PostgreSQL copy protocol to
|
||||||
// It returns the number of rows copied and an error.
|
// perform bulk data insertion. It returns the number of rows copied and an
|
||||||
|
// error.
|
||||||
//
|
//
|
||||||
// CopyTo requires all values use the binary format. Almost all types
|
// CopyTo requires all values use the binary format. Almost all types
|
||||||
// implemented by pgx use the binary format by default. Types implementing
|
// implemented by pgx use the binary format by default. Types implementing
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package pgx_test
|
package pgx_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -228,27 +227,6 @@ func TestConnCopyToFailServerSideMidway(t *testing.T) {
|
|||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
type failSource struct {
|
|
||||||
count int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fs *failSource) Next() bool {
|
|
||||||
time.Sleep(time.Millisecond * 100)
|
|
||||||
fs.count++
|
|
||||||
return fs.count < 100
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fs *failSource) Values() ([]interface{}, error) {
|
|
||||||
if fs.count == 3 {
|
|
||||||
return []interface{}{nil}, nil
|
|
||||||
}
|
|
||||||
return []interface{}{make([]byte, 100000)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fs *failSource) Err() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
|
func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@ -303,28 +281,6 @@ func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
|
|||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientFailSource struct {
|
|
||||||
count int
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfs *clientFailSource) Next() bool {
|
|
||||||
cfs.count++
|
|
||||||
return cfs.count < 100
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfs *clientFailSource) Values() ([]interface{}, error) {
|
|
||||||
if cfs.count == 3 {
|
|
||||||
cfs.err = fmt.Errorf("client error")
|
|
||||||
return nil, cfs.err
|
|
||||||
}
|
|
||||||
return []interface{}{make([]byte, 100000)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfs *clientFailSource) Err() error {
|
|
||||||
return cfs.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) {
|
func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@ -368,23 +324,6 @@ func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) {
|
|||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientFinalErrSource struct {
|
|
||||||
count int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfs *clientFinalErrSource) Next() bool {
|
|
||||||
cfs.count++
|
|
||||||
return cfs.count < 5
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfs *clientFinalErrSource) Values() ([]interface{}, error) {
|
|
||||||
return []interface{}{make([]byte, 100000)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfs *clientFinalErrSource) Err() error {
|
|
||||||
return fmt.Errorf("final error")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnCopyToCopyToSourceErrorEnd(t *testing.T) {
|
func TestConnCopyToCopyToSourceErrorEnd(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user