Add pgconn.Exec

pull/483/head
Jack Christensen 2018-12-31 13:32:26 -06:00
parent c552e2c028
commit c33441674f
3 changed files with 132 additions and 63 deletions

View File

@ -483,7 +483,7 @@ func AfterConnectTargetSessionAttrsReadWrite(pgConn *PgConn) error {
return errors.New("show transaction_read_only failed")
}
if string(result.Value(0)) == "on" {
if string(result.Values()[0]) == "on" {
return errors.New("read only connection")
}

View File

@ -340,10 +340,9 @@ func (ct CommandTag) RowsAffected() int64 {
return n
}
// SendExec enqueues the execution of sql via the PostgreSQL simple query
// protocol. sql may contain multipe queries. Multiple queries will be processed
// within a single transation. It is only sent to the PostgreSQL server when
// Flush is called.
// SendExec enqueues the execution of sql via the PostgreSQL simple query protocol. sql may contain multiple queries.
// Execution is implicitly wrapped in a transactions unless a transaction is already in progress or sql contains
// transaction control statements. It is only sent to the PostgreSQL server when Flush is called.
func (pgConn *PgConn) SendExec(sql string) {
pgConn.batchBuf = appendQuery(pgConn.batchBuf, sql)
pgConn.batchCount += 1
@ -359,30 +358,51 @@ func appendQuery(buf []byte, query string) []byte {
}
type PgResultReader struct {
pgConn *PgConn
fieldDescriptions []pgproto3.FieldDescription
rowValues [][]byte
commandTag CommandTag
err error
complete bool
pgConn *PgConn
fieldDescriptions []pgproto3.FieldDescription
rowValues [][]byte
commandTag CommandTag
err error
complete bool
preloadedRowValues bool
}
// GetResult returns a PgResultReader for the next result. If all results are
// consumed it returns nil. If an error occurs it will be reported on the
// returned PgResultReader.
func (pgConn *PgConn) GetResult() *PgResultReader {
if pgConn.pendingReadyForQueryCount == 0 {
return nil
for pgConn.pendingReadyForQueryCount > 0 {
msg, err := pgConn.ReceiveMessage()
if err != nil {
return &PgResultReader{pgConn: pgConn, err: err, complete: true}
}
switch msg := msg.(type) {
case *pgproto3.RowDescription:
return &PgResultReader{pgConn: pgConn, fieldDescriptions: msg.Fields}
case *pgproto3.DataRow:
return &PgResultReader{pgConn: pgConn, rowValues: msg.Values, preloadedRowValues: true}
case *pgproto3.CommandComplete:
return &PgResultReader{pgConn: pgConn, commandTag: CommandTag(msg.CommandTag), complete: true}
case *pgproto3.ErrorResponse:
return &PgResultReader{pgConn: pgConn, err: errorResponseToPgError(msg), complete: true}
}
}
return &PgResultReader{pgConn: pgConn}
return nil
}
func (rr *PgResultReader) NextRow() (present bool) {
// NextRow returns advances the PgResultReader to the next row and returns true if a row is available.
func (rr *PgResultReader) NextRow() bool {
if rr.complete {
return false
}
if rr.preloadedRowValues {
rr.preloadedRowValues = false
return true
}
for {
msg, err := rr.pgConn.ReceiveMessage()
if err != nil {
@ -396,6 +416,7 @@ func (rr *PgResultReader) NextRow() (present bool) {
rr.rowValues = msg.Values
return true
case *pgproto3.CommandComplete:
rr.rowValues = nil
rr.commandTag = CommandTag(msg.CommandTag)
rr.complete = true
return false
@ -407,8 +428,11 @@ func (rr *PgResultReader) NextRow() (present bool) {
}
}
func (rr *PgResultReader) Value(c int) []byte {
return rr.rowValues[c]
// Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only
// valid until the next NextRow call or the PgResultReader is closed. However, the underlying byte data is safe to
// retain a reference to and mutate.
func (rr *PgResultReader) Values() [][]byte {
return rr.rowValues
}
// Close consumes any remaining result data and returns the command tag or
@ -418,6 +442,8 @@ func (rr *PgResultReader) Close() (CommandTag, error) {
return rr.commandTag, rr.err
}
rr.rowValues = nil
for {
msg, err := rr.pgConn.ReceiveMessage()
if err != nil {
@ -464,6 +490,57 @@ func (pgConn *PgConn) resetBatch() {
}
}
type PgResult struct {
Rows [][][]byte
CommandTag CommandTag
}
// Exec executes sql via the PostgreSQL simple query protocol, buffers the entire result, and returns it. sql may
// contain multiple queries, but only the last results will be returned. Execution is implicitly wrapped in a
// transactions unless a transaction is already in progress or sql contains transaction control statements.
//
// Exec must not be called when there are pending results from previous Send* methods (e.g. SendExec).
func (pgConn *PgConn) Exec(sql string) (*PgResult, error) {
if pgConn.batchCount != 0 {
return nil, errors.New("unflushed previous sends")
}
if pgConn.pendingReadyForQueryCount != 0 {
return nil, errors.New("unread previous results")
}
pgConn.SendExec(sql)
err := pgConn.Flush()
if err != nil {
return nil, err
}
var result *PgResult
for resultReader := pgConn.GetResult(); resultReader != nil; resultReader = pgConn.GetResult() {
rows := [][][]byte{}
for resultReader.NextRow() {
row := make([][]byte, len(resultReader.Values()))
copy(row, resultReader.Values())
rows = append(rows, row)
}
commandTag, err := resultReader.Close()
if err != nil {
return nil, err
}
result = &PgResult{
Rows: rows,
CommandTag: commandTag,
}
}
if result == nil {
return nil, errors.New("unexpected missing result")
}
return result, nil
}
func errorResponseToPgError(msg *pgproto3.ErrorResponse) PgError {
return PgError{
Severity: msg.Severity,

View File

@ -126,36 +126,15 @@ func TestConnectWithRuntimeParams(t *testing.T) {
require.Nil(t, err)
defer closeConn(t, conn)
// TODO - refactor these selects once there are higher level query functions
conn.SendExec("show application_name")
conn.SendExec("show search_path")
err = conn.Flush()
result, err := conn.Exec("show application_name")
require.Nil(t, err)
assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, "pgxtest", string(result.Rows[0][0]))
result := conn.GetResult()
require.NotNil(t, result)
rowFound := result.NextRow()
assert.True(t, rowFound)
if rowFound {
assert.Equal(t, "pgxtest", string(result.Value(0)))
}
_, err = result.Close()
assert.Nil(t, err)
result = conn.GetResult()
require.NotNil(t, result)
rowFound = result.NextRow()
assert.True(t, rowFound)
if rowFound {
assert.Equal(t, "myschema", string(result.Value(0)))
}
_, err = result.Close()
assert.Nil(t, err)
result, err = conn.Exec("show search_path")
require.Nil(t, err)
assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, "myschema", string(result.Rows[0][0]))
}
func TestConnectWithFallback(t *testing.T) {
@ -239,26 +218,39 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) {
}
}
func TestSimple(t *testing.T) {
func TestExec(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
pgConn.SendExec("select current_database()")
err = pgConn.Flush()
result, err := pgConn.Exec("select current_database()")
require.Nil(t, err)
result := pgConn.GetResult()
require.NotNil(t, result)
rowFound := result.NextRow()
assert.True(t, rowFound)
if rowFound {
assert.Equal(t, "pgx_test", string(result.Value(0)))
}
_, err = result.Close()
assert.Nil(t, err)
err = pgConn.Close()
assert.Nil(t, err)
assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0]))
}
func TestExecMultipleQueries(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
result, err := pgConn.Exec("select current_database(); select 1")
require.Nil(t, err)
assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, "1", string(result.Rows[0][0]))
}
func TestExecMultipleQueriesError(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
result, err := pgConn.Exec("select 1; select 1/0; select 1")
require.NotNil(t, err)
require.Nil(t, result)
if pgErr, ok := err.(pgconn.PgError); ok {
assert.Equal(t, "22012", pgErr.Code)
} else {
t.Errorf("unexpected error: %v", err)
}
}