mirror of https://github.com/jackc/pgx.git
Merge pull request #2240 from bonnefoa/fix-watch-panic
Unwatch and close connection on a batch write errorpull/2257/head
commit
1abf7d9050
|
@ -1773,9 +1773,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
||||||
|
|
||||||
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
|
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
|
||||||
if batch.err != nil {
|
if batch.err != nil {
|
||||||
|
pgConn.contextWatcher.Unwatch()
|
||||||
|
multiResult.err = normalizeTimeoutError(multiResult.ctx, batch.err)
|
||||||
multiResult.closed = true
|
multiResult.closed = true
|
||||||
multiResult.err = batch.err
|
pgConn.asyncClose()
|
||||||
pgConn.unlock()
|
|
||||||
return multiResult
|
return multiResult
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1783,9 +1784,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
||||||
defer pgConn.exitPotentialWriteReadDeadlock()
|
defer pgConn.exitPotentialWriteReadDeadlock()
|
||||||
_, err := pgConn.conn.Write(batch.buf)
|
_, err := pgConn.conn.Write(batch.buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
pgConn.contextWatcher.Unwatch()
|
||||||
|
multiResult.err = normalizeTimeoutError(multiResult.ctx, err)
|
||||||
multiResult.closed = true
|
multiResult.closed = true
|
||||||
multiResult.err = err
|
pgConn.asyncClose()
|
||||||
pgConn.unlock()
|
|
||||||
return multiResult
|
return multiResult
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1420,6 +1420,52 @@ func TestConnExecBatch(t *testing.T) {
|
||||||
assert.Equal(t, "SELECT 1", results[2].CommandTag.String())
|
assert.Equal(t, "SELECT 1", results[2].CommandTag.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockConnection struct {
|
||||||
|
net.Conn
|
||||||
|
writeLatency *time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockConnection) Write(b []byte) (n int, err error) {
|
||||||
|
time.Sleep(*m.writeLatency)
|
||||||
|
return m.Conn.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnExecBatchWriteError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var mockConn mockConnection
|
||||||
|
writeLatency := 0 * time.Second
|
||||||
|
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
conn, err := net.Dial(network, address)
|
||||||
|
mockConn = mockConnection{conn, &writeLatency}
|
||||||
|
return mockConn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
batch := &pgconn.Batch{}
|
||||||
|
pgConn.Conn()
|
||||||
|
|
||||||
|
ctx2, cancel2 := context.WithTimeout(context.Background(), 1*time.Second)
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
|
||||||
|
writeLatency = 2 * time.Second
|
||||||
|
mrr := pgConn.ExecBatch(ctx2, batch)
|
||||||
|
err = mrr.Close()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||||
|
require.True(t, pgConn.IsClosed())
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnExecBatchDeferredError(t *testing.T) {
|
func TestConnExecBatchDeferredError(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue