From 3a90b03a37d88584e7c53d1cfa7139424ae3ed1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sat, 8 May 2021 11:56:57 -0300 Subject: [PATCH] Refactor dialect.go so its easier to add new dialects --- dialect.go | 40 +++++++++++++++++++++++++++++++++++----- docker-compose.yml | 8 ++++++++ ksql.go | 24 ++++-------------------- ksql_test.go | 16 ++++++++-------- 4 files changed, 55 insertions(+), 33 deletions(-) diff --git a/dialect.go b/dialect.go index 5522fb8..6640c34 100644 --- a/dialect.go +++ b/dialect.go @@ -2,13 +2,32 @@ package ksql import "strconv" +type insertMethod int + +const ( + insertWithReturning insertMethod = iota + insertWithLastInsertID + insertWithNoIDRetrieval +) + +var supportedDialects = map[string]dialect{ + "postgres": &postgresDialect{}, + "sqlite3": &sqlite3Dialect{}, + // "mysql": &mysqlDialect{}, +} + type dialect interface { + InsertMethod() insertMethod Escape(str string) string Placeholder(idx int) string } type postgresDialect struct{} +func (postgresDialect) InsertMethod() insertMethod { + return insertWithReturning +} + func (postgresDialect) Escape(str string) string { return `"` + str + `"` } @@ -19,6 +38,10 @@ func (postgresDialect) Placeholder(idx int) string { type sqlite3Dialect struct{} +func (sqlite3Dialect) InsertMethod() insertMethod { + return insertWithLastInsertID +} + func (sqlite3Dialect) Escape(str string) string { return "`" + str + "`" } @@ -27,9 +50,16 @@ func (sqlite3Dialect) Placeholder(idx int) string { return "?" } -func getDriverDialect(driver string) dialect { - return map[string]dialect{ - "postgres": &postgresDialect{}, - "sqlite3": &sqlite3Dialect{}, - }[driver] +type mysqlDialect struct{} + +func (mysqlDialect) InsertMethod() insertMethod { + return insertWithLastInsertID +} + +func (mysqlDialect) Escape(str string) string { + return "`" + str + "`" +} + +func (mysqlDialect) Placeholder(idx int) string { + return "?" } diff --git a/docker-compose.yml b/docker-compose.yml index 878d61d..6d0932c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,3 +15,11 @@ services: environment: - POSTGRES_USER=postgres - POSTGRES_PASSWORD=postgres + + mysql: + image: mysql + restart: always + ports: + - "127.0.0.1:3306:3306" + environment: + MYSQL_ROOT_PASSWORD: mysql diff --git a/ksql.go b/ksql.go index e426514..0386b68 100644 --- a/ksql.go +++ b/ksql.go @@ -32,14 +32,6 @@ type sqlProvider interface { QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } -type insertMethod int - -const ( - insertWithReturning insertMethod = iota - insertWithLastInsertID - insertWithNoIDRetrieval -) - // Config describes the optional arguments accepted // by the ksql.New() function. type Config struct { @@ -75,7 +67,7 @@ func New( db.SetMaxOpenConns(config.MaxOpenConns) - dialect := getDriverDialect(dbDriver) + dialect := supportedDialects[dbDriver] if dialect == nil { return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver) } @@ -84,17 +76,9 @@ func New( config.IDColumns = []string{"id"} } - var insertMethod insertMethod - switch dbDriver { - case "sqlite3": - insertMethod = insertWithLastInsertID - if len(config.IDColumns) > 1 { - insertMethod = insertWithNoIDRetrieval - } - case "postgres": - insertMethod = insertWithReturning - default: - return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver) + insertMethod := dialect.InsertMethod() + if len(config.IDColumns) > 1 && insertMethod == insertWithLastInsertID { + insertMethod = insertWithNoIDRetrieval } return DB{ diff --git a/ksql_test.go b/ksql_test.go index 903acff..f0514b9 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -33,7 +33,7 @@ type Address struct { } func TestQuery(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { + for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { t.Run("using slice of structs", func(t *testing.T) { err := createTable(driver) @@ -222,7 +222,7 @@ func TestQuery(t *testing.T) { } func TestQueryOne(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { + for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { err := createTable(driver) if err != nil { @@ -318,7 +318,7 @@ func TestQueryOne(t *testing.T) { } func TestInsert(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { + for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { t.Run("using slice of structs", func(t *testing.T) { err := createTable(driver) @@ -446,7 +446,7 @@ func TestInsert(t *testing.T) { } func TestDelete(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { + for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { err := createTable(driver) if err != nil { @@ -589,7 +589,7 @@ func TestDelete(t *testing.T) { } func TestUpdate(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { + for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { err := createTable(driver) if err != nil { @@ -758,7 +758,7 @@ func TestUpdate(t *testing.T) { } func TestQueryChunks(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { + for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { t.Run("should query a single row correctly", func(t *testing.T) { err := createTable(driver) @@ -1156,7 +1156,7 @@ func TestQueryChunks(t *testing.T) { } func TestTransaction(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { + for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { t.Run("should query a single row correctly", func(t *testing.T) { err := createTable(driver) @@ -1391,7 +1391,7 @@ func newTestDB(db *sql.DB, driver string, tableName string, ids ...string) DB { return DB{ driver: driver, - dialect: getDriverDialect(driver), + dialect: supportedDialects[driver], db: db, tableName: tableName,