mirror of https://github.com/jackc/pgx.git
Add beginning of PgConn
parent
44de49ffa1
commit
b89ba28919
205
base/conn.go
205
base/conn.go
|
@ -16,9 +16,12 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/pgio"
|
||||||
"github.com/jackc/pgx/pgproto3"
|
"github.com/jackc/pgx/pgproto3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const batchBufferSize = 4096
|
||||||
|
|
||||||
// PgError represents an error reported by the PostgreSQL server. See
|
// PgError represents an error reported by the PostgreSQL server. See
|
||||||
// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for
|
// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for
|
||||||
// detailed field description.
|
// detailed field description.
|
||||||
|
@ -111,6 +114,13 @@ type PgConn struct {
|
||||||
Frontend *pgproto3.Frontend
|
Frontend *pgproto3.Frontend
|
||||||
|
|
||||||
Config ConnConfig
|
Config ConnConfig
|
||||||
|
|
||||||
|
batchBuf []byte
|
||||||
|
batchCount int32
|
||||||
|
|
||||||
|
pendingReadyForQueryCount int32
|
||||||
|
|
||||||
|
closed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func Connect(cc ConnConfig) (*PgConn, error) {
|
func Connect(cc ConnConfig) (*PgConn, error) {
|
||||||
|
@ -258,16 +268,211 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
case *pgproto3.ReadyForQuery:
|
case *pgproto3.ReadyForQuery:
|
||||||
|
// Under normal circumstances pendingReadyForQueryCount will be > 0 when a
|
||||||
|
// ReadyForQuery is received. However, this is not the case on initial
|
||||||
|
// connection.
|
||||||
|
if pgConn.pendingReadyForQueryCount > 0 {
|
||||||
|
pgConn.pendingReadyForQueryCount -= 1
|
||||||
|
}
|
||||||
pgConn.TxStatus = msg.TxStatus
|
pgConn.TxStatus = msg.TxStatus
|
||||||
case *pgproto3.ParameterStatus:
|
case *pgproto3.ParameterStatus:
|
||||||
pgConn.parameterStatuses[msg.Name] = msg.Value
|
pgConn.parameterStatuses[msg.Name] = msg.Value
|
||||||
|
case *pgproto3.ErrorResponse:
|
||||||
|
if msg.Severity == "FATAL" {
|
||||||
|
// TODO - close pgConn
|
||||||
|
return nil, errorResponseToPgError(msg)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes a connection. It is safe to call Close on a already closed
|
||||||
|
// connection.
|
||||||
|
func (pgConn *PgConn) Close() error {
|
||||||
|
if pgConn.closed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
pgConn.closed = true
|
||||||
|
|
||||||
|
_, err := pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4})
|
||||||
|
if err != nil {
|
||||||
|
pgConn.NetConn.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = pgConn.NetConn.Read(make([]byte, 1))
|
||||||
|
if err != io.EOF {
|
||||||
|
pgConn.NetConn.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return pgConn.NetConn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
// ParameterStatus returns the value of a parameter reported by the server (e.g.
|
// ParameterStatus returns the value of a parameter reported by the server (e.g.
|
||||||
// server_version). Returns an empty string for unknown parameters.
|
// server_version). Returns an empty string for unknown parameters.
|
||||||
func (pgConn *PgConn) ParameterStatus(key string) string {
|
func (pgConn *PgConn) ParameterStatus(key string) string {
|
||||||
return pgConn.parameterStatuses[key]
|
return pgConn.parameterStatuses[key]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CommandTag is the result of an Exec function
|
||||||
|
type CommandTag string
|
||||||
|
|
||||||
|
// RowsAffected returns the number of rows affected. If the CommandTag was not
|
||||||
|
// for a row affecting command (e.g. "CREATE TABLE") then it returns 0.
|
||||||
|
func (ct CommandTag) RowsAffected() int64 {
|
||||||
|
s := string(ct)
|
||||||
|
index := strings.LastIndex(s, " ")
|
||||||
|
if index == -1 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
n, _ := strconv.ParseInt(s[index+1:], 10, 64)
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendExec enqueues the execution of sql via the PostgreSQL simple query
|
||||||
|
// protocol. sql may contain multipe queries. Multiple queries will be processed
|
||||||
|
// within a single transation. It is only sent to the PostgreSQL server when
|
||||||
|
// Flush is called.
|
||||||
|
func (pgConn *PgConn) SendExec(sql string) {
|
||||||
|
pgConn.batchBuf = appendQuery(pgConn.batchBuf, sql)
|
||||||
|
pgConn.batchCount += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// appendQuery appends a PostgreSQL wire protocol query message to buf and returns it.
|
||||||
|
func appendQuery(buf []byte, query string) []byte {
|
||||||
|
buf = append(buf, 'Q')
|
||||||
|
buf = pgio.AppendInt32(buf, int32(len(query)+5))
|
||||||
|
buf = append(buf, query...)
|
||||||
|
buf = append(buf, 0)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
type PgResultReader struct {
|
||||||
|
pgConn *PgConn
|
||||||
|
fieldDescriptions []pgproto3.FieldDescription
|
||||||
|
rowValues [][]byte
|
||||||
|
commandTag CommandTag
|
||||||
|
err error
|
||||||
|
complete bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetResult returns a PgResultReader for the next result. If all results are
|
||||||
|
// consumed it returns nil. If an error occurs it will be reported on the
|
||||||
|
// returned PgResultReader.
|
||||||
|
func (pgConn *PgConn) GetResult() *PgResultReader {
|
||||||
|
if pgConn.pendingReadyForQueryCount == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PgResultReader{pgConn: pgConn}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rr *PgResultReader) NextRow() (present bool) {
|
||||||
|
if rr.complete {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
msg, err := rr.pgConn.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case *pgproto3.RowDescription:
|
||||||
|
rr.fieldDescriptions = msg.Fields
|
||||||
|
case *pgproto3.DataRow:
|
||||||
|
rr.rowValues = msg.Values
|
||||||
|
return true
|
||||||
|
case *pgproto3.CommandComplete:
|
||||||
|
rr.commandTag = CommandTag(msg.CommandTag)
|
||||||
|
rr.complete = true
|
||||||
|
return false
|
||||||
|
case *pgproto3.ErrorResponse:
|
||||||
|
rr.err = errorResponseToPgError(msg)
|
||||||
|
rr.complete = true
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rr *PgResultReader) Value(c int) []byte {
|
||||||
|
return rr.rowValues[c]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close consumes any remaining result data and returns the command tag or
|
||||||
|
// error.
|
||||||
|
func (rr *PgResultReader) Close() (CommandTag, error) {
|
||||||
|
if rr.complete {
|
||||||
|
return rr.commandTag, rr.err
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
msg, err := rr.pgConn.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
rr.err = err
|
||||||
|
rr.complete = true
|
||||||
|
return rr.commandTag, rr.err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case *pgproto3.CommandComplete:
|
||||||
|
rr.commandTag = CommandTag(msg.CommandTag)
|
||||||
|
rr.complete = true
|
||||||
|
return rr.commandTag, rr.err
|
||||||
|
case *pgproto3.ErrorResponse:
|
||||||
|
rr.err = errorResponseToPgError(msg)
|
||||||
|
rr.complete = true
|
||||||
|
return rr.commandTag, rr.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush sends the enqueued execs to the server.
|
||||||
|
func (pgConn *PgConn) Flush() error {
|
||||||
|
defer pgConn.resetBatch()
|
||||||
|
|
||||||
|
n, err := pgConn.NetConn.Write(pgConn.batchBuf)
|
||||||
|
if err != nil {
|
||||||
|
if n > 0 {
|
||||||
|
// TODO - kill connection - we sent a partial message
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pgConn.pendingReadyForQueryCount += pgConn.batchCount
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pgConn *PgConn) resetBatch() {
|
||||||
|
pgConn.batchCount = 0
|
||||||
|
if len(pgConn.batchBuf) > batchBufferSize {
|
||||||
|
pgConn.batchBuf = make([]byte, 0, batchBufferSize)
|
||||||
|
} else {
|
||||||
|
pgConn.batchBuf = pgConn.batchBuf[0:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func errorResponseToPgError(msg *pgproto3.ErrorResponse) PgError {
|
||||||
|
return PgError{
|
||||||
|
Severity: msg.Severity,
|
||||||
|
Code: msg.Code,
|
||||||
|
Message: msg.Message,
|
||||||
|
Detail: msg.Detail,
|
||||||
|
Hint: msg.Hint,
|
||||||
|
Position: msg.Position,
|
||||||
|
InternalPosition: msg.InternalPosition,
|
||||||
|
InternalQuery: msg.InternalQuery,
|
||||||
|
Where: msg.Where,
|
||||||
|
SchemaName: msg.SchemaName,
|
||||||
|
TableName: msg.TableName,
|
||||||
|
ColumnName: msg.ColumnName,
|
||||||
|
DataTypeName: msg.DataTypeName,
|
||||||
|
ConstraintName: msg.ConstraintName,
|
||||||
|
File: msg.File,
|
||||||
|
Line: msg.Line,
|
||||||
|
Routine: msg.Routine,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
package base_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/jackc/pgx/base"
|
||||||
|
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSimple(t *testing.T) {
|
||||||
|
pgConn, err := base.Connect(base.ConnConfig{Host: "/var/run/postgresql", User: "jack", Database: "pgx_test"})
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
pgConn.SendExec("select current_database()")
|
||||||
|
err = pgConn.Flush()
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
result := pgConn.GetResult()
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
rowFound := result.NextRow()
|
||||||
|
assert.True(t, rowFound)
|
||||||
|
if rowFound {
|
||||||
|
assert.Equal(t, "pgx_test", string(result.Value(0)))
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = result.Close()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
err = pgConn.Close()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
}
|
Loading…
Reference in New Issue