Unify serialization code in values.go

scan-io
Jack Christensen 2014-07-09 08:05:03 -05:00
parent 009cdfa0b1
commit 89bcc0670c
4 changed files with 159 additions and 173 deletions

View File

@ -1,106 +0,0 @@
package pgx
import (
"encoding/hex"
"fmt"
"regexp"
"strconv"
"strings"
"time"
)
type SerializationError string
func (e SerializationError) Error() string {
return string(e)
}
// TextEncoder is an interface used to encode values in text format for
// transmission to the PostgreSQL server. It is used by unprepared
// queries and for prepared queries when the type does not implement
// BinaryEncoder
type TextEncoder interface {
// EncodeText MUST sanitize (and quote, if necessary) the returned string.
// It will be interpolated directly into the SQL string.
EncodeText() (string, error)
}
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
// QuoteString escapes and quotes a string making it safe for interpolation
// into an SQL string.
func QuoteString(input string) (output string) {
output = "'" + strings.Replace(input, "'", "''", -1) + "'"
return
}
// QuoteIdentifier escapes and quotes an identifier making it safe for
// interpolation into an SQL string
func QuoteIdentifier(input string) (output string) {
output = `"` + strings.Replace(input, `"`, `""`, -1) + `"`
return
}
// SanitizeSql substitutely args positionaly into sql. Placeholder values are
// $ prefixed integers like $1, $2, $3, etc. args are sanitized and quoted as
// appropriate.
func SanitizeSql(sql string, args ...interface{}) (output string, err error) {
replacer := func(match string) (replacement string) {
if err != nil {
return ""
}
n, _ := strconv.ParseInt(match[1:], 10, 0)
if int(n-1) >= len(args) {
err = fmt.Errorf("Cannot interpolate %v, only %d arguments provided", match, len(args))
return
}
switch arg := args[n-1].(type) {
case string:
return QuoteString(arg)
case int:
return strconv.FormatInt(int64(arg), 10)
case int8:
return strconv.FormatInt(int64(arg), 10)
case int16:
return strconv.FormatInt(int64(arg), 10)
case int32:
return strconv.FormatInt(int64(arg), 10)
case int64:
return strconv.FormatInt(int64(arg), 10)
case time.Time:
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700"))
case uint:
return strconv.FormatUint(uint64(arg), 10)
case uint8:
return strconv.FormatUint(uint64(arg), 10)
case uint16:
return strconv.FormatUint(uint64(arg), 10)
case uint32:
return strconv.FormatUint(uint64(arg), 10)
case uint64:
return strconv.FormatUint(uint64(arg), 10)
case float32:
return strconv.FormatFloat(float64(arg), 'f', -1, 32)
case float64:
return strconv.FormatFloat(arg, 'f', -1, 64)
case bool:
return strconv.FormatBool(arg)
case []byte:
return `E'\\x` + hex.EncodeToString(arg) + `'`
case nil:
return "null"
case TextEncoder:
var s string
s, err = arg.EncodeText()
return s
default:
err = SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
return ""
}
}
output = literalPattern.ReplaceAllStringFunc(sql, replacer)
return
}

View File

@ -1,67 +0,0 @@
package pgx_test
import (
"github.com/jackc/pgx"
"strings"
"testing"
)
func TestQuoteString(t *testing.T) {
t.Parallel()
if pgx.QuoteString("test") != "'test'" {
t.Error("Failed to quote string")
}
if pgx.QuoteString("Jack's") != "'Jack''s'" {
t.Error("Failed to quote and escape string with embedded quote")
}
}
func TestSanitizeSql(t *testing.T) {
t.Parallel()
successTests := []struct {
sql string
args []interface{}
output string
}{
{"select $1", []interface{}{nil}, "select null"},
{"select $1", []interface{}{"Jack's"}, "select 'Jack''s'"},
{"select $1", []interface{}{42}, "select 42"},
{"select $1", []interface{}{1.23}, "select 1.23"},
{"select $1", []interface{}{true}, "select true"},
{"select $1, $2, $3", []interface{}{"Jack's", 42, 1.23}, "select 'Jack''s', 42, 1.23"},
{"select $1", []interface{}{[]byte{0, 15, 255, 17}}, `select E'\\x000fff11'`},
{"select $1", []interface{}{&pgx.NullInt64{Int64: 1, Valid: true}}, "select 1"},
}
for i, tt := range successTests {
san, err := pgx.SanitizeSql(tt.sql, tt.args...)
if err != nil {
t.Errorf("%d. Unexpected failure: %v (sql -> %v, args -> %v)", i, err, tt.sql, tt.args)
}
if san != tt.output {
t.Errorf("%d. Expected %v, got %v (sql -> %v, args -> %v)", i, tt.output, san, tt.sql, tt.args)
}
}
errorTests := []struct {
sql string
args []interface{}
err string
}{
{"select $1", []interface{}{t}, "is not a core type and it does not implement TextEncoder"},
{"select $1, $2", []interface{}{}, "Cannot interpolate $1, only 0 arguments provided"},
}
for i, tt := range errorTests {
_, err := pgx.SanitizeSql(tt.sql, tt.args...)
if err == nil {
t.Errorf("%d. Unexpected success (sql -> %v, args -> %v)", i, tt.sql, tt.args, err)
}
if !strings.Contains(err.Error(), tt.err) {
t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, args -> %v)", i, tt.err, err, tt.sql, tt.args)
}
}
}

View File

@ -6,6 +6,7 @@ import (
"math"
"regexp"
"strconv"
"strings"
"time"
"unsafe"
)
@ -29,10 +30,26 @@ const (
BinaryFormatCode = 1
)
type SerializationError string
func (e SerializationError) Error() string {
return string(e)
}
type Scanner interface {
Scan(qr *QueryResult, fd *FieldDescription, size int32) error
}
// TextEncoder is an interface used to encode values in text format for
// transmission to the PostgreSQL server. It is used by unprepared
// queries and for prepared queries when the type does not implement
// BinaryEncoder
type TextEncoder interface {
// EncodeText MUST sanitize (and quote, if necessary) the returned string.
// It will be interpolated directly into the SQL string.
EncodeText() (string, error)
}
// BinaryEncoder is an interface used to encode values in binary format for
// transmission to the PostgreSQL server. It is used by prepared queries.
type BinaryEncoder interface {
@ -46,6 +63,86 @@ type NullInt64 struct {
Valid bool // Valid is true if Int64 is not NULL
}
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
// QuoteString escapes and quotes a string making it safe for interpolation
// into an SQL string.
func QuoteString(input string) (output string) {
output = "'" + strings.Replace(input, "'", "''", -1) + "'"
return
}
// QuoteIdentifier escapes and quotes an identifier making it safe for
// interpolation into an SQL string
func QuoteIdentifier(input string) (output string) {
output = `"` + strings.Replace(input, `"`, `""`, -1) + `"`
return
}
// SanitizeSql substitutely args positionaly into sql. Placeholder values are
// $ prefixed integers like $1, $2, $3, etc. args are sanitized and quoted as
// appropriate.
func SanitizeSql(sql string, args ...interface{}) (output string, err error) {
replacer := func(match string) (replacement string) {
if err != nil {
return ""
}
n, _ := strconv.ParseInt(match[1:], 10, 0)
if int(n-1) >= len(args) {
err = fmt.Errorf("Cannot interpolate %v, only %d arguments provided", match, len(args))
return
}
switch arg := args[n-1].(type) {
case string:
return QuoteString(arg)
case int:
return strconv.FormatInt(int64(arg), 10)
case int8:
return strconv.FormatInt(int64(arg), 10)
case int16:
return strconv.FormatInt(int64(arg), 10)
case int32:
return strconv.FormatInt(int64(arg), 10)
case int64:
return strconv.FormatInt(int64(arg), 10)
case time.Time:
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700"))
case uint:
return strconv.FormatUint(uint64(arg), 10)
case uint8:
return strconv.FormatUint(uint64(arg), 10)
case uint16:
return strconv.FormatUint(uint64(arg), 10)
case uint32:
return strconv.FormatUint(uint64(arg), 10)
case uint64:
return strconv.FormatUint(uint64(arg), 10)
case float32:
return strconv.FormatFloat(float64(arg), 'f', -1, 32)
case float64:
return strconv.FormatFloat(arg, 'f', -1, 64)
case bool:
return strconv.FormatBool(arg)
case []byte:
return `E'\\x` + hex.EncodeToString(arg) + `'`
case nil:
return "null"
case TextEncoder:
var s string
s, err = arg.EncodeText()
return s
default:
err = SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
return ""
}
}
output = literalPattern.ReplaceAllStringFunc(sql, replacer)
return
}
func (n *NullInt64) Scan(qr *QueryResult, fd *FieldDescription, size int32) error {
if size == -1 {
n.Int64, n.Valid = 0, false

View File

@ -1,10 +1,72 @@
package pgx_test
import (
"github.com/jackc/pgx"
"strings"
"testing"
"time"
)
func TestQuoteString(t *testing.T) {
t.Parallel()
if pgx.QuoteString("test") != "'test'" {
t.Error("Failed to quote string")
}
if pgx.QuoteString("Jack's") != "'Jack''s'" {
t.Error("Failed to quote and escape string with embedded quote")
}
}
func TestSanitizeSql(t *testing.T) {
t.Parallel()
successTests := []struct {
sql string
args []interface{}
output string
}{
{"select $1", []interface{}{nil}, "select null"},
{"select $1", []interface{}{"Jack's"}, "select 'Jack''s'"},
{"select $1", []interface{}{42}, "select 42"},
{"select $1", []interface{}{1.23}, "select 1.23"},
{"select $1", []interface{}{true}, "select true"},
{"select $1, $2, $3", []interface{}{"Jack's", 42, 1.23}, "select 'Jack''s', 42, 1.23"},
{"select $1", []interface{}{[]byte{0, 15, 255, 17}}, `select E'\\x000fff11'`},
{"select $1", []interface{}{&pgx.NullInt64{Int64: 1, Valid: true}}, "select 1"},
}
for i, tt := range successTests {
san, err := pgx.SanitizeSql(tt.sql, tt.args...)
if err != nil {
t.Errorf("%d. Unexpected failure: %v (sql -> %v, args -> %v)", i, err, tt.sql, tt.args)
}
if san != tt.output {
t.Errorf("%d. Expected %v, got %v (sql -> %v, args -> %v)", i, tt.output, san, tt.sql, tt.args)
}
}
errorTests := []struct {
sql string
args []interface{}
err string
}{
{"select $1", []interface{}{t}, "is not a core type and it does not implement TextEncoder"},
{"select $1, $2", []interface{}{}, "Cannot interpolate $1, only 0 arguments provided"},
}
for i, tt := range errorTests {
_, err := pgx.SanitizeSql(tt.sql, tt.args...)
if err == nil {
t.Errorf("%d. Unexpected success (sql -> %v, args -> %v)", i, tt.sql, tt.args, err)
}
if !strings.Contains(err.Error(), tt.err) {
t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, args -> %v)", i, tt.err, err, tt.sql, tt.args)
}
}
}
func TestEncodeError(t *testing.T) {
t.Parallel()