From 38dd42de4bc5ba4b7492b3a07ae4e472fab9517d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 10 May 2020 14:09:02 -0500 Subject: [PATCH] Support new pgtype format preferences --- extended_query_builder.go | 61 +++++++++++++++++++++------------------ go.mod | 2 +- go.sum | 2 ++ values.go | 4 ++- 4 files changed, 39 insertions(+), 30 deletions(-) diff --git a/extended_query_builder.go b/extended_query_builder.go index b6a85a9e..09419f0d 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -21,7 +21,7 @@ func (eqb *extendedQueryBuilder) AppendParam(ci *pgtype.ConnInfo, oid uint32, ar f := chooseParameterFormatCode(ci, oid, arg) eqb.paramFormats = append(eqb.paramFormats, f) - v, err := eqb.encodeExtendedParamValue(ci, oid, arg) + v, err := eqb.encodeExtendedParamValue(ci, oid, f, arg) if err != nil { return err } @@ -62,7 +62,7 @@ func (eqb *extendedQueryBuilder) Reset() { } -func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, oid uint32, arg interface{}) ([]byte, error) { +func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, oid uint32, formatCode int16, arg interface{}) ([]byte, error) { if arg == nil { return nil, nil } @@ -82,36 +82,41 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o var buf []byte pos := len(eqb.paramValueBytes) - switch arg := arg.(type) { - case pgtype.BinaryEncoder: - buf, err = arg.EncodeBinary(ci, eqb.paramValueBytes) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - eqb.paramValueBytes = buf - return eqb.paramValueBytes[pos:], nil - case pgtype.TextEncoder: - buf, err = arg.EncodeText(ci, eqb.paramValueBytes) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - eqb.paramValueBytes = buf - return eqb.paramValueBytes[pos:], nil - case string: + if arg, ok := arg.(string); ok { return []byte(arg), nil } + if formatCode == TextFormatCode { + if arg, ok := arg.(pgtype.TextEncoder); ok { + buf, err = arg.EncodeText(ci, eqb.paramValueBytes) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + eqb.paramValueBytes = buf + return eqb.paramValueBytes[pos:], nil + } + } else if formatCode == BinaryFormatCode { + if arg, ok := arg.(pgtype.BinaryEncoder); ok { + buf, err = arg.EncodeBinary(ci, eqb.paramValueBytes) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + eqb.paramValueBytes = buf + return eqb.paramValueBytes[pos:], nil + } + } + if argIsPtr { // We have already checked that arg is not pointing to nil, // so it is safe to dereference here. arg = refVal.Elem().Interface() - return eqb.encodeExtendedParamValue(ci, oid, arg) + return eqb.encodeExtendedParamValue(ci, oid, formatCode, arg) } if dt, ok := ci.DataTypeForOID(oid); ok { @@ -124,14 +129,14 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o if err != nil { return nil, err } - return eqb.encodeExtendedParamValue(ci, oid, v) + return eqb.encodeExtendedParamValue(ci, oid, formatCode, v) } } return nil, err } - return eqb.encodeExtendedParamValue(ci, oid, value) + return eqb.encodeExtendedParamValue(ci, oid, formatCode, value) } // There is no data type registered for the destination OID, but maybe there is data type registered for the arg @@ -157,7 +162,7 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o } if strippedArg, ok := stripNamedType(&refVal); ok { - return eqb.encodeExtendedParamValue(ci, oid, strippedArg) + return eqb.encodeExtendedParamValue(ci, oid, formatCode, strippedArg) } return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } diff --git a/go.mod b/go.mod index ef6e1568..87013a7a 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/jackc/pgconn v1.5.0 github.com/jackc/pgio v1.0.0 github.com/jackc/pgproto3/v2 v2.0.1 - github.com/jackc/pgtype v1.3.1-0.20200510045248-7e66ab1e146c + github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a github.com/jackc/puddle v1.1.1 github.com/rs/zerolog v1.15.0 github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc diff --git a/go.sum b/go.sum index 84abfb55..0c9e4b20 100644 --- a/go.sum +++ b/go.sum @@ -59,6 +59,8 @@ github.com/jackc/pgtype v1.3.1-0.20200508211315-97bbe6ae20e2 h1:Y6cErz3hUojOwnjU github.com/jackc/pgtype v1.3.1-0.20200508211315-97bbe6ae20e2/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= github.com/jackc/pgtype v1.3.1-0.20200510045248-7e66ab1e146c h1:id5j6vOwHhbR7BYdGyb0sDMQjNsKTO+mXWaJxiwKu5M= github.com/jackc/pgtype v1.3.1-0.20200510045248-7e66ab1e146c/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= +github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a h1:XUNeoL8E15IgWouQ8gfA6EPHOfTqVetdxBhAKMYKNGo= +github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= diff --git a/values.go b/values.go index 355510b6..c9ea5637 100644 --- a/values.go +++ b/values.go @@ -226,7 +226,9 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32 // argument to a prepared statement. It defaults to TextFormatCode if no // determination can be made. func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid uint32, arg interface{}) int16 { - switch arg.(type) { + switch arg := arg.(type) { + case pgtype.ParamFormatPreferrer: + return arg.PreferredParamFormat() case pgtype.BinaryEncoder: return BinaryFormatCode case string, *string, pgtype.TextEncoder: