mirror of https://github.com/jackc/pgx.git
Unify serialization code in values.go
parent
009cdfa0b1
commit
89bcc0670c
106
sanitize.go
106
sanitize.go
|
@ -1,106 +0,0 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SerializationError string
|
||||
|
||||
func (e SerializationError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// TextEncoder is an interface used to encode values in text format for
|
||||
// transmission to the PostgreSQL server. It is used by unprepared
|
||||
// queries and for prepared queries when the type does not implement
|
||||
// BinaryEncoder
|
||||
type TextEncoder interface {
|
||||
// EncodeText MUST sanitize (and quote, if necessary) the returned string.
|
||||
// It will be interpolated directly into the SQL string.
|
||||
EncodeText() (string, error)
|
||||
}
|
||||
|
||||
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
|
||||
|
||||
// QuoteString escapes and quotes a string making it safe for interpolation
|
||||
// into an SQL string.
|
||||
func QuoteString(input string) (output string) {
|
||||
output = "'" + strings.Replace(input, "'", "''", -1) + "'"
|
||||
return
|
||||
}
|
||||
|
||||
// QuoteIdentifier escapes and quotes an identifier making it safe for
|
||||
// interpolation into an SQL string
|
||||
func QuoteIdentifier(input string) (output string) {
|
||||
output = `"` + strings.Replace(input, `"`, `""`, -1) + `"`
|
||||
return
|
||||
}
|
||||
|
||||
// SanitizeSql substitutely args positionaly into sql. Placeholder values are
|
||||
// $ prefixed integers like $1, $2, $3, etc. args are sanitized and quoted as
|
||||
// appropriate.
|
||||
func SanitizeSql(sql string, args ...interface{}) (output string, err error) {
|
||||
replacer := func(match string) (replacement string) {
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
n, _ := strconv.ParseInt(match[1:], 10, 0)
|
||||
if int(n-1) >= len(args) {
|
||||
err = fmt.Errorf("Cannot interpolate %v, only %d arguments provided", match, len(args))
|
||||
return
|
||||
}
|
||||
|
||||
switch arg := args[n-1].(type) {
|
||||
case string:
|
||||
return QuoteString(arg)
|
||||
case int:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int8:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int16:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int32:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int64:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case time.Time:
|
||||
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700"))
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint8:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint16:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint32:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint64:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(arg), 'f', -1, 32)
|
||||
case float64:
|
||||
return strconv.FormatFloat(arg, 'f', -1, 64)
|
||||
case bool:
|
||||
return strconv.FormatBool(arg)
|
||||
case []byte:
|
||||
return `E'\\x` + hex.EncodeToString(arg) + `'`
|
||||
case nil:
|
||||
return "null"
|
||||
case TextEncoder:
|
||||
var s string
|
||||
s, err = arg.EncodeText()
|
||||
return s
|
||||
default:
|
||||
err = SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
output = literalPattern.ReplaceAllStringFunc(sql, replacer)
|
||||
return
|
||||
}
|
|
@ -1,67 +0,0 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestQuoteString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if pgx.QuoteString("test") != "'test'" {
|
||||
t.Error("Failed to quote string")
|
||||
}
|
||||
|
||||
if pgx.QuoteString("Jack's") != "'Jack''s'" {
|
||||
t.Error("Failed to quote and escape string with embedded quote")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSql(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
successTests := []struct {
|
||||
sql string
|
||||
args []interface{}
|
||||
output string
|
||||
}{
|
||||
{"select $1", []interface{}{nil}, "select null"},
|
||||
{"select $1", []interface{}{"Jack's"}, "select 'Jack''s'"},
|
||||
{"select $1", []interface{}{42}, "select 42"},
|
||||
{"select $1", []interface{}{1.23}, "select 1.23"},
|
||||
{"select $1", []interface{}{true}, "select true"},
|
||||
{"select $1, $2, $3", []interface{}{"Jack's", 42, 1.23}, "select 'Jack''s', 42, 1.23"},
|
||||
{"select $1", []interface{}{[]byte{0, 15, 255, 17}}, `select E'\\x000fff11'`},
|
||||
{"select $1", []interface{}{&pgx.NullInt64{Int64: 1, Valid: true}}, "select 1"},
|
||||
}
|
||||
|
||||
for i, tt := range successTests {
|
||||
san, err := pgx.SanitizeSql(tt.sql, tt.args...)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v (sql -> %v, args -> %v)", i, err, tt.sql, tt.args)
|
||||
}
|
||||
if san != tt.output {
|
||||
t.Errorf("%d. Expected %v, got %v (sql -> %v, args -> %v)", i, tt.output, san, tt.sql, tt.args)
|
||||
}
|
||||
}
|
||||
|
||||
errorTests := []struct {
|
||||
sql string
|
||||
args []interface{}
|
||||
err string
|
||||
}{
|
||||
{"select $1", []interface{}{t}, "is not a core type and it does not implement TextEncoder"},
|
||||
{"select $1, $2", []interface{}{}, "Cannot interpolate $1, only 0 arguments provided"},
|
||||
}
|
||||
|
||||
for i, tt := range errorTests {
|
||||
_, err := pgx.SanitizeSql(tt.sql, tt.args...)
|
||||
if err == nil {
|
||||
t.Errorf("%d. Unexpected success (sql -> %v, args -> %v)", i, tt.sql, tt.args, err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.err) {
|
||||
t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, args -> %v)", i, tt.err, err, tt.sql, tt.args)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -6,6 +6,7 @@ import (
|
|||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
@ -29,10 +30,26 @@ const (
|
|||
BinaryFormatCode = 1
|
||||
)
|
||||
|
||||
type SerializationError string
|
||||
|
||||
func (e SerializationError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
type Scanner interface {
|
||||
Scan(qr *QueryResult, fd *FieldDescription, size int32) error
|
||||
}
|
||||
|
||||
// TextEncoder is an interface used to encode values in text format for
|
||||
// transmission to the PostgreSQL server. It is used by unprepared
|
||||
// queries and for prepared queries when the type does not implement
|
||||
// BinaryEncoder
|
||||
type TextEncoder interface {
|
||||
// EncodeText MUST sanitize (and quote, if necessary) the returned string.
|
||||
// It will be interpolated directly into the SQL string.
|
||||
EncodeText() (string, error)
|
||||
}
|
||||
|
||||
// BinaryEncoder is an interface used to encode values in binary format for
|
||||
// transmission to the PostgreSQL server. It is used by prepared queries.
|
||||
type BinaryEncoder interface {
|
||||
|
@ -46,6 +63,86 @@ type NullInt64 struct {
|
|||
Valid bool // Valid is true if Int64 is not NULL
|
||||
}
|
||||
|
||||
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
|
||||
|
||||
// QuoteString escapes and quotes a string making it safe for interpolation
|
||||
// into an SQL string.
|
||||
func QuoteString(input string) (output string) {
|
||||
output = "'" + strings.Replace(input, "'", "''", -1) + "'"
|
||||
return
|
||||
}
|
||||
|
||||
// QuoteIdentifier escapes and quotes an identifier making it safe for
|
||||
// interpolation into an SQL string
|
||||
func QuoteIdentifier(input string) (output string) {
|
||||
output = `"` + strings.Replace(input, `"`, `""`, -1) + `"`
|
||||
return
|
||||
}
|
||||
|
||||
// SanitizeSql substitutely args positionaly into sql. Placeholder values are
|
||||
// $ prefixed integers like $1, $2, $3, etc. args are sanitized and quoted as
|
||||
// appropriate.
|
||||
func SanitizeSql(sql string, args ...interface{}) (output string, err error) {
|
||||
replacer := func(match string) (replacement string) {
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
n, _ := strconv.ParseInt(match[1:], 10, 0)
|
||||
if int(n-1) >= len(args) {
|
||||
err = fmt.Errorf("Cannot interpolate %v, only %d arguments provided", match, len(args))
|
||||
return
|
||||
}
|
||||
|
||||
switch arg := args[n-1].(type) {
|
||||
case string:
|
||||
return QuoteString(arg)
|
||||
case int:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int8:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int16:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int32:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case int64:
|
||||
return strconv.FormatInt(int64(arg), 10)
|
||||
case time.Time:
|
||||
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700"))
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint8:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint16:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint32:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case uint64:
|
||||
return strconv.FormatUint(uint64(arg), 10)
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(arg), 'f', -1, 32)
|
||||
case float64:
|
||||
return strconv.FormatFloat(arg, 'f', -1, 64)
|
||||
case bool:
|
||||
return strconv.FormatBool(arg)
|
||||
case []byte:
|
||||
return `E'\\x` + hex.EncodeToString(arg) + `'`
|
||||
case nil:
|
||||
return "null"
|
||||
case TextEncoder:
|
||||
var s string
|
||||
s, err = arg.EncodeText()
|
||||
return s
|
||||
default:
|
||||
err = SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
output = literalPattern.ReplaceAllStringFunc(sql, replacer)
|
||||
return
|
||||
}
|
||||
|
||||
func (n *NullInt64) Scan(qr *QueryResult, fd *FieldDescription, size int32) error {
|
||||
if size == -1 {
|
||||
n.Int64, n.Valid = 0, false
|
|
@ -1,10 +1,72 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestQuoteString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if pgx.QuoteString("test") != "'test'" {
|
||||
t.Error("Failed to quote string")
|
||||
}
|
||||
|
||||
if pgx.QuoteString("Jack's") != "'Jack''s'" {
|
||||
t.Error("Failed to quote and escape string with embedded quote")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSql(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
successTests := []struct {
|
||||
sql string
|
||||
args []interface{}
|
||||
output string
|
||||
}{
|
||||
{"select $1", []interface{}{nil}, "select null"},
|
||||
{"select $1", []interface{}{"Jack's"}, "select 'Jack''s'"},
|
||||
{"select $1", []interface{}{42}, "select 42"},
|
||||
{"select $1", []interface{}{1.23}, "select 1.23"},
|
||||
{"select $1", []interface{}{true}, "select true"},
|
||||
{"select $1, $2, $3", []interface{}{"Jack's", 42, 1.23}, "select 'Jack''s', 42, 1.23"},
|
||||
{"select $1", []interface{}{[]byte{0, 15, 255, 17}}, `select E'\\x000fff11'`},
|
||||
{"select $1", []interface{}{&pgx.NullInt64{Int64: 1, Valid: true}}, "select 1"},
|
||||
}
|
||||
|
||||
for i, tt := range successTests {
|
||||
san, err := pgx.SanitizeSql(tt.sql, tt.args...)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v (sql -> %v, args -> %v)", i, err, tt.sql, tt.args)
|
||||
}
|
||||
if san != tt.output {
|
||||
t.Errorf("%d. Expected %v, got %v (sql -> %v, args -> %v)", i, tt.output, san, tt.sql, tt.args)
|
||||
}
|
||||
}
|
||||
|
||||
errorTests := []struct {
|
||||
sql string
|
||||
args []interface{}
|
||||
err string
|
||||
}{
|
||||
{"select $1", []interface{}{t}, "is not a core type and it does not implement TextEncoder"},
|
||||
{"select $1, $2", []interface{}{}, "Cannot interpolate $1, only 0 arguments provided"},
|
||||
}
|
||||
|
||||
for i, tt := range errorTests {
|
||||
_, err := pgx.SanitizeSql(tt.sql, tt.args...)
|
||||
if err == nil {
|
||||
t.Errorf("%d. Unexpected success (sql -> %v, args -> %v)", i, tt.sql, tt.args, err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.err) {
|
||||
t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, args -> %v)", i, tt.err, err, tt.sql, tt.args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
Loading…
Reference in New Issue