mirror of https://github.com/jackc/pgx.git
StrictNamedArgs
parent
1b6227af11
commit
b6e5548341
|
@ -2,6 +2,7 @@ package pgx
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
@ -21,6 +22,34 @@ type NamedArgs map[string]any
|
|||
|
||||
// RewriteQuery implements the QueryRewriter interface.
|
||||
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
|
||||
return rewriteQuery(na, sql, false)
|
||||
}
|
||||
|
||||
// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all
|
||||
// named arguments that the sql query uses, and no extra arguments.
|
||||
type StrictNamedArgs map[string]any
|
||||
|
||||
// RewriteQuery implements the QueryRewriter interface.
|
||||
func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
|
||||
return rewriteQuery(sna, sql, true)
|
||||
}
|
||||
|
||||
type namedArg string
|
||||
|
||||
type sqlLexer struct {
|
||||
src string
|
||||
start int
|
||||
pos int
|
||||
nested int // multiline comment nesting level.
|
||||
stateFn stateFn
|
||||
parts []any
|
||||
|
||||
nameToOrdinal map[namedArg]int
|
||||
}
|
||||
|
||||
type stateFn func(*sqlLexer) stateFn
|
||||
|
||||
func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) {
|
||||
l := &sqlLexer{
|
||||
src: sql,
|
||||
stateFn: rawState,
|
||||
|
@ -44,27 +73,24 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar
|
|||
|
||||
newArgs = make([]any, len(l.nameToOrdinal))
|
||||
for name, ordinal := range l.nameToOrdinal {
|
||||
newArgs[ordinal-1] = na[string(name)]
|
||||
var found bool
|
||||
newArgs[ordinal-1], found = na[string(name)]
|
||||
if isStrict && !found {
|
||||
return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name)
|
||||
}
|
||||
}
|
||||
|
||||
if isStrict {
|
||||
for name := range na {
|
||||
if _, found := l.nameToOrdinal[namedArg(name)]; !found {
|
||||
return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), newArgs, nil
|
||||
}
|
||||
|
||||
type namedArg string
|
||||
|
||||
type sqlLexer struct {
|
||||
src string
|
||||
start int
|
||||
pos int
|
||||
nested int // multiline comment nesting level.
|
||||
stateFn stateFn
|
||||
parts []any
|
||||
|
||||
nameToOrdinal map[namedArg]int
|
||||
}
|
||||
|
||||
type stateFn func(*sqlLexer) stateFn
|
||||
|
||||
func rawState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
|
|
|
@ -93,6 +93,18 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
|
|||
where id = $1;`,
|
||||
expectedArgs: []any{int32(42)},
|
||||
},
|
||||
{
|
||||
sql: "extra provided argument",
|
||||
namedArgs: pgx.NamedArgs{"extra": int32(1)},
|
||||
expectedSQL: "extra provided argument",
|
||||
expectedArgs: []any{},
|
||||
},
|
||||
{
|
||||
sql: "@missing argument",
|
||||
namedArgs: pgx.NamedArgs{},
|
||||
expectedSQL: "$1 argument",
|
||||
expectedArgs: []any{nil},
|
||||
},
|
||||
|
||||
// test comments and quotes
|
||||
} {
|
||||
|
@ -102,3 +114,49 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
|
|||
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrictNamedArgsRewriteQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i, tt := range []struct {
|
||||
sql string
|
||||
namedArgs pgx.StrictNamedArgs
|
||||
expectedSQL string
|
||||
expectedArgs []any
|
||||
isExpectedError bool
|
||||
}{
|
||||
{
|
||||
sql: "no arguments",
|
||||
namedArgs: pgx.StrictNamedArgs{},
|
||||
expectedSQL: "no arguments",
|
||||
expectedArgs: []any{},
|
||||
isExpectedError: false,
|
||||
},
|
||||
{
|
||||
sql: "@all @matches",
|
||||
namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)},
|
||||
expectedSQL: "$1 $2",
|
||||
expectedArgs: []any{int32(1), int32(2)},
|
||||
isExpectedError: false,
|
||||
},
|
||||
{
|
||||
sql: "extra provided argument",
|
||||
namedArgs: pgx.StrictNamedArgs{"extra": int32(1)},
|
||||
isExpectedError: true,
|
||||
},
|
||||
{
|
||||
sql: "@missing argument",
|
||||
namedArgs: pgx.StrictNamedArgs{},
|
||||
isExpectedError: true,
|
||||
},
|
||||
} {
|
||||
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil)
|
||||
if tt.isExpectedError {
|
||||
assert.Errorf(t, err, "%d", i)
|
||||
} else {
|
||||
require.NoErrorf(t, err, "%d", i)
|
||||
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
|
||||
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue