Restore record support

This commit is contained in:
Jack Christensen 2022-01-31 20:42:12 -06:00
parent ef7114a8ce
commit 11223497b3
3 changed files with 191 additions and 1 deletions

View File

@ -86,6 +86,7 @@ const (
VarbitArrayOID = 1563
NumericOID = 1700
RecordOID = 2249
RecordArrayOID = 2287
UUIDOID = 2950
UUIDArrayOID = 2951
JSONBOID = 3802
@ -211,7 +212,6 @@ func NewConnInfo() *ConnInfo {
// ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID})
// ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID})
// ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID})
// ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID})
// ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID})
// ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID})
// ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID})
@ -245,6 +245,7 @@ func NewConnInfo() *ConnInfo {
ci.RegisterDataType(DataType{Name: "path", OID: PathOID, Codec: PathCodec{}})
ci.RegisterDataType(DataType{Name: "point", OID: PointOID, Codec: PointCodec{}})
ci.RegisterDataType(DataType{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}})
ci.RegisterDataType(DataType{Name: "record", OID: RecordOID, Codec: RecordCodec{}})
ci.RegisterDataType(DataType{Name: "text", OID: TextOID, Codec: TextCodec{}})
ci.RegisterDataType(DataType{Name: "tid", OID: TIDOID, Codec: TIDCodec{}})
ci.RegisterDataType(DataType{Name: "time", OID: TimeOID, Codec: TimeCodec{}})
@ -285,6 +286,7 @@ func NewConnInfo() *ConnInfo {
ci.RegisterDataType(DataType{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PathOID]}})
ci.RegisterDataType(DataType{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PointOID]}})
ci.RegisterDataType(DataType{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[PolygonOID]}})
ci.RegisterDataType(DataType{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[RecordOID]}})
ci.RegisterDataType(DataType{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TextOID]}})
ci.RegisterDataType(DataType{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TIDOID]}})
ci.RegisterDataType(DataType{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[TimeOID]}})

116
pgtype/record_codec.go Normal file
View File

@ -0,0 +1,116 @@
package pgtype
import (
"database/sql/driver"
"fmt"
)
// ArrayGetter is a type that can be converted into a PostgreSQL array.
// RecordCodec is a codec for the generic PostgreSQL record type such as is created with the "row" function. Record can
// only decode the binary format. The text format output format from PostgreSQL does not include type information and
// is therefore impossible to decode. Encoding is impossible because PostgreSQL does not support input of generic
// records.
type RecordCodec struct{}
func (RecordCodec) FormatSupported(format int16) bool {
return format == BinaryFormatCode
}
func (RecordCodec) PreferredFormat() int16 {
return BinaryFormatCode
}
func (RecordCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan {
return nil
}
func (RecordCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan {
if format == BinaryFormatCode {
switch target.(type) {
case CompositeIndexScanner:
return &scanPlanBinaryRecordToCompositeIndexScanner{ci: ci}
}
}
return nil
}
type scanPlanBinaryRecordToCompositeIndexScanner struct {
ci *ConnInfo
}
func (plan *scanPlanBinaryRecordToCompositeIndexScanner) Scan(src []byte, target interface{}) error {
targetScanner := (target).(CompositeIndexScanner)
if src == nil {
return targetScanner.ScanNull()
}
scanner := NewCompositeBinaryScanner(plan.ci, src)
for i := 0; scanner.Next(); i++ {
fieldTarget := targetScanner.ScanIndex(i)
if fieldTarget != nil {
fieldPlan := plan.ci.PlanScan(scanner.OID(), BinaryFormatCode, fieldTarget)
if fieldPlan == nil {
return fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), fieldTarget)
}
err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
if err != nil {
return err
}
}
}
if err := scanner.Err(); err != nil {
return err
}
return nil
}
func (RecordCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil {
return nil, nil
}
return nil, fmt.Errorf("not implemented")
}
func (RecordCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) {
if src == nil {
return nil, nil
}
switch format {
case TextFormatCode:
return string(src), nil
case BinaryFormatCode:
scanner := NewCompositeBinaryScanner(ci, src)
values := make([]interface{}, scanner.FieldCount())
for i := 0; scanner.Next(); i++ {
var v interface{}
fieldPlan := ci.PlanScan(scanner.OID(), BinaryFormatCode, &v)
if fieldPlan == nil {
return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v)
}
err := fieldPlan.Scan(scanner.Bytes(), &v)
if err != nil {
return nil, err
}
values[i] = v
}
if err := scanner.Err(); err != nil {
return nil, err
}
return values, nil
default:
return nil, fmt.Errorf("unknown format code %d", format)
}
}

View File

@ -0,0 +1,72 @@
package pgtype_test
import (
"context"
"testing"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgtype/testutil"
"github.com/stretchr/testify/require"
)
func TestRecordCodec(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
var a string
var b int32
err := conn.QueryRow(context.Background(), `select row('foo'::text, 42::int4)`).Scan(pgtype.CompositeFields{&a, &b})
require.NoError(t, err)
require.Equal(t, "foo", a)
require.Equal(t, int32(42), b)
}
func TestRecordCodecDecodeValue(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
for _, tt := range []struct {
sql string
expected interface{}
}{
{
sql: `select row()`,
expected: []interface{}{},
},
{
sql: `select row('foo'::text, 42::int4)`,
expected: []interface{}{"foo", int32(42)},
},
{
sql: `select row(100.0::float4, 1.09::float4)`,
expected: []interface{}{float32(100), float32(1.09)},
},
{
sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`,
expected: []interface{}{"foo", []interface{}{int32(1), int32(2), nil, int32(4)}, int32(42)},
},
{
sql: `select row(null)`,
expected: []interface{}{nil},
},
{
sql: `select null::record`,
expected: nil,
},
} {
t.Run(tt.sql, func(t *testing.T) {
rows, err := conn.Query(context.Background(), tt.sql)
require.NoError(t, err)
for rows.Next() {
values, err := rows.Values()
require.NoError(t, err)
require.Len(t, values, 1)
require.Equal(t, tt.expected, values[0])
}
require.NoError(t, rows.Err())
})
}
}