Adds RowsAffected for CopyToWriter and CopyFromReader

pull/493/head
Nikolay Vorobev 2018-12-12 18:41:47 +03:00
parent c59c9cac59
commit a0331e7409
6 changed files with 90 additions and 41 deletions

View File

@ -553,10 +553,10 @@ func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc C
} }
// CopyFromReader acquires a connection, delegates the call to that connection, and releases the connection // CopyFromReader acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyFromReader(r io.Reader, sql string) error { func (p *ConnPool) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
c, err := p.Acquire() c, err := p.Acquire()
if err != nil { if err != nil {
return err return "", err
} }
defer p.Release(c) defer p.Release(c)
@ -564,10 +564,10 @@ func (p *ConnPool) CopyFromReader(r io.Reader, sql string) error {
} }
// CopyToWriter acquires a connection, delegates the call to that connection, and releases the connection // CopyToWriter acquires a connection, delegates the call to that connection, and releases the connection
func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) error { func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) (CommandTag, error) {
c, err := p.Acquire() c, err := p.Acquire()
if err != nil { if err != nil {
return err return "", err
} }
defer p.Release(c) defer p.Release(c)

View File

@ -284,13 +284,13 @@ func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyF
} }
// CopyFromReader uses the PostgreSQL textual format of the copy protocol // CopyFromReader uses the PostgreSQL textual format of the copy protocol
func (c *Conn) CopyFromReader(r io.Reader, sql string) error { func (c *Conn) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
if err := c.sendSimpleQuery(sql); err != nil { if err := c.sendSimpleQuery(sql); err != nil {
return err return "", err
} }
if err := c.readUntilCopyInResponse(); err != nil { if err := c.readUntilCopyInResponse(); err != nil {
return err return "", err
} }
buf := c.wbuf buf := c.wbuf
@ -305,7 +305,7 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) error {
pgio.SetInt32(buf[sp:], int32(n+4)) pgio.SetInt32(buf[sp:], int32(n+4))
if _, err := c.conn.Write(buf); err != nil { if _, err := c.conn.Write(buf); err != nil {
return err return "", err
} }
} }
@ -314,26 +314,25 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) error {
buf = pgio.AppendInt32(buf, 4) buf = pgio.AppendInt32(buf, 4)
if _, err := c.conn.Write(buf); err != nil { if _, err := c.conn.Write(buf); err != nil {
return err return "", err
} }
for { for {
msg, err := c.rxMsg() msg, err := c.rxMsg()
if err != nil { if err != nil {
return err return "", err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg) c.rxReadyForQuery(msg)
return nil return "", err
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
return CommandTag(msg.CommandTag), nil
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
return c.rxErrorResponse(msg) return "", c.rxErrorResponse(msg)
default: default:
return c.processContextFreeMsg(msg) return "", c.processContextFreeMsg(msg)
} }
} }
return nil
} }

View File

@ -72,9 +72,14 @@ func TestConnCopyFromSmall(t *testing.T) {
mustExec(t, conn, "truncate foo") mustExec(t, conn, "truncate foo")
if err := conn.CopyFromReader(inputReader, "copy foo from stdin"); err != nil { res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
if err != nil {
t.Errorf("Unexpected error for CopyFromReader: %v", err) t.Errorf("Unexpected error for CopyFromReader: %v", err)
} }
copyCount = int(res.RowsAffected())
if copyCount != len(inputRows) {
t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err = conn.Query("select * from foo") rows, err = conn.Query("select * from foo")
if err != nil { if err != nil {
@ -160,9 +165,14 @@ func TestConnCopyFromLarge(t *testing.T) {
mustExec(t, conn, "truncate foo") mustExec(t, conn, "truncate foo")
if err := conn.CopyFromReader(strings.NewReader(inputStringRows), "copy foo from stdin"); err != nil { res, err := conn.CopyFromReader(strings.NewReader(inputStringRows), "copy foo from stdin")
if err != nil {
t.Errorf("Unexpected error for CopyFromReader: %v", err) t.Errorf("Unexpected error for CopyFromReader: %v", err)
} }
copyCount = int(res.RowsAffected())
if copyCount != len(inputRows) {
t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err = conn.Query("select * from foo") rows, err = conn.Query("select * from foo")
if err != nil { if err != nil {
@ -244,9 +254,14 @@ func TestConnCopyFromJSON(t *testing.T) {
mustExec(t, conn, "truncate foo") mustExec(t, conn, "truncate foo")
if err := conn.CopyFromReader(inputReader, "copy foo from stdin"); err != nil { res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err) t.Errorf("Unexpected error for CopyFrom: %v", err)
} }
copyCount = int(res.RowsAffected())
if copyCount != len(inputRows) {
t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err = conn.Query("select * from foo") rows, err = conn.Query("select * from foo")
if err != nil { if err != nil {
@ -348,13 +363,17 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
mustExec(t, conn, "truncate foo") mustExec(t, conn, "truncate foo")
err = conn.CopyFromReader(inputReader, "copy foo from stdin") res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
if err == nil { if err == nil {
t.Errorf("Expected CopyFromReader return error, but it did not") t.Errorf("Expected CopyFromReader return error, but it did not")
} }
if _, ok := err.(pgx.PgError); !ok { if _, ok := err.(pgx.PgError); !ok {
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err) t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
} }
copyCount = int(res.RowsAffected())
if copyCount != 0 {
t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount)
}
rows, err = conn.Query("select * from foo") rows, err = conn.Query("select * from foo")
if err != nil { if err != nil {
@ -613,7 +632,7 @@ func TestConnCopyFromReaderQueryError(t *testing.T) {
inputReader := strings.NewReader("") inputReader := strings.NewReader("")
err := conn.CopyFromReader(inputReader, "cropy foo from stdin") res, err := conn.CopyFromReader(inputReader, "cropy foo from stdin")
if err == nil { if err == nil {
t.Errorf("Expected CopyFromReader return error, but it did not") t.Errorf("Expected CopyFromReader return error, but it did not")
} }
@ -622,6 +641,11 @@ func TestConnCopyFromReaderQueryError(t *testing.T) {
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err) t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
} }
copyCount := int(res.RowsAffected())
if copyCount != 0 {
t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount)
}
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
@ -633,7 +657,7 @@ func TestConnCopyFromReaderNoTableError(t *testing.T) {
inputReader := strings.NewReader("") inputReader := strings.NewReader("")
err := conn.CopyFromReader(inputReader, "copy foo from stdin") res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
if err == nil { if err == nil {
t.Errorf("Expected CopyFromReader return error, but it did not") t.Errorf("Expected CopyFromReader return error, but it did not")
} }
@ -642,6 +666,11 @@ func TestConnCopyFromReaderNoTableError(t *testing.T) {
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err) t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
} }
copyCount := int(res.RowsAffected())
if copyCount != 0 {
t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount)
}
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
@ -688,11 +717,16 @@ func TestConnCopyFromGzipReader(t *testing.T) {
t.Fatalf("Unexpected error for gzip.NewReader: %v", err) t.Fatalf("Unexpected error for gzip.NewReader: %v", err)
} }
err = conn.CopyFromReader(gr, "COPY foo FROM STDIN WITH (FORMAT csv)") res, err := conn.CopyFromReader(gr, "COPY foo FROM STDIN WITH (FORMAT csv)")
if err != nil { if err != nil {
t.Errorf("Unexpected error for CopyFromReader: %v", err) t.Errorf("Unexpected error for CopyFromReader: %v", err)
} }
copyCount := int(res.RowsAffected())
if copyCount != len(inputRows) {
t.Errorf("Expected CopyFromReader to return 1000 copied rows, but got %d", copyCount)
}
err = gr.Close() err = gr.Close()
if err != nil { if err != nil {
t.Errorf("Unexpected error for gr.Close: %v", err) t.Errorf("Unexpected error for gr.Close: %v", err)

View File

@ -25,19 +25,19 @@ func (c *Conn) readUntilCopyOutResponse() error {
} }
} }
func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) error { func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) (CommandTag, error) {
if err := c.sendSimpleQuery(sql, args...); err != nil { if err := c.sendSimpleQuery(sql, args...); err != nil {
return err return "", err
} }
if err := c.readUntilCopyOutResponse(); err != nil { if err := c.readUntilCopyOutResponse(); err != nil {
return err return "", err
} }
for { for {
msg, err := c.rxMsg() msg, err := c.rxMsg()
if err != nil { if err != nil {
return err return "", err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@ -47,18 +47,17 @@ func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) error
_, err := w.Write(msg.Data) _, err := w.Write(msg.Data)
if err != nil { if err != nil {
c.die(err) c.die(err)
return err return "", err
} }
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg) c.rxReadyForQuery(msg)
return nil return "", nil
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
return CommandTag(msg.CommandTag), nil
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
return c.rxErrorResponse(msg) return "", c.rxErrorResponse(msg)
default: default:
return c.processContextFreeMsg(msg) return "", c.processContextFreeMsg(msg)
} }
} }
return nil
} }

View File

@ -30,10 +30,16 @@ func TestConnCopyToWriterSmall(t *testing.T) {
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
if err := conn.CopyToWriter(outputWriter, "copy foo to stdout"); err != nil { res, err := conn.CopyToWriter(outputWriter, "copy foo to stdout")
if err != nil {
t.Errorf("Unexpected error for CopyToWriter: %v", err) t.Errorf("Unexpected error for CopyToWriter: %v", err)
} }
copyCount := int(res.RowsAffected())
if copyCount != 2 {
t.Errorf("Expected CopyToWriter to return 2 copied rows, but got %d", copyCount)
}
if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 { if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 {
t.Errorf("Input rows and output rows do not equal:\n%q\n%q", string(inputBytes), string(outputWriter.Bytes())) t.Errorf("Input rows and output rows do not equal:\n%q\n%q", string(inputBytes), string(outputWriter.Bytes()))
} }
@ -66,10 +72,16 @@ func TestConnCopyToWriterLarge(t *testing.T) {
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
if err := conn.CopyToWriter(outputWriter, "copy foo to stdout"); err != nil { res, err := conn.CopyToWriter(outputWriter, "copy foo to stdout")
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err) t.Errorf("Unexpected error for CopyFrom: %v", err)
} }
copyCount := int(res.RowsAffected())
if copyCount != 1000 {
t.Errorf("Expected CopyToWriter to return 1 copied rows, but got %d", copyCount)
}
if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 { if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 {
t.Errorf("Input rows and output rows do not equal") t.Errorf("Input rows and output rows do not equal")
} }
@ -85,13 +97,18 @@ func TestConnCopyToWriterQueryError(t *testing.T) {
outputWriter := bytes.NewBuffer(make([]byte, 0)) outputWriter := bytes.NewBuffer(make([]byte, 0))
err := conn.CopyToWriter(outputWriter, "cropy foo to stdout") res, err := conn.CopyToWriter(outputWriter, "cropy foo to stdout")
if err == nil { if err == nil {
t.Errorf("Expected CopyFromReader return error, but it did not") t.Errorf("Expected CopyToWriter return error, but it did not")
} }
if _, ok := err.(pgx.PgError); !ok { if _, ok := err.(pgx.PgError); !ok {
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err) t.Errorf("Expected CopyToWriter return pgx.PgError, but instead it returned: %v", err)
}
copyCount := int(res.RowsAffected())
if copyCount != 0 {
t.Errorf("Expected CopyToWriter to return 0 copied rows, but got %d", copyCount)
} }
ensureConnValid(t, conn) ensureConnValid(t, conn)

8
tx.go
View File

@ -240,18 +240,18 @@ func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFr
} }
// CopyFromReader delegates to the underlying *Conn // CopyFromReader delegates to the underlying *Conn
func (tx *Tx) CopyFromReader(r io.Reader, sql string) error { func (tx *Tx) CopyFromReader(r io.Reader, sql string) (commandTag CommandTag, err error) {
if tx.status != TxStatusInProgress { if tx.status != TxStatusInProgress {
return ErrTxClosed return CommandTag(""), ErrTxClosed
} }
return tx.conn.CopyFromReader(r, sql) return tx.conn.CopyFromReader(r, sql)
} }
// CopyToWriter delegates to the underlying *Conn // CopyToWriter delegates to the underlying *Conn
func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) error { func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) (commandTag CommandTag, err error) {
if tx.status != TxStatusInProgress { if tx.status != TxStatusInProgress {
return ErrTxClosed return CommandTag(""), ErrTxClosed
} }
return tx.conn.CopyToWriter(w, sql, args...) return tx.conn.CopyToWriter(w, sql, args...)