Handle driver.Valuers inside Map.Encode

query-exec-mode
Jack Christensen 2022-03-05 21:27:17 -06:00
parent 0905d1f452
commit c4b08378f2
3 changed files with 18 additions and 35 deletions

28
conn.go
View File

@ -472,22 +472,17 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []i
return commandTag, err
}
func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, arguments []interface{}) error {
if len(sd.ParamOIDs) != len(arguments) {
return fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(arguments))
func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, args []interface{}) error {
if len(sd.ParamOIDs) != len(args) {
return fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args))
}
c.eqb.Reset()
anynil.NormalizeSlice(arguments)
args, err := evaluateDriverValuers(arguments)
if err != nil {
return err
}
anynil.NormalizeSlice(args)
for i := range args {
err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i])
err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i])
if err != nil {
return err
}
@ -675,11 +670,6 @@ optionLoop:
rows.sql = sd.SQL
anynil.NormalizeSlice(args)
args, err = evaluateDriverValuers(args)
if err != nil {
rows.fatal(err)
return rows, rows.err
}
for i := range args {
err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i])
@ -836,13 +826,9 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
}
anynil.NormalizeSlice(bi.arguments)
args, err := evaluateDriverValuers(bi.arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
for i := range args {
err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i])
for i := range bi.arguments {
err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i])
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}

View File

@ -1720,6 +1720,17 @@ func (m *Map) Encode(oid uint32, formatCode int16, value interface{}, buf []byte
plan := m.PlanEncode(oid, formatCode, value)
if plan == nil {
if dv, ok := value.(driver.Valuer); ok {
if dv == nil {
return nil, nil
}
v, err := dv.Value()
if err != nil {
return nil, err
}
return m.Encode(oid, formatCode, v, buf)
}
return nil, fmt.Errorf("unable to encode %#v into OID %d", value, oid)
}
return plan.Encode(value, buf)

View File

@ -159,17 +159,3 @@ func stripNamedType(val *reflect.Value) (interface{}, bool) {
return nil, false
}
func evaluateDriverValuers(args []interface{}) ([]interface{}, error) {
for i, arg := range args {
switch arg := arg.(type) {
case driver.Valuer:
v, err := arg.Value()
if err != nil {
return nil, err
}
args[i] = v
}
}
return args, nil
}