From c4b08378f235fdc9f8928f5f31316aff5602bfaf Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Mar 2022 21:27:17 -0600 Subject: [PATCH] Handle driver.Valuers inside Map.Encode --- conn.go | 28 +++++++--------------------- pgtype/pgtype.go | 11 +++++++++++ values.go | 14 -------------- 3 files changed, 18 insertions(+), 35 deletions(-) diff --git a/conn.go b/conn.go index 20968892..9f2fdcf0 100644 --- a/conn.go +++ b/conn.go @@ -472,22 +472,17 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []i return commandTag, err } -func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, arguments []interface{}) error { - if len(sd.ParamOIDs) != len(arguments) { - return fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(arguments)) +func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, args []interface{}) error { + if len(sd.ParamOIDs) != len(args) { + return fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args)) } c.eqb.Reset() - anynil.NormalizeSlice(arguments) - - args, err := evaluateDriverValuers(arguments) - if err != nil { - return err - } + anynil.NormalizeSlice(args) for i := range args { - err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) + err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) if err != nil { return err } @@ -675,11 +670,6 @@ optionLoop: rows.sql = sd.SQL anynil.NormalizeSlice(args) - args, err = evaluateDriverValuers(args) - if err != nil { - rows.fatal(err) - return rows, rows.err - } for i := range args { err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) @@ -836,13 +826,9 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { } anynil.NormalizeSlice(bi.arguments) - args, err := evaluateDriverValuers(bi.arguments) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} - } - for i := range args { - err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) + for i := range bi.arguments { + err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i]) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 3244b504..1cc809b1 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1720,6 +1720,17 @@ func (m *Map) Encode(oid uint32, formatCode int16, value interface{}, buf []byte plan := m.PlanEncode(oid, formatCode, value) if plan == nil { + if dv, ok := value.(driver.Valuer); ok { + if dv == nil { + return nil, nil + } + v, err := dv.Value() + if err != nil { + return nil, err + } + return m.Encode(oid, formatCode, v, buf) + } + return nil, fmt.Errorf("unable to encode %#v into OID %d", value, oid) } return plan.Encode(value, buf) diff --git a/values.go b/values.go index 0f34b6a6..a3343d81 100644 --- a/values.go +++ b/values.go @@ -159,17 +159,3 @@ func stripNamedType(val *reflect.Value) (interface{}, bool) { return nil, false } - -func evaluateDriverValuers(args []interface{}) ([]interface{}, error) { - for i, arg := range args { - switch arg := arg.(type) { - case driver.Valuer: - v, err := arg.Value() - if err != nil { - return nil, err - } - args[i] = v - } - } - return args, nil -}