mirror of https://github.com/jackc/pgx.git
Add more ColumnType support
parent
d49a78dd73
commit
0f84f73c7b
57
messages.go
57
messages.go
|
@ -1,14 +1,19 @@
|
||||||
package pgx
|
package pgx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"math"
|
||||||
|
"reflect"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgx/pgio"
|
"github.com/jackc/pgx/pgio"
|
||||||
"github.com/jackc/pgx/pgtype"
|
"github.com/jackc/pgx/pgtype"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
copyData = 'd'
|
copyData = 'd'
|
||||||
copyFail = 'f'
|
copyFail = 'f'
|
||||||
copyDone = 'c'
|
copyDone = 'c'
|
||||||
|
varHeaderSize = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
type FieldDescription struct {
|
type FieldDescription struct {
|
||||||
|
@ -22,6 +27,52 @@ type FieldDescription struct {
|
||||||
FormatCode int16
|
FormatCode int16
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (fd FieldDescription) Length() (int64, bool) {
|
||||||
|
switch fd.DataType {
|
||||||
|
case pgtype.TextOID, pgtype.ByteaOID:
|
||||||
|
return math.MaxInt64, true
|
||||||
|
case pgtype.VarcharOID:
|
||||||
|
return int64(fd.Modifier - varHeaderSize), true
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fd FieldDescription) PrecisionScale() (precision, scale int64, ok bool) {
|
||||||
|
switch fd.DataType {
|
||||||
|
case pgtype.NumericOID:
|
||||||
|
mod := fd.Modifier - varHeaderSize
|
||||||
|
precision = int64((mod >> 16) & 0xffff)
|
||||||
|
scale = int64(mod & 0xffff)
|
||||||
|
return precision, scale, true
|
||||||
|
default:
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fd FieldDescription) Type() reflect.Type {
|
||||||
|
switch fd.DataType {
|
||||||
|
case pgtype.Int8OID:
|
||||||
|
return reflect.TypeOf(int64(0))
|
||||||
|
case pgtype.Int4OID:
|
||||||
|
return reflect.TypeOf(int32(0))
|
||||||
|
case pgtype.Int2OID:
|
||||||
|
return reflect.TypeOf(int16(0))
|
||||||
|
case pgtype.VarcharOID, pgtype.TextOID:
|
||||||
|
return reflect.TypeOf("")
|
||||||
|
case pgtype.BoolOID:
|
||||||
|
return reflect.TypeOf(false)
|
||||||
|
case pgtype.NumericOID:
|
||||||
|
return reflect.TypeOf(float64(0))
|
||||||
|
case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID:
|
||||||
|
return reflect.TypeOf(time.Time{})
|
||||||
|
case pgtype.ByteaOID:
|
||||||
|
return reflect.TypeOf([]byte(nil))
|
||||||
|
default:
|
||||||
|
return reflect.TypeOf(new(interface{})).Elem()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// PgError represents an error reported by the PostgreSQL server. See
|
// PgError represents an error reported by the PostgreSQL server. See
|
||||||
// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for
|
// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for
|
||||||
// detailed field description.
|
// detailed field description.
|
||||||
|
|
|
@ -46,6 +46,7 @@ const (
|
||||||
DateArrayOID = 1182
|
DateArrayOID = 1182
|
||||||
TimestamptzOID = 1184
|
TimestamptzOID = 1184
|
||||||
TimestamptzArrayOID = 1185
|
TimestamptzArrayOID = 1185
|
||||||
|
NumericOID = 1700
|
||||||
RecordOID = 2249
|
RecordOID = 2249
|
||||||
UUIDOID = 2950
|
UUIDOID = 2950
|
||||||
JSONBOID = 3802
|
JSONBOID = 3802
|
||||||
|
|
|
@ -70,6 +70,7 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
@ -415,10 +416,29 @@ func (r *Rows) Columns() []string {
|
||||||
return names
|
return names
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ColumnTypeDatabaseTypeName return the database system type name.
|
||||||
func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
|
func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
|
||||||
return strings.ToUpper(r.rows.FieldDescriptions()[index].DataTypeName)
|
return strings.ToUpper(r.rows.FieldDescriptions()[index].DataTypeName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ColumnTypeLength returns the length of the column type if the column is a
|
||||||
|
// variable length type. If the column is not a variable length type ok
|
||||||
|
// should return false.
|
||||||
|
func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
|
||||||
|
return r.rows.FieldDescriptions()[index].Length()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ColumnTypePrecisionScale should return the precision and scale for decimal
|
||||||
|
// types. If not applicable, ok should be false.
|
||||||
|
func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
|
||||||
|
return r.rows.FieldDescriptions()[index].PrecisionScale()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ColumnTypeScanType returns the value type that can be used to scan types into.
|
||||||
|
func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
|
||||||
|
return r.rows.FieldDescriptions()[index].Type()
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Rows) Close() error {
|
func (r *Rows) Close() error {
|
||||||
r.rows.Close()
|
r.rows.Close()
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -5,6 +5,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -1258,3 +1260,128 @@ func TestStmtQueryContextCancel(t *testing.T) {
|
||||||
t.Errorf("mock server err: %v", err)
|
t.Errorf("mock server err: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRowsColumnTypes(t *testing.T) {
|
||||||
|
columnTypesTests := []struct {
|
||||||
|
Name string
|
||||||
|
TypeName string
|
||||||
|
Length struct {
|
||||||
|
Len int64
|
||||||
|
OK bool
|
||||||
|
}
|
||||||
|
DecimalSize struct {
|
||||||
|
Precision int64
|
||||||
|
Scale int64
|
||||||
|
OK bool
|
||||||
|
}
|
||||||
|
ScanType reflect.Type
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "a",
|
||||||
|
TypeName: "INT4",
|
||||||
|
Length: struct {
|
||||||
|
Len int64
|
||||||
|
OK bool
|
||||||
|
}{
|
||||||
|
Len: 0,
|
||||||
|
OK: false,
|
||||||
|
},
|
||||||
|
DecimalSize: struct {
|
||||||
|
Precision int64
|
||||||
|
Scale int64
|
||||||
|
OK bool
|
||||||
|
}{
|
||||||
|
Precision: 0,
|
||||||
|
Scale: 0,
|
||||||
|
OK: false,
|
||||||
|
},
|
||||||
|
ScanType: reflect.TypeOf(int32(0)),
|
||||||
|
}, {
|
||||||
|
Name: "bar",
|
||||||
|
TypeName: "TEXT",
|
||||||
|
Length: struct {
|
||||||
|
Len int64
|
||||||
|
OK bool
|
||||||
|
}{
|
||||||
|
Len: math.MaxInt64,
|
||||||
|
OK: true,
|
||||||
|
},
|
||||||
|
DecimalSize: struct {
|
||||||
|
Precision int64
|
||||||
|
Scale int64
|
||||||
|
OK bool
|
||||||
|
}{
|
||||||
|
Precision: 0,
|
||||||
|
Scale: 0,
|
||||||
|
OK: false,
|
||||||
|
},
|
||||||
|
ScanType: reflect.TypeOf(""),
|
||||||
|
}, {
|
||||||
|
Name: "dec",
|
||||||
|
TypeName: "NUMERIC",
|
||||||
|
Length: struct {
|
||||||
|
Len int64
|
||||||
|
OK bool
|
||||||
|
}{
|
||||||
|
Len: 0,
|
||||||
|
OK: false,
|
||||||
|
},
|
||||||
|
DecimalSize: struct {
|
||||||
|
Precision int64
|
||||||
|
Scale int64
|
||||||
|
OK bool
|
||||||
|
}{
|
||||||
|
Precision: 9,
|
||||||
|
Scale: 2,
|
||||||
|
OK: true,
|
||||||
|
},
|
||||||
|
ScanType: reflect.TypeOf(float64(0)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
db := openDB(t)
|
||||||
|
defer closeDB(t, db)
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
columns, err := rows.ColumnTypes()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(columns) != 3 {
|
||||||
|
t.Errorf("expected 3 columns found %d", len(columns))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range columnTypesTests {
|
||||||
|
c := columns[i]
|
||||||
|
if c.Name() != tt.Name {
|
||||||
|
t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name)
|
||||||
|
}
|
||||||
|
if c.DatabaseTypeName() != tt.TypeName {
|
||||||
|
t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName)
|
||||||
|
}
|
||||||
|
l, ok := c.Length()
|
||||||
|
if l != tt.Length.Len {
|
||||||
|
t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len)
|
||||||
|
}
|
||||||
|
if ok != tt.Length.OK {
|
||||||
|
t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK)
|
||||||
|
}
|
||||||
|
p, s, ok := c.DecimalSize()
|
||||||
|
if p != tt.DecimalSize.Precision {
|
||||||
|
t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision)
|
||||||
|
}
|
||||||
|
if s != tt.DecimalSize.Scale {
|
||||||
|
t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale)
|
||||||
|
}
|
||||||
|
if ok != tt.DecimalSize.OK {
|
||||||
|
t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK)
|
||||||
|
}
|
||||||
|
if c.ScanType() != tt.ScanType {
|
||||||
|
t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue