diff --git a/kiss_orm.go b/kiss_orm.go index e3eb846..afd9918 100644 --- a/kiss_orm.go +++ b/kiss_orm.go @@ -13,7 +13,12 @@ type DB struct { driver string dialect dialect tableName string - db *sql.DB + db sqlProvider +} + +type sqlProvider interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } // New instantiates a new client diff --git a/kiss_orm_test.go b/kiss_orm_test.go index f3b829b..a21d597 100644 --- a/kiss_orm_test.go +++ b/kiss_orm_test.go @@ -952,7 +952,9 @@ func shiftErrSlice(errs *[]error) error { return err } -func getUsersByID(db *sql.DB, dialect dialect, resultsPtr *[]User, ids ...uint) error { +func getUsersByID(dbi sqlProvider, dialect dialect, resultsPtr *[]User, ids ...uint) error { + db := dbi.(*sql.DB) + placeholders := make([]string, len(ids)) params := make([]interface{}, len(ids)) for i := range ids { @@ -992,7 +994,9 @@ func getUsersByID(db *sql.DB, dialect dialect, resultsPtr *[]User, ids ...uint) return nil } -func getUserByID(db *sql.DB, dialect dialect, result *User, id uint) error { +func getUserByID(dbi sqlProvider, dialect dialect, result *User, id uint) error { + db := dbi.(*sql.DB) + row := db.QueryRow(`SELECT id, name, age FROM users WHERE id=`+dialect.Placeholder(0), id) if row.Err() != nil { return row.Err()