Add PgConn.CopyTo

pull/483/head
Jack Christensen 2019-01-19 14:49:39 -06:00
parent e97dbe1b22
commit 5907f222ee
2 changed files with 177 additions and 0 deletions

View File

@ -747,6 +747,71 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
return result
}
// CopyTo executes the copy command sql and copies the results to w.
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
select {
case <-ctx.Done():
return "", ctx.Err()
case pgConn.controller <- pgConn:
}
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
// Send copy to command
var buf []byte
buf = (&pgproto3.Query{String: sql}).Encode(buf)
n, err := pgConn.conn.Write(buf)
if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
cleanupContextDeadline()
<-pgConn.controller
return "", preferContextOverNetTimeoutError(ctx, err)
}
// Read results
var commandTag CommandTag
var pgErr error
for {
msg, err := pgConn.ReceiveMessage()
if err != nil {
cleanupContextDeadline()
if err, ok := err.(net.Error); ok && err.Timeout() {
go pgConn.recoverFromTimeout()
} else {
<-pgConn.controller
}
return "", preferContextOverNetTimeoutError(ctx, err)
}
switch msg := msg.(type) {
case *pgproto3.CopyDone:
case *pgproto3.CopyData:
_, err := w.Write(msg.Data)
if err != nil {
// This isn't actually a timeout, but we want the same behavior. Abort the request and cleanup.
cleanupContextDeadline()
go pgConn.recoverFromTimeout()
return "", err
}
case *pgproto3.ReadyForQuery:
<-pgConn.controller
return commandTag, pgErr
case *pgproto3.CommandComplete:
commandTag = CommandTag(msg.CommandTag)
case *pgproto3.ErrorResponse:
pgErr = errorResponseToPgError(msg)
}
}
}
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
type MultiResultReader struct {
pgConn *PgConn

View File

@ -1,6 +1,7 @@
package pgconn_test
import (
"bytes"
"context"
"crypto/tls"
"fmt"
@ -679,6 +680,117 @@ func TestConnWaitForNotificationTimeout(t *testing.T) {
ensureConnValid(t, pgConn)
}
func TestConnCopyToSmall(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
a int2,
b int4,
c int8,
d varchar,
e text,
f date,
g json
)`).ReadAll()
require.Nil(t, err)
_, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll()
require.Nil(t, err)
_, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll()
require.Nil(t, err)
inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" +
"\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n")
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout")
require.Nil(t, err)
assert.Equal(t, int64(2), res.RowsAffected())
assert.Equal(t, inputBytes, outputWriter.Bytes())
ensureConnValid(t, pgConn)
}
func TestConnCopyToLarge(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
a int2,
b int4,
c int8,
d varchar,
e text,
f date,
g json,
h bytea
)`).ReadAll()
require.Nil(t, err)
inputBytes := make([]byte, 0)
for i := 0; i < 1000; i++ {
_, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll()
require.Nil(t, err)
inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...)
}
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout")
require.Nil(t, err)
assert.Equal(t, int64(1000), res.RowsAffected())
assert.Equal(t, inputBytes, outputWriter.Bytes())
ensureConnValid(t, pgConn)
}
func TestConnCopyToQueryError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
outputWriter := bytes.NewBuffer(make([]byte, 0))
res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout")
require.Error(t, err)
assert.IsType(t, &pgconn.PgError{}, err)
assert.Equal(t, int64(0), res.RowsAffected())
ensureConnValid(t, pgConn)
}
func TestConnCopyToCanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
outputWriter := &bytes.Buffer{}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
assert.Equal(t, context.DeadlineExceeded, err)
assert.Equal(t, pgconn.CommandTag(""), res)
ensureConnValid(t, pgConn)
}
func Example() {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
if err != nil {