diff --git a/conn_pool.go b/conn_pool.go index 947ebe1c..068a6886 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -553,10 +553,10 @@ func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc C } // CopyFromReader acquires a connection, delegates the call to that connection, and releases the connection -func (p *ConnPool) CopyFromReader(r io.Reader, sql string) error { +func (p *ConnPool) CopyFromReader(r io.Reader, sql string) (CommandTag, error) { c, err := p.Acquire() if err != nil { - return err + return "", err } defer p.Release(c) @@ -564,10 +564,10 @@ func (p *ConnPool) CopyFromReader(r io.Reader, sql string) error { } // CopyToWriter acquires a connection, delegates the call to that connection, and releases the connection -func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) error { +func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) (CommandTag, error) { c, err := p.Acquire() if err != nil { - return err + return "", err } defer p.Release(c) diff --git a/copy_from.go b/copy_from.go index a4d4d91c..1e9a3c77 100644 --- a/copy_from.go +++ b/copy_from.go @@ -284,13 +284,13 @@ func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyF } // CopyFromReader uses the PostgreSQL textual format of the copy protocol -func (c *Conn) CopyFromReader(r io.Reader, sql string) error { +func (c *Conn) CopyFromReader(r io.Reader, sql string) (CommandTag, error) { if err := c.sendSimpleQuery(sql); err != nil { - return err + return "", err } if err := c.readUntilCopyInResponse(); err != nil { - return err + return "", err } buf := c.wbuf @@ -305,7 +305,7 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) error { pgio.SetInt32(buf[sp:], int32(n+4)) if _, err := c.BaseConn.NetConn.Write(buf); err != nil { - return err + return "", err } } @@ -314,26 +314,25 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) error { buf = pgio.AppendInt32(buf, 4) if _, err := c.BaseConn.NetConn.Write(buf); err != nil { - return err + return "", err } for { msg, err := c.rxMsg() if err != nil { - return err + return "", err } switch msg := msg.(type) { case *pgproto3.ReadyForQuery: c.rxReadyForQuery(msg) - return nil + return "", err case *pgproto3.CommandComplete: + return CommandTag(msg.CommandTag), nil case *pgproto3.ErrorResponse: - return c.rxErrorResponse(msg) + return "", c.rxErrorResponse(msg) default: - return c.processContextFreeMsg(msg) + return "", c.processContextFreeMsg(msg) } } - - return nil } diff --git a/copy_from_test.go b/copy_from_test.go index 4c239b05..df9503dc 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -72,9 +72,14 @@ func TestConnCopyFromSmall(t *testing.T) { mustExec(t, conn, "truncate foo") - if err := conn.CopyFromReader(inputReader, "copy foo from stdin"); err != nil { + res, err := conn.CopyFromReader(inputReader, "copy foo from stdin") + if err != nil { t.Errorf("Unexpected error for CopyFromReader: %v", err) } + copyCount = int(res.RowsAffected()) + if copyCount != len(inputRows) { + t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount) + } rows, err = conn.Query("select * from foo") if err != nil { @@ -160,9 +165,14 @@ func TestConnCopyFromLarge(t *testing.T) { mustExec(t, conn, "truncate foo") - if err := conn.CopyFromReader(strings.NewReader(inputStringRows), "copy foo from stdin"); err != nil { + res, err := conn.CopyFromReader(strings.NewReader(inputStringRows), "copy foo from stdin") + if err != nil { t.Errorf("Unexpected error for CopyFromReader: %v", err) } + copyCount = int(res.RowsAffected()) + if copyCount != len(inputRows) { + t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount) + } rows, err = conn.Query("select * from foo") if err != nil { @@ -244,9 +254,14 @@ func TestConnCopyFromJSON(t *testing.T) { mustExec(t, conn, "truncate foo") - if err := conn.CopyFromReader(inputReader, "copy foo from stdin"); err != nil { + res, err := conn.CopyFromReader(inputReader, "copy foo from stdin") + if err != nil { t.Errorf("Unexpected error for CopyFrom: %v", err) } + copyCount = int(res.RowsAffected()) + if copyCount != len(inputRows) { + t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount) + } rows, err = conn.Query("select * from foo") if err != nil { @@ -348,13 +363,17 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) { mustExec(t, conn, "truncate foo") - err = conn.CopyFromReader(inputReader, "copy foo from stdin") + res, err := conn.CopyFromReader(inputReader, "copy foo from stdin") if err == nil { t.Errorf("Expected CopyFromReader return error, but it did not") } if _, ok := err.(pgx.PgError); !ok { t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err) } + copyCount = int(res.RowsAffected()) + if copyCount != 0 { + t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount) + } rows, err = conn.Query("select * from foo") if err != nil { @@ -613,7 +632,7 @@ func TestConnCopyFromReaderQueryError(t *testing.T) { inputReader := strings.NewReader("") - err := conn.CopyFromReader(inputReader, "cropy foo from stdin") + res, err := conn.CopyFromReader(inputReader, "cropy foo from stdin") if err == nil { t.Errorf("Expected CopyFromReader return error, but it did not") } @@ -622,6 +641,11 @@ func TestConnCopyFromReaderQueryError(t *testing.T) { t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err) } + copyCount := int(res.RowsAffected()) + if copyCount != 0 { + t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount) + } + ensureConnValid(t, conn) } @@ -633,7 +657,7 @@ func TestConnCopyFromReaderNoTableError(t *testing.T) { inputReader := strings.NewReader("") - err := conn.CopyFromReader(inputReader, "copy foo from stdin") + res, err := conn.CopyFromReader(inputReader, "copy foo from stdin") if err == nil { t.Errorf("Expected CopyFromReader return error, but it did not") } @@ -642,6 +666,11 @@ func TestConnCopyFromReaderNoTableError(t *testing.T) { t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err) } + copyCount := int(res.RowsAffected()) + if copyCount != 0 { + t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount) + } + ensureConnValid(t, conn) } @@ -688,11 +717,16 @@ func TestConnCopyFromGzipReader(t *testing.T) { t.Fatalf("Unexpected error for gzip.NewReader: %v", err) } - err = conn.CopyFromReader(gr, "COPY foo FROM STDIN WITH (FORMAT csv)") + res, err := conn.CopyFromReader(gr, "COPY foo FROM STDIN WITH (FORMAT csv)") if err != nil { t.Errorf("Unexpected error for CopyFromReader: %v", err) } + copyCount := int(res.RowsAffected()) + if copyCount != len(inputRows) { + t.Errorf("Expected CopyFromReader to return 1000 copied rows, but got %d", copyCount) + } + err = gr.Close() if err != nil { t.Errorf("Unexpected error for gr.Close: %v", err) diff --git a/copy_to.go b/copy_to.go index 0e11a6ed..c85a0aa8 100644 --- a/copy_to.go +++ b/copy_to.go @@ -25,19 +25,19 @@ func (c *Conn) readUntilCopyOutResponse() error { } } -func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) error { +func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) (CommandTag, error) { if err := c.sendSimpleQuery(sql, args...); err != nil { - return err + return "", err } if err := c.readUntilCopyOutResponse(); err != nil { - return err + return "", err } for { msg, err := c.rxMsg() if err != nil { - return err + return "", err } switch msg := msg.(type) { @@ -47,18 +47,17 @@ func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) error _, err := w.Write(msg.Data) if err != nil { c.die(err) - return err + return "", err } case *pgproto3.ReadyForQuery: c.rxReadyForQuery(msg) - return nil + return "", nil case *pgproto3.CommandComplete: + return CommandTag(msg.CommandTag), nil case *pgproto3.ErrorResponse: - return c.rxErrorResponse(msg) + return "", c.rxErrorResponse(msg) default: - return c.processContextFreeMsg(msg) + return "", c.processContextFreeMsg(msg) } } - - return nil } diff --git a/copy_to_test.go b/copy_to_test.go index 5a479a8d..cff837cb 100644 --- a/copy_to_test.go +++ b/copy_to_test.go @@ -30,10 +30,16 @@ func TestConnCopyToWriterSmall(t *testing.T) { outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - if err := conn.CopyToWriter(outputWriter, "copy foo to stdout"); err != nil { + res, err := conn.CopyToWriter(outputWriter, "copy foo to stdout") + if err != nil { t.Errorf("Unexpected error for CopyToWriter: %v", err) } + copyCount := int(res.RowsAffected()) + if copyCount != 2 { + t.Errorf("Expected CopyToWriter to return 2 copied rows, but got %d", copyCount) + } + if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 { t.Errorf("Input rows and output rows do not equal:\n%q\n%q", string(inputBytes), string(outputWriter.Bytes())) } @@ -66,10 +72,16 @@ func TestConnCopyToWriterLarge(t *testing.T) { outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - if err := conn.CopyToWriter(outputWriter, "copy foo to stdout"); err != nil { + res, err := conn.CopyToWriter(outputWriter, "copy foo to stdout") + if err != nil { t.Errorf("Unexpected error for CopyFrom: %v", err) } + copyCount := int(res.RowsAffected()) + if copyCount != 1000 { + t.Errorf("Expected CopyToWriter to return 1 copied rows, but got %d", copyCount) + } + if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 { t.Errorf("Input rows and output rows do not equal") } @@ -85,13 +97,18 @@ func TestConnCopyToWriterQueryError(t *testing.T) { outputWriter := bytes.NewBuffer(make([]byte, 0)) - err := conn.CopyToWriter(outputWriter, "cropy foo to stdout") + res, err := conn.CopyToWriter(outputWriter, "cropy foo to stdout") if err == nil { - t.Errorf("Expected CopyFromReader return error, but it did not") + t.Errorf("Expected CopyToWriter return error, but it did not") } if _, ok := err.(pgx.PgError); !ok { - t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err) + t.Errorf("Expected CopyToWriter return pgx.PgError, but instead it returned: %v", err) + } + + copyCount := int(res.RowsAffected()) + if copyCount != 0 { + t.Errorf("Expected CopyToWriter to return 0 copied rows, but got %d", copyCount) } ensureConnValid(t, conn) diff --git a/tx.go b/tx.go index 611d3f9f..123f82b9 100644 --- a/tx.go +++ b/tx.go @@ -240,18 +240,18 @@ func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFr } // CopyFromReader delegates to the underlying *Conn -func (tx *Tx) CopyFromReader(r io.Reader, sql string) error { +func (tx *Tx) CopyFromReader(r io.Reader, sql string) (commandTag CommandTag, err error) { if tx.status != TxStatusInProgress { - return ErrTxClosed + return CommandTag(""), ErrTxClosed } return tx.conn.CopyFromReader(r, sql) } // CopyToWriter delegates to the underlying *Conn -func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) error { +func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) (commandTag CommandTag, err error) { if tx.status != TxStatusInProgress { - return ErrTxClosed + return CommandTag(""), ErrTxClosed } return tx.conn.CopyToWriter(w, sql, args...)