mirror of https://github.com/jackc/pgx.git
Add pipeline mode to pgconn
parent
ed3e9f1dd4
commit
ae2881a23c
|
@ -16,6 +16,8 @@ pgconn now uses non-blocking IO. This is a significant internal restructuring, b
|
|||
|
||||
`CheckConn()` checks a connection's liveness by doing a non-blocking read. This can be used to detect database restarts or network interruptions without executing a query or a ping.
|
||||
|
||||
pgconn now supports pipeline mode.
|
||||
|
||||
## pgtype
|
||||
|
||||
The `pgtype` package has been significantly changed.
|
||||
|
|
|
@ -18,6 +18,11 @@ Executing Multiple Queries in a Single Round Trip
|
|||
Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query
|
||||
result. The ReadAll method reads all query results into memory.
|
||||
|
||||
Pipeline Mode
|
||||
|
||||
Pipeline mode allows sending queries without having read the results of previously sent queries. It allows
|
||||
control of exactly how many and when network round trips occur.
|
||||
|
||||
Context Support
|
||||
|
||||
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the
|
||||
|
|
266
pgconn/pgconn.go
266
pgconn/pgconn.go
|
@ -81,6 +81,7 @@ type PgConn struct {
|
|||
// Reusable / preallocated resources
|
||||
resultReader ResultReader
|
||||
multiResultReader MultiResultReader
|
||||
pipeline Pipeline
|
||||
contextWatcher *ctxwatch.ContextWatcher
|
||||
|
||||
cleanupDone chan struct{}
|
||||
|
@ -1244,6 +1245,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
type MultiResultReader struct {
|
||||
pgConn *PgConn
|
||||
ctx context.Context
|
||||
pipeline *Pipeline
|
||||
|
||||
rr *ResultReader
|
||||
|
||||
|
@ -1276,9 +1278,13 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
|
|||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
mrr.pgConn.contextWatcher.Unwatch()
|
||||
mrr.closed = true
|
||||
if mrr.pipeline != nil {
|
||||
mrr.pipeline.expectedReadyForQueryCount--
|
||||
} else {
|
||||
mrr.pgConn.contextWatcher.Unwatch()
|
||||
mrr.pgConn.unlock()
|
||||
}
|
||||
case *pgproto3.ErrorResponse:
|
||||
mrr.err = ErrorResponseToPgError(msg)
|
||||
}
|
||||
|
@ -1341,6 +1347,7 @@ func (mrr *MultiResultReader) Close() error {
|
|||
type ResultReader struct {
|
||||
pgConn *PgConn
|
||||
multiResultReader *MultiResultReader
|
||||
pipeline *Pipeline
|
||||
ctx context.Context
|
||||
|
||||
fieldDescriptions []pgproto3.FieldDescription
|
||||
|
@ -1429,7 +1436,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
|
|||
}
|
||||
}
|
||||
|
||||
if rr.multiResultReader == nil {
|
||||
if rr.multiResultReader == nil && rr.pipeline == nil {
|
||||
for {
|
||||
msg, err := rr.receiveMessage()
|
||||
if err != nil {
|
||||
|
@ -1539,7 +1546,8 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor
|
|||
}
|
||||
|
||||
// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a
|
||||
// transaction is already in progress or SQL contains transaction control statements.
|
||||
// transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing
|
||||
// multiple queries in a single round trip than using pipeline mode.
|
||||
func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return &MultiResultReader{
|
||||
|
@ -1676,3 +1684,255 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
|
|||
|
||||
return pgConn, nil
|
||||
}
|
||||
|
||||
// Pipeline represents a connection in pipeline mode.
|
||||
//
|
||||
// SendPrepare, SendQueryParam, and SendQueryPrepared queue requests to the server. These requests are not written until
|
||||
// pipeline is flushed by Flush or Sync. Sync must be called after the last request is queued. Requests between
|
||||
// synchronization points are implicitly transactional unless explicit transaction control statements have been issued.
|
||||
//
|
||||
// The context the pipeline was started with is in effect for the entire life of the Pipeline.
|
||||
//
|
||||
// For a deeper understanding of pipeline mode see the PostgreSQL documentation for the extended query protocol
|
||||
// (https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY) and the libpq pipeline mode
|
||||
// (https://www.postgresql.org/docs/current/libpq-pipeline-mode.html).
|
||||
type Pipeline struct {
|
||||
conn *PgConn
|
||||
ctx context.Context
|
||||
|
||||
expectedReadyForQueryCount int
|
||||
pendingSync bool
|
||||
|
||||
err error
|
||||
closed bool
|
||||
}
|
||||
|
||||
// PipelineSync is returned by GetResults when a ReadyForQuery message is received.
|
||||
type PipelineSync struct{}
|
||||
|
||||
// StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent
|
||||
// to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection
|
||||
// to normal mode. While in pipeline mode, no methods that communicate with the server may be called except
|
||||
// CancelRequest and Close. ctx is in effect for entire life of the *Pipeline.
|
||||
//
|
||||
// Prefer ExecBatch when only sending one group of queries at once.
|
||||
func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return &Pipeline{
|
||||
closed: true,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
pgConn.pipeline = Pipeline{
|
||||
conn: pgConn,
|
||||
ctx: ctx,
|
||||
}
|
||||
pipeline := &pgConn.pipeline
|
||||
|
||||
if ctx != context.Background() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
pipeline.closed = true
|
||||
pipeline.err = newContextAlreadyDoneError(ctx)
|
||||
pgConn.unlock()
|
||||
return pipeline
|
||||
default:
|
||||
}
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
}
|
||||
|
||||
return pipeline
|
||||
}
|
||||
|
||||
// SendPrepare is the pipeline version of *PgConn.Prepare.
|
||||
func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) {
|
||||
if p.closed {
|
||||
return
|
||||
}
|
||||
p.pendingSync = true
|
||||
|
||||
p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
|
||||
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
|
||||
}
|
||||
|
||||
// SendQueryParams is the pipeline version of *PgConn.QueryParams.
|
||||
func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
|
||||
if p.closed {
|
||||
return
|
||||
}
|
||||
p.pendingSync = true
|
||||
|
||||
p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs})
|
||||
p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
|
||||
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
|
||||
p.conn.frontend.SendExecute(&pgproto3.Execute{})
|
||||
}
|
||||
|
||||
// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared.
|
||||
func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
|
||||
if p.closed {
|
||||
return
|
||||
}
|
||||
p.pendingSync = true
|
||||
|
||||
p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
|
||||
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
|
||||
p.conn.frontend.SendExecute(&pgproto3.Execute{})
|
||||
}
|
||||
|
||||
// Flush flushes the queued requests without establishing a synchronization point.
|
||||
func (p *Pipeline) Flush() error {
|
||||
if p.closed {
|
||||
if p.err != nil {
|
||||
return p.err
|
||||
}
|
||||
return errors.New("pipeline closed")
|
||||
}
|
||||
|
||||
err := p.conn.frontend.Flush()
|
||||
if err != nil {
|
||||
err = preferContextOverNetTimeoutError(p.ctx, err)
|
||||
|
||||
p.conn.asyncClose()
|
||||
|
||||
p.conn.contextWatcher.Unwatch()
|
||||
p.conn.unlock()
|
||||
p.closed = true
|
||||
p.err = err
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sync establishes a synchronization point and flushes the queued requests.
|
||||
func (p *Pipeline) Sync() error {
|
||||
p.conn.frontend.SendSync(&pgproto3.Sync{})
|
||||
err := p.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.pendingSync = false
|
||||
p.expectedReadyForQueryCount++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or
|
||||
// *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. If no
|
||||
// results are available, results and err will both be nil.
|
||||
func (p *Pipeline) GetResults() (results any, err error) {
|
||||
if p.expectedReadyForQueryCount == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := p.conn.receiveMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.RowDescription:
|
||||
p.conn.resultReader = ResultReader{
|
||||
pgConn: p.conn,
|
||||
pipeline: p,
|
||||
ctx: p.ctx,
|
||||
fieldDescriptions: msg.Fields,
|
||||
}
|
||||
return &p.conn.resultReader, nil
|
||||
case *pgproto3.CommandComplete:
|
||||
p.conn.resultReader = ResultReader{
|
||||
commandTag: p.conn.makeCommandTag(msg.CommandTag),
|
||||
commandConcluded: true,
|
||||
closed: true,
|
||||
}
|
||||
return &p.conn.resultReader, nil
|
||||
case *pgproto3.ParseComplete:
|
||||
peekedMsg, err := p.conn.peekMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok {
|
||||
return p.getResultsPrepare()
|
||||
}
|
||||
case *pgproto3.ReadyForQuery:
|
||||
p.expectedReadyForQueryCount--
|
||||
return &PipelineSync{}, nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
pgErr := ErrorResponseToPgError(msg)
|
||||
return nil, pgErr
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) {
|
||||
psd := &StatementDescription{}
|
||||
|
||||
for {
|
||||
msg, err := p.conn.receiveMessage()
|
||||
if err != nil {
|
||||
p.conn.asyncClose()
|
||||
return nil, preferContextOverNetTimeoutError(p.ctx, err)
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ParameterDescription:
|
||||
psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs))
|
||||
copy(psd.ParamOIDs, msg.ParameterOIDs)
|
||||
case *pgproto3.RowDescription:
|
||||
psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields))
|
||||
copy(psd.Fields, msg.Fields)
|
||||
return psd, nil
|
||||
|
||||
// These should never happen here. But don't take chances that could lead to a deadlock.
|
||||
case *pgproto3.ErrorResponse:
|
||||
pgErr := ErrorResponseToPgError(msg)
|
||||
return nil, pgErr
|
||||
case *pgproto3.CommandComplete:
|
||||
p.conn.asyncClose()
|
||||
return nil, errors.New("BUG: received CommandComplete while handling Describe")
|
||||
case *pgproto3.ReadyForQuery:
|
||||
p.conn.asyncClose()
|
||||
return nil, errors.New("BUG: received ReadyForQuery while handling Describe")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the pipeline and returns the connection to normal mode.
|
||||
func (p *Pipeline) Close() error {
|
||||
if p.closed {
|
||||
return p.err
|
||||
}
|
||||
p.closed = true
|
||||
|
||||
if p.pendingSync {
|
||||
p.conn.asyncClose()
|
||||
p.err = errors.New("pipeline has unsynced requests")
|
||||
p.conn.contextWatcher.Unwatch()
|
||||
p.conn.unlock()
|
||||
|
||||
return p.err
|
||||
}
|
||||
|
||||
for p.expectedReadyForQueryCount > 0 {
|
||||
_, err := p.GetResults()
|
||||
if err != nil {
|
||||
var pgErr *PgError
|
||||
if !errors.As(err, &pgErr) {
|
||||
p.conn.asyncClose()
|
||||
p.err = err
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.conn.contextWatcher.Unwatch()
|
||||
p.conn.unlock()
|
||||
|
||||
return p.err
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"github.com/jackc/pgx/v5/internal/pgmock"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -2094,6 +2095,417 @@ func TestConnCheckConn(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestPipelinePrepare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
pipeline := pgConn.StartPipeline(context.Background())
|
||||
pipeline.SendPrepare("selectInt", "select $1::int as a", nil)
|
||||
pipeline.SendPrepare("selectText", "select $1::text as b", nil)
|
||||
pipeline.SendPrepare("selectNoParams", "select 42 as c", nil)
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err := pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
sd, ok := results.(*pgconn.StatementDescription)
|
||||
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
||||
require.Len(t, sd.Fields, 1)
|
||||
require.Equal(t, string(sd.Fields[0].Name), "a")
|
||||
require.Equal(t, []uint32{pgtype.Int4OID}, sd.ParamOIDs)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
sd, ok = results.(*pgconn.StatementDescription)
|
||||
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
||||
require.Len(t, sd.Fields, 1)
|
||||
require.Equal(t, string(sd.Fields[0].Name), "b")
|
||||
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
sd, ok = results.(*pgconn.StatementDescription)
|
||||
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
||||
require.Len(t, sd.Fields, 1)
|
||||
require.Equal(t, string(sd.Fields[0].Name), "c")
|
||||
require.Equal(t, []uint32{}, sd.ParamOIDs)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
err = pipeline.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestPipelinePrepareError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
pipeline := pgConn.StartPipeline(context.Background())
|
||||
pipeline.SendPrepare("selectInt", "select $1::int as a", nil)
|
||||
pipeline.SendPrepare("selectError", "bad", nil)
|
||||
pipeline.SendPrepare("selectText", "select $1::text as b", nil)
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err := pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
sd, ok := results.(*pgconn.StatementDescription)
|
||||
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
||||
require.Len(t, sd.Fields, 1)
|
||||
require.Equal(t, string(sd.Fields[0].Name), "a")
|
||||
require.Equal(t, []uint32{pgtype.Int4OID}, sd.ParamOIDs)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
var pgErr *pgconn.PgError
|
||||
require.ErrorAs(t, err, &pgErr)
|
||||
require.Nil(t, results)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
err = pipeline.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestPipelineQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
pipeline := pgConn.StartPipeline(context.Background())
|
||||
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
||||
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
|
||||
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
|
||||
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err := pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok := results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult := rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "2", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "3", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "4", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "5", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
err = pipeline.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestPipelinePrepareQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
pipeline := pgConn.StartPipeline(context.Background())
|
||||
pipeline.SendPrepare("ps", "select $1::text as msg", nil)
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil)
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("goodbye")}, nil, nil)
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err := pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
sd, ok := results.(*pgconn.StatementDescription)
|
||||
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
||||
require.Len(t, sd.Fields, 1)
|
||||
require.Equal(t, string(sd.Fields[0].Name), "msg")
|
||||
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok := results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult := rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "hello", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "goodbye", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
err = pipeline.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestPipelineQueryErrorBetweenSyncs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
pipeline := pgConn.StartPipeline(context.Background())
|
||||
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
||||
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
|
||||
pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil)
|
||||
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
|
||||
pipeline.SendQueryParams(`select 6`, nil, nil, nil, nil)
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err := pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok := results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult := rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "2", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "3", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
var pgErr *pgconn.PgError
|
||||
require.ErrorAs(t, readResult.Err, &pgErr)
|
||||
require.Equal(t, "22012", pgErr.Code)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "5", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "6", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
err = pipeline.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestPipelineCloseReadsUnreadResults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
pipeline := pgConn.StartPipeline(context.Background())
|
||||
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
||||
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
|
||||
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
|
||||
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err := pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok := results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult := rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
||||
|
||||
err = pipeline.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
pipeline := pgConn.StartPipeline(context.Background())
|
||||
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
||||
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
|
||||
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
|
||||
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
|
||||
|
||||
results, err := pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok := results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult := rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
||||
|
||||
err = pipeline.Close()
|
||||
require.EqualError(t, err, "pipeline has unsynced requests")
|
||||
}
|
||||
|
||||
func Example() {
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue