Refactor dialect.go so its easier to add new dialects

pull/2/head
Vinícius Garcia 2021-05-08 11:56:57 -03:00
parent 398f7f43d7
commit 3a90b03a37
4 changed files with 55 additions and 33 deletions

View File

@ -2,13 +2,32 @@ package ksql
import "strconv"
type insertMethod int
const (
insertWithReturning insertMethod = iota
insertWithLastInsertID
insertWithNoIDRetrieval
)
var supportedDialects = map[string]dialect{
"postgres": &postgresDialect{},
"sqlite3": &sqlite3Dialect{},
// "mysql": &mysqlDialect{},
}
type dialect interface {
InsertMethod() insertMethod
Escape(str string) string
Placeholder(idx int) string
}
type postgresDialect struct{}
func (postgresDialect) InsertMethod() insertMethod {
return insertWithReturning
}
func (postgresDialect) Escape(str string) string {
return `"` + str + `"`
}
@ -19,6 +38,10 @@ func (postgresDialect) Placeholder(idx int) string {
type sqlite3Dialect struct{}
func (sqlite3Dialect) InsertMethod() insertMethod {
return insertWithLastInsertID
}
func (sqlite3Dialect) Escape(str string) string {
return "`" + str + "`"
}
@ -27,9 +50,16 @@ func (sqlite3Dialect) Placeholder(idx int) string {
return "?"
}
func getDriverDialect(driver string) dialect {
return map[string]dialect{
"postgres": &postgresDialect{},
"sqlite3": &sqlite3Dialect{},
}[driver]
type mysqlDialect struct{}
func (mysqlDialect) InsertMethod() insertMethod {
return insertWithLastInsertID
}
func (mysqlDialect) Escape(str string) string {
return "`" + str + "`"
}
func (mysqlDialect) Placeholder(idx int) string {
return "?"
}

View File

@ -15,3 +15,11 @@ services:
environment:
- POSTGRES_USER=postgres
- POSTGRES_PASSWORD=postgres
mysql:
image: mysql
restart: always
ports:
- "127.0.0.1:3306:3306"
environment:
MYSQL_ROOT_PASSWORD: mysql

22
ksql.go
View File

@ -32,14 +32,6 @@ type sqlProvider interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
type insertMethod int
const (
insertWithReturning insertMethod = iota
insertWithLastInsertID
insertWithNoIDRetrieval
)
// Config describes the optional arguments accepted
// by the ksql.New() function.
type Config struct {
@ -75,7 +67,7 @@ func New(
db.SetMaxOpenConns(config.MaxOpenConns)
dialect := getDriverDialect(dbDriver)
dialect := supportedDialects[dbDriver]
if dialect == nil {
return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver)
}
@ -84,18 +76,10 @@ func New(
config.IDColumns = []string{"id"}
}
var insertMethod insertMethod
switch dbDriver {
case "sqlite3":
insertMethod = insertWithLastInsertID
if len(config.IDColumns) > 1 {
insertMethod := dialect.InsertMethod()
if len(config.IDColumns) > 1 && insertMethod == insertWithLastInsertID {
insertMethod = insertWithNoIDRetrieval
}
case "postgres":
insertMethod = insertWithReturning
default:
return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver)
}
return DB{
dialect: dialect,

View File

@ -33,7 +33,7 @@ type Address struct {
}
func TestQuery(t *testing.T) {
for _, driver := range []string{"sqlite3", "postgres"} {
for driver := range supportedDialects {
t.Run(driver, func(t *testing.T) {
t.Run("using slice of structs", func(t *testing.T) {
err := createTable(driver)
@ -222,7 +222,7 @@ func TestQuery(t *testing.T) {
}
func TestQueryOne(t *testing.T) {
for _, driver := range []string{"sqlite3", "postgres"} {
for driver := range supportedDialects {
t.Run(driver, func(t *testing.T) {
err := createTable(driver)
if err != nil {
@ -318,7 +318,7 @@ func TestQueryOne(t *testing.T) {
}
func TestInsert(t *testing.T) {
for _, driver := range []string{"sqlite3", "postgres"} {
for driver := range supportedDialects {
t.Run(driver, func(t *testing.T) {
t.Run("using slice of structs", func(t *testing.T) {
err := createTable(driver)
@ -446,7 +446,7 @@ func TestInsert(t *testing.T) {
}
func TestDelete(t *testing.T) {
for _, driver := range []string{"sqlite3", "postgres"} {
for driver := range supportedDialects {
t.Run(driver, func(t *testing.T) {
err := createTable(driver)
if err != nil {
@ -589,7 +589,7 @@ func TestDelete(t *testing.T) {
}
func TestUpdate(t *testing.T) {
for _, driver := range []string{"sqlite3", "postgres"} {
for driver := range supportedDialects {
t.Run(driver, func(t *testing.T) {
err := createTable(driver)
if err != nil {
@ -758,7 +758,7 @@ func TestUpdate(t *testing.T) {
}
func TestQueryChunks(t *testing.T) {
for _, driver := range []string{"sqlite3", "postgres"} {
for driver := range supportedDialects {
t.Run(driver, func(t *testing.T) {
t.Run("should query a single row correctly", func(t *testing.T) {
err := createTable(driver)
@ -1156,7 +1156,7 @@ func TestQueryChunks(t *testing.T) {
}
func TestTransaction(t *testing.T) {
for _, driver := range []string{"sqlite3", "postgres"} {
for driver := range supportedDialects {
t.Run(driver, func(t *testing.T) {
t.Run("should query a single row correctly", func(t *testing.T) {
err := createTable(driver)
@ -1391,7 +1391,7 @@ func newTestDB(db *sql.DB, driver string, tableName string, ids ...string) DB {
return DB{
driver: driver,
dialect: getDriverDialect(driver),
dialect: supportedDialects[driver],
db: db,
tableName: tableName,