diff --git a/messages.go b/messages.go index 53a5a67c..5ffa5c06 100644 --- a/messages.go +++ b/messages.go @@ -1,14 +1,19 @@ package pgx import ( + "math" + "reflect" + "time" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgtype" ) const ( - copyData = 'd' - copyFail = 'f' - copyDone = 'c' + copyData = 'd' + copyFail = 'f' + copyDone = 'c' + varHeaderSize = 4 ) type FieldDescription struct { @@ -22,6 +27,52 @@ type FieldDescription struct { FormatCode int16 } +func (fd FieldDescription) Length() (int64, bool) { + switch fd.DataType { + case pgtype.TextOID, pgtype.ByteaOID: + return math.MaxInt64, true + case pgtype.VarcharOID: + return int64(fd.Modifier - varHeaderSize), true + default: + return 0, false + } +} + +func (fd FieldDescription) PrecisionScale() (precision, scale int64, ok bool) { + switch fd.DataType { + case pgtype.NumericOID: + mod := fd.Modifier - varHeaderSize + precision = int64((mod >> 16) & 0xffff) + scale = int64(mod & 0xffff) + return precision, scale, true + default: + return 0, 0, false + } +} + +func (fd FieldDescription) Type() reflect.Type { + switch fd.DataType { + case pgtype.Int8OID: + return reflect.TypeOf(int64(0)) + case pgtype.Int4OID: + return reflect.TypeOf(int32(0)) + case pgtype.Int2OID: + return reflect.TypeOf(int16(0)) + case pgtype.VarcharOID, pgtype.TextOID: + return reflect.TypeOf("") + case pgtype.BoolOID: + return reflect.TypeOf(false) + case pgtype.NumericOID: + return reflect.TypeOf(float64(0)) + case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID: + return reflect.TypeOf(time.Time{}) + case pgtype.ByteaOID: + return reflect.TypeOf([]byte(nil)) + default: + return reflect.TypeOf(new(interface{})).Elem() + } +} + // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for // detailed field description. diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 6f8e7986..00175a30 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -46,6 +46,7 @@ const ( DateArrayOID = 1182 TimestamptzOID = 1184 TimestamptzArrayOID = 1185 + NumericOID = 1700 RecordOID = 2249 UUIDOID = 2950 JSONBOID = 3802 diff --git a/stdlib/sql.go b/stdlib/sql.go index 0c140343..c021e317 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -70,6 +70,7 @@ import ( "encoding/binary" "fmt" "io" + "reflect" "strings" "sync" @@ -415,10 +416,29 @@ func (r *Rows) Columns() []string { return names } +// ColumnTypeDatabaseTypeName return the database system type name. func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { return strings.ToUpper(r.rows.FieldDescriptions()[index].DataTypeName) } +// ColumnTypeLength returns the length of the column type if the column is a +// variable length type. If the column is not a variable length type ok +// should return false. +func (r *Rows) ColumnTypeLength(index int) (int64, bool) { + return r.rows.FieldDescriptions()[index].Length() +} + +// ColumnTypePrecisionScale should return the precision and scale for decimal +// types. If not applicable, ok should be false. +func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { + return r.rows.FieldDescriptions()[index].PrecisionScale() +} + +// ColumnTypeScanType returns the value type that can be used to scan types into. +func (r *Rows) ColumnTypeScanType(index int) reflect.Type { + return r.rows.FieldDescriptions()[index].Type() +} + func (r *Rows) Close() error { r.rows.Close() return nil diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 65f80ac4..1880429d 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -5,6 +5,8 @@ import ( "context" "database/sql" "fmt" + "math" + "reflect" "testing" "time" @@ -1258,3 +1260,128 @@ func TestStmtQueryContextCancel(t *testing.T) { t.Errorf("mock server err: %v", err) } } + +func TestRowsColumnTypes(t *testing.T) { + columnTypesTests := []struct { + Name string + TypeName string + Length struct { + Len int64 + OK bool + } + DecimalSize struct { + Precision int64 + Scale int64 + OK bool + } + ScanType reflect.Type + }{ + { + Name: "a", + TypeName: "INT4", + Length: struct { + Len int64 + OK bool + }{ + Len: 0, + OK: false, + }, + DecimalSize: struct { + Precision int64 + Scale int64 + OK bool + }{ + Precision: 0, + Scale: 0, + OK: false, + }, + ScanType: reflect.TypeOf(int32(0)), + }, { + Name: "bar", + TypeName: "TEXT", + Length: struct { + Len int64 + OK bool + }{ + Len: math.MaxInt64, + OK: true, + }, + DecimalSize: struct { + Precision int64 + Scale int64 + OK bool + }{ + Precision: 0, + Scale: 0, + OK: false, + }, + ScanType: reflect.TypeOf(""), + }, { + Name: "dec", + TypeName: "NUMERIC", + Length: struct { + Len int64 + OK bool + }{ + Len: 0, + OK: false, + }, + DecimalSize: struct { + Precision int64 + Scale int64 + OK bool + }{ + Precision: 9, + Scale: 2, + OK: true, + }, + ScanType: reflect.TypeOf(float64(0)), + }, + } + + db := openDB(t) + defer closeDB(t, db) + + rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec") + if err != nil { + t.Fatal(err) + } + + columns, err := rows.ColumnTypes() + if err != nil { + t.Fatal(err) + } + if len(columns) != 3 { + t.Errorf("expected 3 columns found %d", len(columns)) + } + + for i, tt := range columnTypesTests { + c := columns[i] + if c.Name() != tt.Name { + t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name) + } + if c.DatabaseTypeName() != tt.TypeName { + t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName) + } + l, ok := c.Length() + if l != tt.Length.Len { + t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len) + } + if ok != tt.Length.OK { + t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK) + } + p, s, ok := c.DecimalSize() + if p != tt.DecimalSize.Precision { + t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision) + } + if s != tt.DecimalSize.Scale { + t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale) + } + if ok != tt.DecimalSize.OK { + t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK) + } + if c.ScanType() != tt.ScanType { + t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType) + } + } +}