Use result readers in next/get fashion

pull/483/head
Jack Christensen 2019-01-01 14:10:16 -06:00
parent b12b579814
commit 0330052b0a
3 changed files with 43 additions and 39 deletions

View File

@ -538,10 +538,9 @@ type PgResultReader struct {
cleanupContext func()
}
// GetResult returns a PgResultReader for the next result. If all results are consumed it returns nil. If an error
// occurs it will be reported on the returned PgResultReader. Returned PgResultReader is only valid until next call of
// GetResult.
func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader {
// NextResult reads until a result is ready to be read or no results are pending. Returns true if a result is available.
// Use ResultReader() to acquire a reader for the result.
func (pgConn *PgConn) NextResult(ctx context.Context) bool {
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn)
for pgConn.pendingReadyForQueryCount > 0 {
@ -549,29 +548,34 @@ func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader {
if err != nil {
cleanupContext()
pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: preferContextOverNetTimeoutError(ctx, err), complete: true}
return &pgConn.resultReader
return true
}
switch msg := msg.(type) {
case *pgproto3.RowDescription:
pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, fieldDescriptions: msg.Fields}
return &pgConn.resultReader
return true
case *pgproto3.DataRow:
pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, rowValues: msg.Values, preloadedRowValues: true}
return &pgConn.resultReader
return true
case *pgproto3.CommandComplete:
cleanupContext()
pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true}
return &pgConn.resultReader
return true
case *pgproto3.ErrorResponse:
cleanupContext()
pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true}
return &pgConn.resultReader
return true
}
}
cleanupContext()
return nil
return false
}
// ResultReader returns the result reader prepared by next result. It is only valid until the result is completed.
func (pgConn *PgConn) ResultReader() *PgResultReader {
return &pgConn.resultReader
}
// NextRow returns advances the PgResultReader to the next row and returns true if a row is available.
@ -806,7 +810,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) {
func (pgConn *PgConn) bufferLastResult(ctx context.Context) (*PgResult, error) {
var result *PgResult
for resultReader := pgConn.GetResult(ctx); resultReader != nil; resultReader = pgConn.GetResult(ctx) {
for pgConn.NextResult(ctx) {
resultReader := pgConn.ResultReader()
rows := [][][]byte{}
for resultReader.NextRow() {
row := make([][]byte, len(resultReader.Values()))

View File

@ -84,10 +84,10 @@ func stressBatch(pgConn *pgconn.PgConn) error {
}
// Query 1
resultReader := pgConn.GetResult(context.Background())
if resultReader == nil {
return errors.New("missing resultReader")
if !pgConn.NextResult(context.Background()) {
return errors.New("missing result")
}
resultReader := pgConn.ResultReader()
for resultReader.NextRow() {
}
@ -97,10 +97,10 @@ func stressBatch(pgConn *pgconn.PgConn) error {
}
// Query 2
resultReader = pgConn.GetResult(context.Background())
if resultReader == nil {
return errors.New("missing resultReader")
if !pgConn.NextResult(context.Background()) {
return errors.New("missing result")
}
resultReader = pgConn.ResultReader()
for resultReader.NextRow() {
}
@ -110,8 +110,7 @@ func stressBatch(pgConn *pgconn.PgConn) error {
}
// No more
resultReader = pgConn.GetResult(context.Background())
if resultReader != nil {
if pgConn.NextResult(context.Background()) {
return errors.New("unexpected result reader")
}
@ -162,10 +161,10 @@ func stressBatchCanceled(pgConn *pgconn.PgConn) error {
}
// Query 1
resultReader := pgConn.GetResult(context.Background())
if resultReader == nil {
return errors.New("missing resultReader")
if !pgConn.NextResult(context.Background()) {
return errors.New("missing result")
}
resultReader := pgConn.ResultReader()
for resultReader.NextRow() {
}
@ -176,11 +175,11 @@ func stressBatchCanceled(pgConn *pgconn.PgConn) error {
// Query 2
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
resultReader = pgConn.GetResult(ctx)
cancel()
if resultReader == nil {
return errors.New("missing resultReader")
if !pgConn.NextResult(ctx) {
return errors.New("missing result")
}
cancel()
resultReader = pgConn.ResultReader()
for resultReader.NextRow() {
}

View File

@ -373,8 +373,8 @@ func TestConnBatchedQueries(t *testing.T) {
err = pgConn.Flush(context.Background())
// "select 'SendExec 1'"
resultReader := pgConn.GetResult(context.Background())
require.NotNil(t, resultReader)
require.True(t, pgConn.NextResult(context.Background()))
resultReader := pgConn.ResultReader()
rows := [][][]byte{}
for resultReader.NextRow() {
@ -391,8 +391,8 @@ func TestConnBatchedQueries(t *testing.T) {
assert.Nil(t, err)
// "SendExecParams 1"
resultReader = pgConn.GetResult(context.Background())
require.NotNil(t, resultReader)
require.True(t, pgConn.NextResult(context.Background()))
resultReader = pgConn.ResultReader()
rows = [][][]byte{}
for resultReader.NextRow() {
@ -409,8 +409,8 @@ func TestConnBatchedQueries(t *testing.T) {
assert.Nil(t, err)
// "SendExecPrepared 1"
resultReader = pgConn.GetResult(context.Background())
require.NotNil(t, resultReader)
require.True(t, pgConn.NextResult(context.Background()))
resultReader = pgConn.ResultReader()
rows = [][][]byte{}
for resultReader.NextRow() {
@ -427,8 +427,8 @@ func TestConnBatchedQueries(t *testing.T) {
assert.Nil(t, err)
// "SendExec 2"
resultReader = pgConn.GetResult(context.Background())
require.NotNil(t, resultReader)
require.True(t, pgConn.NextResult(context.Background()))
resultReader = pgConn.ResultReader()
rows = [][][]byte{}
for resultReader.NextRow() {
@ -445,8 +445,8 @@ func TestConnBatchedQueries(t *testing.T) {
assert.Nil(t, err)
// "SendExecParams 2"
resultReader = pgConn.GetResult(context.Background())
require.NotNil(t, resultReader)
require.True(t, pgConn.NextResult(context.Background()))
resultReader = pgConn.ResultReader()
rows = [][][]byte{}
for resultReader.NextRow() {
@ -463,8 +463,7 @@ func TestConnBatchedQueries(t *testing.T) {
assert.Nil(t, err)
// Done
resultReader = pgConn.GetResult(context.Background())
assert.Nil(t, resultReader)
require.False(t, pgConn.NextResult(context.Background()))
}
func TestConnRecoverFromTimeout(t *testing.T) {
@ -505,7 +504,8 @@ func TestConnCancelQuery(t *testing.T) {
err = pgConn.CancelRequest(context.Background())
require.Nil(t, err)
_, err = pgConn.GetResult(context.Background()).Close()
require.True(t, pgConn.NextResult(context.Background()))
_, err = pgConn.ResultReader().Close()
if err, ok := err.(*pgconn.PgError); ok {
assert.Equal(t, "57014", err.Code)
} else {