Use pgproto3.FieldDescription instead of pgx version

This allows removing a malloc and memcpy.
pull/483/head
Jack Christensen 2019-05-04 13:47:03 -05:00
parent ea31df3b50
commit 583c8d3b25
7 changed files with 97 additions and 140 deletions

43
conn.go
View File

@ -56,7 +56,7 @@ type Conn struct {
type PreparedStatement struct {
Name string
SQL string
FieldDescriptions []FieldDescription
FieldDescriptions []pgproto3.FieldDescription
ParameterOIDs []pgtype.OID
}
@ -213,15 +213,12 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedState
Name: psd.Name,
SQL: psd.SQL,
ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)),
FieldDescriptions: make([]FieldDescription, len(psd.Fields)),
FieldDescriptions: psd.Fields,
}
for i := range ps.ParameterOIDs {
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
}
for i := range ps.FieldDescriptions {
pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i])
}
if name != "" {
c.preparedStatements[name] = ps
@ -416,7 +413,7 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
resultFormats := make([]int16, len(ps.FieldDescriptions))
for i := range resultFormats {
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok {
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
resultFormats[i] = BinaryFormatCode
} else {
@ -453,15 +450,12 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
Name: psd.Name,
SQL: psd.SQL,
ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)),
FieldDescriptions: make([]FieldDescription, len(psd.Fields)),
FieldDescriptions: psd.Fields,
}
for i := range ps.ParameterOIDs {
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
}
for i := range ps.FieldDescriptions {
pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i])
}
arguments, err = convertDriverValuers(arguments)
if err != nil {
@ -481,7 +475,7 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
resultFormats := make([]int16, len(ps.FieldDescriptions))
for i := range resultFormats {
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok {
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
resultFormats[i] = BinaryFormatCode
} else {
@ -549,22 +543,6 @@ func newencodePreparedStatementArgument(ci *pgtype.ConnInfo, oid pgtype.OID, arg
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
}
// pgproto3FieldDescriptionToPgxFieldDescription copies and converts the data from a pgproto3.FieldDescription to a
// FieldDescription.
func pgproto3FieldDescriptionToPgxFieldDescription(connInfo *pgtype.ConnInfo, src *pgproto3.FieldDescription, dst *FieldDescription) {
dst.Name = string(src.Name)
dst.Table = pgtype.OID(src.TableOID)
dst.AttributeNumber = src.TableAttributeNumber
dst.DataType = pgtype.OID(src.DataTypeOID)
dst.DataTypeSize = src.DataTypeSize
dst.Modifier = src.TypeModifier
dst.FormatCode = src.Format
if dt, ok := connInfo.DataTypeForOID(dst.DataType); ok {
dst.DataTypeName = dt.Name
}
}
func (c *Conn) getRows(sql string, args []interface{}) *connRows {
if len(c.preallocatedRows) == 0 {
c.preallocatedRows = make([]connRows, 64)
@ -628,15 +606,12 @@ optionLoop:
Name: psd.Name,
SQL: psd.SQL,
ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)),
FieldDescriptions: make([]FieldDescription, len(psd.Fields)),
FieldDescriptions: psd.Fields,
}
for i := range ps.ParameterOIDs {
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
}
for i := range ps.FieldDescriptions {
pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i])
}
}
rows.sql = ps.SQL
@ -658,13 +633,13 @@ optionLoop:
if resultFormatsByOID != nil {
resultFormats = make([]int16, len(ps.FieldDescriptions))
for i := range resultFormats {
resultFormats[i] = resultFormatsByOID[ps.FieldDescriptions[i].DataType]
resultFormats[i] = resultFormatsByOID[pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)]
}
}
if resultFormats == nil {
for i := range ps.FieldDescriptions {
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok {
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
c.eqb.AppendResultFormat(BinaryFormatCode)
} else {
@ -725,7 +700,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
if resultFormats == nil {
resultFormats = make([]int16, len(ps.FieldDescriptions))
for i := range resultFormats {
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok {
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
resultFormats[i] = BinaryFormatCode
} else {

View File

@ -7,6 +7,7 @@ import (
"io"
"github.com/jackc/pgio"
"github.com/jackc/pgtype"
errors "golang.org/x/xerrors"
)
@ -129,7 +130,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byt
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
for i, val := range values {
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, pgtype.OID(ps.FieldDescriptions[i].DataTypeOID), val)
if err != nil {
return false, nil, err
}

View File

@ -2,82 +2,17 @@ package pgx
import (
"database/sql/driver"
"math"
"reflect"
"time"
"github.com/jackc/pgio"
"github.com/jackc/pgtype"
)
const (
copyData = 'd'
copyFail = 'f'
copyDone = 'c'
varHeaderSize = 4
copyData = 'd'
copyFail = 'f'
copyDone = 'c'
)
type FieldDescription struct {
Name string
Table pgtype.OID
AttributeNumber uint16
DataType pgtype.OID
DataTypeSize int16
DataTypeName string
Modifier int32
FormatCode int16
}
func (fd FieldDescription) Length() (int64, bool) {
switch fd.DataType {
case pgtype.TextOID, pgtype.ByteaOID:
return math.MaxInt64, true
case pgtype.VarcharOID, pgtype.BPCharArrayOID:
return int64(fd.Modifier - varHeaderSize), true
default:
return 0, false
}
}
func (fd FieldDescription) PrecisionScale() (precision, scale int64, ok bool) {
switch fd.DataType {
case pgtype.NumericOID:
mod := fd.Modifier - varHeaderSize
precision = int64((mod >> 16) & 0xffff)
scale = int64(mod & 0xffff)
return precision, scale, true
default:
return 0, 0, false
}
}
func (fd FieldDescription) Type() reflect.Type {
switch fd.DataType {
case pgtype.Float8OID:
return reflect.TypeOf(float64(0))
case pgtype.Float4OID:
return reflect.TypeOf(float32(0))
case pgtype.Int8OID:
return reflect.TypeOf(int64(0))
case pgtype.Int4OID:
return reflect.TypeOf(int32(0))
case pgtype.Int2OID:
return reflect.TypeOf(int16(0))
case pgtype.VarcharOID, pgtype.BPCharArrayOID, pgtype.TextOID:
return reflect.TypeOf("")
case pgtype.BoolOID:
return reflect.TypeOf(false)
case pgtype.NumericOID:
return reflect.TypeOf(float64(0))
case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID:
return reflect.TypeOf(time.Time{})
case pgtype.ByteaOID:
return reflect.TypeOf([]byte(nil))
default:
return reflect.TypeOf(new(interface{})).Elem()
}
}
func convertDriverValuers(args []interface{}) ([]interface{}, error) {
for i, arg := range args {
switch arg := arg.(type) {

View File

@ -1,6 +1,7 @@
package pool
import (
"github.com/jackc/pgproto3/v2"
"github.com/jackc/pgx/v4"
)
@ -8,12 +9,12 @@ type errRows struct {
err error
}
func (errRows) Close() {}
func (e errRows) Err() error { return e.err }
func (errRows) FieldDescriptions() []pgx.FieldDescription { return nil }
func (errRows) Next() bool { return false }
func (e errRows) Scan(dest ...interface{}) error { return e.err }
func (e errRows) Values() ([]interface{}, error) { return nil, e.err }
func (errRows) Close() {}
func (e errRows) Err() error { return e.err }
func (errRows) FieldDescriptions() []pgproto3.FieldDescription { return nil }
func (errRows) Next() bool { return false }
func (e errRows) Scan(dest ...interface{}) error { return e.err }
func (e errRows) Values() ([]interface{}, error) { return nil, e.err }
type errRow struct {
err error
@ -42,7 +43,7 @@ func (rows *poolRows) Err() error {
return rows.r.Err()
}
func (rows *poolRows) FieldDescriptions() []pgx.FieldDescription {
func (rows *poolRows) FieldDescriptions() []pgproto3.FieldDescription {
return rows.r.FieldDescriptions()
}

View File

@ -248,7 +248,7 @@ func TestIdentifySystem(t *testing.T) {
}
defer r.Close()
for _, fd := range r.FieldDescriptions() {
t.Logf("Field: %s of type %v", fd.Name, fd.DataType)
t.Logf("Field: %s of type %v", fd.Name, fd.DataTypeOID)
}
var rowCount int
@ -307,7 +307,7 @@ func TestGetTimelineHistory(t *testing.T) {
defer r.Close()
for _, fd := range r.FieldDescriptions() {
t.Logf("Field: %s of type %v", fd.Name, fd.DataType)
t.Logf("Field: %s of type %v", fd.Name, fd.DataTypeOID)
}
var rowCount int

35
rows.go
View File

@ -8,6 +8,7 @@ import (
errors "golang.org/x/xerrors"
"github.com/jackc/pgconn"
"github.com/jackc/pgproto3/v2"
"github.com/jackc/pgtype"
)
@ -20,7 +21,7 @@ type Rows interface {
Close()
Err() error
FieldDescriptions() []FieldDescription
FieldDescriptions() []pgproto3.FieldDescription
// Next prepares the next row for reading. It returns true if there is another
// row and false if no more rows are available. It automatically closes rows
@ -77,7 +78,6 @@ type connRows struct {
logger rowLog
connInfo *pgtype.ConnInfo
values [][]byte
fields []FieldDescription
rowCount int
columnIdx int
err error
@ -89,8 +89,8 @@ type connRows struct {
resultReader *pgconn.ResultReader
}
func (rows *connRows) FieldDescriptions() []FieldDescription {
return rows.fields
func (rows *connRows) FieldDescriptions() []pgproto3.FieldDescription {
return rows.resultReader.FieldDescriptions()
}
func (rows *connRows) Close() {
@ -140,13 +140,6 @@ func (rows *connRows) Next() bool {
}
if rows.resultReader.NextRow() {
if rows.fields == nil {
rrFieldDescriptions := rows.resultReader.FieldDescriptions()
rows.fields = make([]FieldDescription, len(rrFieldDescriptions))
for i := range rrFieldDescriptions {
pgproto3FieldDescriptionToPgxFieldDescription(rows.connInfo, &rrFieldDescriptions[i], &rows.fields[i])
}
}
rows.rowCount++
rows.columnIdx = 0
rows.values = rows.resultReader.Values()
@ -157,24 +150,24 @@ func (rows *connRows) Next() bool {
}
}
func (rows *connRows) nextColumn() ([]byte, *FieldDescription, bool) {
func (rows *connRows) nextColumn() ([]byte, *pgproto3.FieldDescription, bool) {
if rows.closed {
return nil, nil, false
}
if len(rows.fields) <= rows.columnIdx {
if len(rows.FieldDescriptions()) <= rows.columnIdx {
rows.fatal(ProtocolError("No next column available"))
return nil, nil, false
}
buf := rows.values[rows.columnIdx]
fd := &rows.fields[rows.columnIdx]
fd := &rows.FieldDescriptions()[rows.columnIdx]
rows.columnIdx++
return buf, fd, true
}
func (rows *connRows) Scan(dest ...interface{}) error {
if len(rows.fields) != len(dest) {
err := errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields))
if len(rows.FieldDescriptions()) != len(dest) {
err := errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.FieldDescriptions()))
rows.fatal(err)
return err
}
@ -186,7 +179,7 @@ func (rows *connRows) Scan(dest ...interface{}) error {
continue
}
err := rows.connInfo.Scan(fd.DataType, fd.FormatCode, buf, d)
err := rows.connInfo.Scan(pgtype.OID(fd.DataTypeOID), fd.Format, buf, d)
if err != nil {
rows.fatal(scanArgError{col: i, err: err})
return err
@ -201,9 +194,9 @@ func (rows *connRows) Values() ([]interface{}, error) {
return nil, errors.New("rows is closed")
}
values := make([]interface{}, 0, len(rows.fields))
values := make([]interface{}, 0, len(rows.FieldDescriptions()))
for range rows.fields {
for range rows.FieldDescriptions() {
buf, fd, _ := rows.nextColumn()
if buf == nil {
@ -211,10 +204,10 @@ func (rows *connRows) Values() ([]interface{}, error) {
continue
}
if dt, ok := rows.connInfo.DataTypeForOID(fd.DataType); ok {
if dt, ok := rows.connInfo.DataTypeForOID(pgtype.OID(fd.DataTypeOID)); ok {
value := reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(pgtype.Value)
switch fd.FormatCode {
switch fd.Format {
case TextFormatCode:
decoder := value.(pgtype.TextDecoder)
if decoder == nil {

View File

@ -74,6 +74,7 @@ import (
"database/sql/driver"
"fmt"
"io"
"math"
"net"
"reflect"
"strings"
@ -260,7 +261,7 @@ func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []dr
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
more := rows.Next()
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
return &Rows{conn: c, rows: rows, skipNext: true, skipNextMore: more}, nil
}
func (c *Conn) Ping(ctx context.Context) error {
@ -301,6 +302,7 @@ func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (dri
}
type Rows struct {
conn *Conn
rows pgx.Rows
values []interface{}
skipNext bool
@ -311,32 +313,82 @@ func (r *Rows) Columns() []string {
fieldDescriptions := r.rows.FieldDescriptions()
names := make([]string, 0, len(fieldDescriptions))
for _, fd := range fieldDescriptions {
names = append(names, fd.Name)
names = append(names, string(fd.Name))
}
return names
}
// ColumnTypeDatabaseTypeName return the database system type name.
func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
return strings.ToUpper(r.rows.FieldDescriptions()[index].DataTypeName)
if dt, ok := r.conn.conn.ConnInfo.DataTypeForOID(pgtype.OID(r.rows.FieldDescriptions()[index].DataTypeOID)); ok {
return strings.ToUpper(dt.Name)
}
return ""
}
const varHeaderSize = 4
// ColumnTypeLength returns the length of the column type if the column is a
// variable length type. If the column is not a variable length type ok
// should return false.
func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
return r.rows.FieldDescriptions()[index].Length()
fd := r.rows.FieldDescriptions()[index]
switch fd.DataTypeOID {
case pgtype.TextOID, pgtype.ByteaOID:
return math.MaxInt64, true
case pgtype.VarcharOID, pgtype.BPCharArrayOID:
return int64(fd.TypeModifier - varHeaderSize), true
default:
return 0, false
}
}
// ColumnTypePrecisionScale should return the precision and scale for decimal
// types. If not applicable, ok should be false.
func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
return r.rows.FieldDescriptions()[index].PrecisionScale()
fd := r.rows.FieldDescriptions()[index]
switch fd.DataTypeOID {
case pgtype.NumericOID:
mod := fd.TypeModifier - varHeaderSize
precision = int64((mod >> 16) & 0xffff)
scale = int64(mod & 0xffff)
return precision, scale, true
default:
return 0, 0, false
}
}
// ColumnTypeScanType returns the value type that can be used to scan types into.
func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
return r.rows.FieldDescriptions()[index].Type()
fd := r.rows.FieldDescriptions()[index]
switch fd.DataTypeOID {
case pgtype.Float8OID:
return reflect.TypeOf(float64(0))
case pgtype.Float4OID:
return reflect.TypeOf(float32(0))
case pgtype.Int8OID:
return reflect.TypeOf(int64(0))
case pgtype.Int4OID:
return reflect.TypeOf(int32(0))
case pgtype.Int2OID:
return reflect.TypeOf(int16(0))
case pgtype.VarcharOID, pgtype.BPCharArrayOID, pgtype.TextOID:
return reflect.TypeOf("")
case pgtype.BoolOID:
return reflect.TypeOf(false)
case pgtype.NumericOID:
return reflect.TypeOf(float64(0))
case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID:
return reflect.TypeOf(time.Time{})
case pgtype.ByteaOID:
return reflect.TypeOf([]byte(nil))
default:
return reflect.TypeOf(new(interface{})).Elem()
}
}
func (r *Rows) Close() error {
@ -348,7 +400,7 @@ func (r *Rows) Next(dest []driver.Value) error {
if r.values == nil {
r.values = make([]interface{}, len(r.rows.FieldDescriptions()))
for i, fd := range r.rows.FieldDescriptions() {
switch fd.DataType {
switch fd.DataTypeOID {
case pgtype.BoolOID:
r.values[i] = &pgtype.Bool{}
case pgtype.ByteaOID: