mirror of https://github.com/jackc/pgx.git
380 lines
9.6 KiB
Go
380 lines
9.6 KiB
Go
package pgtype
|
|
|
|
import (
|
|
"database/sql/driver"
|
|
"fmt"
|
|
|
|
"github.com/jackc/pgx/v5/internal/pgio"
|
|
)
|
|
|
|
// RangeValuer is a type that can be converted into a PostgreSQL range.
|
|
type RangeValuer interface {
|
|
// IsNull returns true if the value is SQL NULL.
|
|
IsNull() bool
|
|
|
|
// BoundTypes returns the lower and upper bound types.
|
|
BoundTypes() (lower, upper BoundType)
|
|
|
|
// Bounds returns the lower and upper range values.
|
|
Bounds() (lower, upper any)
|
|
}
|
|
|
|
// RangeScanner is a type can be scanned from a PostgreSQL range.
|
|
type RangeScanner interface {
|
|
// ScanNull sets the value to SQL NULL.
|
|
ScanNull() error
|
|
|
|
// ScanBounds returns values usable as a scan target. The returned values may not be scanned if the range is empty or
|
|
// the bound type is unbounded.
|
|
ScanBounds() (lowerTarget, upperTarget any)
|
|
|
|
// SetBoundTypes sets the lower and upper bound types. ScanBounds will be called and the returned values scanned
|
|
// (if appropriate) before SetBoundTypes is called. If the bound types are unbounded or empty this method must
|
|
// also set the bound values.
|
|
SetBoundTypes(lower, upper BoundType) error
|
|
}
|
|
|
|
// RangeCodec is a codec for any range type.
|
|
type RangeCodec struct {
|
|
ElementType *Type
|
|
}
|
|
|
|
func (c *RangeCodec) FormatSupported(format int16) bool {
|
|
return c.ElementType.Codec.FormatSupported(format)
|
|
}
|
|
|
|
func (c *RangeCodec) PreferredFormat() int16 {
|
|
if c.FormatSupported(BinaryFormatCode) {
|
|
return BinaryFormatCode
|
|
}
|
|
return TextFormatCode
|
|
}
|
|
|
|
func (c *RangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
|
|
if _, ok := value.(RangeValuer); !ok {
|
|
return nil
|
|
}
|
|
|
|
switch format {
|
|
case BinaryFormatCode:
|
|
return &encodePlanRangeCodecRangeValuerToBinary{rc: c, m: m}
|
|
case TextFormatCode:
|
|
return &encodePlanRangeCodecRangeValuerToText{rc: c, m: m}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type encodePlanRangeCodecRangeValuerToBinary struct {
|
|
rc *RangeCodec
|
|
m *Map
|
|
}
|
|
|
|
func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
|
getter := value.(RangeValuer)
|
|
|
|
if getter.IsNull() {
|
|
return nil, nil
|
|
}
|
|
|
|
lowerType, upperType := getter.BoundTypes()
|
|
lower, upper := getter.Bounds()
|
|
|
|
var rangeType byte
|
|
switch lowerType {
|
|
case Inclusive:
|
|
rangeType |= lowerInclusiveMask
|
|
case Unbounded:
|
|
rangeType |= lowerUnboundedMask
|
|
case Exclusive:
|
|
case Empty:
|
|
return append(buf, emptyMask), nil
|
|
default:
|
|
return nil, fmt.Errorf("unknown LowerType: %v", lowerType)
|
|
}
|
|
|
|
switch upperType {
|
|
case Inclusive:
|
|
rangeType |= upperInclusiveMask
|
|
case Unbounded:
|
|
rangeType |= upperUnboundedMask
|
|
case Exclusive:
|
|
default:
|
|
return nil, fmt.Errorf("unknown UpperType: %v", upperType)
|
|
}
|
|
|
|
buf = append(buf, rangeType)
|
|
|
|
if lowerType != Unbounded {
|
|
if lower == nil {
|
|
return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded")
|
|
}
|
|
|
|
sp := len(buf)
|
|
buf = pgio.AppendInt32(buf, -1)
|
|
|
|
lowerPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, BinaryFormatCode, lower)
|
|
if lowerPlan == nil {
|
|
return nil, fmt.Errorf("cannot encode %v as element of range", lower)
|
|
}
|
|
|
|
buf, err = lowerPlan.Encode(lower, buf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to encode %v as element of range: %w", lower, err)
|
|
}
|
|
if buf == nil {
|
|
return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded")
|
|
}
|
|
|
|
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
|
}
|
|
|
|
if upperType != Unbounded {
|
|
if upper == nil {
|
|
return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded")
|
|
}
|
|
|
|
sp := len(buf)
|
|
buf = pgio.AppendInt32(buf, -1)
|
|
|
|
upperPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, BinaryFormatCode, upper)
|
|
if upperPlan == nil {
|
|
return nil, fmt.Errorf("cannot encode %v as element of range", upper)
|
|
}
|
|
|
|
buf, err = upperPlan.Encode(upper, buf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to encode %v as element of range: %w", upper, err)
|
|
}
|
|
if buf == nil {
|
|
return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded")
|
|
}
|
|
|
|
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
|
}
|
|
|
|
return buf, nil
|
|
}
|
|
|
|
type encodePlanRangeCodecRangeValuerToText struct {
|
|
rc *RangeCodec
|
|
m *Map
|
|
}
|
|
|
|
func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
|
getter := value.(RangeValuer)
|
|
|
|
if getter.IsNull() {
|
|
return nil, nil
|
|
}
|
|
|
|
lowerType, upperType := getter.BoundTypes()
|
|
lower, upper := getter.Bounds()
|
|
|
|
switch lowerType {
|
|
case Exclusive, Unbounded:
|
|
buf = append(buf, '(')
|
|
case Inclusive:
|
|
buf = append(buf, '[')
|
|
case Empty:
|
|
return append(buf, "empty"...), nil
|
|
default:
|
|
return nil, fmt.Errorf("unknown lower bound type %v", lowerType)
|
|
}
|
|
|
|
if lowerType != Unbounded {
|
|
if lower == nil {
|
|
return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded")
|
|
}
|
|
|
|
lowerPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, lower)
|
|
if lowerPlan == nil {
|
|
return nil, fmt.Errorf("cannot encode %v as element of range", lower)
|
|
}
|
|
|
|
buf, err = lowerPlan.Encode(lower, buf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to encode %v as element of range: %w", lower, err)
|
|
}
|
|
if buf == nil {
|
|
return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded")
|
|
}
|
|
}
|
|
|
|
buf = append(buf, ',')
|
|
|
|
if upperType != Unbounded {
|
|
if upper == nil {
|
|
return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded")
|
|
}
|
|
|
|
upperPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, upper)
|
|
if upperPlan == nil {
|
|
return nil, fmt.Errorf("cannot encode %v as element of range", upper)
|
|
}
|
|
|
|
buf, err = upperPlan.Encode(upper, buf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to encode %v as element of range: %w", upper, err)
|
|
}
|
|
if buf == nil {
|
|
return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded")
|
|
}
|
|
}
|
|
|
|
switch upperType {
|
|
case Exclusive, Unbounded:
|
|
buf = append(buf, ')')
|
|
case Inclusive:
|
|
buf = append(buf, ']')
|
|
default:
|
|
return nil, fmt.Errorf("unknown upper bound type %v", upperType)
|
|
}
|
|
|
|
return buf, nil
|
|
}
|
|
|
|
func (c *RangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
|
switch format {
|
|
case BinaryFormatCode:
|
|
switch target.(type) {
|
|
case RangeScanner:
|
|
return &scanPlanBinaryRangeToRangeScanner{rc: c, m: m}
|
|
}
|
|
case TextFormatCode:
|
|
switch target.(type) {
|
|
case RangeScanner:
|
|
return &scanPlanTextRangeToRangeScanner{rc: c, m: m}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type scanPlanBinaryRangeToRangeScanner struct {
|
|
rc *RangeCodec
|
|
m *Map
|
|
}
|
|
|
|
func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) error {
|
|
rangeScanner := (target).(RangeScanner)
|
|
|
|
if src == nil {
|
|
return rangeScanner.ScanNull()
|
|
}
|
|
|
|
ubr, err := parseUntypedBinaryRange(src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if ubr.LowerType == Empty {
|
|
return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType)
|
|
}
|
|
|
|
lowerTarget, upperTarget := rangeScanner.ScanBounds()
|
|
|
|
if ubr.LowerType == Inclusive || ubr.LowerType == Exclusive {
|
|
lowerPlan := plan.m.PlanScan(plan.rc.ElementType.OID, BinaryFormatCode, lowerTarget)
|
|
if lowerPlan == nil {
|
|
return fmt.Errorf("cannot scan into %v from range element", lowerTarget)
|
|
}
|
|
|
|
err = lowerPlan.Scan(ubr.Lower, lowerTarget)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot scan into %v from range element: %w", lowerTarget, err)
|
|
}
|
|
}
|
|
|
|
if ubr.UpperType == Inclusive || ubr.UpperType == Exclusive {
|
|
upperPlan := plan.m.PlanScan(plan.rc.ElementType.OID, BinaryFormatCode, upperTarget)
|
|
if upperPlan == nil {
|
|
return fmt.Errorf("cannot scan into %v from range element", upperTarget)
|
|
}
|
|
|
|
err = upperPlan.Scan(ubr.Upper, upperTarget)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot scan into %v from range element: %w", upperTarget, err)
|
|
}
|
|
}
|
|
|
|
return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType)
|
|
}
|
|
|
|
type scanPlanTextRangeToRangeScanner struct {
|
|
rc *RangeCodec
|
|
m *Map
|
|
}
|
|
|
|
func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error {
|
|
rangeScanner := (target).(RangeScanner)
|
|
|
|
if src == nil {
|
|
return rangeScanner.ScanNull()
|
|
}
|
|
|
|
utr, err := parseUntypedTextRange(string(src))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if utr.LowerType == Empty {
|
|
return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType)
|
|
}
|
|
|
|
lowerTarget, upperTarget := rangeScanner.ScanBounds()
|
|
|
|
if utr.LowerType == Inclusive || utr.LowerType == Exclusive {
|
|
lowerPlan := plan.m.PlanScan(plan.rc.ElementType.OID, TextFormatCode, lowerTarget)
|
|
if lowerPlan == nil {
|
|
return fmt.Errorf("cannot scan into %v from range element", lowerTarget)
|
|
}
|
|
|
|
err = lowerPlan.Scan([]byte(utr.Lower), lowerTarget)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot scan into %v from range element: %w", lowerTarget, err)
|
|
}
|
|
}
|
|
|
|
if utr.UpperType == Inclusive || utr.UpperType == Exclusive {
|
|
upperPlan := plan.m.PlanScan(plan.rc.ElementType.OID, TextFormatCode, upperTarget)
|
|
if upperPlan == nil {
|
|
return fmt.Errorf("cannot scan into %v from range element", upperTarget)
|
|
}
|
|
|
|
err = upperPlan.Scan([]byte(utr.Upper), upperTarget)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot scan into %v from range element: %w", upperTarget, err)
|
|
}
|
|
}
|
|
|
|
return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType)
|
|
}
|
|
|
|
func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
|
|
if src == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
switch format {
|
|
case TextFormatCode:
|
|
return string(src), nil
|
|
case BinaryFormatCode:
|
|
buf := make([]byte, len(src))
|
|
copy(buf, src)
|
|
return buf, nil
|
|
default:
|
|
return nil, fmt.Errorf("unknown format code %d", format)
|
|
}
|
|
}
|
|
|
|
func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
|
|
if src == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
var r Range[any]
|
|
err := c.PlanScan(m, oid, format, &r).Scan(src, &r)
|
|
return r, err
|
|
}
|