From 4e9a69643473f1b9e844a140116a3c4b64d333ed Mon Sep 17 00:00:00 2001 From: Murat Kabilov Date: Tue, 7 Aug 2018 23:44:02 +0300 Subject: [PATCH] addressing the comments add copy methods to the Tx struct --- conn_pool.go | 12 ++++++------ copy_from.go | 6 +++--- copy_from_test.go | 32 ++++++++++++++++---------------- copy_to.go | 2 +- copy_to_test.go | 18 +++++++++--------- tx.go | 19 +++++++++++++++++++ 6 files changed, 54 insertions(+), 35 deletions(-) diff --git a/conn_pool.go b/conn_pool.go index 1e786a1f..b97ccb28 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -542,26 +542,26 @@ func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc C return c.CopyFrom(tableName, columnNames, rowSrc) } -// CopyFromTextual acquires a connection, delegates the call to that connection, and releases the connection -func (p *ConnPool) CopyFromTextual(r io.Reader, sql string, args ...interface{}) error { +// CopyFromReader acquires a connection, delegates the call to that connection, and releases the connection +func (p *ConnPool) CopyFromReader(r io.Reader, sql string) error { c, err := p.Acquire() if err != nil { return err } defer p.Release(c) - return c.CopyFromTextual(r, sql, args...) + return c.CopyFromReader(r, sql) } -// CopyToTextual acquires a connection, delegates the call to that connection, and releases the connection -func (p *ConnPool) CopyToTextual(w io.Writer, sql string, args ...interface{}) error { +// CopyToWriter acquires a connection, delegates the call to that connection, and releases the connection +func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) error { c, err := p.Acquire() if err != nil { return err } defer p.Release(c) - return c.CopyToTextual(w, sql, args...) + return c.CopyToWriter(w, sql, args...) } // BeginBatch acquires a connection and begins a batch on that connection. When diff --git a/copy_from.go b/copy_from.go index 4332c042..314d441f 100644 --- a/copy_from.go +++ b/copy_from.go @@ -283,9 +283,9 @@ func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyF return ct.run() } -// CopyFromTextual uses the PostgreSQL textual format of the copy protocol -func (c *Conn) CopyFromTextual(r io.Reader, sql string, args ...interface{}) error { - if err := c.sendSimpleQuery(sql, args...); err != nil { +// CopyFromReader uses the PostgreSQL textual format of the copy protocol +func (c *Conn) CopyFromReader(r io.Reader, sql string) error { + if err := c.sendSimpleQuery(sql); err != nil { return err } diff --git a/copy_from_test.go b/copy_from_test.go index 67a0e7eb..0ed88b72 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -67,8 +67,8 @@ func TestConnCopyFromSmall(t *testing.T) { mustExec(t, conn, "truncate foo") - if err := conn.CopyFromTextual(inputReader, "copy foo from stdin"); err != nil { - t.Errorf("Unexpected error for CopyFromTextual: %v", err) + if err := conn.CopyFromReader(inputReader, "copy foo from stdin"); err != nil { + t.Errorf("Unexpected error for CopyFromReader: %v", err) } rows, err = conn.Query("select * from foo") @@ -155,8 +155,8 @@ func TestConnCopyFromLarge(t *testing.T) { mustExec(t, conn, "truncate foo") - if err := conn.CopyFromTextual(strings.NewReader(inputStringRows), "copy foo from stdin"); err != nil { - t.Errorf("Unexpected error for CopyFromTextual: %v", err) + if err := conn.CopyFromReader(strings.NewReader(inputStringRows), "copy foo from stdin"); err != nil { + t.Errorf("Unexpected error for CopyFromReader: %v", err) } rows, err = conn.Query("select * from foo") @@ -239,7 +239,7 @@ func TestConnCopyFromJSON(t *testing.T) { mustExec(t, conn, "truncate foo") - if err := conn.CopyFromTextual(inputReader, "copy foo from stdin"); err != nil { + if err := conn.CopyFromReader(inputReader, "copy foo from stdin"); err != nil { t.Errorf("Unexpected error for CopyFrom: %v", err) } @@ -343,12 +343,12 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) { mustExec(t, conn, "truncate foo") - err = conn.CopyFromTextual(inputReader, "copy foo from stdin") + err = conn.CopyFromReader(inputReader, "copy foo from stdin") if err == nil { - t.Errorf("Expected CopyFromTextual return error, but it did not") + t.Errorf("Expected CopyFromReader return error, but it did not") } if _, ok := err.(pgx.PgError); !ok { - t.Errorf("Expected CopyFromTextual return pgx.PgError, but instead it returned: %v", err) + t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err) } rows, err = conn.Query("select * from foo") @@ -600,7 +600,7 @@ func TestConnCopyFromCopyFromSourceNextPanic(t *testing.T) { } } -func TestConnCopyFromTextualQueryError(t *testing.T) { +func TestConnCopyFromReaderQueryError(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -608,19 +608,19 @@ func TestConnCopyFromTextualQueryError(t *testing.T) { inputReader := strings.NewReader("") - err := conn.CopyFromTextual(inputReader, "cropy foo from stdin") + err := conn.CopyFromReader(inputReader, "cropy foo from stdin") if err == nil { - t.Errorf("Expected CopyFromTextual return error, but it did not") + t.Errorf("Expected CopyFromReader return error, but it did not") } if _, ok := err.(pgx.PgError); !ok { - t.Errorf("Expected CopyFromTextual return pgx.PgError, but instead it returned: %v", err) + t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err) } ensureConnValid(t, conn) } -func TestConnCopyFromTextualNoTableError(t *testing.T) { +func TestConnCopyFromReaderNoTableError(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -628,13 +628,13 @@ func TestConnCopyFromTextualNoTableError(t *testing.T) { inputReader := strings.NewReader("") - err := conn.CopyFromTextual(inputReader, "copy foo from stdin") + err := conn.CopyFromReader(inputReader, "copy foo from stdin") if err == nil { - t.Errorf("Expected CopyFromTextual return error, but it did not") + t.Errorf("Expected CopyFromReader return error, but it did not") } if _, ok := err.(pgx.PgError); !ok { - t.Errorf("Expected CopyFromTextual return pgx.PgError, but instead it returned: %v", err) + t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err) } ensureConnValid(t, conn) diff --git a/copy_to.go b/copy_to.go index 2ea08810..0e11a6ed 100644 --- a/copy_to.go +++ b/copy_to.go @@ -25,7 +25,7 @@ func (c *Conn) readUntilCopyOutResponse() error { } } -func (c *Conn) CopyToTextual(w io.Writer, sql string, args ...interface{}) error { +func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) error { if err := c.sendSimpleQuery(sql, args...); err != nil { return err } diff --git a/copy_to_test.go b/copy_to_test.go index eab44280..5a479a8d 100644 --- a/copy_to_test.go +++ b/copy_to_test.go @@ -7,7 +7,7 @@ import ( "github.com/jackc/pgx" ) -func TestConnCopyToTextualSmall(t *testing.T) { +func TestConnCopyToWriterSmall(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -30,8 +30,8 @@ func TestConnCopyToTextualSmall(t *testing.T) { outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - if err := conn.CopyToTextual(outputWriter, "copy foo to stdout"); err != nil { - t.Errorf("Unexpected error for CopyToTextual: %v", err) + if err := conn.CopyToWriter(outputWriter, "copy foo to stdout"); err != nil { + t.Errorf("Unexpected error for CopyToWriter: %v", err) } if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 { @@ -41,7 +41,7 @@ func TestConnCopyToTextualSmall(t *testing.T) { ensureConnValid(t, conn) } -func TestConnCopyToTextualLarge(t *testing.T) { +func TestConnCopyToWriterLarge(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -66,7 +66,7 @@ func TestConnCopyToTextualLarge(t *testing.T) { outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - if err := conn.CopyToTextual(outputWriter, "copy foo to stdout"); err != nil { + if err := conn.CopyToWriter(outputWriter, "copy foo to stdout"); err != nil { t.Errorf("Unexpected error for CopyFrom: %v", err) } @@ -77,7 +77,7 @@ func TestConnCopyToTextualLarge(t *testing.T) { ensureConnValid(t, conn) } -func TestConnCopyToTextualQueryError(t *testing.T) { +func TestConnCopyToWriterQueryError(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -85,13 +85,13 @@ func TestConnCopyToTextualQueryError(t *testing.T) { outputWriter := bytes.NewBuffer(make([]byte, 0)) - err := conn.CopyToTextual(outputWriter, "cropy foo to stdout") + err := conn.CopyToWriter(outputWriter, "cropy foo to stdout") if err == nil { - t.Errorf("Expected CopyFromTextual return error, but it did not") + t.Errorf("Expected CopyFromReader return error, but it did not") } if _, ok := err.(pgx.PgError); !ok { - t.Errorf("Expected CopyFromTextual return pgx.PgError, but instead it returned: %v", err) + t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err) } ensureConnValid(t, conn) diff --git a/tx.go b/tx.go index c7731dde..eb6b6805 100644 --- a/tx.go +++ b/tx.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "io" "time" "github.com/pkg/errors" @@ -237,6 +238,24 @@ func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFr return tx.conn.CopyFrom(tableName, columnNames, rowSrc) } +// CopyFromReader delegates to the underlying *Conn +func (tx *Tx) CopyFromReader(r io.Reader, sql string) error { + if tx.status != TxStatusInProgress { + return ErrTxClosed + } + + return tx.conn.CopyFromReader(r, sql) +} + +// CopyToWriter delegates to the underlying *Conn +func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) error { + if tx.status != TxStatusInProgress { + return ErrTxClosed + } + + return tx.conn.CopyToWriter(w, sql, args...) +} + // Status returns the status of the transaction from the set of // pgx.TxStatus* constants. func (tx *Tx) Status() int8 {