diff --git a/bench_test.go b/bench_test.go index 687c7650..c3863880 100644 --- a/bench_test.go +++ b/bench_test.go @@ -31,89 +31,6 @@ func createNarrowTestData(b *testing.B, conn *pgx.Conn) { mustPrepare(b, conn, "getMultipleNarrowByIdAsJSON", "select json_agg(row_to_json(narrow)) from narrow where id between $1 and $2") } -func removeBinaryEncoders() (encoders map[pgx.Oid]func(*pgx.MessageReader, int32) interface{}) { - encoders = make(map[pgx.Oid]func(*pgx.MessageReader, int32) interface{}) - for k, v := range pgx.ValueTranscoders { - encoders[k] = v.DecodeBinary - pgx.ValueTranscoders[k].DecodeBinary = nil - } - return -} - -func restoreBinaryEncoders(encoders map[pgx.Oid]func(*pgx.MessageReader, int32) interface{}) { - for k, v := range encoders { - pgx.ValueTranscoders[k].DecodeBinary = v - } -} - -func BenchmarkSelectRowSimpleNarrow(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createNarrowTestData(b, conn) - - // Get random ids outside of timing - ids := make([]int32, b.N) - for i := 0; i < b.N; i++ { - ids[i] = 1 + rand.Int31n(9999) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = mustSelectRow(b, conn, "select * from narrow where id=$1", ids[i]) - } -} - -func BenchmarkSelectRowPreparedNarrow(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createNarrowTestData(b, conn) - - // Get random ids outside of timing - ids := make([]int32, b.N) - for i := 0; i < b.N; i++ { - ids[i] = 1 + rand.Int31n(9999) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRow(b, conn, "getNarrowById", ids[i]) - } -} - -func BenchmarkSelectRowsSimpleNarrow(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createNarrowTestData(b, conn) - - // Get random ids outside of timing - ids := make([]int32, b.N) - for i := 0; i < b.N; i++ { - ids[i] = 1 + rand.Int31n(9999) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, "select * from narrow where id between $1 and $2", ids[i], ids[i]+10) - } -} - -func BenchmarkSelectRowsPreparedNarrow(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createNarrowTestData(b, conn) - - // Get random ids outside of timing - ids := make([]int32, b.N) - for i := 0; i < b.N; i++ { - ids[i] = 1 + rand.Int31n(9999) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, "getMultipleNarrowById", ids[i], ids[i]+10) - } -} - func BenchmarkSelectValuePreparedNarrow(b *testing.B) { conn := mustConnect(b, *defaultConnConfig) defer closeConn(b, conn) @@ -148,438 +65,6 @@ func BenchmarkSelectValueToPreparedNarrow(b *testing.B) { } } -func createJoinsTestData(b *testing.B, conn *pgx.Conn) { - mustExec(b, conn, ` - drop table if exists product_component; - drop table if exists component; - drop table if exists product; - - create table component( - id serial primary key, - filler1 varchar not null default '01234567890123456789', - filler2 varchar not null default '01234567890123456789', - filler3 varchar not null default '01234567890123456789', - weight int not null, - cost int not null - ); - - insert into component(weight, cost) - select (random()*100)::int, (random()*1000)::int - from generate_series(1, 1000) n; - - create index on component (weight); - create index on component (cost); - - create table product( - id serial primary key, - filler1 varchar not null default '01234567890123456789', - filler2 varchar not null default '01234567890123456789', - filler3 varchar not null default '01234567890123456789', - filler4 varchar not null default '01234567890123456789', - filler5 varchar not null default '01234567890123456789' - ); - - insert into product(id) - select n - from generate_series(1, 10000) n; - - create table product_component( - id serial primary key, - product_id int not null references product, - component_id int not null references component, - quantity int not null - ); - - insert into product_component(product_id, component_id, quantity) - select product.id, component.id, 1 + (random()*10)::int - from product - join component on (random() * 200)::int = 1; - - create unique index on product_component(product_id, component_id); - create index on product_component(product_id); - create index on product_component(component_id); - - analyze; - `) - - mustPrepare(b, conn, "joinAggregate", ` - select product.id, sum(cost*quantity) as total_cost - from product - join product_component on product.id=product_component.product_id - join component on component.id=product_component.component_id - group by product.id - having sum(weight*quantity) > 10 - order by total_cost desc - `) -} - -func BenchmarkSelectRowsSimpleJoins(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createJoinsTestData(b, conn) - - sql := ` - select product.id, sum(cost*quantity) as total_cost - from product - join product_component on product.id=product_component.product_id - join component on component.id=product_component.component_id - group by product.id - having sum(weight*quantity) > 10 - order by total_cost desc - ` - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, sql) - } -} - -func BenchmarkSelectRowsPreparedJoins(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createJoinsTestData(b, conn) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, "joinAggregate") - } -} - -func createInt2TextVsBinaryTestData(b *testing.B, conn *pgx.Conn) { - mustExec(b, conn, ` - drop table if exists t; - - create temporary table t( - a int2 not null, - b int2 not null, - c int2 not null, - d int2 not null, - e int2 not null - ); - - insert into t(a, b, c, d, e) - select - (random() * 32000)::int2, (random() * 32000)::int2, (random() * 32000)::int2, (random() * 32000)::int2, (random() * 32000)::int2 - from generate_series(1, 10); - `) -} - -func BenchmarkInt2Text(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createInt2TextVsBinaryTestData(b, conn) - - encoders := removeBinaryEncoders() - defer func() { restoreBinaryEncoders(encoders) }() - - mustPrepare(b, conn, "selectInt16", "select * from t") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, "selectInt16") - } -} - -func BenchmarkInt2Binary(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createInt2TextVsBinaryTestData(b, conn) - - mustPrepare(b, conn, "selectInt16", "select * from t") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, "selectInt16") - } -} - -func createInt4TextVsBinaryTestData(b *testing.B, conn *pgx.Conn) { - mustExec(b, conn, ` - drop table if exists t; - - create temporary table t( - a int4 not null, - b int4 not null, - c int4 not null, - d int4 not null, - e int4 not null - ); - - insert into t(a, b, c, d, e) - select - (random() * 1000000)::int4, (random() * 1000000)::int4, (random() * 1000000)::int4, (random() * 1000000)::int4, (random() * 1000000)::int4 - from generate_series(1, 10); - `) -} - -func BenchmarkInt4Text(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createInt4TextVsBinaryTestData(b, conn) - - encoders := removeBinaryEncoders() - defer func() { restoreBinaryEncoders(encoders) }() - - mustPrepare(b, conn, "selectInt32", "select * from t") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, "selectInt32") - } -} - -func BenchmarkInt4Binary(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createInt4TextVsBinaryTestData(b, conn) - - mustPrepare(b, conn, "selectInt32", "select * from t") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, "selectInt32") - } -} - -func createInt8TextVsBinaryTestData(b *testing.B, conn *pgx.Conn) { - mustExec(b, conn, ` - drop table if exists t; - - create temporary table t( - a int8 not null, - b int8 not null, - c int8 not null, - d int8 not null, - e int8 not null - ); - - insert into t(a, b, c, d, e) - select - (random() * 1000000)::int8, (random() * 1000000)::int8, (random() * 1000000)::int8, (random() * 1000000)::int8, (random() * 1000000)::int8 - from generate_series(1, 10); - `) -} - -func BenchmarkInt8Text(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createInt8TextVsBinaryTestData(b, conn) - - encoders := removeBinaryEncoders() - defer func() { restoreBinaryEncoders(encoders) }() - - mustPrepare(b, conn, "selectInt64", "select * from t") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, "selectInt64") - } -} - -func BenchmarkInt8Binary(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createInt8TextVsBinaryTestData(b, conn) - mustPrepare(b, conn, "selectInt64", "select * from t") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, "selectInt64") - } -} - -func createFloat4TextVsBinaryTestData(b *testing.B, conn *pgx.Conn) { - mustExec(b, conn, ` - drop table if exists t; - - create temporary table t( - a float4 not null, - b float4 not null, - c float4 not null, - d float4 not null, - e float4 not null - ); - - insert into t(a, b, c, d, e) - select - (random() * 1000000)::float4, (random() * 1000000)::float4, (random() * 1000000)::float4, (random() * 1000000)::float4, (random() * 1000000)::float4 - from generate_series(1, 10); - `) -} - -func BenchmarkFloat4Text(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createFloat4TextVsBinaryTestData(b, conn) - - encoders := removeBinaryEncoders() - defer func() { restoreBinaryEncoders(encoders) }() - - mustPrepare(b, conn, "selectFloat32", "select * from t") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, "selectFloat32") - } -} - -func BenchmarkFloat4Binary(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createFloat4TextVsBinaryTestData(b, conn) - mustPrepare(b, conn, "selectFloat32", "select * from t") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, "selectFloat32") - } -} - -func createFloat8TextVsBinaryTestData(b *testing.B, conn *pgx.Conn) { - mustExec(b, conn, ` - drop table if exists t; - - create temporary table t( - a float8 not null, - b float8 not null, - c float8 not null, - d float8 not null, - e float8 not null - ); - - insert into t(a, b, c, d, e) - select - (random() * 1000000)::float8, (random() * 1000000)::float8, (random() * 1000000)::float8, (random() * 1000000)::float8, (random() * 1000000)::float8 - from generate_series(1, 10); - `) -} - -func BenchmarkFloat8Text(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createFloat8TextVsBinaryTestData(b, conn) - - encoders := removeBinaryEncoders() - defer func() { restoreBinaryEncoders(encoders) }() - - mustPrepare(b, conn, "selectFloat32", "select * from t") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, "selectFloat32") - } -} - -func BenchmarkFloat8Binary(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createFloat8TextVsBinaryTestData(b, conn) - mustPrepare(b, conn, "selectFloat32", "select * from t") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, "selectFloat32") - } -} - -func createBoolTextVsBinaryTestData(b *testing.B, conn *pgx.Conn) { - mustExec(b, conn, ` - drop table if exists t; - - create temporary table t( - a bool not null, - b bool not null, - c bool not null, - d bool not null, - e bool not null - ); - - insert into t(a, b, c, d, e) - select - random() > 0.5, random() > 0.5, random() > 0.5, random() > 0.5, random() > 0.5 - from generate_series(1, 10); - `) -} - -func BenchmarkBoolText(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createBoolTextVsBinaryTestData(b, conn) - - encoders := removeBinaryEncoders() - defer func() { restoreBinaryEncoders(encoders) }() - - mustPrepare(b, conn, "selectBool", "select * from t") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, "selectBool") - } -} - -func BenchmarkBoolBinary(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createBoolTextVsBinaryTestData(b, conn) - mustPrepare(b, conn, "selectBool", "select * from t") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, "selectBool") - } -} - -func createTimestampTzTextVsBinaryTestData(b *testing.B, conn *pgx.Conn) { - mustExec(b, conn, ` - drop table if exists t; - - create temporary table t( - a timestamptz not null, - b timestamptz not null, - c timestamptz not null, - d timestamptz not null, - e timestamptz not null - ); - - insert into t(a, b, c, d, e) - select - now() - '10 years'::interval * random(), - now() - '10 years'::interval * random(), - now() - '10 years'::interval * random(), - now() - '10 years'::interval * random(), - now() - '10 years'::interval * random() - from generate_series(1, 10); - `) -} - -func BenchmarkTimestampTzText(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createTimestampTzTextVsBinaryTestData(b, conn) - - encoders := removeBinaryEncoders() - defer func() { restoreBinaryEncoders(encoders) }() - - mustPrepare(b, conn, "selectTimestampTz", "select * from t") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, "selectTimestampTz") - } -} - -func BenchmarkTimestampTzBinary(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createTimestampTzTextVsBinaryTestData(b, conn) - mustPrepare(b, conn, "selectTimestampTz", "select * from t") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectRows(b, conn, "selectTimestampTz") - } -} - func BenchmarkConnPool(b *testing.B) { config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5} pool, err := pgx.NewConnPool(config) diff --git a/conn.go b/conn.go index 8dae0515..d68b62cf 100644 --- a/conn.go +++ b/conn.go @@ -64,7 +64,7 @@ type Conn struct { alive bool causeOfDeath error logger log.Logger - drr DataRowReader + qr QueryResult } type PreparedStatement struct { @@ -276,137 +276,6 @@ func ParseURI(uri string) (ConnConfig, error) { return cp, nil } -// SelectFunc executes sql and for each row returned calls onDataRow. sql can be -// either a prepared statement name or an SQL string. arguments will be sanitized -// before being interpolated into sql strings. arguments should be referenced -// positionally from the sql string as $1, $2, etc. -// -// SelectFunc calls onDataRow as the rows are received. This means that it does not -// need to simultaneously store the entire result set in memory. It also means that -// it is possible to process some rows and then for an error to occur. Callers -// should be aware of this possibility. -func (c *Conn) SelectFunc(sql string, onDataRow func(*DataRowReader) error, arguments ...interface{}) error { - startTime := time.Now() - err := c.selectFunc(sql, onDataRow, arguments...) - if err != nil { - c.logger.Error("SelectFunc", "sql", sql, "args", arguments, "error", err) - return err - } - - endTime := time.Now() - c.logger.Info("SelectFunc", "sql", sql, "args", arguments, "time", endTime.Sub(startTime)) - return nil -} - -func (c *Conn) selectFunc(sql string, onDataRow func(*DataRowReader) error, arguments ...interface{}) (err error) { - var fields []FieldDescription - - if ps, present := c.preparedStatements[sql]; present { - fields = ps.FieldDescriptions - err = c.sendPreparedQuery(ps, arguments...) - } else { - err = c.sendSimpleQuery(sql, arguments...) - } - if err != nil { - return - } - - var softErr error - - for { - var t byte - var r *MessageReader - t, r, err = c.rxMsg() - if err != nil { - return err - } - - switch t { - case readyForQuery: - c.rxReadyForQuery(r) - return softErr - case rowDescription: - fields = c.rxRowDescription(r) - case dataRow: - if softErr == nil { - c.drr.mr = r - c.drr.FieldDescriptions = fields - c.drr.currentFieldIdx = 0 - - fieldCount := int(r.ReadInt16()) - if fieldCount != len(fields) { - softErr = ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(fields), fieldCount)) - } - if softErr == nil { - softErr = onDataRow(&c.drr) - } - } - case commandComplete: - case bindComplete: - default: - if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { - softErr = e - } - } - } -} - -// SelectRows executes sql and returns a slice of maps representing the found rows. -// sql can be either a prepared statement name or an SQL string. arguments will be -// sanitized before being interpolated into sql strings. arguments should be referenced -// positionally from the sql string as $1, $2, etc. -func (c *Conn) SelectRows(sql string, arguments ...interface{}) ([]map[string]interface{}, error) { - startTime := time.Now() - - rows := make([]map[string]interface{}, 0, 8) - onDataRow := func(r *DataRowReader) error { - rows = append(rows, c.rxDataRow(r)) - return nil - } - err := c.selectFunc(sql, onDataRow, arguments...) - if err != nil { - c.logger.Error("SelectRows", "sql", sql, "args", arguments, "error", err) - return nil, err - } - - endTime := time.Now() - c.logger.Info("SelectRows", "sql", sql, "args", arguments, "rowsFound", len(rows), "time", endTime.Sub(startTime)) - return rows, nil -} - -// SelectRow executes sql and returns a map representing the found row. -// sql can be either a prepared statement name or an SQL string. arguments will be -// sanitized before being interpolated into sql strings. arguments should be referenced -// positionally from the sql string as $1, $2, etc. -// -// Returns a NotSingleRowError if exactly one row is not found -func (c *Conn) SelectRow(sql string, arguments ...interface{}) (map[string]interface{}, error) { - startTime := time.Now() - - var numRowsFound int64 - var row map[string]interface{} - - onDataRow := func(r *DataRowReader) error { - numRowsFound++ - row = c.rxDataRow(r) - return nil - } - err := c.selectFunc(sql, onDataRow, arguments...) - if err != nil { - c.logger.Error("SelectRow", "sql", sql, "args", arguments, "error", err) - return nil, err - } - if numRowsFound != 1 { - err = NotSingleRowError{RowCount: numRowsFound} - row = nil - } - - endTime := time.Now() - c.logger.Info("SelectRow", "sql", sql, "args", arguments, "rowsFound", numRowsFound, "time", endTime.Sub(startTime)) - - return row, err -} - // SelectValue executes sql and returns a single value. sql can be either a prepared // statement name or an SQL string. arguments will be sanitized before being // interpolated into sql strings. arguments should be referenced positionally from @@ -420,29 +289,31 @@ func (c *Conn) SelectValue(sql string, arguments ...interface{}) (interface{}, e var numRowsFound int64 var v interface{} - onDataRow := func(r *DataRowReader) error { - if len(r.FieldDescriptions) != 1 { - return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: int16(len(r.FieldDescriptions))} + qr, _ := c.Query(sql, arguments...) + defer qr.Close() + + for qr.NextRow() { + if len(qr.fields) != 1 { + qr.Close() + return nil, UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: int16(len(qr.fields))} } numRowsFound++ - v = r.ReadValue() - return nil + var rr RowReader + v = rr.ReadValue(qr) } - err := c.selectFunc(sql, onDataRow, arguments...) - if err != nil { - c.logger.Error("SelectValue", "sql", sql, "args", arguments, "error", err) - return nil, err + if qr.Err() != nil { + return nil, qr.Err() } + if numRowsFound != 1 { - err = NotSingleRowError{RowCount: numRowsFound} - v = nil + return nil, NotSingleRowError{RowCount: numRowsFound} } endTime := time.Now() c.logger.Info("SelectValue", "sql", sql, "args", arguments, "rowsFound", numRowsFound, "time", endTime.Sub(startTime)) - return v, err + return v, nil } // SelectValueTo executes sql that returns a single value and writes that value to w. @@ -561,35 +432,6 @@ func (c *Conn) rxDataRowValueTo(w io.Writer, bodySize int32) (err error) { return } -// SelectValues executes sql and returns a slice of values. sql can be either a prepared -// statement name or an SQL string. arguments will be sanitized before being -// interpolated into sql strings. arguments should be referenced positionally from -// the sql string as $1, $2, etc. -// -// Returns a UnexpectedColumnCountError if exactly one column is not found -func (c *Conn) SelectValues(sql string, arguments ...interface{}) ([]interface{}, error) { - startTime := time.Now() - - values := make([]interface{}, 0, 8) - onDataRow := func(r *DataRowReader) error { - if len(r.FieldDescriptions) != 1 { - return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: int16(len(r.FieldDescriptions))} - } - - values = append(values, r.ReadValue()) - return nil - } - err := c.selectFunc(sql, onDataRow, arguments...) - if err != nil { - c.logger.Error("SelectValues", "sql", sql, "args", arguments, "error", err) - return nil, err - } - - endTime := time.Now() - c.logger.Info("SelectValues", "sql", sql, "args", arguments, "valuesFound", len(values), "time", endTime.Sub(startTime)) - return values, nil -} - // Prepare creates a prepared statement with name and sql. sql can contain placeholders // for bound parameters. These placeholders are referenced positional as $1, $2, etc. func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { @@ -647,8 +489,10 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { ps.FieldDescriptions = c.rxRowDescription(r) for i := range ps.FieldDescriptions { oid := ps.FieldDescriptions[i].DataType - if ValueTranscoders[oid] != nil && ValueTranscoders[oid].DecodeBinary != nil { - ps.FieldDescriptions[i].FormatCode = 1 + vt := ValueTranscoders[oid] + + if vt != nil { + ps.FieldDescriptions[i].FormatCode = vt.DecodeFormat } } case noData: @@ -744,6 +588,229 @@ func (c *Conn) CauseOfDeath() error { return c.causeOfDeath } +type RowReader struct{} + +// TODO - Read*... + +func (rr *RowReader) ReadInt32(qr *QueryResult) int32 { + fd, size := qr.NextColumn() + + // TODO - do something about nulls + if size == -1 { + panic("Can't handle nulls") + } + + return decodeInt4(qr, fd, size) +} + +func (rr *RowReader) ReadTime(qr *QueryResult) time.Time { + fd, size := qr.NextColumn() + + // TODO - do something about nulls + if size == -1 { + panic("Can't handle nulls") + } + + return decodeTimestampTz(qr, fd, size) +} + +func (rr *RowReader) ReadDate(qr *QueryResult) time.Time { + fd, size := qr.NextColumn() + + // TODO - do something about nulls + if size == -1 { + panic("Can't handle nulls") + } + + return decodeDate(qr, fd, size) +} + +func (rr *RowReader) ReadString(qr *QueryResult) string { + _, size := qr.NextColumn() + return qr.mr.ReadString(size) +} + +func (rr *RowReader) ReadValue(qr *QueryResult) interface{} { + fd, size := qr.NextColumn() + + if size > -1 { + if vt, present := ValueTranscoders[fd.DataType]; present && vt.Decode != nil { + return vt.Decode(qr, fd, size) + } else { + return qr.mr.ReadString(size) + } + } else { + return nil + } +} + +type QueryResult struct { + pool *ConnPool + conn *Conn + mr *MessageReader + fields []FieldDescription + rowCount int + columnIdx int + err error + closed bool +} + +func (qr *QueryResult) FieldDescriptions() []FieldDescription { + return qr.fields +} + +func (qr *QueryResult) MessageReader() *MessageReader { + return qr.mr +} + +func (qr *QueryResult) close() { + if qr.pool != nil { + qr.pool.Release(qr.conn) + qr.pool = nil + } + + qr.closed = true +} + +func (qr *QueryResult) readUntilReadyForQuery() { + for { + t, r, err := qr.conn.rxMsg() + if err != nil { + qr.close() + return + } + + switch t { + case readyForQuery: + qr.conn.rxReadyForQuery(r) + qr.close() + return + case rowDescription: + case dataRow: + case commandComplete: + case bindComplete: + default: + err = qr.conn.processContextFreeMsg(t, r) + if err != nil { + qr.close() + return + } + } + } +} + +func (qr *QueryResult) Close() { + if qr.closed { + return + } + qr.readUntilReadyForQuery() + qr.close() +} + +func (qr *QueryResult) Err() error { + return qr.err +} + +func (qr *QueryResult) Fatal(err error) { + qr.err = err + qr.Close() +} + +func (qr *QueryResult) NextRow() bool { + if qr.closed { + return false + } + + qr.rowCount++ + qr.columnIdx = 0 + + for { + t, r, err := qr.conn.rxMsg() + if err != nil { + qr.Fatal(err) + return false + } + + switch t { + case readyForQuery: + qr.conn.rxReadyForQuery(r) + qr.close() + return false + case dataRow: + fieldCount := int(r.ReadInt16()) + if fieldCount != len(qr.fields) { + qr.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(qr.fields), fieldCount))) + return false + } + + qr.mr = r + return true + case commandComplete: + case bindComplete: + default: + err = qr.conn.processContextFreeMsg(t, r) + if err != nil { + qr.Fatal(err) + return false + } + } + } +} + +func (qr *QueryResult) NextColumn() (*FieldDescription, int32) { + fd := &qr.fields[qr.columnIdx] + qr.columnIdx++ + size := qr.mr.ReadInt32() + + return fd, size +} + +// TODO - document +func (c *Conn) Query(sql string, args ...interface{}) (*QueryResult, error) { + c.qr = QueryResult{conn: c} + qr := &c.qr + + // TODO - shouldn't be messing with qr.err and qr.closed directly + if ps, present := c.preparedStatements[sql]; present { + qr.fields = ps.FieldDescriptions + qr.err = c.sendPreparedQuery(ps, args...) + if qr.err != nil { + qr.closed = true + } + return qr, qr.err + } + + qr.err = c.sendSimpleQuery(sql, args...) + if qr.err != nil { + qr.closed = true + return qr, qr.err + } + + // Simple queries don't know the field descriptions of the result. + // Read until that is known before returning + for { + t, r, err := c.rxMsg() + if err != nil { + qr.err = err + qr.closed = true + return qr, qr.err + } + + switch t { + case rowDescription: + qr.fields = qr.conn.rxRowDescription(r) + return qr, nil + default: + err = qr.conn.processContextFreeMsg(t, r) + if err != nil { + qr.closed = true + qr.err = err + return qr, qr.err + } + } + } +} + func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) { if ps, present := c.preparedStatements[sql]; present { return c.sendPreparedQuery(ps, arguments...) @@ -812,8 +879,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(int16(len(ps.FieldDescriptions))) for _, fd := range ps.FieldDescriptions { transcoder := ValueTranscoders[fd.DataType] - if transcoder != nil && transcoder.DecodeBinary != nil { - wbuf.WriteInt16(1) + if transcoder != nil { + wbuf.WriteInt16(transcoder.DecodeFormat) } else { wbuf.WriteInt16(0) } @@ -958,7 +1025,6 @@ func (c *Conn) rxMsg() (t byte, r *MessageReader, err error) { if body, err = c.rxMsgBody(bodySize); err != nil { return } - r = (*MessageReader)(body) return } @@ -1041,6 +1107,9 @@ func (c *Conn) rxErrorResponse(r *MessageReader) (err PgError) { case 'M': err.Message = r.ReadCString() case 0: // End of error message + if err.Severity == "FATAL" { + c.die(err) + } return default: // Ignore other error fields r.ReadCString() @@ -1082,16 +1151,6 @@ func (c *Conn) rxParameterDescription(r *MessageReader) (parameters []Oid) { return } -func (c *Conn) rxDataRow(r *DataRowReader) (row map[string]interface{}) { - fieldCount := len(r.FieldDescriptions) - - row = make(map[string]interface{}, fieldCount) - for i := 0; i < fieldCount; i++ { - row[r.FieldDescriptions[i].Name] = r.ReadValue() - } - return -} - func (c *Conn) rxCommandComplete(r *MessageReader) string { return r.ReadCString() } diff --git a/conn_pool.go b/conn_pool.go index 751e6012..db65a842 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -166,39 +166,6 @@ func (p *ConnPool) createConnection() (c *Conn, err error) { return } -// SelectFunc acquires a connection, delegates the call to that connection, and releases the connection -func (p *ConnPool) SelectFunc(sql string, onDataRow func(*DataRowReader) error, arguments ...interface{}) (err error) { - var c *Conn - if c, err = p.Acquire(); err != nil { - return - } - defer p.Release(c) - - return c.SelectFunc(sql, onDataRow, arguments...) -} - -// SelectRows acquires a connection, delegates the call to that connection, and releases the connection -func (p *ConnPool) SelectRows(sql string, arguments ...interface{}) (rows []map[string]interface{}, err error) { - var c *Conn - if c, err = p.Acquire(); err != nil { - return - } - defer p.Release(c) - - return c.SelectRows(sql, arguments...) -} - -// SelectRow acquires a connection, delegates the call to that connection, and releases the connection -func (p *ConnPool) SelectRow(sql string, arguments ...interface{}) (row map[string]interface{}, err error) { - var c *Conn - if c, err = p.Acquire(); err != nil { - return - } - defer p.Release(c) - - return c.SelectRow(sql, arguments...) -} - // SelectValue acquires a connection, delegates the call to that connection, and releases the connection func (p *ConnPool) SelectValue(sql string, arguments ...interface{}) (v interface{}, err error) { var c *Conn @@ -221,17 +188,6 @@ func (p *ConnPool) SelectValueTo(w io.Writer, sql string, arguments ...interface return c.SelectValueTo(w, sql, arguments...) } -// SelectValues acquires a connection, delegates the call to that connection, and releases the connection -func (p *ConnPool) SelectValues(sql string, arguments ...interface{}) (values []interface{}, err error) { - var c *Conn - if c, err = p.Acquire(); err != nil { - return - } - defer p.Release(c) - - return c.SelectValues(sql, arguments...) -} - // Exec acquires a connection, delegates the call to that connection, and releases the connection func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { var c *Conn @@ -243,6 +199,23 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman return c.Exec(sql, arguments...) } +func (p *ConnPool) Query(sql string, args ...interface{}) (*QueryResult, error) { + c, err := p.Acquire() + if err != nil { + // Because checking for errors can be deferred to the *QueryResult, build one with the error + return &QueryResult{closed: true, err: err}, err + } + + qr, err := c.Query(sql, args...) + if err != nil { + p.Release(c) + return qr, err + } + + qr.pool = p + return qr, nil +} + // Transaction acquires a connection, delegates the call to that connection, // and releases the connection. The call signature differs slightly from the // underlying Transaction in that the callback function accepts a *Conn diff --git a/conn_pool_test.go b/conn_pool_test.go index fcd8e9fb..92b66450 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -374,3 +374,44 @@ func TestPoolTransactionIso(t *testing.T) { t.Fatal("Transaction was not committed when it should have been") } } + +func TestConnPoolQuery(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 2) + defer pool.Close() + + var sum, rowCount int32 + + qr, err := pool.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("pool.Query failed: %v", err) + } + + stats := pool.Stat() + if stats.CurrentConnections != 1 || stats.AvailableConnections != 0 { + t.Fatalf("Unexpected connection pool stats: %v", stats) + } + + for qr.NextRow() { + var rr pgx.RowReader + sum += rr.ReadInt32(qr) + rowCount++ + } + + if qr.Err() != nil { + t.Fatalf("conn.Query failed: ", err) + } + + if rowCount != 10 { + t.Error("Select called onDataRow wrong number of times") + } + if sum != 55 { + t.Error("Wrong values returned") + } + + stats = pool.Stat() + if stats.CurrentConnections != 1 || stats.AvailableConnections != 1 { + t.Fatalf("Unexpected connection pool stats: %v", stats) + } +} diff --git a/conn_test.go b/conn_test.go index 570e7abe..bc615dc4 100644 --- a/conn_test.go +++ b/conn_test.go @@ -30,9 +30,8 @@ func TestConnect(t *testing.T) { t.Error("Backend secret key not stored") } - var rows []map[string]interface{} - rows, err = conn.SelectRows("select current_database()") - if err != nil || rows[0]["current_database"] != defaultConnConfig.Database { + currentDB, err := conn.SelectValue("select current_database()") + if err != nil || currentDB != defaultConnConfig.Database { t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) } @@ -278,23 +277,30 @@ func TestExecFailure(t *testing.T) { } } -func TestSelectFunc(t *testing.T) { +func TestConnQuery(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) var sum, rowCount int32 - onDataRow := func(r *pgx.DataRowReader) error { + + rows, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + defer rows.Close() + + for rows.NextRow() { + var rr pgx.RowReader + sum += rr.ReadInt32(rows) rowCount++ - sum += r.ReadValue().(int32) - return nil } - err := conn.SelectFunc("select generate_series(1,$1)", onDataRow, 10) - if err != nil { - t.Fatal("Select failed: " + err.Error()) + if rows.Err() != nil { + t.Fatalf("conn.Query failed: ", err) } + if rowCount != 10 { t.Error("Select called onDataRow wrong number of times") } @@ -303,124 +309,6 @@ func TestSelectFunc(t *testing.T) { } } -func TestSelectFuncFailure(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - // using SelectValue as it delegates to SelectFunc and is easier to work with - if _, err := conn.SelectValue("select;"); err == nil { - t.Fatal("Expected SQL syntax error") - } - - if _, err := conn.SelectValue("select 1"); err != nil { - t.Fatalf("SelectFunc failure appears to have broken connection: %v", err) - } -} - -func Example_connectionSelectFunc() { - conn, err := pgx.Connect(*defaultConnConfig) - if err != nil { - fmt.Printf("Unable to establish connection: %v", err) - return - } - - onDataRow := func(r *pgx.DataRowReader) error { - fmt.Println(r.ReadValue()) - return nil - } - - err = conn.SelectFunc("select generate_series(1,$1)", onDataRow, 5) - if err != nil { - fmt.Println(err) - } - // Output: - // 1 - // 2 - // 3 - // 4 - // 5 -} - -func TestSelectRows(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - rows := mustSelectRows(t, conn, "select $1 as name, null as position", "Jack") - - if len(rows) != 1 { - t.Fatal("Received wrong number of rows") - } - - if rows[0]["name"] != "Jack" { - t.Error("Received incorrect name") - } - - if value, presence := rows[0]["position"]; presence { - if value != nil { - t.Error("Should have received nil for null") - } - } else { - t.Error("Null value should have been present in map as nil") - } -} - -func Example_connectionSelectRows() { - conn, err := pgx.Connect(*defaultConnConfig) - if err != nil { - fmt.Printf("Unable to establish connection: %v", err) - return - } - - var rows []map[string]interface{} - if rows, err = conn.SelectRows("select generate_series(1,$1) as number", 5); err != nil { - fmt.Printf("Error selecting rows: %v", err) - return - } - for _, r := range rows { - fmt.Println(r["number"]) - } - // Output: - // 1 - // 2 - // 3 - // 4 - // 5 -} - -func TestSelectRow(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - row := mustSelectRow(t, conn, "select $1 as name, null as position", "Jack") - if row["name"] != "Jack" { - t.Error("Received incorrect name") - } - - if value, presence := row["position"]; presence { - if value != nil { - t.Error("Should have received nil for null") - } - } else { - t.Error("Null value should have been present in map as nil") - } - - _, err := conn.SelectRow("select 'Jack' as name where 1=2") - if _, ok := err.(pgx.NotSingleRowError); !ok { - t.Error("No matching row should have returned NotSingleRowError") - } - - _, err = conn.SelectRow("select * from (values ('Matthew'), ('Mark')) t") - if _, ok := err.(pgx.NotSingleRowError); !ok { - t.Error("Multiple matching rows should have returned NotSingleRowError") - } -} - func TestConnectionSelectValue(t *testing.T) { t.Parallel() @@ -438,6 +326,7 @@ func TestConnectionSelectValue(t *testing.T) { } } + fmt.Println("Starting test") test("select $1", "foo", "foo") test("select 'foo'", "foo") test("select true", true) @@ -515,41 +404,6 @@ func TestConnectionSelectValueTo(t *testing.T) { } -func TestSelectValues(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - test := func(sql string, expected []interface{}, arguments ...interface{}) { - values, err := conn.SelectValues(sql, arguments...) - if err != nil { - t.Errorf("%v while running %v", err, sql) - return - } - if len(values) != len(expected) { - t.Errorf("Expected: %#v Received: %#v", expected, values) - return - } - for i := 0; i < len(values); i++ { - if values[i] != expected[i] { - t.Errorf("Expected: %#v Received: %#v", expected, values) - return - } - } - } - - test("select * from (values ($1)) t", []interface{}{"Matthew"}, "Matthew") - test("select * from (values ('Matthew'), ('Mark'), ('Luke'), ('John')) t", []interface{}{"Matthew", "Mark", "Luke", "John"}) - test("select * from (values ('Matthew'), (null)) t", []interface{}{"Matthew", nil}) - test("select * from (values (1::int4), (2::int4), (null), (3::int4)) t", []interface{}{int32(1), int32(2), nil, int32(3)}) - - _, err := conn.SelectValues("select 'Matthew', 'Mark'") - if _, ok := err.(pgx.UnexpectedColumnCountError); !ok { - t.Error("Multiple columns should have returned UnexpectedColumnCountError") - } -} - func TestPrepare(t *testing.T) { t.Parallel() @@ -872,11 +726,13 @@ func TestFatalTxError(t *testing.T) { } defer otherConn.Close() - if _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.Pid); err != nil { + _, err = otherConn.Exec("select pg_terminate_backend($1)", conn.Pid) + if err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } - if _, err := conn.SelectValue("select 1"); err == nil { + _, err = conn.SelectValue("select 1") + if err == nil { t.Fatal("Expected error but none occurred") } diff --git a/data_row_reader.go b/data_row_reader.go deleted file mode 100644 index 2fc77e3b..00000000 --- a/data_row_reader.go +++ /dev/null @@ -1,40 +0,0 @@ -package pgx - -import ( - "fmt" -) - -// DataRowReader is used by SelectFunc to process incoming rows. -type DataRowReader struct { - mr *MessageReader - FieldDescriptions []FieldDescription - currentFieldIdx int -} - -func (r *DataRowReader) MessageReader() *MessageReader { - return r.mr -} - -// ReadValue returns the next value from the current row. -func (r *DataRowReader) ReadValue() interface{} { - fieldDescription := r.FieldDescriptions[r.currentFieldIdx] - r.currentFieldIdx++ - - size := r.mr.ReadInt32() - if size > -1 { - if vt, present := ValueTranscoders[fieldDescription.DataType]; present { - switch fieldDescription.FormatCode { - case 0: - return vt.DecodeText(r.mr, size) - case 1: - return vt.DecodeBinary(r.mr, size) - default: - return ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fieldDescription.FormatCode)) - } - } else { - return r.mr.ReadString(size) - } - } else { - return nil - } -} diff --git a/data_row_reader_test.go b/data_row_reader_test.go deleted file mode 100644 index 24f13863..00000000 --- a/data_row_reader_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package pgx_test - -import ( - "github.com/jackc/pgx" - "testing" -) - -func TestDataRowReaderReadValue(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - test := func(sql string, expected interface{}) { - var v interface{} - - onDataRow := func(r *pgx.DataRowReader) error { - v = r.ReadValue() - return nil - } - - err := conn.SelectFunc(sql, onDataRow) - if err != nil { - t.Fatalf("Select failed: %v", err) - } - if v != expected { - t.Errorf("Expected: %#v Received: %#v", expected, v) - } - } - - test("select null", nil) - test("select 'Jack'", "Jack") - test("select true", true) - test("select false", false) - test("select 1::int2", int16(1)) - test("select 1::int4", int32(1)) - test("select 1::int8", int64(1)) - test("select 1.23::float4", float32(1.23)) - test("select 1.23::float8", float64(1.23)) -} diff --git a/example_value_transcoder_test.go b/example_value_transcoder_test.go index 2d7b25a5..554c9343 100644 --- a/example_value_transcoder_test.go +++ b/example_value_transcoder_test.go @@ -7,6 +7,10 @@ import ( "strconv" ) +const ( + pointOid = 600 +) + var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`) type Point struct { @@ -19,9 +23,11 @@ func (p Point) String() string { } func Example_customValueTranscoder() { - pgx.ValueTranscoders[pgx.Oid(600)] = &pgx.ValueTranscoder{ - DecodeText: decodePointFromText, - EncodeTo: encodePoint} + pgx.ValueTranscoders[pointOid] = &pgx.ValueTranscoder{ + Decode: func(qr *pgx.QueryResult, fd *pgx.FieldDescription, size int32) interface{} { + return decodePoint(qr, fd, size) + }, + EncodeTo: encodePoint} conn, err := pgx.Connect(*defaultConnConfig) if err != nil { @@ -35,24 +41,39 @@ func Example_customValueTranscoder() { // 1.5, 2.5 } -func decodePointFromText(mr *pgx.MessageReader, size int32) interface{} { - s := mr.ReadString(size) - match := pointRegexp.FindStringSubmatch(s) - if match == nil { - return pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s)) +func decodePoint(qr *pgx.QueryResult, fd *pgx.FieldDescription, size int32) Point { + var p Point + + if fd.DataType != pointOid { + qr.Fatal(pgx.ProtocolError(fmt.Sprintf("Tried to read point but received: %v", fd.DataType))) + return p } - var err error - var p Point - p.x, err = strconv.ParseFloat(match[1], 64) - if err != nil { - return pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s)) + switch fd.FormatCode { + case pgx.TextFormatCode: + s := qr.MessageReader().ReadString(size) + match := pointRegexp.FindStringSubmatch(s) + if match == nil { + qr.Fatal(pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s))) + return p + } + + var err error + p.x, err = strconv.ParseFloat(match[1], 64) + if err != nil { + qr.Fatal(pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s))) + return p + } + p.y, err = strconv.ParseFloat(match[2], 64) + if err != nil { + qr.Fatal(pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s))) + return p + } + return p + default: + qr.Fatal(pgx.ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + return p } - p.y, err = strconv.ParseFloat(match[2], 64) - if err != nil { - return pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s)) - } - return p } func encodePoint(w *pgx.WriteBuf, value interface{}) error { diff --git a/helper_test.go b/helper_test.go index 821b5f98..ff5dd0a0 100644 --- a/helper_test.go +++ b/helper_test.go @@ -35,22 +35,6 @@ func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{} return } -func mustSelectRow(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (row map[string]interface{}) { - var err error - if row, err = conn.SelectRow(sql, arguments...); err != nil { - t.Fatalf("SelectRow unexpectedly failed with %v: %v", sql, err) - } - return -} - -func mustSelectRows(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (rows []map[string]interface{}) { - var err error - if rows, err = conn.SelectRows(sql, arguments...); err != nil { - t.Fatalf("SelectRows unexpected failed with %v: %v", sql, err) - } - return -} - func mustSelectValue(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (value interface{}) { var err error if value, err = conn.SelectValue(sql, arguments...); err != nil { diff --git a/stdlib/sql.go b/stdlib/sql.go index 2e08f710..58f99d58 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -133,44 +133,12 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { args := valueToInterface(argsV) - rowCount := 0 - columnsChan := make(chan []string) - errChan := make(chan error) - rowChan := make(chan []driver.Value, 8) - - go func() { - err := c.conn.SelectFunc(query, func(r *pgx.DataRowReader) error { - if rowCount == 0 { - fieldNames := make([]string, len(r.FieldDescriptions)) - for i, fd := range r.FieldDescriptions { - fieldNames[i] = fd.Name - } - columnsChan <- fieldNames - } - rowCount++ - - values := make([]driver.Value, len(r.FieldDescriptions)) - for i, _ := range r.FieldDescriptions { - values[i] = r.ReadValue() - } - rowChan <- values - - return nil - }, args...) - close(rowChan) - if err != nil { - errChan <- err - } - }() - - rows := Rows{rowChan: rowChan} - - select { - case rows.columnNames = <-columnsChan: - return &rows, nil - case err := <-errChan: + qr, err := c.conn.Query(query, args...) + if err != nil { return nil, err } + + return &Rows{qr: qr}, nil } type Stmt struct { @@ -194,29 +162,40 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { return s.conn.Query(s.ps.Name, argsV) } +// TODO - rename to avoid alloc type Rows struct { - columnNames []string - rowChan chan []driver.Value + qr *pgx.QueryResult } func (r *Rows) Columns() []string { - return r.columnNames + fieldDescriptions := r.qr.FieldDescriptions() + names := make([]string, 0, len(fieldDescriptions)) + for _, fd := range fieldDescriptions { + names = append(names, fd.Name) + } + return names } func (r *Rows) Close() error { - for _ = range r.rowChan { - // Ensure all rows are read - } + r.qr.Close() return nil } func (r *Rows) Next(dest []driver.Value) error { - row, ok := <-r.rowChan - if !ok { - return io.EOF + more := r.qr.NextRow() + if !more { + if r.qr.Err() == nil { + return io.EOF + } else { + return r.qr.Err() + } + } + + var rr pgx.RowReader + for i, _ := range r.qr.FieldDescriptions() { + dest[i] = driver.Value(rr.ReadValue(r.qr)) } - copy(dest, row) return nil } diff --git a/value_transcoder.go b/value_transcoder.go index 3da5173d..357fd822 100644 --- a/value_transcoder.go +++ b/value_transcoder.go @@ -28,18 +28,23 @@ const ( TimestampTzOid = 1184 ) +const ( + TextFormatCode = 0 + BinaryFormatCode = 1 +) + // ValueTranscoder stores all the data necessary to encode and decode values from // a PostgreSQL server type ValueTranscoder struct { - // DecodeText decodes values returned from the server in text format - DecodeText func(*MessageReader, int32) interface{} - // DecodeBinary decodes values returned from the server in binary format - DecodeBinary func(*MessageReader, int32) interface{} + // Decode decodes values returned from the server + Decode func(qr *QueryResult, fd *FieldDescription, size int32) interface{} + // DecodeFormat is the preferred response format. + // Allowed values: TextFormatCode, BinaryFormatCode + DecodeFormat int16 // EncodeTo encodes values to send to the server EncodeTo func(*WriteBuf, interface{}) error // EncodeFormat is the format values are encoded for transmission. - // 0 = text - // 1 = binary + // Allowed values: TextFormatCode, BinaryFormatCode EncodeFormat int16 } @@ -55,86 +60,104 @@ func init() { // bool ValueTranscoders[BoolOid] = &ValueTranscoder{ - DecodeText: decodeBoolFromText, - DecodeBinary: decodeBoolFromBinary, + Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeBool(qr, fd, size) }, + DecodeFormat: BinaryFormatCode, EncodeTo: encodeBool, - EncodeFormat: 1} + EncodeFormat: BinaryFormatCode} // bytea ValueTranscoders[ByteaOid] = &ValueTranscoder{ - DecodeText: decodeByteaFromText, + Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeBytea(qr, fd, size) }, + DecodeFormat: TextFormatCode, EncodeTo: encodeBytea, - EncodeFormat: 1} + EncodeFormat: BinaryFormatCode} // int8 ValueTranscoders[Int8Oid] = &ValueTranscoder{ - DecodeText: decodeInt8FromText, - DecodeBinary: decodeInt8FromBinary, + Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeInt8(qr, fd, size) }, + DecodeFormat: BinaryFormatCode, EncodeTo: encodeInt8, - EncodeFormat: 1} + EncodeFormat: BinaryFormatCode} // int2 ValueTranscoders[Int2Oid] = &ValueTranscoder{ - DecodeText: decodeInt2FromText, - DecodeBinary: decodeInt2FromBinary, + Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeInt2(qr, fd, size) }, + DecodeFormat: BinaryFormatCode, EncodeTo: encodeInt2, - EncodeFormat: 1} + EncodeFormat: BinaryFormatCode} // int4 ValueTranscoders[Int4Oid] = &ValueTranscoder{ - DecodeText: decodeInt4FromText, - DecodeBinary: decodeInt4FromBinary, + Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeInt4(qr, fd, size) }, + DecodeFormat: BinaryFormatCode, EncodeTo: encodeInt4, - EncodeFormat: 1} + EncodeFormat: BinaryFormatCode} // text ValueTranscoders[TextOid] = &ValueTranscoder{ - DecodeText: decodeTextFromText, - EncodeTo: encodeText} + Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeText(qr, fd, size) }, + DecodeFormat: TextFormatCode, + EncodeTo: encodeText, + EncodeFormat: TextFormatCode} // float4 ValueTranscoders[Float4Oid] = &ValueTranscoder{ - DecodeText: decodeFloat4FromText, - DecodeBinary: decodeFloat4FromBinary, + Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeFloat4(qr, fd, size) }, EncodeTo: encodeFloat4, - EncodeFormat: 1} + EncodeFormat: BinaryFormatCode} // float8 ValueTranscoders[Float8Oid] = &ValueTranscoder{ - DecodeText: decodeFloat8FromText, - DecodeBinary: decodeFloat8FromBinary, + Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeFloat8(qr, fd, size) }, + DecodeFormat: BinaryFormatCode, EncodeTo: encodeFloat8, - EncodeFormat: 1} + EncodeFormat: BinaryFormatCode} // int2[] ValueTranscoders[Int2ArrayOid] = &ValueTranscoder{ - DecodeText: decodeInt2ArrayFromText, - EncodeTo: encodeInt2Array} + Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { + return decodeInt2Array(qr, fd, size) + }, + DecodeFormat: TextFormatCode, + EncodeTo: encodeInt2Array, + EncodeFormat: TextFormatCode} // int4[] ValueTranscoders[Int4ArrayOid] = &ValueTranscoder{ - DecodeText: decodeInt4ArrayFromText, - EncodeTo: encodeInt4Array} + Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { + return decodeInt4Array(qr, fd, size) + }, + DecodeFormat: TextFormatCode, + EncodeTo: encodeInt4Array, + EncodeFormat: TextFormatCode} // int8[] ValueTranscoders[Int8ArrayOid] = &ValueTranscoder{ - DecodeText: decodeInt8ArrayFromText, - EncodeTo: encodeInt8Array} + Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { + return decodeInt8Array(qr, fd, size) + }, + DecodeFormat: TextFormatCode, + EncodeTo: encodeInt8Array, + EncodeFormat: TextFormatCode} // varchar -- same as text ValueTranscoders[VarcharOid] = ValueTranscoders[Oid(25)] // date ValueTranscoders[DateOid] = &ValueTranscoder{ - DecodeText: decodeDateFromText, - DecodeBinary: decodeDateFromBinary, - EncodeTo: encodeDate} + Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { return decodeDate(qr, fd, size) }, + DecodeFormat: BinaryFormatCode, + EncodeTo: encodeDate, + EncodeFormat: TextFormatCode} // timestamptz ValueTranscoders[TimestampTzOid] = &ValueTranscoder{ - DecodeText: decodeTimestampTzFromText, - DecodeBinary: decodeTimestampTzFromBinary, - EncodeTo: encodeTimestampTz} + Decode: func(qr *QueryResult, fd *FieldDescription, size int32) interface{} { + return decodeTimestampTz(qr, fd, size) + }, + DecodeFormat: BinaryFormatCode, + EncodeTo: encodeTimestampTz, + EncodeFormat: TextFormatCode} // use text transcoder for anything we don't understand defaultTranscoder = ValueTranscoders[TextOid] @@ -158,26 +181,32 @@ func SplitArrayText(text string) (elements []string) { return } -func decodeBoolFromText(mr *MessageReader, size int32) interface{} { - s := mr.ReadString(size) - switch s { - case "t": - return true - case "f": - return false +func decodeBool(qr *QueryResult, fd *FieldDescription, size int32) bool { + switch fd.FormatCode { + case TextFormatCode: + s := qr.mr.ReadString(size) + switch s { + case "t": + return true + case "f": + return false + default: + qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid bool: %v", s))) + return false + } + case BinaryFormatCode: + if size != 1 { + qr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", size))) + return false + } + b := qr.mr.ReadByte() + return b != 0 default: - return ProtocolError(fmt.Sprintf("Received invalid bool: %v", s)) + qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + return false } } -func decodeBoolFromBinary(mr *MessageReader, size int32) interface{} { - if size != 1 { - return ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", size)) - } - b := mr.ReadByte() - return b != 0 -} - func encodeBool(w *WriteBuf, value interface{}) error { v, ok := value.(bool) if !ok { @@ -196,20 +225,31 @@ func encodeBool(w *WriteBuf, value interface{}) error { return nil } -func decodeInt8FromText(mr *MessageReader, size int32) interface{} { - s := mr.ReadString(size) - n, err := strconv.ParseInt(s, 10, 64) - if err != nil { - return ProtocolError(fmt.Sprintf("Received invalid int8: %v", s)) +func decodeInt8(qr *QueryResult, fd *FieldDescription, size int32) int64 { + if fd.DataType != Int8Oid { + qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read %v but received: %v", Int8Oid, fd.DataType))) + return 0 } - return n -} -func decodeInt8FromBinary(mr *MessageReader, size int32) interface{} { - if size != 8 { - return ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", size)) + switch fd.FormatCode { + case TextFormatCode: + s := qr.mr.ReadString(size) + n, err := strconv.ParseInt(s, 10, 64) + if err != nil { + qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int8: %v", s))) + return 0 + } + return n + case BinaryFormatCode: + if size != 8 { + qr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", size))) + return 0 + } + return qr.mr.ReadInt64() + default: + qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + return 0 } - return mr.ReadInt64() } func encodeInt8(w *WriteBuf, value interface{}) error { @@ -246,20 +286,31 @@ func encodeInt8(w *WriteBuf, value interface{}) error { return nil } -func decodeInt2FromText(mr *MessageReader, size int32) interface{} { - s := mr.ReadString(size) - n, err := strconv.ParseInt(s, 10, 16) - if err != nil { - return ProtocolError(fmt.Sprintf("Received invalid int2: %v", s)) +func decodeInt2(qr *QueryResult, fd *FieldDescription, size int32) int16 { + if fd.DataType != Int2Oid { + qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read %v but received: %v", Int2Oid, fd.DataType))) + return 0 } - return int16(n) -} -func decodeInt2FromBinary(mr *MessageReader, size int32) interface{} { - if size != 2 { - return ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", size)) + switch fd.FormatCode { + case TextFormatCode: + s := qr.mr.ReadString(size) + n, err := strconv.ParseInt(s, 10, 16) + if err != nil { + qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int2: %v", s))) + return 0 + } + return int16(n) + case BinaryFormatCode: + if size != 2 { + qr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", size))) + return 0 + } + return qr.mr.ReadInt16() + default: + qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + return 0 } - return mr.ReadInt16() } func encodeInt2(w *WriteBuf, value interface{}) error { @@ -311,20 +362,30 @@ func encodeInt2(w *WriteBuf, value interface{}) error { return nil } -func decodeInt4FromText(mr *MessageReader, size int32) interface{} { - s := mr.ReadString(size) - n, err := strconv.ParseInt(s, 10, 32) - if err != nil { - return ProtocolError(fmt.Sprintf("Received invalid int4: %v", s)) +func decodeInt4(qr *QueryResult, fd *FieldDescription, size int32) int32 { + if fd.DataType != Int4Oid { + qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read %v but received: %v", Int4Oid, fd.DataType))) + return 0 } - return int32(n) -} -func decodeInt4FromBinary(mr *MessageReader, size int32) interface{} { - if size != 4 { - return ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", size)) + switch fd.FormatCode { + case TextFormatCode: + s := qr.mr.ReadString(size) + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int4: %v", s))) + } + return int32(n) + case BinaryFormatCode: + if size != 4 { + qr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", size))) + return 0 + } + return qr.mr.ReadInt32() + default: + qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + return 0 } - return mr.ReadInt32() } func encodeInt4(w *WriteBuf, value interface{}) error { @@ -370,23 +431,29 @@ func encodeInt4(w *WriteBuf, value interface{}) error { return nil } -func decodeFloat4FromText(mr *MessageReader, size int32) interface{} { - s := mr.ReadString(size) - n, err := strconv.ParseFloat(s, 32) - if err != nil { - return ProtocolError(fmt.Sprintf("Received invalid float4: %v", s)) - } - return float32(n) -} +func decodeFloat4(qr *QueryResult, fd *FieldDescription, size int32) float32 { + switch fd.FormatCode { + case TextFormatCode: + s := qr.mr.ReadString(size) + n, err := strconv.ParseFloat(s, 32) + if err != nil { + qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid float4: %v", s))) + return 0 + } + return float32(n) + case BinaryFormatCode: + if size != 4 { + qr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", size))) + return 0 + } -func decodeFloat4FromBinary(mr *MessageReader, size int32) interface{} { - if size != 4 { - return ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", size)) + i := qr.mr.ReadInt32() + p := unsafe.Pointer(&i) + return *(*float32)(p) + default: + qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + return 0 } - - i := mr.ReadInt32() - p := unsafe.Pointer(&i) - return *(*float32)(p) } func encodeFloat4(w *WriteBuf, value interface{}) error { @@ -411,23 +478,29 @@ func encodeFloat4(w *WriteBuf, value interface{}) error { return nil } -func decodeFloat8FromText(mr *MessageReader, size int32) interface{} { - s := mr.ReadString(size) - v, err := strconv.ParseFloat(s, 64) - if err != nil { - return ProtocolError(fmt.Sprintf("Received invalid float8: %v", s)) - } - return v -} +func decodeFloat8(qr *QueryResult, fd *FieldDescription, size int32) float64 { + switch fd.FormatCode { + case TextFormatCode: + s := qr.mr.ReadString(size) + v, err := strconv.ParseFloat(s, 64) + if err != nil { + qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid float8: %v", s))) + return 0 + } + return v + case BinaryFormatCode: + if size != 8 { + qr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", size))) + return 0 + } -func decodeFloat8FromBinary(mr *MessageReader, size int32) interface{} { - if size != 8 { - return ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", size)) + i := qr.mr.ReadInt64() + p := unsafe.Pointer(&i) + return *(*float64)(p) + default: + qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + return 0 } - - i := mr.ReadInt64() - p := unsafe.Pointer(&i) - return *(*float64)(p) } func encodeFloat8(w *WriteBuf, value interface{}) error { @@ -449,8 +522,8 @@ func encodeFloat8(w *WriteBuf, value interface{}) error { return nil } -func decodeTextFromText(mr *MessageReader, size int32) interface{} { - return mr.ReadString(size) +func decodeText(qr *QueryResult, fd *FieldDescription, size int32) string { + return qr.mr.ReadString(size) } func encodeText(w *WriteBuf, value interface{}) error { @@ -465,13 +538,20 @@ func encodeText(w *WriteBuf, value interface{}) error { return nil } -func decodeByteaFromText(mr *MessageReader, size int32) interface{} { - s := mr.ReadString(size) - b, err := hex.DecodeString(s[2:]) - if err != nil { - return ProtocolError(fmt.Sprintf("Can't decode byte array: %v - %v", err, s)) +func decodeBytea(qr *QueryResult, fd *FieldDescription, size int32) []byte { + switch fd.FormatCode { + case TextFormatCode: + s := qr.mr.ReadString(size) + b, err := hex.DecodeString(s[2:]) + if err != nil { + qr.Fatal(ProtocolError(fmt.Sprintf("Can't decode byte array: %v - %v", err, s))) + return nil + } + return b + default: + qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + return nil } - return b } func encodeBytea(w *WriteBuf, value interface{}) error { @@ -486,18 +566,33 @@ func encodeBytea(w *WriteBuf, value interface{}) error { return nil } -func decodeDateFromText(mr *MessageReader, size int32) interface{} { - s := mr.ReadString(size) - t, err := time.ParseInLocation("2006-01-02", s, time.Local) - if err != nil { - return ProtocolError(fmt.Sprintf("Can't decode date: %v", s)) - } - return t -} +func decodeDate(qr *QueryResult, fd *FieldDescription, size int32) time.Time { + var zeroTime time.Time -func decodeDateFromBinary(mr *MessageReader, size int32) interface{} { - dayOffset := mr.ReadInt32() - return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local) + if fd.DataType != DateOid { + qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read date but received: %v", fd.DataType))) + return zeroTime + } + + switch fd.FormatCode { + case TextFormatCode: + s := qr.mr.ReadString(size) + t, err := time.ParseInLocation("2006-01-02", s, time.Local) + if err != nil { + qr.Fatal(ProtocolError(fmt.Sprintf("Can't decode date: %v", s))) + return zeroTime + } + return t + case BinaryFormatCode: + if size != 4 { + qr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", size))) + } + dayOffset := qr.mr.ReadInt32() + return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local) + default: + qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + return zeroTime + } } func encodeDate(w *WriteBuf, value interface{}) error { @@ -510,27 +605,35 @@ func encodeDate(w *WriteBuf, value interface{}) error { return encodeText(w, s) } -func decodeTimestampTzFromText(mr *MessageReader, size int32) interface{} { - s := mr.ReadString(size) - t, err := time.Parse("2006-01-02 15:04:05.999999-07", s) - if err != nil { - return ProtocolError(fmt.Sprintf("Can't decode timestamptz: %v - %v", err, s)) +func decodeTimestampTz(qr *QueryResult, fd *FieldDescription, size int32) time.Time { + var zeroTime time.Time + + if fd.DataType != TimestampTzOid { + qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read timestamptz but received: %v", fd.DataType))) + return zeroTime } - return t -} -func decodeTimestampTzFromBinary(mr *MessageReader, size int32) interface{} { - if size != 8 { - return ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", size)) + switch fd.FormatCode { + case TextFormatCode: + s := qr.mr.ReadString(size) + t, err := time.Parse("2006-01-02 15:04:05.999999-07", s) + if err != nil { + qr.Fatal(ProtocolError(fmt.Sprintf("Can't decode timestamptz: %v - %v", err, s))) + return zeroTime + } + return t + case BinaryFormatCode: + if size != 8 { + qr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", size))) + } + microsecFromUnixEpochToY2K := int64(946684800 * 1000000) + microsecSinceY2K := qr.mr.ReadInt64() + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) + default: + qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + return zeroTime } - microsecFromUnixEpochToY2K := int64(946684800 * 1000000) - microsecSinceY2K := mr.ReadInt64() - microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K - return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) - - // 2000-01-01 00:00:00 in 946684800 - // 946684800 * 1000000 - } func encodeTimestampTz(w *WriteBuf, value interface{}) error { @@ -543,22 +646,34 @@ func encodeTimestampTz(w *WriteBuf, value interface{}) error { return encodeText(w, s) } -func decodeInt2ArrayFromText(mr *MessageReader, size int32) interface{} { - s := mr.ReadString(size) - - elements := SplitArrayText(s) - - numbers := make([]int16, 0, len(elements)) - - for _, e := range elements { - n, err := strconv.ParseInt(e, 10, 16) - if err != nil { - return ProtocolError(fmt.Sprintf("Received invalid int2[]: %v", s)) - } - numbers = append(numbers, int16(n)) +func decodeInt2Array(qr *QueryResult, fd *FieldDescription, size int32) []int16 { + if fd.DataType != Int2ArrayOid { + qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read int2[] but received: %v", fd.DataType))) + return nil } - return numbers + switch fd.FormatCode { + case TextFormatCode: + s := qr.mr.ReadString(size) + + elements := SplitArrayText(s) + + numbers := make([]int16, 0, len(elements)) + + for _, e := range elements { + n, err := strconv.ParseInt(e, 10, 16) + if err != nil { + qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int2[]: %v", s))) + return nil + } + numbers = append(numbers, int16(n)) + } + + return numbers + default: + qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + return nil + } } func int16SliceToArrayString(nums []int16) (string, error) { @@ -604,22 +719,34 @@ func encodeInt2Array(w *WriteBuf, value interface{}) error { return encodeText(w, s) } -func decodeInt4ArrayFromText(mr *MessageReader, size int32) interface{} { - s := mr.ReadString(size) - - elements := SplitArrayText(s) - - numbers := make([]int32, 0, len(elements)) - - for _, e := range elements { - n, err := strconv.ParseInt(e, 10, 16) - if err != nil { - return ProtocolError(fmt.Sprintf("Received invalid int4[]: %v", s)) - } - numbers = append(numbers, int32(n)) +func decodeInt4Array(qr *QueryResult, fd *FieldDescription, size int32) []int32 { + if fd.DataType != Int4ArrayOid { + qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read int4[] but received: %v", fd.DataType))) + return nil } - return numbers + switch fd.FormatCode { + case TextFormatCode: + s := qr.mr.ReadString(size) + + elements := SplitArrayText(s) + + numbers := make([]int32, 0, len(elements)) + + for _, e := range elements { + n, err := strconv.ParseInt(e, 10, 32) + if err != nil { + qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int4[]: %v", s))) + return nil + } + numbers = append(numbers, int32(n)) + } + + return numbers + default: + qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + return nil + } } func int32SliceToArrayString(nums []int32) (string, error) { @@ -666,22 +793,34 @@ func encodeInt4Array(w *WriteBuf, value interface{}) error { return encodeText(w, s) } -func decodeInt8ArrayFromText(mr *MessageReader, size int32) interface{} { - s := mr.ReadString(size) - - elements := SplitArrayText(s) - - numbers := make([]int64, 0, len(elements)) - - for _, e := range elements { - n, err := strconv.ParseInt(e, 10, 16) - if err != nil { - return ProtocolError(fmt.Sprintf("Received invalid int8[]: %v", s)) - } - numbers = append(numbers, int64(n)) +func decodeInt8Array(qr *QueryResult, fd *FieldDescription, size int32) []int64 { + if fd.DataType != Int8ArrayOid { + qr.Fatal(ProtocolError(fmt.Sprintf("Tried to read int8[] but received: %v", fd.DataType))) + return nil } - return numbers + switch fd.FormatCode { + case TextFormatCode: + s := qr.mr.ReadString(size) + + elements := SplitArrayText(s) + + numbers := make([]int64, 0, len(elements)) + + for _, e := range elements { + n, err := strconv.ParseInt(e, 10, 64) + if err != nil { + qr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int8[]: %v", s))) + return nil + } + numbers = append(numbers, int64(n)) + } + + return numbers + default: + qr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) + return nil + } } func int64SliceToArrayString(nums []int64) (string, error) {