diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 72368926..28ee01a7 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1,6 +1,7 @@ package pgconn import ( + "container/list" "context" "crypto/md5" "crypto/tls" @@ -1408,9 +1409,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { - pgConn *PgConn - ctx context.Context - pipeline *Pipeline + pgConn *PgConn + ctx context.Context rr *ResultReader @@ -1443,12 +1443,8 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) switch msg := msg.(type) { case *pgproto3.ReadyForQuery: mrr.closed = true - if mrr.pipeline != nil { - mrr.pipeline.expectedReadyForQueryCount-- - } else { - mrr.pgConn.contextWatcher.Unwatch() - mrr.pgConn.unlock() - } + mrr.pgConn.contextWatcher.Unwatch() + mrr.pgConn.unlock() case *pgproto3.ErrorResponse: mrr.err = ErrorResponseToPgError(msg) } @@ -1672,7 +1668,11 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.EmptyQueryResponse: rr.concludeCommand(CommandTag{}, nil) case *pgproto3.ErrorResponse: - rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg)) + pgErr := ErrorResponseToPgError(msg) + if rr.pipeline != nil { + rr.pipeline.state.HandleError(pgErr) + } + rr.concludeCommand(CommandTag{}, pgErr) } return msg, nil @@ -1999,9 +1999,7 @@ type Pipeline struct { conn *PgConn ctx context.Context - expectedReadyForQueryCount int - pendingSync bool - + state pipelineState err error closed bool } @@ -2012,6 +2010,122 @@ type PipelineSync struct{} // CloseComplete is returned by GetResults when a CloseComplete message is received. type CloseComplete struct{} +type pipelineRequestType int + +const ( + PIPELINE_NIL pipelineRequestType = iota + PIPELINE_PREPARE + PIPELINE_QUERY_PARAMS + PIPELINE_QUERY_PREPARED + PIPELINE_DEALLOCATE + PIPELINE_SYNC_REQUEST + PIPELINE_FLUSH_REQUEST +) + +type pipelineRequestEvent struct { + RequestType pipelineRequestType + WasSentToServer bool + BeforeFlushOrSync bool +} + +type pipelineState struct { + requestEventQueue list.List + lastRequestType pipelineRequestType + pgErr *PgError + expectedReadyForQueryCount int +} + +func (s *pipelineState) Init() { + s.requestEventQueue.Init() + s.lastRequestType = PIPELINE_NIL +} + +func (s *pipelineState) RegisterSendingToServer() { + for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() { + val := elem.Value.(pipelineRequestEvent) + if val.WasSentToServer { + return + } + val.WasSentToServer = true + elem.Value = val + } +} + +func (s *pipelineState) registerFlushingBufferOnServer() { + for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() { + val := elem.Value.(pipelineRequestEvent) + if val.BeforeFlushOrSync { + return + } + val.BeforeFlushOrSync = true + elem.Value = val + } +} + +func (s *pipelineState) PushBackRequestType(req pipelineRequestType) { + if req == PIPELINE_NIL { + return + } + + if req != PIPELINE_FLUSH_REQUEST { + s.requestEventQueue.PushBack(pipelineRequestEvent{RequestType: req}) + } + if req == PIPELINE_FLUSH_REQUEST || req == PIPELINE_SYNC_REQUEST { + s.registerFlushingBufferOnServer() + } + s.lastRequestType = req + + if req == PIPELINE_SYNC_REQUEST { + s.expectedReadyForQueryCount++ + } +} + +func (s *pipelineState) ExtractFrontRequestType() pipelineRequestType { + for { + elem := s.requestEventQueue.Front() + if elem == nil { + return PIPELINE_NIL + } + val := elem.Value.(pipelineRequestEvent) + if !(val.WasSentToServer && val.BeforeFlushOrSync) { + return PIPELINE_NIL + } + + s.requestEventQueue.Remove(elem) + if val.RequestType == PIPELINE_SYNC_REQUEST { + s.pgErr = nil + } + if s.pgErr == nil { + return val.RequestType + } + } +} + +func (s *pipelineState) HandleError(err *PgError) { + s.pgErr = err +} + +func (s *pipelineState) HandleReadyForQuery() { + s.expectedReadyForQueryCount-- +} + +func (s *pipelineState) PendingSync() bool { + var notPendingSync bool + + if elem := s.requestEventQueue.Back(); elem != nil { + val := elem.Value.(pipelineRequestEvent) + notPendingSync = (val.RequestType == PIPELINE_SYNC_REQUEST) && val.WasSentToServer + } else { + notPendingSync = (s.lastRequestType == PIPELINE_SYNC_REQUEST) || (s.lastRequestType == PIPELINE_NIL) + } + + return !notPendingSync +} + +func (s *pipelineState) ExpectedReadyForQuery() int { + return s.expectedReadyForQueryCount +} + // StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent // to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection // to normal mode. While in pipeline mode, no methods that communicate with the server may be called except @@ -2020,16 +2134,21 @@ type CloseComplete struct{} // Prefer ExecBatch when only sending one group of queries at once. func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline { if err := pgConn.lock(); err != nil { - return &Pipeline{ + pipeline := &Pipeline{ closed: true, err: err, } + pipeline.state.Init() + + return pipeline } pgConn.pipeline = Pipeline{ conn: pgConn, ctx: ctx, } + pgConn.pipeline.state.Init() + pipeline := &pgConn.pipeline if ctx != context.Background() { @@ -2052,10 +2171,10 @@ func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) { if p.closed { return } - p.pendingSync = true p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) + p.state.PushBackRequestType(PIPELINE_PREPARE) } // SendDeallocate deallocates a prepared statement. @@ -2063,9 +2182,9 @@ func (p *Pipeline) SendDeallocate(name string) { if p.closed { return } - p.pendingSync = true p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name}) + p.state.PushBackRequestType(PIPELINE_DEALLOCATE) } // SendQueryParams is the pipeline version of *PgConn.QueryParams. @@ -2073,12 +2192,12 @@ func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs [ if p.closed { return } - p.pendingSync = true p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) p.conn.frontend.SendExecute(&pgproto3.Execute{}) + p.state.PushBackRequestType(PIPELINE_QUERY_PARAMS) } // SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared. @@ -2086,11 +2205,11 @@ func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, para if p.closed { return } - p.pendingSync = true p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) p.conn.frontend.SendExecute(&pgproto3.Execute{}) + p.state.PushBackRequestType(PIPELINE_QUERY_PREPARED) } // SendFlushRequest sends a request for the server to flush its output buffer. @@ -2104,9 +2223,24 @@ func (p *Pipeline) SendFlushRequest() { if p.closed { return } - p.pendingSync = true p.conn.frontend.Send(&pgproto3.Flush{}) + p.state.PushBackRequestType(PIPELINE_FLUSH_REQUEST) +} + +// SendPipelineSync marks a synchronization point in a pipeline by sending a sync message +// without flushing the send buffer. This serves as the delimiter of an implicit +// transaction and an error recovery point. +// +// Note that the request is not itself flushed to the server automatically; use Flush if +// necessary. This copies the behavior of libpq PQsendPipelineSync. +func (p *Pipeline) SendPipelineSync() { + if p.closed { + return + } + + p.conn.frontend.SendSync(&pgproto3.Sync{}) + p.state.PushBackRequestType(PIPELINE_SYNC_REQUEST) } // Flush flushes the queued requests without establishing a synchronization point. @@ -2131,28 +2265,14 @@ func (p *Pipeline) Flush() error { return err } + p.state.RegisterSendingToServer() return nil } // Sync establishes a synchronization point and flushes the queued requests. func (p *Pipeline) Sync() error { - if p.closed { - if p.err != nil { - return p.err - } - return errors.New("pipeline closed") - } - - p.conn.frontend.SendSync(&pgproto3.Sync{}) - err := p.Flush() - if err != nil { - return err - } - - p.pendingSync = false - p.expectedReadyForQueryCount++ - - return nil + p.SendPipelineSync() + return p.Flush() } // GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or @@ -2166,30 +2286,13 @@ func (p *Pipeline) GetResults() (results any, err error) { return nil, errors.New("pipeline closed") } - if p.expectedReadyForQueryCount == 0 { + if p.state.ExtractFrontRequestType() == PIPELINE_NIL { return nil, nil } return p.getResults() } -// GetResultsNotCheckSync gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, -// or *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. -// -// This method should be used only if the request was sent to the server via methods SendFlushRequest and Flush, -// without using Sync. In this case, you need to identify on your own when all results are received and -// there is no need to call the method anymore. -func (p *Pipeline) GetResultsNotCheckSync() (results any, err error) { - if p.closed { - if p.err != nil { - return nil, p.err - } - return nil, errors.New("pipeline closed") - } - - return p.getResults() -} - func (p *Pipeline) getResults() (results any, err error) { for { msg, err := p.conn.receiveMessage() @@ -2228,13 +2331,13 @@ func (p *Pipeline) getResults() (results any, err error) { case *pgproto3.CloseComplete: return &CloseComplete{}, nil case *pgproto3.ReadyForQuery: - p.expectedReadyForQueryCount-- + p.state.HandleReadyForQuery() return &PipelineSync{}, nil case *pgproto3.ErrorResponse: pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) return nil, pgErr } - } } @@ -2264,6 +2367,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { // These should never happen here. But don't take chances that could lead to a deadlock. case *pgproto3.ErrorResponse: pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) return nil, pgErr case *pgproto3.CommandComplete: p.conn.asyncClose() @@ -2283,7 +2387,7 @@ func (p *Pipeline) Close() error { p.closed = true - if p.pendingSync { + if p.state.PendingSync() { p.conn.asyncClose() p.err = errors.New("pipeline has unsynced requests") p.conn.contextWatcher.Unwatch() @@ -2292,7 +2396,7 @@ func (p *Pipeline) Close() error { return p.err } - for p.expectedReadyForQueryCount > 0 { + for p.state.ExpectedReadyForQuery() > 0 { _, err := p.getResults() if err != nil { p.err = err diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index b56d11fd..f800677d 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -3003,70 +3003,6 @@ func TestPipelinePrepareQuery(t *testing.T) { ensureConnValid(t, pgConn) } -func TestPipelinePrepareQueryWithFlush(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) - defer cancel() - - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - pipeline := pgConn.StartPipeline(ctx) - pipeline.SendPrepare("ps", "select $1::text as msg", nil) - pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil) - pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("goodbye")}, nil, nil) - pipeline.SendFlushRequest() - err = pipeline.Flush() - require.NoError(t, err) - - results, err := pipeline.GetResultsNotCheckSync() - require.NoError(t, err) - sd, ok := results.(*pgconn.StatementDescription) - require.Truef(t, ok, "expected StatementDescription, got: %#v", results) - require.Len(t, sd.Fields, 1) - require.Equal(t, "msg", string(sd.Fields[0].Name)) - require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) - - results, err = pipeline.GetResultsNotCheckSync() - require.NoError(t, err) - rr, ok := results.(*pgconn.ResultReader) - require.Truef(t, ok, "expected ResultReader, got: %#v", results) - readResult := rr.Read() - require.NoError(t, readResult.Err) - require.Len(t, readResult.Rows, 1) - require.Len(t, readResult.Rows[0], 1) - require.Equal(t, "hello", string(readResult.Rows[0][0])) - - results, err = pipeline.GetResultsNotCheckSync() - require.NoError(t, err) - rr, ok = results.(*pgconn.ResultReader) - require.Truef(t, ok, "expected ResultReader, got: %#v", results) - readResult = rr.Read() - require.NoError(t, readResult.Err) - require.Len(t, readResult.Rows, 1) - require.Len(t, readResult.Rows[0], 1) - require.Equal(t, "goodbye", string(readResult.Rows[0][0])) - - err = pipeline.Sync() - require.NoError(t, err) - - results, err = pipeline.GetResults() - require.NoError(t, err) - _, ok = results.(*pgconn.PipelineSync) - require.Truef(t, ok, "expected PipelineSync, got: %#v", results) - - results, err = pipeline.GetResults() - require.NoError(t, err) - require.Nil(t, results) - - err = pipeline.Close() - require.NoError(t, err) - - ensureConnValid(t, pgConn) -} - func TestPipelineQueryErrorBetweenSyncs(t *testing.T) { t.Parallel() @@ -3174,6 +3110,344 @@ func TestPipelineQueryErrorBetweenSyncs(t *testing.T) { ensureConnValid(t, pgConn) } +func TestPipelineFlushForSingleRequests(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + + pipeline.SendPrepare("ps", "select $1::text as msg", nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, "msg", string(sd.Fields[0].Name)) + require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "hello", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendDeallocate("ps") + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.CloseComplete) + require.Truef(t, ok, "expected CloseComplete, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Sync() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineFlushForRequestSeries(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendPrepare("ps", "select $1::bigint as num", nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, "num", string(sd.Fields[0].Name)) + require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("1")}, nil, nil) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("2")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "2", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("3")}, nil, nil) + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("4")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "3", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "4", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("5")}, nil, nil) + pipeline.SendFlushRequest() + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("6")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "5", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "6", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Sync() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineFlushWithError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + var pgErr *pgconn.PgError + require.ErrorAs(t, readResult.Err, &pgErr) + require.Equal(t, "22012", pgErr.Code) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendPipelineSync() + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "5", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Sync() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + func TestPipelineCloseReadsUnreadResults(t *testing.T) { t.Parallel()