Unwatch and close connection on a batch write error

Previously, a conn.Write would simply unlock pgconn, leaving the
connection as Idle and reusable while the multiResultReader would be
closed. From this state, calling multiResultReader.Close won't try to
receiveMessage and thus won't unwatch and close the connection since it
is already closed. This leaves the connection "open" and the next time
it's used, a "Watch already in progress" panic could be triggered.

This patch fixes the issue by unwatching and closing the connection on a
batch write error. The same was done on Sync.Encode error even if the
path is unreachable as Sync.Error never returns an error.
pull/2240/head
Anthonin Bonnefoy 2025-01-23 11:57:42 +01:00
parent 0bc29e3000
commit 228cfffc20
2 changed files with 52 additions and 4 deletions

View File

@ -1773,9 +1773,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
if batch.err != nil {
pgConn.contextWatcher.Unwatch()
multiResult.err = normalizeTimeoutError(multiResult.ctx, batch.err)
multiResult.closed = true
multiResult.err = batch.err
pgConn.unlock()
pgConn.asyncClose()
return multiResult
}
@ -1783,9 +1784,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
defer pgConn.exitPotentialWriteReadDeadlock()
_, err := pgConn.conn.Write(batch.buf)
if err != nil {
pgConn.contextWatcher.Unwatch()
multiResult.err = normalizeTimeoutError(multiResult.ctx, err)
multiResult.closed = true
multiResult.err = err
pgConn.unlock()
pgConn.asyncClose()
return multiResult
}

View File

@ -1420,6 +1420,52 @@ func TestConnExecBatch(t *testing.T) {
assert.Equal(t, "SELECT 1", results[2].CommandTag.String())
}
type mockConnection struct {
net.Conn
writeLatency *time.Duration
}
func (m mockConnection) Write(b []byte) (n int, err error) {
time.Sleep(*m.writeLatency)
return m.Conn.Write(b)
}
func TestConnExecBatchWriteError(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
var mockConn mockConnection
writeLatency := 0 * time.Second
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := net.Dial(network, address)
mockConn = mockConnection{conn, &writeLatency}
return mockConn, err
}
pgConn, err := pgconn.ConnectConfig(ctx, config)
require.NoError(t, err)
defer closeConn(t, pgConn)
batch := &pgconn.Batch{}
pgConn.Conn()
ctx2, cancel2 := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel2()
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
writeLatency = 2 * time.Second
mrr := pgConn.ExecBatch(ctx2, batch)
err = mrr.Close()
require.Error(t, err)
assert.ErrorIs(t, err, context.DeadlineExceeded)
require.True(t, pgConn.IsClosed())
}
func TestConnExecBatchDeferredError(t *testing.T) {
t.Parallel()