mirror of https://github.com/jackc/pgx.git
107 lines
2.9 KiB
Go
107 lines
2.9 KiB
Go
package pgx
|
|
|
|
import (
|
|
"encoding/hex"
|
|
"fmt"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type SerializationError string
|
|
|
|
func (e SerializationError) Error() string {
|
|
return string(e)
|
|
}
|
|
|
|
// TextEncoder is an interface used to encode values in text format for
|
|
// transmission to the PostgreSQL server. It is used by unprepared
|
|
// queries and for prepared queries when the type does not implement
|
|
// BinaryEncoder
|
|
type TextEncoder interface {
|
|
// EncodeText MUST sanitize (and quote, if necessary) the returned string.
|
|
// It will be interpolated directly into the SQL string.
|
|
EncodeText() (string, error)
|
|
}
|
|
|
|
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
|
|
|
|
// QuoteString escapes and quotes a string making it safe for interpolation
|
|
// into an SQL string.
|
|
func QuoteString(input string) (output string) {
|
|
output = "'" + strings.Replace(input, "'", "''", -1) + "'"
|
|
return
|
|
}
|
|
|
|
// QuoteIdentifier escapes and quotes an identifier making it safe for
|
|
// interpolation into an SQL string
|
|
func QuoteIdentifier(input string) (output string) {
|
|
output = `"` + strings.Replace(input, `"`, `""`, -1) + `"`
|
|
return
|
|
}
|
|
|
|
// SanitizeSql substitutely args positionaly into sql. Placeholder values are
|
|
// $ prefixed integers like $1, $2, $3, etc. args are sanitized and quoted as
|
|
// appropriate.
|
|
func SanitizeSql(sql string, args ...interface{}) (output string, err error) {
|
|
replacer := func(match string) (replacement string) {
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
n, _ := strconv.ParseInt(match[1:], 10, 0)
|
|
if int(n-1) >= len(args) {
|
|
err = fmt.Errorf("Cannot interpolate %v, only %d arguments provided", match, len(args))
|
|
return
|
|
}
|
|
|
|
switch arg := args[n-1].(type) {
|
|
case string:
|
|
return QuoteString(arg)
|
|
case int:
|
|
return strconv.FormatInt(int64(arg), 10)
|
|
case int8:
|
|
return strconv.FormatInt(int64(arg), 10)
|
|
case int16:
|
|
return strconv.FormatInt(int64(arg), 10)
|
|
case int32:
|
|
return strconv.FormatInt(int64(arg), 10)
|
|
case int64:
|
|
return strconv.FormatInt(int64(arg), 10)
|
|
case time.Time:
|
|
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700"))
|
|
case uint:
|
|
return strconv.FormatUint(uint64(arg), 10)
|
|
case uint8:
|
|
return strconv.FormatUint(uint64(arg), 10)
|
|
case uint16:
|
|
return strconv.FormatUint(uint64(arg), 10)
|
|
case uint32:
|
|
return strconv.FormatUint(uint64(arg), 10)
|
|
case uint64:
|
|
return strconv.FormatUint(uint64(arg), 10)
|
|
case float32:
|
|
return strconv.FormatFloat(float64(arg), 'f', -1, 32)
|
|
case float64:
|
|
return strconv.FormatFloat(arg, 'f', -1, 64)
|
|
case bool:
|
|
return strconv.FormatBool(arg)
|
|
case []byte:
|
|
return `E'\\x` + hex.EncodeToString(arg) + `'`
|
|
case nil:
|
|
return "null"
|
|
case TextEncoder:
|
|
var s string
|
|
s, err = arg.EncodeText()
|
|
return s
|
|
default:
|
|
err = SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
|
|
return ""
|
|
}
|
|
}
|
|
|
|
output = literalPattern.ReplaceAllStringFunc(sql, replacer)
|
|
return
|
|
}
|