From f5c3eeb813aa682f1a8ae95582ca86b876e0e563 Mon Sep 17 00:00:00 2001
From: Jack Christensen <jack@jackchristensen.com>
Date: Sat, 29 Jan 2022 15:43:18 -0600
Subject: [PATCH] Initial rebuilt composite support

---
 pgtype/composite.go      | 551 +++++++++++++++++++++++++++++++++++++++
 pgtype/composite_test.go |  76 ++++++
 2 files changed, 627 insertions(+)
 create mode 100644 pgtype/composite.go
 create mode 100644 pgtype/composite_test.go

diff --git a/pgtype/composite.go b/pgtype/composite.go
new file mode 100644
index 00000000..d21ab665
--- /dev/null
+++ b/pgtype/composite.go
@@ -0,0 +1,551 @@
+package pgtype
+
+import (
+	"database/sql/driver"
+	"encoding/binary"
+	"errors"
+	"fmt"
+	"strings"
+
+	"github.com/jackc/pgio"
+)
+
+// CompositeIndexGetter is a type accessed by index that can be converted into a PostgreSQL composite.
+type CompositeIndexGetter interface {
+	// IsNull returns true if the value is SQL NULL.
+	IsNull() bool
+
+	// Index returns the element at i.
+	Index(i int) interface{}
+}
+
+// CompositeIndexScanner is a type accessed by index that can be scanned from a PostgreSQL composite.
+type CompositeIndexScanner interface {
+	// ScanNull sets the value to SQL NULL.
+	ScanNull() error
+
+	// ScanIndex returns a value usable as a scan target for i.
+	ScanIndex(i int) interface{}
+}
+
+type CompositeCodecField struct {
+	Name     string
+	DataType *DataType
+}
+
+type CompositeCodec struct {
+	Fields []CompositeCodecField
+}
+
+func (c *CompositeCodec) FormatSupported(format int16) bool {
+	for _, f := range c.Fields {
+		if !f.DataType.Codec.FormatSupported(format) {
+			return false
+		}
+	}
+
+	return true
+}
+
+func (c *CompositeCodec) PreferredFormat() int16 {
+	if c.FormatSupported(BinaryFormatCode) {
+		return BinaryFormatCode
+	}
+	return TextFormatCode
+}
+
+func (c *CompositeCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan {
+	if _, ok := value.(CompositeIndexGetter); !ok {
+		return nil
+	}
+
+	switch format {
+	case BinaryFormatCode:
+		return &encodePlanCompositeCodecCompositeIndexGetterToBinary{cc: c, ci: ci}
+	case TextFormatCode:
+		return &encodePlanCompositeCodecCompositeIndexGetterToText{cc: c, ci: ci}
+	}
+
+	return nil
+}
+
+type encodePlanCompositeCodecCompositeIndexGetterToBinary struct {
+	cc *CompositeCodec
+	ci *ConnInfo
+}
+
+func (plan *encodePlanCompositeCodecCompositeIndexGetterToBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
+	getter := value.(CompositeIndexGetter)
+
+	if getter.IsNull() {
+		return nil, nil
+	}
+
+	builder := NewCompositeBinaryBuilder(plan.ci, buf)
+	for i, field := range plan.cc.Fields {
+		builder.AppendValue(field.DataType.OID, getter.Index(i))
+	}
+
+	return builder.Finish()
+}
+
+type encodePlanCompositeCodecCompositeIndexGetterToText struct {
+	cc *CompositeCodec
+	ci *ConnInfo
+}
+
+func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
+	getter := value.(CompositeIndexGetter)
+
+	if getter.IsNull() {
+		return nil, nil
+	}
+
+	b := NewCompositeTextBuilder(plan.ci, buf)
+	for i, field := range plan.cc.Fields {
+		b.AppendValue(field.DataType.OID, getter.Index(i))
+	}
+
+	return b.Finish()
+}
+
+func (c *CompositeCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan {
+	switch format {
+	case BinaryFormatCode:
+		switch target.(type) {
+		case CompositeIndexScanner:
+			return &scanPlanBinaryCompositeToCompositeIndexScanner{cc: c, ci: ci}
+		}
+	case TextFormatCode:
+		switch target.(type) {
+		case CompositeIndexScanner:
+			return &scanPlanTextCompositeToCompositeIndexScanner{cc: c, ci: ci}
+		}
+	}
+
+	return nil
+}
+
+type scanPlanBinaryCompositeToCompositeIndexScanner struct {
+	cc *CompositeCodec
+	ci *ConnInfo
+}
+
+func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, target interface{}) error {
+	targetScanner := (target).(CompositeIndexScanner)
+
+	if src == nil {
+		return targetScanner.ScanNull()
+	}
+
+	scanner := NewCompositeBinaryScanner(plan.ci, src)
+	for i, field := range plan.cc.Fields {
+		if scanner.Next() {
+			fieldTarget := targetScanner.ScanIndex(i)
+			if fieldTarget != nil {
+				fieldPlan := plan.ci.PlanScan(field.DataType.OID, BinaryFormatCode, fieldTarget)
+				if fieldPlan == nil {
+					return fmt.Errorf("unable to encode %v into OID %d in binary format", field, field.DataType.OID)
+				}
+
+				err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
+				if err != nil {
+					return err
+				}
+			}
+		} else {
+			return errors.New("read past end of composite")
+		}
+	}
+
+	if err := scanner.Err(); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+type scanPlanTextCompositeToCompositeIndexScanner struct {
+	cc *CompositeCodec
+	ci *ConnInfo
+}
+
+func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, target interface{}) error {
+	targetScanner := (target).(CompositeIndexScanner)
+
+	if src == nil {
+		return targetScanner.ScanNull()
+	}
+
+	scanner := NewCompositeTextScanner(plan.ci, src)
+	for i, field := range plan.cc.Fields {
+		if scanner.Next() {
+			fieldTarget := targetScanner.ScanIndex(i)
+			if fieldTarget != nil {
+				fieldPlan := plan.ci.PlanScan(field.DataType.OID, TextFormatCode, fieldTarget)
+				if fieldPlan == nil {
+					return fmt.Errorf("unable to encode %v into OID %d in text format", field, field.DataType.OID)
+				}
+
+				err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
+				if err != nil {
+					return err
+				}
+			}
+		} else {
+			return errors.New("read past end of composite")
+		}
+	}
+
+	if err := scanner.Err(); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (c *CompositeCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) {
+	if src == nil {
+		return nil, nil
+	}
+
+	// var n int64
+	// err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n)
+	// return n, err
+
+	return nil, fmt.Errorf("not implemented")
+}
+
+func (c *CompositeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) {
+	if src == nil {
+		return nil, nil
+	}
+
+	// var n int16
+	// err := c.PlanScan(ci, oid, format, &n, true).Scan(ci, oid, format, src, &n)
+	// return n, err
+
+	return nil, fmt.Errorf("not implemented")
+}
+
+type CompositeBinaryScanner struct {
+	ci  *ConnInfo
+	rp  int
+	src []byte
+
+	fieldCount int32
+	fieldBytes []byte
+	fieldOID   uint32
+	err        error
+}
+
+// NewCompositeBinaryScanner a scanner over a binary encoded composite balue.
+func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner {
+	rp := 0
+	if len(src[rp:]) < 4 {
+		return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)}
+	}
+
+	fieldCount := int32(binary.BigEndian.Uint32(src[rp:]))
+	rp += 4
+
+	return &CompositeBinaryScanner{
+		ci:         ci,
+		rp:         rp,
+		src:        src,
+		fieldCount: fieldCount,
+	}
+}
+
+// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
+// Next returns false, the Err method can be called to check if any errors occurred.
+func (cfs *CompositeBinaryScanner) Next() bool {
+	if cfs.err != nil {
+		return false
+	}
+
+	if cfs.rp == len(cfs.src) {
+		return false
+	}
+
+	if len(cfs.src[cfs.rp:]) < 8 {
+		cfs.err = fmt.Errorf("Record incomplete %v", cfs.src)
+		return false
+	}
+	cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:])
+	cfs.rp += 4
+
+	fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:])))
+	cfs.rp += 4
+
+	if fieldLen >= 0 {
+		if len(cfs.src[cfs.rp:]) < fieldLen {
+			cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src)
+			return false
+		}
+		cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen]
+		cfs.rp += fieldLen
+	} else {
+		cfs.fieldBytes = nil
+	}
+
+	return true
+}
+
+func (cfs *CompositeBinaryScanner) FieldCount() int {
+	return int(cfs.fieldCount)
+}
+
+// Bytes returns the bytes of the field most recently read by Scan().
+func (cfs *CompositeBinaryScanner) Bytes() []byte {
+	return cfs.fieldBytes
+}
+
+// OID returns the OID of the field most recently read by Scan().
+func (cfs *CompositeBinaryScanner) OID() uint32 {
+	return cfs.fieldOID
+}
+
+// Err returns any error encountered by the scanner.
+func (cfs *CompositeBinaryScanner) Err() error {
+	return cfs.err
+}
+
+type CompositeTextScanner struct {
+	ci  *ConnInfo
+	rp  int
+	src []byte
+
+	fieldBytes []byte
+	err        error
+}
+
+// NewCompositeTextScanner a scanner over a text encoded composite value.
+func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner {
+	if len(src) < 2 {
+		return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)}
+	}
+
+	if src[0] != '(' {
+		return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")}
+	}
+
+	if src[len(src)-1] != ')' {
+		return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")}
+	}
+
+	return &CompositeTextScanner{
+		ci:  ci,
+		rp:  1,
+		src: src,
+	}
+}
+
+// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
+// Next returns false, the Err method can be called to check if any errors occurred.
+func (cfs *CompositeTextScanner) Next() bool {
+	if cfs.err != nil {
+		return false
+	}
+
+	if cfs.rp == len(cfs.src) {
+		return false
+	}
+
+	switch cfs.src[cfs.rp] {
+	case ',', ')': // null
+		cfs.rp++
+		cfs.fieldBytes = nil
+		return true
+	case '"': // quoted value
+		cfs.rp++
+		cfs.fieldBytes = make([]byte, 0, 16)
+		for {
+			ch := cfs.src[cfs.rp]
+
+			if ch == '"' {
+				cfs.rp++
+				if cfs.src[cfs.rp] == '"' {
+					cfs.fieldBytes = append(cfs.fieldBytes, '"')
+					cfs.rp++
+				} else {
+					break
+				}
+			} else if ch == '\\' {
+				cfs.rp++
+				cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp])
+				cfs.rp++
+			} else {
+				cfs.fieldBytes = append(cfs.fieldBytes, ch)
+				cfs.rp++
+			}
+		}
+		cfs.rp++
+		return true
+	default: // unquoted value
+		start := cfs.rp
+		for {
+			ch := cfs.src[cfs.rp]
+			if ch == ',' || ch == ')' {
+				break
+			}
+			cfs.rp++
+		}
+		cfs.fieldBytes = cfs.src[start:cfs.rp]
+		cfs.rp++
+		return true
+	}
+}
+
+// Bytes returns the bytes of the field most recently read by Scan().
+func (cfs *CompositeTextScanner) Bytes() []byte {
+	return cfs.fieldBytes
+}
+
+// Err returns any error encountered by the scanner.
+func (cfs *CompositeTextScanner) Err() error {
+	return cfs.err
+}
+
+type CompositeBinaryBuilder struct {
+	ci         *ConnInfo
+	buf        []byte
+	startIdx   int
+	fieldCount uint32
+	err        error
+}
+
+func NewCompositeBinaryBuilder(ci *ConnInfo, buf []byte) *CompositeBinaryBuilder {
+	startIdx := len(buf)
+	buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields
+	return &CompositeBinaryBuilder{ci: ci, buf: buf, startIdx: startIdx}
+}
+
+func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) {
+	if b.err != nil {
+		return
+	}
+
+	if field == nil {
+		b.buf = pgio.AppendUint32(b.buf, oid)
+		b.buf = pgio.AppendInt32(b.buf, -1)
+		b.fieldCount++
+		return
+	}
+
+	plan := b.ci.PlanEncode(oid, BinaryFormatCode, field)
+	if plan == nil {
+		b.err = fmt.Errorf("unable to encode %v into OID %d in binary format", field, oid)
+		return
+	}
+
+	b.buf = pgio.AppendUint32(b.buf, oid)
+	lengthPos := len(b.buf)
+	b.buf = pgio.AppendInt32(b.buf, -1)
+	fieldBuf, err := plan.Encode(field, b.buf)
+	if err != nil {
+		b.err = err
+		return
+	}
+	if fieldBuf != nil {
+		binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf)))
+		b.buf = fieldBuf
+	}
+
+	b.fieldCount++
+}
+
+func (b *CompositeBinaryBuilder) Finish() ([]byte, error) {
+	if b.err != nil {
+		return nil, b.err
+	}
+
+	binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount)
+	return b.buf, nil
+}
+
+type CompositeTextBuilder struct {
+	ci         *ConnInfo
+	buf        []byte
+	startIdx   int
+	fieldCount uint32
+	err        error
+	fieldBuf   [32]byte
+}
+
+func NewCompositeTextBuilder(ci *ConnInfo, buf []byte) *CompositeTextBuilder {
+	buf = append(buf, '(') // allocate room for number of fields
+	return &CompositeTextBuilder{ci: ci, buf: buf}
+}
+
+func (b *CompositeTextBuilder) AppendValue(oid uint32, field interface{}) {
+	if b.err != nil {
+		return
+	}
+
+	if field == nil {
+		b.buf = append(b.buf, ',')
+		return
+	}
+
+	plan := b.ci.PlanEncode(oid, TextFormatCode, field)
+	if plan == nil {
+		b.err = fmt.Errorf("unable to encode %v into OID %d in text format", field, oid)
+		return
+	}
+
+	fieldBuf, err := plan.Encode(field, b.fieldBuf[0:0])
+	if err != nil {
+		b.err = err
+		return
+	}
+	if fieldBuf != nil {
+		b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...)
+	}
+
+	b.buf = append(b.buf, ',')
+}
+
+func (b *CompositeTextBuilder) Finish() ([]byte, error) {
+	if b.err != nil {
+		return nil, b.err
+	}
+
+	b.buf[len(b.buf)-1] = ')'
+	return b.buf, nil
+}
+
+var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
+
+func quoteCompositeField(src string) string {
+	return `"` + quoteCompositeReplacer.Replace(src) + `"`
+}
+
+func quoteCompositeFieldIfNeeded(src string) string {
+	if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) {
+		return quoteCompositeField(src)
+	}
+	return src
+}
+
+// CompositeFields represents the values of a composite value. It can be used as an encoding source or as a scan target.
+// It cannot scan a NULL, but the composite fields can be NULL.
+type CompositeFields []interface{}
+
+func (cf CompositeFields) SkipUnderlyingTypePlan() {}
+
+func (cf CompositeFields) IsNull() bool {
+	return cf == nil
+}
+
+func (cf CompositeFields) Index(i int) interface{} {
+	return cf[i]
+}
+
+func (cf CompositeFields) ScanNull() error {
+	return fmt.Errorf("cannot scan NULL into CompositeFields")
+}
+
+func (cf CompositeFields) ScanIndex(i int) interface{} {
+	return cf[i]
+}
diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go
new file mode 100644
index 00000000..ba91de80
--- /dev/null
+++ b/pgtype/composite_test.go
@@ -0,0 +1,76 @@
+package pgtype_test
+
+import (
+	"context"
+	"testing"
+
+	pgx "github.com/jackc/pgx/v5"
+	"github.com/jackc/pgx/v5/pgtype"
+	"github.com/jackc/pgx/v5/pgtype/testutil"
+	"github.com/stretchr/testify/require"
+)
+
+func TestCompositeCodecTranscode(t *testing.T) {
+	conn := testutil.MustConnectPgx(t)
+	defer testutil.MustCloseContext(t, conn)
+
+	_, err := conn.Exec(context.Background(), `drop type if exists ct_test;
+
+create type ct_test as (
+	a text,
+  b int4
+);`)
+	require.NoError(t, err)
+	defer conn.Exec(context.Background(), "drop type ct_test")
+
+	var oid uint32
+	err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid)
+	require.NoError(t, err)
+
+	defer conn.Exec(context.Background(), "drop type ct_test")
+
+	textDataType, ok := conn.ConnInfo().DataTypeForOID(pgtype.TextOID)
+	require.True(t, ok)
+
+	int4DataType, ok := conn.ConnInfo().DataTypeForOID(pgtype.Int4OID)
+	require.True(t, ok)
+
+	conn.ConnInfo().RegisterDataType(pgtype.DataType{
+		Name: "ct_test",
+		OID:  oid,
+		Codec: &pgtype.CompositeCodec{
+			Fields: []pgtype.CompositeCodecField{
+				{
+					Name:     "a",
+					DataType: textDataType,
+				},
+				{
+					Name:     "b",
+					DataType: int4DataType,
+				},
+			},
+		},
+	})
+
+	formats := []struct {
+		name string
+		code int16
+	}{
+		{name: "TextFormat", code: pgx.TextFormatCode},
+		{name: "BinaryFormat", code: pgx.BinaryFormatCode},
+	}
+
+	for _, format := range formats {
+		var a string
+		var b int32
+
+		err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QueryResultFormats{format.code},
+			pgtype.CompositeFields{"hi", int32(42)},
+		).Scan(
+			pgtype.CompositeFields{&a, &b},
+		)
+		require.NoErrorf(t, err, "%v", format.name)
+		require.EqualValuesf(t, "hi", a, "%v", format.name)
+		require.EqualValuesf(t, 42, b, "%v", format.name)
+	}
+}