diff --git a/dialect.go b/dialect.go index 655c29a..8a5711c 100644 --- a/dialect.go +++ b/dialect.go @@ -70,10 +70,7 @@ func (sqlite3Dialect) Placeholder(idx int) string { // provided driver string, if the drive is not supported // it returns an error func GetDriverDialect(driver string) (Dialect, error) { - dialect, found := map[string]Dialect{ - "postgres": &postgresDialect{}, - "sqlite3": &sqlite3Dialect{}, - }[driver] + dialect, found := supportedDialects[driver] if !found { return nil, fmt.Errorf("unsupported driver `%s`", driver) } diff --git a/dialect_test.go b/dialect_test.go new file mode 100644 index 0000000..bad5471 --- /dev/null +++ b/dialect_test.go @@ -0,0 +1,24 @@ +package ksql + +import ( + "testing" + + tt "github.com/vingarcia/ksql/internal/testtools" +) + +func TestGetDriverDialect(t *testing.T) { + t.Run("should work for all registered drivers", func(t *testing.T) { + for drivername, expectedDialect := range supportedDialects { + t.Run(drivername, func(t *testing.T) { + dialect, err := GetDriverDialect(drivername) + tt.AssertNoErr(t, err) + tt.AssertEqual(t, dialect, expectedDialect) + }) + } + }) + + t.Run("should report error if no driver is found", func(t *testing.T) { + _, err := GetDriverDialect("non-existing-driver") + tt.AssertErrContains(t, err, "unsupported driver", "non-existing-driver") + }) +}