diff --git a/README.md b/README.md index ebdcf7a..ca1f2ad 100644 --- a/README.md +++ b/README.md @@ -523,11 +523,15 @@ make test - Add tests for tables using composite keys - Add support for serializing structs as other formats such as YAML -- Update `kstructs.FillStructWith` to work with `json` tagged attributes +- Update `kstructs.FillStructWith` to work with `ksql:"..,json"` tagged attributes - Make testing easier by exposing the connection strings in an .env file - Make testing easier by automatically creating the `ksql` database - Create a way for users to submit user defined dialects - Improve error messages +- Add support for the update function to work with maps for partial updates +- Add support for the insert function to work with maps +- Add support for a `ksql.Array(params ...interface{})` for allowing queries like this: + `db.Query(ctx, &user, "SELECT * FROM user WHERE id in (?)", ksql.Array(1,2,3))` ### Optimization Oportunities diff --git a/adapters.go b/adapters.go new file mode 100644 index 0000000..2e041a2 --- /dev/null +++ b/adapters.go @@ -0,0 +1,47 @@ +package ksql + +import ( + "context" + "database/sql" +) + +// SQLAdapter adapts the sql.DB type to be compatible with the `DBAdapter` interface +type SQLAdapter struct { + *sql.DB +} + +var _ DBAdapter = SQLAdapter{} + +// ExecContext implements the DBAdapter interface +func (s SQLAdapter) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { + return s.DB.ExecContext(ctx, query, args...) +} + +// QueryContext implements the DBAdapter interface +func (s SQLAdapter) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) { + return s.DB.QueryContext(ctx, query, args...) +} + +// SQLTx is used to implement the DBAdapter interface and implements +// the Tx interface +type SQLTx struct { + *sql.Tx +} + +// ExecContext implements the Tx interface +func (s SQLTx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { + return s.Tx.ExecContext(ctx, query, args...) +} + +// QueryContext implements the Tx interface +func (s SQLTx) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) { + return s.Tx.QueryContext(ctx, query, args...) +} + +var _ Tx = SQLTx{} + +// BeginTx implements the Tx interface +func (s SQLAdapter) BeginTx(ctx context.Context) (Tx, error) { + tx, err := s.DB.BeginTx(ctx, nil) + return SQLTx{Tx: tx}, err +} diff --git a/ksql.go b/ksql.go index 0e87e26..1477d25 100644 --- a/ksql.go +++ b/ksql.go @@ -35,8 +35,36 @@ type DB struct { // // To create a new client using this adapter use ksql.NewWithDB() type DBAdapter interface { - ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) - QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) +} + +// TxBeginner needs to be implemented by the DBAdapter in order to make it possible +// to use the `ksql.Transaction()` function. +type TxBeginner interface { + BeginTx(ctx context.Context) (Tx, error) +} + +// Result stores information about the result of an Exec query +type Result = sql.Result + +// Rows represents the results from a call to Query() +type Rows interface { + Scan(...interface{}) error + Close() error + Next() bool + Err() error + Columns() ([]string, error) +} + +var _ Rows = &sql.Rows{} + +// Tx represents a transaction and is expected to be returned by the DBAdapter.BeginTx function +type Tx interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) + Rollback() error + Commit() error } // Config describes the optional arguments accepted @@ -66,7 +94,7 @@ func New( db.SetMaxOpenConns(config.MaxOpenConns) - return newWithDB(db, dbDriver, connectionString) + return newWithDB(SQLAdapter{db}, dbDriver, connectionString) } // NewWithDB allows the user to insert a custom implementation @@ -752,11 +780,11 @@ func (c DB) Exec(ctx context.Context, query string, params ...interface{}) error // Transaction just runs an SQL command on the database returning no rows. func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error { - switch db := c.db.(type) { - case *sql.Tx: + switch txBeginner := c.db.(type) { + case Tx: return fn(c) - case *sql.DB: - tx, err := db.BeginTx(ctx, nil) + case TxBeginner: + tx, err := txBeginner.BeginTx(ctx) if err != nil { return err } @@ -789,7 +817,7 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error { return tx.Commit() default: - return fmt.Errorf("unexpected error on ksql: db attribute has an invalid type") + return fmt.Errorf("can't start transaction: The DBAdapter doesn't implement the TxBegginner interface") } } @@ -801,7 +829,7 @@ func (nopScanner) Scan(value interface{}) error { return nil } -func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error { +func scanRows(dialect dialect, rows Rows, record interface{}) error { v := reflect.ValueOf(record) t := v.Type() if t.Kind() != reflect.Ptr { @@ -836,7 +864,7 @@ func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error { return rows.Scan(scanArgs...) } -func getScanArgsForNestedStructs(dialect dialect, rows *sql.Rows, t reflect.Type, v reflect.Value, info kstructs.StructInfo) []interface{} { +func getScanArgsForNestedStructs(dialect dialect, rows Rows, t reflect.Type, v reflect.Value, info kstructs.StructInfo) []interface{} { scanArgs := []interface{}{} for i := 0; i < v.NumField(); i++ { // TODO(vingarcia00): Handle case where type is pointer diff --git a/ksql_test.go b/ksql_test.go index f57ba46..d940800 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -687,7 +687,7 @@ func TestInsert(t *testing.T) { assert.Equal(t, nil, err) var inserted User - err := getUserByName(db, c.dialect, &inserted, "Preset Name") + err := getUserByName(SQLAdapter{db}, c.dialect, &inserted, "Preset Name") assert.Equal(t, nil, err) assert.Equal(t, 5455, inserted.Age) }) @@ -792,7 +792,7 @@ func TestInsert(t *testing.T) { assert.Equal(t, nil, err) var u User - err = getUserByName(db, c.dialect, &u, "Inserted With no ID") + err = getUserByName(SQLAdapter{db}, c.dialect, &u, "Inserted With no ID") assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) assert.Equal(t, 42, u.Age) @@ -1944,7 +1944,7 @@ func newTestDB(db *sql.DB, driver string) DB { return DB{ driver: driver, dialect: supportedDialects[driver], - db: db, + db: SQLAdapter{db}, } } @@ -1968,7 +1968,7 @@ func shiftErrSlice(errs *[]error) error { } func getUsersByID(dbi DBAdapter, dialect dialect, resultsPtr *[]User, ids ...uint) error { - db := dbi.(*sql.DB) + db := dbi.(SQLAdapter) placeholders := make([]string, len(ids)) params := make([]interface{}, len(ids)) @@ -1978,7 +1978,7 @@ func getUsersByID(dbi DBAdapter, dialect dialect, resultsPtr *[]User, ids ...uin } results := []User{} - rows, err := db.Query( + rows, err := db.DB.Query( fmt.Sprintf( "SELECT id, name, age FROM users WHERE id IN (%s)", strings.Join(placeholders, ", "), @@ -2010,9 +2010,9 @@ func getUsersByID(dbi DBAdapter, dialect dialect, resultsPtr *[]User, ids ...uin } func getUserByID(dbi DBAdapter, dialect dialect, result *User, id uint) error { - db := dbi.(*sql.DB) + db := dbi.(SQLAdapter) - row := db.QueryRow(`SELECT id, name, age, address FROM users WHERE id=`+dialect.Placeholder(0), id) + row := db.DB.QueryRow(`SELECT id, name, age, address FROM users WHERE id=`+dialect.Placeholder(0), id) if row.Err() != nil { return row.Err() } @@ -2031,9 +2031,9 @@ func getUserByID(dbi DBAdapter, dialect dialect, result *User, id uint) error { } func getUserByName(dbi DBAdapter, dialect dialect, result *User, name string) error { - db := dbi.(*sql.DB) + db := dbi.(SQLAdapter) - row := db.QueryRow(`SELECT id, name, age, address FROM users WHERE name=`+dialect.Placeholder(0), name) + row := db.DB.QueryRow(`SELECT id, name, age, address FROM users WHERE name=`+dialect.Placeholder(0), name) if row.Err() != nil { return row.Err() }