mirror of https://github.com/jackc/pgx.git
pipeline queue for client requests
parent
76593f37f7
commit
de3f868c1d
220
pgconn/pgconn.go
220
pgconn/pgconn.go
|
@ -1,6 +1,7 @@
|
|||
package pgconn
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/tls"
|
||||
|
@ -1408,9 +1409,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
|
||||
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
|
||||
type MultiResultReader struct {
|
||||
pgConn *PgConn
|
||||
ctx context.Context
|
||||
pipeline *Pipeline
|
||||
pgConn *PgConn
|
||||
ctx context.Context
|
||||
|
||||
rr *ResultReader
|
||||
|
||||
|
@ -1443,12 +1443,8 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
|
|||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
mrr.closed = true
|
||||
if mrr.pipeline != nil {
|
||||
mrr.pipeline.expectedReadyForQueryCount--
|
||||
} else {
|
||||
mrr.pgConn.contextWatcher.Unwatch()
|
||||
mrr.pgConn.unlock()
|
||||
}
|
||||
mrr.pgConn.contextWatcher.Unwatch()
|
||||
mrr.pgConn.unlock()
|
||||
case *pgproto3.ErrorResponse:
|
||||
mrr.err = ErrorResponseToPgError(msg)
|
||||
}
|
||||
|
@ -1672,7 +1668,11 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
|
|||
case *pgproto3.EmptyQueryResponse:
|
||||
rr.concludeCommand(CommandTag{}, nil)
|
||||
case *pgproto3.ErrorResponse:
|
||||
rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg))
|
||||
pgErr := ErrorResponseToPgError(msg)
|
||||
if rr.pipeline != nil {
|
||||
rr.pipeline.state.HandleError(pgErr)
|
||||
}
|
||||
rr.concludeCommand(CommandTag{}, pgErr)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
|
@ -1999,9 +1999,7 @@ type Pipeline struct {
|
|||
conn *PgConn
|
||||
ctx context.Context
|
||||
|
||||
expectedReadyForQueryCount int
|
||||
pendingSync bool
|
||||
|
||||
state pipelineState
|
||||
err error
|
||||
closed bool
|
||||
}
|
||||
|
@ -2012,6 +2010,122 @@ type PipelineSync struct{}
|
|||
// CloseComplete is returned by GetResults when a CloseComplete message is received.
|
||||
type CloseComplete struct{}
|
||||
|
||||
type pipelineRequestType int
|
||||
|
||||
const (
|
||||
PIPELINE_NIL pipelineRequestType = iota
|
||||
PIPELINE_PREPARE
|
||||
PIPELINE_QUERY_PARAMS
|
||||
PIPELINE_QUERY_PREPARED
|
||||
PIPELINE_DEALLOCATE
|
||||
PIPELINE_SYNC_REQUEST
|
||||
PIPELINE_FLUSH_REQUEST
|
||||
)
|
||||
|
||||
type pipelineRequestEvent struct {
|
||||
RequestType pipelineRequestType
|
||||
WasSentToServer bool
|
||||
BeforeFlushOrSync bool
|
||||
}
|
||||
|
||||
type pipelineState struct {
|
||||
requestEventQueue list.List
|
||||
lastRequestType pipelineRequestType
|
||||
pgErr *PgError
|
||||
expectedReadyForQueryCount int
|
||||
}
|
||||
|
||||
func (s *pipelineState) Init() {
|
||||
s.requestEventQueue.Init()
|
||||
s.lastRequestType = PIPELINE_NIL
|
||||
}
|
||||
|
||||
func (s *pipelineState) RegisterSendingToServer() {
|
||||
for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() {
|
||||
val := elem.Value.(pipelineRequestEvent)
|
||||
if val.WasSentToServer {
|
||||
return
|
||||
}
|
||||
val.WasSentToServer = true
|
||||
elem.Value = val
|
||||
}
|
||||
}
|
||||
|
||||
func (s *pipelineState) registerFlushingBufferOnServer() {
|
||||
for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() {
|
||||
val := elem.Value.(pipelineRequestEvent)
|
||||
if val.BeforeFlushOrSync {
|
||||
return
|
||||
}
|
||||
val.BeforeFlushOrSync = true
|
||||
elem.Value = val
|
||||
}
|
||||
}
|
||||
|
||||
func (s *pipelineState) PushBackRequestType(req pipelineRequestType) {
|
||||
if req == PIPELINE_NIL {
|
||||
return
|
||||
}
|
||||
|
||||
if req != PIPELINE_FLUSH_REQUEST {
|
||||
s.requestEventQueue.PushBack(pipelineRequestEvent{RequestType: req})
|
||||
}
|
||||
if req == PIPELINE_FLUSH_REQUEST || req == PIPELINE_SYNC_REQUEST {
|
||||
s.registerFlushingBufferOnServer()
|
||||
}
|
||||
s.lastRequestType = req
|
||||
|
||||
if req == PIPELINE_SYNC_REQUEST {
|
||||
s.expectedReadyForQueryCount++
|
||||
}
|
||||
}
|
||||
|
||||
func (s *pipelineState) ExtractFrontRequestType() pipelineRequestType {
|
||||
for {
|
||||
elem := s.requestEventQueue.Front()
|
||||
if elem == nil {
|
||||
return PIPELINE_NIL
|
||||
}
|
||||
val := elem.Value.(pipelineRequestEvent)
|
||||
if !(val.WasSentToServer && val.BeforeFlushOrSync) {
|
||||
return PIPELINE_NIL
|
||||
}
|
||||
|
||||
s.requestEventQueue.Remove(elem)
|
||||
if val.RequestType == PIPELINE_SYNC_REQUEST {
|
||||
s.pgErr = nil
|
||||
}
|
||||
if s.pgErr == nil {
|
||||
return val.RequestType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *pipelineState) HandleError(err *PgError) {
|
||||
s.pgErr = err
|
||||
}
|
||||
|
||||
func (s *pipelineState) HandleReadyForQuery() {
|
||||
s.expectedReadyForQueryCount--
|
||||
}
|
||||
|
||||
func (s *pipelineState) PendingSync() bool {
|
||||
var notPendingSync bool
|
||||
|
||||
if elem := s.requestEventQueue.Back(); elem != nil {
|
||||
val := elem.Value.(pipelineRequestEvent)
|
||||
notPendingSync = (val.RequestType == PIPELINE_SYNC_REQUEST) && val.WasSentToServer
|
||||
} else {
|
||||
notPendingSync = (s.lastRequestType == PIPELINE_SYNC_REQUEST) || (s.lastRequestType == PIPELINE_NIL)
|
||||
}
|
||||
|
||||
return !notPendingSync
|
||||
}
|
||||
|
||||
func (s *pipelineState) ExpectedReadyForQuery() int {
|
||||
return s.expectedReadyForQueryCount
|
||||
}
|
||||
|
||||
// StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent
|
||||
// to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection
|
||||
// to normal mode. While in pipeline mode, no methods that communicate with the server may be called except
|
||||
|
@ -2020,16 +2134,21 @@ type CloseComplete struct{}
|
|||
// Prefer ExecBatch when only sending one group of queries at once.
|
||||
func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return &Pipeline{
|
||||
pipeline := &Pipeline{
|
||||
closed: true,
|
||||
err: err,
|
||||
}
|
||||
pipeline.state.Init()
|
||||
|
||||
return pipeline
|
||||
}
|
||||
|
||||
pgConn.pipeline = Pipeline{
|
||||
conn: pgConn,
|
||||
ctx: ctx,
|
||||
}
|
||||
pgConn.pipeline.state.Init()
|
||||
|
||||
pipeline := &pgConn.pipeline
|
||||
|
||||
if ctx != context.Background() {
|
||||
|
@ -2052,10 +2171,10 @@ func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) {
|
|||
if p.closed {
|
||||
return
|
||||
}
|
||||
p.pendingSync = true
|
||||
|
||||
p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
|
||||
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
|
||||
p.state.PushBackRequestType(PIPELINE_PREPARE)
|
||||
}
|
||||
|
||||
// SendDeallocate deallocates a prepared statement.
|
||||
|
@ -2063,9 +2182,9 @@ func (p *Pipeline) SendDeallocate(name string) {
|
|||
if p.closed {
|
||||
return
|
||||
}
|
||||
p.pendingSync = true
|
||||
|
||||
p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name})
|
||||
p.state.PushBackRequestType(PIPELINE_DEALLOCATE)
|
||||
}
|
||||
|
||||
// SendQueryParams is the pipeline version of *PgConn.QueryParams.
|
||||
|
@ -2073,12 +2192,12 @@ func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs [
|
|||
if p.closed {
|
||||
return
|
||||
}
|
||||
p.pendingSync = true
|
||||
|
||||
p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs})
|
||||
p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
|
||||
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
|
||||
p.conn.frontend.SendExecute(&pgproto3.Execute{})
|
||||
p.state.PushBackRequestType(PIPELINE_QUERY_PARAMS)
|
||||
}
|
||||
|
||||
// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared.
|
||||
|
@ -2086,11 +2205,11 @@ func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, para
|
|||
if p.closed {
|
||||
return
|
||||
}
|
||||
p.pendingSync = true
|
||||
|
||||
p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
|
||||
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
|
||||
p.conn.frontend.SendExecute(&pgproto3.Execute{})
|
||||
p.state.PushBackRequestType(PIPELINE_QUERY_PREPARED)
|
||||
}
|
||||
|
||||
// SendFlushRequest sends a request for the server to flush its output buffer.
|
||||
|
@ -2104,9 +2223,24 @@ func (p *Pipeline) SendFlushRequest() {
|
|||
if p.closed {
|
||||
return
|
||||
}
|
||||
p.pendingSync = true
|
||||
|
||||
p.conn.frontend.Send(&pgproto3.Flush{})
|
||||
p.state.PushBackRequestType(PIPELINE_FLUSH_REQUEST)
|
||||
}
|
||||
|
||||
// SendPipelineSync marks a synchronization point in a pipeline by sending a sync message
|
||||
// without flushing the send buffer. This serves as the delimiter of an implicit
|
||||
// transaction and an error recovery point.
|
||||
//
|
||||
// Note that the request is not itself flushed to the server automatically; use Flush if
|
||||
// necessary. This copies the behavior of libpq PQsendPipelineSync.
|
||||
func (p *Pipeline) SendPipelineSync() {
|
||||
if p.closed {
|
||||
return
|
||||
}
|
||||
|
||||
p.conn.frontend.SendSync(&pgproto3.Sync{})
|
||||
p.state.PushBackRequestType(PIPELINE_SYNC_REQUEST)
|
||||
}
|
||||
|
||||
// Flush flushes the queued requests without establishing a synchronization point.
|
||||
|
@ -2131,28 +2265,14 @@ func (p *Pipeline) Flush() error {
|
|||
return err
|
||||
}
|
||||
|
||||
p.state.RegisterSendingToServer()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sync establishes a synchronization point and flushes the queued requests.
|
||||
func (p *Pipeline) Sync() error {
|
||||
if p.closed {
|
||||
if p.err != nil {
|
||||
return p.err
|
||||
}
|
||||
return errors.New("pipeline closed")
|
||||
}
|
||||
|
||||
p.conn.frontend.SendSync(&pgproto3.Sync{})
|
||||
err := p.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.pendingSync = false
|
||||
p.expectedReadyForQueryCount++
|
||||
|
||||
return nil
|
||||
p.SendPipelineSync()
|
||||
return p.Flush()
|
||||
}
|
||||
|
||||
// GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or
|
||||
|
@ -2166,30 +2286,13 @@ func (p *Pipeline) GetResults() (results any, err error) {
|
|||
return nil, errors.New("pipeline closed")
|
||||
}
|
||||
|
||||
if p.expectedReadyForQueryCount == 0 {
|
||||
if p.state.ExtractFrontRequestType() == PIPELINE_NIL {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return p.getResults()
|
||||
}
|
||||
|
||||
// GetResultsNotCheckSync gets the next results. If results are present, results may be a *ResultReader, *StatementDescription,
|
||||
// or *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError.
|
||||
//
|
||||
// This method should be used only if the request was sent to the server via methods SendFlushRequest and Flush,
|
||||
// without using Sync. In this case, you need to identify on your own when all results are received and
|
||||
// there is no need to call the method anymore.
|
||||
func (p *Pipeline) GetResultsNotCheckSync() (results any, err error) {
|
||||
if p.closed {
|
||||
if p.err != nil {
|
||||
return nil, p.err
|
||||
}
|
||||
return nil, errors.New("pipeline closed")
|
||||
}
|
||||
|
||||
return p.getResults()
|
||||
}
|
||||
|
||||
func (p *Pipeline) getResults() (results any, err error) {
|
||||
for {
|
||||
msg, err := p.conn.receiveMessage()
|
||||
|
@ -2228,13 +2331,13 @@ func (p *Pipeline) getResults() (results any, err error) {
|
|||
case *pgproto3.CloseComplete:
|
||||
return &CloseComplete{}, nil
|
||||
case *pgproto3.ReadyForQuery:
|
||||
p.expectedReadyForQueryCount--
|
||||
p.state.HandleReadyForQuery()
|
||||
return &PipelineSync{}, nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
pgErr := ErrorResponseToPgError(msg)
|
||||
p.state.HandleError(pgErr)
|
||||
return nil, pgErr
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2264,6 +2367,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) {
|
|||
// These should never happen here. But don't take chances that could lead to a deadlock.
|
||||
case *pgproto3.ErrorResponse:
|
||||
pgErr := ErrorResponseToPgError(msg)
|
||||
p.state.HandleError(pgErr)
|
||||
return nil, pgErr
|
||||
case *pgproto3.CommandComplete:
|
||||
p.conn.asyncClose()
|
||||
|
@ -2283,7 +2387,7 @@ func (p *Pipeline) Close() error {
|
|||
|
||||
p.closed = true
|
||||
|
||||
if p.pendingSync {
|
||||
if p.state.PendingSync() {
|
||||
p.conn.asyncClose()
|
||||
p.err = errors.New("pipeline has unsynced requests")
|
||||
p.conn.contextWatcher.Unwatch()
|
||||
|
@ -2292,7 +2396,7 @@ func (p *Pipeline) Close() error {
|
|||
return p.err
|
||||
}
|
||||
|
||||
for p.expectedReadyForQueryCount > 0 {
|
||||
for p.state.ExpectedReadyForQuery() > 0 {
|
||||
_, err := p.getResults()
|
||||
if err != nil {
|
||||
p.err = err
|
||||
|
|
|
@ -3003,70 +3003,6 @@ func TestPipelinePrepareQuery(t *testing.T) {
|
|||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestPipelinePrepareQueryWithFlush(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
pipeline := pgConn.StartPipeline(ctx)
|
||||
pipeline.SendPrepare("ps", "select $1::text as msg", nil)
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil)
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("goodbye")}, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err := pipeline.GetResultsNotCheckSync()
|
||||
require.NoError(t, err)
|
||||
sd, ok := results.(*pgconn.StatementDescription)
|
||||
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
||||
require.Len(t, sd.Fields, 1)
|
||||
require.Equal(t, "msg", string(sd.Fields[0].Name))
|
||||
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
|
||||
|
||||
results, err = pipeline.GetResultsNotCheckSync()
|
||||
require.NoError(t, err)
|
||||
rr, ok := results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult := rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "hello", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResultsNotCheckSync()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "goodbye", string(readResult.Rows[0][0]))
|
||||
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
err = pipeline.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestPipelineQueryErrorBetweenSyncs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -3174,6 +3110,344 @@ func TestPipelineQueryErrorBetweenSyncs(t *testing.T) {
|
|||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestPipelineFlushForSingleRequests(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
pipeline := pgConn.StartPipeline(ctx)
|
||||
|
||||
pipeline.SendPrepare("ps", "select $1::text as msg", nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err := pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
sd, ok := results.(*pgconn.StatementDescription)
|
||||
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
||||
require.Len(t, sd.Fields, 1)
|
||||
require.Equal(t, "msg", string(sd.Fields[0].Name))
|
||||
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok := results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult := rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "hello", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendDeallocate("ps")
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.CloseComplete)
|
||||
require.Truef(t, ok, "expected CloseComplete, got: %#v", results)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
err = pipeline.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestPipelineFlushForRequestSeries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
pipeline := pgConn.StartPipeline(ctx)
|
||||
pipeline.SendPrepare("ps", "select $1::bigint as num", nil)
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err := pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
sd, ok := results.(*pgconn.StatementDescription)
|
||||
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
||||
require.Len(t, sd.Fields, 1)
|
||||
require.Equal(t, "num", string(sd.Fields[0].Name))
|
||||
require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("1")}, nil, nil)
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("2")}, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok := results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult := rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "2", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("3")}, nil, nil)
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("4")}, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "3", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "4", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("5")}, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("6")}, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "5", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "6", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
err = pipeline.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestPipelineFlushWithError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
pipeline := pgConn.StartPipeline(ctx)
|
||||
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
||||
pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil)
|
||||
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err := pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok := results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult := rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
var pgErr *pgconn.PgError
|
||||
require.ErrorAs(t, readResult.Err, &pgErr)
|
||||
require.Equal(t, "22012", pgErr.Code)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
|
||||
pipeline.SendPipelineSync()
|
||||
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "5", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
err = pipeline.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestPipelineCloseReadsUnreadResults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue