From ad79dccd99a2506e049372314bc5b47f4a2bbc6c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 23:44:53 -0600 Subject: [PATCH] Builtin types are automatically wrapped if necessary --- pgtype/array_codec.go | 52 +++++-- pgtype/builtin_wrappers.go | 291 +++++++++++++++++++++++++++++++++++++ pgtype/convert.go | 136 ----------------- pgtype/int.go | 252 +++++++++++++++++++++----------- pgtype/int.go.erb | 84 +++++++---- pgtype/pgtype.go | 176 ++++++++++++++++++++++ 6 files changed, 725 insertions(+), 266 deletions(-) create mode 100644 pgtype/builtin_wrappers.go diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index e8c2b2ed..1e506a43 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "encoding/binary" "fmt" + "reflect" "github.com/jackc/pgio" ) @@ -88,6 +89,8 @@ func (p *encodePlanArrayCodecText) Encode(value interface{}, buf []byte) (newBuf dimElemCounts[i] = int(dimensions[i].Length) * dimElemCounts[i+1] } + var encodePlan EncodePlan + var lastElemType reflect.Type inElemBuf := make([]byte, 0, 32) for i := 0; i < elementCount; i++ { if i > 0 { @@ -100,14 +103,23 @@ func (p *encodePlanArrayCodecText) Encode(value interface{}, buf []byte) (newBuf } } - encodePlan := p.ac.ElementCodec.PlanEncode(p.ci, p.ac.ElementOID, TextFormatCode, array.Index(i)) - if encodePlan == nil { - return nil, fmt.Errorf("unable to encode %v", array.Index(i)) - } - elemBuf, err := encodePlan.Encode(array.Index(i), inElemBuf) - if err != nil { - return nil, err + elem := array.Index(i) + var elemBuf []byte + if elem != nil { + elemType := reflect.TypeOf(elem) + if lastElemType != elemType { + lastElemType = elemType + encodePlan = p.ci.PlanEncode(p.ac.ElementOID, TextFormatCode, elem) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", array.Index(i)) + } + } + elemBuf, err = encodePlan.Encode(elem, inElemBuf) + if err != nil { + return nil, err + } } + if elemBuf == nil { buf = append(buf, `NULL`...) } else { @@ -151,18 +163,30 @@ func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newB buf = arrayHeader.EncodeBinary(p.ci, buf) elementCount := cardinality(dimensions) + + var encodePlan EncodePlan + var lastElemType reflect.Type for i := 0; i < elementCount; i++ { sp := len(buf) buf = pgio.AppendInt32(buf, -1) - encodePlan := p.ac.ElementCodec.PlanEncode(p.ci, p.ac.ElementOID, BinaryFormatCode, array.Index(i)) - if encodePlan == nil { - return nil, fmt.Errorf("unable to encode %v", array.Index(i)) - } - elemBuf, err := encodePlan.Encode(array.Index(i), buf) - if err != nil { - return nil, err + elem := array.Index(i) + var elemBuf []byte + if elem != nil { + elemType := reflect.TypeOf(elem) + if lastElemType != elemType { + lastElemType = elemType + encodePlan = p.ci.PlanEncode(p.ac.ElementOID, BinaryFormatCode, elem) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", array.Index(i)) + } + } + elemBuf, err = encodePlan.Encode(elem, buf) + if err != nil { + return nil, err + } } + if elemBuf == nil { pgio.SetInt32(buf[containsNullIndex:], 1) } else { diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go new file mode 100644 index 00000000..17fe4535 --- /dev/null +++ b/pgtype/builtin_wrappers.go @@ -0,0 +1,291 @@ +package pgtype + +import ( + "fmt" + "math" + "strconv" +) + +type int8Wrapper int8 + +func (n *int8Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int8") + } + + if v.Int < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", v.Int) + } + if v.Int > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", v.Int) + } + *n = int8Wrapper(v.Int) + + return nil +} + +func (n int8Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(n), Valid: true}, nil +} + +type int16Wrapper int16 + +func (n *int16Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int16") + } + + if v.Int < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", v.Int) + } + if v.Int > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", v.Int) + } + *n = int16Wrapper(v.Int) + + return nil +} + +func (n int16Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(n), Valid: true}, nil +} + +type int32Wrapper int32 + +func (n *int32Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int32") + } + + if v.Int < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for int32", v.Int) + } + if v.Int > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for int32", v.Int) + } + *n = int32Wrapper(v.Int) + + return nil +} + +func (n int32Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(n), Valid: true}, nil +} + +type int64Wrapper int64 + +func (n *int64Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int64") + } + + *n = int64Wrapper(v.Int) + + return nil +} + +func (n int64Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(n), Valid: true}, nil +} + +type intWrapper int + +func (n *intWrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *int") + } + + if v.Int < math.MinInt { + return fmt.Errorf("%d is less than minimum value for int", v.Int) + } + if v.Int > math.MaxInt { + return fmt.Errorf("%d is greater than maximum value for int", v.Int) + } + + *n = intWrapper(v.Int) + + return nil +} + +func (n intWrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(n), Valid: true}, nil +} + +type uint8Wrapper uint8 + +func (n *uint8Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint8") + } + + if v.Int < 0 { + return fmt.Errorf("%d is less than minimum value for uint8", v.Int) + } + if v.Int > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", v.Int) + } + *n = uint8Wrapper(v.Int) + + return nil +} + +func (n uint8Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(n), Valid: true}, nil +} + +type uint16Wrapper uint16 + +func (n *uint16Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint16") + } + + if v.Int < 0 { + return fmt.Errorf("%d is less than minimum value for uint16", v.Int) + } + if v.Int > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", v.Int) + } + *n = uint16Wrapper(v.Int) + + return nil +} + +func (n uint16Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(n), Valid: true}, nil +} + +type uint32Wrapper uint32 + +func (n *uint32Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint32") + } + + if v.Int < 0 { + return fmt.Errorf("%d is less than minimum value for uint32", v.Int) + } + if v.Int > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint32", v.Int) + } + *n = uint32Wrapper(v.Int) + + return nil +} + +func (n uint32Wrapper) Int64Value() (Int8, error) { + return Int8{Int: int64(n), Valid: true}, nil +} + +type uint64Wrapper uint64 + +func (n *uint64Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint64") + } + + if v.Int < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", v.Int) + } + + *n = uint64Wrapper(v.Int) + + return nil +} + +func (n uint64Wrapper) Int64Value() (Int8, error) { + if uint64(n) > uint64(math.MaxInt64) { + return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", n) + } + + return Int8{Int: int64(n), Valid: true}, nil +} + +type uintWrapper uint + +func (n *uintWrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *uint64") + } + + if v.Int < 0 { + return fmt.Errorf("%d is less than minimum value for uint64", v.Int) + } + + if uint64(v.Int) > math.MaxUint { + return fmt.Errorf("%d is greater than maximum value for uint", v.Int) + } + + *n = uintWrapper(v.Int) + + return nil +} + +func (n uintWrapper) Int64Value() (Int8, error) { + if uint64(n) > uint64(math.MaxInt64) { + return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", n) + } + + return Int8{Int: int64(n), Valid: true}, nil +} + +type float32Wrapper float32 + +func (n *float32Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *float32") + } + + *n = float32Wrapper(v.Int) + + return nil +} + +func (n float32Wrapper) Int64Value() (Int8, error) { + if n > math.MaxInt64 { + return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", n) + } + + return Int8{Int: int64(n), Valid: true}, nil +} + +type float64Wrapper float64 + +func (n *float64Wrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *float64") + } + + *n = float64Wrapper(v.Int) + + return nil +} + +func (n float64Wrapper) Int64Value() (Int8, error) { + if n > math.MaxInt64 { + return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", n) + } + + return Int8{Int: int64(n), Valid: true}, nil +} + +type stringWrapper string + +func (s *stringWrapper) ScanInt64(v Int8) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *string") + } + + *s = stringWrapper(strconv.FormatInt(v.Int, 10)) + + return nil +} + +func (s stringWrapper) Int64Value() (Int8, error) { + num, err := strconv.ParseInt(string(s), 10, 64) + if err != nil { + return Int8{}, err + } + + return Int8{Int: int64(num), Valid: true}, nil +} diff --git a/pgtype/convert.go b/pgtype/convert.go index ee5ba393..21e208f5 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -5,7 +5,6 @@ import ( "fmt" "math" "reflect" - "strconv" "time" ) @@ -453,141 +452,6 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { return nil, false } -func convertToInt64ForEncode(v interface{}) (n int64, valid bool, err error) { - if v == nil { - return 0, false, nil - } - - switch v := v.(type) { - case int8: - return int64(v), true, nil - case uint8: - return int64(v), true, nil - case int16: - return int64(v), true, nil - case uint16: - return int64(v), true, nil - case int32: - return int64(v), true, nil - case uint32: - return int64(v), true, nil - case int64: - return int64(v), true, nil - case uint64: - if v > math.MaxInt64 { - return 0, false, fmt.Errorf("%d is greater than maximum value for int64", v) - } - return int64(v), true, nil - case int: - return int64(v), true, nil - case uint: - if v > math.MaxInt64 { - return 0, false, fmt.Errorf("%d is greater than maximum value for int64", v) - } - return int64(v), true, nil - case string: - num, err := strconv.ParseInt(v, 10, 64) - if err != nil { - return 0, false, err - } - return num, true, nil - case float32: - if v > math.MaxInt64 { - return 0, false, fmt.Errorf("%f is greater than maximum value for int64", v) - } - return int64(v), true, nil - case float64: - if v > math.MaxInt64 { - return 0, false, fmt.Errorf("%f is greater than maximum value for int64", v) - } - return int64(v), true, nil - case *int8: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *uint8: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *int16: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *uint16: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *int32: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *uint32: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *int64: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *uint64: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *int: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *uint: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *string: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *float32: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - case *float64: - if v == nil { - return 0, false, nil - } else { - return convertToInt64ForEncode(*v) - } - - default: - if originalvalue, ok := underlyingNumberType(v); ok { - return convertToInt64ForEncode(originalvalue) - } - return 0, false, fmt.Errorf("cannot convert %v to int64", v) - } -} - func init() { kindTypes = map[reflect.Kind]reflect.Type{ reflect.Bool: reflect.TypeOf(false), diff --git a/pgtype/int.go b/pgtype/int.go index 54898420..553d4dd0 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -130,54 +130,80 @@ func (Int2Codec) PreferredFormat() int16 { func (Int2Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: - return encodePlanInt2CodecBinary{} + switch value.(type) { + case int16: + return encodePlanInt2CodecBinaryInt16{} + case Int64Valuer: + return encodePlanInt2CodecBinaryInt64Valuer{} + } case TextFormatCode: - return encodePlanInt2CodecText{} + switch value.(type) { + case int16: + return encodePlanInt2CodecTextInt16{} + case Int64Valuer: + return encodePlanInt2CodecTextInt64Valuer{} + } } return nil } -type encodePlanInt2CodecBinary struct{} - -func (encodePlanInt2CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - n, valid, err := convertToInt64ForEncode(value) - if err != nil { - return nil, fmt.Errorf("cannot convert %v to int2: %v", value, err) - } - if !valid { - return nil, nil - } - - if n > math.MaxInt16 { - return nil, fmt.Errorf("%d is greater than maximum value for int2", n) - } - if n < math.MinInt16 { - return nil, fmt.Errorf("%d is less than minimum value for int2", n) - } +type encodePlanInt2CodecBinaryInt16 struct{} +func (encodePlanInt2CodecBinaryInt16) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(int16) return pgio.AppendInt16(buf, int16(n)), nil } -type encodePlanInt2CodecText struct{} +type encodePlanInt2CodecTextInt16 struct{} -func (encodePlanInt2CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - n, valid, err := convertToInt64ForEncode(value) +func (encodePlanInt2CodecTextInt16) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(int16) + return append(buf, strconv.FormatInt(int64(n), 10)...), nil +} + +type encodePlanInt2CodecBinaryInt64Valuer struct{} + +func (encodePlanInt2CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() if err != nil { - return nil, fmt.Errorf("cannot convert %v to int2: %v", value, err) + return nil, err } - if !valid { + + if !n.Valid { return nil, nil } - if n > math.MaxInt16 { - return nil, fmt.Errorf("%d is greater than maximum value for int2", n) + if n.Int > math.MaxInt16 { + return nil, fmt.Errorf("%d is greater than maximum value for int2", n.Int) } - if n < math.MinInt16 { - return nil, fmt.Errorf("%d is less than minimum value for int2", n) + if n.Int < math.MinInt16 { + return nil, fmt.Errorf("%d is less than minimum value for int2", n.Int) } - return append(buf, strconv.FormatInt(n, 10)...), nil + return pgio.AppendInt16(buf, int16(n.Int)), nil +} + +type encodePlanInt2CodecTextInt64Valuer struct{} + +func (encodePlanInt2CodecTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int > math.MaxInt16 { + return nil, fmt.Errorf("%d is greater than maximum value for int2", n.Int) + } + if n.Int < math.MinInt16 { + return nil, fmt.Errorf("%d is less than minimum value for int2", n.Int) + } + + return append(buf, strconv.FormatInt(n.Int, 10)...), nil } func (Int2Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { @@ -641,54 +667,80 @@ func (Int4Codec) PreferredFormat() int16 { func (Int4Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: - return encodePlanInt4CodecBinary{} + switch value.(type) { + case int32: + return encodePlanInt4CodecBinaryInt32{} + case Int64Valuer: + return encodePlanInt4CodecBinaryInt64Valuer{} + } case TextFormatCode: - return encodePlanInt4CodecText{} + switch value.(type) { + case int32: + return encodePlanInt4CodecTextInt32{} + case Int64Valuer: + return encodePlanInt4CodecTextInt64Valuer{} + } } return nil } -type encodePlanInt4CodecBinary struct{} - -func (encodePlanInt4CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - n, valid, err := convertToInt64ForEncode(value) - if err != nil { - return nil, fmt.Errorf("cannot convert %v to int4: %v", value, err) - } - if !valid { - return nil, nil - } - - if n > math.MaxInt32 { - return nil, fmt.Errorf("%d is greater than maximum value for int4", n) - } - if n < math.MinInt32 { - return nil, fmt.Errorf("%d is less than minimum value for int4", n) - } +type encodePlanInt4CodecBinaryInt32 struct{} +func (encodePlanInt4CodecBinaryInt32) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(int32) return pgio.AppendInt32(buf, int32(n)), nil } -type encodePlanInt4CodecText struct{} +type encodePlanInt4CodecTextInt32 struct{} -func (encodePlanInt4CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - n, valid, err := convertToInt64ForEncode(value) +func (encodePlanInt4CodecTextInt32) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(int32) + return append(buf, strconv.FormatInt(int64(n), 10)...), nil +} + +type encodePlanInt4CodecBinaryInt64Valuer struct{} + +func (encodePlanInt4CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() if err != nil { - return nil, fmt.Errorf("cannot convert %v to int4: %v", value, err) + return nil, err } - if !valid { + + if !n.Valid { return nil, nil } - if n > math.MaxInt32 { - return nil, fmt.Errorf("%d is greater than maximum value for int4", n) + if n.Int > math.MaxInt32 { + return nil, fmt.Errorf("%d is greater than maximum value for int4", n.Int) } - if n < math.MinInt32 { - return nil, fmt.Errorf("%d is less than minimum value for int4", n) + if n.Int < math.MinInt32 { + return nil, fmt.Errorf("%d is less than minimum value for int4", n.Int) } - return append(buf, strconv.FormatInt(n, 10)...), nil + return pgio.AppendInt32(buf, int32(n.Int)), nil +} + +type encodePlanInt4CodecTextInt64Valuer struct{} + +func (encodePlanInt4CodecTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int > math.MaxInt32 { + return nil, fmt.Errorf("%d is greater than maximum value for int4", n.Int) + } + if n.Int < math.MinInt32 { + return nil, fmt.Errorf("%d is less than minimum value for int4", n.Int) + } + + return append(buf, strconv.FormatInt(n.Int, 10)...), nil } func (Int4Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { @@ -1163,54 +1215,80 @@ func (Int8Codec) PreferredFormat() int16 { func (Int8Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: - return encodePlanInt8CodecBinary{} + switch value.(type) { + case int64: + return encodePlanInt8CodecBinaryInt64{} + case Int64Valuer: + return encodePlanInt8CodecBinaryInt64Valuer{} + } case TextFormatCode: - return encodePlanInt8CodecText{} + switch value.(type) { + case int64: + return encodePlanInt8CodecTextInt64{} + case Int64Valuer: + return encodePlanInt8CodecTextInt64Valuer{} + } } return nil } -type encodePlanInt8CodecBinary struct{} - -func (encodePlanInt8CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - n, valid, err := convertToInt64ForEncode(value) - if err != nil { - return nil, fmt.Errorf("cannot convert %v to int8: %v", value, err) - } - if !valid { - return nil, nil - } - - if n > math.MaxInt64 { - return nil, fmt.Errorf("%d is greater than maximum value for int8", n) - } - if n < math.MinInt64 { - return nil, fmt.Errorf("%d is less than minimum value for int8", n) - } +type encodePlanInt8CodecBinaryInt64 struct{} +func (encodePlanInt8CodecBinaryInt64) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(int64) return pgio.AppendInt64(buf, int64(n)), nil } -type encodePlanInt8CodecText struct{} +type encodePlanInt8CodecTextInt64 struct{} -func (encodePlanInt8CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - n, valid, err := convertToInt64ForEncode(value) +func (encodePlanInt8CodecTextInt64) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(int64) + return append(buf, strconv.FormatInt(int64(n), 10)...), nil +} + +type encodePlanInt8CodecBinaryInt64Valuer struct{} + +func (encodePlanInt8CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() if err != nil { - return nil, fmt.Errorf("cannot convert %v to int8: %v", value, err) + return nil, err } - if !valid { + + if !n.Valid { return nil, nil } - if n > math.MaxInt64 { - return nil, fmt.Errorf("%d is greater than maximum value for int8", n) + if n.Int > math.MaxInt64 { + return nil, fmt.Errorf("%d is greater than maximum value for int8", n.Int) } - if n < math.MinInt64 { - return nil, fmt.Errorf("%d is less than minimum value for int8", n) + if n.Int < math.MinInt64 { + return nil, fmt.Errorf("%d is less than minimum value for int8", n.Int) } - return append(buf, strconv.FormatInt(n, 10)...), nil + return pgio.AppendInt64(buf, int64(n.Int)), nil +} + +type encodePlanInt8CodecTextInt64Valuer struct{} + +func (encodePlanInt8CodecTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int > math.MaxInt64 { + return nil, fmt.Errorf("%d is greater than maximum value for int8", n.Int) + } + if n.Int < math.MinInt64 { + return nil, fmt.Errorf("%d is less than minimum value for int8", n.Int) + } + + return append(buf, strconv.FormatInt(n.Int, 10)...), nil } func (Int8Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index 0d88dd42..6aecb761 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -131,54 +131,80 @@ func (Int<%= pg_byte_size %>Codec) PreferredFormat() int16 { func (Int<%= pg_byte_size %>Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { switch format { case BinaryFormatCode: - return encodePlanInt<%= pg_byte_size %>CodecBinary{} + switch value.(type) { + case int<%= pg_bit_size %>: + return encodePlanInt<%= pg_byte_size %>CodecBinaryInt<%= pg_bit_size %>{} + case Int64Valuer: + return encodePlanInt<%= pg_byte_size %>CodecBinaryInt64Valuer{} + } case TextFormatCode: - return encodePlanInt<%= pg_byte_size %>CodecText{} + switch value.(type) { + case int<%= pg_bit_size %>: + return encodePlanInt<%= pg_byte_size %>CodecTextInt<%= pg_bit_size %>{} + case Int64Valuer: + return encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer{} + } } return nil } -type encodePlanInt<%= pg_byte_size %>CodecBinary struct{} - -func (encodePlanInt<%= pg_byte_size %>CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - n, valid, err := convertToInt64ForEncode(value) - if err != nil { - return nil, fmt.Errorf("cannot convert %v to int<%= pg_byte_size %>: %v", value, err) - } - if !valid { - return nil, nil - } - - if n > math.MaxInt<%= pg_bit_size %> { - return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n) - } - if n < math.MinInt<%= pg_bit_size %> { - return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n) - } +type encodePlanInt<%= pg_byte_size %>CodecBinaryInt<%= pg_bit_size %> struct{} +func (encodePlanInt<%= pg_byte_size %>CodecBinaryInt<%= pg_bit_size %>) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(int<%= pg_bit_size %>) return pgio.AppendInt<%= pg_bit_size %>(buf, int<%= pg_bit_size %>(n)), nil } -type encodePlanInt<%= pg_byte_size %>CodecText struct{} +type encodePlanInt<%= pg_byte_size %>CodecTextInt<%= pg_bit_size %> struct{} -func (encodePlanInt<%= pg_byte_size %>CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { - n, valid, err := convertToInt64ForEncode(value) +func (encodePlanInt<%= pg_byte_size %>CodecTextInt<%= pg_bit_size %>) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n := value.(int<%= pg_bit_size %>) + return append(buf, strconv.FormatInt(int64(n), 10)...), nil +} + +type encodePlanInt<%= pg_byte_size %>CodecBinaryInt64Valuer struct{} + +func (encodePlanInt<%= pg_byte_size %>CodecBinaryInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() if err != nil { - return nil, fmt.Errorf("cannot convert %v to int<%= pg_byte_size %>: %v", value, err) + return nil, err } - if !valid { + + if !n.Valid { return nil, nil } - if n > math.MaxInt<%= pg_bit_size %> { - return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n) + if n.Int > math.MaxInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n.Int) } - if n < math.MinInt<%= pg_bit_size %> { - return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n) + if n.Int < math.MinInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n.Int) } - return append(buf, strconv.FormatInt(n, 10)...), nil + return pgio.AppendInt<%= pg_bit_size %>(buf, int<%= pg_bit_size %>(n.Int)), nil +} + +type encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer struct{} + +func (encodePlanInt<%= pg_byte_size %>CodecTextInt64Valuer) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !n.Valid { + return nil, nil + } + + if n.Int > math.MaxInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n.Int) + } + if n.Int < math.MinInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n.Int) + } + + return append(buf, strconv.FormatInt(n.Int, 10)...), nil } func (Int<%= pg_byte_size %>Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 7fff7dd5..bb0d2a9d 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1011,6 +1011,14 @@ func (ci *ConnInfo) PlanEncode(oid uint32, format int16, value interface{}) Enco } } + if wrapperPlan, nextValue, ok := tryWrapBuiltinTypeEncodePlan(value); ok { + if nextPlan := ci.PlanEncode(oid, format, nextValue); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + + } + } return nil @@ -1074,6 +1082,174 @@ func tryBaseTypeEncodePlan(value interface{}) (plan *baseTypeEncodePlan, nextVal return nil, nil, false } +type WrappedEncodePlanNextSetter interface { + SetNext(EncodePlan) + EncodePlan +} + +func tryWrapBuiltinTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { + switch value.(type) { + case int8: + return &wrapInt8EncodePlan{}, int8Wrapper(value.(int8)), true + case int16: + return &wrapInt16EncodePlan{}, int16Wrapper(value.(int16)), true + case int32: + return &wrapInt32EncodePlan{}, int32Wrapper(value.(int32)), true + case int64: + return &wrapInt64EncodePlan{}, int64Wrapper(value.(int64)), true + case int: + return &wrapIntEncodePlan{}, intWrapper(value.(int)), true + case uint8: + return &wrapUint8EncodePlan{}, uint8Wrapper(value.(uint8)), true + case uint16: + return &wrapUint16EncodePlan{}, uint16Wrapper(value.(uint16)), true + case uint32: + return &wrapUint32EncodePlan{}, uint32Wrapper(value.(uint32)), true + case uint64: + return &wrapUint64EncodePlan{}, uint64Wrapper(value.(uint64)), true + case uint: + return &wrapUintEncodePlan{}, uintWrapper(value.(uint)), true + case float32: + return &wrapFloat32EncodePlan{}, float32Wrapper(value.(float32)), true + case float64: + return &wrapFloat64EncodePlan{}, float64Wrapper(value.(float64)), true + case string: + return &wrapStringEncodePlan{}, stringWrapper(value.(string)), true + } + + return nil, nil, false +} + +type wrapInt8EncodePlan struct { + next EncodePlan +} + +func (plan *wrapInt8EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapInt8EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(int8Wrapper(value.(int8)), buf) +} + +type wrapInt16EncodePlan struct { + next EncodePlan +} + +func (plan *wrapInt16EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapInt16EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(int16Wrapper(value.(int16)), buf) +} + +type wrapInt32EncodePlan struct { + next EncodePlan +} + +func (plan *wrapInt32EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapInt32EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(int32Wrapper(value.(int32)), buf) +} + +type wrapInt64EncodePlan struct { + next EncodePlan +} + +func (plan *wrapInt64EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapInt64EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(int64Wrapper(value.(int64)), buf) +} + +type wrapIntEncodePlan struct { + next EncodePlan +} + +func (plan *wrapIntEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapIntEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(intWrapper(value.(int)), buf) +} + +type wrapUint8EncodePlan struct { + next EncodePlan +} + +func (plan *wrapUint8EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUint8EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uint8Wrapper(value.(uint8)), buf) +} + +type wrapUint16EncodePlan struct { + next EncodePlan +} + +func (plan *wrapUint16EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUint16EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uint16Wrapper(value.(uint16)), buf) +} + +type wrapUint32EncodePlan struct { + next EncodePlan +} + +func (plan *wrapUint32EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUint32EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uint32Wrapper(value.(uint32)), buf) +} + +type wrapUint64EncodePlan struct { + next EncodePlan +} + +func (plan *wrapUint64EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUint64EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uint64Wrapper(value.(uint64)), buf) +} + +type wrapUintEncodePlan struct { + next EncodePlan +} + +func (plan *wrapUintEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapUintEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(uintWrapper(value.(uint)), buf) +} + +type wrapFloat32EncodePlan struct { + next EncodePlan +} + +func (plan *wrapFloat32EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapFloat32EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(float32Wrapper(value.(float32)), buf) +} + +type wrapFloat64EncodePlan struct { + next EncodePlan +} + +func (plan *wrapFloat64EncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapFloat64EncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(float64Wrapper(value.(float64)), buf) +} + +type wrapStringEncodePlan struct { + next EncodePlan +} + +func (plan *wrapStringEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapStringEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(stringWrapper(value.(string)), buf) +} + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written.