addressing the comments

add copy methods to the Tx struct
pull/438/head
Murat Kabilov 2018-08-07 23:44:02 +03:00
parent 5315995dfa
commit 4e9a696434
6 changed files with 54 additions and 35 deletions

View File

@ -542,26 +542,26 @@ func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc C
return c.CopyFrom(tableName, columnNames, rowSrc) return c.CopyFrom(tableName, columnNames, rowSrc)
} }
// CopyFromTextual acquires a connection, delegates the call to that connection, and releases the connection // CopyFromReader acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyFromTextual(r io.Reader, sql string, args ...interface{}) error { func (p *ConnPool) CopyFromReader(r io.Reader, sql string) error {
c, err := p.Acquire() c, err := p.Acquire()
if err != nil { if err != nil {
return err return err
} }
defer p.Release(c) defer p.Release(c)
return c.CopyFromTextual(r, sql, args...) return c.CopyFromReader(r, sql)
} }
// CopyToTextual acquires a connection, delegates the call to that connection, and releases the connection // CopyToWriter acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyToTextual(w io.Writer, sql string, args ...interface{}) error { func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) error {
c, err := p.Acquire() c, err := p.Acquire()
if err != nil { if err != nil {
return err return err
} }
defer p.Release(c) defer p.Release(c)
return c.CopyToTextual(w, sql, args...) return c.CopyToWriter(w, sql, args...)
} }
// BeginBatch acquires a connection and begins a batch on that connection. When // BeginBatch acquires a connection and begins a batch on that connection. When

View File

@ -283,9 +283,9 @@ func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyF
return ct.run() return ct.run()
} }
// CopyFromTextual uses the PostgreSQL textual format of the copy protocol // CopyFromReader uses the PostgreSQL textual format of the copy protocol
func (c *Conn) CopyFromTextual(r io.Reader, sql string, args ...interface{}) error { func (c *Conn) CopyFromReader(r io.Reader, sql string) error {
if err := c.sendSimpleQuery(sql, args...); err != nil { if err := c.sendSimpleQuery(sql); err != nil {
return err return err
} }

View File

@ -67,8 +67,8 @@ func TestConnCopyFromSmall(t *testing.T) {
mustExec(t, conn, "truncate foo") mustExec(t, conn, "truncate foo")
if err := conn.CopyFromTextual(inputReader, "copy foo from stdin"); err != nil { if err := conn.CopyFromReader(inputReader, "copy foo from stdin"); err != nil {
t.Errorf("Unexpected error for CopyFromTextual: %v", err) t.Errorf("Unexpected error for CopyFromReader: %v", err)
} }
rows, err = conn.Query("select * from foo") rows, err = conn.Query("select * from foo")
@ -155,8 +155,8 @@ func TestConnCopyFromLarge(t *testing.T) {
mustExec(t, conn, "truncate foo") mustExec(t, conn, "truncate foo")
if err := conn.CopyFromTextual(strings.NewReader(inputStringRows), "copy foo from stdin"); err != nil { if err := conn.CopyFromReader(strings.NewReader(inputStringRows), "copy foo from stdin"); err != nil {
t.Errorf("Unexpected error for CopyFromTextual: %v", err) t.Errorf("Unexpected error for CopyFromReader: %v", err)
} }
rows, err = conn.Query("select * from foo") rows, err = conn.Query("select * from foo")
@ -239,7 +239,7 @@ func TestConnCopyFromJSON(t *testing.T) {
mustExec(t, conn, "truncate foo") mustExec(t, conn, "truncate foo")
if err := conn.CopyFromTextual(inputReader, "copy foo from stdin"); err != nil { if err := conn.CopyFromReader(inputReader, "copy foo from stdin"); err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err) t.Errorf("Unexpected error for CopyFrom: %v", err)
} }
@ -343,12 +343,12 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
mustExec(t, conn, "truncate foo") mustExec(t, conn, "truncate foo")
err = conn.CopyFromTextual(inputReader, "copy foo from stdin") err = conn.CopyFromReader(inputReader, "copy foo from stdin")
if err == nil { if err == nil {
t.Errorf("Expected CopyFromTextual return error, but it did not") t.Errorf("Expected CopyFromReader return error, but it did not")
} }
if _, ok := err.(pgx.PgError); !ok { if _, ok := err.(pgx.PgError); !ok {
t.Errorf("Expected CopyFromTextual return pgx.PgError, but instead it returned: %v", err) t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
} }
rows, err = conn.Query("select * from foo") rows, err = conn.Query("select * from foo")
@ -600,7 +600,7 @@ func TestConnCopyFromCopyFromSourceNextPanic(t *testing.T) {
} }
} }
func TestConnCopyFromTextualQueryError(t *testing.T) { func TestConnCopyFromReaderQueryError(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnect(t, *defaultConnConfig) conn := mustConnect(t, *defaultConnConfig)
@ -608,19 +608,19 @@ func TestConnCopyFromTextualQueryError(t *testing.T) {
inputReader := strings.NewReader("") inputReader := strings.NewReader("")
err := conn.CopyFromTextual(inputReader, "cropy foo from stdin") err := conn.CopyFromReader(inputReader, "cropy foo from stdin")
if err == nil { if err == nil {
t.Errorf("Expected CopyFromTextual return error, but it did not") t.Errorf("Expected CopyFromReader return error, but it did not")
} }
if _, ok := err.(pgx.PgError); !ok { if _, ok := err.(pgx.PgError); !ok {
t.Errorf("Expected CopyFromTextual return pgx.PgError, but instead it returned: %v", err) t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
} }
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
func TestConnCopyFromTextualNoTableError(t *testing.T) { func TestConnCopyFromReaderNoTableError(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnect(t, *defaultConnConfig) conn := mustConnect(t, *defaultConnConfig)
@ -628,13 +628,13 @@ func TestConnCopyFromTextualNoTableError(t *testing.T) {
inputReader := strings.NewReader("") inputReader := strings.NewReader("")
err := conn.CopyFromTextual(inputReader, "copy foo from stdin") err := conn.CopyFromReader(inputReader, "copy foo from stdin")
if err == nil { if err == nil {
t.Errorf("Expected CopyFromTextual return error, but it did not") t.Errorf("Expected CopyFromReader return error, but it did not")
} }
if _, ok := err.(pgx.PgError); !ok { if _, ok := err.(pgx.PgError); !ok {
t.Errorf("Expected CopyFromTextual return pgx.PgError, but instead it returned: %v", err) t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
} }
ensureConnValid(t, conn) ensureConnValid(t, conn)

View File

@ -25,7 +25,7 @@ func (c *Conn) readUntilCopyOutResponse() error {
} }
} }
func (c *Conn) CopyToTextual(w io.Writer, sql string, args ...interface{}) error { func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) error {
if err := c.sendSimpleQuery(sql, args...); err != nil { if err := c.sendSimpleQuery(sql, args...); err != nil {
return err return err
} }

View File

@ -7,7 +7,7 @@ import (
"github.com/jackc/pgx" "github.com/jackc/pgx"
) )
func TestConnCopyToTextualSmall(t *testing.T) { func TestConnCopyToWriterSmall(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnect(t, *defaultConnConfig) conn := mustConnect(t, *defaultConnConfig)
@ -30,8 +30,8 @@ func TestConnCopyToTextualSmall(t *testing.T) {
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
if err := conn.CopyToTextual(outputWriter, "copy foo to stdout"); err != nil { if err := conn.CopyToWriter(outputWriter, "copy foo to stdout"); err != nil {
t.Errorf("Unexpected error for CopyToTextual: %v", err) t.Errorf("Unexpected error for CopyToWriter: %v", err)
} }
if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 { if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 {
@ -41,7 +41,7 @@ func TestConnCopyToTextualSmall(t *testing.T) {
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
func TestConnCopyToTextualLarge(t *testing.T) { func TestConnCopyToWriterLarge(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnect(t, *defaultConnConfig) conn := mustConnect(t, *defaultConnConfig)
@ -66,7 +66,7 @@ func TestConnCopyToTextualLarge(t *testing.T) {
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
if err := conn.CopyToTextual(outputWriter, "copy foo to stdout"); err != nil { if err := conn.CopyToWriter(outputWriter, "copy foo to stdout"); err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err) t.Errorf("Unexpected error for CopyFrom: %v", err)
} }
@ -77,7 +77,7 @@ func TestConnCopyToTextualLarge(t *testing.T) {
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
func TestConnCopyToTextualQueryError(t *testing.T) { func TestConnCopyToWriterQueryError(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnect(t, *defaultConnConfig) conn := mustConnect(t, *defaultConnConfig)
@ -85,13 +85,13 @@ func TestConnCopyToTextualQueryError(t *testing.T) {
outputWriter := bytes.NewBuffer(make([]byte, 0)) outputWriter := bytes.NewBuffer(make([]byte, 0))
err := conn.CopyToTextual(outputWriter, "cropy foo to stdout") err := conn.CopyToWriter(outputWriter, "cropy foo to stdout")
if err == nil { if err == nil {
t.Errorf("Expected CopyFromTextual return error, but it did not") t.Errorf("Expected CopyFromReader return error, but it did not")
} }
if _, ok := err.(pgx.PgError); !ok { if _, ok := err.(pgx.PgError); !ok {
t.Errorf("Expected CopyFromTextual return pgx.PgError, but instead it returned: %v", err) t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
} }
ensureConnValid(t, conn) ensureConnValid(t, conn)

19
tx.go
View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"io"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -237,6 +238,24 @@ func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFr
return tx.conn.CopyFrom(tableName, columnNames, rowSrc) return tx.conn.CopyFrom(tableName, columnNames, rowSrc)
} }
// CopyFromReader delegates to the underlying *Conn
func (tx *Tx) CopyFromReader(r io.Reader, sql string) error {
if tx.status != TxStatusInProgress {
return ErrTxClosed
}
return tx.conn.CopyFromReader(r, sql)
}
// CopyToWriter delegates to the underlying *Conn
func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) error {
if tx.status != TxStatusInProgress {
return ErrTxClosed
}
return tx.conn.CopyToWriter(w, sql, args...)
}
// Status returns the status of the transaction from the set of // Status returns the status of the transaction from the set of
// pgx.TxStatus* constants. // pgx.TxStatus* constants.
func (tx *Tx) Status() int8 { func (tx *Tx) Status() int8 {