From 20c02acd637eed8407fe5a77b49d0f5a459aba6e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 14 Jul 2018 11:26:09 -0500 Subject: [PATCH] Fix deadlock when CopyFromSource panics fixes #433 --- copy_from.go | 86 ++++++++++++++++++++++++++++++++--------------- copy_from_test.go | 42 +++++++++++++++++++++++ 2 files changed, 101 insertions(+), 27 deletions(-) diff --git a/copy_from.go b/copy_from.go index 8b7c3d5b..13a80b50 100644 --- a/copy_from.go +++ b/copy_from.go @@ -115,8 +115,15 @@ func (ct *copyFrom) run() (int, error) { return 0, err } + panicked := true + go ct.readUntilReadyForQuery() defer ct.waitForReaderDone() + defer func() { + if panicked { + ct.conn.die(errors.New("panic while in copy from")) + } + }() buf := ct.conn.wbuf buf = append(buf, copyData) @@ -129,49 +136,40 @@ func (ct *copyFrom) run() (int, error) { var sentCount int - for ct.rowSrc.Next() { + moreRows := true + for moreRows { select { case err = <-ct.readerErrChan: + panicked = false return 0, err default: } - if len(buf) > 65536 { - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err = ct.conn.conn.Write(buf) - if err != nil { - ct.conn.die(err) - return 0, err - } - - // Directly manipulate wbuf to reset to reuse the same buffer - buf = buf[0:5] - } - - sentCount++ - - values, err := ct.rowSrc.Values() + var addedRows int + var err error + moreRows, buf, addedRows, err = ct.buildCopyBuf(buf, ps) if err != nil { + panicked = false ct.cancelCopyIn() return 0, err } - if len(values) != len(ct.columnNames) { - ct.cancelCopyIn() - return 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) + sentCount += addedRows + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + _, err = ct.conn.conn.Write(buf) + if err != nil { + panicked = false + ct.conn.die(err) + return 0, err } - buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) - for i, val := range values { - buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val) - if err != nil { - ct.cancelCopyIn() - return 0, err - } + // Directly manipulate wbuf to reset to reuse the same buffer + buf = buf[0:5] - } } if ct.rowSrc.Err() != nil { + panicked = false ct.cancelCopyIn() return 0, ct.rowSrc.Err() } @@ -184,17 +182,51 @@ func (ct *copyFrom) run() (int, error) { _, err = ct.conn.conn.Write(buf) if err != nil { + panicked = false ct.conn.die(err) return 0, err } err = ct.waitForReaderDone() if err != nil { + panicked = false return 0, err } + + panicked = false return sentCount, nil } +func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, int, error) { + var rowCount int + + for ct.rowSrc.Next() { + values, err := ct.rowSrc.Values() + if err != nil { + return false, nil, 0, err + } + if len(values) != len(ct.columnNames) { + return false, nil, 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) + } + + buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) + for i, val := range values { + buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val) + if err != nil { + return false, nil, 0, err + } + } + + rowCount++ + + if len(buf) > 65536 { + return true, buf, rowCount, nil + } + } + + return false, buf, rowCount, nil +} + func (c *Conn) readUntilCopyInResponse() error { for { msg, err := c.rxMsg() diff --git a/copy_from_test.go b/copy_from_test.go index ec674855..e2c54af4 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -426,3 +426,45 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { ensureConnValid(t, conn) } + +type nextPanicSource struct { +} + +func (cfs *nextPanicSource) Next() bool { + panic("crash") +} + +func (cfs *nextPanicSource) Values() ([]interface{}, error) { + return []interface{}{nil}, nil // should never get here +} + +func (cfs *nextPanicSource) Err() error { + return nil // should never gets here +} + +func TestConnCopyFromCopyFromSourceNextPanic(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + caughtPanic := false + + func() { + defer func() { + if x := recover(); x != nil { + caughtPanic = true + } + }() + + conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &nextPanicSource{}) + }() + + if conn.IsAlive() { + t.Error("panic should have killed conn") + } +}