add flush request in pipeline

pull/2200/head
zenkovev 2024-12-17 11:49:13 +03:00
parent 3e6c719698
commit 76593f37f7
2 changed files with 97 additions and 0 deletions

View File

@ -2093,6 +2093,22 @@ func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, para
p.conn.frontend.SendExecute(&pgproto3.Execute{})
}
// SendFlushRequest sends a request for the server to flush its output buffer.
//
// The server flushes its output buffer automatically as a result of Sync being called,
// or on any request when not in pipeline mode; this function is useful to cause the server
// to flush its output buffer in pipeline mode without establishing a synchronization point.
// Note that the request is not itself flushed to the server automatically; use Flush if
// necessary. This copies the behavior of libpq PQsendFlushRequest.
func (p *Pipeline) SendFlushRequest() {
if p.closed {
return
}
p.pendingSync = true
p.conn.frontend.Send(&pgproto3.Flush{})
}
// Flush flushes the queued requests without establishing a synchronization point.
func (p *Pipeline) Flush() error {
if p.closed {
@ -2157,6 +2173,23 @@ func (p *Pipeline) GetResults() (results any, err error) {
return p.getResults()
}
// GetResultsNotCheckSync gets the next results. If results are present, results may be a *ResultReader, *StatementDescription,
// or *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError.
//
// This method should be used only if the request was sent to the server via methods SendFlushRequest and Flush,
// without using Sync. In this case, you need to identify on your own when all results are received and
// there is no need to call the method anymore.
func (p *Pipeline) GetResultsNotCheckSync() (results any, err error) {
if p.closed {
if p.err != nil {
return nil, p.err
}
return nil, errors.New("pipeline closed")
}
return p.getResults()
}
func (p *Pipeline) getResults() (results any, err error) {
for {
msg, err := p.conn.receiveMessage()

View File

@ -3003,6 +3003,70 @@ func TestPipelinePrepareQuery(t *testing.T) {
ensureConnValid(t, pgConn)
}
func TestPipelinePrepareQueryWithFlush(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
pipeline := pgConn.StartPipeline(ctx)
pipeline.SendPrepare("ps", "select $1::text as msg", nil)
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil)
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("goodbye")}, nil, nil)
pipeline.SendFlushRequest()
err = pipeline.Flush()
require.NoError(t, err)
results, err := pipeline.GetResultsNotCheckSync()
require.NoError(t, err)
sd, ok := results.(*pgconn.StatementDescription)
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
require.Len(t, sd.Fields, 1)
require.Equal(t, "msg", string(sd.Fields[0].Name))
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
results, err = pipeline.GetResultsNotCheckSync()
require.NoError(t, err)
rr, ok := results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult := rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "hello", string(readResult.Rows[0][0]))
results, err = pipeline.GetResultsNotCheckSync()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "goodbye", string(readResult.Rows[0][0]))
err = pipeline.Sync()
require.NoError(t, err)
results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.PipelineSync)
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
err = pipeline.Close()
require.NoError(t, err)
ensureConnValid(t, pgConn)
}
func TestPipelineQueryErrorBetweenSyncs(t *testing.T) {
t.Parallel()