Abstract the DBAdapter so that we can support other sql adapters

This was done for a few different reasons:

1. This allows us to work on top of the pgx client in the future
2. This would allow our users to implement their own DBAdapters
   to use with our tool.
3. This gives the users the option of using advanced configs
   of any sql client they want to use and just feed us with it
   after the configuration is done, which means we will not have
   to worry about supporting a growing number of configurations
   as we try to add support to more drivers or if we get issues
   asking for more advanced config options.
pull/2/head
Vinícius Garcia 2021-07-31 18:54:57 -03:00
parent f420553e0b
commit e73db4a216
4 changed files with 99 additions and 20 deletions

View File

@ -523,11 +523,15 @@ make test
- Add tests for tables using composite keys
- Add support for serializing structs as other formats such as YAML
- Update `kstructs.FillStructWith` to work with `json` tagged attributes
- Update `kstructs.FillStructWith` to work with `ksql:"..,json"` tagged attributes
- Make testing easier by exposing the connection strings in an .env file
- Make testing easier by automatically creating the `ksql` database
- Create a way for users to submit user defined dialects
- Improve error messages
- Add support for the update function to work with maps for partial updates
- Add support for the insert function to work with maps
- Add support for a `ksql.Array(params ...interface{})` for allowing queries like this:
`db.Query(ctx, &user, "SELECT * FROM user WHERE id in (?)", ksql.Array(1,2,3))`
### Optimization Oportunities

47
adapters.go Normal file
View File

@ -0,0 +1,47 @@
package ksql
import (
"context"
"database/sql"
)
// SQLAdapter adapts the sql.DB type to be compatible with the `DBAdapter` interface
type SQLAdapter struct {
*sql.DB
}
var _ DBAdapter = SQLAdapter{}
// ExecContext implements the DBAdapter interface
func (s SQLAdapter) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
return s.DB.ExecContext(ctx, query, args...)
}
// QueryContext implements the DBAdapter interface
func (s SQLAdapter) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) {
return s.DB.QueryContext(ctx, query, args...)
}
// SQLTx is used to implement the DBAdapter interface and implements
// the Tx interface
type SQLTx struct {
*sql.Tx
}
// ExecContext implements the Tx interface
func (s SQLTx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
return s.Tx.ExecContext(ctx, query, args...)
}
// QueryContext implements the Tx interface
func (s SQLTx) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) {
return s.Tx.QueryContext(ctx, query, args...)
}
var _ Tx = SQLTx{}
// BeginTx implements the Tx interface
func (s SQLAdapter) BeginTx(ctx context.Context) (Tx, error) {
tx, err := s.DB.BeginTx(ctx, nil)
return SQLTx{Tx: tx}, err
}

48
ksql.go
View File

@ -35,8 +35,36 @@ type DB struct {
//
// To create a new client using this adapter use ksql.NewWithDB()
type DBAdapter interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
}
// TxBeginner needs to be implemented by the DBAdapter in order to make it possible
// to use the `ksql.Transaction()` function.
type TxBeginner interface {
BeginTx(ctx context.Context) (Tx, error)
}
// Result stores information about the result of an Exec query
type Result = sql.Result
// Rows represents the results from a call to Query()
type Rows interface {
Scan(...interface{}) error
Close() error
Next() bool
Err() error
Columns() ([]string, error)
}
var _ Rows = &sql.Rows{}
// Tx represents a transaction and is expected to be returned by the DBAdapter.BeginTx function
type Tx interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
Rollback() error
Commit() error
}
// Config describes the optional arguments accepted
@ -66,7 +94,7 @@ func New(
db.SetMaxOpenConns(config.MaxOpenConns)
return newWithDB(db, dbDriver, connectionString)
return newWithDB(SQLAdapter{db}, dbDriver, connectionString)
}
// NewWithDB allows the user to insert a custom implementation
@ -752,11 +780,11 @@ func (c DB) Exec(ctx context.Context, query string, params ...interface{}) error
// Transaction just runs an SQL command on the database returning no rows.
func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error {
switch db := c.db.(type) {
case *sql.Tx:
switch txBeginner := c.db.(type) {
case Tx:
return fn(c)
case *sql.DB:
tx, err := db.BeginTx(ctx, nil)
case TxBeginner:
tx, err := txBeginner.BeginTx(ctx)
if err != nil {
return err
}
@ -789,7 +817,7 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error {
return tx.Commit()
default:
return fmt.Errorf("unexpected error on ksql: db attribute has an invalid type")
return fmt.Errorf("can't start transaction: The DBAdapter doesn't implement the TxBegginner interface")
}
}
@ -801,7 +829,7 @@ func (nopScanner) Scan(value interface{}) error {
return nil
}
func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error {
func scanRows(dialect dialect, rows Rows, record interface{}) error {
v := reflect.ValueOf(record)
t := v.Type()
if t.Kind() != reflect.Ptr {
@ -836,7 +864,7 @@ func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error {
return rows.Scan(scanArgs...)
}
func getScanArgsForNestedStructs(dialect dialect, rows *sql.Rows, t reflect.Type, v reflect.Value, info kstructs.StructInfo) []interface{} {
func getScanArgsForNestedStructs(dialect dialect, rows Rows, t reflect.Type, v reflect.Value, info kstructs.StructInfo) []interface{} {
scanArgs := []interface{}{}
for i := 0; i < v.NumField(); i++ {
// TODO(vingarcia00): Handle case where type is pointer

View File

@ -687,7 +687,7 @@ func TestInsert(t *testing.T) {
assert.Equal(t, nil, err)
var inserted User
err := getUserByName(db, c.dialect, &inserted, "Preset Name")
err := getUserByName(SQLAdapter{db}, c.dialect, &inserted, "Preset Name")
assert.Equal(t, nil, err)
assert.Equal(t, 5455, inserted.Age)
})
@ -792,7 +792,7 @@ func TestInsert(t *testing.T) {
assert.Equal(t, nil, err)
var u User
err = getUserByName(db, c.dialect, &u, "Inserted With no ID")
err = getUserByName(SQLAdapter{db}, c.dialect, &u, "Inserted With no ID")
assert.Equal(t, nil, err)
assert.NotEqual(t, uint(0), u.ID)
assert.Equal(t, 42, u.Age)
@ -1944,7 +1944,7 @@ func newTestDB(db *sql.DB, driver string) DB {
return DB{
driver: driver,
dialect: supportedDialects[driver],
db: db,
db: SQLAdapter{db},
}
}
@ -1968,7 +1968,7 @@ func shiftErrSlice(errs *[]error) error {
}
func getUsersByID(dbi DBAdapter, dialect dialect, resultsPtr *[]User, ids ...uint) error {
db := dbi.(*sql.DB)
db := dbi.(SQLAdapter)
placeholders := make([]string, len(ids))
params := make([]interface{}, len(ids))
@ -1978,7 +1978,7 @@ func getUsersByID(dbi DBAdapter, dialect dialect, resultsPtr *[]User, ids ...uin
}
results := []User{}
rows, err := db.Query(
rows, err := db.DB.Query(
fmt.Sprintf(
"SELECT id, name, age FROM users WHERE id IN (%s)",
strings.Join(placeholders, ", "),
@ -2010,9 +2010,9 @@ func getUsersByID(dbi DBAdapter, dialect dialect, resultsPtr *[]User, ids ...uin
}
func getUserByID(dbi DBAdapter, dialect dialect, result *User, id uint) error {
db := dbi.(*sql.DB)
db := dbi.(SQLAdapter)
row := db.QueryRow(`SELECT id, name, age, address FROM users WHERE id=`+dialect.Placeholder(0), id)
row := db.DB.QueryRow(`SELECT id, name, age, address FROM users WHERE id=`+dialect.Placeholder(0), id)
if row.Err() != nil {
return row.Err()
}
@ -2031,9 +2031,9 @@ func getUserByID(dbi DBAdapter, dialect dialect, result *User, id uint) error {
}
func getUserByName(dbi DBAdapter, dialect dialect, result *User, name string) error {
db := dbi.(*sql.DB)
db := dbi.(SQLAdapter)
row := db.QueryRow(`SELECT id, name, age, address FROM users WHERE name=`+dialect.Placeholder(0), name)
row := db.DB.QueryRow(`SELECT id, name, age, address FROM users WHERE name=`+dialect.Placeholder(0), name)
if row.Err() != nil {
return row.Err()
}