mirror of https://github.com/jackc/pgx.git
StrictNamedArgs
parent
1b6227af11
commit
b6e5548341
|
@ -2,6 +2,7 @@ package pgx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
@ -21,6 +22,34 @@ type NamedArgs map[string]any
|
||||||
|
|
||||||
// RewriteQuery implements the QueryRewriter interface.
|
// RewriteQuery implements the QueryRewriter interface.
|
||||||
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
|
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
|
||||||
|
return rewriteQuery(na, sql, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all
|
||||||
|
// named arguments that the sql query uses, and no extra arguments.
|
||||||
|
type StrictNamedArgs map[string]any
|
||||||
|
|
||||||
|
// RewriteQuery implements the QueryRewriter interface.
|
||||||
|
func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
|
||||||
|
return rewriteQuery(sna, sql, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
type namedArg string
|
||||||
|
|
||||||
|
type sqlLexer struct {
|
||||||
|
src string
|
||||||
|
start int
|
||||||
|
pos int
|
||||||
|
nested int // multiline comment nesting level.
|
||||||
|
stateFn stateFn
|
||||||
|
parts []any
|
||||||
|
|
||||||
|
nameToOrdinal map[namedArg]int
|
||||||
|
}
|
||||||
|
|
||||||
|
type stateFn func(*sqlLexer) stateFn
|
||||||
|
|
||||||
|
func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) {
|
||||||
l := &sqlLexer{
|
l := &sqlLexer{
|
||||||
src: sql,
|
src: sql,
|
||||||
stateFn: rawState,
|
stateFn: rawState,
|
||||||
|
@ -44,27 +73,24 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar
|
||||||
|
|
||||||
newArgs = make([]any, len(l.nameToOrdinal))
|
newArgs = make([]any, len(l.nameToOrdinal))
|
||||||
for name, ordinal := range l.nameToOrdinal {
|
for name, ordinal := range l.nameToOrdinal {
|
||||||
newArgs[ordinal-1] = na[string(name)]
|
var found bool
|
||||||
|
newArgs[ordinal-1], found = na[string(name)]
|
||||||
|
if isStrict && !found {
|
||||||
|
return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isStrict {
|
||||||
|
for name := range na {
|
||||||
|
if _, found := l.nameToOrdinal[namedArg(name)]; !found {
|
||||||
|
return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String(), newArgs, nil
|
return sb.String(), newArgs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type namedArg string
|
|
||||||
|
|
||||||
type sqlLexer struct {
|
|
||||||
src string
|
|
||||||
start int
|
|
||||||
pos int
|
|
||||||
nested int // multiline comment nesting level.
|
|
||||||
stateFn stateFn
|
|
||||||
parts []any
|
|
||||||
|
|
||||||
nameToOrdinal map[namedArg]int
|
|
||||||
}
|
|
||||||
|
|
||||||
type stateFn func(*sqlLexer) stateFn
|
|
||||||
|
|
||||||
func rawState(l *sqlLexer) stateFn {
|
func rawState(l *sqlLexer) stateFn {
|
||||||
for {
|
for {
|
||||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
|
|
@ -93,6 +93,18 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
|
||||||
where id = $1;`,
|
where id = $1;`,
|
||||||
expectedArgs: []any{int32(42)},
|
expectedArgs: []any{int32(42)},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
sql: "extra provided argument",
|
||||||
|
namedArgs: pgx.NamedArgs{"extra": int32(1)},
|
||||||
|
expectedSQL: "extra provided argument",
|
||||||
|
expectedArgs: []any{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "@missing argument",
|
||||||
|
namedArgs: pgx.NamedArgs{},
|
||||||
|
expectedSQL: "$1 argument",
|
||||||
|
expectedArgs: []any{nil},
|
||||||
|
},
|
||||||
|
|
||||||
// test comments and quotes
|
// test comments and quotes
|
||||||
} {
|
} {
|
||||||
|
@ -102,3 +114,49 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
|
||||||
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStrictNamedArgsRewriteQuery(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for i, tt := range []struct {
|
||||||
|
sql string
|
||||||
|
namedArgs pgx.StrictNamedArgs
|
||||||
|
expectedSQL string
|
||||||
|
expectedArgs []any
|
||||||
|
isExpectedError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
sql: "no arguments",
|
||||||
|
namedArgs: pgx.StrictNamedArgs{},
|
||||||
|
expectedSQL: "no arguments",
|
||||||
|
expectedArgs: []any{},
|
||||||
|
isExpectedError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "@all @matches",
|
||||||
|
namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)},
|
||||||
|
expectedSQL: "$1 $2",
|
||||||
|
expectedArgs: []any{int32(1), int32(2)},
|
||||||
|
isExpectedError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "extra provided argument",
|
||||||
|
namedArgs: pgx.StrictNamedArgs{"extra": int32(1)},
|
||||||
|
isExpectedError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "@missing argument",
|
||||||
|
namedArgs: pgx.StrictNamedArgs{},
|
||||||
|
isExpectedError: true,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil)
|
||||||
|
if tt.isExpectedError {
|
||||||
|
assert.Errorf(t, err, "%d", i)
|
||||||
|
} else {
|
||||||
|
require.NoErrorf(t, err, "%d", i)
|
||||||
|
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
|
||||||
|
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue