Add feature of using string placeholders (%s) instead of the db specific ones

pull/7/head
Vinícius Garcia 2021-10-11 19:20:28 -03:00
parent d280eb1eb1
commit 7d06b3dfe8
6 changed files with 126 additions and 30 deletions

View File

@ -136,6 +136,14 @@ func main() {
ctx := context.Background()
db, err := ksql.New("sqlite3", "/tmp/hello.sqlite", ksql.Config{
MaxOpenConns: 1,
// UseGolangPlaceholders allows you to use the same placeholder `%s`
// for all databases which is useful if you want your code to work in
// different platforms.
//
// Ignore or set this argument to false if you prefer
// using the database specific placeholders like `$1`, `?` or `@p1`
UseGolangPlaceholders: true,
})
if err != nil {
panic(err.Error())
@ -186,7 +194,7 @@ func main() {
// Retrieving Cristina:
var cris User
err = db.QueryOne(ctx, &cris, "SELECT * FROM users WHERE name = ? ORDER BY id", "Cristina")
err = db.QueryOne(ctx, &cris, "SELECT * FROM users WHERE name = %s ORDER BY id", "Cristina")
if err != nil {
panic(err.Error())
}
@ -234,7 +242,7 @@ func main() {
// Making transactions:
err = db.Transaction(ctx, func(db ksql.Provider) error {
var cris2 User
err = db.QueryOne(ctx, &cris2, "SELECT * FROM users WHERE id = ?", cris.ID)
err = db.QueryOne(ctx, &cris2, "SELECT * FROM users WHERE id = %s", cris.ID)
if err != nil {
// This will cause an automatic rollback:
return err

View File

@ -21,6 +21,21 @@ var supportedDialects = map[string]Dialect{
"sqlserver": &sqlserverDialect{},
}
// GetDriverDialect instantiantes the dialect for the
// provided driver string, if the drive is not supported
// it returns an error
func GetDriverDialect(driver string) (Dialect, error) {
dialect, found := map[string]Dialect{
"postgres": &postgresDialect{},
"sqlite3": &sqlite3Dialect{},
}[driver]
if !found {
return nil, fmt.Errorf("unsupported driver `%s`", driver)
}
return dialect, nil
}
// Dialect is used to represent the different ways
// of writing SQL queries used by each SQL driver.
type Dialect interface {
@ -66,21 +81,6 @@ func (sqlite3Dialect) Placeholder(idx int) string {
return "?"
}
// GetDriverDialect instantiantes the dialect for the
// provided driver string, if the drive is not supported
// it returns an error
func GetDriverDialect(driver string) (Dialect, error) {
dialect, found := map[string]Dialect{
"postgres": &postgresDialect{},
"sqlite3": &sqlite3Dialect{},
}[driver]
if !found {
return nil, fmt.Errorf("unsupported driver `%s`", driver)
}
return dialect, nil
}
type mysqlDialect struct{}
func (mysqlDialect) DriverName() string {

View File

@ -41,6 +41,14 @@ func main() {
ctx := context.Background()
db, err := ksql.New("sqlite3", "/tmp/hello.sqlite", ksql.Config{
MaxOpenConns: 1,
// UseGolangPlaceholders allows you to use the same placeholder `%s`
// for all databases which is useful if you want your code to work in
// different platforms.
//
// Ignore or set this argument to false if you prefer
// using the database specific placeholders like `$1`, `?` or `@p1`
UseGolangPlaceholders: true,
})
if err != nil {
panic(err.Error())
@ -91,7 +99,7 @@ func main() {
// Retrieving Cristina:
var cris User
err = db.QueryOne(ctx, &cris, "SELECT * FROM users WHERE name = ? ORDER BY id", "Cristina")
err = db.QueryOne(ctx, &cris, "SELECT * FROM users WHERE name = %s ORDER BY id", "Cristina")
if err != nil {
panic(err.Error())
}
@ -139,7 +147,7 @@ func main() {
// Making transactions:
err = db.Transaction(ctx, func(db ksql.Provider) error {
var cris2 User
err = db.QueryOne(ctx, &cris2, "SELECT * FROM users WHERE id = ?", cris.ID)
err = db.QueryOne(ctx, &cris2, "SELECT * FROM users WHERE id = %s", cris.ID)
if err != nil {
// This will cause an automatic rollback:
return err

38
ksql.go
View File

@ -74,6 +74,11 @@ type Tx interface {
type Config struct {
// MaxOpenCons defaults to 1 if not set
MaxOpenConns int
// If set to true, it will expect the queries to use %s as placeholder for the params
// instead of `$1`, `?` or `@p1`, this is not set by default because it is an
// extra `fmt.Sprintf()` operation that not all users might want.
UseGolangPlaceholders bool
}
// New instantiates a new KissSQL client
@ -81,13 +86,13 @@ func New(
dbDriver string,
connectionString string,
config Config,
) (DB, error) {
) (Provider, error) {
db, err := sql.Open(dbDriver, connectionString)
if err != nil {
return DB{}, err
return nil, err
}
if err = db.Ping(); err != nil {
return DB{}, err
return nil, err
}
if config.MaxOpenConns == 0 {
@ -96,7 +101,17 @@ func New(
db.SetMaxOpenConns(config.MaxOpenConns)
return NewWithAdapter(SQLAdapter{db}, dbDriver)
kdb, err := NewWithAdapter(SQLAdapter{db}, dbDriver)
if err != nil {
return nil, err
}
if config.UseGolangPlaceholders {
dialect, _ := GetDriverDialect("postgres")
kdb = newPlaceholderAdapter(kdb, dialect)
}
return kdb, nil
}
// NewWithPGX instantiates a new KissSQL client using the pgx
@ -105,7 +120,7 @@ func NewWithPGX(
ctx context.Context,
connectionString string,
config Config,
) (db DB, err error) {
) (db Provider, err error) {
pgxConf, err := pgxpool.ParseConfig(connectionString)
if err != nil {
return DB{}, err
@ -122,7 +137,16 @@ func NewWithPGX(
}
db, err = NewWithAdapter(PGXAdapter{pool}, "postgres")
return db, err
if err != nil {
return nil, err
}
if config.UseGolangPlaceholders {
dialect, _ := GetDriverDialect("postgres")
db = newPlaceholderAdapter(db, dialect)
}
return db, nil
}
// NewWithAdapter allows the user to insert a custom implementation
@ -130,7 +154,7 @@ func NewWithPGX(
func NewWithAdapter(
db DBAdapter,
dbDriver string,
) (DB, error) {
) (Provider, error) {
dialect := supportedDialects[dbDriver]
if dialect == nil {
return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver)

View File

@ -675,13 +675,14 @@ func TestInsert(t *testing.T) {
return
}
ctx := context.Background()
// Using columns "id" and "name" as IDs:
table := NewTable("users", "id", "name")
c, err := New(config.driver, connectionString[config.driver], Config{})
assert.Equal(t, nil, err)
db, closer := connectDB(t, config)
defer closer.Close()
ctx := context.Background()
c := newTestDB(db, config.driver)
u := User{
Name: "No ID returned",

55
placeholder_adapter.go Normal file
View File

@ -0,0 +1,55 @@
package ksql
import (
"context"
"fmt"
)
type placeholderAdapter struct {
Provider
dialect Dialect
}
func newPlaceholderAdapter(db Provider, dialect Dialect) placeholderAdapter {
return placeholderAdapter{
Provider: db,
dialect: dialect,
}
}
func (p placeholderAdapter) Query(ctx context.Context, records interface{}, queryFormat string, params ...interface{}) error {
query := fmt.Sprintf(queryFormat, buildPlaceholderList(p.dialect, len(params))...)
return p.Provider.Query(ctx, records, query, params...)
}
func (p placeholderAdapter) QueryOne(ctx context.Context, record interface{}, queryFormat string, params ...interface{}) error {
query := fmt.Sprintf(queryFormat, buildPlaceholderList(p.dialect, len(params))...)
return p.Provider.QueryOne(ctx, record, query, params...)
}
func (p placeholderAdapter) QueryChunks(ctx context.Context, parser ChunkParser) error {
parser.Query = fmt.Sprintf(parser.Query, buildPlaceholderList(p.dialect, len(parser.Params))...)
return p.Provider.QueryChunks(ctx, parser)
}
func (p placeholderAdapter) Exec(ctx context.Context, queryFormat string, params ...interface{}) error {
query := fmt.Sprintf(queryFormat, buildPlaceholderList(p.dialect, len(params))...)
return p.Provider.Exec(ctx, query, params...)
}
func (p placeholderAdapter) Transaction(ctx context.Context, fn func(Provider) error) error {
return p.Provider.Transaction(ctx, func(db Provider) error {
db = newPlaceholderAdapter(db, p.dialect)
return fn(db)
})
}
func buildPlaceholderList(dialect Dialect, numElements int) []interface{} {
placeholders := []interface{}{}
for i := 0; i < numElements; i++ {
placeholders = append(placeholders, dialect.Placeholder(i))
}
return placeholders
}