From c6f3e03a6103ca6c4a524ad33b26e69585f89db2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 17:01:32 -0600 Subject: [PATCH] BoolCodec EncodePlan actually plans --- pgtype/bool.go | 103 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 69 insertions(+), 34 deletions(-) diff --git a/pgtype/bool.go b/pgtype/bool.go index 21cd9889..3dd7efd3 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -8,7 +8,11 @@ import ( ) type BoolScanner interface { - ScanBool(v bool, valid bool) error + ScanBool(v Bool) error +} + +type BoolValuer interface { + BoolValue() (Bool, error) } type Bool struct { @@ -16,18 +20,15 @@ type Bool struct { Valid bool } -// ScanBool implements the BoolScanner interface. -func (dst *Bool) ScanBool(v bool, valid bool) error { - if !valid { - *dst = Bool{} - return nil - } - - *dst = Bool{Bool: v, Valid: true} - +func (b *Bool) ScanBool(v Bool) error { + *b = v return nil } +func (b Bool) BoolValue() (Bool, error) { + return b, nil +} + // Scan implements the database/sql Scanner interface. func (dst *Bool) Scan(src interface{}) error { if src == nil { @@ -108,27 +109,28 @@ func (BoolCodec) PreferredFormat() int16 { func (BoolCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: - return encodePlanBoolCodecBinary{} + switch value.(type) { + case bool: + return encodePlanBoolCodecBinaryBool{} + case BoolValuer: + return encodePlanBoolCodecBinaryBoolScanner{} + } case TextFormatCode: - return encodePlanBoolCodecText{} + switch value.(type) { + case bool: + return encodePlanBoolCodecTextBool{} + case BoolValuer: + return encodePlanBoolCodecTextBoolScanner{} + } } return nil } -type encodePlanBoolCodecBinary struct{} +type encodePlanBoolCodecBinaryBool struct{} -func (encodePlanBoolCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - v, valid, err := convertToBoolForEncode(value) - if err != nil { - return nil, fmt.Errorf("cannot convert %v to bool: %v", value, err) - } - if !valid { - return nil, nil - } - if value == nil { - return nil, nil - } +func (encodePlanBoolCodecBinaryBool) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v := value.(bool) if v { buf = append(buf, 1) @@ -139,20 +141,53 @@ func (encodePlanBoolCodecBinary) Encode(value interface{}, buf []byte) (newBuf [ return buf, nil } -type encodePlanBoolCodecText struct{} +type encodePlanBoolCodecTextBoolScanner struct{} -func (encodePlanBoolCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - v, valid, err := convertToBoolForEncode(value) +func (encodePlanBoolCodecTextBoolScanner) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + b, err := value.(BoolValuer).BoolValue() if err != nil { - return nil, fmt.Errorf("cannot convert %v to bool: %v", value, err) + return nil, err } - if !valid { + + if !b.Valid { return nil, nil } - if value == nil { + + if b.Bool { + buf = append(buf, 't') + } else { + buf = append(buf, 'f') + } + + return buf, nil +} + +type encodePlanBoolCodecBinaryBoolScanner struct{} + +func (encodePlanBoolCodecBinaryBoolScanner) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + b, err := value.(BoolValuer).BoolValue() + if err != nil { + return nil, err + } + + if !b.Valid { return nil, nil } + if b.Bool { + buf = append(buf, 1) + } else { + buf = append(buf, 0) + } + + return buf, nil +} + +type encodePlanBoolCodecTextBool struct{} + +func (encodePlanBoolCodecTextBool) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v := value.(bool) + if v { buf = append(buf, 't') } else { @@ -288,14 +323,14 @@ func (scanPlanBinaryBoolToBoolScanner) Scan(ci *ConnInfo, oid uint32, formatCode } if src == nil { - return s.ScanBool(false, false) + return s.ScanBool(Bool{}) } if len(src) != 1 { return fmt.Errorf("invalid length for bool: %v", len(src)) } - return s.ScanBool(src[0] == 1, true) + return s.ScanBool(Bool{Bool: src[0] == 1, Valid: true}) } type scanPlanTextAnyToBoolScanner struct{} @@ -307,12 +342,12 @@ func (scanPlanTextAnyToBoolScanner) Scan(ci *ConnInfo, oid uint32, formatCode in } if src == nil { - return s.ScanBool(false, false) + return s.ScanBool(Bool{}) } if len(src) != 1 { return fmt.Errorf("invalid length for bool: %v", len(src)) } - return s.ScanBool(src[0] == 't', true) + return s.ScanBool(Bool{Bool: src[0] == 't', Valid: true}) }