mirror of https://github.com/VinGarcia/ksql.git
Refactor dialect.go so its easier to add new dialects
parent
398f7f43d7
commit
3a90b03a37
40
dialect.go
40
dialect.go
|
@ -2,13 +2,32 @@ package ksql
|
||||||
|
|
||||||
import "strconv"
|
import "strconv"
|
||||||
|
|
||||||
|
type insertMethod int
|
||||||
|
|
||||||
|
const (
|
||||||
|
insertWithReturning insertMethod = iota
|
||||||
|
insertWithLastInsertID
|
||||||
|
insertWithNoIDRetrieval
|
||||||
|
)
|
||||||
|
|
||||||
|
var supportedDialects = map[string]dialect{
|
||||||
|
"postgres": &postgresDialect{},
|
||||||
|
"sqlite3": &sqlite3Dialect{},
|
||||||
|
// "mysql": &mysqlDialect{},
|
||||||
|
}
|
||||||
|
|
||||||
type dialect interface {
|
type dialect interface {
|
||||||
|
InsertMethod() insertMethod
|
||||||
Escape(str string) string
|
Escape(str string) string
|
||||||
Placeholder(idx int) string
|
Placeholder(idx int) string
|
||||||
}
|
}
|
||||||
|
|
||||||
type postgresDialect struct{}
|
type postgresDialect struct{}
|
||||||
|
|
||||||
|
func (postgresDialect) InsertMethod() insertMethod {
|
||||||
|
return insertWithReturning
|
||||||
|
}
|
||||||
|
|
||||||
func (postgresDialect) Escape(str string) string {
|
func (postgresDialect) Escape(str string) string {
|
||||||
return `"` + str + `"`
|
return `"` + str + `"`
|
||||||
}
|
}
|
||||||
|
@ -19,6 +38,10 @@ func (postgresDialect) Placeholder(idx int) string {
|
||||||
|
|
||||||
type sqlite3Dialect struct{}
|
type sqlite3Dialect struct{}
|
||||||
|
|
||||||
|
func (sqlite3Dialect) InsertMethod() insertMethod {
|
||||||
|
return insertWithLastInsertID
|
||||||
|
}
|
||||||
|
|
||||||
func (sqlite3Dialect) Escape(str string) string {
|
func (sqlite3Dialect) Escape(str string) string {
|
||||||
return "`" + str + "`"
|
return "`" + str + "`"
|
||||||
}
|
}
|
||||||
|
@ -27,9 +50,16 @@ func (sqlite3Dialect) Placeholder(idx int) string {
|
||||||
return "?"
|
return "?"
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDriverDialect(driver string) dialect {
|
type mysqlDialect struct{}
|
||||||
return map[string]dialect{
|
|
||||||
"postgres": &postgresDialect{},
|
func (mysqlDialect) InsertMethod() insertMethod {
|
||||||
"sqlite3": &sqlite3Dialect{},
|
return insertWithLastInsertID
|
||||||
}[driver]
|
}
|
||||||
|
|
||||||
|
func (mysqlDialect) Escape(str string) string {
|
||||||
|
return "`" + str + "`"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mysqlDialect) Placeholder(idx int) string {
|
||||||
|
return "?"
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,3 +15,11 @@ services:
|
||||||
environment:
|
environment:
|
||||||
- POSTGRES_USER=postgres
|
- POSTGRES_USER=postgres
|
||||||
- POSTGRES_PASSWORD=postgres
|
- POSTGRES_PASSWORD=postgres
|
||||||
|
|
||||||
|
mysql:
|
||||||
|
image: mysql
|
||||||
|
restart: always
|
||||||
|
ports:
|
||||||
|
- "127.0.0.1:3306:3306"
|
||||||
|
environment:
|
||||||
|
MYSQL_ROOT_PASSWORD: mysql
|
||||||
|
|
24
ksql.go
24
ksql.go
|
@ -32,14 +32,6 @@ type sqlProvider interface {
|
||||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type insertMethod int
|
|
||||||
|
|
||||||
const (
|
|
||||||
insertWithReturning insertMethod = iota
|
|
||||||
insertWithLastInsertID
|
|
||||||
insertWithNoIDRetrieval
|
|
||||||
)
|
|
||||||
|
|
||||||
// Config describes the optional arguments accepted
|
// Config describes the optional arguments accepted
|
||||||
// by the ksql.New() function.
|
// by the ksql.New() function.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
@ -75,7 +67,7 @@ func New(
|
||||||
|
|
||||||
db.SetMaxOpenConns(config.MaxOpenConns)
|
db.SetMaxOpenConns(config.MaxOpenConns)
|
||||||
|
|
||||||
dialect := getDriverDialect(dbDriver)
|
dialect := supportedDialects[dbDriver]
|
||||||
if dialect == nil {
|
if dialect == nil {
|
||||||
return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver)
|
return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver)
|
||||||
}
|
}
|
||||||
|
@ -84,17 +76,9 @@ func New(
|
||||||
config.IDColumns = []string{"id"}
|
config.IDColumns = []string{"id"}
|
||||||
}
|
}
|
||||||
|
|
||||||
var insertMethod insertMethod
|
insertMethod := dialect.InsertMethod()
|
||||||
switch dbDriver {
|
if len(config.IDColumns) > 1 && insertMethod == insertWithLastInsertID {
|
||||||
case "sqlite3":
|
insertMethod = insertWithNoIDRetrieval
|
||||||
insertMethod = insertWithLastInsertID
|
|
||||||
if len(config.IDColumns) > 1 {
|
|
||||||
insertMethod = insertWithNoIDRetrieval
|
|
||||||
}
|
|
||||||
case "postgres":
|
|
||||||
insertMethod = insertWithReturning
|
|
||||||
default:
|
|
||||||
return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return DB{
|
return DB{
|
||||||
|
|
16
ksql_test.go
16
ksql_test.go
|
@ -33,7 +33,7 @@ type Address struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQuery(t *testing.T) {
|
func TestQuery(t *testing.T) {
|
||||||
for _, driver := range []string{"sqlite3", "postgres"} {
|
for driver := range supportedDialects {
|
||||||
t.Run(driver, func(t *testing.T) {
|
t.Run(driver, func(t *testing.T) {
|
||||||
t.Run("using slice of structs", func(t *testing.T) {
|
t.Run("using slice of structs", func(t *testing.T) {
|
||||||
err := createTable(driver)
|
err := createTable(driver)
|
||||||
|
@ -222,7 +222,7 @@ func TestQuery(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryOne(t *testing.T) {
|
func TestQueryOne(t *testing.T) {
|
||||||
for _, driver := range []string{"sqlite3", "postgres"} {
|
for driver := range supportedDialects {
|
||||||
t.Run(driver, func(t *testing.T) {
|
t.Run(driver, func(t *testing.T) {
|
||||||
err := createTable(driver)
|
err := createTable(driver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -318,7 +318,7 @@ func TestQueryOne(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInsert(t *testing.T) {
|
func TestInsert(t *testing.T) {
|
||||||
for _, driver := range []string{"sqlite3", "postgres"} {
|
for driver := range supportedDialects {
|
||||||
t.Run(driver, func(t *testing.T) {
|
t.Run(driver, func(t *testing.T) {
|
||||||
t.Run("using slice of structs", func(t *testing.T) {
|
t.Run("using slice of structs", func(t *testing.T) {
|
||||||
err := createTable(driver)
|
err := createTable(driver)
|
||||||
|
@ -446,7 +446,7 @@ func TestInsert(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDelete(t *testing.T) {
|
func TestDelete(t *testing.T) {
|
||||||
for _, driver := range []string{"sqlite3", "postgres"} {
|
for driver := range supportedDialects {
|
||||||
t.Run(driver, func(t *testing.T) {
|
t.Run(driver, func(t *testing.T) {
|
||||||
err := createTable(driver)
|
err := createTable(driver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -589,7 +589,7 @@ func TestDelete(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdate(t *testing.T) {
|
func TestUpdate(t *testing.T) {
|
||||||
for _, driver := range []string{"sqlite3", "postgres"} {
|
for driver := range supportedDialects {
|
||||||
t.Run(driver, func(t *testing.T) {
|
t.Run(driver, func(t *testing.T) {
|
||||||
err := createTable(driver)
|
err := createTable(driver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -758,7 +758,7 @@ func TestUpdate(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryChunks(t *testing.T) {
|
func TestQueryChunks(t *testing.T) {
|
||||||
for _, driver := range []string{"sqlite3", "postgres"} {
|
for driver := range supportedDialects {
|
||||||
t.Run(driver, func(t *testing.T) {
|
t.Run(driver, func(t *testing.T) {
|
||||||
t.Run("should query a single row correctly", func(t *testing.T) {
|
t.Run("should query a single row correctly", func(t *testing.T) {
|
||||||
err := createTable(driver)
|
err := createTable(driver)
|
||||||
|
@ -1156,7 +1156,7 @@ func TestQueryChunks(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTransaction(t *testing.T) {
|
func TestTransaction(t *testing.T) {
|
||||||
for _, driver := range []string{"sqlite3", "postgres"} {
|
for driver := range supportedDialects {
|
||||||
t.Run(driver, func(t *testing.T) {
|
t.Run(driver, func(t *testing.T) {
|
||||||
t.Run("should query a single row correctly", func(t *testing.T) {
|
t.Run("should query a single row correctly", func(t *testing.T) {
|
||||||
err := createTable(driver)
|
err := createTable(driver)
|
||||||
|
@ -1391,7 +1391,7 @@ func newTestDB(db *sql.DB, driver string, tableName string, ids ...string) DB {
|
||||||
|
|
||||||
return DB{
|
return DB{
|
||||||
driver: driver,
|
driver: driver,
|
||||||
dialect: getDriverDialect(driver),
|
dialect: supportedDialects[driver],
|
||||||
db: db,
|
db: db,
|
||||||
tableName: tableName,
|
tableName: tableName,
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue