diff --git a/pgconn/config.go b/pgconn/config.go index 38144be7..a446a67e 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -483,7 +483,7 @@ func AfterConnectTargetSessionAttrsReadWrite(pgConn *PgConn) error { return errors.New("show transaction_read_only failed") } - if string(result.Value(0)) == "on" { + if string(result.Values()[0]) == "on" { return errors.New("read only connection") } diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 94397759..c243d2f6 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -340,10 +340,9 @@ func (ct CommandTag) RowsAffected() int64 { return n } -// SendExec enqueues the execution of sql via the PostgreSQL simple query -// protocol. sql may contain multipe queries. Multiple queries will be processed -// within a single transation. It is only sent to the PostgreSQL server when -// Flush is called. +// SendExec enqueues the execution of sql via the PostgreSQL simple query protocol. sql may contain multiple queries. +// Execution is implicitly wrapped in a transactions unless a transaction is already in progress or sql contains +// transaction control statements. It is only sent to the PostgreSQL server when Flush is called. func (pgConn *PgConn) SendExec(sql string) { pgConn.batchBuf = appendQuery(pgConn.batchBuf, sql) pgConn.batchCount += 1 @@ -359,30 +358,51 @@ func appendQuery(buf []byte, query string) []byte { } type PgResultReader struct { - pgConn *PgConn - fieldDescriptions []pgproto3.FieldDescription - rowValues [][]byte - commandTag CommandTag - err error - complete bool + pgConn *PgConn + fieldDescriptions []pgproto3.FieldDescription + rowValues [][]byte + commandTag CommandTag + err error + complete bool + preloadedRowValues bool } // GetResult returns a PgResultReader for the next result. If all results are // consumed it returns nil. If an error occurs it will be reported on the // returned PgResultReader. func (pgConn *PgConn) GetResult() *PgResultReader { - if pgConn.pendingReadyForQueryCount == 0 { - return nil + for pgConn.pendingReadyForQueryCount > 0 { + msg, err := pgConn.ReceiveMessage() + if err != nil { + return &PgResultReader{pgConn: pgConn, err: err, complete: true} + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + return &PgResultReader{pgConn: pgConn, fieldDescriptions: msg.Fields} + case *pgproto3.DataRow: + return &PgResultReader{pgConn: pgConn, rowValues: msg.Values, preloadedRowValues: true} + case *pgproto3.CommandComplete: + return &PgResultReader{pgConn: pgConn, commandTag: CommandTag(msg.CommandTag), complete: true} + case *pgproto3.ErrorResponse: + return &PgResultReader{pgConn: pgConn, err: errorResponseToPgError(msg), complete: true} + } } - return &PgResultReader{pgConn: pgConn} + return nil } -func (rr *PgResultReader) NextRow() (present bool) { +// NextRow returns advances the PgResultReader to the next row and returns true if a row is available. +func (rr *PgResultReader) NextRow() bool { if rr.complete { return false } + if rr.preloadedRowValues { + rr.preloadedRowValues = false + return true + } + for { msg, err := rr.pgConn.ReceiveMessage() if err != nil { @@ -396,6 +416,7 @@ func (rr *PgResultReader) NextRow() (present bool) { rr.rowValues = msg.Values return true case *pgproto3.CommandComplete: + rr.rowValues = nil rr.commandTag = CommandTag(msg.CommandTag) rr.complete = true return false @@ -407,8 +428,11 @@ func (rr *PgResultReader) NextRow() (present bool) { } } -func (rr *PgResultReader) Value(c int) []byte { - return rr.rowValues[c] +// Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only +// valid until the next NextRow call or the PgResultReader is closed. However, the underlying byte data is safe to +// retain a reference to and mutate. +func (rr *PgResultReader) Values() [][]byte { + return rr.rowValues } // Close consumes any remaining result data and returns the command tag or @@ -418,6 +442,8 @@ func (rr *PgResultReader) Close() (CommandTag, error) { return rr.commandTag, rr.err } + rr.rowValues = nil + for { msg, err := rr.pgConn.ReceiveMessage() if err != nil { @@ -464,6 +490,57 @@ func (pgConn *PgConn) resetBatch() { } } +type PgResult struct { + Rows [][][]byte + CommandTag CommandTag +} + +// Exec executes sql via the PostgreSQL simple query protocol, buffers the entire result, and returns it. sql may +// contain multiple queries, but only the last results will be returned. Execution is implicitly wrapped in a +// transactions unless a transaction is already in progress or sql contains transaction control statements. +// +// Exec must not be called when there are pending results from previous Send* methods (e.g. SendExec). +func (pgConn *PgConn) Exec(sql string) (*PgResult, error) { + if pgConn.batchCount != 0 { + return nil, errors.New("unflushed previous sends") + } + if pgConn.pendingReadyForQueryCount != 0 { + return nil, errors.New("unread previous results") + } + + pgConn.SendExec(sql) + err := pgConn.Flush() + if err != nil { + return nil, err + } + + var result *PgResult + + for resultReader := pgConn.GetResult(); resultReader != nil; resultReader = pgConn.GetResult() { + rows := [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + + commandTag, err := resultReader.Close() + if err != nil { + return nil, err + } + + result = &PgResult{ + Rows: rows, + CommandTag: commandTag, + } + } + if result == nil { + return nil, errors.New("unexpected missing result") + } + + return result, nil +} + func errorResponseToPgError(msg *pgproto3.ErrorResponse) PgError { return PgError{ Severity: msg.Severity, diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 0dccc99f..f3f22d42 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -126,36 +126,15 @@ func TestConnectWithRuntimeParams(t *testing.T) { require.Nil(t, err) defer closeConn(t, conn) - // TODO - refactor these selects once there are higher level query functions - - conn.SendExec("show application_name") - conn.SendExec("show search_path") - err = conn.Flush() + result, err := conn.Exec("show application_name") require.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result := conn.GetResult() - require.NotNil(t, result) - - rowFound := result.NextRow() - assert.True(t, rowFound) - if rowFound { - assert.Equal(t, "pgxtest", string(result.Value(0))) - } - - _, err = result.Close() - assert.Nil(t, err) - - result = conn.GetResult() - require.NotNil(t, result) - - rowFound = result.NextRow() - assert.True(t, rowFound) - if rowFound { - assert.Equal(t, "myschema", string(result.Value(0))) - } - - _, err = result.Close() - assert.Nil(t, err) + result, err = conn.Exec("show search_path") + require.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "myschema", string(result.Rows[0][0])) } func TestConnectWithFallback(t *testing.T) { @@ -239,26 +218,39 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { } } -func TestSimple(t *testing.T) { +func TestExec(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) + defer closeConn(t, pgConn) - pgConn.SendExec("select current_database()") - err = pgConn.Flush() + result, err := pgConn.Exec("select current_database()") require.Nil(t, err) - - result := pgConn.GetResult() - require.NotNil(t, result) - - rowFound := result.NextRow() - assert.True(t, rowFound) - if rowFound { - assert.Equal(t, "pgx_test", string(result.Value(0))) - } - - _, err = result.Close() - assert.Nil(t, err) - - err = pgConn.Close() - assert.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0])) +} + +func TestExecMultipleQueries(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + result, err := pgConn.Exec("select current_database(); select 1") + require.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "1", string(result.Rows[0][0])) +} + +func TestExecMultipleQueriesError(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + result, err := pgConn.Exec("select 1; select 1/0; select 1") + require.NotNil(t, err) + require.Nil(t, result) + if pgErr, ok := err.(pgconn.PgError); ok { + assert.Equal(t, "22012", pgErr.Code) + } else { + t.Errorf("unexpected error: %v", err) + } }