addressing the comments

add copy methods to the Tx struct
pull/438/head
Murat Kabilov 2018-08-07 23:44:02 +03:00
parent 5315995dfa
commit 4e9a696434
6 changed files with 54 additions and 35 deletions

View File

@ -542,26 +542,26 @@ func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc C
return c.CopyFrom(tableName, columnNames, rowSrc)
}
// CopyFromTextual acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyFromTextual(r io.Reader, sql string, args ...interface{}) error {
// CopyFromReader acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyFromReader(r io.Reader, sql string) error {
c, err := p.Acquire()
if err != nil {
return err
}
defer p.Release(c)
return c.CopyFromTextual(r, sql, args...)
return c.CopyFromReader(r, sql)
}
// CopyToTextual acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyToTextual(w io.Writer, sql string, args ...interface{}) error {
// CopyToWriter acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) error {
c, err := p.Acquire()
if err != nil {
return err
}
defer p.Release(c)
return c.CopyToTextual(w, sql, args...)
return c.CopyToWriter(w, sql, args...)
}
// BeginBatch acquires a connection and begins a batch on that connection. When

View File

@ -283,9 +283,9 @@ func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyF
return ct.run()
}
// CopyFromTextual uses the PostgreSQL textual format of the copy protocol
func (c *Conn) CopyFromTextual(r io.Reader, sql string, args ...interface{}) error {
if err := c.sendSimpleQuery(sql, args...); err != nil {
// CopyFromReader uses the PostgreSQL textual format of the copy protocol
func (c *Conn) CopyFromReader(r io.Reader, sql string) error {
if err := c.sendSimpleQuery(sql); err != nil {
return err
}

View File

@ -67,8 +67,8 @@ func TestConnCopyFromSmall(t *testing.T) {
mustExec(t, conn, "truncate foo")
if err := conn.CopyFromTextual(inputReader, "copy foo from stdin"); err != nil {
t.Errorf("Unexpected error for CopyFromTextual: %v", err)
if err := conn.CopyFromReader(inputReader, "copy foo from stdin"); err != nil {
t.Errorf("Unexpected error for CopyFromReader: %v", err)
}
rows, err = conn.Query("select * from foo")
@ -155,8 +155,8 @@ func TestConnCopyFromLarge(t *testing.T) {
mustExec(t, conn, "truncate foo")
if err := conn.CopyFromTextual(strings.NewReader(inputStringRows), "copy foo from stdin"); err != nil {
t.Errorf("Unexpected error for CopyFromTextual: %v", err)
if err := conn.CopyFromReader(strings.NewReader(inputStringRows), "copy foo from stdin"); err != nil {
t.Errorf("Unexpected error for CopyFromReader: %v", err)
}
rows, err = conn.Query("select * from foo")
@ -239,7 +239,7 @@ func TestConnCopyFromJSON(t *testing.T) {
mustExec(t, conn, "truncate foo")
if err := conn.CopyFromTextual(inputReader, "copy foo from stdin"); err != nil {
if err := conn.CopyFromReader(inputReader, "copy foo from stdin"); err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
@ -343,12 +343,12 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
mustExec(t, conn, "truncate foo")
err = conn.CopyFromTextual(inputReader, "copy foo from stdin")
err = conn.CopyFromReader(inputReader, "copy foo from stdin")
if err == nil {
t.Errorf("Expected CopyFromTextual return error, but it did not")
t.Errorf("Expected CopyFromReader return error, but it did not")
}
if _, ok := err.(pgx.PgError); !ok {
t.Errorf("Expected CopyFromTextual return pgx.PgError, but instead it returned: %v", err)
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
}
rows, err = conn.Query("select * from foo")
@ -600,7 +600,7 @@ func TestConnCopyFromCopyFromSourceNextPanic(t *testing.T) {
}
}
func TestConnCopyFromTextualQueryError(t *testing.T) {
func TestConnCopyFromReaderQueryError(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@ -608,19 +608,19 @@ func TestConnCopyFromTextualQueryError(t *testing.T) {
inputReader := strings.NewReader("")
err := conn.CopyFromTextual(inputReader, "cropy foo from stdin")
err := conn.CopyFromReader(inputReader, "cropy foo from stdin")
if err == nil {
t.Errorf("Expected CopyFromTextual return error, but it did not")
t.Errorf("Expected CopyFromReader return error, but it did not")
}
if _, ok := err.(pgx.PgError); !ok {
t.Errorf("Expected CopyFromTextual return pgx.PgError, but instead it returned: %v", err)
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
}
ensureConnValid(t, conn)
}
func TestConnCopyFromTextualNoTableError(t *testing.T) {
func TestConnCopyFromReaderNoTableError(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@ -628,13 +628,13 @@ func TestConnCopyFromTextualNoTableError(t *testing.T) {
inputReader := strings.NewReader("")
err := conn.CopyFromTextual(inputReader, "copy foo from stdin")
err := conn.CopyFromReader(inputReader, "copy foo from stdin")
if err == nil {
t.Errorf("Expected CopyFromTextual return error, but it did not")
t.Errorf("Expected CopyFromReader return error, but it did not")
}
if _, ok := err.(pgx.PgError); !ok {
t.Errorf("Expected CopyFromTextual return pgx.PgError, but instead it returned: %v", err)
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
}
ensureConnValid(t, conn)

View File

@ -25,7 +25,7 @@ func (c *Conn) readUntilCopyOutResponse() error {
}
}
func (c *Conn) CopyToTextual(w io.Writer, sql string, args ...interface{}) error {
func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) error {
if err := c.sendSimpleQuery(sql, args...); err != nil {
return err
}

View File

@ -7,7 +7,7 @@ import (
"github.com/jackc/pgx"
)
func TestConnCopyToTextualSmall(t *testing.T) {
func TestConnCopyToWriterSmall(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@ -30,8 +30,8 @@ func TestConnCopyToTextualSmall(t *testing.T) {
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
if err := conn.CopyToTextual(outputWriter, "copy foo to stdout"); err != nil {
t.Errorf("Unexpected error for CopyToTextual: %v", err)
if err := conn.CopyToWriter(outputWriter, "copy foo to stdout"); err != nil {
t.Errorf("Unexpected error for CopyToWriter: %v", err)
}
if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 {
@ -41,7 +41,7 @@ func TestConnCopyToTextualSmall(t *testing.T) {
ensureConnValid(t, conn)
}
func TestConnCopyToTextualLarge(t *testing.T) {
func TestConnCopyToWriterLarge(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@ -66,7 +66,7 @@ func TestConnCopyToTextualLarge(t *testing.T) {
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
if err := conn.CopyToTextual(outputWriter, "copy foo to stdout"); err != nil {
if err := conn.CopyToWriter(outputWriter, "copy foo to stdout"); err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
@ -77,7 +77,7 @@ func TestConnCopyToTextualLarge(t *testing.T) {
ensureConnValid(t, conn)
}
func TestConnCopyToTextualQueryError(t *testing.T) {
func TestConnCopyToWriterQueryError(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@ -85,13 +85,13 @@ func TestConnCopyToTextualQueryError(t *testing.T) {
outputWriter := bytes.NewBuffer(make([]byte, 0))
err := conn.CopyToTextual(outputWriter, "cropy foo to stdout")
err := conn.CopyToWriter(outputWriter, "cropy foo to stdout")
if err == nil {
t.Errorf("Expected CopyFromTextual return error, but it did not")
t.Errorf("Expected CopyFromReader return error, but it did not")
}
if _, ok := err.(pgx.PgError); !ok {
t.Errorf("Expected CopyFromTextual return pgx.PgError, but instead it returned: %v", err)
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
}
ensureConnValid(t, conn)

19
tx.go
View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"io"
"time"
"github.com/pkg/errors"
@ -237,6 +238,24 @@ func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFr
return tx.conn.CopyFrom(tableName, columnNames, rowSrc)
}
// CopyFromReader delegates to the underlying *Conn
func (tx *Tx) CopyFromReader(r io.Reader, sql string) error {
if tx.status != TxStatusInProgress {
return ErrTxClosed
}
return tx.conn.CopyFromReader(r, sql)
}
// CopyToWriter delegates to the underlying *Conn
func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) error {
if tx.status != TxStatusInProgress {
return ErrTxClosed
}
return tx.conn.CopyToWriter(w, sql, args...)
}
// Status returns the status of the transaction from the set of
// pgx.TxStatus* constants.
func (tx *Tx) Status() int8 {