mirror of https://github.com/jackc/pgx.git
Merge pull request #2240 from bonnefoa/fix-watch-panic
Unwatch and close connection on a batch write errorpull/2257/head
commit
1abf7d9050
|
@ -1773,9 +1773,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
|||
|
||||
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
|
||||
if batch.err != nil {
|
||||
pgConn.contextWatcher.Unwatch()
|
||||
multiResult.err = normalizeTimeoutError(multiResult.ctx, batch.err)
|
||||
multiResult.closed = true
|
||||
multiResult.err = batch.err
|
||||
pgConn.unlock()
|
||||
pgConn.asyncClose()
|
||||
return multiResult
|
||||
}
|
||||
|
||||
|
@ -1783,9 +1784,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
|||
defer pgConn.exitPotentialWriteReadDeadlock()
|
||||
_, err := pgConn.conn.Write(batch.buf)
|
||||
if err != nil {
|
||||
pgConn.contextWatcher.Unwatch()
|
||||
multiResult.err = normalizeTimeoutError(multiResult.ctx, err)
|
||||
multiResult.closed = true
|
||||
multiResult.err = err
|
||||
pgConn.unlock()
|
||||
pgConn.asyncClose()
|
||||
return multiResult
|
||||
}
|
||||
|
||||
|
|
|
@ -1420,6 +1420,52 @@ func TestConnExecBatch(t *testing.T) {
|
|||
assert.Equal(t, "SELECT 1", results[2].CommandTag.String())
|
||||
}
|
||||
|
||||
type mockConnection struct {
|
||||
net.Conn
|
||||
writeLatency *time.Duration
|
||||
}
|
||||
|
||||
func (m mockConnection) Write(b []byte) (n int, err error) {
|
||||
time.Sleep(*m.writeLatency)
|
||||
return m.Conn.Write(b)
|
||||
}
|
||||
|
||||
func TestConnExecBatchWriteError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
|
||||
var mockConn mockConnection
|
||||
writeLatency := 0 * time.Second
|
||||
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := net.Dial(network, address)
|
||||
mockConn = mockConnection{conn, &writeLatency}
|
||||
return mockConn, err
|
||||
}
|
||||
|
||||
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
batch := &pgconn.Batch{}
|
||||
pgConn.Conn()
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
|
||||
writeLatency = 2 * time.Second
|
||||
mrr := pgConn.ExecBatch(ctx2, batch)
|
||||
err = mrr.Close()
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.True(t, pgConn.IsClosed())
|
||||
}
|
||||
|
||||
func TestConnExecBatchDeferredError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue