diff --git a/stdlib/sql.go b/stdlib/sql.go index c7947eaf..76dd4578 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -56,6 +56,7 @@ import ( "io" "math" "reflect" + "strconv" "strings" "sync" "time" @@ -423,13 +424,13 @@ func (r *Rows) Columns() []string { return names } -// ColumnTypeDatabaseTypeName return the database system type name. +// ColumnTypeDatabaseTypeName returns the database system type name. If the name is unknown the OID is returned. func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { if dt, ok := r.conn.conn.ConnInfo().DataTypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok { return strings.ToUpper(dt.Name) } - return "" + return strconv.FormatInt(int64(r.rows.FieldDescriptions()[index].DataTypeOID), 10) } const varHeaderSize = 4 @@ -481,8 +482,6 @@ func (r *Rows) ColumnTypeScanType(index int) reflect.Type { return reflect.TypeOf(int32(0)) case pgtype.Int2OID: return reflect.TypeOf(int16(0)) - case pgtype.VarcharOID, pgtype.BPCharArrayOID, pgtype.TextOID: - return reflect.TypeOf("") case pgtype.BoolOID: return reflect.TypeOf(false) case pgtype.NumericOID: @@ -492,7 +491,7 @@ func (r *Rows) ColumnTypeScanType(index int) reflect.Type { case pgtype.ByteaOID: return reflect.TypeOf([]byte(nil)) default: - return reflect.TypeOf(new(interface{})).Elem() + return reflect.TypeOf("") } } diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index a551e34b..3f6958f2 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -834,17 +834,35 @@ func TestRowsColumnTypes(t *testing.T) { OK: true, }, ScanType: reflect.TypeOf(float64(0)), + }, { + Name: "d", + TypeName: "1266", + Length: struct { + Len int64 + OK bool + }{ + Len: 0, + OK: false, + }, + DecimalSize: struct { + Precision int64 + Scale int64 + OK bool + }{ + Precision: 0, + Scale: 0, + OK: false, + }, + ScanType: reflect.TypeOf(""), }, } - rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec") + rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec, '12:00:00'::timetz as d") require.NoError(t, err) columns, err := rows.ColumnTypes() require.NoError(t, err) - if len(columns) != 3 { - t.Errorf("expected 3 columns found %d", len(columns)) - } + assert.Len(t, columns, 4) for i, tt := range columnTypesTests { c := columns[i]