mirror of https://github.com/jackc/pgx.git
parent
3cbe92ebb5
commit
20c02acd63
80
copy_from.go
80
copy_from.go
|
@ -115,8 +115,15 @@ func (ct *copyFrom) run() (int, error) {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
panicked := true
|
||||||
|
|
||||||
go ct.readUntilReadyForQuery()
|
go ct.readUntilReadyForQuery()
|
||||||
defer ct.waitForReaderDone()
|
defer ct.waitForReaderDone()
|
||||||
|
defer func() {
|
||||||
|
if panicked {
|
||||||
|
ct.conn.die(errors.New("panic while in copy from"))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
buf := ct.conn.wbuf
|
buf := ct.conn.wbuf
|
||||||
buf = append(buf, copyData)
|
buf = append(buf, copyData)
|
||||||
|
@ -129,49 +136,40 @@ func (ct *copyFrom) run() (int, error) {
|
||||||
|
|
||||||
var sentCount int
|
var sentCount int
|
||||||
|
|
||||||
for ct.rowSrc.Next() {
|
moreRows := true
|
||||||
|
for moreRows {
|
||||||
select {
|
select {
|
||||||
case err = <-ct.readerErrChan:
|
case err = <-ct.readerErrChan:
|
||||||
|
panicked = false
|
||||||
return 0, err
|
return 0, err
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(buf) > 65536 {
|
var addedRows int
|
||||||
|
var err error
|
||||||
|
moreRows, buf, addedRows, err = ct.buildCopyBuf(buf, ps)
|
||||||
|
if err != nil {
|
||||||
|
panicked = false
|
||||||
|
ct.cancelCopyIn()
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
sentCount += addedRows
|
||||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||||
|
|
||||||
_, err = ct.conn.conn.Write(buf)
|
_, err = ct.conn.conn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
panicked = false
|
||||||
ct.conn.die(err)
|
ct.conn.die(err)
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Directly manipulate wbuf to reset to reuse the same buffer
|
// Directly manipulate wbuf to reset to reuse the same buffer
|
||||||
buf = buf[0:5]
|
buf = buf[0:5]
|
||||||
}
|
|
||||||
|
|
||||||
sentCount++
|
|
||||||
|
|
||||||
values, err := ct.rowSrc.Values()
|
|
||||||
if err != nil {
|
|
||||||
ct.cancelCopyIn()
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if len(values) != len(ct.columnNames) {
|
|
||||||
ct.cancelCopyIn()
|
|
||||||
return 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
|
|
||||||
}
|
|
||||||
|
|
||||||
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
|
||||||
for i, val := range values {
|
|
||||||
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
|
|
||||||
if err != nil {
|
|
||||||
ct.cancelCopyIn()
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if ct.rowSrc.Err() != nil {
|
if ct.rowSrc.Err() != nil {
|
||||||
|
panicked = false
|
||||||
ct.cancelCopyIn()
|
ct.cancelCopyIn()
|
||||||
return 0, ct.rowSrc.Err()
|
return 0, ct.rowSrc.Err()
|
||||||
}
|
}
|
||||||
|
@ -184,17 +182,51 @@ func (ct *copyFrom) run() (int, error) {
|
||||||
|
|
||||||
_, err = ct.conn.conn.Write(buf)
|
_, err = ct.conn.conn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
panicked = false
|
||||||
ct.conn.die(err)
|
ct.conn.die(err)
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ct.waitForReaderDone()
|
err = ct.waitForReaderDone()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
panicked = false
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
panicked = false
|
||||||
return sentCount, nil
|
return sentCount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, int, error) {
|
||||||
|
var rowCount int
|
||||||
|
|
||||||
|
for ct.rowSrc.Next() {
|
||||||
|
values, err := ct.rowSrc.Values()
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, 0, err
|
||||||
|
}
|
||||||
|
if len(values) != len(ct.columnNames) {
|
||||||
|
return false, nil, 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
||||||
|
for i, val := range values {
|
||||||
|
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rowCount++
|
||||||
|
|
||||||
|
if len(buf) > 65536 {
|
||||||
|
return true, buf, rowCount, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, buf, rowCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conn) readUntilCopyInResponse() error {
|
func (c *Conn) readUntilCopyInResponse() error {
|
||||||
for {
|
for {
|
||||||
msg, err := c.rxMsg()
|
msg, err := c.rxMsg()
|
||||||
|
|
|
@ -426,3 +426,45 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type nextPanicSource struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfs *nextPanicSource) Next() bool {
|
||||||
|
panic("crash")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfs *nextPanicSource) Values() ([]interface{}, error) {
|
||||||
|
return []interface{}{nil}, nil // should never get here
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfs *nextPanicSource) Err() error {
|
||||||
|
return nil // should never gets here
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnCopyFromCopyFromSourceNextPanic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
mustExec(t, conn, `create temporary table foo(
|
||||||
|
a bytea not null
|
||||||
|
)`)
|
||||||
|
|
||||||
|
caughtPanic := false
|
||||||
|
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
if x := recover(); x != nil {
|
||||||
|
caughtPanic = true
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &nextPanicSource{})
|
||||||
|
}()
|
||||||
|
|
||||||
|
if conn.IsAlive() {
|
||||||
|
t.Error("panic should have killed conn")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue