mirror of https://github.com/jackc/pgx.git
Initial rebuilt composite support
parent
dc77e7c2da
commit
f5c3eeb813
|
@ -0,0 +1,551 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
// CompositeIndexGetter is a type accessed by index that can be converted into a PostgreSQL composite.
|
||||
type CompositeIndexGetter interface {
|
||||
// IsNull returns true if the value is SQL NULL.
|
||||
IsNull() bool
|
||||
|
||||
// Index returns the element at i.
|
||||
Index(i int) interface{}
|
||||
}
|
||||
|
||||
// CompositeIndexScanner is a type accessed by index that can be scanned from a PostgreSQL composite.
|
||||
type CompositeIndexScanner interface {
|
||||
// ScanNull sets the value to SQL NULL.
|
||||
ScanNull() error
|
||||
|
||||
// ScanIndex returns a value usable as a scan target for i.
|
||||
ScanIndex(i int) interface{}
|
||||
}
|
||||
|
||||
type CompositeCodecField struct {
|
||||
Name string
|
||||
DataType *DataType
|
||||
}
|
||||
|
||||
type CompositeCodec struct {
|
||||
Fields []CompositeCodecField
|
||||
}
|
||||
|
||||
func (c *CompositeCodec) FormatSupported(format int16) bool {
|
||||
for _, f := range c.Fields {
|
||||
if !f.DataType.Codec.FormatSupported(format) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *CompositeCodec) PreferredFormat() int16 {
|
||||
if c.FormatSupported(BinaryFormatCode) {
|
||||
return BinaryFormatCode
|
||||
}
|
||||
return TextFormatCode
|
||||
}
|
||||
|
||||
func (c *CompositeCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan {
|
||||
if _, ok := value.(CompositeIndexGetter); !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch format {
|
||||
case BinaryFormatCode:
|
||||
return &encodePlanCompositeCodecCompositeIndexGetterToBinary{cc: c, ci: ci}
|
||||
case TextFormatCode:
|
||||
return &encodePlanCompositeCodecCompositeIndexGetterToText{cc: c, ci: ci}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type encodePlanCompositeCodecCompositeIndexGetterToBinary struct {
|
||||
cc *CompositeCodec
|
||||
ci *ConnInfo
|
||||
}
|
||||
|
||||
func (plan *encodePlanCompositeCodecCompositeIndexGetterToBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
|
||||
getter := value.(CompositeIndexGetter)
|
||||
|
||||
if getter.IsNull() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
builder := NewCompositeBinaryBuilder(plan.ci, buf)
|
||||
for i, field := range plan.cc.Fields {
|
||||
builder.AppendValue(field.DataType.OID, getter.Index(i))
|
||||
}
|
||||
|
||||
return builder.Finish()
|
||||
}
|
||||
|
||||
type encodePlanCompositeCodecCompositeIndexGetterToText struct {
|
||||
cc *CompositeCodec
|
||||
ci *ConnInfo
|
||||
}
|
||||
|
||||
func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
|
||||
getter := value.(CompositeIndexGetter)
|
||||
|
||||
if getter.IsNull() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
b := NewCompositeTextBuilder(plan.ci, buf)
|
||||
for i, field := range plan.cc.Fields {
|
||||
b.AppendValue(field.DataType.OID, getter.Index(i))
|
||||
}
|
||||
|
||||
return b.Finish()
|
||||
}
|
||||
|
||||
func (c *CompositeCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan {
|
||||
switch format {
|
||||
case BinaryFormatCode:
|
||||
switch target.(type) {
|
||||
case CompositeIndexScanner:
|
||||
return &scanPlanBinaryCompositeToCompositeIndexScanner{cc: c, ci: ci}
|
||||
}
|
||||
case TextFormatCode:
|
||||
switch target.(type) {
|
||||
case CompositeIndexScanner:
|
||||
return &scanPlanTextCompositeToCompositeIndexScanner{cc: c, ci: ci}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type scanPlanBinaryCompositeToCompositeIndexScanner struct {
|
||||
cc *CompositeCodec
|
||||
ci *ConnInfo
|
||||
}
|
||||
|
||||
func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, target interface{}) error {
|
||||
targetScanner := (target).(CompositeIndexScanner)
|
||||
|
||||
if src == nil {
|
||||
return targetScanner.ScanNull()
|
||||
}
|
||||
|
||||
scanner := NewCompositeBinaryScanner(plan.ci, src)
|
||||
for i, field := range plan.cc.Fields {
|
||||
if scanner.Next() {
|
||||
fieldTarget := targetScanner.ScanIndex(i)
|
||||
if fieldTarget != nil {
|
||||
fieldPlan := plan.ci.PlanScan(field.DataType.OID, BinaryFormatCode, fieldTarget)
|
||||
if fieldPlan == nil {
|
||||
return fmt.Errorf("unable to encode %v into OID %d in binary format", field, field.DataType.OID)
|
||||
}
|
||||
|
||||
err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return errors.New("read past end of composite")
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type scanPlanTextCompositeToCompositeIndexScanner struct {
|
||||
cc *CompositeCodec
|
||||
ci *ConnInfo
|
||||
}
|
||||
|
||||
func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, target interface{}) error {
|
||||
targetScanner := (target).(CompositeIndexScanner)
|
||||
|
||||
if src == nil {
|
||||
return targetScanner.ScanNull()
|
||||
}
|
||||
|
||||
scanner := NewCompositeTextScanner(plan.ci, src)
|
||||
for i, field := range plan.cc.Fields {
|
||||
if scanner.Next() {
|
||||
fieldTarget := targetScanner.ScanIndex(i)
|
||||
if fieldTarget != nil {
|
||||
fieldPlan := plan.ci.PlanScan(field.DataType.OID, TextFormatCode, fieldTarget)
|
||||
if fieldPlan == nil {
|
||||
return fmt.Errorf("unable to encode %v into OID %d in text format", field, field.DataType.OID)
|
||||
}
|
||||
|
||||
err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return errors.New("read past end of composite")
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CompositeCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) {
|
||||
if src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// var n int64
|
||||
// err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n)
|
||||
// return n, err
|
||||
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (c *CompositeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) {
|
||||
if src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// var n int16
|
||||
// err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n)
|
||||
// return n, err
|
||||
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
type CompositeBinaryScanner struct {
|
||||
ci *ConnInfo
|
||||
rp int
|
||||
src []byte
|
||||
|
||||
fieldCount int32
|
||||
fieldBytes []byte
|
||||
fieldOID uint32
|
||||
err error
|
||||
}
|
||||
|
||||
// NewCompositeBinaryScanner a scanner over a binary encoded composite balue.
|
||||
func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner {
|
||||
rp := 0
|
||||
if len(src[rp:]) < 4 {
|
||||
return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)}
|
||||
}
|
||||
|
||||
fieldCount := int32(binary.BigEndian.Uint32(src[rp:]))
|
||||
rp += 4
|
||||
|
||||
return &CompositeBinaryScanner{
|
||||
ci: ci,
|
||||
rp: rp,
|
||||
src: src,
|
||||
fieldCount: fieldCount,
|
||||
}
|
||||
}
|
||||
|
||||
// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
|
||||
// Next returns false, the Err method can be called to check if any errors occurred.
|
||||
func (cfs *CompositeBinaryScanner) Next() bool {
|
||||
if cfs.err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if cfs.rp == len(cfs.src) {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(cfs.src[cfs.rp:]) < 8 {
|
||||
cfs.err = fmt.Errorf("Record incomplete %v", cfs.src)
|
||||
return false
|
||||
}
|
||||
cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:])
|
||||
cfs.rp += 4
|
||||
|
||||
fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:])))
|
||||
cfs.rp += 4
|
||||
|
||||
if fieldLen >= 0 {
|
||||
if len(cfs.src[cfs.rp:]) < fieldLen {
|
||||
cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src)
|
||||
return false
|
||||
}
|
||||
cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen]
|
||||
cfs.rp += fieldLen
|
||||
} else {
|
||||
cfs.fieldBytes = nil
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (cfs *CompositeBinaryScanner) FieldCount() int {
|
||||
return int(cfs.fieldCount)
|
||||
}
|
||||
|
||||
// Bytes returns the bytes of the field most recently read by Scan().
|
||||
func (cfs *CompositeBinaryScanner) Bytes() []byte {
|
||||
return cfs.fieldBytes
|
||||
}
|
||||
|
||||
// OID returns the OID of the field most recently read by Scan().
|
||||
func (cfs *CompositeBinaryScanner) OID() uint32 {
|
||||
return cfs.fieldOID
|
||||
}
|
||||
|
||||
// Err returns any error encountered by the scanner.
|
||||
func (cfs *CompositeBinaryScanner) Err() error {
|
||||
return cfs.err
|
||||
}
|
||||
|
||||
type CompositeTextScanner struct {
|
||||
ci *ConnInfo
|
||||
rp int
|
||||
src []byte
|
||||
|
||||
fieldBytes []byte
|
||||
err error
|
||||
}
|
||||
|
||||
// NewCompositeTextScanner a scanner over a text encoded composite value.
|
||||
func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner {
|
||||
if len(src) < 2 {
|
||||
return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)}
|
||||
}
|
||||
|
||||
if src[0] != '(' {
|
||||
return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")}
|
||||
}
|
||||
|
||||
if src[len(src)-1] != ')' {
|
||||
return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")}
|
||||
}
|
||||
|
||||
return &CompositeTextScanner{
|
||||
ci: ci,
|
||||
rp: 1,
|
||||
src: src,
|
||||
}
|
||||
}
|
||||
|
||||
// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
|
||||
// Next returns false, the Err method can be called to check if any errors occurred.
|
||||
func (cfs *CompositeTextScanner) Next() bool {
|
||||
if cfs.err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if cfs.rp == len(cfs.src) {
|
||||
return false
|
||||
}
|
||||
|
||||
switch cfs.src[cfs.rp] {
|
||||
case ',', ')': // null
|
||||
cfs.rp++
|
||||
cfs.fieldBytes = nil
|
||||
return true
|
||||
case '"': // quoted value
|
||||
cfs.rp++
|
||||
cfs.fieldBytes = make([]byte, 0, 16)
|
||||
for {
|
||||
ch := cfs.src[cfs.rp]
|
||||
|
||||
if ch == '"' {
|
||||
cfs.rp++
|
||||
if cfs.src[cfs.rp] == '"' {
|
||||
cfs.fieldBytes = append(cfs.fieldBytes, '"')
|
||||
cfs.rp++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
} else if ch == '\\' {
|
||||
cfs.rp++
|
||||
cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp])
|
||||
cfs.rp++
|
||||
} else {
|
||||
cfs.fieldBytes = append(cfs.fieldBytes, ch)
|
||||
cfs.rp++
|
||||
}
|
||||
}
|
||||
cfs.rp++
|
||||
return true
|
||||
default: // unquoted value
|
||||
start := cfs.rp
|
||||
for {
|
||||
ch := cfs.src[cfs.rp]
|
||||
if ch == ',' || ch == ')' {
|
||||
break
|
||||
}
|
||||
cfs.rp++
|
||||
}
|
||||
cfs.fieldBytes = cfs.src[start:cfs.rp]
|
||||
cfs.rp++
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Bytes returns the bytes of the field most recently read by Scan().
|
||||
func (cfs *CompositeTextScanner) Bytes() []byte {
|
||||
return cfs.fieldBytes
|
||||
}
|
||||
|
||||
// Err returns any error encountered by the scanner.
|
||||
func (cfs *CompositeTextScanner) Err() error {
|
||||
return cfs.err
|
||||
}
|
||||
|
||||
type CompositeBinaryBuilder struct {
|
||||
ci *ConnInfo
|
||||
buf []byte
|
||||
startIdx int
|
||||
fieldCount uint32
|
||||
err error
|
||||
}
|
||||
|
||||
func NewCompositeBinaryBuilder(ci *ConnInfo, buf []byte) *CompositeBinaryBuilder {
|
||||
startIdx := len(buf)
|
||||
buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields
|
||||
return &CompositeBinaryBuilder{ci: ci, buf: buf, startIdx: startIdx}
|
||||
}
|
||||
|
||||
func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) {
|
||||
if b.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if field == nil {
|
||||
b.buf = pgio.AppendUint32(b.buf, oid)
|
||||
b.buf = pgio.AppendInt32(b.buf, -1)
|
||||
b.fieldCount++
|
||||
return
|
||||
}
|
||||
|
||||
plan := b.ci.PlanEncode(oid, BinaryFormatCode, field)
|
||||
if plan == nil {
|
||||
b.err = fmt.Errorf("unable to encode %v into OID %d in binary format", field, oid)
|
||||
return
|
||||
}
|
||||
|
||||
b.buf = pgio.AppendUint32(b.buf, oid)
|
||||
lengthPos := len(b.buf)
|
||||
b.buf = pgio.AppendInt32(b.buf, -1)
|
||||
fieldBuf, err := plan.Encode(field, b.buf)
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
}
|
||||
if fieldBuf != nil {
|
||||
binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf)))
|
||||
b.buf = fieldBuf
|
||||
}
|
||||
|
||||
b.fieldCount++
|
||||
}
|
||||
|
||||
func (b *CompositeBinaryBuilder) Finish() ([]byte, error) {
|
||||
if b.err != nil {
|
||||
return nil, b.err
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount)
|
||||
return b.buf, nil
|
||||
}
|
||||
|
||||
type CompositeTextBuilder struct {
|
||||
ci *ConnInfo
|
||||
buf []byte
|
||||
startIdx int
|
||||
fieldCount uint32
|
||||
err error
|
||||
fieldBuf [32]byte
|
||||
}
|
||||
|
||||
func NewCompositeTextBuilder(ci *ConnInfo, buf []byte) *CompositeTextBuilder {
|
||||
buf = append(buf, '(') // allocate room for number of fields
|
||||
return &CompositeTextBuilder{ci: ci, buf: buf}
|
||||
}
|
||||
|
||||
func (b *CompositeTextBuilder) AppendValue(oid uint32, field interface{}) {
|
||||
if b.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if field == nil {
|
||||
b.buf = append(b.buf, ',')
|
||||
return
|
||||
}
|
||||
|
||||
plan := b.ci.PlanEncode(oid, TextFormatCode, field)
|
||||
if plan == nil {
|
||||
b.err = fmt.Errorf("unable to encode %v into OID %d in text format", field, oid)
|
||||
return
|
||||
}
|
||||
|
||||
fieldBuf, err := plan.Encode(field, b.fieldBuf[0:0])
|
||||
if err != nil {
|
||||
b.err = err
|
||||
return
|
||||
}
|
||||
if fieldBuf != nil {
|
||||
b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...)
|
||||
}
|
||||
|
||||
b.buf = append(b.buf, ',')
|
||||
}
|
||||
|
||||
func (b *CompositeTextBuilder) Finish() ([]byte, error) {
|
||||
if b.err != nil {
|
||||
return nil, b.err
|
||||
}
|
||||
|
||||
b.buf[len(b.buf)-1] = ')'
|
||||
return b.buf, nil
|
||||
}
|
||||
|
||||
var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
|
||||
|
||||
func quoteCompositeField(src string) string {
|
||||
return `"` + quoteCompositeReplacer.Replace(src) + `"`
|
||||
}
|
||||
|
||||
func quoteCompositeFieldIfNeeded(src string) string {
|
||||
if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) {
|
||||
return quoteCompositeField(src)
|
||||
}
|
||||
return src
|
||||
}
|
||||
|
||||
// CompositeFields represents the values of a composite value. It can be used as an encoding source or as a scan target.
|
||||
// It cannot scan a NULL, but the composite fields can be NULL.
|
||||
type CompositeFields []interface{}
|
||||
|
||||
func (cf CompositeFields) SkipUnderlyingTypePlan() {}
|
||||
|
||||
func (cf CompositeFields) IsNull() bool {
|
||||
return cf == nil
|
||||
}
|
||||
|
||||
func (cf CompositeFields) Index(i int) interface{} {
|
||||
return cf[i]
|
||||
}
|
||||
|
||||
func (cf CompositeFields) ScanNull() error {
|
||||
return fmt.Errorf("cannot scan NULL into CompositeFields")
|
||||
}
|
||||
|
||||
func (cf CompositeFields) ScanIndex(i int) interface{} {
|
||||
return cf[i]
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
pgx "github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgtype/testutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCompositeCodecTranscode(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
||||
_, err := conn.Exec(context.Background(), `drop type if exists ct_test;
|
||||
|
||||
create type ct_test as (
|
||||
a text,
|
||||
b int4
|
||||
);`)
|
||||
require.NoError(t, err)
|
||||
defer conn.Exec(context.Background(), "drop type ct_test")
|
||||
|
||||
var oid uint32
|
||||
err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer conn.Exec(context.Background(), "drop type ct_test")
|
||||
|
||||
textDataType, ok := conn.ConnInfo().DataTypeForOID(pgtype.TextOID)
|
||||
require.True(t, ok)
|
||||
|
||||
int4DataType, ok := conn.ConnInfo().DataTypeForOID(pgtype.Int4OID)
|
||||
require.True(t, ok)
|
||||
|
||||
conn.ConnInfo().RegisterDataType(pgtype.DataType{
|
||||
Name: "ct_test",
|
||||
OID: oid,
|
||||
Codec: &pgtype.CompositeCodec{
|
||||
Fields: []pgtype.CompositeCodecField{
|
||||
{
|
||||
Name: "a",
|
||||
DataType: textDataType,
|
||||
},
|
||||
{
|
||||
Name: "b",
|
||||
DataType: int4DataType,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
formats := []struct {
|
||||
name string
|
||||
code int16
|
||||
}{
|
||||
{name: "TextFormat", code: pgx.TextFormatCode},
|
||||
{name: "BinaryFormat", code: pgx.BinaryFormatCode},
|
||||
}
|
||||
|
||||
for _, format := range formats {
|
||||
var a string
|
||||
var b int32
|
||||
|
||||
err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QueryResultFormats{format.code},
|
||||
pgtype.CompositeFields{"hi", int32(42)},
|
||||
).Scan(
|
||||
pgtype.CompositeFields{&a, &b},
|
||||
)
|
||||
require.NoErrorf(t, err, "%v", format.name)
|
||||
require.EqualValuesf(t, "hi", a, "%v", format.name)
|
||||
require.EqualValuesf(t, 42, b, "%v", format.name)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue