mirror of https://github.com/jackc/pgx.git
603 lines
14 KiB
Go
603 lines
14 KiB
Go
package pgtype
|
|
|
|
import (
|
|
"database/sql/driver"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/jackc/pgx/v5/internal/pgio"
|
|
)
|
|
|
|
// CompositeIndexGetter is a type accessed by index that can be converted into a PostgreSQL composite.
|
|
type CompositeIndexGetter interface {
|
|
// IsNull returns true if the value is SQL NULL.
|
|
IsNull() bool
|
|
|
|
// Index returns the element at i.
|
|
Index(i int) any
|
|
}
|
|
|
|
// CompositeIndexScanner is a type accessed by index that can be scanned from a PostgreSQL composite.
|
|
type CompositeIndexScanner interface {
|
|
// ScanNull sets the value to SQL NULL.
|
|
ScanNull() error
|
|
|
|
// ScanIndex returns a value usable as a scan target for i.
|
|
ScanIndex(i int) any
|
|
}
|
|
|
|
type CompositeCodecField struct {
|
|
Name string
|
|
Type *Type
|
|
}
|
|
|
|
type CompositeCodec struct {
|
|
Fields []CompositeCodecField
|
|
}
|
|
|
|
func (c *CompositeCodec) FormatSupported(format int16) bool {
|
|
for _, f := range c.Fields {
|
|
if !f.Type.Codec.FormatSupported(format) {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (c *CompositeCodec) PreferredFormat() int16 {
|
|
if c.FormatSupported(BinaryFormatCode) {
|
|
return BinaryFormatCode
|
|
}
|
|
return TextFormatCode
|
|
}
|
|
|
|
func (c *CompositeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
|
if _, ok := value.(CompositeIndexGetter); !ok {
|
|
return nil
|
|
}
|
|
|
|
switch format {
|
|
case BinaryFormatCode:
|
|
return &encodePlanCompositeCodecCompositeIndexGetterToBinary{cc: c, m: m}
|
|
case TextFormatCode:
|
|
return &encodePlanCompositeCodecCompositeIndexGetterToText{cc: c, m: m}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type encodePlanCompositeCodecCompositeIndexGetterToBinary struct {
|
|
cc *CompositeCodec
|
|
m *Map
|
|
}
|
|
|
|
func (plan *encodePlanCompositeCodecCompositeIndexGetterToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
|
getter := value.(CompositeIndexGetter)
|
|
|
|
if getter.IsNull() {
|
|
return nil, nil
|
|
}
|
|
|
|
builder := NewCompositeBinaryBuilder(plan.m, buf)
|
|
for i, field := range plan.cc.Fields {
|
|
builder.AppendValue(field.Type.OID, getter.Index(i))
|
|
}
|
|
|
|
return builder.Finish()
|
|
}
|
|
|
|
type encodePlanCompositeCodecCompositeIndexGetterToText struct {
|
|
cc *CompositeCodec
|
|
m *Map
|
|
}
|
|
|
|
func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
|
getter := value.(CompositeIndexGetter)
|
|
|
|
if getter.IsNull() {
|
|
return nil, nil
|
|
}
|
|
|
|
b := NewCompositeTextBuilder(plan.m, buf)
|
|
for i, field := range plan.cc.Fields {
|
|
b.AppendValue(field.Type.OID, getter.Index(i))
|
|
}
|
|
|
|
return b.Finish()
|
|
}
|
|
|
|
func (c *CompositeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
|
switch format {
|
|
case BinaryFormatCode:
|
|
switch target.(type) {
|
|
case CompositeIndexScanner:
|
|
return &scanPlanBinaryCompositeToCompositeIndexScanner{cc: c, m: m}
|
|
}
|
|
case TextFormatCode:
|
|
switch target.(type) {
|
|
case CompositeIndexScanner:
|
|
return &scanPlanTextCompositeToCompositeIndexScanner{cc: c, m: m}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type scanPlanBinaryCompositeToCompositeIndexScanner struct {
|
|
cc *CompositeCodec
|
|
m *Map
|
|
}
|
|
|
|
func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, target any) error {
|
|
targetScanner := (target).(CompositeIndexScanner)
|
|
|
|
if src == nil {
|
|
return targetScanner.ScanNull()
|
|
}
|
|
|
|
scanner := NewCompositeBinaryScanner(plan.m, src)
|
|
for i, field := range plan.cc.Fields {
|
|
if scanner.Next() {
|
|
fieldTarget := targetScanner.ScanIndex(i)
|
|
if fieldTarget != nil {
|
|
fieldPlan := plan.m.PlanScan(field.Type.OID, BinaryFormatCode, fieldTarget)
|
|
if fieldPlan == nil {
|
|
return fmt.Errorf("unable to encode %v into OID %d in binary format", field, field.Type.OID)
|
|
}
|
|
|
|
err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
} else {
|
|
return errors.New("read past end of composite")
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type scanPlanTextCompositeToCompositeIndexScanner struct {
|
|
cc *CompositeCodec
|
|
m *Map
|
|
}
|
|
|
|
func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, target any) error {
|
|
targetScanner := (target).(CompositeIndexScanner)
|
|
|
|
if src == nil {
|
|
return targetScanner.ScanNull()
|
|
}
|
|
|
|
scanner := NewCompositeTextScanner(plan.m, src)
|
|
for i, field := range plan.cc.Fields {
|
|
if scanner.Next() {
|
|
fieldTarget := targetScanner.ScanIndex(i)
|
|
if fieldTarget != nil {
|
|
fieldPlan := plan.m.PlanScan(field.Type.OID, TextFormatCode, fieldTarget)
|
|
if fieldPlan == nil {
|
|
return fmt.Errorf("unable to encode %v into OID %d in text format", field, field.Type.OID)
|
|
}
|
|
|
|
err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
} else {
|
|
return errors.New("read past end of composite")
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *CompositeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
|
if src == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
switch format {
|
|
case TextFormatCode:
|
|
return string(src), nil
|
|
case BinaryFormatCode:
|
|
buf := make([]byte, len(src))
|
|
copy(buf, src)
|
|
return buf, nil
|
|
default:
|
|
return nil, fmt.Errorf("unknown format code %d", format)
|
|
}
|
|
}
|
|
|
|
func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
|
if src == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
switch format {
|
|
case TextFormatCode:
|
|
scanner := NewCompositeTextScanner(m, src)
|
|
values := make(map[string]any, len(c.Fields))
|
|
for i := 0; scanner.Next() && i < len(c.Fields); i++ {
|
|
var v any
|
|
fieldPlan := m.PlanScan(c.Fields[i].Type.OID, TextFormatCode, &v)
|
|
if fieldPlan == nil {
|
|
return nil, fmt.Errorf("unable to scan OID %d in text format into %v", c.Fields[i].Type.OID, v)
|
|
}
|
|
|
|
err := fieldPlan.Scan(scanner.Bytes(), &v)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
values[c.Fields[i].Name] = v
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return values, nil
|
|
case BinaryFormatCode:
|
|
scanner := NewCompositeBinaryScanner(m, src)
|
|
values := make(map[string]any, len(c.Fields))
|
|
for i := 0; scanner.Next() && i < len(c.Fields); i++ {
|
|
var v any
|
|
fieldPlan := m.PlanScan(scanner.OID(), BinaryFormatCode, &v)
|
|
if fieldPlan == nil {
|
|
return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v)
|
|
}
|
|
|
|
err := fieldPlan.Scan(scanner.Bytes(), &v)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
values[c.Fields[i].Name] = v
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return values, nil
|
|
default:
|
|
return nil, fmt.Errorf("unknown format code %d", format)
|
|
}
|
|
|
|
}
|
|
|
|
type CompositeBinaryScanner struct {
|
|
m *Map
|
|
rp int
|
|
src []byte
|
|
|
|
fieldCount int32
|
|
fieldBytes []byte
|
|
fieldOID uint32
|
|
err error
|
|
}
|
|
|
|
// NewCompositeBinaryScanner a scanner over a binary encoded composite balue.
|
|
func NewCompositeBinaryScanner(m *Map, src []byte) *CompositeBinaryScanner {
|
|
rp := 0
|
|
if len(src[rp:]) < 4 {
|
|
return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)}
|
|
}
|
|
|
|
fieldCount := int32(binary.BigEndian.Uint32(src[rp:]))
|
|
rp += 4
|
|
|
|
return &CompositeBinaryScanner{
|
|
m: m,
|
|
rp: rp,
|
|
src: src,
|
|
fieldCount: fieldCount,
|
|
}
|
|
}
|
|
|
|
// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
|
|
// Next returns false, the Err method can be called to check if any errors occurred.
|
|
func (cfs *CompositeBinaryScanner) Next() bool {
|
|
if cfs.err != nil {
|
|
return false
|
|
}
|
|
|
|
if cfs.rp == len(cfs.src) {
|
|
return false
|
|
}
|
|
|
|
if len(cfs.src[cfs.rp:]) < 8 {
|
|
cfs.err = fmt.Errorf("Record incomplete %v", cfs.src)
|
|
return false
|
|
}
|
|
cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:])
|
|
cfs.rp += 4
|
|
|
|
fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:])))
|
|
cfs.rp += 4
|
|
|
|
if fieldLen >= 0 {
|
|
if len(cfs.src[cfs.rp:]) < fieldLen {
|
|
cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src)
|
|
return false
|
|
}
|
|
cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen]
|
|
cfs.rp += fieldLen
|
|
} else {
|
|
cfs.fieldBytes = nil
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (cfs *CompositeBinaryScanner) FieldCount() int {
|
|
return int(cfs.fieldCount)
|
|
}
|
|
|
|
// Bytes returns the bytes of the field most recently read by Scan().
|
|
func (cfs *CompositeBinaryScanner) Bytes() []byte {
|
|
return cfs.fieldBytes
|
|
}
|
|
|
|
// OID returns the OID of the field most recently read by Scan().
|
|
func (cfs *CompositeBinaryScanner) OID() uint32 {
|
|
return cfs.fieldOID
|
|
}
|
|
|
|
// Err returns any error encountered by the scanner.
|
|
func (cfs *CompositeBinaryScanner) Err() error {
|
|
return cfs.err
|
|
}
|
|
|
|
type CompositeTextScanner struct {
|
|
m *Map
|
|
rp int
|
|
src []byte
|
|
|
|
fieldBytes []byte
|
|
err error
|
|
}
|
|
|
|
// NewCompositeTextScanner a scanner over a text encoded composite value.
|
|
func NewCompositeTextScanner(m *Map, src []byte) *CompositeTextScanner {
|
|
if len(src) < 2 {
|
|
return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)}
|
|
}
|
|
|
|
if src[0] != '(' {
|
|
return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")}
|
|
}
|
|
|
|
if src[len(src)-1] != ')' {
|
|
return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")}
|
|
}
|
|
|
|
return &CompositeTextScanner{
|
|
m: m,
|
|
rp: 1,
|
|
src: src,
|
|
}
|
|
}
|
|
|
|
// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
|
|
// Next returns false, the Err method can be called to check if any errors occurred.
|
|
func (cfs *CompositeTextScanner) Next() bool {
|
|
if cfs.err != nil {
|
|
return false
|
|
}
|
|
|
|
if cfs.rp == len(cfs.src) {
|
|
return false
|
|
}
|
|
|
|
switch cfs.src[cfs.rp] {
|
|
case ',', ')': // null
|
|
cfs.rp++
|
|
cfs.fieldBytes = nil
|
|
return true
|
|
case '"': // quoted value
|
|
cfs.rp++
|
|
cfs.fieldBytes = make([]byte, 0, 16)
|
|
for {
|
|
ch := cfs.src[cfs.rp]
|
|
|
|
if ch == '"' {
|
|
cfs.rp++
|
|
if cfs.src[cfs.rp] == '"' {
|
|
cfs.fieldBytes = append(cfs.fieldBytes, '"')
|
|
cfs.rp++
|
|
} else {
|
|
break
|
|
}
|
|
} else if ch == '\\' {
|
|
cfs.rp++
|
|
cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp])
|
|
cfs.rp++
|
|
} else {
|
|
cfs.fieldBytes = append(cfs.fieldBytes, ch)
|
|
cfs.rp++
|
|
}
|
|
}
|
|
cfs.rp++
|
|
return true
|
|
default: // unquoted value
|
|
start := cfs.rp
|
|
for {
|
|
ch := cfs.src[cfs.rp]
|
|
if ch == ',' || ch == ')' {
|
|
break
|
|
}
|
|
cfs.rp++
|
|
}
|
|
cfs.fieldBytes = cfs.src[start:cfs.rp]
|
|
cfs.rp++
|
|
return true
|
|
}
|
|
}
|
|
|
|
// Bytes returns the bytes of the field most recently read by Scan().
|
|
func (cfs *CompositeTextScanner) Bytes() []byte {
|
|
return cfs.fieldBytes
|
|
}
|
|
|
|
// Err returns any error encountered by the scanner.
|
|
func (cfs *CompositeTextScanner) Err() error {
|
|
return cfs.err
|
|
}
|
|
|
|
type CompositeBinaryBuilder struct {
|
|
m *Map
|
|
buf []byte
|
|
startIdx int
|
|
fieldCount uint32
|
|
err error
|
|
}
|
|
|
|
func NewCompositeBinaryBuilder(m *Map, buf []byte) *CompositeBinaryBuilder {
|
|
startIdx := len(buf)
|
|
buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields
|
|
return &CompositeBinaryBuilder{m: m, buf: buf, startIdx: startIdx}
|
|
}
|
|
|
|
func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field any) {
|
|
if b.err != nil {
|
|
return
|
|
}
|
|
|
|
if field == nil {
|
|
b.buf = pgio.AppendUint32(b.buf, oid)
|
|
b.buf = pgio.AppendInt32(b.buf, -1)
|
|
b.fieldCount++
|
|
return
|
|
}
|
|
|
|
plan := b.m.PlanEncode(oid, BinaryFormatCode, field)
|
|
if plan == nil {
|
|
b.err = fmt.Errorf("unable to encode %v into OID %d in binary format", field, oid)
|
|
return
|
|
}
|
|
|
|
b.buf = pgio.AppendUint32(b.buf, oid)
|
|
lengthPos := len(b.buf)
|
|
b.buf = pgio.AppendInt32(b.buf, -1)
|
|
fieldBuf, err := plan.Encode(field, b.buf)
|
|
if err != nil {
|
|
b.err = err
|
|
return
|
|
}
|
|
if fieldBuf != nil {
|
|
binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf)))
|
|
b.buf = fieldBuf
|
|
}
|
|
|
|
b.fieldCount++
|
|
}
|
|
|
|
func (b *CompositeBinaryBuilder) Finish() ([]byte, error) {
|
|
if b.err != nil {
|
|
return nil, b.err
|
|
}
|
|
|
|
binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount)
|
|
return b.buf, nil
|
|
}
|
|
|
|
type CompositeTextBuilder struct {
|
|
m *Map
|
|
buf []byte
|
|
startIdx int
|
|
fieldCount uint32
|
|
err error
|
|
fieldBuf [32]byte
|
|
}
|
|
|
|
func NewCompositeTextBuilder(m *Map, buf []byte) *CompositeTextBuilder {
|
|
buf = append(buf, '(') // allocate room for number of fields
|
|
return &CompositeTextBuilder{m: m, buf: buf}
|
|
}
|
|
|
|
func (b *CompositeTextBuilder) AppendValue(oid uint32, field any) {
|
|
if b.err != nil {
|
|
return
|
|
}
|
|
|
|
if field == nil {
|
|
b.buf = append(b.buf, ',')
|
|
return
|
|
}
|
|
|
|
plan := b.m.PlanEncode(oid, TextFormatCode, field)
|
|
if plan == nil {
|
|
b.err = fmt.Errorf("unable to encode %v into OID %d in text format", field, oid)
|
|
return
|
|
}
|
|
|
|
fieldBuf, err := plan.Encode(field, b.fieldBuf[0:0])
|
|
if err != nil {
|
|
b.err = err
|
|
return
|
|
}
|
|
if fieldBuf != nil {
|
|
b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...)
|
|
}
|
|
|
|
b.buf = append(b.buf, ',')
|
|
}
|
|
|
|
func (b *CompositeTextBuilder) Finish() ([]byte, error) {
|
|
if b.err != nil {
|
|
return nil, b.err
|
|
}
|
|
|
|
b.buf[len(b.buf)-1] = ')'
|
|
return b.buf, nil
|
|
}
|
|
|
|
var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
|
|
|
|
func quoteCompositeField(src string) string {
|
|
return `"` + quoteCompositeReplacer.Replace(src) + `"`
|
|
}
|
|
|
|
func quoteCompositeFieldIfNeeded(src string) string {
|
|
if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) {
|
|
return quoteCompositeField(src)
|
|
}
|
|
return src
|
|
}
|
|
|
|
// CompositeFields represents the values of a composite value. It can be used as an encoding source or as a scan target.
|
|
// It cannot scan a NULL, but the composite fields can be NULL.
|
|
type CompositeFields []any
|
|
|
|
func (cf CompositeFields) SkipUnderlyingTypePlan() {}
|
|
|
|
func (cf CompositeFields) IsNull() bool {
|
|
return cf == nil
|
|
}
|
|
|
|
func (cf CompositeFields) Index(i int) any {
|
|
return cf[i]
|
|
}
|
|
|
|
func (cf CompositeFields) ScanNull() error {
|
|
return fmt.Errorf("cannot scan NULL into CompositeFields")
|
|
}
|
|
|
|
func (cf CompositeFields) ScanIndex(i int) any {
|
|
return cf[i]
|
|
}
|