Merge pull request #2240 from bonnefoa/fix-watch-panic

Unwatch and close connection on a batch write error
pull/2257/head
Jack Christensen 2025-01-25 08:38:33 -06:00 committed by GitHub
commit 1abf7d9050
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 4 deletions

View File

@ -1773,9 +1773,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
if batch.err != nil {
pgConn.contextWatcher.Unwatch()
multiResult.err = normalizeTimeoutError(multiResult.ctx, batch.err)
multiResult.closed = true
multiResult.err = batch.err
pgConn.unlock()
pgConn.asyncClose()
return multiResult
}
@ -1783,9 +1784,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
defer pgConn.exitPotentialWriteReadDeadlock()
_, err := pgConn.conn.Write(batch.buf)
if err != nil {
pgConn.contextWatcher.Unwatch()
multiResult.err = normalizeTimeoutError(multiResult.ctx, err)
multiResult.closed = true
multiResult.err = err
pgConn.unlock()
pgConn.asyncClose()
return multiResult
}

View File

@ -1420,6 +1420,52 @@ func TestConnExecBatch(t *testing.T) {
assert.Equal(t, "SELECT 1", results[2].CommandTag.String())
}
type mockConnection struct {
net.Conn
writeLatency *time.Duration
}
func (m mockConnection) Write(b []byte) (n int, err error) {
time.Sleep(*m.writeLatency)
return m.Conn.Write(b)
}
func TestConnExecBatchWriteError(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
var mockConn mockConnection
writeLatency := 0 * time.Second
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := net.Dial(network, address)
mockConn = mockConnection{conn, &writeLatency}
return mockConn, err
}
pgConn, err := pgconn.ConnectConfig(ctx, config)
require.NoError(t, err)
defer closeConn(t, pgConn)
batch := &pgconn.Batch{}
pgConn.Conn()
ctx2, cancel2 := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel2()
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
writeLatency = 2 * time.Second
mrr := pgConn.ExecBatch(ctx2, batch)
err = mrr.Close()
require.Error(t, err)
assert.ErrorIs(t, err, context.DeadlineExceeded)
require.True(t, pgConn.IsClosed())
}
func TestConnExecBatchDeferredError(t *testing.T) {
t.Parallel()