diff --git a/extended_query_builder.go b/extended_query_builder.go index 9c9de5b2..0056cec7 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -161,58 +161,51 @@ func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid ui // no way to safely use binary or to specify the parameter OIDs. func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error { for _, arg := range args { - if arg == nil { - err := eqb.appendParam(m, 0, TextFormatCode, arg) - if err != nil { - return err - } - } else { - dt, ok := m.TypeForValue(arg) - if !ok { - var tv pgtype.TextValuer - if tv, ok = arg.(pgtype.TextValuer); ok { - t, err := tv.TextValue() - if err != nil { - return err - } + oid, modArg, err := eqb.oidAndArgForQueryExecModeExec(m, arg) + if err != nil { + return err + } - dt, ok = m.TypeForOID(pgtype.TextOID) - if ok { - arg = t - } - } - } - if !ok { - var dv driver.Valuer - if dv, ok = arg.(driver.Valuer); ok { - v, err := dv.Value() - if err != nil { - return err - } - dt, ok = m.TypeForValue(v) - if ok { - arg = v - } - } - } - if !ok { - var str fmt.Stringer - if str, ok = arg.(fmt.Stringer); ok { - dt, ok = m.TypeForOID(pgtype.TextOID) - if ok { - arg = str.String() - } - } - } - if !ok { - return &unknownArgumentTypeQueryExecModeExecError{arg: arg} - } - err := eqb.appendParam(m, dt.OID, TextFormatCode, arg) - if err != nil { - return err - } + err = eqb.appendParam(m, oid, pgtype.TextFormatCode, modArg) + if err != nil { + return err } } return nil } + +func (eqb *ExtendedQueryBuilder) oidAndArgForQueryExecModeExec(m *pgtype.Map, arg any) (uint32, any, error) { + if arg == nil { + return 0, arg, nil + } + + if dt, ok := m.TypeForValue(arg); ok { + return dt.OID, arg, nil + } + + if textValuer, ok := arg.(pgtype.TextValuer); ok { + tv, err := textValuer.TextValue() + if err != nil { + return 0, nil, err + } + + return pgtype.TextOID, tv, nil + } + + if dv, ok := arg.(driver.Valuer); ok { + v, err := dv.Value() + if err != nil { + return 0, nil, err + } + if dt, ok := m.TypeForValue(v); ok { + return dt.OID, v, nil + } + } + + if str, ok := arg.(fmt.Stringer); ok { + return pgtype.TextOID, str.String(), nil + } + + return 0, nil, &unknownArgumentTypeQueryExecModeExecError{arg: arg} +}