mirror of
https://github.com/jackc/pgx.git
synced 2025-05-31 11:42:24 +00:00
Restore record support
This commit is contained in:
parent
ef7114a8ce
commit
11223497b3
@ -86,6 +86,7 @@ const (
|
||||
VarbitArrayOID = 1563
|
||||
NumericOID = 1700
|
||||
RecordOID = 2249
|
||||
RecordArrayOID = 2287
|
||||
UUIDOID = 2950
|
||||
UUIDArrayOID = 2951
|
||||
JSONBOID = 3802
|
||||
@ -211,7 +212,6 @@ func NewConnInfo() *ConnInfo {
|
||||
// ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID})
|
||||
// ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID})
|
||||
// ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID})
|
||||
// ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID})
|
||||
// ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID})
|
||||
// ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID})
|
||||
// ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID})
|
||||
@ -245,6 +245,7 @@ func NewConnInfo() *ConnInfo {
|
||||
ci.RegisterDataType(DataType{Name: "path", OID: PathOID, Codec: PathCodec{}})
|
||||
ci.RegisterDataType(DataType{Name: "point", OID: PointOID, Codec: PointCodec{}})
|
||||
ci.RegisterDataType(DataType{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}})
|
||||
ci.RegisterDataType(DataType{Name: "record", OID: RecordOID, Codec: RecordCodec{}})
|
||||
ci.RegisterDataType(DataType{Name: "text", OID: TextOID, Codec: TextCodec{}})
|
||||
ci.RegisterDataType(DataType{Name: "tid", OID: TIDOID, Codec: TIDCodec{}})
|
||||
ci.RegisterDataType(DataType{Name: "time", OID: TimeOID, Codec: TimeCodec{}})
|
||||
@ -285,6 +286,7 @@ func NewConnInfo() *ConnInfo {
|
||||
ci.RegisterDataType(DataType{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PathOID]}})
|
||||
ci.RegisterDataType(DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PointOID]}})
|
||||
ci.RegisterDataType(DataType{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PolygonOID]}})
|
||||
ci.RegisterDataType(DataType{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[RecordOID]}})
|
||||
ci.RegisterDataType(DataType{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TextOID]}})
|
||||
ci.RegisterDataType(DataType{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TIDOID]}})
|
||||
ci.RegisterDataType(DataType{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimeOID]}})
|
||||
|
116
pgtype/record_codec.go
Normal file
116
pgtype/record_codec.go
Normal file
@ -0,0 +1,116 @@
|
||||
package pgtype
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ArrayGetter is a type that can be converted into a PostgreSQL array.
|
||||
|
||||
// RecordCodec is a codec for the generic PostgreSQL record type such as is created with the "row" function. Record can
|
||||
// only decode the binary format. The text format output format from PostgreSQL does not include type information and
|
||||
// is therefore impossible to decode. Encoding is impossible because PostgreSQL does not support input of generic
|
||||
// records.
|
||||
type RecordCodec struct{}
|
||||
|
||||
func (RecordCodec) FormatSupported(format int16) bool {
|
||||
return format == BinaryFormatCode
|
||||
}
|
||||
|
||||
func (RecordCodec) PreferredFormat() int16 {
|
||||
return BinaryFormatCode
|
||||
}
|
||||
|
||||
func (RecordCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (RecordCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan {
|
||||
if format == BinaryFormatCode {
|
||||
switch target.(type) {
|
||||
case CompositeIndexScanner:
|
||||
return &scanPlanBinaryRecordToCompositeIndexScanner{ci: ci}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type scanPlanBinaryRecordToCompositeIndexScanner struct {
|
||||
ci *ConnInfo
|
||||
}
|
||||
|
||||
func (plan *scanPlanBinaryRecordToCompositeIndexScanner) Scan(src []byte, target interface{}) error {
|
||||
targetScanner := (target).(CompositeIndexScanner)
|
||||
|
||||
if src == nil {
|
||||
return targetScanner.ScanNull()
|
||||
}
|
||||
|
||||
scanner := NewCompositeBinaryScanner(plan.ci, src)
|
||||
for i := 0; scanner.Next(); i++ {
|
||||
fieldTarget := targetScanner.ScanIndex(i)
|
||||
if fieldTarget != nil {
|
||||
fieldPlan := plan.ci.PlanScan(scanner.OID(), BinaryFormatCode, fieldTarget)
|
||||
if fieldPlan == nil {
|
||||
return fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), fieldTarget)
|
||||
}
|
||||
|
||||
err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (RecordCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||
if src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (RecordCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) {
|
||||
if src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch format {
|
||||
case TextFormatCode:
|
||||
return string(src), nil
|
||||
case BinaryFormatCode:
|
||||
scanner := NewCompositeBinaryScanner(ci, src)
|
||||
values := make([]interface{}, scanner.FieldCount())
|
||||
for i := 0; scanner.Next(); i++ {
|
||||
var v interface{}
|
||||
fieldPlan := ci.PlanScan(scanner.OID(), BinaryFormatCode, &v)
|
||||
if fieldPlan == nil {
|
||||
return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v)
|
||||
}
|
||||
|
||||
err := fieldPlan.Scan(scanner.Bytes(), &v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
values[i] = v
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return values, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown format code %d", format)
|
||||
}
|
||||
|
||||
}
|
72
pgtype/record_codec_test.go
Normal file
72
pgtype/record_codec_test.go
Normal file
@ -0,0 +1,72 @@
|
||||
package pgtype_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgtype/testutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRecordCodec(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
||||
var a string
|
||||
var b int32
|
||||
err := conn.QueryRow(context.Background(), `select row('foo'::text, 42::int4)`).Scan(pgtype.CompositeFields{&a, &b})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "foo", a)
|
||||
require.Equal(t, int32(42), b)
|
||||
}
|
||||
|
||||
func TestRecordCodecDecodeValue(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
||||
for _, tt := range []struct {
|
||||
sql string
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
sql: `select row()`,
|
||||
expected: []interface{}{},
|
||||
},
|
||||
{
|
||||
sql: `select row('foo'::text, 42::int4)`,
|
||||
expected: []interface{}{"foo", int32(42)},
|
||||
},
|
||||
{
|
||||
sql: `select row(100.0::float4, 1.09::float4)`,
|
||||
expected: []interface{}{float32(100), float32(1.09)},
|
||||
},
|
||||
{
|
||||
sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`,
|
||||
expected: []interface{}{"foo", []interface{}{int32(1), int32(2), nil, int32(4)}, int32(42)},
|
||||
},
|
||||
{
|
||||
sql: `select row(null)`,
|
||||
expected: []interface{}{nil},
|
||||
},
|
||||
{
|
||||
sql: `select null::record`,
|
||||
expected: nil,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.sql, func(t *testing.T) {
|
||||
rows, err := conn.Query(context.Background(), tt.sql)
|
||||
require.NoError(t, err)
|
||||
|
||||
for rows.Next() {
|
||||
values, err := rows.Values()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, values, 1)
|
||||
require.Equal(t, tt.expected, values[0])
|
||||
}
|
||||
|
||||
require.NoError(t, rows.Err())
|
||||
})
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user