From 5315995dfa75238f83b5a93602f6d336770631eb Mon Sep 17 00:00:00 2001 From: Murat Kabilov Date: Mon, 30 Jul 2018 17:29:26 +0200 Subject: [PATCH] Add *Conn. CopyFromTextual, CopyToTextual, which use textual format for copying data --- conn_pool.go | 23 ++++++ copy_from.go | 56 ++++++++++++++ copy_from_test.go | 171 +++++++++++++++++++++++++++++++++++++++++- copy_to.go | 64 ++++++++++++++++ copy_to_test.go | 98 ++++++++++++++++++++++++ pgproto3/copy_done.go | 30 ++++++++ pgproto3/frontend.go | 3 + 7 files changed, 443 insertions(+), 2 deletions(-) create mode 100644 copy_to.go create mode 100644 copy_to_test.go create mode 100644 pgproto3/copy_done.go diff --git a/conn_pool.go b/conn_pool.go index 6ca0ee01..1e786a1f 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -2,6 +2,7 @@ package pgx import ( "context" + "io" "sync" "time" @@ -541,6 +542,28 @@ func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc C return c.CopyFrom(tableName, columnNames, rowSrc) } +// CopyFromTextual acquires a connection, delegates the call to that connection, and releases the connection +func (p *ConnPool) CopyFromTextual(r io.Reader, sql string, args ...interface{}) error { + c, err := p.Acquire() + if err != nil { + return err + } + defer p.Release(c) + + return c.CopyFromTextual(r, sql, args...) +} + +// CopyToTextual acquires a connection, delegates the call to that connection, and releases the connection +func (p *ConnPool) CopyToTextual(w io.Writer, sql string, args ...interface{}) error { + c, err := p.Acquire() + if err != nil { + return err + } + defer p.Release(c) + + return c.CopyToTextual(w, sql, args...) +} + // BeginBatch acquires a connection and begins a batch on that connection. When // *Batch is finished, the connection is released automatically. func (p *ConnPool) BeginBatch() *Batch { diff --git a/copy_from.go b/copy_from.go index 13a80b50..4332c042 100644 --- a/copy_from.go +++ b/copy_from.go @@ -3,6 +3,7 @@ package pgx import ( "bytes" "fmt" + "io" "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" @@ -281,3 +282,58 @@ func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyF return ct.run() } + +// CopyFromTextual uses the PostgreSQL textual format of the copy protocol +func (c *Conn) CopyFromTextual(r io.Reader, sql string, args ...interface{}) error { + if err := c.sendSimpleQuery(sql, args...); err != nil { + return err + } + + if err := c.readUntilCopyInResponse(); err != nil { + return err + } + buf := c.wbuf + + buf = append(buf, copyData) + sp := len(buf) + for { + n, err := r.Read(buf[5:cap(buf)]) + if err == io.EOF { + break + } + buf = buf[0 : n+5] + pgio.SetInt32(buf[sp:], int32(n+4)) + + if _, err := c.conn.Write(buf); err != nil { + return err + } + } + + buf = buf[:0] + buf = append(buf, copyDone) + buf = pgio.AppendInt32(buf, 4) + + if _, err := c.conn.Write(buf); err != nil { + return err + } + + for { + msg, err := c.rxMsg() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + c.rxReadyForQuery(msg) + return nil + case *pgproto3.CommandComplete: + case *pgproto3.ErrorResponse: + return c.rxErrorResponse(msg) + default: + return c.processContextFreeMsg(msg) + } + } + + return nil +} diff --git a/copy_from_test.go b/copy_from_test.go index 3fd6c78c..67a0e7eb 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -2,6 +2,7 @@ package pgx_test import ( "reflect" + "strings" "testing" "time" @@ -25,10 +26,14 @@ func TestConnCopyFromSmall(t *testing.T) { g timestamptz )`) + tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) + inputRows := [][]interface{}{ - {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, {nil, nil, nil, nil, nil, nil, nil}, } + inputReader := strings.NewReader("0\t1\t2\tabc\tefg\t2000-01-01\t" + tzedTime.Format(time.RFC3339Nano) + "\n" + + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) if err != nil { @@ -60,6 +65,34 @@ func TestConnCopyFromSmall(t *testing.T) { t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) } + mustExec(t, conn, "truncate foo") + + if err := conn.CopyFromTextual(inputReader, "copy foo from stdin"); err != nil { + t.Errorf("Unexpected error for CopyFromTextual: %v", err) + } + + rows, err = conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + outputRows = make([][]interface{}, 0) + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + ensureConnValid(t, conn) } @@ -80,10 +113,14 @@ func TestConnCopyFromLarge(t *testing.T) { h bytea )`) + tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) + inputRows := [][]interface{}{} + inputStringRows := "" for i := 0; i < 10000; i++ { - inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) + inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}}) + inputStringRows += "0\t1\t2\tabc\tefg\t2000-01-01\t" + tzedTime.Format(time.RFC3339Nano) + "\toooo\n" } copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows)) @@ -116,6 +153,34 @@ func TestConnCopyFromLarge(t *testing.T) { t.Errorf("Input rows and output rows do not equal") } + mustExec(t, conn, "truncate foo") + + if err := conn.CopyFromTextual(strings.NewReader(inputStringRows), "copy foo from stdin"); err != nil { + t.Errorf("Unexpected error for CopyFromTextual: %v", err) + } + + rows, err = conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + outputRows = make([][]interface{}, 0) + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal") + } + ensureConnValid(t, conn) } @@ -140,6 +205,7 @@ func TestConnCopyFromJSON(t *testing.T) { {map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}}, {nil, nil}, } + inputReader := strings.NewReader("{\"foo\":\"bar\"}\t{\"bar\":\"quz\"}\n\\N\t\\N\n") copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) if err != nil { @@ -171,6 +237,34 @@ func TestConnCopyFromJSON(t *testing.T) { t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) } + mustExec(t, conn, "truncate foo") + + if err := conn.CopyFromTextual(inputReader, "copy foo from stdin"); err != nil { + t.Errorf("Unexpected error for CopyFrom: %v", err) + } + + rows, err = conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + outputRows = make([][]interface{}, 0) + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + ensureConnValid(t, conn) } @@ -212,6 +306,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) { {int32(2), nil}, // this row should trigger a failure {int32(3), "def"}, } + inputReader := strings.NewReader("1\tabc\n2\t\\N\n3\tdef\n") copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) if err == nil { @@ -246,6 +341,38 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) { t.Errorf("Expected 0 rows, but got %v", outputRows) } + mustExec(t, conn, "truncate foo") + + err = conn.CopyFromTextual(inputReader, "copy foo from stdin") + if err == nil { + t.Errorf("Expected CopyFromTextual return error, but it did not") + } + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyFromTextual return pgx.PgError, but instead it returned: %v", err) + } + + rows, err = conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + outputRows = make([][]interface{}, 0) + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + ensureConnValid(t, conn) } @@ -472,3 +599,43 @@ func TestConnCopyFromCopyFromSourceNextPanic(t *testing.T) { t.Error("panic should have killed conn") } } + +func TestConnCopyFromTextualQueryError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + inputReader := strings.NewReader("") + + err := conn.CopyFromTextual(inputReader, "cropy foo from stdin") + if err == nil { + t.Errorf("Expected CopyFromTextual return error, but it did not") + } + + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyFromTextual return pgx.PgError, but instead it returned: %v", err) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyFromTextualNoTableError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + inputReader := strings.NewReader("") + + err := conn.CopyFromTextual(inputReader, "copy foo from stdin") + if err == nil { + t.Errorf("Expected CopyFromTextual return error, but it did not") + } + + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyFromTextual return pgx.PgError, but instead it returned: %v", err) + } + + ensureConnValid(t, conn) +} diff --git a/copy_to.go b/copy_to.go new file mode 100644 index 00000000..2ea08810 --- /dev/null +++ b/copy_to.go @@ -0,0 +1,64 @@ +package pgx + +import ( + "io" + + "github.com/jackc/pgx/pgproto3" +) + +func (c *Conn) readUntilCopyOutResponse() error { + for { + msg, err := c.rxMsg() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.CopyOutResponse: + return nil + default: + err = c.processContextFreeMsg(msg) + if err != nil { + return err + } + } + } +} + +func (c *Conn) CopyToTextual(w io.Writer, sql string, args ...interface{}) error { + if err := c.sendSimpleQuery(sql, args...); err != nil { + return err + } + + if err := c.readUntilCopyOutResponse(); err != nil { + return err + } + + for { + msg, err := c.rxMsg() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.CopyDone: + break + case *pgproto3.CopyData: + _, err := w.Write(msg.Data) + if err != nil { + c.die(err) + return err + } + case *pgproto3.ReadyForQuery: + c.rxReadyForQuery(msg) + return nil + case *pgproto3.CommandComplete: + case *pgproto3.ErrorResponse: + return c.rxErrorResponse(msg) + default: + return c.processContextFreeMsg(msg) + } + } + + return nil +} diff --git a/copy_to_test.go b/copy_to_test.go new file mode 100644 index 00000000..eab44280 --- /dev/null +++ b/copy_to_test.go @@ -0,0 +1,98 @@ +package pgx_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx" +) + +func TestConnCopyToTextualSmall(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json + )`) + mustExec(t, conn, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`) + mustExec(t, conn, `insert into foo values (null, null, null, null, null, null, null)`) + + inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") + + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + + if err := conn.CopyToTextual(outputWriter, "copy foo to stdout"); err != nil { + t.Errorf("Unexpected error for CopyToTextual: %v", err) + } + + if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 { + t.Errorf("Input rows and output rows do not equal:\n%q\n%q", string(inputBytes), string(outputWriter.Bytes())) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyToTextualLarge(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json, + h bytea + )`) + inputBytes := make([]byte, 0) + + for i := 0; i < 1000; i++ { + mustExec(t, conn, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`) + inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) + } + + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + + if err := conn.CopyToTextual(outputWriter, "copy foo to stdout"); err != nil { + t.Errorf("Unexpected error for CopyFrom: %v", err) + } + + if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 { + t.Errorf("Input rows and output rows do not equal") + } + + ensureConnValid(t, conn) +} + +func TestConnCopyToTextualQueryError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + outputWriter := bytes.NewBuffer(make([]byte, 0)) + + err := conn.CopyToTextual(outputWriter, "cropy foo to stdout") + if err == nil { + t.Errorf("Expected CopyFromTextual return error, but it did not") + } + + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyFromTextual return pgx.PgError, but instead it returned: %v", err) + } + + ensureConnValid(t, conn) +} diff --git a/pgproto3/copy_done.go b/pgproto3/copy_done.go new file mode 100644 index 00000000..92481908 --- /dev/null +++ b/pgproto3/copy_done.go @@ -0,0 +1,30 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type CopyDone struct { +} + +func (*CopyDone) Backend() {} + +func (dst *CopyDone) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "CopyDone", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *CopyDone) Encode(dst []byte) []byte { + return append(dst, 'c', 0, 0, 0, 4) +} + +func (src *CopyDone) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "CopyDone", + }) +} diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index d803d362..d1541c74 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -22,6 +22,7 @@ type Frontend struct { copyData CopyData copyInResponse CopyInResponse copyOutResponse CopyOutResponse + copyDone CopyDone dataRow DataRow emptyQueryResponse EmptyQueryResponse errorResponse ErrorResponse @@ -72,6 +73,8 @@ func (b *Frontend) Receive() (BackendMessage, error) { msg = &b.closeComplete case 'A': msg = &b.notificationResponse + case 'c': + msg = &b.copyDone case 'C': msg = &b.commandComplete case 'd':