mirror of https://github.com/jackc/pgx.git
Merge branch 'nvorobev-rows-affected-for-copyto-copyfrom' into v4-experimental
commit
25f21a597c
|
@ -553,10 +553,10 @@ func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc C
|
|||
}
|
||||
|
||||
// CopyFromReader acquires a connection, delegates the call to that connection, and releases the connection
|
||||
func (p *ConnPool) CopyFromReader(r io.Reader, sql string) error {
|
||||
func (p *ConnPool) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
|
@ -564,10 +564,10 @@ func (p *ConnPool) CopyFromReader(r io.Reader, sql string) error {
|
|||
}
|
||||
|
||||
// CopyToWriter acquires a connection, delegates the call to that connection, and releases the connection
|
||||
func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) error {
|
||||
func (p *ConnPool) CopyToWriter(w io.Writer, sql string, args ...interface{}) (CommandTag, error) {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
|
|
21
copy_from.go
21
copy_from.go
|
@ -284,13 +284,13 @@ func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyF
|
|||
}
|
||||
|
||||
// CopyFromReader uses the PostgreSQL textual format of the copy protocol
|
||||
func (c *Conn) CopyFromReader(r io.Reader, sql string) error {
|
||||
func (c *Conn) CopyFromReader(r io.Reader, sql string) (CommandTag, error) {
|
||||
if err := c.sendSimpleQuery(sql); err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := c.readUntilCopyInResponse(); err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
buf := c.wbuf
|
||||
|
||||
|
@ -305,7 +305,7 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) error {
|
|||
pgio.SetInt32(buf[sp:], int32(n+4))
|
||||
|
||||
if _, err := c.BaseConn.NetConn.Write(buf); err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -314,26 +314,25 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) error {
|
|||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
if _, err := c.BaseConn.NetConn.Write(buf); err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
c.rxReadyForQuery(msg)
|
||||
return nil
|
||||
return "", err
|
||||
case *pgproto3.CommandComplete:
|
||||
return CommandTag(msg.CommandTag), nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return c.rxErrorResponse(msg)
|
||||
return "", c.rxErrorResponse(msg)
|
||||
default:
|
||||
return c.processContextFreeMsg(msg)
|
||||
return "", c.processContextFreeMsg(msg)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -72,9 +72,14 @@ func TestConnCopyFromSmall(t *testing.T) {
|
|||
|
||||
mustExec(t, conn, "truncate foo")
|
||||
|
||||
if err := conn.CopyFromReader(inputReader, "copy foo from stdin"); err != nil {
|
||||
res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFromReader: %v", err)
|
||||
}
|
||||
copyCount = int(res.RowsAffected())
|
||||
if copyCount != len(inputRows) {
|
||||
t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err = conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
|
@ -160,9 +165,14 @@ func TestConnCopyFromLarge(t *testing.T) {
|
|||
|
||||
mustExec(t, conn, "truncate foo")
|
||||
|
||||
if err := conn.CopyFromReader(strings.NewReader(inputStringRows), "copy foo from stdin"); err != nil {
|
||||
res, err := conn.CopyFromReader(strings.NewReader(inputStringRows), "copy foo from stdin")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFromReader: %v", err)
|
||||
}
|
||||
copyCount = int(res.RowsAffected())
|
||||
if copyCount != len(inputRows) {
|
||||
t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err = conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
|
@ -244,9 +254,14 @@ func TestConnCopyFromJSON(t *testing.T) {
|
|||
|
||||
mustExec(t, conn, "truncate foo")
|
||||
|
||||
if err := conn.CopyFromReader(inputReader, "copy foo from stdin"); err != nil {
|
||||
res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||
}
|
||||
copyCount = int(res.RowsAffected())
|
||||
if copyCount != len(inputRows) {
|
||||
t.Errorf("Expected CopyFromReader to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err = conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
|
@ -348,13 +363,17 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
|
|||
|
||||
mustExec(t, conn, "truncate foo")
|
||||
|
||||
err = conn.CopyFromReader(inputReader, "copy foo from stdin")
|
||||
res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyFromReader return error, but it did not")
|
||||
}
|
||||
if _, ok := err.(pgx.PgError); !ok {
|
||||
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
|
||||
}
|
||||
copyCount = int(res.RowsAffected())
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
rows, err = conn.Query("select * from foo")
|
||||
if err != nil {
|
||||
|
@ -613,7 +632,7 @@ func TestConnCopyFromReaderQueryError(t *testing.T) {
|
|||
|
||||
inputReader := strings.NewReader("")
|
||||
|
||||
err := conn.CopyFromReader(inputReader, "cropy foo from stdin")
|
||||
res, err := conn.CopyFromReader(inputReader, "cropy foo from stdin")
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyFromReader return error, but it did not")
|
||||
}
|
||||
|
@ -622,6 +641,11 @@ func TestConnCopyFromReaderQueryError(t *testing.T) {
|
|||
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
|
||||
}
|
||||
|
||||
copyCount := int(res.RowsAffected())
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
|
@ -633,7 +657,7 @@ func TestConnCopyFromReaderNoTableError(t *testing.T) {
|
|||
|
||||
inputReader := strings.NewReader("")
|
||||
|
||||
err := conn.CopyFromReader(inputReader, "copy foo from stdin")
|
||||
res, err := conn.CopyFromReader(inputReader, "copy foo from stdin")
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyFromReader return error, but it did not")
|
||||
}
|
||||
|
@ -642,6 +666,11 @@ func TestConnCopyFromReaderNoTableError(t *testing.T) {
|
|||
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
|
||||
}
|
||||
|
||||
copyCount := int(res.RowsAffected())
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyFromReader to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
|
@ -688,11 +717,16 @@ func TestConnCopyFromGzipReader(t *testing.T) {
|
|||
t.Fatalf("Unexpected error for gzip.NewReader: %v", err)
|
||||
}
|
||||
|
||||
err = conn.CopyFromReader(gr, "COPY foo FROM STDIN WITH (FORMAT csv)")
|
||||
res, err := conn.CopyFromReader(gr, "COPY foo FROM STDIN WITH (FORMAT csv)")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFromReader: %v", err)
|
||||
}
|
||||
|
||||
copyCount := int(res.RowsAffected())
|
||||
if copyCount != len(inputRows) {
|
||||
t.Errorf("Expected CopyFromReader to return 1000 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
err = gr.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for gr.Close: %v", err)
|
||||
|
|
19
copy_to.go
19
copy_to.go
|
@ -25,19 +25,19 @@ func (c *Conn) readUntilCopyOutResponse() error {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) error {
|
||||
func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) (CommandTag, error) {
|
||||
if err := c.sendSimpleQuery(sql, args...); err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := c.readUntilCopyOutResponse(); err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
|
@ -47,18 +47,17 @@ func (c *Conn) CopyToWriter(w io.Writer, sql string, args ...interface{}) error
|
|||
_, err := w.Write(msg.Data)
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
case *pgproto3.ReadyForQuery:
|
||||
c.rxReadyForQuery(msg)
|
||||
return nil
|
||||
return "", nil
|
||||
case *pgproto3.CommandComplete:
|
||||
return CommandTag(msg.CommandTag), nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return c.rxErrorResponse(msg)
|
||||
return "", c.rxErrorResponse(msg)
|
||||
default:
|
||||
return c.processContextFreeMsg(msg)
|
||||
return "", c.processContextFreeMsg(msg)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -30,10 +30,16 @@ func TestConnCopyToWriterSmall(t *testing.T) {
|
|||
|
||||
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
|
||||
|
||||
if err := conn.CopyToWriter(outputWriter, "copy foo to stdout"); err != nil {
|
||||
res, err := conn.CopyToWriter(outputWriter, "copy foo to stdout")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyToWriter: %v", err)
|
||||
}
|
||||
|
||||
copyCount := int(res.RowsAffected())
|
||||
if copyCount != 2 {
|
||||
t.Errorf("Expected CopyToWriter to return 2 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 {
|
||||
t.Errorf("Input rows and output rows do not equal:\n%q\n%q", string(inputBytes), string(outputWriter.Bytes()))
|
||||
}
|
||||
|
@ -66,10 +72,16 @@ func TestConnCopyToWriterLarge(t *testing.T) {
|
|||
|
||||
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
|
||||
|
||||
if err := conn.CopyToWriter(outputWriter, "copy foo to stdout"); err != nil {
|
||||
res, err := conn.CopyToWriter(outputWriter, "copy foo to stdout")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||
}
|
||||
|
||||
copyCount := int(res.RowsAffected())
|
||||
if copyCount != 1000 {
|
||||
t.Errorf("Expected CopyToWriter to return 1 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
if i := bytes.Compare(inputBytes, outputWriter.Bytes()); i != 0 {
|
||||
t.Errorf("Input rows and output rows do not equal")
|
||||
}
|
||||
|
@ -85,13 +97,18 @@ func TestConnCopyToWriterQueryError(t *testing.T) {
|
|||
|
||||
outputWriter := bytes.NewBuffer(make([]byte, 0))
|
||||
|
||||
err := conn.CopyToWriter(outputWriter, "cropy foo to stdout")
|
||||
res, err := conn.CopyToWriter(outputWriter, "cropy foo to stdout")
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyFromReader return error, but it did not")
|
||||
t.Errorf("Expected CopyToWriter return error, but it did not")
|
||||
}
|
||||
|
||||
if _, ok := err.(pgx.PgError); !ok {
|
||||
t.Errorf("Expected CopyFromReader return pgx.PgError, but instead it returned: %v", err)
|
||||
t.Errorf("Expected CopyToWriter return pgx.PgError, but instead it returned: %v", err)
|
||||
}
|
||||
|
||||
copyCount := int(res.RowsAffected())
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyToWriter to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
|
|
8
tx.go
8
tx.go
|
@ -240,18 +240,18 @@ func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFr
|
|||
}
|
||||
|
||||
// CopyFromReader delegates to the underlying *Conn
|
||||
func (tx *Tx) CopyFromReader(r io.Reader, sql string) error {
|
||||
func (tx *Tx) CopyFromReader(r io.Reader, sql string) (commandTag CommandTag, err error) {
|
||||
if tx.status != TxStatusInProgress {
|
||||
return ErrTxClosed
|
||||
return CommandTag(""), ErrTxClosed
|
||||
}
|
||||
|
||||
return tx.conn.CopyFromReader(r, sql)
|
||||
}
|
||||
|
||||
// CopyToWriter delegates to the underlying *Conn
|
||||
func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) error {
|
||||
func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) (commandTag CommandTag, err error) {
|
||||
if tx.status != TxStatusInProgress {
|
||||
return ErrTxClosed
|
||||
return CommandTag(""), ErrTxClosed
|
||||
}
|
||||
|
||||
return tx.conn.CopyToWriter(w, sql, args...)
|
||||
|
|
Loading…
Reference in New Issue