mirror of https://github.com/VinGarcia/ksql.git
Add feature of using string placeholders (%s) instead of the db specific ones
parent
d280eb1eb1
commit
7d06b3dfe8
12
README.md
12
README.md
|
@ -136,6 +136,14 @@ func main() {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
db, err := ksql.New("sqlite3", "/tmp/hello.sqlite", ksql.Config{
|
db, err := ksql.New("sqlite3", "/tmp/hello.sqlite", ksql.Config{
|
||||||
MaxOpenConns: 1,
|
MaxOpenConns: 1,
|
||||||
|
|
||||||
|
// UseGolangPlaceholders allows you to use the same placeholder `%s`
|
||||||
|
// for all databases which is useful if you want your code to work in
|
||||||
|
// different platforms.
|
||||||
|
//
|
||||||
|
// Ignore or set this argument to false if you prefer
|
||||||
|
// using the database specific placeholders like `$1`, `?` or `@p1`
|
||||||
|
UseGolangPlaceholders: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err.Error())
|
panic(err.Error())
|
||||||
|
@ -186,7 +194,7 @@ func main() {
|
||||||
|
|
||||||
// Retrieving Cristina:
|
// Retrieving Cristina:
|
||||||
var cris User
|
var cris User
|
||||||
err = db.QueryOne(ctx, &cris, "SELECT * FROM users WHERE name = ? ORDER BY id", "Cristina")
|
err = db.QueryOne(ctx, &cris, "SELECT * FROM users WHERE name = %s ORDER BY id", "Cristina")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err.Error())
|
panic(err.Error())
|
||||||
}
|
}
|
||||||
|
@ -234,7 +242,7 @@ func main() {
|
||||||
// Making transactions:
|
// Making transactions:
|
||||||
err = db.Transaction(ctx, func(db ksql.Provider) error {
|
err = db.Transaction(ctx, func(db ksql.Provider) error {
|
||||||
var cris2 User
|
var cris2 User
|
||||||
err = db.QueryOne(ctx, &cris2, "SELECT * FROM users WHERE id = ?", cris.ID)
|
err = db.QueryOne(ctx, &cris2, "SELECT * FROM users WHERE id = %s", cris.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// This will cause an automatic rollback:
|
// This will cause an automatic rollback:
|
||||||
return err
|
return err
|
||||||
|
|
30
dialect.go
30
dialect.go
|
@ -21,6 +21,21 @@ var supportedDialects = map[string]Dialect{
|
||||||
"sqlserver": &sqlserverDialect{},
|
"sqlserver": &sqlserverDialect{},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDriverDialect instantiantes the dialect for the
|
||||||
|
// provided driver string, if the drive is not supported
|
||||||
|
// it returns an error
|
||||||
|
func GetDriverDialect(driver string) (Dialect, error) {
|
||||||
|
dialect, found := map[string]Dialect{
|
||||||
|
"postgres": &postgresDialect{},
|
||||||
|
"sqlite3": &sqlite3Dialect{},
|
||||||
|
}[driver]
|
||||||
|
if !found {
|
||||||
|
return nil, fmt.Errorf("unsupported driver `%s`", driver)
|
||||||
|
}
|
||||||
|
|
||||||
|
return dialect, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Dialect is used to represent the different ways
|
// Dialect is used to represent the different ways
|
||||||
// of writing SQL queries used by each SQL driver.
|
// of writing SQL queries used by each SQL driver.
|
||||||
type Dialect interface {
|
type Dialect interface {
|
||||||
|
@ -66,21 +81,6 @@ func (sqlite3Dialect) Placeholder(idx int) string {
|
||||||
return "?"
|
return "?"
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDriverDialect instantiantes the dialect for the
|
|
||||||
// provided driver string, if the drive is not supported
|
|
||||||
// it returns an error
|
|
||||||
func GetDriverDialect(driver string) (Dialect, error) {
|
|
||||||
dialect, found := map[string]Dialect{
|
|
||||||
"postgres": &postgresDialect{},
|
|
||||||
"sqlite3": &sqlite3Dialect{},
|
|
||||||
}[driver]
|
|
||||||
if !found {
|
|
||||||
return nil, fmt.Errorf("unsupported driver `%s`", driver)
|
|
||||||
}
|
|
||||||
|
|
||||||
return dialect, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type mysqlDialect struct{}
|
type mysqlDialect struct{}
|
||||||
|
|
||||||
func (mysqlDialect) DriverName() string {
|
func (mysqlDialect) DriverName() string {
|
||||||
|
|
|
@ -41,6 +41,14 @@ func main() {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
db, err := ksql.New("sqlite3", "/tmp/hello.sqlite", ksql.Config{
|
db, err := ksql.New("sqlite3", "/tmp/hello.sqlite", ksql.Config{
|
||||||
MaxOpenConns: 1,
|
MaxOpenConns: 1,
|
||||||
|
|
||||||
|
// UseGolangPlaceholders allows you to use the same placeholder `%s`
|
||||||
|
// for all databases which is useful if you want your code to work in
|
||||||
|
// different platforms.
|
||||||
|
//
|
||||||
|
// Ignore or set this argument to false if you prefer
|
||||||
|
// using the database specific placeholders like `$1`, `?` or `@p1`
|
||||||
|
UseGolangPlaceholders: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err.Error())
|
panic(err.Error())
|
||||||
|
@ -91,7 +99,7 @@ func main() {
|
||||||
|
|
||||||
// Retrieving Cristina:
|
// Retrieving Cristina:
|
||||||
var cris User
|
var cris User
|
||||||
err = db.QueryOne(ctx, &cris, "SELECT * FROM users WHERE name = ? ORDER BY id", "Cristina")
|
err = db.QueryOne(ctx, &cris, "SELECT * FROM users WHERE name = %s ORDER BY id", "Cristina")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err.Error())
|
panic(err.Error())
|
||||||
}
|
}
|
||||||
|
@ -139,7 +147,7 @@ func main() {
|
||||||
// Making transactions:
|
// Making transactions:
|
||||||
err = db.Transaction(ctx, func(db ksql.Provider) error {
|
err = db.Transaction(ctx, func(db ksql.Provider) error {
|
||||||
var cris2 User
|
var cris2 User
|
||||||
err = db.QueryOne(ctx, &cris2, "SELECT * FROM users WHERE id = ?", cris.ID)
|
err = db.QueryOne(ctx, &cris2, "SELECT * FROM users WHERE id = %s", cris.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// This will cause an automatic rollback:
|
// This will cause an automatic rollback:
|
||||||
return err
|
return err
|
||||||
|
|
38
ksql.go
38
ksql.go
|
@ -74,6 +74,11 @@ type Tx interface {
|
||||||
type Config struct {
|
type Config struct {
|
||||||
// MaxOpenCons defaults to 1 if not set
|
// MaxOpenCons defaults to 1 if not set
|
||||||
MaxOpenConns int
|
MaxOpenConns int
|
||||||
|
|
||||||
|
// If set to true, it will expect the queries to use %s as placeholder for the params
|
||||||
|
// instead of `$1`, `?` or `@p1`, this is not set by default because it is an
|
||||||
|
// extra `fmt.Sprintf()` operation that not all users might want.
|
||||||
|
UseGolangPlaceholders bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// New instantiates a new KissSQL client
|
// New instantiates a new KissSQL client
|
||||||
|
@ -81,13 +86,13 @@ func New(
|
||||||
dbDriver string,
|
dbDriver string,
|
||||||
connectionString string,
|
connectionString string,
|
||||||
config Config,
|
config Config,
|
||||||
) (DB, error) {
|
) (Provider, error) {
|
||||||
db, err := sql.Open(dbDriver, connectionString)
|
db, err := sql.Open(dbDriver, connectionString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DB{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = db.Ping(); err != nil {
|
if err = db.Ping(); err != nil {
|
||||||
return DB{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.MaxOpenConns == 0 {
|
if config.MaxOpenConns == 0 {
|
||||||
|
@ -96,7 +101,17 @@ func New(
|
||||||
|
|
||||||
db.SetMaxOpenConns(config.MaxOpenConns)
|
db.SetMaxOpenConns(config.MaxOpenConns)
|
||||||
|
|
||||||
return NewWithAdapter(SQLAdapter{db}, dbDriver)
|
kdb, err := NewWithAdapter(SQLAdapter{db}, dbDriver)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.UseGolangPlaceholders {
|
||||||
|
dialect, _ := GetDriverDialect("postgres")
|
||||||
|
kdb = newPlaceholderAdapter(kdb, dialect)
|
||||||
|
}
|
||||||
|
|
||||||
|
return kdb, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWithPGX instantiates a new KissSQL client using the pgx
|
// NewWithPGX instantiates a new KissSQL client using the pgx
|
||||||
|
@ -105,7 +120,7 @@ func NewWithPGX(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
connectionString string,
|
connectionString string,
|
||||||
config Config,
|
config Config,
|
||||||
) (db DB, err error) {
|
) (db Provider, err error) {
|
||||||
pgxConf, err := pgxpool.ParseConfig(connectionString)
|
pgxConf, err := pgxpool.ParseConfig(connectionString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DB{}, err
|
return DB{}, err
|
||||||
|
@ -122,7 +137,16 @@ func NewWithPGX(
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err = NewWithAdapter(PGXAdapter{pool}, "postgres")
|
db, err = NewWithAdapter(PGXAdapter{pool}, "postgres")
|
||||||
return db, err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.UseGolangPlaceholders {
|
||||||
|
dialect, _ := GetDriverDialect("postgres")
|
||||||
|
db = newPlaceholderAdapter(db, dialect)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWithAdapter allows the user to insert a custom implementation
|
// NewWithAdapter allows the user to insert a custom implementation
|
||||||
|
@ -130,7 +154,7 @@ func NewWithPGX(
|
||||||
func NewWithAdapter(
|
func NewWithAdapter(
|
||||||
db DBAdapter,
|
db DBAdapter,
|
||||||
dbDriver string,
|
dbDriver string,
|
||||||
) (DB, error) {
|
) (Provider, error) {
|
||||||
dialect := supportedDialects[dbDriver]
|
dialect := supportedDialects[dbDriver]
|
||||||
if dialect == nil {
|
if dialect == nil {
|
||||||
return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver)
|
return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver)
|
||||||
|
|
|
@ -675,13 +675,14 @@ func TestInsert(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// Using columns "id" and "name" as IDs:
|
// Using columns "id" and "name" as IDs:
|
||||||
table := NewTable("users", "id", "name")
|
table := NewTable("users", "id", "name")
|
||||||
|
|
||||||
c, err := New(config.driver, connectionString[config.driver], Config{})
|
db, closer := connectDB(t, config)
|
||||||
assert.Equal(t, nil, err)
|
defer closer.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
c := newTestDB(db, config.driver)
|
||||||
|
|
||||||
u := User{
|
u := User{
|
||||||
Name: "No ID returned",
|
Name: "No ID returned",
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
package ksql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type placeholderAdapter struct {
|
||||||
|
Provider
|
||||||
|
|
||||||
|
dialect Dialect
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPlaceholderAdapter(db Provider, dialect Dialect) placeholderAdapter {
|
||||||
|
return placeholderAdapter{
|
||||||
|
Provider: db,
|
||||||
|
dialect: dialect,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p placeholderAdapter) Query(ctx context.Context, records interface{}, queryFormat string, params ...interface{}) error {
|
||||||
|
query := fmt.Sprintf(queryFormat, buildPlaceholderList(p.dialect, len(params))...)
|
||||||
|
return p.Provider.Query(ctx, records, query, params...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p placeholderAdapter) QueryOne(ctx context.Context, record interface{}, queryFormat string, params ...interface{}) error {
|
||||||
|
query := fmt.Sprintf(queryFormat, buildPlaceholderList(p.dialect, len(params))...)
|
||||||
|
return p.Provider.QueryOne(ctx, record, query, params...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p placeholderAdapter) QueryChunks(ctx context.Context, parser ChunkParser) error {
|
||||||
|
parser.Query = fmt.Sprintf(parser.Query, buildPlaceholderList(p.dialect, len(parser.Params))...)
|
||||||
|
return p.Provider.QueryChunks(ctx, parser)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p placeholderAdapter) Exec(ctx context.Context, queryFormat string, params ...interface{}) error {
|
||||||
|
query := fmt.Sprintf(queryFormat, buildPlaceholderList(p.dialect, len(params))...)
|
||||||
|
return p.Provider.Exec(ctx, query, params...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p placeholderAdapter) Transaction(ctx context.Context, fn func(Provider) error) error {
|
||||||
|
return p.Provider.Transaction(ctx, func(db Provider) error {
|
||||||
|
db = newPlaceholderAdapter(db, p.dialect)
|
||||||
|
return fn(db)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildPlaceholderList(dialect Dialect, numElements int) []interface{} {
|
||||||
|
placeholders := []interface{}{}
|
||||||
|
for i := 0; i < numElements; i++ {
|
||||||
|
placeholders = append(placeholders, dialect.Placeholder(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
return placeholders
|
||||||
|
}
|
Loading…
Reference in New Issue