diff --git a/extended_query_builder.go b/extended_query_builder.go index 480e35d3..36447c99 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -131,7 +131,7 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o } return eqb.encodeExtendedParamValue(ci, oid, formatCode, value) } else if dt.Codec != nil { - buf, err := dt.Codec.Encode(ci, oid, formatCode, arg, eqb.paramValueBytes) + buf, err := ci.Encode(oid, formatCode, arg, eqb.paramValueBytes) if err != nil { return nil, err } diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index 16ce7382..e8c2b2ed 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -42,63 +42,29 @@ func (c *ArrayCodec) PreferredFormat() int16 { return c.ElementCodec.PreferredFormat() } -func (c *ArrayCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { - if value == nil { - return nil, nil +func (c *ArrayCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + return &encodePlanArrayCodecBinary{ac: c, ci: ci, oid: oid} + case TextFormatCode: + return &encodePlanArrayCodecText{ac: c, ci: ci, oid: oid} } + return nil +} + +type encodePlanArrayCodecText struct { + ac *ArrayCodec + ci *ConnInfo + oid uint32 +} + +func (p *encodePlanArrayCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { array, err := makeArrayGetter(value) if err != nil { return nil, err } - switch format { - case BinaryFormatCode: - return c.encodeBinary(ci, oid, array, buf) - case TextFormatCode: - return c.encodeText(ci, oid, array, buf) - default: - return nil, fmt.Errorf("unknown format code: %v", format) - } - -} - -func (c *ArrayCodec) encodeBinary(ci *ConnInfo, oid uint32, array ArrayGetter, buf []byte) (newBuf []byte, err error) { - dimensions := array.Dimensions() - if dimensions == nil { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: dimensions, - ElementOID: int32(c.ElementOID), - } - - containsNullIndex := len(buf) + 4 - - buf = arrayHeader.EncodeBinary(ci, buf) - - elementCount := cardinality(dimensions) - for i := 0; i < elementCount; i++ { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := c.ElementCodec.Encode(ci, c.ElementOID, BinaryFormatCode, array.Index(i), buf) - if err != nil { - return nil, err - } - if elemBuf == nil { - pgio.SetInt32(buf[containsNullIndex:], 1) - } else { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -func (c *ArrayCodec) encodeText(ci *ConnInfo, oid uint32, array ArrayGetter, buf []byte) (newBuf []byte, err error) { dimensions := array.Dimensions() if dimensions == nil { return nil, nil @@ -134,7 +100,11 @@ func (c *ArrayCodec) encodeText(ci *ConnInfo, oid uint32, array ArrayGetter, buf } } - elemBuf, err := c.ElementCodec.Encode(ci, c.ElementOID, TextFormatCode, array.Index(i), inElemBuf) + encodePlan := p.ac.ElementCodec.PlanEncode(p.ci, p.ac.ElementOID, TextFormatCode, array.Index(i)) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", array.Index(i)) + } + elemBuf, err := encodePlan.Encode(array.Index(i), inElemBuf) if err != nil { return nil, err } @@ -154,6 +124,56 @@ func (c *ArrayCodec) encodeText(ci *ConnInfo, oid uint32, array ArrayGetter, buf return buf, nil } +type encodePlanArrayCodecBinary struct { + ac *ArrayCodec + ci *ConnInfo + oid uint32 +} + +func (p *encodePlanArrayCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + array, err := makeArrayGetter(value) + if err != nil { + return nil, err + } + + dimensions := array.Dimensions() + if dimensions == nil { + return nil, nil + } + + arrayHeader := ArrayHeader{ + Dimensions: dimensions, + ElementOID: int32(p.ac.ElementOID), + } + + containsNullIndex := len(buf) + 4 + + buf = arrayHeader.EncodeBinary(p.ci, buf) + + elementCount := cardinality(dimensions) + for i := 0; i < elementCount; i++ { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + encodePlan := p.ac.ElementCodec.PlanEncode(p.ci, p.ac.ElementOID, BinaryFormatCode, array.Index(i)) + if encodePlan == nil { + return nil, fmt.Errorf("unable to encode %v", array.Index(i)) + } + elemBuf, err := encodePlan.Encode(array.Index(i), buf) + if err != nil { + return nil, err + } + if elemBuf == nil { + pgio.SetInt32(buf[containsNullIndex:], 1) + } else { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + func (c *ArrayCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { _, err := makeArraySetter(target) if err != nil { diff --git a/pgtype/bool.go b/pgtype/bool.go index 36d29d40..d2c3cdc3 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -105,7 +105,20 @@ func (BoolCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (BoolCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { +func (BoolCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + return &encodePlanBoolCodecBinary{} + case TextFormatCode: + return &encodePlanBoolCodecText{} + } + + return nil +} + +type encodePlanBoolCodecBinary struct{} + +func (p *encodePlanBoolCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { v, valid, err := convertToBoolForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to bool: %v", value, err) @@ -117,24 +130,36 @@ func (BoolCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{ return nil, nil } - switch format { - case BinaryFormatCode: - if v { - buf = append(buf, 1) - } else { - buf = append(buf, 0) - } - return buf, nil - case TextFormatCode: - if v { - buf = append(buf, 't') - } else { - buf = append(buf, 'f') - } - return buf, nil - default: - return nil, fmt.Errorf("unknown format code: %v", format) + if v { + buf = append(buf, 1) + } else { + buf = append(buf, 0) } + + return buf, nil +} + +type encodePlanBoolCodecText struct{} + +func (p *encodePlanBoolCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + v, valid, err := convertToBoolForEncode(value) + if err != nil { + return nil, fmt.Errorf("cannot convert %v to bool: %v", value, err) + } + if !valid { + return nil, nil + } + if value == nil { + return nil, nil + } + + if v { + buf = append(buf, 't') + } else { + buf = append(buf, 'f') + } + + return buf, nil } func (BoolCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { diff --git a/pgtype/box.go b/pgtype/box.go index 7db7d5a2..b5c30ed3 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -50,7 +50,7 @@ func (dst *Box) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Box) Value() (driver.Value, error) { - buf, err := BoxCodec{}.Encode(nil, 0, TextFormatCode, src, nil) + buf, err := BoxCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) if err != nil { return nil, err } @@ -67,44 +67,51 @@ func (BoxCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (BoxCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { - if value == nil { - return nil, nil - } - - var box Box - if v, ok := value.(BoxValuer); ok { - b, err := v.BoxValue() - if err != nil { - return nil, err - } - box = b - } else { - return nil, fmt.Errorf("cannot convert %v to box: %v", value, err) - } - - if !box.Valid { - return nil, nil +func (BoxCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(BoxValuer); !ok { + return nil } switch format { case BinaryFormatCode: - buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].X)) - buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].Y)) - buf = pgio.AppendUint64(buf, math.Float64bits(box.P[1].X)) - buf = pgio.AppendUint64(buf, math.Float64bits(box.P[1].Y)) - return buf, nil + return &encodePlanBoxCodecBinary{} case TextFormatCode: - buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, - strconv.FormatFloat(box.P[0].X, 'f', -1, 64), - strconv.FormatFloat(box.P[0].Y, 'f', -1, 64), - strconv.FormatFloat(box.P[1].X, 'f', -1, 64), - strconv.FormatFloat(box.P[1].Y, 'f', -1, 64), - )...) - return buf, nil - default: - return nil, fmt.Errorf("unknown format code: %v", format) + return &encodePlanBoxCodecText{} } + + return nil +} + +type encodePlanBoxCodecBinary struct{} + +func (p *encodePlanBoxCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + box, err := value.(BoxValuer).BoxValue() + if err != nil { + return nil, err + } + + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(box.P[1].Y)) + return buf, nil +} + +type encodePlanBoxCodecText struct{} + +func (p *encodePlanBoxCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + box, err := value.(BoxValuer).BoxValue() + if err != nil { + return nil, err + } + + buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, + strconv.FormatFloat(box.P[0].X, 'f', -1, 64), + strconv.FormatFloat(box.P[0].Y, 'f', -1, 64), + strconv.FormatFloat(box.P[1].X, 'f', -1, 64), + strconv.FormatFloat(box.P[1].Y, 'f', -1, 64), + )...) + return buf, nil } func (BoxCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { diff --git a/pgtype/circle.go b/pgtype/circle.go index f1f66175..f214a070 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -51,7 +51,7 @@ func (dst *Circle) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src Circle) Value() (driver.Value, error) { - buf, err := CircleCodec{}.Encode(nil, 0, TextFormatCode, src, nil) + buf, err := CircleCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) if err != nil { return nil, err } @@ -68,42 +68,49 @@ func (CircleCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (CircleCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { - if value == nil { - return nil, nil - } - - var circle Circle - if v, ok := value.(CircleValuer); ok { - c, err := v.CircleValue() - if err != nil { - return nil, err - } - circle = c - } else { - return nil, fmt.Errorf("cannot convert %v to circle: %v", value, err) - } - - if !circle.Valid { - return nil, nil +func (CircleCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(CircleValuer); !ok { + return nil } switch format { case BinaryFormatCode: - buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.X)) - buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.Y)) - buf = pgio.AppendUint64(buf, math.Float64bits(circle.R)) - return buf, nil + return &encodePlanCircleCodecBinary{} case TextFormatCode: - buf = append(buf, fmt.Sprintf(`<(%s,%s),%s>`, - strconv.FormatFloat(circle.P.X, 'f', -1, 64), - strconv.FormatFloat(circle.P.Y, 'f', -1, 64), - strconv.FormatFloat(circle.R, 'f', -1, 64), - )...) - return buf, nil - default: - return nil, fmt.Errorf("unknown format code: %v", format) + return &encodePlanCircleCodecText{} } + + return nil +} + +type encodePlanCircleCodecBinary struct{} + +func (p *encodePlanCircleCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + circle, err := value.(CircleValuer).CircleValue() + if err != nil { + return nil, err + } + + buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(circle.R)) + return buf, nil +} + +type encodePlanCircleCodecText struct{} + +func (p *encodePlanCircleCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + circle, err := value.(CircleValuer).CircleValue() + if err != nil { + return nil, err + } + + buf = append(buf, fmt.Sprintf(`<(%s,%s),%s>`, + strconv.FormatFloat(circle.P.X, 'f', -1, 64), + strconv.FormatFloat(circle.P.Y, 'f', -1, 64), + strconv.FormatFloat(circle.R, 'f', -1, 64), + )...) + return buf, nil } func (CircleCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { diff --git a/pgtype/int.go b/pgtype/int.go index 21259beb..18b1ba90 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -119,7 +119,20 @@ func (Int2Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Int2Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { +func (Int2Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + return &encodePlanInt2CodecBinary{} + case TextFormatCode: + return &encodePlanInt2CodecText{} + } + + return nil +} + +type encodePlanInt2CodecBinary struct{} + +func (p *encodePlanInt2CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { n, valid, err := convertToInt64ForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to int2: %v", value, err) @@ -135,14 +148,28 @@ func (Int2Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{ return nil, fmt.Errorf("%d is less than minimum value for int2", n) } - switch format { - case BinaryFormatCode: - return pgio.AppendInt16(buf, int16(n)), nil - case TextFormatCode: - return append(buf, strconv.FormatInt(n, 10)...), nil - default: - return nil, fmt.Errorf("unknown format code: %v", format) + return pgio.AppendInt16(buf, int16(n)), nil +} + +type encodePlanInt2CodecText struct{} + +func (p *encodePlanInt2CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, valid, err := convertToInt64ForEncode(value) + if err != nil { + return nil, fmt.Errorf("cannot convert %v to int2: %v", value, err) } + if !valid { + return nil, nil + } + + if n > math.MaxInt16 { + return nil, fmt.Errorf("%d is greater than maximum value for int2", n) + } + if n < math.MinInt16 { + return nil, fmt.Errorf("%d is less than minimum value for int2", n) + } + + return append(buf, strconv.FormatInt(n, 10)...), nil } func (Int2Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { @@ -599,7 +626,20 @@ func (Int4Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Int4Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { +func (Int4Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + return &encodePlanInt4CodecBinary{} + case TextFormatCode: + return &encodePlanInt4CodecText{} + } + + return nil +} + +type encodePlanInt4CodecBinary struct{} + +func (p *encodePlanInt4CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { n, valid, err := convertToInt64ForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to int4: %v", value, err) @@ -615,14 +655,28 @@ func (Int4Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{ return nil, fmt.Errorf("%d is less than minimum value for int4", n) } - switch format { - case BinaryFormatCode: - return pgio.AppendInt32(buf, int32(n)), nil - case TextFormatCode: - return append(buf, strconv.FormatInt(n, 10)...), nil - default: - return nil, fmt.Errorf("unknown format code: %v", format) + return pgio.AppendInt32(buf, int32(n)), nil +} + +type encodePlanInt4CodecText struct{} + +func (p *encodePlanInt4CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, valid, err := convertToInt64ForEncode(value) + if err != nil { + return nil, fmt.Errorf("cannot convert %v to int4: %v", value, err) } + if !valid { + return nil, nil + } + + if n > math.MaxInt32 { + return nil, fmt.Errorf("%d is greater than maximum value for int4", n) + } + if n < math.MinInt32 { + return nil, fmt.Errorf("%d is less than minimum value for int4", n) + } + + return append(buf, strconv.FormatInt(n, 10)...), nil } func (Int4Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { @@ -1090,7 +1144,20 @@ func (Int8Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Int8Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { +func (Int8Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + return &encodePlanInt8CodecBinary{} + case TextFormatCode: + return &encodePlanInt8CodecText{} + } + + return nil +} + +type encodePlanInt8CodecBinary struct{} + +func (p *encodePlanInt8CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { n, valid, err := convertToInt64ForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to int8: %v", value, err) @@ -1106,14 +1173,28 @@ func (Int8Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{ return nil, fmt.Errorf("%d is less than minimum value for int8", n) } - switch format { - case BinaryFormatCode: - return pgio.AppendInt64(buf, int64(n)), nil - case TextFormatCode: - return append(buf, strconv.FormatInt(n, 10)...), nil - default: - return nil, fmt.Errorf("unknown format code: %v", format) + return pgio.AppendInt64(buf, int64(n)), nil +} + +type encodePlanInt8CodecText struct{} + +func (p *encodePlanInt8CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, valid, err := convertToInt64ForEncode(value) + if err != nil { + return nil, fmt.Errorf("cannot convert %v to int8: %v", value, err) } + if !valid { + return nil, nil + } + + if n > math.MaxInt64 { + return nil, fmt.Errorf("%d is greater than maximum value for int8", n) + } + if n < math.MinInt64 { + return nil, fmt.Errorf("%d is less than minimum value for int8", n) + } + + return append(buf, strconv.FormatInt(n, 10)...), nil } func (Int8Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index 546494d4..3f15dfce 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -120,7 +120,20 @@ func (Int<%= pg_byte_size %>Codec) PreferredFormat() int16 { return BinaryFormatCode } -func (Int<%= pg_byte_size %>Codec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { +func (Int<%= pg_byte_size %>Codec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + switch format { + case BinaryFormatCode: + return &encodePlanInt<%= pg_byte_size %>CodecBinary{} + case TextFormatCode: + return &encodePlanInt<%= pg_byte_size %>CodecText{} + } + + return nil +} + +type encodePlanInt<%= pg_byte_size %>CodecBinary struct{} + +func (p *encodePlanInt<%= pg_byte_size %>CodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { n, valid, err := convertToInt64ForEncode(value) if err != nil { return nil, fmt.Errorf("cannot convert %v to int<%= pg_byte_size %>: %v", value, err) @@ -136,14 +149,28 @@ func (Int<%= pg_byte_size %>Codec) Encode(ci *ConnInfo, oid uint32, format int16 return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n) } - switch format { - case BinaryFormatCode: - return pgio.AppendInt<%= pg_bit_size %>(buf, int<%= pg_bit_size %>(n)), nil - case TextFormatCode: - return append(buf, strconv.FormatInt(n, 10)...), nil - default: - return nil, fmt.Errorf("unknown format code: %v", format) + return pgio.AppendInt<%= pg_bit_size %>(buf, int<%= pg_bit_size %>(n)), nil +} + +type encodePlanInt<%= pg_byte_size %>CodecText struct{} + +func (p *encodePlanInt<%= pg_byte_size %>CodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + n, valid, err := convertToInt64ForEncode(value) + if err != nil { + return nil, fmt.Errorf("cannot convert %v to int<%= pg_byte_size %>: %v", value, err) } + if !valid { + return nil, nil + } + + if n > math.MaxInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is greater than maximum value for int<%= pg_byte_size %>", n) + } + if n < math.MinInt<%= pg_bit_size %> { + return nil, fmt.Errorf("%d is less than minimum value for int<%= pg_byte_size %>", n) + } + + return append(buf, strconv.FormatInt(n, 10)...), nil } func (Int<%= pg_byte_size %>Codec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 3d863373..a49d29c9 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -155,10 +155,9 @@ type Codec interface { // PreferredFormat returns the preferred format. PreferredFormat() int16 - // Encode appends the encoded bytes of value to buf. If value is the SQL NULL then append nothing and return - // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data - // written. - Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) + // PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be + // found then nil is returned. + PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan // PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If // actualTarget is true then the returned ScanPlan may be optimized to directly scan into target. If no plan can be @@ -172,12 +171,6 @@ type Codec interface { DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) } -// ResultFormatPreferrer allows a type to specify its preferred result format instead of it being inferred from -// whether it is also a BinaryDecoder. -type ResultFormatPreferrer interface { - PreferredResultFormat() int16 -} - type BinaryDecoder interface { // DecodeBinary decodes src into BinaryDecoder. If src is nil then the // original SQL value is NULL. BinaryDecoder takes ownership of src. The @@ -462,6 +455,14 @@ func (ci *ConnInfo) PreferAssignToOverSQLScannerForType(value interface{}) { ci.preferAssignToOverSQLScannerTypes[reflect.TypeOf(value)] = struct{}{} } +// EncodePlan is a precompiled plan to encode a particular type into a particular OID and format. +type EncodePlan interface { + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return + // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data + // written. + Encode(value interface{}, buf []byte) (newBuf []byte, err error) +} + // ScanPlan is a precompiled plan to scan into a type of destination. type ScanPlan interface { // Scan scans src into dst. If the dst type has changed in an incompatible way a ScanPlan should automatically @@ -929,10 +930,51 @@ func codecDecodeToTextFormat(codec Codec, ci *ConnInfo, oid uint32, format int16 if err != nil { return nil, err } - buf, err := codec.Encode(ci, oid, TextFormatCode, value, nil) + buf, err := ci.Encode(oid, TextFormatCode, value, nil) if err != nil { return nil, err } return string(buf), nil } } + +// PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be +// found then nil is returned. +func (ci *ConnInfo) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan { + + var dt *DataType + + if oid == 0 { + if dataType, ok := ci.DataTypeForValue(value); ok { + dt = dataType + } + } else { + if dataType, ok := ci.DataTypeForOID(oid); ok { + dt = dataType + } + } + + if dt != nil && dt.Codec != nil { + if plan := dt.Codec.PlanEncode(ci, oid, format, value); plan != nil { + return plan + } + + } + + return nil +} + +// Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return +// (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data +// written. +func (ci *ConnInfo) Encode(oid uint32, formatCode int16, value interface{}, buf []byte) (newBuf []byte, err error) { + if value == nil { + return nil, nil + } + + plan := ci.PlanEncode(oid, formatCode, value) + if plan == nil { + return nil, fmt.Errorf("unable to encode %v", value) + } + return plan.Encode(value, buf) +} diff --git a/values.go b/values.go index e084a69b..a60d4129 100644 --- a/values.go +++ b/values.go @@ -130,7 +130,7 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e } return string(buf), nil } else if dt.Codec != nil { - buf, err := dt.Codec.Encode(ci, 0, TextFormatCode, arg, nil) + buf, err := ci.Encode(0, TextFormatCode, arg, nil) if err != nil { return nil, err } @@ -230,7 +230,7 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32 } else if dt.Codec != nil { sp := len(buf) buf = pgio.AppendInt32(buf, -1) - argBuf, err := dt.Codec.Encode(ci, oid, BinaryFormatCode, arg, buf) + argBuf, err := ci.Encode(oid, BinaryFormatCode, arg, buf) if err != nil { return nil, err }