From ae2881a23c66209ca3525000547fb4debbe8baf4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 2 Jul 2022 21:48:16 -0500 Subject: [PATCH] Add pipeline mode to pgconn --- CHANGELOG.md | 2 + pgconn/doc.go | 5 + pgconn/pgconn.go | 272 +++++++++++++++++++++++++++- pgconn/pgconn_test.go | 412 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 685 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 51582553..81b6cb26 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ pgconn now uses non-blocking IO. This is a significant internal restructuring, b `CheckConn()` checks a connection's liveness by doing a non-blocking read. This can be used to detect database restarts or network interruptions without executing a query or a ping. +pgconn now supports pipeline mode. + ## pgtype The `pgtype` package has been significantly changed. diff --git a/pgconn/doc.go b/pgconn/doc.go index cde58cd8..e3242cf4 100644 --- a/pgconn/doc.go +++ b/pgconn/doc.go @@ -18,6 +18,11 @@ Executing Multiple Queries in a Single Round Trip Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query result. The ReadAll method reads all query results into memory. +Pipeline Mode + +Pipeline mode allows sending queries without having read the results of previously sent queries. It allows +control of exactly how many and when network round trips occur. + Context Support All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 306b2e16..b386a786 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -81,6 +81,7 @@ type PgConn struct { // Reusable / preallocated resources resultReader ResultReader multiResultReader MultiResultReader + pipeline Pipeline contextWatcher *ctxwatch.ContextWatcher cleanupDone chan struct{} @@ -1242,8 +1243,9 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { - pgConn *PgConn - ctx context.Context + pgConn *PgConn + ctx context.Context + pipeline *Pipeline rr *ResultReader @@ -1276,9 +1278,13 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - mrr.pgConn.contextWatcher.Unwatch() mrr.closed = true - mrr.pgConn.unlock() + if mrr.pipeline != nil { + mrr.pipeline.expectedReadyForQueryCount-- + } else { + mrr.pgConn.contextWatcher.Unwatch() + mrr.pgConn.unlock() + } case *pgproto3.ErrorResponse: mrr.err = ErrorResponseToPgError(msg) } @@ -1341,6 +1347,7 @@ func (mrr *MultiResultReader) Close() error { type ResultReader struct { pgConn *PgConn multiResultReader *MultiResultReader + pipeline *Pipeline ctx context.Context fieldDescriptions []pgproto3.FieldDescription @@ -1429,7 +1436,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { } } - if rr.multiResultReader == nil { + if rr.multiResultReader == nil && rr.pipeline == nil { for { msg, err := rr.receiveMessage() if err != nil { @@ -1539,7 +1546,8 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor } // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a -// transaction is already in progress or SQL contains transaction control statements. +// transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing +// multiple queries in a single round trip than using pipeline mode. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ @@ -1676,3 +1684,255 @@ func Construct(hc *HijackedConn) (*PgConn, error) { return pgConn, nil } + +// Pipeline represents a connection in pipeline mode. +// +// SendPrepare, SendQueryParam, and SendQueryPrepared queue requests to the server. These requests are not written until +// pipeline is flushed by Flush or Sync. Sync must be called after the last request is queued. Requests between +// synchronization points are implicitly transactional unless explicit transaction control statements have been issued. +// +// The context the pipeline was started with is in effect for the entire life of the Pipeline. +// +// For a deeper understanding of pipeline mode see the PostgreSQL documentation for the extended query protocol +// (https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY) and the libpq pipeline mode +// (https://www.postgresql.org/docs/current/libpq-pipeline-mode.html). +type Pipeline struct { + conn *PgConn + ctx context.Context + + expectedReadyForQueryCount int + pendingSync bool + + err error + closed bool +} + +// PipelineSync is returned by GetResults when a ReadyForQuery message is received. +type PipelineSync struct{} + +// StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent +// to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection +// to normal mode. While in pipeline mode, no methods that communicate with the server may be called except +// CancelRequest and Close. ctx is in effect for entire life of the *Pipeline. +// +// Prefer ExecBatch when only sending one group of queries at once. +func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline { + if err := pgConn.lock(); err != nil { + return &Pipeline{ + closed: true, + err: err, + } + } + + pgConn.pipeline = Pipeline{ + conn: pgConn, + ctx: ctx, + } + pipeline := &pgConn.pipeline + + if ctx != context.Background() { + select { + case <-ctx.Done(): + pipeline.closed = true + pipeline.err = newContextAlreadyDoneError(ctx) + pgConn.unlock() + return pipeline + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + return pipeline +} + +// SendPrepare is the pipeline version of *PgConn.Prepare. +func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) { + if p.closed { + return + } + p.pendingSync = true + + p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) + p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) +} + +// SendQueryParams is the pipeline version of *PgConn.QueryParams. +func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { + if p.closed { + return + } + p.pendingSync = true + + p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) + p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) + p.conn.frontend.SendExecute(&pgproto3.Execute{}) +} + +// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared. +func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { + if p.closed { + return + } + p.pendingSync = true + + p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) + p.conn.frontend.SendExecute(&pgproto3.Execute{}) +} + +// Flush flushes the queued requests without establishing a synchronization point. +func (p *Pipeline) Flush() error { + if p.closed { + if p.err != nil { + return p.err + } + return errors.New("pipeline closed") + } + + err := p.conn.frontend.Flush() + if err != nil { + err = preferContextOverNetTimeoutError(p.ctx, err) + + p.conn.asyncClose() + + p.conn.contextWatcher.Unwatch() + p.conn.unlock() + p.closed = true + p.err = err + return err + } + + return nil +} + +// Sync establishes a synchronization point and flushes the queued requests. +func (p *Pipeline) Sync() error { + p.conn.frontend.SendSync(&pgproto3.Sync{}) + err := p.Flush() + if err != nil { + return err + } + + p.pendingSync = false + p.expectedReadyForQueryCount++ + + return nil +} + +// GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or +// *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. If no +// results are available, results and err will both be nil. +func (p *Pipeline) GetResults() (results any, err error) { + if p.expectedReadyForQueryCount == 0 { + return nil, nil + } + + for { + msg, err := p.conn.receiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + p.conn.resultReader = ResultReader{ + pgConn: p.conn, + pipeline: p, + ctx: p.ctx, + fieldDescriptions: msg.Fields, + } + return &p.conn.resultReader, nil + case *pgproto3.CommandComplete: + p.conn.resultReader = ResultReader{ + commandTag: p.conn.makeCommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, + } + return &p.conn.resultReader, nil + case *pgproto3.ParseComplete: + peekedMsg, err := p.conn.peekMessage() + if err != nil { + return nil, err + } + if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok { + return p.getResultsPrepare() + } + case *pgproto3.ReadyForQuery: + p.expectedReadyForQueryCount-- + return &PipelineSync{}, nil + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + return nil, pgErr + } + + } + +} + +func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { + psd := &StatementDescription{} + + for { + msg, err := p.conn.receiveMessage() + if err != nil { + p.conn.asyncClose() + return nil, preferContextOverNetTimeoutError(p.ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) + copy(psd.ParamOIDs, msg.ParameterOIDs) + case *pgproto3.RowDescription: + psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) + copy(psd.Fields, msg.Fields) + return psd, nil + + // These should never happen here. But don't take chances that could lead to a deadlock. + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + return nil, pgErr + case *pgproto3.CommandComplete: + p.conn.asyncClose() + return nil, errors.New("BUG: received CommandComplete while handling Describe") + case *pgproto3.ReadyForQuery: + p.conn.asyncClose() + return nil, errors.New("BUG: received ReadyForQuery while handling Describe") + } + } +} + +// Close closes the pipeline and returns the connection to normal mode. +func (p *Pipeline) Close() error { + if p.closed { + return p.err + } + p.closed = true + + if p.pendingSync { + p.conn.asyncClose() + p.err = errors.New("pipeline has unsynced requests") + p.conn.contextWatcher.Unwatch() + p.conn.unlock() + + return p.err + } + + for p.expectedReadyForQueryCount > 0 { + _, err := p.GetResults() + if err != nil { + var pgErr *PgError + if !errors.As(err, &pgErr) { + p.conn.asyncClose() + p.err = err + break + } + } + } + + p.conn.contextWatcher.Unwatch() + p.conn.unlock() + + return p.err +} diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index b47f17d6..c72ed6d6 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -20,6 +20,7 @@ import ( "github.com/jackc/pgx/v5/internal/pgmock" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" + "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -2094,6 +2095,417 @@ func TestConnCheckConn(t *testing.T) { require.Error(t, err) } +func TestPipelinePrepare(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendPrepare("selectInt", "select $1::int as a", nil) + pipeline.SendPrepare("selectText", "select $1::text as b", nil) + pipeline.SendPrepare("selectNoParams", "select 42 as c", nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, string(sd.Fields[0].Name), "a") + require.Equal(t, []uint32{pgtype.Int4OID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + sd, ok = results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, string(sd.Fields[0].Name), "b") + require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + sd, ok = results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, string(sd.Fields[0].Name), "c") + require.Equal(t, []uint32{}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelinePrepareError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendPrepare("selectInt", "select $1::int as a", nil) + pipeline.SendPrepare("selectError", "bad", nil) + pipeline.SendPrepare("selectText", "select $1::text as b", nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, string(sd.Fields[0].Name), "a") + require.Equal(t, []uint32{pgtype.Int4OID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + var pgErr *pgconn.PgError + require.ErrorAs(t, err, &pgErr) + require.Nil(t, results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineQuery(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "2", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "3", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "4", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "5", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelinePrepareQuery(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendPrepare("ps", "select $1::text as msg", nil) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("goodbye")}, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, string(sd.Fields[0].Name), "msg") + require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "hello", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "goodbye", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineQueryErrorBetweenSyncs(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 6`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "2", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "3", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + var pgErr *pgconn.PgError + require.ErrorAs(t, readResult.Err, &pgErr) + require.Equal(t, "22012", pgErr.Code) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "5", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "6", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineCloseReadsUnreadResults(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(context.Background()) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + err = pipeline.Close() + require.EqualError(t, err, "pipeline has unsynced requests") +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil {