diff --git a/conn.go b/conn.go index d49eb436..0d06098b 100644 --- a/conn.go +++ b/conn.go @@ -56,7 +56,7 @@ type Conn struct { type PreparedStatement struct { Name string SQL string - FieldDescriptions []FieldDescription + FieldDescriptions []pgproto3.FieldDescription ParameterOIDs []pgtype.OID } @@ -213,15 +213,12 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedState Name: psd.Name, SQL: psd.SQL, ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)), - FieldDescriptions: make([]FieldDescription, len(psd.Fields)), + FieldDescriptions: psd.Fields, } for i := range ps.ParameterOIDs { ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i]) } - for i := range ps.FieldDescriptions { - pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i]) - } if name != "" { c.preparedStatements[name] = ps @@ -416,7 +413,7 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) ( resultFormats := make([]int16, len(ps.FieldDescriptions)) for i := range resultFormats { - if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok { + if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok { if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { resultFormats[i] = BinaryFormatCode } else { @@ -453,15 +450,12 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) ( Name: psd.Name, SQL: psd.SQL, ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)), - FieldDescriptions: make([]FieldDescription, len(psd.Fields)), + FieldDescriptions: psd.Fields, } for i := range ps.ParameterOIDs { ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i]) } - for i := range ps.FieldDescriptions { - pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i]) - } arguments, err = convertDriverValuers(arguments) if err != nil { @@ -481,7 +475,7 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) ( resultFormats := make([]int16, len(ps.FieldDescriptions)) for i := range resultFormats { - if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok { + if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok { if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { resultFormats[i] = BinaryFormatCode } else { @@ -549,22 +543,6 @@ func newencodePreparedStatementArgument(ci *pgtype.ConnInfo, oid pgtype.OID, arg return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } -// pgproto3FieldDescriptionToPgxFieldDescription copies and converts the data from a pgproto3.FieldDescription to a -// FieldDescription. -func pgproto3FieldDescriptionToPgxFieldDescription(connInfo *pgtype.ConnInfo, src *pgproto3.FieldDescription, dst *FieldDescription) { - dst.Name = string(src.Name) - dst.Table = pgtype.OID(src.TableOID) - dst.AttributeNumber = src.TableAttributeNumber - dst.DataType = pgtype.OID(src.DataTypeOID) - dst.DataTypeSize = src.DataTypeSize - dst.Modifier = src.TypeModifier - dst.FormatCode = src.Format - - if dt, ok := connInfo.DataTypeForOID(dst.DataType); ok { - dst.DataTypeName = dt.Name - } -} - func (c *Conn) getRows(sql string, args []interface{}) *connRows { if len(c.preallocatedRows) == 0 { c.preallocatedRows = make([]connRows, 64) @@ -628,15 +606,12 @@ optionLoop: Name: psd.Name, SQL: psd.SQL, ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)), - FieldDescriptions: make([]FieldDescription, len(psd.Fields)), + FieldDescriptions: psd.Fields, } for i := range ps.ParameterOIDs { ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i]) } - for i := range ps.FieldDescriptions { - pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i]) - } } rows.sql = ps.SQL @@ -658,13 +633,13 @@ optionLoop: if resultFormatsByOID != nil { resultFormats = make([]int16, len(ps.FieldDescriptions)) for i := range resultFormats { - resultFormats[i] = resultFormatsByOID[ps.FieldDescriptions[i].DataType] + resultFormats[i] = resultFormatsByOID[pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)] } } if resultFormats == nil { for i := range ps.FieldDescriptions { - if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok { + if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok { if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { c.eqb.AppendResultFormat(BinaryFormatCode) } else { @@ -725,7 +700,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { if resultFormats == nil { resultFormats = make([]int16, len(ps.FieldDescriptions)) for i := range resultFormats { - if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok { + if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok { if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { resultFormats[i] = BinaryFormatCode } else { diff --git a/copy_from.go b/copy_from.go index 368d4c33..ec56b11f 100644 --- a/copy_from.go +++ b/copy_from.go @@ -7,6 +7,7 @@ import ( "io" "github.com/jackc/pgio" + "github.com/jackc/pgtype" errors "golang.org/x/xerrors" ) @@ -129,7 +130,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byt buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) for i, val := range values { - buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val) + buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, pgtype.OID(ps.FieldDescriptions[i].DataTypeOID), val) if err != nil { return false, nil, err } diff --git a/messages.go b/messages.go index 45dda1ad..8e00f65f 100644 --- a/messages.go +++ b/messages.go @@ -2,82 +2,17 @@ package pgx import ( "database/sql/driver" - "math" - "reflect" - "time" "github.com/jackc/pgio" "github.com/jackc/pgtype" ) const ( - copyData = 'd' - copyFail = 'f' - copyDone = 'c' - varHeaderSize = 4 + copyData = 'd' + copyFail = 'f' + copyDone = 'c' ) -type FieldDescription struct { - Name string - Table pgtype.OID - AttributeNumber uint16 - DataType pgtype.OID - DataTypeSize int16 - DataTypeName string - Modifier int32 - FormatCode int16 -} - -func (fd FieldDescription) Length() (int64, bool) { - switch fd.DataType { - case pgtype.TextOID, pgtype.ByteaOID: - return math.MaxInt64, true - case pgtype.VarcharOID, pgtype.BPCharArrayOID: - return int64(fd.Modifier - varHeaderSize), true - default: - return 0, false - } -} - -func (fd FieldDescription) PrecisionScale() (precision, scale int64, ok bool) { - switch fd.DataType { - case pgtype.NumericOID: - mod := fd.Modifier - varHeaderSize - precision = int64((mod >> 16) & 0xffff) - scale = int64(mod & 0xffff) - return precision, scale, true - default: - return 0, 0, false - } -} - -func (fd FieldDescription) Type() reflect.Type { - switch fd.DataType { - case pgtype.Float8OID: - return reflect.TypeOf(float64(0)) - case pgtype.Float4OID: - return reflect.TypeOf(float32(0)) - case pgtype.Int8OID: - return reflect.TypeOf(int64(0)) - case pgtype.Int4OID: - return reflect.TypeOf(int32(0)) - case pgtype.Int2OID: - return reflect.TypeOf(int16(0)) - case pgtype.VarcharOID, pgtype.BPCharArrayOID, pgtype.TextOID: - return reflect.TypeOf("") - case pgtype.BoolOID: - return reflect.TypeOf(false) - case pgtype.NumericOID: - return reflect.TypeOf(float64(0)) - case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID: - return reflect.TypeOf(time.Time{}) - case pgtype.ByteaOID: - return reflect.TypeOf([]byte(nil)) - default: - return reflect.TypeOf(new(interface{})).Elem() - } -} - func convertDriverValuers(args []interface{}) ([]interface{}, error) { for i, arg := range args { switch arg := arg.(type) { diff --git a/pool/rows.go b/pool/rows.go index 930f21e1..99933db8 100644 --- a/pool/rows.go +++ b/pool/rows.go @@ -1,6 +1,7 @@ package pool import ( + "github.com/jackc/pgproto3/v2" "github.com/jackc/pgx/v4" ) @@ -8,12 +9,12 @@ type errRows struct { err error } -func (errRows) Close() {} -func (e errRows) Err() error { return e.err } -func (errRows) FieldDescriptions() []pgx.FieldDescription { return nil } -func (errRows) Next() bool { return false } -func (e errRows) Scan(dest ...interface{}) error { return e.err } -func (e errRows) Values() ([]interface{}, error) { return nil, e.err } +func (errRows) Close() {} +func (e errRows) Err() error { return e.err } +func (errRows) FieldDescriptions() []pgproto3.FieldDescription { return nil } +func (errRows) Next() bool { return false } +func (e errRows) Scan(dest ...interface{}) error { return e.err } +func (e errRows) Values() ([]interface{}, error) { return nil, e.err } type errRow struct { err error @@ -42,7 +43,7 @@ func (rows *poolRows) Err() error { return rows.r.Err() } -func (rows *poolRows) FieldDescriptions() []pgx.FieldDescription { +func (rows *poolRows) FieldDescriptions() []pgproto3.FieldDescription { return rows.r.FieldDescriptions() } diff --git a/replication_test.go b/replication_test.go index cd9b917a..d84b390e 100644 --- a/replication_test.go +++ b/replication_test.go @@ -248,7 +248,7 @@ func TestIdentifySystem(t *testing.T) { } defer r.Close() for _, fd := range r.FieldDescriptions() { - t.Logf("Field: %s of type %v", fd.Name, fd.DataType) + t.Logf("Field: %s of type %v", fd.Name, fd.DataTypeOID) } var rowCount int @@ -307,7 +307,7 @@ func TestGetTimelineHistory(t *testing.T) { defer r.Close() for _, fd := range r.FieldDescriptions() { - t.Logf("Field: %s of type %v", fd.Name, fd.DataType) + t.Logf("Field: %s of type %v", fd.Name, fd.DataTypeOID) } var rowCount int diff --git a/rows.go b/rows.go index b57f16fa..e237f6bc 100644 --- a/rows.go +++ b/rows.go @@ -8,6 +8,7 @@ import ( errors "golang.org/x/xerrors" "github.com/jackc/pgconn" + "github.com/jackc/pgproto3/v2" "github.com/jackc/pgtype" ) @@ -20,7 +21,7 @@ type Rows interface { Close() Err() error - FieldDescriptions() []FieldDescription + FieldDescriptions() []pgproto3.FieldDescription // Next prepares the next row for reading. It returns true if there is another // row and false if no more rows are available. It automatically closes rows @@ -77,7 +78,6 @@ type connRows struct { logger rowLog connInfo *pgtype.ConnInfo values [][]byte - fields []FieldDescription rowCount int columnIdx int err error @@ -89,8 +89,8 @@ type connRows struct { resultReader *pgconn.ResultReader } -func (rows *connRows) FieldDescriptions() []FieldDescription { - return rows.fields +func (rows *connRows) FieldDescriptions() []pgproto3.FieldDescription { + return rows.resultReader.FieldDescriptions() } func (rows *connRows) Close() { @@ -140,13 +140,6 @@ func (rows *connRows) Next() bool { } if rows.resultReader.NextRow() { - if rows.fields == nil { - rrFieldDescriptions := rows.resultReader.FieldDescriptions() - rows.fields = make([]FieldDescription, len(rrFieldDescriptions)) - for i := range rrFieldDescriptions { - pgproto3FieldDescriptionToPgxFieldDescription(rows.connInfo, &rrFieldDescriptions[i], &rows.fields[i]) - } - } rows.rowCount++ rows.columnIdx = 0 rows.values = rows.resultReader.Values() @@ -157,24 +150,24 @@ func (rows *connRows) Next() bool { } } -func (rows *connRows) nextColumn() ([]byte, *FieldDescription, bool) { +func (rows *connRows) nextColumn() ([]byte, *pgproto3.FieldDescription, bool) { if rows.closed { return nil, nil, false } - if len(rows.fields) <= rows.columnIdx { + if len(rows.FieldDescriptions()) <= rows.columnIdx { rows.fatal(ProtocolError("No next column available")) return nil, nil, false } buf := rows.values[rows.columnIdx] - fd := &rows.fields[rows.columnIdx] + fd := &rows.FieldDescriptions()[rows.columnIdx] rows.columnIdx++ return buf, fd, true } func (rows *connRows) Scan(dest ...interface{}) error { - if len(rows.fields) != len(dest) { - err := errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields)) + if len(rows.FieldDescriptions()) != len(dest) { + err := errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.FieldDescriptions())) rows.fatal(err) return err } @@ -186,7 +179,7 @@ func (rows *connRows) Scan(dest ...interface{}) error { continue } - err := rows.connInfo.Scan(fd.DataType, fd.FormatCode, buf, d) + err := rows.connInfo.Scan(pgtype.OID(fd.DataTypeOID), fd.Format, buf, d) if err != nil { rows.fatal(scanArgError{col: i, err: err}) return err @@ -201,9 +194,9 @@ func (rows *connRows) Values() ([]interface{}, error) { return nil, errors.New("rows is closed") } - values := make([]interface{}, 0, len(rows.fields)) + values := make([]interface{}, 0, len(rows.FieldDescriptions())) - for range rows.fields { + for range rows.FieldDescriptions() { buf, fd, _ := rows.nextColumn() if buf == nil { @@ -211,10 +204,10 @@ func (rows *connRows) Values() ([]interface{}, error) { continue } - if dt, ok := rows.connInfo.DataTypeForOID(fd.DataType); ok { + if dt, ok := rows.connInfo.DataTypeForOID(pgtype.OID(fd.DataTypeOID)); ok { value := reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(pgtype.Value) - switch fd.FormatCode { + switch fd.Format { case TextFormatCode: decoder := value.(pgtype.TextDecoder) if decoder == nil { diff --git a/stdlib/sql.go b/stdlib/sql.go index df2d0572..1bde8c4e 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -74,6 +74,7 @@ import ( "database/sql/driver" "fmt" "io" + "math" "net" "reflect" "strings" @@ -260,7 +261,7 @@ func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []dr // Preload first row because otherwise we won't know what columns are available when database/sql asks. more := rows.Next() - return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil + return &Rows{conn: c, rows: rows, skipNext: true, skipNextMore: more}, nil } func (c *Conn) Ping(ctx context.Context) error { @@ -301,6 +302,7 @@ func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (dri } type Rows struct { + conn *Conn rows pgx.Rows values []interface{} skipNext bool @@ -311,32 +313,82 @@ func (r *Rows) Columns() []string { fieldDescriptions := r.rows.FieldDescriptions() names := make([]string, 0, len(fieldDescriptions)) for _, fd := range fieldDescriptions { - names = append(names, fd.Name) + names = append(names, string(fd.Name)) } return names } // ColumnTypeDatabaseTypeName return the database system type name. func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { - return strings.ToUpper(r.rows.FieldDescriptions()[index].DataTypeName) + if dt, ok := r.conn.conn.ConnInfo.DataTypeForOID(pgtype.OID(r.rows.FieldDescriptions()[index].DataTypeOID)); ok { + return strings.ToUpper(dt.Name) + } + + return "" } +const varHeaderSize = 4 + // ColumnTypeLength returns the length of the column type if the column is a // variable length type. If the column is not a variable length type ok // should return false. func (r *Rows) ColumnTypeLength(index int) (int64, bool) { - return r.rows.FieldDescriptions()[index].Length() + fd := r.rows.FieldDescriptions()[index] + + switch fd.DataTypeOID { + case pgtype.TextOID, pgtype.ByteaOID: + return math.MaxInt64, true + case pgtype.VarcharOID, pgtype.BPCharArrayOID: + return int64(fd.TypeModifier - varHeaderSize), true + default: + return 0, false + } } // ColumnTypePrecisionScale should return the precision and scale for decimal // types. If not applicable, ok should be false. func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { - return r.rows.FieldDescriptions()[index].PrecisionScale() + fd := r.rows.FieldDescriptions()[index] + + switch fd.DataTypeOID { + case pgtype.NumericOID: + mod := fd.TypeModifier - varHeaderSize + precision = int64((mod >> 16) & 0xffff) + scale = int64(mod & 0xffff) + return precision, scale, true + default: + return 0, 0, false + } } // ColumnTypeScanType returns the value type that can be used to scan types into. func (r *Rows) ColumnTypeScanType(index int) reflect.Type { - return r.rows.FieldDescriptions()[index].Type() + fd := r.rows.FieldDescriptions()[index] + + switch fd.DataTypeOID { + case pgtype.Float8OID: + return reflect.TypeOf(float64(0)) + case pgtype.Float4OID: + return reflect.TypeOf(float32(0)) + case pgtype.Int8OID: + return reflect.TypeOf(int64(0)) + case pgtype.Int4OID: + return reflect.TypeOf(int32(0)) + case pgtype.Int2OID: + return reflect.TypeOf(int16(0)) + case pgtype.VarcharOID, pgtype.BPCharArrayOID, pgtype.TextOID: + return reflect.TypeOf("") + case pgtype.BoolOID: + return reflect.TypeOf(false) + case pgtype.NumericOID: + return reflect.TypeOf(float64(0)) + case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID: + return reflect.TypeOf(time.Time{}) + case pgtype.ByteaOID: + return reflect.TypeOf([]byte(nil)) + default: + return reflect.TypeOf(new(interface{})).Elem() + } } func (r *Rows) Close() error { @@ -348,7 +400,7 @@ func (r *Rows) Next(dest []driver.Value) error { if r.values == nil { r.values = make([]interface{}, len(r.rows.FieldDescriptions())) for i, fd := range r.rows.FieldDescriptions() { - switch fd.DataType { + switch fd.DataTypeOID { case pgtype.BoolOID: r.values[i] = &pgtype.Bool{} case pgtype.ByteaOID: