ValueTranscoder uses new interfaces

query-exec-mode
Jack Christensen 2021-12-04 12:45:20 -06:00
parent 8f454e4cd6
commit e22675d20b
4 changed files with 123 additions and 48 deletions

View File

@ -129,12 +129,44 @@ func (src *ArrayType) AssignTo(dst interface{}) error {
}
}
func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error {
func (ArrayType) BinaryFormatSupported() bool {
return true
}
func (ArrayType) TextFormatSupported() bool {
return true
}
func (ArrayType) PreferredFormat() int16 {
return TextFormatCode
}
func (dst *ArrayType) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error {
if src == nil {
dst.setNil()
return nil
}
switch format {
case BinaryFormatCode:
return dst.DecodeBinary(ci, src)
case TextFormatCode:
return dst.DecodeText(ci, src)
}
return fmt.Errorf("unknown format code %d", format)
}
func (src ArrayType) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) {
switch format {
case BinaryFormatCode:
return src.EncodeBinary(ci, buf)
case TextFormatCode:
return src.EncodeText(ci, buf)
}
return nil, fmt.Errorf("unknown format code %d", format)
}
func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error {
uta, err := ParseUntypedTextArray(string(src))
if err != nil {
return err
@ -151,7 +183,7 @@ func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error {
if s != "NULL" {
elemSrc = []byte(s)
}
err = elem.DecodeText(ci, elemSrc)
err = elem.DecodeResult(ci, dst.elementOID, TextFormatCode, elemSrc)
if err != nil {
return err
}
@ -168,11 +200,6 @@ func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error {
}
func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
dst.setNil()
return nil
}
var arrayHeader ArrayHeader
rp, err := arrayHeader.DecodeBinary(ci, src)
if err != nil {
@ -204,7 +231,7 @@ func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error {
elemSrc = src[rp : rp+elemLen]
rp += elemLen
}
err = elem.DecodeBinary(ci, elemSrc)
err = elem.DecodeResult(ci, dst.elementOID, BinaryFormatCode, elemSrc)
if err != nil {
return err
}
@ -253,7 +280,7 @@ func (src ArrayType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
}
}
elemBuf, err := elem.EncodeText(ci, inElemBuf)
elemBuf, err := elem.EncodeParam(ci, src.elementOID, TextFormatCode, inElemBuf)
if err != nil {
return nil, err
}
@ -296,7 +323,7 @@ func (src ArrayType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
elemBuf, err := src.elements[i].EncodeBinary(ci, buf)
elemBuf, err := src.elements[i].EncodeParam(ci, src.elementOID, BinaryFormatCode, buf)
if err != nil {
return nil, err
}

View File

@ -59,8 +59,8 @@ func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
b := NewCompositeTextBuilder(ci, buf)
for _, f := range cf {
if textEncoder, ok := f.(TextEncoder); ok {
b.AppendEncoder(textEncoder)
if paramEncoder, ok := f.(ParamEncoder); ok {
b.AppendEncoder(paramEncoder)
} else {
b.AppendValue(f)
}
@ -88,15 +88,15 @@ func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error)
return nil, fmt.Errorf("Unknown OID for %#v", f)
}
if binaryEncoder, ok := f.(BinaryEncoder); ok {
b.AppendEncoder(dt.OID, binaryEncoder)
if paramEncoder, ok := f.(ParamEncoder); ok {
b.AppendEncoder(dt.OID, paramEncoder)
} else {
err := dt.Value.Set(f)
if err != nil {
return nil, err
}
if binaryEncoder, ok := dt.Value.(BinaryEncoder); ok {
b.AppendEncoder(dt.OID, binaryEncoder)
if paramEncoder, ok := dt.Value.(ParamEncoder); ok {
b.AppendEncoder(dt.OID, paramEncoder)
} else {
return nil, fmt.Errorf("Cannot encode binary format for %v", f)
}

View File

@ -91,9 +91,13 @@ func (ct *CompositeType) Fields() []CompositeTypeField {
return ct.fields
}
func (dst *CompositeType) setNil() {
dst.valid = false
}
func (dst *CompositeType) Set(src interface{}) error {
if src == nil {
dst.valid = false
dst.setNil()
return nil
}
@ -110,7 +114,7 @@ func (dst *CompositeType) Set(src interface{}) error {
dst.valid = true
case *[]interface{}:
if value == nil {
dst.valid = false
dst.setNil()
return nil
}
return dst.Set(*value)
@ -213,6 +217,56 @@ func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) {
return true, nil
}
func (ct *CompositeType) BinaryFormatSupported() bool {
for _, vt := range ct.valueTranscoders {
if !vt.BinaryFormatSupported() {
return false
}
}
return true
}
func (ct *CompositeType) TextFormatSupported() bool {
for _, vt := range ct.valueTranscoders {
if !vt.TextFormatSupported() {
return false
}
}
return true
}
func (ct *CompositeType) PreferredFormat() int16 {
if ct.BinaryFormatSupported() {
return BinaryFormatCode
}
return TextFormatCode
}
func (dst *CompositeType) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error {
if src == nil {
dst.setNil()
return nil
}
switch format {
case BinaryFormatCode:
return dst.DecodeBinary(ci, src)
case TextFormatCode:
return dst.DecodeText(ci, src)
}
return fmt.Errorf("unknown format code %d", format)
}
func (src CompositeType) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) {
switch format {
case BinaryFormatCode:
return src.EncodeBinary(ci, buf)
case TextFormatCode:
return src.EncodeText(ci, buf)
}
return nil, fmt.Errorf("unknown format code %d", format)
}
func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) {
if !src.valid {
return nil, nil
@ -231,11 +285,6 @@ func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte,
// and decoding fails if SQL value can't be assigned due to
// type mismatch
func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error {
if buf == nil {
dst.valid = false
return nil
}
scanner := NewCompositeBinaryScanner(ci, buf)
for _, f := range dst.valueTranscoders {
@ -252,11 +301,6 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error {
}
func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error {
if buf == nil {
dst.valid = false
return nil
}
scanner := NewCompositeTextScanner(ci, buf)
for _, f := range dst.valueTranscoders {
@ -315,13 +359,13 @@ func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner
}
// ScanDecoder calls Next and decodes the result with d.
func (cfs *CompositeBinaryScanner) ScanDecoder(d BinaryDecoder) {
func (cfs *CompositeBinaryScanner) ScanDecoder(d ResultDecoder) {
if cfs.err != nil {
return
}
if cfs.Next() {
cfs.err = d.DecodeBinary(cfs.ci, cfs.fieldBytes)
cfs.err = d.DecodeResult(cfs.ci, 0, BinaryFormatCode, cfs.fieldBytes)
} else {
cfs.err = errors.New("read past end of composite")
}
@ -425,13 +469,13 @@ func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner {
}
// ScanDecoder calls Next and decodes the result with d.
func (cfs *CompositeTextScanner) ScanDecoder(d TextDecoder) {
func (cfs *CompositeTextScanner) ScanDecoder(d ResultDecoder) {
if cfs.err != nil {
return
}
if cfs.Next() {
cfs.err = d.DecodeText(cfs.ci, cfs.fieldBytes)
cfs.err = d.DecodeResult(cfs.ci, 0, TextFormatCode, cfs.fieldBytes)
} else {
cfs.err = errors.New("read past end of composite")
}
@ -547,16 +591,16 @@ func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) {
return
}
binaryEncoder, ok := dt.Value.(BinaryEncoder)
paramEncoder, ok := dt.Value.(ParamEncoder)
if !ok {
b.err = fmt.Errorf("unable to encode binary for OID: %d", oid)
b.err = fmt.Errorf("unable to encode for OID: %d", oid)
return
}
b.AppendEncoder(oid, binaryEncoder)
b.AppendEncoder(oid, paramEncoder)
}
func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field BinaryEncoder) {
func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field ParamEncoder) {
if b.err != nil {
return
}
@ -564,7 +608,7 @@ func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field BinaryEncoder)
b.buf = pgio.AppendUint32(b.buf, oid)
lengthPos := len(b.buf)
b.buf = pgio.AppendInt32(b.buf, -1)
fieldBuf, err := field.EncodeBinary(b.ci, b.buf)
fieldBuf, err := field.EncodeParam(b.ci, oid, BinaryFormatCode, b.buf)
if err != nil {
b.err = err
return
@ -622,21 +666,21 @@ func (b *CompositeTextBuilder) AppendValue(field interface{}) {
return
}
textEncoder, ok := dt.Value.(TextEncoder)
paramEncoder, ok := dt.Value.(ParamEncoder)
if !ok {
b.err = fmt.Errorf("unable to encode text for value: %v", field)
b.err = fmt.Errorf("unable to encode for value: %v", field)
return
}
b.AppendEncoder(textEncoder)
b.AppendEncoder(paramEncoder)
}
func (b *CompositeTextBuilder) AppendEncoder(field TextEncoder) {
func (b *CompositeTextBuilder) AppendEncoder(field ParamEncoder) {
if b.err != nil {
return
}
fieldBuf, err := field.EncodeText(b.ci, b.fieldBuf[0:0])
fieldBuf, err := field.EncodeParam(b.ci, 0, TextFormatCode, b.fieldBuf[0:0])
if err != nil {
b.err = err
return

View File

@ -147,10 +147,9 @@ type TypeValue interface {
// ValueTranscoder is a value that implements the text and binary encoding and decoding interfaces.
type ValueTranscoder interface {
Value
TextEncoder
BinaryEncoder
TextDecoder
BinaryDecoder
FormatSupport
ParamEncoder
ResultDecoder
}
type FormatSupport interface {
@ -160,12 +159,17 @@ type FormatSupport interface {
}
type ParamEncoder interface {
FormatSupport
// EncodeParam should append the encoded value of self to buf. If self is the
// SQL value NULL then append nothing and return (nil, nil). The caller of
// EncodeText is responsible for writing the correct NULL value or the
// length of the data written.
EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error)
}
type ResultDecoder interface {
FormatSupport
// DecodeResult decodes src into ResultDecoder. If src is nil then the
// original SQL value is NULL. ResultDecoder takes ownership of src. The
// caller MUST not use it again.
DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error
}