diff --git a/base/conn.go b/base/conn.go index b8867c18..386daaa5 100644 --- a/base/conn.go +++ b/base/conn.go @@ -16,9 +16,12 @@ import ( "strings" "time" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" ) +const batchBufferSize = 4096 + // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for // detailed field description. @@ -111,6 +114,13 @@ type PgConn struct { Frontend *pgproto3.Frontend Config ConnConfig + + batchBuf []byte + batchCount int32 + + pendingReadyForQueryCount int32 + + closed bool } func Connect(cc ConnConfig) (*PgConn, error) { @@ -258,16 +268,211 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { switch msg := msg.(type) { case *pgproto3.ReadyForQuery: + // Under normal circumstances pendingReadyForQueryCount will be > 0 when a + // ReadyForQuery is received. However, this is not the case on initial + // connection. + if pgConn.pendingReadyForQueryCount > 0 { + pgConn.pendingReadyForQueryCount -= 1 + } pgConn.TxStatus = msg.TxStatus case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value + case *pgproto3.ErrorResponse: + if msg.Severity == "FATAL" { + // TODO - close pgConn + return nil, errorResponseToPgError(msg) + } } return msg, nil } +// Close closes a connection. It is safe to call Close on a already closed +// connection. +func (pgConn *PgConn) Close() error { + if pgConn.closed { + return nil + } + pgConn.closed = true + + _, err := pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4}) + if err != nil { + pgConn.NetConn.Close() + return err + } + + _, err = pgConn.NetConn.Read(make([]byte, 1)) + if err != io.EOF { + pgConn.NetConn.Close() + return err + } + + return pgConn.NetConn.Close() +} + // ParameterStatus returns the value of a parameter reported by the server (e.g. // server_version). Returns an empty string for unknown parameters. func (pgConn *PgConn) ParameterStatus(key string) string { return pgConn.parameterStatuses[key] } + +// CommandTag is the result of an Exec function +type CommandTag string + +// RowsAffected returns the number of rows affected. If the CommandTag was not +// for a row affecting command (e.g. "CREATE TABLE") then it returns 0. +func (ct CommandTag) RowsAffected() int64 { + s := string(ct) + index := strings.LastIndex(s, " ") + if index == -1 { + return 0 + } + n, _ := strconv.ParseInt(s[index+1:], 10, 64) + return n +} + +// SendExec enqueues the execution of sql via the PostgreSQL simple query +// protocol. sql may contain multipe queries. Multiple queries will be processed +// within a single transation. It is only sent to the PostgreSQL server when +// Flush is called. +func (pgConn *PgConn) SendExec(sql string) { + pgConn.batchBuf = appendQuery(pgConn.batchBuf, sql) + pgConn.batchCount += 1 +} + +// appendQuery appends a PostgreSQL wire protocol query message to buf and returns it. +func appendQuery(buf []byte, query string) []byte { + buf = append(buf, 'Q') + buf = pgio.AppendInt32(buf, int32(len(query)+5)) + buf = append(buf, query...) + buf = append(buf, 0) + return buf +} + +type PgResultReader struct { + pgConn *PgConn + fieldDescriptions []pgproto3.FieldDescription + rowValues [][]byte + commandTag CommandTag + err error + complete bool +} + +// GetResult returns a PgResultReader for the next result. If all results are +// consumed it returns nil. If an error occurs it will be reported on the +// returned PgResultReader. +func (pgConn *PgConn) GetResult() *PgResultReader { + if pgConn.pendingReadyForQueryCount == 0 { + return nil + } + + return &PgResultReader{pgConn: pgConn} +} + +func (rr *PgResultReader) NextRow() (present bool) { + if rr.complete { + return false + } + + for { + msg, err := rr.pgConn.ReceiveMessage() + if err != nil { + return false + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + rr.fieldDescriptions = msg.Fields + case *pgproto3.DataRow: + rr.rowValues = msg.Values + return true + case *pgproto3.CommandComplete: + rr.commandTag = CommandTag(msg.CommandTag) + rr.complete = true + return false + case *pgproto3.ErrorResponse: + rr.err = errorResponseToPgError(msg) + rr.complete = true + return false + } + } +} + +func (rr *PgResultReader) Value(c int) []byte { + return rr.rowValues[c] +} + +// Close consumes any remaining result data and returns the command tag or +// error. +func (rr *PgResultReader) Close() (CommandTag, error) { + if rr.complete { + return rr.commandTag, rr.err + } + + for { + msg, err := rr.pgConn.ReceiveMessage() + if err != nil { + rr.err = err + rr.complete = true + return rr.commandTag, rr.err + } + + switch msg := msg.(type) { + case *pgproto3.CommandComplete: + rr.commandTag = CommandTag(msg.CommandTag) + rr.complete = true + return rr.commandTag, rr.err + case *pgproto3.ErrorResponse: + rr.err = errorResponseToPgError(msg) + rr.complete = true + return rr.commandTag, rr.err + } + } +} + +// Flush sends the enqueued execs to the server. +func (pgConn *PgConn) Flush() error { + defer pgConn.resetBatch() + + n, err := pgConn.NetConn.Write(pgConn.batchBuf) + if err != nil { + if n > 0 { + // TODO - kill connection - we sent a partial message + } + return err + } + + pgConn.pendingReadyForQueryCount += pgConn.batchCount + return nil +} + +func (pgConn *PgConn) resetBatch() { + pgConn.batchCount = 0 + if len(pgConn.batchBuf) > batchBufferSize { + pgConn.batchBuf = make([]byte, 0, batchBufferSize) + } else { + pgConn.batchBuf = pgConn.batchBuf[0:0] + } +} + +func errorResponseToPgError(msg *pgproto3.ErrorResponse) PgError { + return PgError{ + Severity: msg.Severity, + Code: msg.Code, + Message: msg.Message, + Detail: msg.Detail, + Hint: msg.Hint, + Position: msg.Position, + InternalPosition: msg.InternalPosition, + InternalQuery: msg.InternalQuery, + Where: msg.Where, + SchemaName: msg.SchemaName, + TableName: msg.TableName, + ColumnName: msg.ColumnName, + DataTypeName: msg.DataTypeName, + ConstraintName: msg.ConstraintName, + File: msg.File, + Line: msg.Line, + Routine: msg.Routine, + } +} diff --git a/base/pgconn_test.go b/base/pgconn_test.go new file mode 100644 index 00000000..ad1a9918 --- /dev/null +++ b/base/pgconn_test.go @@ -0,0 +1,34 @@ +package base_test + +import ( + "github.com/jackc/pgx/base" + + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSimple(t *testing.T) { + pgConn, err := base.Connect(base.ConnConfig{Host: "/var/run/postgresql", User: "jack", Database: "pgx_test"}) + require.Nil(t, err) + + pgConn.SendExec("select current_database()") + err = pgConn.Flush() + require.Nil(t, err) + + result := pgConn.GetResult() + require.NotNil(t, result) + + rowFound := result.NextRow() + assert.True(t, rowFound) + if rowFound { + assert.Equal(t, "pgx_test", string(result.Value(0))) + } + + _, err = result.Close() + assert.Nil(t, err) + + err = pgConn.Close() + assert.Nil(t, err) +}