Merge branch 'nvorobev-rows-affected-for-copyto-copyfrom' into v4-experimental

v4-experimental
Jack Christensen 2018-12-15 17:26:03 -06:00
commit 25f21a597c
6 changed files with 90 additions and 41 deletions

View File

@ -553,10 +553,10 @@ func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc C
}
// CopyFromReader acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyFromReader(r io.Reader, sql string) error {
func (p *ConnPool) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
c, err := p.Acquire()
if err != nil {
return err
return "", err
}
defer p.Release(c)
@ -564,10 +564,10 @@ func (p *ConnPool) CopyFromReader(r io.Reader, sql string) error {
}
// CopyToWriter acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) error {
func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) (CommandTag, error) {
c, err := p.Acquire()
if err != nil {
return err
return "", err
}
defer p.Release(c)

View File

@ -284,13 +284,13 @@ func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyF
}
// CopyFromReader uses the PostgreSQL textual format of the copy protocol
func (c *Conn) CopyFromReader(r io.Reader, sql string) error {
func (c *Conn) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
if err := c.sendSimpleQuery(sql); err != nil {
return err
return "", err
}
if err := c.readUntilCopyInResponse(); err != nil {
return err
return "", err
}
buf := c.wbuf
@ -305,7 +305,7 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) error {
pgio.SetInt32(buf[sp:], int32(n+4))
if _, err := c.BaseConn.NetConn.Write(buf); err != nil {
return err
return "", err
}
}
@ -314,26 +314,25 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) error {
buf = pgio.AppendInt32(buf, 4)
if _, err := c.BaseConn.NetConn.Write(buf); err != nil {
return err
return "", err
}
for {
msg, err := c.rxMsg()
if err != nil {
return err
return "", err
}
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
return nil
return "", err
case *pgproto3.CommandComplete:
return CommandTag(msg.CommandTag), nil
case *pgproto3.ErrorResponse:
return c.rxErrorResponse(msg)
return "", c.rxErrorResponse(msg)
default:
return c.processContextFreeMsg(msg)
return "", c.processContextFreeMsg(msg)
}
}
return nil
}

View File

@ -72,9 +72,14 @@ func TestConnCopyFromSmall(t *testing.T) {
mustExec(t, conn, "truncate foo")
if err := conn.CopyFromReader(inputReader, "copy foo from stdin"); err != nil {
res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
if err != nil {
t.Errorf("Unexpected error for CopyFromReader: %v", err)
}
copyCount = int(res.RowsAffected())
if copyCount != len(inputRows) {
t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err = conn.Query("select * from foo")
if err != nil {
@ -160,9 +165,14 @@ func TestConnCopyFromLarge(t *testing.T) {
mustExec(t, conn, "truncate foo")
if err := conn.CopyFromReader(strings.NewReader(inputStringRows), "copy foo from stdin"); err != nil {
res, err := conn.CopyFromReader(strings.NewReader(inputStringRows), "copy foo from stdin")
if err != nil {
t.Errorf("Unexpected error for CopyFromReader: %v", err)
}
copyCount = int(res.RowsAffected())
if copyCount != len(inputRows) {
t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err = conn.Query("select * from foo")
if err != nil {
@ -244,9 +254,14 @@ func TestConnCopyFromJSON(t *testing.T) {
mustExec(t, conn, "truncate foo")
if err := conn.CopyFromReader(inputReader, "copy foo from stdin"); err != nil {
res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
copyCount = int(res.RowsAffected())
if copyCount != len(inputRows) {
t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err = conn.Query("select * from foo")
if err != nil {
@ -348,13 +363,17 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
mustExec(t, conn, "truncate foo")
err = conn.CopyFromReader(inputReader, "copy foo from stdin")
res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
if err == nil {
t.Errorf("Expected CopyFromReader return error, but it did not")
}
if _, ok := err.(pgx.PgError); !ok {
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
}
copyCount = int(res.RowsAffected())
if copyCount != 0 {
t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount)
}
rows, err = conn.Query("select * from foo")
if err != nil {
@ -613,7 +632,7 @@ func TestConnCopyFromReaderQueryError(t *testing.T) {
inputReader := strings.NewReader("")
err := conn.CopyFromReader(inputReader, "cropy foo from stdin")
res, err := conn.CopyFromReader(inputReader, "cropy foo from stdin")
if err == nil {
t.Errorf("Expected CopyFromReader return error, but it did not")
}
@ -622,6 +641,11 @@ func TestConnCopyFromReaderQueryError(t *testing.T) {
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
}
copyCount := int(res.RowsAffected())
if copyCount != 0 {
t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount)
}
ensureConnValid(t, conn)
}
@ -633,7 +657,7 @@ func TestConnCopyFromReaderNoTableError(t *testing.T) {
inputReader := strings.NewReader("")
err := conn.CopyFromReader(inputReader, "copy foo from stdin")
res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
if err == nil {
t.Errorf("Expected CopyFromReader return error, but it did not")
}
@ -642,6 +666,11 @@ func TestConnCopyFromReaderNoTableError(t *testing.T) {
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
}
copyCount := int(res.RowsAffected())
if copyCount != 0 {
t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount)
}
ensureConnValid(t, conn)
}
@ -688,11 +717,16 @@ func TestConnCopyFromGzipReader(t *testing.T) {
t.Fatalf("Unexpected error for gzip.NewReader: %v", err)
}
err = conn.CopyFromReader(gr, "COPY foo FROM STDIN WITH (FORMAT csv)")
res, err := conn.CopyFromReader(gr, "COPY foo FROM STDIN WITH (FORMAT csv)")
if err != nil {
t.Errorf("Unexpected error for CopyFromReader: %v", err)
}
copyCount := int(res.RowsAffected())
if copyCount != len(inputRows) {
t.Errorf("Expected CopyFromReader to return 1000 copied rows, but got %d", copyCount)
}
err = gr.Close()
if err != nil {
t.Errorf("Unexpected error for gr.Close: %v", err)

View File

@ -25,19 +25,19 @@ func (c *Conn) readUntilCopyOutResponse() error {
}
}
func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) error {
func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) (CommandTag, error) {
if err := c.sendSimpleQuery(sql, args...); err != nil {
return err
return "", err
}
if err := c.readUntilCopyOutResponse(); err != nil {
return err
return "", err
}
for {
msg, err := c.rxMsg()
if err != nil {
return err
return "", err
}
switch msg := msg.(type) {
@ -47,18 +47,17 @@ func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) error
_, err := w.Write(msg.Data)
if err != nil {
c.die(err)
return err
return "", err
}
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
return nil
return "", nil
case *pgproto3.CommandComplete:
return CommandTag(msg.CommandTag), nil
case *pgproto3.ErrorResponse:
return c.rxErrorResponse(msg)
return "", c.rxErrorResponse(msg)
default:
return c.processContextFreeMsg(msg)
return "", c.processContextFreeMsg(msg)
}
}
return nil
}

View File

@ -30,10 +30,16 @@ func TestConnCopyToWriterSmall(t *testing.T) {
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
if err := conn.CopyToWriter(outputWriter, "copy foo to stdout"); err != nil {
res, err := conn.CopyToWriter(outputWriter, "copy foo to stdout")
if err != nil {
t.Errorf("Unexpected error for CopyToWriter: %v", err)
}
copyCount := int(res.RowsAffected())
if copyCount != 2 {
t.Errorf("Expected CopyToWriter to return 2 copied rows, but got %d", copyCount)
}
if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 {
t.Errorf("Input rows and output rows do not equal:\n%q\n%q", string(inputBytes), string(outputWriter.Bytes()))
}
@ -66,10 +72,16 @@ func TestConnCopyToWriterLarge(t *testing.T) {
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
if err := conn.CopyToWriter(outputWriter, "copy foo to stdout"); err != nil {
res, err := conn.CopyToWriter(outputWriter, "copy foo to stdout")
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
copyCount := int(res.RowsAffected())
if copyCount != 1000 {
t.Errorf("Expected CopyToWriter to return 1 copied rows, but got %d", copyCount)
}
if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 {
t.Errorf("Input rows and output rows do not equal")
}
@ -85,13 +97,18 @@ func TestConnCopyToWriterQueryError(t *testing.T) {
outputWriter := bytes.NewBuffer(make([]byte, 0))
err := conn.CopyToWriter(outputWriter, "cropy foo to stdout")
res, err := conn.CopyToWriter(outputWriter, "cropy foo to stdout")
if err == nil {
t.Errorf("Expected CopyFromReader return error, but it did not")
t.Errorf("Expected CopyToWriter return error, but it did not")
}
if _, ok := err.(pgx.PgError); !ok {
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
t.Errorf("Expected CopyToWriter return pgx.PgError, but instead it returned: %v", err)
}
copyCount := int(res.RowsAffected())
if copyCount != 0 {
t.Errorf("Expected CopyToWriter to return 0 copied rows, but got %d", copyCount)
}
ensureConnValid(t, conn)

8
tx.go
View File

@ -240,18 +240,18 @@ func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFr
}
// CopyFromReader delegates to the underlying *Conn
func (tx *Tx) CopyFromReader(r io.Reader, sql string) error {
func (tx *Tx) CopyFromReader(r io.Reader, sql string) (commandTag CommandTag, err error) {
if tx.status != TxStatusInProgress {
return ErrTxClosed
return CommandTag(""), ErrTxClosed
}
return tx.conn.CopyFromReader(r, sql)
}
// CopyToWriter delegates to the underlying *Conn
func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) error {
func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) (commandTag CommandTag, err error) {
if tx.status != TxStatusInProgress {
return ErrTxClosed
return CommandTag(""), ErrTxClosed
}
return tx.conn.CopyToWriter(w, sql, args...)