mirror of https://github.com/jackc/pgx.git
Add more ColumnType support
parent
d49a78dd73
commit
0f84f73c7b
57
messages.go
57
messages.go
|
@ -1,14 +1,19 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"math"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
const (
|
||||
copyData = 'd'
|
||||
copyFail = 'f'
|
||||
copyDone = 'c'
|
||||
copyData = 'd'
|
||||
copyFail = 'f'
|
||||
copyDone = 'c'
|
||||
varHeaderSize = 4
|
||||
)
|
||||
|
||||
type FieldDescription struct {
|
||||
|
@ -22,6 +27,52 @@ type FieldDescription struct {
|
|||
FormatCode int16
|
||||
}
|
||||
|
||||
func (fd FieldDescription) Length() (int64, bool) {
|
||||
switch fd.DataType {
|
||||
case pgtype.TextOID, pgtype.ByteaOID:
|
||||
return math.MaxInt64, true
|
||||
case pgtype.VarcharOID:
|
||||
return int64(fd.Modifier - varHeaderSize), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func (fd FieldDescription) PrecisionScale() (precision, scale int64, ok bool) {
|
||||
switch fd.DataType {
|
||||
case pgtype.NumericOID:
|
||||
mod := fd.Modifier - varHeaderSize
|
||||
precision = int64((mod >> 16) & 0xffff)
|
||||
scale = int64(mod & 0xffff)
|
||||
return precision, scale, true
|
||||
default:
|
||||
return 0, 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func (fd FieldDescription) Type() reflect.Type {
|
||||
switch fd.DataType {
|
||||
case pgtype.Int8OID:
|
||||
return reflect.TypeOf(int64(0))
|
||||
case pgtype.Int4OID:
|
||||
return reflect.TypeOf(int32(0))
|
||||
case pgtype.Int2OID:
|
||||
return reflect.TypeOf(int16(0))
|
||||
case pgtype.VarcharOID, pgtype.TextOID:
|
||||
return reflect.TypeOf("")
|
||||
case pgtype.BoolOID:
|
||||
return reflect.TypeOf(false)
|
||||
case pgtype.NumericOID:
|
||||
return reflect.TypeOf(float64(0))
|
||||
case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID:
|
||||
return reflect.TypeOf(time.Time{})
|
||||
case pgtype.ByteaOID:
|
||||
return reflect.TypeOf([]byte(nil))
|
||||
default:
|
||||
return reflect.TypeOf(new(interface{})).Elem()
|
||||
}
|
||||
}
|
||||
|
||||
// PgError represents an error reported by the PostgreSQL server. See
|
||||
// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for
|
||||
// detailed field description.
|
||||
|
|
|
@ -46,6 +46,7 @@ const (
|
|||
DateArrayOID = 1182
|
||||
TimestamptzOID = 1184
|
||||
TimestamptzArrayOID = 1185
|
||||
NumericOID = 1700
|
||||
RecordOID = 2249
|
||||
UUIDOID = 2950
|
||||
JSONBOID = 3802
|
||||
|
|
|
@ -70,6 +70,7 @@ import (
|
|||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
|
@ -415,10 +416,29 @@ func (r *Rows) Columns() []string {
|
|||
return names
|
||||
}
|
||||
|
||||
// ColumnTypeDatabaseTypeName return the database system type name.
|
||||
func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
|
||||
return strings.ToUpper(r.rows.FieldDescriptions()[index].DataTypeName)
|
||||
}
|
||||
|
||||
// ColumnTypeLength returns the length of the column type if the column is a
|
||||
// variable length type. If the column is not a variable length type ok
|
||||
// should return false.
|
||||
func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
|
||||
return r.rows.FieldDescriptions()[index].Length()
|
||||
}
|
||||
|
||||
// ColumnTypePrecisionScale should return the precision and scale for decimal
|
||||
// types. If not applicable, ok should be false.
|
||||
func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
|
||||
return r.rows.FieldDescriptions()[index].PrecisionScale()
|
||||
}
|
||||
|
||||
// ColumnTypeScanType returns the value type that can be used to scan types into.
|
||||
func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
|
||||
return r.rows.FieldDescriptions()[index].Type()
|
||||
}
|
||||
|
||||
func (r *Rows) Close() error {
|
||||
r.rows.Close()
|
||||
return nil
|
||||
|
|
|
@ -5,6 +5,8 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -1258,3 +1260,128 @@ func TestStmtQueryContextCancel(t *testing.T) {
|
|||
t.Errorf("mock server err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRowsColumnTypes(t *testing.T) {
|
||||
columnTypesTests := []struct {
|
||||
Name string
|
||||
TypeName string
|
||||
Length struct {
|
||||
Len int64
|
||||
OK bool
|
||||
}
|
||||
DecimalSize struct {
|
||||
Precision int64
|
||||
Scale int64
|
||||
OK bool
|
||||
}
|
||||
ScanType reflect.Type
|
||||
}{
|
||||
{
|
||||
Name: "a",
|
||||
TypeName: "INT4",
|
||||
Length: struct {
|
||||
Len int64
|
||||
OK bool
|
||||
}{
|
||||
Len: 0,
|
||||
OK: false,
|
||||
},
|
||||
DecimalSize: struct {
|
||||
Precision int64
|
||||
Scale int64
|
||||
OK bool
|
||||
}{
|
||||
Precision: 0,
|
||||
Scale: 0,
|
||||
OK: false,
|
||||
},
|
||||
ScanType: reflect.TypeOf(int32(0)),
|
||||
}, {
|
||||
Name: "bar",
|
||||
TypeName: "TEXT",
|
||||
Length: struct {
|
||||
Len int64
|
||||
OK bool
|
||||
}{
|
||||
Len: math.MaxInt64,
|
||||
OK: true,
|
||||
},
|
||||
DecimalSize: struct {
|
||||
Precision int64
|
||||
Scale int64
|
||||
OK bool
|
||||
}{
|
||||
Precision: 0,
|
||||
Scale: 0,
|
||||
OK: false,
|
||||
},
|
||||
ScanType: reflect.TypeOf(""),
|
||||
}, {
|
||||
Name: "dec",
|
||||
TypeName: "NUMERIC",
|
||||
Length: struct {
|
||||
Len int64
|
||||
OK bool
|
||||
}{
|
||||
Len: 0,
|
||||
OK: false,
|
||||
},
|
||||
DecimalSize: struct {
|
||||
Precision int64
|
||||
Scale int64
|
||||
OK bool
|
||||
}{
|
||||
Precision: 9,
|
||||
Scale: 2,
|
||||
OK: true,
|
||||
},
|
||||
ScanType: reflect.TypeOf(float64(0)),
|
||||
},
|
||||
}
|
||||
|
||||
db := openDB(t)
|
||||
defer closeDB(t, db)
|
||||
|
||||
rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
columns, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(columns) != 3 {
|
||||
t.Errorf("expected 3 columns found %d", len(columns))
|
||||
}
|
||||
|
||||
for i, tt := range columnTypesTests {
|
||||
c := columns[i]
|
||||
if c.Name() != tt.Name {
|
||||
t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name)
|
||||
}
|
||||
if c.DatabaseTypeName() != tt.TypeName {
|
||||
t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName)
|
||||
}
|
||||
l, ok := c.Length()
|
||||
if l != tt.Length.Len {
|
||||
t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len)
|
||||
}
|
||||
if ok != tt.Length.OK {
|
||||
t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK)
|
||||
}
|
||||
p, s, ok := c.DecimalSize()
|
||||
if p != tt.DecimalSize.Precision {
|
||||
t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision)
|
||||
}
|
||||
if s != tt.DecimalSize.Scale {
|
||||
t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale)
|
||||
}
|
||||
if ok != tt.DecimalSize.OK {
|
||||
t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK)
|
||||
}
|
||||
if c.ScanType() != tt.ScanType {
|
||||
t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue