mirror of https://github.com/jackc/pgx.git
Add NamedArgs
https://github.com/jackc/pgx/issues/1186 https://github.com/jackc/pgx/issues/387non-blocking v5.0.0-alpha.3
parent
b72b0daa5a
commit
107196ab0c
2
conn.go
2
conn.go
|
@ -692,7 +692,7 @@ type QueryResultFormatsByOID map[uint32]int16
|
||||||
|
|
||||||
// QueryRewriter rewrites a query when used as the first arguments to a query method.
|
// QueryRewriter rewrites a query when used as the first arguments to a query method.
|
||||||
type QueryRewriter interface {
|
type QueryRewriter interface {
|
||||||
RewriteQuery(ctx context.Context, conn *Conn, sql string, args ...any) (newSQL string, newArgs []any)
|
RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query executes sql with args. It is safe to attempt to read from the returned Rows even if an error is returned. The
|
// Query executes sql with args. It is safe to attempt to read from the returned Rows even if an error is returned. The
|
||||||
|
|
|
@ -235,7 +235,7 @@ type testQueryRewriter struct {
|
||||||
args []any
|
args []any
|
||||||
}
|
}
|
||||||
|
|
||||||
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args ...any) (newSQL string, newArgs []any) {
|
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any) {
|
||||||
return qr.sql, qr.args
|
return qr.sql, qr.args
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,266 @@
|
||||||
|
package pgx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NamedArgs can be used as the first argument to a query method. It will replace every '@' named placeholder with a '$'
|
||||||
|
// ordinal placeholder and construct the appropriate arguments.
|
||||||
|
//
|
||||||
|
// For example, the following two queries are equivalent:
|
||||||
|
//
|
||||||
|
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2}))
|
||||||
|
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2}))
|
||||||
|
type NamedArgs map[string]any
|
||||||
|
|
||||||
|
// RewriteQuery implements the QueryRewriter interface.
|
||||||
|
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) {
|
||||||
|
l := &sqlLexer{
|
||||||
|
src: sql,
|
||||||
|
stateFn: rawState,
|
||||||
|
nameToOrdinal: make(map[namedArg]int, len(na)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for l.stateFn != nil {
|
||||||
|
l.stateFn = l.stateFn(l)
|
||||||
|
}
|
||||||
|
|
||||||
|
sb := strings.Builder{}
|
||||||
|
for _, p := range l.parts {
|
||||||
|
switch p := p.(type) {
|
||||||
|
case string:
|
||||||
|
sb.WriteString(p)
|
||||||
|
case namedArg:
|
||||||
|
sb.WriteRune('$')
|
||||||
|
sb.WriteString(strconv.Itoa(l.nameToOrdinal[p]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newArgs = make([]any, len(l.nameToOrdinal))
|
||||||
|
for name, ordinal := range l.nameToOrdinal {
|
||||||
|
newArgs[ordinal-1] = na[string(name)]
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), newArgs
|
||||||
|
}
|
||||||
|
|
||||||
|
type namedArg string
|
||||||
|
|
||||||
|
type sqlLexer struct {
|
||||||
|
src string
|
||||||
|
start int
|
||||||
|
pos int
|
||||||
|
nested int // multiline comment nesting level.
|
||||||
|
stateFn stateFn
|
||||||
|
parts []any
|
||||||
|
|
||||||
|
nameToOrdinal map[namedArg]int
|
||||||
|
}
|
||||||
|
|
||||||
|
type stateFn func(*sqlLexer) stateFn
|
||||||
|
|
||||||
|
func rawState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case 'e', 'E':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune == '\'' {
|
||||||
|
l.pos += width
|
||||||
|
return escapeStringState
|
||||||
|
}
|
||||||
|
case '\'':
|
||||||
|
return singleQuoteState
|
||||||
|
case '"':
|
||||||
|
return doubleQuoteState
|
||||||
|
case '@':
|
||||||
|
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if isLetter(nextRune) {
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos-width])
|
||||||
|
}
|
||||||
|
l.start = l.pos
|
||||||
|
return namedArgState
|
||||||
|
}
|
||||||
|
case '-':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune == '-' {
|
||||||
|
l.pos += width
|
||||||
|
return oneLineCommentState
|
||||||
|
}
|
||||||
|
case '/':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune == '*' {
|
||||||
|
l.pos += width
|
||||||
|
return multilineCommentState
|
||||||
|
}
|
||||||
|
case utf8.RuneError:
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLetter(r rune) bool {
|
||||||
|
return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z')
|
||||||
|
}
|
||||||
|
|
||||||
|
func namedArgState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
if r == utf8.RuneError {
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
na := namedArg(l.src[l.start:l.pos])
|
||||||
|
if _, found := l.nameToOrdinal[na]; !found {
|
||||||
|
l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
|
||||||
|
}
|
||||||
|
l.parts = append(l.parts, na)
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
} else if !(isLetter(r) || (r >= '0' && r <= '9')) {
|
||||||
|
l.pos -= width
|
||||||
|
na := namedArg(l.src[l.start:l.pos])
|
||||||
|
if _, found := l.nameToOrdinal[na]; !found {
|
||||||
|
l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
|
||||||
|
}
|
||||||
|
l.parts = append(l.parts, namedArg(na))
|
||||||
|
l.start = l.pos
|
||||||
|
return rawState
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func singleQuoteState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case '\'':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune != '\'' {
|
||||||
|
return rawState
|
||||||
|
}
|
||||||
|
l.pos += width
|
||||||
|
case utf8.RuneError:
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func doubleQuoteState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case '"':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune != '"' {
|
||||||
|
return rawState
|
||||||
|
}
|
||||||
|
l.pos += width
|
||||||
|
case utf8.RuneError:
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func escapeStringState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case '\\':
|
||||||
|
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
case '\'':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune != '\'' {
|
||||||
|
return rawState
|
||||||
|
}
|
||||||
|
l.pos += width
|
||||||
|
case utf8.RuneError:
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func oneLineCommentState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case '\\':
|
||||||
|
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
case '\n', '\r':
|
||||||
|
return rawState
|
||||||
|
case utf8.RuneError:
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func multilineCommentState(l *sqlLexer) stateFn {
|
||||||
|
for {
|
||||||
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
l.pos += width
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case '/':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune == '*' {
|
||||||
|
l.pos += width
|
||||||
|
l.nested++
|
||||||
|
}
|
||||||
|
case '*':
|
||||||
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||||
|
if nextRune != '/' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
l.pos += width
|
||||||
|
if l.nested == 0 {
|
||||||
|
return rawState
|
||||||
|
}
|
||||||
|
l.nested--
|
||||||
|
|
||||||
|
case utf8.RuneError:
|
||||||
|
if l.pos-l.start > 0 {
|
||||||
|
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||||
|
l.start = l.pos
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,96 @@
|
||||||
|
package pgx_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNamedArgsRewriteQuery(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for i, tt := range []struct {
|
||||||
|
sql string
|
||||||
|
args []any
|
||||||
|
namedArgs pgx.NamedArgs
|
||||||
|
expectedSQL string
|
||||||
|
expectedArgs []any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
sql: "select * from users where id = @id",
|
||||||
|
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||||
|
expectedSQL: "select * from users where id = $1",
|
||||||
|
expectedArgs: []any{int32(42)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "select * from t where foo < @abc and baz = @def and bar < @abc",
|
||||||
|
namedArgs: pgx.NamedArgs{"abc": int32(42), "def": int32(1)},
|
||||||
|
expectedSQL: "select * from t where foo < $1 and baz = $2 and bar < $1",
|
||||||
|
expectedArgs: []any{int32(42), int32(1)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "select @a::int, @b::text",
|
||||||
|
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||||
|
expectedSQL: "select $1::int, $2::text",
|
||||||
|
expectedArgs: []any{int32(42), "foo"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "at end @",
|
||||||
|
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||||
|
expectedSQL: "at end @",
|
||||||
|
expectedArgs: []any{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "ignores without letter after @ foo bar",
|
||||||
|
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||||
|
expectedSQL: "ignores without letter after @ foo bar",
|
||||||
|
expectedArgs: []any{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: "name must start with letter @1 foo bar",
|
||||||
|
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||||
|
expectedSQL: "name must start with letter @1 foo bar",
|
||||||
|
expectedArgs: []any{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: `select *, '@foo' as "@bar" from users where id = @id`,
|
||||||
|
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||||
|
expectedSQL: `select *, '@foo' as "@bar" from users where id = $1`,
|
||||||
|
expectedArgs: []any{int32(42)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: `select * -- @foo
|
||||||
|
from users -- @single line comments
|
||||||
|
where id = @id;`,
|
||||||
|
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||||
|
expectedSQL: `select * -- @foo
|
||||||
|
from users -- @single line comments
|
||||||
|
where id = $1;`,
|
||||||
|
expectedArgs: []any{int32(42)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sql: `select * /* @multi line
|
||||||
|
@comment
|
||||||
|
*/
|
||||||
|
/* /* with @nesting */ */
|
||||||
|
from users
|
||||||
|
where id = @id;`,
|
||||||
|
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||||
|
expectedSQL: `select * /* @multi line
|
||||||
|
@comment
|
||||||
|
*/
|
||||||
|
/* /* with @nesting */ */
|
||||||
|
from users
|
||||||
|
where id = $1;`,
|
||||||
|
expectedArgs: []any{int32(42)},
|
||||||
|
},
|
||||||
|
|
||||||
|
// test comments and quotes
|
||||||
|
} {
|
||||||
|
sql, args := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
|
||||||
|
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
|
||||||
|
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue