mirror of https://github.com/jackc/pgx.git
444 lines
9.6 KiB
Go
444 lines
9.6 KiB
Go
package pgtype
|
|
|
|
import (
|
|
"bytes"
|
|
"database/sql/driver"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"github.com/jackc/pgx/v5/internal/pgio"
|
|
)
|
|
|
|
// MultirangeGetter is a type that can be converted into a PostgreSQL multirange.
|
|
type MultirangeGetter interface {
|
|
// IsNull returns true if the value is SQL NULL.
|
|
IsNull() bool
|
|
|
|
// Len returns the number of elements in the multirange.
|
|
Len() int
|
|
|
|
// Index returns the element at i.
|
|
Index(i int) any
|
|
|
|
// IndexType returns a non-nil scan target of the type Index will return. This is used by MultirangeCodec.PlanEncode.
|
|
IndexType() any
|
|
}
|
|
|
|
// MultirangeSetter is a type can be set from a PostgreSQL multirange.
|
|
type MultirangeSetter interface {
|
|
// ScanNull sets the value to SQL NULL.
|
|
ScanNull() error
|
|
|
|
// SetLen prepares the value such that ScanIndex can be called for each element. This will remove any existing
|
|
// elements.
|
|
SetLen(n int) error
|
|
|
|
// ScanIndex returns a value usable as a scan target for i. SetLen must be called before ScanIndex.
|
|
ScanIndex(i int) any
|
|
|
|
// ScanIndexType returns a non-nil scan target of the type ScanIndex will return. This is used by
|
|
// MultirangeCodec.PlanScan.
|
|
ScanIndexType() any
|
|
}
|
|
|
|
// MultirangeCodec is a codec for any multirange type.
|
|
type MultirangeCodec struct {
|
|
ElementType *Type
|
|
}
|
|
|
|
func (c *MultirangeCodec) FormatSupported(format int16) bool {
|
|
return c.ElementType.Codec.FormatSupported(format)
|
|
}
|
|
|
|
func (c *MultirangeCodec) PreferredFormat() int16 {
|
|
return c.ElementType.Codec.PreferredFormat()
|
|
}
|
|
|
|
func (c *MultirangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
|
multirangeValuer, ok := value.(MultirangeGetter)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
elementType := multirangeValuer.IndexType()
|
|
|
|
elementEncodePlan := m.PlanEncode(c.ElementType.OID, format, elementType)
|
|
if elementEncodePlan == nil {
|
|
return nil
|
|
}
|
|
|
|
switch format {
|
|
case BinaryFormatCode:
|
|
return &encodePlanMultirangeCodecBinary{ac: c, m: m, oid: oid}
|
|
case TextFormatCode:
|
|
return &encodePlanMultirangeCodecText{ac: c, m: m, oid: oid}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type encodePlanMultirangeCodecText struct {
|
|
ac *MultirangeCodec
|
|
m *Map
|
|
oid uint32
|
|
}
|
|
|
|
func (p *encodePlanMultirangeCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
|
multirange := value.(MultirangeGetter)
|
|
|
|
if multirange.IsNull() {
|
|
return nil, nil
|
|
}
|
|
|
|
elementCount := multirange.Len()
|
|
|
|
buf = append(buf, '{')
|
|
|
|
var encodePlan EncodePlan
|
|
var lastElemType reflect.Type
|
|
inElemBuf := make([]byte, 0, 32)
|
|
for i := 0; i < elementCount; i++ {
|
|
if i > 0 {
|
|
buf = append(buf, ',')
|
|
}
|
|
|
|
elem := multirange.Index(i)
|
|
var elemBuf []byte
|
|
if elem != nil {
|
|
elemType := reflect.TypeOf(elem)
|
|
if lastElemType != elemType {
|
|
lastElemType = elemType
|
|
encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, TextFormatCode, elem)
|
|
if encodePlan == nil {
|
|
return nil, fmt.Errorf("unable to encode %v", multirange.Index(i))
|
|
}
|
|
}
|
|
elemBuf, err = encodePlan.Encode(elem, inElemBuf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if elemBuf == nil {
|
|
return nil, fmt.Errorf("multirange cannot contain NULL element")
|
|
} else {
|
|
buf = append(buf, elemBuf...)
|
|
}
|
|
}
|
|
|
|
buf = append(buf, '}')
|
|
|
|
return buf, nil
|
|
}
|
|
|
|
type encodePlanMultirangeCodecBinary struct {
|
|
ac *MultirangeCodec
|
|
m *Map
|
|
oid uint32
|
|
}
|
|
|
|
func (p *encodePlanMultirangeCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
|
multirange := value.(MultirangeGetter)
|
|
|
|
if multirange.IsNull() {
|
|
return nil, nil
|
|
}
|
|
|
|
elementCount := multirange.Len()
|
|
|
|
buf = pgio.AppendInt32(buf, int32(elementCount))
|
|
|
|
var encodePlan EncodePlan
|
|
var lastElemType reflect.Type
|
|
for i := 0; i < elementCount; i++ {
|
|
sp := len(buf)
|
|
buf = pgio.AppendInt32(buf, -1)
|
|
|
|
elem := multirange.Index(i)
|
|
var elemBuf []byte
|
|
if elem != nil {
|
|
elemType := reflect.TypeOf(elem)
|
|
if lastElemType != elemType {
|
|
lastElemType = elemType
|
|
encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, BinaryFormatCode, elem)
|
|
if encodePlan == nil {
|
|
return nil, fmt.Errorf("unable to encode %v", multirange.Index(i))
|
|
}
|
|
}
|
|
elemBuf, err = encodePlan.Encode(elem, buf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if elemBuf == nil {
|
|
return nil, fmt.Errorf("multirange cannot contain NULL element")
|
|
} else {
|
|
buf = elemBuf
|
|
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
|
}
|
|
}
|
|
|
|
return buf, nil
|
|
}
|
|
|
|
func (c *MultirangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
|
multirangeScanner, ok := target.(MultirangeSetter)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
elementType := multirangeScanner.ScanIndexType()
|
|
|
|
elementScanPlan := m.PlanScan(c.ElementType.OID, format, elementType)
|
|
if _, ok := elementScanPlan.(*scanPlanFail); ok {
|
|
return nil
|
|
}
|
|
|
|
return &scanPlanMultirangeCodec{
|
|
multirangeCodec: c,
|
|
m: m,
|
|
oid: oid,
|
|
formatCode: format,
|
|
}
|
|
}
|
|
|
|
func (c *MultirangeCodec) decodeBinary(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error {
|
|
rp := 0
|
|
|
|
elementCount := int(binary.BigEndian.Uint32(src[rp:]))
|
|
rp += 4
|
|
|
|
err := multirange.SetLen(elementCount)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if elementCount == 0 {
|
|
return nil
|
|
}
|
|
|
|
elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0))
|
|
if elementScanPlan == nil {
|
|
elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0))
|
|
}
|
|
|
|
for i := 0; i < elementCount; i++ {
|
|
elem := multirange.ScanIndex(i)
|
|
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
|
rp += 4
|
|
var elemSrc []byte
|
|
if elemLen >= 0 {
|
|
elemSrc = src[rp : rp+elemLen]
|
|
rp += elemLen
|
|
}
|
|
err = elementScanPlan.Scan(elemSrc, elem)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to scan multirange element %d: %w", i, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *MultirangeCodec) decodeText(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error {
|
|
elements, err := parseUntypedTextMultirange(src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = multirange.SetLen(len(elements))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(elements) == 0 {
|
|
return nil
|
|
}
|
|
|
|
elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0))
|
|
if elementScanPlan == nil {
|
|
elementScanPlan = m.PlanScan(c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0))
|
|
}
|
|
|
|
for i, s := range elements {
|
|
elem := multirange.ScanIndex(i)
|
|
err = elementScanPlan.Scan([]byte(s), elem)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type scanPlanMultirangeCodec struct {
|
|
multirangeCodec *MultirangeCodec
|
|
m *Map
|
|
oid uint32
|
|
formatCode int16
|
|
elementScanPlan ScanPlan
|
|
}
|
|
|
|
func (spac *scanPlanMultirangeCodec) Scan(src []byte, dst any) error {
|
|
c := spac.multirangeCodec
|
|
m := spac.m
|
|
oid := spac.oid
|
|
formatCode := spac.formatCode
|
|
|
|
multirange := dst.(MultirangeSetter)
|
|
|
|
if src == nil {
|
|
return multirange.ScanNull()
|
|
}
|
|
|
|
switch formatCode {
|
|
case BinaryFormatCode:
|
|
return c.decodeBinary(m, oid, src, multirange)
|
|
case TextFormatCode:
|
|
return c.decodeText(m, oid, src, multirange)
|
|
default:
|
|
return fmt.Errorf("unknown format code %d", formatCode)
|
|
}
|
|
}
|
|
|
|
func (c *MultirangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
|
if src == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
switch format {
|
|
case TextFormatCode:
|
|
return string(src), nil
|
|
case BinaryFormatCode:
|
|
buf := make([]byte, len(src))
|
|
copy(buf, src)
|
|
return buf, nil
|
|
default:
|
|
return nil, fmt.Errorf("unknown format code %d", format)
|
|
}
|
|
}
|
|
|
|
func (c *MultirangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
|
if src == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
var multirange Multirange[Range[any]]
|
|
err := m.PlanScan(oid, format, &multirange).Scan(src, &multirange)
|
|
return multirange, err
|
|
}
|
|
|
|
func parseUntypedTextMultirange(src []byte) ([]string, error) {
|
|
elements := make([]string, 0)
|
|
|
|
buf := bytes.NewBuffer(src)
|
|
|
|
skipWhitespace(buf)
|
|
|
|
r, _, err := buf.ReadRune()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid array: %w", err)
|
|
}
|
|
|
|
if r != '{' {
|
|
return nil, fmt.Errorf("invalid multirange, expected '{' got %v", r)
|
|
}
|
|
|
|
parseValueLoop:
|
|
for {
|
|
r, _, err = buf.ReadRune()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid multirange: %w", err)
|
|
}
|
|
|
|
switch r {
|
|
case ',': // skip range separator
|
|
case '}':
|
|
break parseValueLoop
|
|
default:
|
|
buf.UnreadRune()
|
|
value, err := parseRange(buf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid multirange value: %w", err)
|
|
}
|
|
elements = append(elements, value)
|
|
}
|
|
}
|
|
|
|
skipWhitespace(buf)
|
|
|
|
if buf.Len() > 0 {
|
|
return nil, fmt.Errorf("unexpected trailing data: %v", buf.String())
|
|
}
|
|
|
|
return elements, nil
|
|
|
|
}
|
|
|
|
func parseRange(buf *bytes.Buffer) (string, error) {
|
|
s := &bytes.Buffer{}
|
|
|
|
boundSepRead := false
|
|
for {
|
|
r, _, err := buf.ReadRune()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
switch r {
|
|
case ',', '}':
|
|
if r == ',' && !boundSepRead {
|
|
boundSepRead = true
|
|
break
|
|
}
|
|
buf.UnreadRune()
|
|
return s.String(), nil
|
|
}
|
|
|
|
s.WriteRune(r)
|
|
}
|
|
}
|
|
|
|
// Multirange is a generic multirange type.
|
|
//
|
|
// T should implement RangeValuer and *T should implement RangeScanner. However, there does not appear to be a way to
|
|
// enforce the RangeScanner constraint.
|
|
type Multirange[T RangeValuer] []T
|
|
|
|
func (r Multirange[T]) IsNull() bool {
|
|
return r == nil
|
|
}
|
|
|
|
func (r Multirange[T]) Len() int {
|
|
return len(r)
|
|
}
|
|
|
|
func (r Multirange[T]) Index(i int) any {
|
|
return r[i]
|
|
}
|
|
|
|
func (r Multirange[T]) IndexType() any {
|
|
var zero T
|
|
return zero
|
|
}
|
|
|
|
func (r *Multirange[T]) ScanNull() error {
|
|
*r = nil
|
|
return nil
|
|
}
|
|
|
|
func (r *Multirange[T]) SetLen(n int) error {
|
|
*r = make([]T, n)
|
|
return nil
|
|
}
|
|
|
|
func (r Multirange[T]) ScanIndex(i int) any {
|
|
return &r[i]
|
|
}
|
|
|
|
func (r Multirange[T]) ScanIndexType() any {
|
|
return new(T)
|
|
}
|