mirror of https://github.com/jackc/pgx.git
Fix race condition in CopyFrom
In case of an error it was possible for the goroutine that builds the copy stream to still be running after CopyFrom returned. Since that goroutine uses the connections ConnInfo data types to encode the copy data it was possible for those types to be concurrently used in an unsafe fashion. CopyFrom will no longer return until that goroutine has completed.pull/681/head
parent
8c9d1cc15b
commit
3b9f79e2f3
|
@ -75,8 +75,11 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
|||
}
|
||||
|
||||
r, w := io.Pipe()
|
||||
doneChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
|
||||
// Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283.
|
||||
buf := ct.conn.wbuf
|
||||
|
||||
|
@ -114,6 +117,9 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
|||
|
||||
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
||||
|
||||
r.Close()
|
||||
<-doneChan
|
||||
|
||||
return commandTag.RowsAffected(), err
|
||||
}
|
||||
|
||||
|
|
|
@ -333,6 +333,52 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
type slowFailRaceSource struct {
|
||||
count int
|
||||
}
|
||||
|
||||
func (fs *slowFailRaceSource) Next() bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
fs.count++
|
||||
return fs.count < 1000
|
||||
}
|
||||
|
||||
func (fs *slowFailRaceSource) Values() ([]interface{}, error) {
|
||||
if fs.count == 500 {
|
||||
return []interface{}{nil, nil}, nil
|
||||
}
|
||||
return []interface{}{1, make([]byte, 1000)}, nil
|
||||
}
|
||||
|
||||
func (fs *slowFailRaceSource) Err() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestConnCopyFromSlowFailRace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int not null,
|
||||
b bytea not null
|
||||
)`)
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, &slowFailRaceSource{})
|
||||
if err == nil {
|
||||
t.Errorf("Expected CopyFrom return error, but it did not")
|
||||
}
|
||||
if _, ok := err.(*pgconn.PgError); !ok {
|
||||
t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err)
|
||||
}
|
||||
if copyCount != 0 {
|
||||
t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue