diff --git a/Makefile b/Makefile index 49beb88..5b72354 100644 --- a/Makefile +++ b/Makefile @@ -1,32 +1,38 @@ args= path=./... -GOPATH=$(shell go env GOPATH) +GOBIN=$(shell go env GOPATH)/bin TIME=1s test: setup - $(GOPATH)/bin/richgo test $(path) $(args) + $(GOBIN)/richgo test $(path) $(args) bench: go test -bench=. -benchtime=$(TIME) + @echo "Benchmark executed at: $$(date --iso)" + @echo "Benchmark executed on commit: $$(git rev-parse HEAD)" lint: setup - @$(GOPATH)/bin/golint -set_exit_status -min_confidence 0.9 $(path) $(args) + @$(GOBIN)/golint -set_exit_status -min_confidence 0.9 $(path) $(args) @go vet $(path) $(args) @echo "Golint & Go Vet found no problems on your code!" mock: setup - mockgen -package=exampleservice -source=contracts.go -destination=examples/example_service/mocks.go + $(GOBIN)/mockgen -package=exampleservice -source=contracts.go -destination=examples/example_service/mocks.go -setup: .make.setup -.make.setup: +setup: $(GOBIN)/richgo $(GOBIN)/golint $(GOBIN)/mockgen + +$(GOBIN)/richgo: go get github.com/kyoh86/richgo + +$(GOBIN)/golint: go get golang.org/x/lint + +$(GOBIN)/mockgen: @# (Gomock is used on examples/example_service) go get github.com/golang/mock/gomock go get github.com/golang/mock/mockgen - touch .make.setup # Running examples: exampleservice: mock diff --git a/README.md b/README.md index ec351fd..2fe2de1 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,11 @@ # KissSQL -Welcome to the KissSQL project, the Keep It Stupid Simple sql package. +If the thing you hate the most when coding is having too much unnecessary +abstractions and the second thing you hate the most is having verbose +and repetitive code for routine tasks this library is probably for you. + +Welcome to the KissSQL project, the "Keep It Stupid Simple" SQL client for Go. This package was created to be used by any developer efficiently and safely. The goals were: @@ -10,14 +14,19 @@ The goals were: - To be hard to make mistakes - To have a small API so it's easy to learn - To be easy to mock and test (very easy) -- To be above all readable. +- And above all to be readable. **Supported Drivers:** -Currently we only support 2 Drivers: +Currently we support only the 4 most popular Golang database drivers: - `"postgres"` - `"sqlite3"` +- `"mysql"` +- `"sqlserver"` + +If you need a new one included please open an issue or make +your own implementation and submit a Pull Request. ### Why KissSQL? @@ -50,27 +59,27 @@ in order to save development time for your team, i.e.: - Less time spent learning (few methods to learn) - Less time spent testing (helper tools made to help you) -- less time spent debugging (simple apis are easier to debug) +- Less time spent debugging (simple apis are easier to debug) - and less time reading & understanding the code ### Kiss Interface -The current interface is as follows and we plan to keep +The current interface is as follows and we plan on keeping it with as little functions as possible, so don't expect many additions: ```go -// SQLProvider describes the public behavior of this ORM -type SQLProvider interface { - Insert(ctx context.Context, record interface{}) error - Update(ctx context.Context, record interface{}) error - Delete(ctx context.Context, idsOrRecords ...interface{}) error +// Provider describes the public behavior of this ORM +type Provider interface { + Insert(ctx context.Context, table Table, record interface{}) error + Update(ctx context.Context, table Table, record interface{}) error + Delete(ctx context.Context, table Table, idsOrRecords ...interface{}) error Query(ctx context.Context, records interface{}, query string, params ...interface{}) error QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error QueryChunks(ctx context.Context, parser ChunkParser) error Exec(ctx context.Context, query string, params ...interface{}) error - Transaction(ctx context.Context, fn func(SQLProvider) error) error + Transaction(ctx context.Context, fn func(Provider) error) error } ``` @@ -79,6 +88,10 @@ type SQLProvider interface { This example is also available [here](./examples/crud/crud.go) if you want to compile it yourself. +Also we have a small feature for building the "SELECT" part of the query if +you rather not use `SELECT *` queries, you may skip to the +[Select Generator Feature](#Select-Generator-Feature) which is very clean too. + ```Go package main @@ -115,11 +128,14 @@ type Address struct { City string `json:"city"` } +// UsersTable informs ksql the name of the table and that it can +// use the default value for the primary key column name: "id" +var UsersTable = ksql.NewTable("users") + func main() { ctx := context.Background() db, err := ksql.New("sqlite3", "/tmp/hello.sqlite", ksql.Config{ MaxOpenConns: 1, - TableName: "users", }) if err != nil { panic(err.Error()) @@ -144,14 +160,14 @@ func main() { State: "MG", }, } - err = db.Insert(ctx, &alison) + err = db.Insert(ctx, UsersTable, &alison) if err != nil { panic(err.Error()) } fmt.Println("Alison ID:", alison.ID) // Inserting inline: - err = db.Insert(ctx, &User{ + err = db.Insert(ctx, UsersTable, &User{ Name: "Cristina", Age: 27, Address: Address{ @@ -163,7 +179,7 @@ func main() { } // Deleting Alison: - err = db.Delete(ctx, alison.ID) + err = db.Delete(ctx, UsersTable, alison.ID) if err != nil { panic(err.Error()) } @@ -178,12 +194,12 @@ func main() { // Updating all fields from Cristina: cris.Name = "Cris" - err = db.Update(ctx, cris) + err = db.Update(ctx, UsersTable, cris) // Changing the age of Cristina but not touching any other fields: // Partial update technique 1: - err = db.Update(ctx, struct { + err = db.Update(ctx, UsersTable, struct { ID int `ksql:"id"` Age int `ksql:"age"` }{ID: cris.ID, Age: 28}) @@ -192,7 +208,7 @@ func main() { } // Partial update technique 2: - err = db.Update(ctx, PartialUpdateUser{ + err = db.Update(ctx, UsersTable, PartialUpdateUser{ ID: cris.ID, Age: nullable.Int(28), }) @@ -216,7 +232,7 @@ func main() { } // Making transactions: - err = db.Transaction(ctx, func(db ksql.SQLProvider) error { + err = db.Transaction(ctx, func(db ksql.Provider) error { var cris2 User err = db.QueryOne(ctx, &cris2, "SELECT * FROM users WHERE id = ?", cris.ID) if err != nil { @@ -224,7 +240,7 @@ func main() { return err } - err = db.Update(ctx, PartialUpdateUser{ + err = db.Update(ctx, UsersTable, PartialUpdateUser{ ID: cris2.ID, Age: nullable.Int(29), }) @@ -245,48 +261,287 @@ func main() { } ``` +### Query Chunks Feature + +It's very unsual for us to need to load a number of records from the +database that might be too big for fitting in memory, e.g. load all the +users and send them somewhere. But it might happen. + +For these cases it's best to load chunks of data at a time so +that we can work on a substantial amount of data at a time and never +overload our memory capacity. For this use case we have a specific +function called `QueryChunks`: + +```golang +err = db.QueryChunks(ctx, ksql.ChunkParser{ + Query: "SELECT * FROM users WHERE type = ?", + Params: []interface{}{usersType}, + ChunkSize: 100, + ForEachChunk: func(users []User) error { + err := sendUsersSomewhere(users) + if err != nil { + // This will abort the QueryChunks loop and return this error + return err + } + return nil + }, +}) +if err != nil { + panic(err.Error()) +} +``` + +It's signature is more complicated than the other two Query\* methods, +thus, it is adivisible to always prefer using the other two when possible +reserving this one for the rare use-case where you are actually +loading big sections of the database into memory. + +### Select Generator Feature + +There are good reasons not to use `SELECT *` queries the most important +of them is that you might end up loading more information than you are actually +going to use putting more pressure in your database for no good reason. + +To prevent that `ksql` has a feature specifically for building the `SELECT` +part of the query using the tags from the input struct. +Using it is very simple and it works with all the 3 Query\* functions: + +Querying a single user: + +```golang +var user User +err = db.QueryOne(ctx, &user, "FROM users WHERE id = ?", userID) +if err != nil { + panic(err.Error()) +} +``` + +Querying a page of users: + +```golang +var users []User +err = db.Query(ctx, &users, "FROM users WHERE type = ? ORDER BY id LIMIT ? OFFSET ?", "Cristina", limit, offset) +if err != nil { + panic(err.Error()) +} +``` + +Querying all the users, or any potentially big number of users, from the database (not usual, but supported): + +```golang +err = db.QueryChunks(ctx, ksql.ChunkParser{ + Query: "FROM users WHERE type = ?", + Params: []interface{}{usersType}, + ChunkSize: 100, + ForEachChunk: func(users []User) error { + err := sendUsersSomewhere(users) + if err != nil { + // This will abort the QueryChunks loop and return this error + return err + } + return nil + }, +}) +if err != nil { + panic(err.Error()) +} +``` + +The implementation of this feature is actually simple internally. +First we check if the query is starting with the word `FROM`, +if it is then we just get the `ksql` tags from the struct and +then use it for building the `SELECT` statement. + +The `SELECT` statement is then cached so we don't have to build it again +the next time in order to keep the library efficient even when +using this feature. + +### Select Generation with Joins + +So there is one use-case that was not covered by `ksql` so far: + +What if you want to JOIN multiple tables for which you already have +structs defined? Would you need to create a new struct to represent +the joined columns of the two tables? no, we actually have this covered as well. + +`ksql` has a special feature for allowing the reuse of existing +structs by using composition in an anonymous struct, and then +generating the `SELECT` part of the query accordingly: + +Querying a single joined row: + +```golang +var row struct{ + User User `tablename:"u"` // (here the tablename must match the aliased tablename in the query) + Post Post `tablename:"posts"` // (if no alias is used you should use the actual name of the table) +} +err = db.QueryOne(ctx, &row, "FROM users as u JOIN posts ON u.id = posts.user_id WHERE u.id = ?", userID) +if err != nil { + panic(err.Error()) +} +``` + +Querying a page of joined rows: + +```golang +var rows []struct{ + User User `tablename:"u"` + Post Post `tablename:"p"` +} +err = db.Query(ctx, &rows, + "FROM users as u JOIN posts as p ON u.id = p.user_id WHERE name = ? LIMIT ? OFFSET ?", + "Cristina", limit, offset, +) +if err != nil { + panic(err.Error()) +} +``` + +Querying all the users, or any potentially big number of users, from the database (not usual, but supported): + +```golang +err = db.QueryChunks(ctx, ksql.ChunkParser{ + Query: "FROM users as u JOIN posts as p ON u.id = p.user_id WHERE type = ?", + Params: []interface{}{usersType}, + ChunkSize: 100, + ForEachChunk: func(rows []struct{ + User User `tablename:"u"` + Post Post `tablename:"p"` + }) error { + err := sendRowsSomewhere(rows) + if err != nil { + // This will abort the QueryChunks loop and return this error + return err + } + return nil + }, +}) +if err != nil { + panic(err.Error()) +} +``` + +As advanced as this feature might seem we don't do any parsing of the query, +and all the work is done only once and then cached. + +What actually happens is that we use the "tablename" tag to build the `SELECT` +part of the query like this: + +- `SELECT u.id, u.name, u.age, p.id, p.title ` + +This is then cached, and when we need it again we concatenate it with the rest +of the query. + +This feature has two important limitations: + +1. It is not possible to use `tablename` tags together with normal `ksql` tags. + Doing so will cause the `tablename` tags to be ignored in favor of the `ksql` ones. +2. It is not possible to use it without omitting the `SELECT` part of the query. + While in normal queries we match the selected field with the attribute by name, + in queries joining multiple tables we can't use this strategy because + different tables might have columns with the same name, and we don't + really have access to the full name of these columns making, for example, + it impossible to differentiate between `u.id` and `p.id` except by the + order in which these fields were passed. Thus, it is necessary that + the library itself writes the `SELECT` part of the query when using + this technique so that we can control the order or the selected fields. + +Ok, but what if I don't want to use this feature? + +You are not forced to, and there are a few use-cases where you would prefer not to, e.g.: + +```golang +var rows []struct{ + UserName string `ksql:"name"` + PostTitle string `ksql:"title"` +} +err := db.Query(ctx, &rows, "SELECT u.name, p.title FROM users u JOIN posts p ON u.id = p.user_id LIMIT 10") +if err != nil { + panic(err.Error()) +} +``` + +In the example above, since we are only interested in a couple of columns it +is far simpler and more efficient for the database to only select the columns +that we actually care about, so it's better not to use composite kstructs. + ### Testing Examples This library has a few helper functions for helping your tests: -- `ksql.FillStructWith(struct interface{}, dbRow map[string]interface{}) error` -- `ksql.FillSliceWith(structSlice interface{}, dbRows []map[string]interface{}) error` -- `ksql.StructToMap(struct interface{}) (map[string]interface{}, error)` +- `kstructs.FillStructWith(struct interface{}, dbRow map[string]interface{}) error` +- `kstructs.FillSliceWith(structSlice interface{}, dbRows []map[string]interface{}) error` +- `kstructs.StructToMap(struct interface{}) (map[string]interface{}, error)` +- `kstructs.CallFunctionWithRows(fn interface{}, rows []map[string]interface{}) (map[string]interface{}, error)` If you want to see examples (we have examples for all the public functions) just read the example tests available on our [example service](./examples/example_service) ### Benchmark Comparison -The benchmark is not bad, as far the code is in average as fast as sqlx: +The benchmark is very good, the code is, in practical terms, as fast as sqlx: ```bash -$ make bench TIME=3s -go test -bench=. -benchtime=3s +$ make bench TIME=5s +go test -bench=. -benchtime=5s goos: linux goarch: amd64 pkg: github.com/vingarcia/ksql cpu: Intel(R) Core(TM) i5-3210M CPU @ 2.50GHz -BenchmarkInsert/ksql-setup/insert-one-4 4302 776648 ns/op -BenchmarkInsert/sqlx-setup/insert-one-4 4716 762358 ns/op -BenchmarkQuery/ksql-setup/single-row-4 12204 293858 ns/op -BenchmarkQuery/ksql-setup/multiple-rows-4 11145 323571 ns/op -BenchmarkQuery/sqlx-setup/single-row-4 12440 290937 ns/op -BenchmarkQuery/sqlx-setup/multiple-rows-4 10000 310314 ns/op +BenchmarkInsert/ksql-setup/insert-one-4 5293 960859 ns/op +BenchmarkInsert/pgx-adapter-setup/insert-one-4 7982 736973 ns/op +BenchmarkInsert/sqlx-setup/insert-one-4 6854 857824 ns/op +BenchmarkQuery/ksql-setup/single-row-4 12596 407116 ns/op +BenchmarkQuery/ksql-setup/multiple-rows-4 15883 391135 ns/op +BenchmarkQuery/pgx-adapter-setup/single-row-4 34008 165604 ns/op +BenchmarkQuery/pgx-adapter-setup/multiple-rows-4 22579 280673 ns/op +BenchmarkQuery/sqlx-setup/single-row-4 10000 512741 ns/op +BenchmarkQuery/sqlx-setup/multiple-rows-4 10779 596377 ns/op PASS -ok github.com/vingarcia/ksql 34.251s +ok github.com/vingarcia/ksql 94.951s +Benchmark executed at: 2021-08-01 +Benchmark executed on commit: 37298e2c243f1ec66e88dd92ed7c4542f7820b4f +``` + +### Running the ksql tests (for contributors) + +The tests run in dockerized database instances so the easiest way +to have them working is to just start them using docker-compose: + +```bash +docker-compose up -d +``` + +And then for each of them you will need to run the command: + +```sql +CREATE DATABASE ksql; +``` + +After that you can just run the tests by using: + +```bash +make test ``` ### TODO List -- Implement support for nested objects with prefixed table names -- Improve error messages - Add tests for tables using composite keys - Add support for serializing structs as other formats such as YAML -- Update structs.FillStructWith to work with `json` tagged attributes +- Update `kstructs.FillStructWith` to work with `ksql:"..,json"` tagged attributes +- Make testing easier by exposing the connection strings in an .env file +- Make testing easier by automatically creating the `ksql` database +- Create a way for users to submit user defined dialects +- Improve error messages +- Add support for the update function to work with maps for partial updates +- Add support for the insert function to work with maps +- Add support for a `ksql.Array(params ...interface{})` for allowing queries like this: + `db.Query(ctx, &user, "SELECT * FROM user WHERE id in (?)", ksql.Array(1,2,3))` ### Optimization Oportunities - Test if using a pointer on the field info is faster or not - Consider passing the cached structInfo as argument for all the functions that use it, so that we don't need to get it more than once in the same call. +- Use a cache to store all queries after they are built +- Preload the insert method for all dialects inside `ksql.NewTable()` diff --git a/benchmark_test.go b/benchmark_test.go index 8914511..605ee7a 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -12,6 +12,8 @@ import ( "github.com/vingarcia/ksql" ) +var UsersTable = ksql.NewTable("users") + func BenchmarkInsert(b *testing.B) { ctx := context.Background() @@ -20,7 +22,6 @@ func BenchmarkInsert(b *testing.B) { ksqlDB, err := ksql.New(driver, connStr, ksql.Config{ MaxOpenConns: 1, - TableName: "users", }) if err != nil { b.FailNow() @@ -40,7 +41,33 @@ func BenchmarkInsert(b *testing.B) { b.Run("insert-one", func(b *testing.B) { for i := 0; i < b.N; i++ { - err := ksqlDB.Insert(ctx, &User{ + err := ksqlDB.Insert(ctx, UsersTable, &User{ + Name: strconv.Itoa(i), + Age: i, + }) + if err != nil { + b.Fatalf("insert error: %s", err.Error()) + } + } + }) + }) + + pgxDB, err := ksql.NewWithPGX(ctx, connStr, ksql.Config{ + MaxOpenConns: 1, + }) + if err != nil { + b.Fatalf("error creating pgx client: %s", err) + } + + b.Run("pgx-adapter-setup", func(b *testing.B) { + err := recreateTable(connStr) + if err != nil { + b.Fatalf("error creating table: %s", err.Error()) + } + + b.Run("insert-one", func(b *testing.B) { + for i := 0; i < b.N; i++ { + err := pgxDB.Insert(ctx, UsersTable, &User{ Name: strconv.Itoa(i), Age: i, }) @@ -92,7 +119,6 @@ func BenchmarkQuery(b *testing.B) { ksqlDB, err := ksql.New(driver, connStr, ksql.Config{ MaxOpenConns: 1, - TableName: "users", }) if err != nil { b.FailNow() @@ -139,6 +165,48 @@ func BenchmarkQuery(b *testing.B) { }) }) + pgxDB, err := ksql.NewWithPGX(ctx, connStr, ksql.Config{ + MaxOpenConns: 1, + }) + if err != nil { + b.Fatalf("error creating pgx client: %s", err) + } + + b.Run("pgx-adapter-setup", func(b *testing.B) { + err := recreateTable(connStr) + if err != nil { + b.Fatalf("error creating table: %s", err.Error()) + } + + err = insertUsers(connStr, 100) + if err != nil { + b.Fatalf("error inserting users: %s", err.Error()) + } + + b.Run("single-row", func(b *testing.B) { + for i := 0; i < b.N; i++ { + var user User + err := pgxDB.QueryOne(ctx, &user, `SELECT * FROM users OFFSET $1 LIMIT 1`, i%100) + if err != nil { + b.Fatalf("query error: %s", err.Error()) + } + } + }) + + b.Run("multiple-rows", func(b *testing.B) { + for i := 0; i < b.N; i++ { + var users []User + err := pgxDB.Query(ctx, &users, `SELECT * FROM users OFFSET $1 LIMIT 10`, i%90) + if err != nil { + b.Fatalf("query error: %s", err.Error()) + } + if len(users) < 10 { + b.Fatalf("expected 10 scanned users, but got: %d", len(users)) + } + } + }) + }) + sqlxDB, err := sqlx.Open(driver, connStr) sqlxDB.SetMaxOpenConns(1) diff --git a/contracts.go b/contracts.go index fc52d92..0599770 100644 --- a/contracts.go +++ b/contracts.go @@ -14,18 +14,72 @@ var ErrRecordNotFound error = errors.Wrap(sql.ErrNoRows, "ksql: the query return // ErrAbortIteration ... var ErrAbortIteration error = fmt.Errorf("ksql: abort iteration, should only be used inside QueryChunks function") -// SQLProvider describes the public behavior of this ORM -type SQLProvider interface { - Insert(ctx context.Context, record interface{}) error - Update(ctx context.Context, record interface{}) error - Delete(ctx context.Context, idsOrRecords ...interface{}) error +// Provider describes the public behavior of this ORM +type Provider interface { + Insert(ctx context.Context, table Table, record interface{}) error + Update(ctx context.Context, table Table, record interface{}) error + Delete(ctx context.Context, table Table, idsOrRecords ...interface{}) error Query(ctx context.Context, records interface{}, query string, params ...interface{}) error QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error QueryChunks(ctx context.Context, parser ChunkParser) error Exec(ctx context.Context, query string, params ...interface{}) error - Transaction(ctx context.Context, fn func(SQLProvider) error) error + Transaction(ctx context.Context, fn func(Provider) error) error +} + +// Table describes the required information for inserting, updating and +// deleting entities from the database by ID using the 3 helper functions +// created for that purpose. +type Table struct { + // this name must be set in order to use the Insert, Delete and Update helper + // functions. If you only intend to make queries or to use the Exec function + // it is safe to leave this field unset. + name string + + // IDColumns defaults to []string{"id"} if unset + idColumns []string +} + +// NewTable returns a Table instance that stores +// the tablename and the names of columns used as ID, +// if no column name is passed it defaults to using +// the `"id"` column. +// +// This Table is required only for using the helper methods: +// +// - Insert +// - Update +// - Delete +// +// Passing multiple ID columns will be interpreted +// as a single composite key, if you want +// to use the helper functions with different +// keys you'll need to create multiple Table instances +// for the same database table, each with a different +// set of ID columns, but this is usually not necessary. +func NewTable(tableName string, ids ...string) Table { + if len(ids) == 0 { + ids = []string{"id"} + } + + return Table{ + name: tableName, + idColumns: ids, + } +} + +func (t Table) insertMethodFor(dialect Dialect) insertMethod { + if len(t.idColumns) == 1 { + return dialect.InsertMethod() + } + + insertMethod := dialect.InsertMethod() + if insertMethod == insertWithLastInsertID { + return insertWithNoIDRetrieval + } + + return insertMethod } // ChunkParser stores the arguments of the QueryChunks function diff --git a/dialect.go b/dialect.go index 32c2e21..655c29a 100644 --- a/dialect.go +++ b/dialect.go @@ -5,15 +5,41 @@ import ( "strconv" ) +type insertMethod int + +const ( + insertWithReturning insertMethod = iota + insertWithOutput + insertWithLastInsertID + insertWithNoIDRetrieval +) + +var supportedDialects = map[string]Dialect{ + "postgres": &postgresDialect{}, + "sqlite3": &sqlite3Dialect{}, + "mysql": &mysqlDialect{}, + "sqlserver": &sqlserverDialect{}, +} + // Dialect is used to represent the different ways // of writing SQL queries used by each SQL driver. type Dialect interface { + InsertMethod() insertMethod Escape(str string) string Placeholder(idx int) string + DriverName() string } type postgresDialect struct{} +func (postgresDialect) DriverName() string { + return "postgres" +} + +func (postgresDialect) InsertMethod() insertMethod { + return insertWithReturning +} + func (postgresDialect) Escape(str string) string { return `"` + str + `"` } @@ -24,6 +50,14 @@ func (postgresDialect) Placeholder(idx int) string { type sqlite3Dialect struct{} +func (sqlite3Dialect) DriverName() string { + return "sqlite3" +} + +func (sqlite3Dialect) InsertMethod() insertMethod { + return insertWithLastInsertID +} + func (sqlite3Dialect) Escape(str string) string { return "`" + str + "`" } @@ -46,3 +80,39 @@ func GetDriverDialect(driver string) (Dialect, error) { return dialect, nil } + +type mysqlDialect struct{} + +func (mysqlDialect) DriverName() string { + return "mysql" +} + +func (mysqlDialect) InsertMethod() insertMethod { + return insertWithLastInsertID +} + +func (mysqlDialect) Escape(str string) string { + return "`" + str + "`" +} + +func (mysqlDialect) Placeholder(idx int) string { + return "?" +} + +type sqlserverDialect struct{} + +func (sqlserverDialect) DriverName() string { + return "sqlserver" +} + +func (sqlserverDialect) InsertMethod() insertMethod { + return insertWithOutput +} + +func (sqlserverDialect) Escape(str string) string { + return `[` + str + `]` +} + +func (sqlserverDialect) Placeholder(idx int) string { + return "@p" + strconv.Itoa(idx+1) +} diff --git a/docker-compose.yml b/docker-compose.yml index 878d61d..715b450 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,3 +15,22 @@ services: environment: - POSTGRES_USER=postgres - POSTGRES_PASSWORD=postgres + - POSTGRES_DB=${DB_NAME:-ksql} + + mysql: + image: mysql + restart: always + ports: + - "127.0.0.1:3306:3306" + environment: + MYSQL_ROOT_PASSWORD: mysql + + sqlserver: + image: microsoft/mssql-server-linux:2017-latest + restart: always + ports: + - "127.0.0.1:1433:1433" + - "127.0.0.1:1434:1434" + environment: + SA_PASSWORD: "Sqls3rv3r" + ACCEPT_EULA: "Y" diff --git a/examples/crud/crud.go b/examples/crud/crud.go index 75fa363..6ef5df5 100644 --- a/examples/crud/crud.go +++ b/examples/crud/crud.go @@ -33,11 +33,14 @@ type Address struct { City string `json:"city"` } +// UsersTable informs ksql the name of the table and that it can +// use the default value for the primary key column name: "id" +var UsersTable = ksql.NewTable("users") + func main() { ctx := context.Background() db, err := ksql.New("sqlite3", "/tmp/hello.sqlite", ksql.Config{ MaxOpenConns: 1, - TableName: "users", }) if err != nil { panic(err.Error()) @@ -62,14 +65,14 @@ func main() { State: "MG", }, } - err = db.Insert(ctx, &alison) + err = db.Insert(ctx, UsersTable, &alison) if err != nil { panic(err.Error()) } fmt.Println("Alison ID:", alison.ID) // Inserting inline: - err = db.Insert(ctx, &User{ + err = db.Insert(ctx, UsersTable, &User{ Name: "Cristina", Age: 27, Address: Address{ @@ -81,7 +84,7 @@ func main() { } // Deleting Alison: - err = db.Delete(ctx, alison.ID) + err = db.Delete(ctx, UsersTable, alison.ID) if err != nil { panic(err.Error()) } @@ -96,12 +99,12 @@ func main() { // Updating all fields from Cristina: cris.Name = "Cris" - err = db.Update(ctx, cris) + err = db.Update(ctx, UsersTable, cris) // Changing the age of Cristina but not touching any other fields: // Partial update technique 1: - err = db.Update(ctx, struct { + err = db.Update(ctx, UsersTable, struct { ID int `ksql:"id"` Age int `ksql:"age"` }{ID: cris.ID, Age: 28}) @@ -110,7 +113,7 @@ func main() { } // Partial update technique 2: - err = db.Update(ctx, PartialUpdateUser{ + err = db.Update(ctx, UsersTable, PartialUpdateUser{ ID: cris.ID, Age: nullable.Int(28), }) @@ -134,7 +137,7 @@ func main() { } // Making transactions: - err = db.Transaction(ctx, func(db ksql.SQLProvider) error { + err = db.Transaction(ctx, func(db ksql.Provider) error { var cris2 User err = db.QueryOne(ctx, &cris2, "SELECT * FROM users WHERE id = ?", cris.ID) if err != nil { @@ -142,7 +145,7 @@ func main() { return err } - err = db.Update(ctx, PartialUpdateUser{ + err = db.Update(ctx, UsersTable, PartialUpdateUser{ ID: cris2.ID, Age: nullable.Int(29), }) diff --git a/examples/example_service/example_service.go b/examples/example_service/example_service.go index 85bd18f..bbee499 100644 --- a/examples/example_service/example_service.go +++ b/examples/example_service/example_service.go @@ -8,11 +8,8 @@ import ( "github.com/vingarcia/ksql/nullable" ) -// Service ... -type Service struct { - usersTable ksql.SQLProvider - streamChunkSize int -} +// UsersTable informs ksql that the ID column is named "id" +var UsersTable = ksql.NewTable("users", "id") // UserEntity represents a domain user, // the pointer fields represent optional fields that @@ -41,17 +38,23 @@ type Address struct { Country string `json:"country"` } +// Service ... +type Service struct { + db ksql.Provider + streamChunkSize int +} + // NewUserService ... -func NewUserService(usersTable ksql.SQLProvider) Service { +func NewUserService(db ksql.Provider) Service { return Service{ - usersTable: usersTable, + db: db, streamChunkSize: 100, } } // CreateUser ... func (s Service) CreateUser(ctx context.Context, u UserEntity) error { - return s.usersTable.Insert(ctx, &u) + return s.db.Insert(ctx, UsersTable, &u) } // UpdateUserScore update the user score adding scoreChange with the current @@ -60,12 +63,12 @@ func (s Service) UpdateUserScore(ctx context.Context, uID int, scoreChange int) var scoreRow struct { Score int `ksql:"score"` } - err := s.usersTable.QueryOne(ctx, &scoreRow, "SELECT score FROM users WHERE id = ?", uID) + err := s.db.QueryOne(ctx, &scoreRow, "SELECT score FROM users WHERE id = ?", uID) if err != nil { return err } - return s.usersTable.Update(ctx, &UserEntity{ + return s.db.Update(ctx, UsersTable, &UserEntity{ ID: uID, Score: nullable.Int(scoreRow.Score + scoreChange), }) @@ -76,12 +79,12 @@ func (s Service) ListUsers(ctx context.Context, offset, limit int) (total int, u var countRow struct { Count int `ksql:"count"` } - err = s.usersTable.QueryOne(ctx, &countRow, "SELECT count(*) as count FROM users") + err = s.db.QueryOne(ctx, &countRow, "SELECT count(*) as count FROM users") if err != nil { return 0, nil, err } - return countRow.Count, users, s.usersTable.Query(ctx, &users, "SELECT * FROM users OFFSET ? LIMIT ?", offset, limit) + return countRow.Count, users, s.db.Query(ctx, &users, "SELECT * FROM users OFFSET ? LIMIT ?", offset, limit) } // StreamAllUsers sends all users from the database to an external client @@ -91,7 +94,7 @@ func (s Service) ListUsers(ctx context.Context, offset, limit int) (total int, u // function only when the ammount of data loaded might exceed the available memory and/or // when you can't put an upper limit on the number of values returned. func (s Service) StreamAllUsers(ctx context.Context, sendUser func(u UserEntity) error) error { - return s.usersTable.QueryChunks(ctx, ksql.ChunkParser{ + return s.db.QueryChunks(ctx, ksql.ChunkParser{ Query: "SELECT * FROM users", Params: []interface{}{}, ChunkSize: s.streamChunkSize, @@ -110,5 +113,5 @@ func (s Service) StreamAllUsers(ctx context.Context, sendUser func(u UserEntity) // DeleteUser deletes a user by its ID func (s Service) DeleteUser(ctx context.Context, uID int) error { - return s.usersTable.Delete(ctx, uID) + return s.db.Delete(ctx, UsersTable, uID) } diff --git a/examples/example_service/example_service_test.go b/examples/example_service/example_service_test.go index ae142f5..f19d799 100644 --- a/examples/example_service/example_service_test.go +++ b/examples/example_service/example_service_test.go @@ -5,11 +5,10 @@ import ( "testing" gomock "github.com/golang/mock/gomock" - "github.com/stretchr/testify/require" "github.com/tj/assert" "github.com/vingarcia/ksql" + "github.com/vingarcia/ksql/kstructs" "github.com/vingarcia/ksql/nullable" - "github.com/vingarcia/ksql/structs" ) func TestCreateUser(t *testing.T) { @@ -17,16 +16,16 @@ func TestCreateUser(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 100, } var users []interface{} - usersTableMock.EXPECT().Insert(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, records ...interface{}) error { + mockDB.EXPECT().Insert(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, table ksql.Table, records ...interface{}) error { users = append(users, records...) return nil }) @@ -43,23 +42,23 @@ func TestCreateUser(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 100, } var users []map[string]interface{} - usersTableMock.EXPECT().Insert(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, records ...interface{}) error { + mockDB.EXPECT().Insert(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, table ksql.Table, records ...interface{}) error { for _, record := range records { // The StructToMap function will convert a struct with `ksql` tags // into a map using the ksql attr names as keys. // // If you are inserting an anonymous struct (not usual) this function // can make your tests shorter: - uMap, err := structs.StructToMap(record) + uMap, err := kstructs.StructToMap(record) if err != nil { return err } @@ -83,28 +82,28 @@ func TestUpdateUserScore(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 100, } var users []interface{} gomock.InOrder( - usersTableMock.EXPECT().QueryOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + mockDB.EXPECT().QueryOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, result interface{}, query string, params ...interface{}) error { // This function will use reflection to fill the // struct fields with the values from the map - return structs.FillStructWith(result, map[string]interface{}{ + return kstructs.FillStructWith(result, map[string]interface{}{ // Use int this map the keys you set on the ksql tags, e.g. `ksql:"score"` // Each of these fields represent the database rows returned // by the query. "score": 42, }) }), - usersTableMock.EXPECT().Update(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, records ...interface{}) error { + mockDB.EXPECT().Update(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, table ksql.Table, records ...interface{}) error { users = append(users, records...) return nil }), @@ -127,28 +126,28 @@ func TestListUsers(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 100, } gomock.InOrder( - usersTableMock.EXPECT().QueryOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + mockDB.EXPECT().QueryOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, result interface{}, query string, params ...interface{}) error { // This function will use reflection to fill the // struct fields with the values from the map - return structs.FillStructWith(result, map[string]interface{}{ + return kstructs.FillStructWith(result, map[string]interface{}{ // Use int this map the keys you set on the ksql tags, e.g. `ksql:"score"` // Each of these fields represent the database rows returned // by the query. "count": 420, }) }), - usersTableMock.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + mockDB.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, results interface{}, query string, params ...interface{}) error { - return structs.FillSliceWith(results, []map[string]interface{}{ + return kstructs.FillSliceWith(results, []map[string]interface{}{ { "id": 1, "name": "fake name", @@ -189,28 +188,26 @@ func TestStreamAllUsers(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 2, } - usersTableMock.EXPECT().QueryChunks(gomock.Any(), gomock.Any()). + mockDB.EXPECT().QueryChunks(gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, parser ksql.ChunkParser) error { - fn, ok := parser.ForEachChunk.(func(users []UserEntity) error) - require.True(t, ok) // Chunk 1: - err := fn([]UserEntity{ + err := kstructs.CallFunctionWithRows(parser.ForEachChunk, []map[string]interface{}{ { - ID: 1, - Name: nullable.String("fake name"), - Age: nullable.Int(42), + "id": 1, + "name": "fake name", + "age": 42, }, { - ID: 2, - Name: nullable.String("another fake name"), - Age: nullable.Int(43), + "id": 2, + "name": "another fake name", + "age": 43, }, }) if err != nil { @@ -218,11 +215,11 @@ func TestStreamAllUsers(t *testing.T) { } // Chunk 2: - err = fn([]UserEntity{ + err = kstructs.CallFunctionWithRows(parser.ForEachChunk, []map[string]interface{}{ { - ID: 3, - Name: nullable.String("yet another fake name"), - Age: nullable.Int(44), + "id": 3, + "name": "yet another fake name", + "age": 44, }, }) return err @@ -263,16 +260,16 @@ func TestDeleteUser(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 100, } var ids []interface{} - usersTableMock.EXPECT().Delete(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, idArgs ...interface{}) error { + mockDB.EXPECT().Delete(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, table ksql.Table, idArgs ...interface{}) error { ids = append(ids, idArgs...) return nil }) diff --git a/examples/example_service/mocks.go b/examples/example_service/mocks.go index 44b7baf..1307ab2 100644 --- a/examples/example_service/mocks.go +++ b/examples/example_service/mocks.go @@ -12,33 +12,33 @@ import ( ksql "github.com/vingarcia/ksql" ) -// MockSQLProvider is a mock of SQLProvider interface. -type MockSQLProvider struct { +// MockProvider is a mock of Provider interface. +type MockProvider struct { ctrl *gomock.Controller - recorder *MockSQLProviderMockRecorder + recorder *MockProviderMockRecorder } -// MockSQLProviderMockRecorder is the mock recorder for MockSQLProvider. -type MockSQLProviderMockRecorder struct { - mock *MockSQLProvider +// MockProviderMockRecorder is the mock recorder for MockProvider. +type MockProviderMockRecorder struct { + mock *MockProvider } -// NewMockSQLProvider creates a new mock instance. -func NewMockSQLProvider(ctrl *gomock.Controller) *MockSQLProvider { - mock := &MockSQLProvider{ctrl: ctrl} - mock.recorder = &MockSQLProviderMockRecorder{mock} +// NewMockProvider creates a new mock instance. +func NewMockProvider(ctrl *gomock.Controller) *MockProvider { + mock := &MockProvider{ctrl: ctrl} + mock.recorder = &MockProviderMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSQLProvider) EXPECT() *MockSQLProviderMockRecorder { +func (m *MockProvider) EXPECT() *MockProviderMockRecorder { return m.recorder } // Delete mocks base method. -func (m *MockSQLProvider) Delete(ctx context.Context, idsOrRecords ...interface{}) error { +func (m *MockProvider) Delete(ctx context.Context, table ksql.Table, idsOrRecords ...interface{}) error { m.ctrl.T.Helper() - varargs := []interface{}{ctx} + varargs := []interface{}{ctx, table} for _, a := range idsOrRecords { varargs = append(varargs, a) } @@ -48,14 +48,14 @@ func (m *MockSQLProvider) Delete(ctx context.Context, idsOrRecords ...interface{ } // Delete indicates an expected call of Delete. -func (mr *MockSQLProviderMockRecorder) Delete(ctx interface{}, idsOrRecords ...interface{}) *gomock.Call { +func (mr *MockProviderMockRecorder) Delete(ctx, table interface{}, idsOrRecords ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx}, idsOrRecords...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSQLProvider)(nil).Delete), varargs...) + varargs := append([]interface{}{ctx, table}, idsOrRecords...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockProvider)(nil).Delete), varargs...) } // Exec mocks base method. -func (m *MockSQLProvider) Exec(ctx context.Context, query string, params ...interface{}) error { +func (m *MockProvider) Exec(ctx context.Context, query string, params ...interface{}) error { m.ctrl.T.Helper() varargs := []interface{}{ctx, query} for _, a := range params { @@ -67,28 +67,28 @@ func (m *MockSQLProvider) Exec(ctx context.Context, query string, params ...inte } // Exec indicates an expected call of Exec. -func (mr *MockSQLProviderMockRecorder) Exec(ctx, query interface{}, params ...interface{}) *gomock.Call { +func (mr *MockProviderMockRecorder) Exec(ctx, query interface{}, params ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]interface{}{ctx, query}, params...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLProvider)(nil).Exec), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockProvider)(nil).Exec), varargs...) } // Insert mocks base method. -func (m *MockSQLProvider) Insert(ctx context.Context, record interface{}) error { +func (m *MockProvider) Insert(ctx context.Context, table ksql.Table, record interface{}) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Insert", ctx, record) + ret := m.ctrl.Call(m, "Insert", ctx, table, record) ret0, _ := ret[0].(error) return ret0 } // Insert indicates an expected call of Insert. -func (mr *MockSQLProviderMockRecorder) Insert(ctx, record interface{}) *gomock.Call { +func (mr *MockProviderMockRecorder) Insert(ctx, table, record interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockSQLProvider)(nil).Insert), ctx, record) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockProvider)(nil).Insert), ctx, table, record) } // Query mocks base method. -func (m *MockSQLProvider) Query(ctx context.Context, records interface{}, query string, params ...interface{}) error { +func (m *MockProvider) Query(ctx context.Context, records interface{}, query string, params ...interface{}) error { m.ctrl.T.Helper() varargs := []interface{}{ctx, records, query} for _, a := range params { @@ -100,14 +100,14 @@ func (m *MockSQLProvider) Query(ctx context.Context, records interface{}, query } // Query indicates an expected call of Query. -func (mr *MockSQLProviderMockRecorder) Query(ctx, records, query interface{}, params ...interface{}) *gomock.Call { +func (mr *MockProviderMockRecorder) Query(ctx, records, query interface{}, params ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]interface{}{ctx, records, query}, params...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockSQLProvider)(nil).Query), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockProvider)(nil).Query), varargs...) } // QueryChunks mocks base method. -func (m *MockSQLProvider) QueryChunks(ctx context.Context, parser ksql.ChunkParser) error { +func (m *MockProvider) QueryChunks(ctx context.Context, parser ksql.ChunkParser) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "QueryChunks", ctx, parser) ret0, _ := ret[0].(error) @@ -115,13 +115,13 @@ func (m *MockSQLProvider) QueryChunks(ctx context.Context, parser ksql.ChunkPars } // QueryChunks indicates an expected call of QueryChunks. -func (mr *MockSQLProviderMockRecorder) QueryChunks(ctx, parser interface{}) *gomock.Call { +func (mr *MockProviderMockRecorder) QueryChunks(ctx, parser interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryChunks", reflect.TypeOf((*MockSQLProvider)(nil).QueryChunks), ctx, parser) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryChunks", reflect.TypeOf((*MockProvider)(nil).QueryChunks), ctx, parser) } // QueryOne mocks base method. -func (m *MockSQLProvider) QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error { +func (m *MockProvider) QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error { m.ctrl.T.Helper() varargs := []interface{}{ctx, record, query} for _, a := range params { @@ -133,14 +133,14 @@ func (m *MockSQLProvider) QueryOne(ctx context.Context, record interface{}, quer } // QueryOne indicates an expected call of QueryOne. -func (mr *MockSQLProviderMockRecorder) QueryOne(ctx, record, query interface{}, params ...interface{}) *gomock.Call { +func (mr *MockProviderMockRecorder) QueryOne(ctx, record, query interface{}, params ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]interface{}{ctx, record, query}, params...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryOne", reflect.TypeOf((*MockSQLProvider)(nil).QueryOne), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryOne", reflect.TypeOf((*MockProvider)(nil).QueryOne), varargs...) } // Transaction mocks base method. -func (m *MockSQLProvider) Transaction(ctx context.Context, fn func(ksql.SQLProvider) error) error { +func (m *MockProvider) Transaction(ctx context.Context, fn func(ksql.Provider) error) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Transaction", ctx, fn) ret0, _ := ret[0].(error) @@ -148,21 +148,21 @@ func (m *MockSQLProvider) Transaction(ctx context.Context, fn func(ksql.SQLProvi } // Transaction indicates an expected call of Transaction. -func (mr *MockSQLProviderMockRecorder) Transaction(ctx, fn interface{}) *gomock.Call { +func (mr *MockProviderMockRecorder) Transaction(ctx, fn interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Transaction", reflect.TypeOf((*MockSQLProvider)(nil).Transaction), ctx, fn) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Transaction", reflect.TypeOf((*MockProvider)(nil).Transaction), ctx, fn) } // Update mocks base method. -func (m *MockSQLProvider) Update(ctx context.Context, record interface{}) error { +func (m *MockProvider) Update(ctx context.Context, table ksql.Table, record interface{}) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Update", ctx, record) + ret := m.ctrl.Call(m, "Update", ctx, table, record) ret0, _ := ret[0].(error) return ret0 } // Update indicates an expected call of Update. -func (mr *MockSQLProviderMockRecorder) Update(ctx, record interface{}) *gomock.Call { +func (mr *MockProviderMockRecorder) Update(ctx, table, record interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockSQLProvider)(nil).Update), ctx, record) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockProvider)(nil).Update), ctx, table, record) } diff --git a/go.mod b/go.mod index b218375..e53b423 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,17 @@ module github.com/vingarcia/ksql go 1.14 require ( + github.com/denisenkom/go-mssqldb v0.10.0 github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018 + github.com/go-sql-driver/mysql v1.4.0 github.com/golang/mock v1.5.0 + github.com/jackc/pgconn v1.10.0 // indirect + github.com/jackc/pgx/v4 v4.13.0 github.com/jmoiron/sqlx v1.2.0 - github.com/lib/pq v1.1.1 + github.com/lib/pq v1.10.2 github.com/mattn/go-sqlite3 v1.14.6 github.com/pkg/errors v0.9.1 - github.com/stretchr/testify v1.6.1 + github.com/stretchr/testify v1.7.0 // indirect github.com/tj/assert v0.0.3 google.golang.org/appengine v1.6.7 // indirect ) diff --git a/go.sum b/go.sum index 2e20171..af807c8 100644 --- a/go.sum +++ b/go.sum @@ -1,51 +1,208 @@ +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= +github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= +github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/denisenkom/go-mssqldb v0.10.0 h1:QykgLZBorFE95+gO3u9esLd0BmbvpWp0/waNNZfHBM8= +github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018 h1:QsFkVafcKOaZoAB4WcyUHdkPbwh+VYwZgYJb/rU6EIM= github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018/go.mod h1:5C3SWkut69TSdkerzRDxXMRM5x73PGWNcRLe/xKjXhs= +github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= +github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-sql-driver/mysql v1.4.0 h1:7LxgVwFb2hIQtMm87NdgAVfXjnt4OePseqT1tKx+opk= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= +github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/mock v1.5.0 h1:jlYHihg//f7RRwuPfptm04yp4s7O6Kw8EZiVYIGcH0g= github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= +github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= +github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgconn v1.10.0 h1:4EYhlDVEMsJ30nNj0mmgwIUXoq7e9sMJrVC2ED6QlCU= +github.com/jackc/pgconn v1.10.0/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= +github.com/jackc/pgtype v1.8.1 h1:9k0IXtdJXHJbyAWQgbWr1lU+MEhPXZz6RIXxfR5oxXs= +github.com/jackc/pgtype v1.8.1/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= +github.com/jackc/pgx/v4 v4.13.0 h1:JCjhT5vmhMAf/YwBHLvrBn4OGdIQBiFG6ym8Zmdx570= +github.com/jackc/pgx/v4 v4.13.0/go.mod h1:9P4X524sErlaxj0XSGZk7s+LD0eOyu1ZDUrrpznYDF0= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.3 h1:JnPg/5Q9xVJGfjsO5CPUOjnJps1JaRUm8I9FXVCFK94= +github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jmoiron/sqlx v1.2.0 h1:41Ip0zITnmWNR/vHV+S4m+VoUivnWY5E4OJfLZjCJMA= github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4= -github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tj/assert v0.0.3 h1:Df/BlaZ20mq6kuai7f5z2TvPFiwC3xaWJSDQNiIS3Rk= github.com/tj/assert v0.0.3/go.mod h1:Ne6X72Q+TB1AteidzQncjw9PabbMp4PBMZ1k+vd1Pvk= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c h1:grhR+C34yXImVGp7EzNk+DTIk+323eIUWOmEevy6bDo= gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/json.go b/json.go index aaf2ef9..03fdda1 100644 --- a/json.go +++ b/json.go @@ -11,7 +11,8 @@ import ( // input attributes to be convertible to and from JSON // before sending or receiving it from the database. type jsonSerializable struct { - Attr interface{} + DriverName string + Attr interface{} } // Scan Implements the Scanner interface in order to load @@ -40,5 +41,9 @@ func (j *jsonSerializable) Scan(value interface{}) error { // Value Implements the Valuer interface in order to save // this field as JSON on the database. func (j jsonSerializable) Value() (driver.Value, error) { - return json.Marshal(j.Attr) + b, err := json.Marshal(j.Attr) + if j.DriverName == "sqlserver" { + return string(b), err + } + return b, err } diff --git a/kbuilder/kbuilder.go b/kbuilder/kbuilder.go index b17b808..e809ddc 100644 --- a/kbuilder/kbuilder.go +++ b/kbuilder/kbuilder.go @@ -8,13 +8,20 @@ import ( "github.com/pkg/errors" "github.com/vingarcia/ksql" - "github.com/vingarcia/ksql/structs" + "github.com/vingarcia/ksql/kstructs" ) +// Builder is the basic container for injecting +// query builder configurations. +// +// All the Query structs can also be called +// directly without this builder, but we kept it +// here for convenience. type Builder struct { dialect ksql.Dialect } +// New creates a new Builder container. func New(driver string) (Builder, error) { dialect, err := ksql.GetDriverDialect(driver) return Builder{ @@ -22,6 +29,8 @@ func New(driver string) (Builder, error) { }, err } +// Build receives a query builder struct, injects it with the configurations +// build the query according to its arguments. func (builder *Builder) Build(query Query) (sqlQuery string, params []interface{}, _ error) { var b strings.Builder @@ -66,6 +75,7 @@ func (builder *Builder) Build(query Query) (sqlQuery string, params []interface{ return b.String(), params, nil } +// Query is is the struct template for building SELECT queries. type Query struct { // Select expects either a struct using the `ksql` tags // or a string listing the column names using SQL syntax, @@ -84,6 +94,7 @@ type Query struct { OrderBy OrderByQuery } +// WhereQuery represents a single condition in a WHERE expression. type WhereQuery struct { // Accepts any SQL boolean expression // This expression may optionally contain @@ -99,6 +110,8 @@ type WhereQuery struct { params []interface{} } +// WhereQueries is the helper for creating complex WHERE queries +// in a dynamic way. type WhereQueries []WhereQuery func (w WhereQueries) build(dialect ksql.Dialect) (query string, params []interface{}) { @@ -116,6 +129,8 @@ func (w WhereQueries) build(dialect ksql.Dialect) (query string, params []interf return strings.Join(conds, " AND "), params } +// Where adds a new bollean condition to an existing +// WhereQueries helper. func (w WhereQueries) Where(cond string, params ...interface{}) WhereQueries { return append(w, WhereQuery{ cond: cond, @@ -123,6 +138,7 @@ func (w WhereQueries) Where(cond string, params ...interface{}) WhereQueries { }) } +// WhereIf condionally adds a new boolean expression to the WhereQueries helper. func (w WhereQueries) WhereIf(cond string, param interface{}) WhereQueries { if param == nil || reflect.ValueOf(param).IsNil() { return w @@ -134,6 +150,8 @@ func (w WhereQueries) WhereIf(cond string, param interface{}) WhereQueries { }) } +// Where adds a new bollean condition to an existing +// WhereQueries helper. func Where(cond string, params ...interface{}) WhereQueries { return WhereQueries{{ cond: cond, @@ -141,6 +159,7 @@ func Where(cond string, params ...interface{}) WhereQueries { }} } +// WhereIf condionally adds a new boolean expression to the WhereQueries helper func WhereIf(cond string, param interface{}) WhereQueries { if param == nil || reflect.ValueOf(param).IsNil() { return WhereQueries{} @@ -152,11 +171,14 @@ func WhereIf(cond string, param interface{}) WhereQueries { }} } +// OrderByQuery represents the ORDER BY part of the query type OrderByQuery struct { fields string desc bool } +// Desc is a setter function for configuring the +// ORDER BY part of the query as DESC func (o OrderByQuery) Desc() OrderByQuery { return OrderByQuery{ fields: o.fields, @@ -164,6 +186,8 @@ func (o OrderByQuery) Desc() OrderByQuery { } } +// OrderBy is a helper for building the ORDER BY +// part of the query. func OrderBy(fields string) OrderByQuery { return OrderByQuery{ fields: fields, @@ -188,7 +212,7 @@ func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) { return query, nil } - info := structs.GetTagInfo(t) + info := kstructs.GetTagInfo(t) var escapedNames []string for i := 0; i < info.NumFields(); i++ { diff --git a/ksql.go b/ksql.go index f8504fe..a0e127a 100644 --- a/ksql.go +++ b/ksql.go @@ -6,53 +6,74 @@ import ( "fmt" "reflect" "strings" + "unicode" + "github.com/jackc/pgx/v4/pgxpool" "github.com/pkg/errors" - "github.com/vingarcia/ksql/structs" + "github.com/vingarcia/ksql/kstructs" ) +var selectQueryCache = map[string]map[reflect.Type]string{} + +func init() { + for dname := range supportedDialects { + selectQueryCache[dname] = map[reflect.Type]string{} + } +} + // DB represents the ksql client responsible for // interfacing with the "database/sql" package implementing -// the KissSQL interface `SQLProvider`. +// the KissSQL interface `ksql.Provider`. type DB struct { - driver string - dialect Dialect - tableName string - db sqlProvider - - // Most dbs have a single primary key, - // But in future ksql should work with compound keys as well - idCols []string - - insertMethod insertMethod + driver string + dialect Dialect + db DBAdapter } -type sqlProvider interface { - ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) - QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) +// DBAdapter is minimalistic interface to decouple our implementation +// from database/sql, i.e. if any struct implements the functions below +// with the exact same semantic as the sql package it will work with ksql. +// +// To create a new client using this adapter use ksql.NewWithDB() +type DBAdapter interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) } -type insertMethod int +// TxBeginner needs to be implemented by the DBAdapter in order to make it possible +// to use the `ksql.Transaction()` function. +type TxBeginner interface { + BeginTx(ctx context.Context) (Tx, error) +} -const ( - insertWithReturning insertMethod = iota - insertWithLastInsertID - insertWithNoIDRetrieval -) +// Result stores information about the result of an Exec query +type Result interface { + LastInsertId() (int64, error) + RowsAffected() (int64, error) +} + +// Rows represents the results from a call to Query() +type Rows interface { + Scan(...interface{}) error + Close() error + Next() bool + Err() error + Columns() ([]string, error) +} + +// Tx represents a transaction and is expected to be returned by the DBAdapter.BeginTx function +type Tx interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) + Rollback(ctx context.Context) error + Commit(ctx context.Context) error +} // Config describes the optional arguments accepted // by the ksql.New() function. type Config struct { // MaxOpenCons defaults to 1 if not set MaxOpenConns int - - // TableName must be set in order to use the Insert, Delete and Update helper - // functions. If you only intend to make queries or to use the Exec function - // it is safe to leave this field unset. - TableName string - - // IDColumns defaults to []string{"id"} if unset - IDColumns []string } // New instantiates a new KissSQL client @@ -75,36 +96,51 @@ func New( db.SetMaxOpenConns(config.MaxOpenConns) - dialect, err := GetDriverDialect(dbDriver) + return NewWithAdapter(SQLAdapter{db}, dbDriver, connectionString) +} + +// NewWithPGX instantiates a new KissSQL client using the pgx +// library in the backend +func NewWithPGX( + ctx context.Context, + connectionString string, + config Config, +) (db DB, err error) { + pgxConf, err := pgxpool.ParseConfig(connectionString) if err != nil { return DB{}, err } - if len(config.IDColumns) == 0 { - config.IDColumns = []string{"id"} + pgxConf.MaxConns = int32(config.MaxOpenConns) + + pool, err := pgxpool.ConnectConfig(ctx, pgxConf) + if err != nil { + return DB{}, err + } + if err = pool.Ping(ctx); err != nil { + return DB{}, err } - var insertMethod insertMethod - switch dbDriver { - case "sqlite3": - insertMethod = insertWithLastInsertID - if len(config.IDColumns) > 1 { - insertMethod = insertWithNoIDRetrieval - } - case "postgres": - insertMethod = insertWithReturning - default: + db, err = NewWithAdapter(PGXAdapter{pool}, "postgres", connectionString) + return db, err +} + +// NewWithAdapter allows the user to insert a custom implementation +// of the DBAdapter interface +func NewWithAdapter( + db DBAdapter, + dbDriver string, + connectionString string, +) (DB, error) { + dialect := supportedDialects[dbDriver] + if dialect == nil { return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver) } return DB{ - dialect: dialect, - driver: dbDriver, - db: db, - tableName: config.TableName, - - idCols: config.IDColumns, - insertMethod: insertMethod, + dialect: dialect, + driver: dbDriver, + db: db, }, nil } @@ -128,7 +164,7 @@ func (c DB) Query( } sliceType := slicePtrType.Elem() slice := slicePtr.Elem() - structType, isSliceOfPtrs, err := structs.DecodeAsSliceOfStructs(sliceType) + structType, isSliceOfPtrs, err := kstructs.DecodeAsSliceOfStructs(sliceType) if err != nil { return err } @@ -140,6 +176,22 @@ func (c DB) Query( slice = slice.Slice(0, 0) } + info := kstructs.GetTagInfo(structType) + + firstToken := strings.ToUpper(getFirstToken(query)) + if info.IsNestedStruct && firstToken == "SELECT" { + // This error check is necessary, since if we can't build the select part of the query this feature won't work. + return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query") + } + + if firstToken == "FROM" { + selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()]) + if err != nil { + return err + } + query = selectPrefix + query + } + rows, err := c.db.QueryContext(ctx, query, params...) if err != nil { return fmt.Errorf("error running query: %s", err.Error()) @@ -164,7 +216,7 @@ func (c DB) Query( elemPtr = elemPtr.Elem() } - err = scanRows(rows, elemPtr.Interface()) + err = scanRows(c.dialect, rows, elemPtr.Interface()) if err != nil { return err } @@ -205,6 +257,22 @@ func (c DB) QueryOne( return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record) } + info := kstructs.GetTagInfo(t) + + firstToken := strings.ToUpper(getFirstToken(query)) + if info.IsNestedStruct && firstToken == "SELECT" { + // This error check is necessary, since if we can't build the select part of the query this feature won't work. + return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query") + } + + if firstToken == "FROM" { + selectPrefix, err := buildSelectQuery(c.dialect, t, info, selectQueryCache[c.dialect.DriverName()]) + if err != nil { + return err + } + query = selectPrefix + query + } + rows, err := c.db.QueryContext(ctx, query, params...) if err != nil { return err @@ -218,7 +286,7 @@ func (c DB) QueryOne( return ErrRecordNotFound } - err = scanRows(rows, record) + err = scanRows(c.dialect, rows, record) if err != nil { return err } @@ -247,18 +315,34 @@ func (c DB) QueryChunks( parser ChunkParser, ) error { fnValue := reflect.ValueOf(parser.ForEachChunk) - chunkType, err := parseInputFunc(parser.ForEachChunk) + chunkType, err := kstructs.ParseInputFunc(parser.ForEachChunk) if err != nil { return err } chunk := reflect.MakeSlice(chunkType, 0, parser.ChunkSize) - structType, isSliceOfPtrs, err := structs.DecodeAsSliceOfStructs(chunkType) + structType, isSliceOfPtrs, err := kstructs.DecodeAsSliceOfStructs(chunkType) if err != nil { return err } + info := kstructs.GetTagInfo(structType) + + firstToken := strings.ToUpper(getFirstToken(parser.Query)) + if info.IsNestedStruct && firstToken == "SELECT" { + // This error check is necessary, since if we can't build the select part of the query this feature won't work. + return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query") + } + + if firstToken == "FROM" { + selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()]) + if err != nil { + return err + } + parser.Query = selectPrefix + parser.Query + } + rows, err := c.db.QueryContext(ctx, parser.Query, parser.Params...) if err != nil { return err @@ -278,7 +362,7 @@ func (c DB) QueryChunks( chunk = reflect.Append(chunk, elemValue) } - err = scanRows(rows, chunk.Index(idx).Addr().Interface()) + err = scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface()) if err != nil { return err } @@ -330,22 +414,19 @@ func (c DB) QueryChunks( // the ID is automatically updated after insertion is completed. func (c DB) Insert( ctx context.Context, + table Table, record interface{}, ) error { - if c.tableName == "" { - return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Insert method") - } - - query, params, err := buildInsertQuery(c.dialect, c.tableName, record, c.idCols...) + query, params, scanValues, err := buildInsertQuery(c.dialect, table.name, record, table.idColumns...) if err != nil { return err } - switch c.insertMethod { - case insertWithReturning: - err = c.insertWithReturningID(ctx, record, query, params, c.idCols) + switch table.insertMethodFor(c.dialect) { + case insertWithReturning, insertWithOutput: + err = c.insertReturningIDs(ctx, record, query, params, scanValues, table.idColumns) case insertWithLastInsertID: - err = c.insertWithLastInsertID(ctx, record, query, params, c.idCols[0]) + err = c.insertWithLastInsertID(ctx, record, query, params, table.idColumns[0]) case insertWithNoIDRetrieval: err = c.insertWithNoIDRetrieval(ctx, record, query, params) default: @@ -357,19 +438,14 @@ func (c DB) Insert( return err } -func (c DB) insertWithReturningID( +func (c DB) insertReturningIDs( ctx context.Context, record interface{}, query string, params []interface{}, + scanValues []interface{}, idNames []string, ) error { - escapedIDNames := []string{} - for _, id := range idNames { - escapedIDNames = append(escapedIDNames, c.dialect.Escape(id)) - } - query += " RETURNING " + strings.Join(idNames, ", ") - rows, err := c.db.QueryContext(ctx, query, params...) if err != nil { return err @@ -385,21 +461,7 @@ func (c DB) insertWithReturningID( return err } - v := reflect.ValueOf(record) - t := v.Type() - if err = assertStructPtr(t); err != nil { - return errors.Wrap(err, "can't write id field") - } - info := structs.GetTagInfo(t.Elem()) - - var scanFields []interface{} - for _, id := range idNames { - scanFields = append( - scanFields, - v.Elem().Field(info.ByName(id).Index).Addr().Interface(), - ) - } - err = rows.Scan(scanFields...) + err = rows.Scan(scanValues...) if err != nil { return err } @@ -425,7 +487,7 @@ func (c DB) insertWithLastInsertID( return errors.Wrap(err, "can't write to `"+idName+"` field") } - info := structs.GetTagInfo(t.Elem()) + info := kstructs.GetTagInfo(t.Elem()) id, err := result.LastInsertId() if err != nil { @@ -473,27 +535,24 @@ func assertStructPtr(t reflect.Type) error { // Delete deletes one or more instances from the database by id func (c DB) Delete( ctx context.Context, + table Table, ids ...interface{}, ) error { - if c.tableName == "" { - return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Delete method") - } - if len(ids) == 0 { return nil } - idMaps, err := normalizeIDsAsMaps(c.idCols, ids) + idMaps, err := normalizeIDsAsMaps(table.idColumns, ids) if err != nil { return err } var query string var params []interface{} - if len(c.idCols) == 1 { - query, params = buildSingleKeyDeleteQuery(c.dialect, c.tableName, c.idCols[0], idMaps) + if len(table.idColumns) == 1 { + query, params = buildSingleKeyDeleteQuery(c.dialect, table.name, table.idColumns[0], idMaps) } else { - query, params = buildCompositeKeyDeleteQuery(c.dialect, c.tableName, c.idCols, idMaps) + query, params = buildCompositeKeyDeleteQuery(c.dialect, table.name, table.idColumns, idMaps) } _, err = c.db.ExecContext(ctx, query, params...) @@ -511,7 +570,7 @@ func normalizeIDsAsMaps(idNames []string, ids []interface{}) ([]map[string]inter t := reflect.TypeOf(ids[i]) switch t.Kind() { case reflect.Struct: - m, err := structs.StructToMap(ids[i]) + m, err := kstructs.StructToMap(ids[i]) if err != nil { return nil, errors.Wrapf(err, "could not get ID(s) from record on idx %d", i) } @@ -545,42 +604,63 @@ func normalizeIDsAsMaps(idNames []string, ids []interface{}) ([]map[string]inter // Partial updates are supported, i.e. it will ignore nil pointer attributes func (c DB) Update( ctx context.Context, + table Table, record interface{}, ) error { - if c.tableName == "" { - return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Update method") - } - - query, params, err := buildUpdateQuery(c.dialect, c.tableName, record, c.idCols...) + query, params, err := buildUpdateQuery(c.dialect, table.name, record, table.idColumns...) if err != nil { return err } - _, err = c.db.ExecContext(ctx, query, params...) + result, err := c.db.ExecContext(ctx, query, params...) + if err != nil { + return err + } - return err + n, err := result.RowsAffected() + if err != nil { + return fmt.Errorf( + "unexpected error: unable to fetch how many rows were affected by the update: %s", + err, + ) + } + if n < 1 { + return ErrRecordNotFound + } + + return nil } func buildInsertQuery( dialect Dialect, tableName string, record interface{}, - idFieldNames ...string, -) (query string, params []interface{}, err error) { - recordMap, err := structs.StructToMap(record) + idNames ...string, +) (query string, params []interface{}, scanValues []interface{}, err error) { + v := reflect.ValueOf(record) + t := v.Type() + if err = assertStructPtr(t); err != nil { + return "", nil, nil, fmt.Errorf( + "ksql: expected record to be a pointer to struct, but got: %T", + record, + ) + } + + info := kstructs.GetTagInfo(t.Elem()) + + recordMap, err := kstructs.StructToMap(record) if err != nil { - return "", nil, err + return "", nil, nil, err } - t := reflect.TypeOf(record) - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - info := structs.GetTagInfo(t) + for _, fieldName := range idNames { + field, found := recordMap[fieldName] + if !found { + continue + } - for _, fieldName := range idFieldNames { // Remove any ID field that was not set: - if reflect.ValueOf(recordMap[fieldName]).IsZero() { + if reflect.ValueOf(field).IsZero() { delete(recordMap, fieldName) } } @@ -596,7 +676,10 @@ func buildInsertQuery( recordValue := recordMap[col] params[i] = recordValue if info.ByName(col).SerializeAsJSON { - params[i] = jsonSerializable{Attr: recordValue} + params[i] = jsonSerializable{ + DriverName: dialect.DriverName(), + Attr: recordValue, + } } valuesQuery[i] = dialect.Placeholder(i) @@ -608,14 +691,48 @@ func buildInsertQuery( escapedColumnNames = append(escapedColumnNames, dialect.Escape(col)) } + var returningQuery, outputQuery string + switch dialect.InsertMethod() { + case insertWithReturning: + escapedIDNames := []string{} + for _, id := range idNames { + escapedIDNames = append(escapedIDNames, dialect.Escape(id)) + } + returningQuery = " RETURNING " + strings.Join(escapedIDNames, ", ") + + for _, id := range idNames { + scanValues = append( + scanValues, + v.Elem().Field(info.ByName(id).Index).Addr().Interface(), + ) + } + case insertWithOutput: + escapedIDNames := []string{} + for _, id := range idNames { + escapedIDNames = append(escapedIDNames, "INSERTED."+dialect.Escape(id)) + } + outputQuery = " OUTPUT " + strings.Join(escapedIDNames, ", ") + + for _, id := range idNames { + scanValues = append( + scanValues, + v.Elem().Field(info.ByName(id).Index).Addr().Interface(), + ) + } + } + + // Note that the outputQuery and the returningQuery depend + // on the selected driver, thus, they might be empty strings. query = fmt.Sprintf( - "INSERT INTO %s (%s) VALUES (%s)", + "INSERT INTO %s (%s)%s VALUES (%s)%s", dialect.Escape(tableName), strings.Join(escapedColumnNames, ", "), + outputQuery, strings.Join(valuesQuery, ", "), + returningQuery, ) - return query, params, nil + return query, params, scanValues, nil } func buildUpdateQuery( @@ -624,7 +741,7 @@ func buildUpdateQuery( record interface{}, idFieldNames ...string, ) (query string, args []interface{}, err error) { - recordMap, err := structs.StructToMap(record) + recordMap, err := kstructs.StructToMap(record) if err != nil { return "", nil, err } @@ -654,13 +771,16 @@ func buildUpdateQuery( if t.Kind() == reflect.Ptr { t = t.Elem() } - info := structs.GetTagInfo(t) + info := kstructs.GetTagInfo(t) var setQuery []string for i, k := range keys { recordValue := recordMap[k] if info.ByName(k).SerializeAsJSON { - recordValue = jsonSerializable{Attr: recordValue} + recordValue = jsonSerializable{ + DriverName: dialect.DriverName(), + Attr: recordValue, + } } args[i] = recordValue setQuery = append(setQuery, fmt.Sprintf( @@ -687,18 +807,18 @@ func (c DB) Exec(ctx context.Context, query string, params ...interface{}) error } // Transaction just runs an SQL command on the database returning no rows. -func (c DB) Transaction(ctx context.Context, fn func(SQLProvider) error) error { - switch db := c.db.(type) { - case *sql.Tx: +func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error { + switch txBeginner := c.db.(type) { + case Tx: return fn(c) - case *sql.DB: - tx, err := db.BeginTx(ctx, nil) + case TxBeginner: + tx, err := txBeginner.BeginTx(ctx) if err != nil { return err } defer func() { if r := recover(); r != nil { - rollbackErr := tx.Rollback() + rollbackErr := tx.Rollback(ctx) if rollbackErr != nil { r = errors.Wrap(rollbackErr, fmt.Sprintf("unable to rollback after panic with value: %v", r), @@ -713,7 +833,7 @@ func (c DB) Transaction(ctx context.Context, fn func(SQLProvider) error) error { err = fn(ormCopy) if err != nil { - rollbackErr := tx.Rollback() + rollbackErr := tx.Rollback(ctx) if rollbackErr != nil { err = errors.Wrap(rollbackErr, fmt.Sprintf("unable to rollback after error: %s", err.Error()), @@ -722,45 +842,13 @@ func (c DB) Transaction(ctx context.Context, fn func(SQLProvider) error) error { return err } - return tx.Commit() + return tx.Commit(ctx) default: - return fmt.Errorf("unexpected error on ksql: db attribute has an invalid type") + return fmt.Errorf("can't start transaction: The DBAdapter doesn't implement the TxBegginner interface") } } -var errType = reflect.TypeOf(new(error)).Elem() - -func parseInputFunc(fn interface{}) (reflect.Type, error) { - if fn == nil { - return nil, fmt.Errorf("the ForEachChunk attribute is required and cannot be nil") - } - - t := reflect.TypeOf(fn) - - if t.Kind() != reflect.Func { - return nil, fmt.Errorf("the ForEachChunk callback must be a function") - } - if t.NumIn() != 1 { - return nil, fmt.Errorf("the ForEachChunk callback must have 1 argument") - } - - if t.NumOut() != 1 { - return nil, fmt.Errorf("the ForEachChunk callback must have a single return value") - } - - if t.Out(0) != errType { - return nil, fmt.Errorf("the return value of the ForEachChunk callback must be of type error") - } - - argsType := t.In(0) - if argsType.Kind() != reflect.Slice { - return nil, fmt.Errorf("the argument of the ForEachChunk callback must a slice of structs") - } - - return argsType, nil -} - type nopScanner struct{} var nopScannerValue = reflect.ValueOf(&nopScanner{}).Interface() @@ -769,12 +857,7 @@ func (nopScanner) Scan(value interface{}) error { return nil } -func scanRows(rows *sql.Rows, record interface{}) error { - names, err := rows.Columns() - if err != nil { - return err - } - +func scanRows(dialect Dialect, rows Rows, record interface{}) error { v := reflect.ValueOf(record) t := v.Type() if t.Kind() != reflect.Ptr { @@ -788,8 +871,55 @@ func scanRows(rows *sql.Rows, record interface{}) error { return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record) } - info := structs.GetTagInfo(t) + info := kstructs.GetTagInfo(t) + var scanArgs []interface{} + if info.IsNestedStruct { + // This version is positional meaning that it expect the arguments + // to follow an specific order. It's ok because we don't allow the + // user to type the "SELECT" part of the query for nested kstructs. + scanArgs = getScanArgsForNestedStructs(dialect, rows, t, v, info) + } else { + names, err := rows.Columns() + if err != nil { + return err + } + // Since this version uses the names of the columns it works + // with any order of attributes/columns. + scanArgs = getScanArgsFromNames(dialect, names, v, info) + } + + return rows.Scan(scanArgs...) +} + +func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info kstructs.StructInfo) []interface{} { + scanArgs := []interface{}{} + for i := 0; i < v.NumField(); i++ { + // TODO(vingarcia00): Handle case where type is pointer + nestedStructInfo := kstructs.GetTagInfo(t.Field(i).Type) + nestedStructValue := v.Field(i) + for j := 0; j < nestedStructValue.NumField(); j++ { + fieldInfo := nestedStructInfo.ByIndex(j) + + valueScanner := nopScannerValue + if fieldInfo.Valid { + valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface() + if fieldInfo.SerializeAsJSON { + valueScanner = &jsonSerializable{ + DriverName: dialect.DriverName(), + Attr: valueScanner, + } + } + } + + scanArgs = append(scanArgs, valueScanner) + } + } + + return scanArgs +} + +func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info kstructs.StructInfo) []interface{} { scanArgs := []interface{}{} for _, name := range names { fieldInfo := info.ByName(name) @@ -798,14 +928,17 @@ func scanRows(rows *sql.Rows, record interface{}) error { if fieldInfo.Valid { valueScanner = v.Field(fieldInfo.Index).Addr().Interface() if fieldInfo.SerializeAsJSON { - valueScanner = &jsonSerializable{Attr: valueScanner} + valueScanner = &jsonSerializable{ + DriverName: dialect.DriverName(), + Attr: valueScanner, + } } } scanArgs = append(scanArgs, valueScanner) } - return rows.Scan(scanArgs...) + return scanArgs } func buildSingleKeyDeleteQuery( @@ -856,3 +989,83 @@ func buildCompositeKeyDeleteQuery( strings.Join(values, ","), ), params } + +// We implemented this function instead of using +// a regex or strings.Fields because we wanted +// to preserve the performance of the package. +func getFirstToken(s string) string { + s = strings.TrimLeftFunc(s, unicode.IsSpace) + + var token strings.Builder + for _, c := range s { + if unicode.IsSpace(c) { + break + } + token.WriteRune(c) + } + return token.String() +} + +func buildSelectQuery( + dialect Dialect, + structType reflect.Type, + info kstructs.StructInfo, + selectQueryCache map[reflect.Type]string, +) (query string, err error) { + if selectQuery, found := selectQueryCache[structType]; found { + return selectQuery, nil + } + + if info.IsNestedStruct { + query, err = buildSelectQueryForNestedStructs(dialect, structType, info) + if err != nil { + return "", err + } + } else { + query = buildSelectQueryForPlainStructs(dialect, structType, info) + } + + selectQueryCache[structType] = query + return query, nil +} + +func buildSelectQueryForPlainStructs( + dialect Dialect, + structType reflect.Type, + info kstructs.StructInfo, +) string { + var fields []string + for i := 0; i < structType.NumField(); i++ { + fields = append(fields, dialect.Escape(info.ByIndex(i).Name)) + } + + return "SELECT " + strings.Join(fields, ", ") + " " +} + +func buildSelectQueryForNestedStructs( + dialect Dialect, + structType reflect.Type, + info kstructs.StructInfo, +) (string, error) { + var fields []string + for i := 0; i < structType.NumField(); i++ { + nestedStructName := info.ByIndex(i).Name + nestedStructType := structType.Field(i).Type + if nestedStructType.Kind() != reflect.Struct { + return "", fmt.Errorf( + "expected nested struct with `tablename:\"%s\"` to be a kind of Struct, but got %v", + nestedStructName, nestedStructType, + ) + } + + nestedStructInfo := kstructs.GetTagInfo(nestedStructType) + for j := 0; j < structType.Field(i).Type.NumField(); j++ { + fields = append( + fields, + dialect.Escape(nestedStructName)+"."+dialect.Escape(nestedStructInfo.ByIndex(j).Name), + ) + } + } + + return "SELECT " + strings.Join(fields, ", ") + " ", nil +} diff --git a/ksql_test.go b/ksql_test.go index 583fece..0e9b739 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -6,10 +6,14 @@ import ( "encoding/json" "errors" "fmt" + "io" "strings" "testing" + _ "github.com/denisenkom/go-mssqldb" "github.com/ditointernet/go-assert" + _ "github.com/go-sql-driver/mysql" + "github.com/jackc/pgx/v4/pgxpool" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" "github.com/vingarcia/ksql/nullable" @@ -23,6 +27,8 @@ type User struct { Address Address `ksql:"address,json"` } +var UsersTable = NewTable("users") + type Address struct { Street string `json:"street"` Number string `json:"number"` @@ -32,166 +38,333 @@ type Address struct { Country string `json:"country"` } +type Post struct { + ID int `ksql:"id"` + UserID uint `ksql:"user_id"` + Title string `ksql:"title"` +} + +var PostsTable = NewTable("posts") + +type testConfig struct { + driver string + adapterName string +} + +var supportedConfigs = []testConfig{ + { + driver: "sqlite3", + adapterName: "sql", + }, + { + driver: "postgres", + adapterName: "sql", + }, + { + driver: "mysql", + adapterName: "sql", + }, + { + driver: "sqlserver", + adapterName: "sql", + }, + { + driver: "postgres", + adapterName: "pgx", + }, +} + func TestQuery(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { - t.Run(driver, func(t *testing.T) { - t.Run("using slice of structs", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } + for _, config := range supportedConfigs { + t.Run(config.driver, func(t *testing.T) { + variations := []struct { + desc string + queryPrefix string + }{ + { + desc: "with select *", + queryPrefix: "SELECT * ", + }, + { + desc: "building the SELECT part of the query internally", + queryPrefix: "", + }, + } + for _, variation := range variations { + t.Run(variation.desc, func(t *testing.T) { + t.Run("using slice of structs", func(t *testing.T) { + err := createTables(config.driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } - t.Run("should return 0 results correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + t.Run("should return 0 results correctly", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() - ctx := context.Background() - c := newTestDB(db, driver, "users") - var users []User - err := c.Query(ctx, &users, `SELECT * FROM users WHERE id=1;`) - assert.Equal(t, nil, err) - assert.Equal(t, []User(nil), users) + ctx := context.Background() + c := newTestDB(db, config.driver) + var users []User + err := c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`) + assert.Equal(t, nil, err) + assert.Equal(t, 0, len(users)) - users = []User{} - err = c.Query(ctx, &users, `SELECT * FROM users WHERE id=1;`) - assert.Equal(t, nil, err) - assert.Equal(t, []User{}, users) + users = []User{} + err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`) + assert.Equal(t, nil, err) + assert.Equal(t, 0, len(users)) + }) + + t.Run("should return a user correctly", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + _, err := db.ExecContext(context.TODO(), `INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + + ctx := context.Background() + c := newTestDB(db, config.driver) + var users []User + err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") + + assert.Equal(t, nil, err) + assert.Equal(t, 1, len(users)) + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "Bia", users[0].Name) + assert.Equal(t, "BR", users[0].Address.Country) + }) + + t.Run("should return multiple users correctly", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + _, err := db.ExecContext(context.TODO(), `INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + + _, err = db.ExecContext(context.TODO(), `INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + + ctx := context.Background() + c := newTestDB(db, config.driver) + var users []User + err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia") + + assert.Equal(t, nil, err) + assert.Equal(t, 2, len(users)) + + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "João Garcia", users[0].Name) + assert.Equal(t, "US", users[0].Address.Country) + + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "Bia Garcia", users[1].Name) + assert.Equal(t, "BR", users[1].Address.Country) + }) + + t.Run("should query joined tables correctly", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + // This test only makes sense with no query prefix + if variation.queryPrefix != "" { + return + } + + _, err := db.ExecContext(context.TODO(), `INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + var joao User + getUserByName(db, config.driver, &joao, "João Ribeiro") + assert.Equal(t, nil, err) + + _, err = db.ExecContext(context.TODO(), `INSERT INTO users (name, age, address) VALUES ('Bia Ribeiro', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + var bia User + getUserByName(db, config.driver, &bia, "Bia Ribeiro") + + _, err = db.ExecContext(context.TODO(), fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, bia.ID, `, 'Bia Post1')`)) + assert.Equal(t, nil, err) + _, err = db.ExecContext(context.TODO(), fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, bia.ID, `, 'Bia Post2')`)) + assert.Equal(t, nil, err) + _, err = db.ExecContext(context.TODO(), fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'João Post1')`)) + assert.Equal(t, nil, err) + + ctx := context.Background() + c := newTestDB(db, config.driver) + var rows []struct { + User User `tablename:"u"` + Post Post `tablename:"p"` + } + err = c.Query(ctx, &rows, fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), "% Ribeiro") + + assert.Equal(t, nil, err) + assert.Equal(t, 3, len(rows)) + + assert.Equal(t, joao.ID, rows[0].User.ID) + assert.Equal(t, "João Ribeiro", rows[0].User.Name) + assert.Equal(t, "João Post1", rows[0].Post.Title) + + assert.Equal(t, bia.ID, rows[1].User.ID) + assert.Equal(t, "Bia Ribeiro", rows[1].User.Name) + assert.Equal(t, "Bia Post1", rows[1].Post.Title) + + assert.Equal(t, bia.ID, rows[2].User.ID) + assert.Equal(t, "Bia Ribeiro", rows[2].User.Name) + assert.Equal(t, "Bia Post2", rows[2].Post.Title) + }) + }) + + t.Run("using slice of pointers to structs", func(t *testing.T) { + err := createTables(config.driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + t.Run("should return 0 results correctly", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + var users []*User + err := c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`) + assert.Equal(t, nil, err) + assert.Equal(t, 0, len(users)) + + users = []*User{} + err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`) + assert.Equal(t, nil, err) + assert.Equal(t, 0, len(users)) + }) + + t.Run("should return a user correctly", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + + c := newTestDB(db, config.driver) + var users []*User + err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") + + assert.Equal(t, nil, err) + assert.Equal(t, 1, len(users)) + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "Bia", users[0].Name) + assert.Equal(t, "BR", users[0].Address.Country) + }) + + t.Run("should return multiple users correctly", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + + _, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + + c := newTestDB(db, config.driver) + var users []*User + err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia") + + assert.Equal(t, nil, err) + assert.Equal(t, 2, len(users)) + + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "João Garcia", users[0].Name) + assert.Equal(t, "US", users[0].Address.Country) + + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "Bia Garcia", users[1].Name) + assert.Equal(t, "BR", users[1].Address.Country) + }) + + t.Run("should query joined tables correctly", func(t *testing.T) { + // This test only makes sense with no query prefix + if variation.queryPrefix != "" { + return + } + + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + var joao User + getUserByName(db, config.driver, &joao, "João Ribeiro") + + _, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia Ribeiro', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + var bia User + getUserByName(db, config.driver, &bia, "Bia Ribeiro") + + _, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, bia.ID, `, 'Bia Post1')`)) + assert.Equal(t, nil, err) + _, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, bia.ID, `, 'Bia Post2')`)) + assert.Equal(t, nil, err) + _, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'João Post1')`)) + assert.Equal(t, nil, err) + + c := newTestDB(db, config.driver) + var rows []*struct { + User User `tablename:"u"` + Post Post `tablename:"p"` + } + err = c.Query(ctx, &rows, fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), "% Ribeiro") + + assert.Equal(t, nil, err) + assert.Equal(t, 3, len(rows)) + + assert.Equal(t, joao.ID, rows[0].User.ID) + assert.Equal(t, "João Ribeiro", rows[0].User.Name) + assert.Equal(t, "João Post1", rows[0].Post.Title) + + assert.Equal(t, bia.ID, rows[1].User.ID) + assert.Equal(t, "Bia Ribeiro", rows[1].User.Name) + assert.Equal(t, "Bia Post1", rows[1].Post.Title) + + assert.Equal(t, bia.ID, rows[2].User.ID) + assert.Equal(t, "Bia Ribeiro", rows[2].User.Name) + assert.Equal(t, "Bia Post2", rows[2].Post.Title) + }) + }) }) - - t.Run("should return a user correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) - - ctx := context.Background() - c := newTestDB(db, driver, "users") - var users []User - err = c.Query(ctx, &users, `SELECT * FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") - - assert.Equal(t, nil, err) - assert.Equal(t, 1, len(users)) - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "Bia", users[0].Name) - assert.Equal(t, "BR", users[0].Address.Country) - }) - - t.Run("should return multiple users correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`) - assert.Equal(t, nil, err) - - _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) - - ctx := context.Background() - c := newTestDB(db, driver, "users") - var users []User - err = c.Query(ctx, &users, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia") - - assert.Equal(t, nil, err) - assert.Equal(t, 2, len(users)) - - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "João Garcia", users[0].Name) - assert.Equal(t, "US", users[0].Address.Country) - - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "Bia Garcia", users[1].Name) - assert.Equal(t, "BR", users[1].Address.Country) - }) - }) - - t.Run("using slice of pointers to structs", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - t.Run("should return 0 results correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - var users []*User - err := c.Query(ctx, &users, `SELECT * FROM users WHERE id=1;`) - assert.Equal(t, nil, err) - assert.Equal(t, []*User(nil), users) - - users = []*User{} - err = c.Query(ctx, &users, `SELECT * FROM users WHERE id=1;`) - assert.Equal(t, nil, err) - assert.Equal(t, []*User{}, users) - }) - - t.Run("should return a user correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) - - ctx := context.Background() - c := newTestDB(db, driver, "users") - var users []*User - err = c.Query(ctx, &users, `SELECT * FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") - - assert.Equal(t, nil, err) - assert.Equal(t, 1, len(users)) - assert.Equal(t, "Bia", users[0].Name) - assert.NotEqual(t, uint(0), users[0].ID) - }) - - t.Run("should return multiple users correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`) - assert.Equal(t, nil, err) - - _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) - - ctx := context.Background() - c := newTestDB(db, driver, "users") - var users []*User - err = c.Query(ctx, &users, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia") - - assert.Equal(t, nil, err) - assert.Equal(t, 2, len(users)) - - assert.Equal(t, "João Garcia", users[0].Name) - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "US", users[0].Address.Country) - - assert.Equal(t, "Bia Garcia", users[1].Name) - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "BR", users[1].Address.Country) - }) - }) + } t.Run("testing error cases", func(t *testing.T) { - err := createTable(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should report error if input is not a pointer to a slice of structs", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age) VALUES ('Andréa Sá', 0)`) - assert.Equal(t, nil, err) - - _, err = db.Exec(`INSERT INTO users (name, age) VALUES ('Caio Sá', 0)`) - assert.Equal(t, nil, err) + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Andréa Sá', 0)`) + assert.Equal(t, nil, err) + + _, err = db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Caio Sá', 0)`) + assert.Equal(t, nil, err) + + c := newTestDB(db, config.driver) err = c.Query(ctx, &User{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá") assert.NotEqual(t, nil, err) @@ -207,94 +380,212 @@ func TestQuery(t *testing.T) { }) t.Run("should report error if the query is not valid", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, config.driver) var users []User - err = c.Query(ctx, &users, `SELECT * FROM not a valid query`) + err := c.Query(ctx, &users, `SELECT * FROM not a valid query`) assert.NotEqual(t, nil, err) }) + + t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + var rows []struct { + User User `tablename:"users"` + Post Post `tablename:"posts"` + } + err := c.Query(ctx, &rows, `SELECT * FROM users u JOIN posts p ON u.id = p.user_id`) + assert.NotEqual(t, nil, err) + assert.Equal(t, true, strings.Contains(err.Error(), "nested struct"), "unexpected error msg: "+err.Error()) + assert.Equal(t, true, strings.Contains(err.Error(), "feature"), "unexpected error msg: "+err.Error()) + }) + + t.Run("should report error for nested structs with invalid types", func(t *testing.T) { + t.Run("int", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + var rows []struct { + Foo int `tablename:"foo"` + } + err := c.Query(ctx, &rows, fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), "% Ribeiro") + + assert.NotEqual(t, nil, err) + msg := err.Error() + for _, str := range []string{"foo", "int"} { + assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg)) + } + }) + + t.Run("*struct", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + var rows []struct { + Foo *User `tablename:"foo"` + } + err := c.Query(ctx, &rows, fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), "% Ribeiro") + + assert.NotEqual(t, nil, err) + msg := err.Error() + for _, str := range []string{"foo", "*ksql.User"} { + assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg)) + } + }) + }) }) }) } } func TestQueryOne(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { - t.Run(driver, func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) + for _, config := range supportedConfigs { + t.Run(config.driver, func(t *testing.T) { + variations := []struct { + desc string + queryPrefix string + }{ + { + desc: "with select *", + queryPrefix: "SELECT * ", + }, + { + desc: "building the SELECT part of the query internally", + queryPrefix: "", + }, + } + for _, variation := range variations { + err := createTables(config.driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + t.Run(variation.desc, func(t *testing.T) { + t.Run("should return RecordNotFoundErr when there are no results", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + u := User{} + err := c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE id=1;`) + assert.Equal(t, ErrRecordNotFound, err) + }) + + t.Run("should return a user correctly", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + + c := newTestDB(db, config.driver) + u := User{} + err = c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") + + assert.Equal(t, nil, err) + assert.NotEqual(t, uint(0), u.ID) + assert.Equal(t, "Bia", u.Name) + assert.Equal(t, Address{ + Country: "BR", + }, u.Address) + }) + + t.Run("should return only the first result on multiples matches", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Andréa Sá', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + + _, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Caio Sá', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + + c := newTestDB(db, config.driver) + + var u User + err = c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0)+` ORDER BY id ASC`, "% Sá") + assert.Equal(t, nil, err) + assert.Equal(t, "Andréa Sá", u.Name) + assert.Equal(t, 0, u.Age) + assert.Equal(t, Address{ + Country: "US", + }, u.Address) + }) + + t.Run("should query joined tables correctly", func(t *testing.T) { + // This test only makes sense with no query prefix + if variation.queryPrefix != "" { + return + } + + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + var joao User + getUserByName(db, config.driver, &joao, "João Ribeiro") + + _, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'João Post1')`)) + assert.Equal(t, nil, err) + + c := newTestDB(db, config.driver) + var row struct { + User User `tablename:"u"` + Post Post `tablename:"p"` + } + err = c.QueryOne(ctx, &row, fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), "% Ribeiro") + + assert.Equal(t, nil, err) + assert.Equal(t, joao.ID, row.User.ID) + assert.Equal(t, "João Ribeiro", row.User.Name) + assert.Equal(t, "João Post1", row.Post.Title) + }) + }) } - t.Run("should return RecordNotFoundErr when there are no results", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - u := User{} - err := c.QueryOne(ctx, &u, `SELECT * FROM users WHERE id=1;`) - assert.Equal(t, ErrRecordNotFound, err) - }) - - t.Run("should return a user correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) - - ctx := context.Background() - c := newTestDB(db, driver, "users") - u := User{} - err = c.QueryOne(ctx, &u, `SELECT * FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") - - assert.Equal(t, nil, err) - assert.NotEqual(t, uint(0), u.ID) - assert.Equal(t, "Bia", u.Name) - assert.Equal(t, Address{ - Country: "BR", - }, u.Address) - }) - - t.Run("should return only the first result on multiples matches", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Andréa Sá', 0, '{"country":"US"}')`) - assert.Equal(t, nil, err) - - _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Caio Sá', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) - - ctx := context.Background() - c := newTestDB(db, driver, "users") - - var u User - err = c.QueryOne(ctx, &u, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0)+` ORDER BY id ASC`, "% Sá") - assert.Equal(t, nil, err) - assert.Equal(t, "Andréa Sá", u.Name) - assert.Equal(t, 0, u.Age) - assert.Equal(t, Address{ - Country: "US", - }, u.Address) - }) - t.Run("should report error if input is not a pointer to struct", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Andréa Sá', 0, '{"country":"US"}')`) - assert.Equal(t, nil, err) - - _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Caio Sá', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Andréa Sá', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + + _, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Caio Sá', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + + c := newTestDB(db, config.driver) err = c.QueryOne(ctx, &[]User{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá") assert.NotEqual(t, nil, err) @@ -304,34 +595,50 @@ func TestQueryOne(t *testing.T) { }) t.Run("should report error if the query is not valid", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, config.driver) var user User - err = c.QueryOne(ctx, &user, `SELECT * FROM not a valid query`) + err := c.QueryOne(ctx, &user, `SELECT * FROM not a valid query`) assert.NotEqual(t, nil, err) }) + + t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + var row struct { + User User `tablename:"users"` + Post Post `tablename:"posts"` + } + err := c.QueryOne(ctx, &row, `SELECT * FROM users u JOIN posts p ON u.id = p.user_id LIMIT 1`) + assert.NotEqual(t, nil, err) + assert.Equal(t, true, strings.Contains(err.Error(), "nested struct"), "unexpected error msg: "+err.Error()) + assert.Equal(t, true, strings.Contains(err.Error(), "feature"), "unexpected error msg: "+err.Error()) + }) }) } } func TestInsert(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { - t.Run(driver, func(t *testing.T) { - t.Run("using slice of structs", func(t *testing.T) { - err := createTable(driver) + for _, config := range supportedConfigs { + t.Run(config.driver, func(t *testing.T) { + t.Run("success cases", func(t *testing.T) { + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should insert one user correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, config.driver) u := User{ Name: "Fernanda", @@ -340,7 +647,7 @@ func TestInsert(t *testing.T) { }, } - err := c.Insert(ctx, &u) + err := c.Insert(ctx, UsersTable, &u) assert.Equal(t, nil, err) assert.NotEqual(t, 0, u.ID) @@ -353,19 +660,16 @@ func TestInsert(t *testing.T) { }) t.Run("should insert ignoring the ID for sqlite and multiple ids", func(t *testing.T) { - if driver != "sqlite3" { + if supportedDialects[config.driver].InsertMethod() != insertWithLastInsertID { return } - db := connectDB(t, driver) - defer db.Close() - ctx := context.Background() + // Using columns "id" and "name" as IDs: - c, err := New(driver, connectionString[driver], Config{ - TableName: "users", - IDColumns: []string{"id", "name"}, - }) + table := NewTable("users", "id", "name") + + c, err := New(config.driver, connectionString[config.driver], Config{}) assert.Equal(t, nil, err) u := User{ @@ -377,39 +681,74 @@ func TestInsert(t *testing.T) { }, } - err = c.Insert(ctx, &u) + err = c.Insert(ctx, table, &u) assert.Equal(t, nil, err) assert.Equal(t, uint(0), u.ID) result := User{} - err = getUserByName(c.db, c.dialect, &result, "No ID returned") + err = getUserByName(c.db, config.driver, &result, "No ID returned") assert.Equal(t, nil, err) assert.Equal(t, u.Age, result.Age) assert.Equal(t, u.Address, result.Address) }) + + t.Run("should work with anonymous structs", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + err = c.Insert(ctx, UsersTable, &struct { + ID int `ksql:"id"` + Name string `ksql:"name"` + Address map[string]interface{} `ksql:"address,json"` + }{Name: "fake-name", Address: map[string]interface{}{"city": "bar"}}) + assert.Equal(t, nil, err) + }) + + t.Run("should work with preset IDs", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + usersByName := NewTable("users", "name") + + err = c.Insert(ctx, usersByName, &struct { + Name string `ksql:"name"` + Age int `ksql:"age"` + }{Name: "Preset Name", Age: 5455}) + assert.Equal(t, nil, err) + + var inserted User + err := getUserByName(db, config.driver, &inserted, "Preset Name") + assert.Equal(t, nil, err) + assert.Equal(t, 5455, inserted.Age) + }) }) t.Run("testing error cases", func(t *testing.T) { - err := createTable(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should report error for invalid input types", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, config.driver) - err = c.Insert(ctx, "foo") + err = c.Insert(ctx, UsersTable, "foo") assert.NotEqual(t, nil, err) - err = c.Insert(ctx, nullable.String("foo")) + err = c.Insert(ctx, UsersTable, nullable.String("foo")) assert.NotEqual(t, nil, err) - err = c.Insert(ctx, map[string]interface{}{ + err = c.Insert(ctx, UsersTable, map[string]interface{}{ "name": "foo", "age": 12, }) @@ -419,52 +758,124 @@ func TestInsert(t *testing.T) { &User{Name: "foo", Age: 22}, &User{Name: "bar", Age: 32}, } - err = c.Insert(ctx, ifUserForgetToExpandList) + err = c.Insert(ctx, UsersTable, ifUserForgetToExpandList) assert.NotEqual(t, nil, err) // We might want to support this in the future, but not for now: - err = c.Insert(ctx, User{Name: "not a ptr to user", Age: 42}) + err = c.Insert(ctx, UsersTable, User{Name: "not a ptr to user", Age: 42}) assert.NotEqual(t, nil, err) }) t.Run("should report error if for some reason the insertMethod is invalid", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, config.driver) // This is an invalid value: - c.insertMethod = insertMethod(42) + c.dialect = brokenDialect{} - err = c.Insert(ctx, &User{Name: "foo"}) + err = c.Insert(ctx, UsersTable, &User{Name: "foo"}) assert.NotEqual(t, nil, err) }) + + t.Run("should not panic if a column doesn't exist in the database", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + err = c.Insert(ctx, UsersTable, &struct { + ID string `ksql:"id"` + NonExistingColumn int `ksql:"non_existing"` + Name string `ksql:"name"` + }{NonExistingColumn: 42, Name: "fake-name"}) + assert.NotEqual(t, nil, err) + msg := err.Error() + assert.Equal(t, true, strings.Contains(msg, "column")) + assert.Equal(t, true, strings.Contains(msg, "non_existing")) + }) + + t.Run("should not panic if the ID column doesn't exist in the database", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + brokenTable := NewTable("users", "non_existing_id") + + _ = c.Insert(ctx, brokenTable, &struct { + ID string `ksql:"non_existing_id"` + Age int `ksql:"age"` + Name string `ksql:"name"` + }{Age: 42, Name: "fake-name"}) + }) + + t.Run("should not panic if the ID column is missing in the struct", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + err = c.Insert(ctx, UsersTable, &struct { + Age int `ksql:"age"` + Name string `ksql:"name"` + }{Age: 42, Name: "Inserted With no ID"}) + assert.Equal(t, nil, err) + + var u User + err = getUserByName(db, config.driver, &u, "Inserted With no ID") + assert.Equal(t, nil, err) + assert.NotEqual(t, uint(0), u.ID) + assert.Equal(t, 42, u.Age) + }) }) }) } } +type brokenDialect struct{} + +func (brokenDialect) InsertMethod() insertMethod { + return insertMethod(42) +} + +func (brokenDialect) Escape(str string) string { + return str +} + +func (brokenDialect) Placeholder(idx int) string { + return "?" +} + +func (brokenDialect) DriverName() string { + return "fake-driver-name" +} + func TestDelete(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { - t.Run(driver, func(t *testing.T) { - err := createTable(driver) + for _, config := range supportedConfigs { + t.Run(config.driver, func(t *testing.T) { + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should ignore empty lists of ids", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, config.driver) u := User{ Name: "Won't be deleted", } - err := c.Insert(ctx, &u) + err := c.Insert(ctx, UsersTable, &u) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) @@ -474,7 +885,7 @@ func TestDelete(t *testing.T) { assert.Equal(t, u.ID, result.ID) - err = c.Delete(ctx) + err = c.Delete(ctx, UsersTable) assert.Equal(t, nil, err) result = User{} @@ -484,17 +895,17 @@ func TestDelete(t *testing.T) { }) t.Run("should delete one id correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, config.driver) u1 := User{ Name: "Fernanda", } - err := c.Insert(ctx, &u1) + err := c.Insert(ctx, UsersTable, &u1) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u1.ID) @@ -507,7 +918,7 @@ func TestDelete(t *testing.T) { Name: "Won't be deleted", } - err = c.Insert(ctx, &u2) + err = c.Insert(ctx, UsersTable, &u2) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u2.ID) @@ -516,7 +927,7 @@ func TestDelete(t *testing.T) { assert.Equal(t, nil, err) assert.Equal(t, u2.ID, result.ID) - err = c.Delete(ctx, u1.ID) + err = c.Delete(ctx, UsersTable, u1.ID) assert.Equal(t, nil, err) result = User{} @@ -532,30 +943,30 @@ func TestDelete(t *testing.T) { }) t.Run("should delete multiple ids correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, config.driver) u1 := User{ Name: "Fernanda", } - err := c.Insert(ctx, &u1) + err := c.Insert(ctx, UsersTable, &u1) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u1.ID) u2 := User{ Name: "Juliano", } - err = c.Insert(ctx, &u2) + err = c.Insert(ctx, UsersTable, &u2) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u2.ID) u3 := User{ Name: "This won't be deleted", } - err = c.Insert(ctx, &u3) + err = c.Insert(ctx, UsersTable, &u3) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u3.ID) @@ -574,7 +985,7 @@ func TestDelete(t *testing.T) { assert.Equal(t, nil, err) assert.Equal(t, u3.ID, result.ID) - err = c.Delete(ctx, u1.ID, u2.ID) + err = c.Delete(ctx, UsersTable, u1.ID, u2.ID) assert.Equal(t, nil, err) results := []User{} @@ -589,33 +1000,31 @@ func TestDelete(t *testing.T) { } func TestUpdate(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { - t.Run(driver, func(t *testing.T) { - err := createTable(driver) + for _, config := range supportedConfigs { + t.Run(config.driver, func(t *testing.T) { + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should update one user correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, config.driver) u := User{ Name: "Letícia", } - _, err := db.Exec(`INSERT INTO users (name, age) VALUES ('Letícia', 0)`) + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 0)`) assert.Equal(t, nil, err) - row := db.QueryRow(`SELECT id FROM users WHERE name = 'Letícia'`) - assert.Equal(t, nil, row.Err()) - err = row.Scan(&u.ID) + err = getUserByName(db, config.driver, &u, "Letícia") assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) - err = c.Update(ctx, User{ + err = c.Update(ctx, UsersTable, User{ ID: u.ID, Name: "Thayane", }) @@ -628,25 +1037,23 @@ func TestUpdate(t *testing.T) { }) t.Run("should update one user correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, config.driver) u := User{ Name: "Letícia", } - _, err := db.Exec(`INSERT INTO users (name, age) VALUES ('Letícia', 0)`) + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 0)`) assert.Equal(t, nil, err) - row := db.QueryRow(`SELECT id FROM users WHERE name = 'Letícia'`) - assert.Equal(t, nil, row.Err()) - err = row.Scan(&u.ID) + err = getUserByName(db, config.driver, &u, "Letícia") assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) - err = c.Update(ctx, User{ + err = c.Update(ctx, UsersTable, User{ ID: u.ID, Name: "Thayane", }) @@ -659,31 +1066,27 @@ func TestUpdate(t *testing.T) { }) t.Run("should ignore null pointers on partial updates", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, config.driver) type partialUser struct { ID uint `ksql:"id"` Name string `ksql:"name"` Age *int `ksql:"age"` } - u := partialUser{ - Name: "Letícia", - Age: nullable.Int(22), - } - _, err := db.Exec(`INSERT INTO users (name, age) VALUES ('Letícia', 22)`) + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 22)`) assert.Equal(t, nil, err) - row := db.QueryRow(`SELECT id FROM users WHERE name = 'Letícia'`) - assert.Equal(t, nil, row.Err()) - err = row.Scan(&u.ID) + var u User + err = getUserByName(db, config.driver, &u, "Letícia") assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) - err = c.Update(ctx, partialUser{ + err = c.Update(ctx, UsersTable, partialUser{ ID: u.ID, // Should be updated because it is not null, just empty: Name: "", @@ -700,32 +1103,28 @@ func TestUpdate(t *testing.T) { }) t.Run("should update valid pointers on partial updates", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, config.driver) type partialUser struct { ID uint `ksql:"id"` Name string `ksql:"name"` Age *int `ksql:"age"` } - u := partialUser{ - Name: "Letícia", - Age: nullable.Int(22), - } - _, err := db.Exec(`INSERT INTO users (name, age) VALUES ('Letícia', 22)`) + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 22)`) assert.Equal(t, nil, err) - row := db.QueryRow(`SELECT id FROM users WHERE name = 'Letícia'`) - assert.Equal(t, nil, row.Err()) - err = row.Scan(&u.ID) + var u User + err = getUserByName(db, config.driver, &u, "Letícia") assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) // Should update all fields: - err = c.Update(ctx, partialUser{ + err = c.Update(ctx, UsersTable, partialUser{ ID: u.ID, Name: "Thay", Age: nullable.Int(42), @@ -740,14 +1139,28 @@ func TestUpdate(t *testing.T) { assert.Equal(t, 42, result.Age) }) - t.Run("should report database errors correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + t.Run("should return ErrRecordNotFound when asked to update an inexistent user", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver, "non_existing_table") + c := newTestDB(db, config.driver) - err = c.Update(ctx, User{ + err = c.Update(ctx, UsersTable, User{ + ID: 4200, + Name: "Thayane", + }) + assert.Equal(t, ErrRecordNotFound, err) + }) + + t.Run("should report database errors correctly", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + err = c.Update(ctx, NewTable("non_existing_table"), User{ ID: 1, Name: "Thayane", }) @@ -758,423 +1171,537 @@ func TestUpdate(t *testing.T) { } func TestQueryChunks(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { - t.Run(driver, func(t *testing.T) { - t.Run("should query a single row correctly", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - - _ = c.Insert(ctx, &User{ - Name: "User1", - Address: Address{Country: "BR"}, - }) - - var length int - var u User - err = c.QueryChunks(ctx, ChunkParser{ - Query: `SELECT * FROM users WHERE name = ` + c.dialect.Placeholder(0), - Params: []interface{}{"User1"}, - - ChunkSize: 100, - ForEachChunk: func(users []User) error { - length = len(users) - if length > 0 { - u = users[0] + for _, config := range supportedConfigs { + t.Run(config.driver, func(t *testing.T) { + variations := []struct { + desc string + queryPrefix string + }{ + { + desc: "with select *", + queryPrefix: "SELECT * ", + }, + { + desc: "building the SELECT part of the query internally", + queryPrefix: "", + }, + } + for _, variation := range variations { + t.Run(variation.desc, func(t *testing.T) { + t.Run("should query a single row correctly", func(t *testing.T) { + err := createTables(config.driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) } - return nil - }, - }) - assert.Equal(t, nil, err) - assert.Equal(t, 1, length) - assert.NotEqual(t, uint(0), u.ID) - assert.Equal(t, "User1", u.Name) - assert.Equal(t, "BR", u.Address.Country) - }) + db, closer := connectDB(t, config) + defer closer.Close() - t.Run("should query one chunk correctly", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } + ctx := context.Background() + c := newTestDB(db, config.driver) - db := connectDB(t, driver) - defer db.Close() + _ = c.Insert(ctx, UsersTable, &User{ + Name: "User1", + Address: Address{Country: "BR"}, + }) - ctx := context.Background() - c := newTestDB(db, driver, "users") + var length int + var u User + err = c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `FROM users WHERE name = ` + c.dialect.Placeholder(0), + Params: []interface{}{"User1"}, - _ = c.Insert(ctx, &User{Name: "User1", Address: Address{Country: "US"}}) - _ = c.Insert(ctx, &User{Name: "User2", Address: Address{Country: "BR"}}) + ChunkSize: 100, + ForEachChunk: func(users []User) error { + length = len(users) + if length > 0 { + u = users[0] + } + return nil + }, + }) - var lengths []int - var users []User - err = c.QueryChunks(ctx, ChunkParser{ - Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, - Params: []interface{}{"User%"}, - - ChunkSize: 2, - ForEachChunk: func(buffer []User) error { - users = append(users, buffer...) - lengths = append(lengths, len(buffer)) - return nil - }, - }) - - assert.Equal(t, nil, err) - assert.Equal(t, 1, len(lengths)) - assert.Equal(t, 2, lengths[0]) - - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "User1", users[0].Name) - assert.Equal(t, "US", users[0].Address.Country) - - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "User2", users[1].Name) - assert.Equal(t, "BR", users[1].Address.Country) - }) - - t.Run("should query chunks of 1 correctly", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - - _ = c.Insert(ctx, &User{Name: "User1", Address: Address{Country: "US"}}) - _ = c.Insert(ctx, &User{Name: "User2", Address: Address{Country: "BR"}}) - - var lengths []int - var users []User - err = c.QueryChunks(ctx, ChunkParser{ - Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, - Params: []interface{}{"User%"}, - - ChunkSize: 1, - ForEachChunk: func(buffer []User) error { - lengths = append(lengths, len(buffer)) - users = append(users, buffer...) - return nil - }, - }) - - assert.Equal(t, nil, err) - assert.Equal(t, 2, len(users)) - assert.Equal(t, []int{1, 1}, lengths) - - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "User1", users[0].Name) - assert.Equal(t, "US", users[0].Address.Country) - - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "User2", users[1].Name) - assert.Equal(t, "BR", users[1].Address.Country) - }) - - t.Run("should load partially filled chunks correctly", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) - - var lengths []int - var users []User - err = c.QueryChunks(ctx, ChunkParser{ - Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, - Params: []interface{}{"User%"}, - - ChunkSize: 2, - ForEachChunk: func(buffer []User) error { - lengths = append(lengths, len(buffer)) - users = append(users, buffer...) - return nil - }, - }) - - assert.Equal(t, nil, err) - assert.Equal(t, 3, len(users)) - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "User1", users[0].Name) - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "User2", users[1].Name) - assert.NotEqual(t, uint(0), users[2].ID) - assert.Equal(t, "User3", users[2].Name) - assert.Equal(t, []int{2, 1}, lengths) - }) - - t.Run("should abort the first iteration when the callback returns an ErrAbortIteration", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) - - var lengths []int - var users []User - err = c.QueryChunks(ctx, ChunkParser{ - Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, - Params: []interface{}{"User%"}, - - ChunkSize: 2, - ForEachChunk: func(buffer []User) error { - lengths = append(lengths, len(buffer)) - users = append(users, buffer...) - return ErrAbortIteration - }, - }) - - assert.Equal(t, nil, err) - assert.Equal(t, 2, len(users)) - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "User1", users[0].Name) - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "User2", users[1].Name) - assert.Equal(t, []int{2}, lengths) - }) - - t.Run("should abort the last iteration when the callback returns an ErrAbortIteration", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) - - returnVals := []error{nil, ErrAbortIteration} - var lengths []int - var users []User - err = c.QueryChunks(ctx, ChunkParser{ - Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, - Params: []interface{}{"User%"}, - - ChunkSize: 2, - ForEachChunk: func(buffer []User) error { - lengths = append(lengths, len(buffer)) - users = append(users, buffer...) - - return shiftErrSlice(&returnVals) - }, - }) - - assert.Equal(t, nil, err) - assert.Equal(t, 3, len(users)) - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "User1", users[0].Name) - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "User2", users[1].Name) - assert.NotEqual(t, uint(0), users[2].ID) - assert.Equal(t, "User3", users[2].Name) - assert.Equal(t, []int{2, 1}, lengths) - }) - - t.Run("should return error if the callback returns an error in the first iteration", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) - - var lengths []int - var users []User - err = c.QueryChunks(ctx, ChunkParser{ - Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, - Params: []interface{}{"User%"}, - - ChunkSize: 2, - ForEachChunk: func(buffer []User) error { - lengths = append(lengths, len(buffer)) - users = append(users, buffer...) - return errors.New("fake error msg") - }, - }) - - assert.NotEqual(t, nil, err) - assert.Equal(t, 2, len(users)) - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "User1", users[0].Name) - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "User2", users[1].Name) - assert.Equal(t, []int{2}, lengths) - }) - - t.Run("should return error if the callback returns an error in the last iteration", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) - - returnVals := []error{nil, errors.New("fake error msg")} - var lengths []int - var users []User - err = c.QueryChunks(ctx, ChunkParser{ - Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, - Params: []interface{}{"User%"}, - - ChunkSize: 2, - ForEachChunk: func(buffer []User) error { - lengths = append(lengths, len(buffer)) - users = append(users, buffer...) - - return shiftErrSlice(&returnVals) - }, - }) - - assert.NotEqual(t, nil, err) - assert.Equal(t, 3, len(users)) - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "User1", users[0].Name) - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "User2", users[1].Name) - assert.NotEqual(t, uint(0), users[2].ID) - assert.Equal(t, "User3", users[2].Name) - assert.Equal(t, []int{2, 1}, lengths) - }) - - t.Run("should report error if the input function is invalid", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - - funcs := []interface{}{ - nil, - "not a function", - func() error { - return nil - }, - func(extraInputValue []User, extra []User) error { - return nil - }, - func(invalidArgType string) error { - return nil - }, - func(missingReturnType []User) { - return - }, - func(users []User) string { - return "" - }, - func(extraReturnValue []User) ([]User, error) { - return nil, nil - }, - func(notSliceOfStructs []string) error { - return nil - }, - } - - for _, fn := range funcs { - err := c.QueryChunks(ctx, ChunkParser{ - Query: `SELECT * FROM users`, - Params: []interface{}{}, - - ChunkSize: 2, - ForEachChunk: fn, + assert.Equal(t, nil, err) + assert.Equal(t, 1, length) + assert.NotEqual(t, uint(0), u.ID) + assert.Equal(t, "User1", u.Name) + assert.Equal(t, "BR", u.Address.Country) }) - assert.NotEqual(t, nil, err) - } - }) - t.Run("should report error if the query is not valid", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + t.Run("should query one chunk correctly", func(t *testing.T) { + err := createTables(config.driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } - ctx := context.Background() - c := newTestDB(db, driver, "users") - err := c.QueryChunks(ctx, ChunkParser{ - Query: `SELECT * FROM not a valid query`, - Params: []interface{}{}, + db, closer := connectDB(t, config) + defer closer.Close() - ChunkSize: 2, - ForEachChunk: func(buffer []User) error { - return nil - }, + ctx := context.Background() + c := newTestDB(db, config.driver) + + _ = c.Insert(ctx, UsersTable, &User{Name: "User1", Address: Address{Country: "US"}}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2", Address: Address{Country: "BR"}}) + + var lengths []int + var users []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + ForEachChunk: func(buffer []User) error { + users = append(users, buffer...) + lengths = append(lengths, len(buffer)) + return nil + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 1, len(lengths)) + assert.Equal(t, 2, lengths[0]) + + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.Equal(t, "US", users[0].Address.Country) + + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.Equal(t, "BR", users[1].Address.Country) + }) + + t.Run("should query chunks of 1 correctly", func(t *testing.T) { + err := createTables(config.driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + _ = c.Insert(ctx, UsersTable, &User{Name: "User1", Address: Address{Country: "US"}}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2", Address: Address{Country: "BR"}}) + + var lengths []int + var users []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 1, + ForEachChunk: func(buffer []User) error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + return nil + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 2, len(users)) + assert.Equal(t, []int{1, 1}, lengths) + + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.Equal(t, "US", users[0].Address.Country) + + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.Equal(t, "BR", users[1].Address.Country) + }) + + t.Run("should load partially filled chunks correctly", func(t *testing.T) { + err := createTables(config.driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3"}) + + var lengths []int + var users []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + ForEachChunk: func(buffer []User) error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + return nil + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 3, len(users)) + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.NotEqual(t, uint(0), users[2].ID) + assert.Equal(t, "User3", users[2].Name) + assert.Equal(t, []int{2, 1}, lengths) + }) + + // xxx + t.Run("should query joined tables correctly", func(t *testing.T) { + // This test only makes sense with no query prefix + if variation.queryPrefix != "" { + return + } + + db, closer := connectDB(t, config) + defer closer.Close() + + joao := User{ + Name: "Thiago Ribeiro", + Age: 24, + } + thatiana := User{ + Name: "Thatiana Ribeiro", + Age: 20, + } + + ctx := context.Background() + c := newTestDB(db, config.driver) + _ = c.Insert(ctx, UsersTable, &joao) + _ = c.Insert(ctx, UsersTable, &thatiana) + + _, err := db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, thatiana.ID, `, 'Thatiana Post1')`)) + assert.Equal(t, nil, err) + _, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, thatiana.ID, `, 'Thatiana Post2')`)) + assert.Equal(t, nil, err) + _, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'Thiago Post1')`)) + assert.Equal(t, nil, err) + + var lengths []int + var users []User + var posts []Post + err = c.QueryChunks(ctx, ChunkParser{ + Query: fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), + Params: []interface{}{"% Ribeiro"}, + + ChunkSize: 2, + ForEachChunk: func(chunk []struct { + User User `tablename:"u"` + Post Post `tablename:"p"` + }) error { + lengths = append(lengths, len(chunk)) + for _, row := range chunk { + users = append(users, row.User) + posts = append(posts, row.Post) + } + return nil + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 3, len(posts)) + + assert.Equal(t, joao.ID, users[0].ID) + assert.Equal(t, "Thiago Ribeiro", users[0].Name) + assert.Equal(t, "Thiago Post1", posts[0].Title) + + assert.Equal(t, thatiana.ID, users[1].ID) + assert.Equal(t, "Thatiana Ribeiro", users[1].Name) + assert.Equal(t, "Thatiana Post1", posts[1].Title) + + assert.Equal(t, thatiana.ID, users[2].ID) + assert.Equal(t, "Thatiana Ribeiro", users[2].Name) + assert.Equal(t, "Thatiana Post2", posts[2].Title) + }) + + t.Run("should abort the first iteration when the callback returns an ErrAbortIteration", func(t *testing.T) { + err := createTables(config.driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3"}) + + var lengths []int + var users []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + ForEachChunk: func(buffer []User) error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + return ErrAbortIteration + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 2, len(users)) + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.Equal(t, []int{2}, lengths) + }) + + t.Run("should abort the last iteration when the callback returns an ErrAbortIteration", func(t *testing.T) { + err := createTables(config.driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3"}) + + returnVals := []error{nil, ErrAbortIteration} + var lengths []int + var users []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + ForEachChunk: func(buffer []User) error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + + return shiftErrSlice(&returnVals) + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 3, len(users)) + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.NotEqual(t, uint(0), users[2].ID) + assert.Equal(t, "User3", users[2].Name) + assert.Equal(t, []int{2, 1}, lengths) + }) + + t.Run("should return error if the callback returns an error in the first iteration", func(t *testing.T) { + err := createTables(config.driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3"}) + + var lengths []int + var users []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + ForEachChunk: func(buffer []User) error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + return errors.New("fake error msg") + }, + }) + + assert.NotEqual(t, nil, err) + assert.Equal(t, 2, len(users)) + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.Equal(t, []int{2}, lengths) + }) + + t.Run("should return error if the callback returns an error in the last iteration", func(t *testing.T) { + err := createTables(config.driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3"}) + + returnVals := []error{nil, errors.New("fake error msg")} + var lengths []int + var users []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + ForEachChunk: func(buffer []User) error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + + return shiftErrSlice(&returnVals) + }, + }) + + assert.NotEqual(t, nil, err) + assert.Equal(t, 3, len(users)) + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.NotEqual(t, uint(0), users[2].ID) + assert.Equal(t, "User3", users[2].Name) + assert.Equal(t, []int{2, 1}, lengths) + }) + + t.Run("should report error if the input function is invalid", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + funcs := []interface{}{ + nil, + "not a function", + func() error { + return nil + }, + func(extraInputValue []User, extra []User) error { + return nil + }, + func(invalidArgType string) error { + return nil + }, + func(missingReturnType []User) { + return + }, + func(users []User) string { + return "" + }, + func(extraReturnValue []User) ([]User, error) { + return nil, nil + }, + func(notSliceOfStructs []string) error { + return nil + }, + } + + for _, fn := range funcs { + err := c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `FROM users`, + Params: []interface{}{}, + + ChunkSize: 2, + ForEachChunk: fn, + }) + assert.NotEqual(t, nil, err) + } + }) + + t.Run("should report error if the query is not valid", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + err := c.QueryChunks(ctx, ChunkParser{ + Query: `SELECT * FROM not a valid query`, + Params: []interface{}{}, + + ChunkSize: 2, + ForEachChunk: func(buffer []User) error { + return nil + }, + }) + assert.NotEqual(t, nil, err) + }) + + t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + err := c.QueryChunks(ctx, ChunkParser{ + Query: `SELECT * FROM users u JOIN posts p ON u.id = p.user_id`, + Params: []interface{}{}, + + ChunkSize: 2, + ForEachChunk: func(buffer []struct { + User User `tablename:"users"` + Post Post `tablename:"posts"` + }) error { + return nil + }, + }) + + assert.NotEqual(t, nil, err) + assert.Equal(t, true, strings.Contains(err.Error(), "nested struct"), "unexpected error msg: "+err.Error()) + assert.Equal(t, true, strings.Contains(err.Error(), "feature"), "unexpected error msg: "+err.Error()) + }) }) - assert.NotEqual(t, nil, err) - }) + } }) } } func TestTransaction(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { - t.Run(driver, func(t *testing.T) { + for _, config := range supportedConfigs { + t.Run(config.driver, func(t *testing.T) { t.Run("should query a single row correctly", func(t *testing.T) { - err := createTable(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, config.driver) - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) var users []User - err = c.Transaction(ctx, func(db SQLProvider) error { + err = c.Transaction(ctx, func(db Provider) error { db.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC") return nil }) @@ -1186,26 +1713,26 @@ func TestTransaction(t *testing.T) { }) t.Run("should rollback when there are errors", func(t *testing.T) { - err := createTable(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, config.driver) u1 := User{Name: "User1", Age: 42} u2 := User{Name: "User2", Age: 42} - _ = c.Insert(ctx, &u1) - _ = c.Insert(ctx, &u2) + _ = c.Insert(ctx, UsersTable, &u1) + _ = c.Insert(ctx, UsersTable, &u2) - err = c.Transaction(ctx, func(db SQLProvider) error { - err = db.Insert(ctx, &User{Name: "User3"}) + err = c.Transaction(ctx, func(db Provider) error { + err = db.Insert(ctx, UsersTable, &User{Name: "User3"}) assert.Equal(t, nil, err) - err = db.Insert(ctx, &User{Name: "User4"}) + err = db.Insert(ctx, UsersTable, &User{Name: "User4"}) assert.Equal(t, nil, err) err = db.Exec(ctx, "UPDATE users SET age = 22") assert.Equal(t, nil, err) @@ -1227,18 +1754,22 @@ func TestTransaction(t *testing.T) { func TestScanRows(t *testing.T) { t.Run("should scan users correctly", func(t *testing.T) { - err := createTable("sqlite3") + err := createTables("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } + dialect := supportedDialects["sqlite3"] ctx := context.TODO() - db := connectDB(t, "sqlite3") - defer db.Close() - c := newTestDB(db, "sqlite3", "users") - _ = c.Insert(ctx, &User{Name: "User1", Age: 22}) - _ = c.Insert(ctx, &User{Name: "User2", Age: 14}) - _ = c.Insert(ctx, &User{Name: "User3", Age: 43}) + db, closer := connectDB(t, testConfig{ + driver: "sqlite3", + adapterName: "sql", + }) + defer closer.Close() + c := newTestDB(db, "sqlite3") + _ = c.Insert(ctx, UsersTable, &User{Name: "User1", Age: 22}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2", Age: 14}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3", Age: 43}) rows, err := db.QueryContext(ctx, "select * from users where name='User2'") assert.Equal(t, nil, err) @@ -1247,7 +1778,7 @@ func TestScanRows(t *testing.T) { assert.Equal(t, true, rows.Next()) var u User - err = scanRows(rows, &u) + err = scanRows(dialect, rows, &u) assert.Equal(t, nil, err) assert.Equal(t, "User2", u.Name) @@ -1255,16 +1786,20 @@ func TestScanRows(t *testing.T) { }) t.Run("should ignore extra columns from query", func(t *testing.T) { - err := createTable("sqlite3") + err := createTables("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } + dialect := supportedDialects["sqlite3"] ctx := context.TODO() - db := connectDB(t, "sqlite3") - defer db.Close() - c := newTestDB(db, "sqlite3", "users") - _ = c.Insert(ctx, &User{Name: "User1", Age: 22}) + db, closer := connectDB(t, testConfig{ + driver: "sqlite3", + adapterName: "sql", + }) + defer closer.Close() + c := newTestDB(db, "sqlite3") + _ = c.Insert(ctx, UsersTable, &User{Name: "User1", Age: 22}) rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User1'") assert.Equal(t, nil, err) @@ -1279,21 +1814,25 @@ func TestScanRows(t *testing.T) { // Omitted for testing purposes: // Name string `ksql:"name"` } - err = scanRows(rows, &user) + err = scanRows(dialect, rows, &user) assert.Equal(t, nil, err) assert.Equal(t, 22, user.Age) }) t.Run("should report error for closed rows", func(t *testing.T) { - err := createTable("sqlite3") + err := createTables("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } + dialect := supportedDialects["sqlite3"] ctx := context.TODO() - db := connectDB(t, "sqlite3") - defer db.Close() + db, closer := connectDB(t, testConfig{ + driver: "sqlite3", + adapterName: "sql", + }) + defer closer.Close() rows, err := db.QueryContext(ctx, "select * from users where name='User2'") assert.Equal(t, nil, err) @@ -1301,53 +1840,63 @@ func TestScanRows(t *testing.T) { var u User err = rows.Close() assert.Equal(t, nil, err) - err = scanRows(rows, &u) + err = scanRows(dialect, rows, &u) assert.NotEqual(t, nil, err) }) t.Run("should report if record is not a pointer", func(t *testing.T) { - err := createTable("sqlite3") + err := createTables("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } + dialect := supportedDialects["sqlite3"] ctx := context.TODO() - db := connectDB(t, "sqlite3") - defer db.Close() + db, closer := connectDB(t, testConfig{ + driver: "sqlite3", + adapterName: "sql", + }) + defer closer.Close() rows, err := db.QueryContext(ctx, "select * from users where name='User2'") assert.Equal(t, nil, err) var u User - err = scanRows(rows, u) + err = scanRows(dialect, rows, u) assert.NotEqual(t, nil, err) }) t.Run("should report if record is not a pointer to struct", func(t *testing.T) { - err := createTable("sqlite3") + err := createTables("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } + dialect := supportedDialects["sqlite3"] ctx := context.TODO() - db := connectDB(t, "sqlite3") - defer db.Close() + db, closer := connectDB(t, testConfig{ + driver: "sqlite3", + adapterName: "sql", + }) + defer closer.Close() rows, err := db.QueryContext(ctx, "select * from users where name='User2'") assert.Equal(t, nil, err) var u map[string]interface{} - err = scanRows(rows, &u) + err = scanRows(dialect, rows, &u) assert.NotEqual(t, nil, err) }) } var connectionString = map[string]string{ - "postgres": "host=localhost port=5432 user=postgres password=postgres dbname=ksql sslmode=disable", - "sqlite3": "/tmp/ksql.db", + "postgres": "host=localhost port=5432 user=postgres password=postgres dbname=ksql sslmode=disable", + "sqlite3": "/tmp/ksql.db", + "mysql": "root:mysql@(127.0.0.1:3306)/ksql?timeout=30s", + "sqlserver": "sqlserver://sa:Sqls3rv3r@127.0.0.1:1433?databaseName=ksql", } -func createTable(driver string) error { +func createTables(driver string) error { connStr := connectionString[driver] if connStr == "" { return fmt.Errorf("unsupported driver: '%s'", driver) @@ -1376,6 +1925,52 @@ func createTable(driver string) error { name VARCHAR(50), address jsonb )`) + case "mysql": + _, err = db.Exec(`CREATE TABLE users ( + id INT AUTO_INCREMENT PRIMARY KEY, + age INT, + name VARCHAR(50), + address JSON + )`) + case "sqlserver": + _, err = db.Exec(`CREATE TABLE users ( + id INT IDENTITY(1,1) PRIMARY KEY, + age INT, + name VARCHAR(50), + address NVARCHAR(4000) + )`) + } + if err != nil { + return fmt.Errorf("failed to create new users table: %s", err.Error()) + } + + db.Exec(`DROP TABLE posts`) + + switch driver { + case "sqlite3": + _, err = db.Exec(`CREATE TABLE posts ( + id INTEGER PRIMARY KEY, + user_id INTEGER, + title TEXT + )`) + case "postgres": + _, err = db.Exec(`CREATE TABLE posts ( + id serial PRIMARY KEY, + user_id INT, + title VARCHAR(50) + )`) + case "mysql": + _, err = db.Exec(`CREATE TABLE posts ( + id INT AUTO_INCREMENT PRIMARY KEY, + user_id INT, + title VARCHAR(50) + )`) + case "sqlserver": + _, err = db.Exec(`CREATE TABLE posts ( + id INT IDENTITY(1,1) PRIMARY KEY, + user_id INT, + title VARCHAR(50) + )`) } if err != nil { return fmt.Errorf("failed to create new users table: %s", err.Error()) @@ -1384,41 +1979,41 @@ func createTable(driver string) error { return nil } -func newTestDB(db *sql.DB, driver string, tableName string, ids ...string) DB { - if len(ids) == 0 { - ids = []string{"id"} - } - - dialect, err := GetDriverDialect(driver) - if err != nil { - panic(err) - } - +func newTestDB(db DBAdapter, driver string) DB { return DB{ - driver: driver, - dialect: dialect, - db: db, - tableName: tableName, - - idCols: ids, - insertMethod: map[string]insertMethod{ - "sqlite3": insertWithLastInsertID, - "postgres": insertWithReturning, - }[driver], + driver: driver, + dialect: supportedDialects[driver], + db: db, } } -func connectDB(t *testing.T, driver string) *sql.DB { - connStr := connectionString[driver] +type NopCloser struct{} + +func (NopCloser) Close() error { return nil } + +func connectDB(t *testing.T, config testConfig) (DBAdapter, io.Closer) { + connStr := connectionString[config.driver] if connStr == "" { - panic(fmt.Sprintf("unsupported driver: '%s'", driver)) + panic(fmt.Sprintf("unsupported driver: '%s'", config.driver)) } - db, err := sql.Open(driver, connStr) - if err != nil { - t.Fatal(err.Error()) + switch config.adapterName { + case "sql": + db, err := sql.Open(config.driver, connStr) + if err != nil { + t.Fatal(err.Error()) + } + return SQLAdapter{db}, db + case "pgx": + pool, err := pgxpool.Connect(context.TODO(), connStr) + if err != nil { + t.Fatal(err.Error()) + } + return PGXAdapter{pool}, NopCloser{} } - return db + + t.Fatalf("unsupported adapter: %s", config.adapterName) + return nil, nil } func shiftErrSlice(errs *[]error) error { @@ -1427,9 +2022,7 @@ func shiftErrSlice(errs *[]error) error { return err } -func getUsersByID(dbi sqlProvider, dialect Dialect, resultsPtr *[]User, ids ...uint) error { - db := dbi.(*sql.DB) - +func getUsersByID(db DBAdapter, dialect Dialect, resultsPtr *[]User, ids ...uint) error { placeholders := make([]string, len(ids)) params := make([]interface{}, len(ids)) for i := range ids { @@ -1438,7 +2031,8 @@ func getUsersByID(dbi sqlProvider, dialect Dialect, resultsPtr *[]User, ids ...u } results := []User{} - rows, err := db.Query( + rows, err := db.QueryContext( + context.TODO(), fmt.Sprintf( "SELECT id, name, age FROM users WHERE id IN (%s)", strings.Join(placeholders, ", "), @@ -1469,37 +2063,51 @@ func getUsersByID(dbi sqlProvider, dialect Dialect, resultsPtr *[]User, ids ...u return nil } -func getUserByID(dbi sqlProvider, dialect Dialect, result *User, id uint) error { - db := dbi.(*sql.DB) +func getUserByID(db DBAdapter, dialect Dialect, result *User, id uint) error { + rows, err := db.QueryContext(context.TODO(), `SELECT id, name, age, address FROM users WHERE id=`+dialect.Placeholder(0), id) + if err != nil { + return err + } + defer rows.Close() - row := db.QueryRow(`SELECT id, name, age, address FROM users WHERE id=`+dialect.Placeholder(0), id) - if row.Err() != nil { - return row.Err() + if rows.Next() == false { + if rows.Err() != nil { + return rows.Err() + } + return sql.ErrNoRows + } + + value := jsonSerializable{ + DriverName: dialect.DriverName(), + Attr: &result.Address, + } + + err = rows.Scan(&result.ID, &result.Name, &result.Age, &value) + if err != nil { + return err + } + + return nil +} + +func getUserByName(db DBAdapter, driver string, result *User, name string) error { + dialect := supportedDialects[driver] + + rows, err := db.QueryContext(context.TODO(), `SELECT id, name, age, address FROM users WHERE name=`+dialect.Placeholder(0), name) + if err != nil { + return err + } + defer rows.Close() + + if rows.Next() == false { + if rows.Err() != nil { + return rows.Err() + } + return sql.ErrNoRows } var rawAddr []byte - err := row.Scan(&result.ID, &result.Name, &result.Age, &rawAddr) - if err != nil { - return err - } - - if rawAddr == nil { - return nil - } - - return json.Unmarshal(rawAddr, &result.Address) -} - -func getUserByName(dbi sqlProvider, dialect Dialect, result *User, name string) error { - db := dbi.(*sql.DB) - - row := db.QueryRow(`SELECT id, name, age, address FROM users WHERE name=`+dialect.Placeholder(0), name) - if row.Err() != nil { - return row.Err() - } - - var rawAddr []byte - err := row.Scan(&result.ID, &result.Name, &result.Age, &rawAddr) + err = rows.Scan(&result.ID, &result.Name, &result.Age, &rawAddr) if err != nil { return err } diff --git a/kstructs/func_parser.go b/kstructs/func_parser.go new file mode 100644 index 0000000..56550f2 --- /dev/null +++ b/kstructs/func_parser.go @@ -0,0 +1,40 @@ +package kstructs + +import ( + "fmt" + "reflect" +) + +var errType = reflect.TypeOf(new(error)).Elem() + +// ParseInputFunc is used exclusively for parsing +// the ForEachChunk function used on the QueryChunks method. +func ParseInputFunc(fn interface{}) (reflect.Type, error) { + if fn == nil { + return nil, fmt.Errorf("the ForEachChunk attribute is required and cannot be nil") + } + + t := reflect.TypeOf(fn) + + if t.Kind() != reflect.Func { + return nil, fmt.Errorf("the ForEachChunk callback must be a function") + } + if t.NumIn() != 1 { + return nil, fmt.Errorf("the ForEachChunk callback must have 1 argument") + } + + if t.NumOut() != 1 { + return nil, fmt.Errorf("the ForEachChunk callback must have a single return value") + } + + if t.Out(0) != errType { + return nil, fmt.Errorf("the return value of the ForEachChunk callback must be of type error") + } + + argsType := t.In(0) + if argsType.Kind() != reflect.Slice { + return nil, fmt.Errorf("the argument of the ForEachChunk callback must a slice of structs") + } + + return argsType, nil +} diff --git a/structs/structs.go b/kstructs/structs.go similarity index 61% rename from structs/structs.go rename to kstructs/structs.go index c85e0b7..80959c0 100644 --- a/structs/structs.go +++ b/kstructs/structs.go @@ -1,48 +1,58 @@ -package structs +package kstructs import ( "fmt" "reflect" "strings" - - "github.com/pkg/errors" ) -type structInfo struct { - byIndex map[int]*fieldInfo - byName map[string]*fieldInfo +// StructInfo stores metainformation of the struct +// parser in order to help the ksql library to work +// efectively and efficiently with reflection. +type StructInfo struct { + IsNestedStruct bool + byIndex map[int]*FieldInfo + byName map[string]*FieldInfo } -type fieldInfo struct { +// FieldInfo contains reflection and tags +// information regarding a specific field +// of a struct. +type FieldInfo struct { Name string Index int Valid bool SerializeAsJSON bool } -func (s structInfo) ByIndex(idx int) *fieldInfo { +// ByIndex returns either the *FieldInfo of a valid +// empty struct with Valid set to false +func (s StructInfo) ByIndex(idx int) *FieldInfo { field, found := s.byIndex[idx] if !found { - return &fieldInfo{} + return &FieldInfo{} } return field } -func (s structInfo) ByName(name string) *fieldInfo { +// ByName returns either the *FieldInfo of a valid +// empty struct with Valid set to false +func (s StructInfo) ByName(name string) *FieldInfo { field, found := s.byName[name] if !found { - return &fieldInfo{} + return &FieldInfo{} } return field } -func (s structInfo) Add(field fieldInfo) { +func (s StructInfo) add(field FieldInfo) { field.Valid = true s.byIndex[field.Index] = &field s.byName[field.Name] = &field } -func (s structInfo) NumFields() int { +// NumFields ... +func (s StructInfo) NumFields() int { return len(s.byIndex) } @@ -50,7 +60,7 @@ func (s structInfo) NumFields() int { // because the total number of types on a program // should be finite. So keeping a single cache here // works fine. -var tagInfoCache = map[reflect.Type]structInfo{} +var tagInfoCache = map[reflect.Type]StructInfo{} // GetTagInfo efficiently returns the type information // using a global private cache @@ -58,16 +68,17 @@ var tagInfoCache = map[reflect.Type]structInfo{} // In the future we might move this cache inside // a struct, but for now this accessor is the one // we are using -func GetTagInfo(key reflect.Type) structInfo { +func GetTagInfo(key reflect.Type) StructInfo { return getCachedTagInfo(tagInfoCache, key) } -func getCachedTagInfo(tagInfoCache map[reflect.Type]structInfo, key reflect.Type) structInfo { - info, found := tagInfoCache[key] - if !found { - info = getTagNames(key) - tagInfoCache[key] = info +func getCachedTagInfo(tagInfoCache map[reflect.Type]StructInfo, key reflect.Type) StructInfo { + if info, found := tagInfoCache[key]; found { + return info } + + info := getTagNames(key) + tagInfoCache[key] = info return info } @@ -112,56 +123,6 @@ func StructToMap(obj interface{}) (map[string]interface{}, error) { return m, nil } -// FillStructWith is meant to be used on unit tests to mock -// the response from the database. -// -// The first argument is any struct you are passing to a ksql func, -// and the second is a map representing a database row you want -// to use to update this struct. -func FillStructWith(record interface{}, dbRow map[string]interface{}) error { - v := reflect.ValueOf(record) - t := v.Type() - - if t.Kind() != reflect.Ptr { - return fmt.Errorf( - "FillStructWith: expected input to be a pointer to struct but got %T", - record, - ) - } - - t = t.Elem() - v = v.Elem() - - if t.Kind() != reflect.Struct { - return fmt.Errorf( - "FillStructWith: expected input kind to be a struct but got %T", - record, - ) - } - - info := getCachedTagInfo(tagInfoCache, t) - for colName, rawSrc := range dbRow { - fieldInfo := info.ByName(colName) - if !fieldInfo.Valid { - // Ignore columns not tagged with `ksql:"..."` - continue - } - - src := NewPtrConverter(rawSrc) - dest := v.Field(fieldInfo.Index) - destType := t.Field(fieldInfo.Index).Type - - destValue, err := src.Convert(destType) - if err != nil { - return errors.Wrap(err, fmt.Sprintf("FillStructWith: error on field `%s`", colName)) - } - - dest.Set(destValue) - } - - return nil -} - // PtrConverter was created to make it easier // to handle conversion between ptr and non ptr types, e.g.: // @@ -243,58 +204,15 @@ func (p PtrConverter) Convert(destType reflect.Type) (reflect.Value, error) { return destValue, nil } -// FillSliceWith is meant to be used on unit tests to mock -// the response from the database. -// -// The first argument is any slice of structs you are passing to a ksql func, -// and the second is a slice of maps representing the database rows you want -// to use to update this struct. -func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error { - sliceRef := reflect.ValueOf(entities) - sliceType := sliceRef.Type() - if sliceType.Kind() != reflect.Ptr { - return fmt.Errorf( - "FillSliceWith: expected input to be a pointer to struct but got %v", - sliceType, - ) - } - - structType, isSliceOfPtrs, err := DecodeAsSliceOfStructs(sliceType.Elem()) - if err != nil { - return errors.Wrap(err, "FillSliceWith") - } - - slice := sliceRef.Elem() - for idx, row := range dbRows { - if slice.Len() <= idx { - var elemValue reflect.Value - elemValue = reflect.New(structType) - if !isSliceOfPtrs { - elemValue = elemValue.Elem() - } - slice = reflect.Append(slice, elemValue) - } - - err := FillStructWith(slice.Index(idx).Addr().Interface(), row) - if err != nil { - return errors.Wrap(err, "FillSliceWith") - } - } - - sliceRef.Elem().Set(slice) - - return nil -} - // This function collects only the names // that will be used from the input type. // // This should save several calls to `Field(i).Tag.Get("foo")` // which improves performance by a lot. -func getTagNames(t reflect.Type) structInfo { - info := structInfo{ - byIndex: map[int]*fieldInfo{}, - byName: map[string]*fieldInfo{}, +func getTagNames(t reflect.Type) StructInfo { + info := StructInfo{ + byIndex: map[int]*FieldInfo{}, + byName: map[string]*FieldInfo{}, } for i := 0; i < t.NumField(); i++ { name := t.Field(i).Tag.Get("ksql") @@ -309,13 +227,36 @@ func getTagNames(t reflect.Type) structInfo { serializeAsJSON = tags[1] == "json" } - info.Add(fieldInfo{ + info.add(FieldInfo{ Name: name, Index: i, SerializeAsJSON: serializeAsJSON, }) } + // If there were `ksql` tags present, then we are finished: + if len(info.byIndex) > 0 { + return info + } + + // If there are no `ksql` tags in the struct, lets assume + // it is a struct tagged with `tablename` for allowing JOINs + for i := 0; i < t.NumField(); i++ { + name := t.Field(i).Tag.Get("tablename") + if name == "" { + continue + } + + info.add(FieldInfo{ + Name: name, + Index: i, + }) + } + + if len(info.byIndex) > 0 { + info.IsNestedStruct = true + } + return info } diff --git a/structs/structs_test.go b/kstructs/structs_test.go similarity index 99% rename from structs/structs_test.go rename to kstructs/structs_test.go index c11c3c1..3887be2 100644 --- a/structs/structs_test.go +++ b/kstructs/structs_test.go @@ -1,4 +1,4 @@ -package structs +package kstructs import ( "testing" diff --git a/kstructs/testhelpers.go b/kstructs/testhelpers.go new file mode 100644 index 0000000..caa7bb7 --- /dev/null +++ b/kstructs/testhelpers.go @@ -0,0 +1,125 @@ +package kstructs + +import ( + "fmt" + "reflect" + + "github.com/pkg/errors" +) + +// FillStructWith is meant to be used on unit tests to mock +// the response from the database. +// +// The first argument is any struct you are passing to a ksql func, +// and the second is a map representing a database row you want +// to use to update this struct. +func FillStructWith(record interface{}, dbRow map[string]interface{}) error { + v := reflect.ValueOf(record) + t := v.Type() + + if t.Kind() != reflect.Ptr { + return fmt.Errorf( + "FillStructWith: expected input to be a pointer to struct but got %T", + record, + ) + } + + t = t.Elem() + v = v.Elem() + + if t.Kind() != reflect.Struct { + return fmt.Errorf( + "FillStructWith: expected input kind to be a struct but got %T", + record, + ) + } + + info := GetTagInfo(t) + for colName, rawSrc := range dbRow { + fieldInfo := info.ByName(colName) + if !fieldInfo.Valid { + // Ignore columns not tagged with `ksql:"..."` + continue + } + + src := NewPtrConverter(rawSrc) + dest := v.Field(fieldInfo.Index) + destType := t.Field(fieldInfo.Index).Type + + destValue, err := src.Convert(destType) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("FillStructWith: error on field `%s`", colName)) + } + + dest.Set(destValue) + } + + return nil +} + +// FillSliceWith is meant to be used on unit tests to mock +// the response from the database. +// +// The first argument is any slice of structs you are passing to a ksql func, +// and the second is a slice of maps representing the database rows you want +// to use to update this struct. +func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error { + sliceRef := reflect.ValueOf(entities) + sliceType := sliceRef.Type() + if sliceType.Kind() != reflect.Ptr { + return fmt.Errorf( + "FillSliceWith: expected input to be a pointer to a slice of structs but got %v", + sliceType, + ) + } + + structType, isSliceOfPtrs, err := DecodeAsSliceOfStructs(sliceType.Elem()) + if err != nil { + return errors.Wrap(err, "FillSliceWith") + } + + slice := sliceRef.Elem() + for idx, row := range dbRows { + if slice.Len() <= idx { + var elemValue reflect.Value + elemValue = reflect.New(structType) + if !isSliceOfPtrs { + elemValue = elemValue.Elem() + } + slice = reflect.Append(slice, elemValue) + } + + err := FillStructWith(slice.Index(idx).Addr().Interface(), row) + if err != nil { + return errors.Wrap(err, "FillSliceWith") + } + } + + sliceRef.Elem().Set(slice) + + return nil +} + +// CallFunctionWithRows was created for helping test the QueryChunks method +func CallFunctionWithRows(fn interface{}, rows []map[string]interface{}) error { + fnValue := reflect.ValueOf(fn) + chunkType, err := ParseInputFunc(fn) + if err != nil { + return err + } + + chunk := reflect.MakeSlice(chunkType, 0, len(rows)) + + // Create a pointer to a slice (required by FillSliceWith) + chunkPtr := reflect.New(chunkType) + chunkPtr.Elem().Set(chunk) + + err = FillSliceWith(chunkPtr.Interface(), rows) + if err != nil { + return err + } + + err, _ = fnValue.Call([]reflect.Value{chunkPtr.Elem()})[0].Interface().(error) + + return err +} diff --git a/mocks.go b/mocks.go index eab1770..97911cd 100644 --- a/mocks.go +++ b/mocks.go @@ -2,58 +2,58 @@ package ksql import "context" -var _ SQLProvider = MockSQLProvider{} +var _ Provider = Mock{} -// MockSQLProvider ... -type MockSQLProvider struct { - InsertFn func(ctx context.Context, record interface{}) error - UpdateFn func(ctx context.Context, record interface{}) error - DeleteFn func(ctx context.Context, ids ...interface{}) error +// Mock ... +type Mock struct { + InsertFn func(ctx context.Context, table Table, record interface{}) error + UpdateFn func(ctx context.Context, table Table, record interface{}) error + DeleteFn func(ctx context.Context, table Table, ids ...interface{}) error QueryFn func(ctx context.Context, records interface{}, query string, params ...interface{}) error QueryOneFn func(ctx context.Context, record interface{}, query string, params ...interface{}) error QueryChunksFn func(ctx context.Context, parser ChunkParser) error ExecFn func(ctx context.Context, query string, params ...interface{}) error - TransactionFn func(ctx context.Context, fn func(db SQLProvider) error) error + TransactionFn func(ctx context.Context, fn func(db Provider) error) error } // Insert ... -func (m MockSQLProvider) Insert(ctx context.Context, record interface{}) error { - return m.InsertFn(ctx, record) +func (m Mock) Insert(ctx context.Context, table Table, record interface{}) error { + return m.InsertFn(ctx, table, record) } // Update ... -func (m MockSQLProvider) Update(ctx context.Context, record interface{}) error { - return m.UpdateFn(ctx, record) +func (m Mock) Update(ctx context.Context, table Table, record interface{}) error { + return m.UpdateFn(ctx, table, record) } // Delete ... -func (m MockSQLProvider) Delete(ctx context.Context, ids ...interface{}) error { - return m.DeleteFn(ctx, ids...) +func (m Mock) Delete(ctx context.Context, table Table, ids ...interface{}) error { + return m.DeleteFn(ctx, table, ids...) } // Query ... -func (m MockSQLProvider) Query(ctx context.Context, records interface{}, query string, params ...interface{}) error { +func (m Mock) Query(ctx context.Context, records interface{}, query string, params ...interface{}) error { return m.QueryFn(ctx, records, query, params...) } // QueryOne ... -func (m MockSQLProvider) QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error { +func (m Mock) QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error { return m.QueryOneFn(ctx, record, query, params...) } // QueryChunks ... -func (m MockSQLProvider) QueryChunks(ctx context.Context, parser ChunkParser) error { +func (m Mock) QueryChunks(ctx context.Context, parser ChunkParser) error { return m.QueryChunksFn(ctx, parser) } // Exec ... -func (m MockSQLProvider) Exec(ctx context.Context, query string, params ...interface{}) error { +func (m Mock) Exec(ctx context.Context, query string, params ...interface{}) error { return m.ExecFn(ctx, query, params...) } // Transaction ... -func (m MockSQLProvider) Transaction(ctx context.Context, fn func(db SQLProvider) error) error { +func (m Mock) Transaction(ctx context.Context, fn func(db Provider) error) error { return m.TransactionFn(ctx, fn) } diff --git a/pgx_adapter.go b/pgx_adapter.go new file mode 100644 index 0000000..944dff7 --- /dev/null +++ b/pgx_adapter.go @@ -0,0 +1,106 @@ +package ksql + +import ( + "context" + "fmt" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgxpool" +) + +// PGXAdapter adapts the sql.DB type to be compatible with the `DBAdapter` interface +type PGXAdapter struct { + db *pgxpool.Pool +} + +var _ DBAdapter = PGXAdapter{} + +// ExecContext implements the DBAdapter interface +func (p PGXAdapter) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { + result, err := p.db.Exec(ctx, query, args...) + return PGXResult{result}, err +} + +// QueryContext implements the DBAdapter interface +func (p PGXAdapter) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) { + rows, err := p.db.Query(ctx, query, args...) + return PGXRows{rows}, err +} + +// BeginTx implements the Tx interface +func (p PGXAdapter) BeginTx(ctx context.Context) (Tx, error) { + tx, err := p.db.Begin(ctx) + return PGXTx{tx}, err +} + +// PGXResult is used to implement the DBAdapter interface and implements +// the Result interface +type PGXResult struct { + tag pgconn.CommandTag +} + +// RowsAffected implements the Result interface +func (p PGXResult) RowsAffected() (int64, error) { + return p.tag.RowsAffected(), nil +} + +// LastInsertId implements the Result interface +func (p PGXResult) LastInsertId() (int64, error) { + return 0, fmt.Errorf( + "LastInsertId is not implemented in the pgx adapter, use the `RETURNING` statement instead", + ) +} + +// PGXTx is used to implement the DBAdapter interface and implements +// the Tx interface +type PGXTx struct { + tx pgx.Tx +} + +// ExecContext implements the Tx interface +func (p PGXTx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { + result, err := p.tx.Exec(ctx, query, args...) + return PGXResult{result}, err +} + +// QueryContext implements the Tx interface +func (p PGXTx) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) { + rows, err := p.tx.Query(ctx, query, args...) + return PGXRows{rows}, err +} + +// Rollback implements the Tx interface +func (p PGXTx) Rollback(ctx context.Context) error { + return p.tx.Rollback(ctx) +} + +// Commit implements the Tx interface +func (p PGXTx) Commit(ctx context.Context) error { + return p.tx.Commit(ctx) +} + +var _ Tx = PGXTx{} + +// PGXRows implements the Rows interface and is used to help +// the PGXAdapter to implement the DBAdapter interface. +type PGXRows struct { + pgx.Rows +} + +var _ Rows = PGXRows{} + +// Columns implements the Rows interface +func (p PGXRows) Columns() ([]string, error) { + var names []string + for _, desc := range p.Rows.FieldDescriptions() { + names = append(names, string(desc.Name)) + } + return names, nil +} + +// Close implements the Rows interface +func (p PGXRows) Close() error { + p.Rows.Close() + return nil +} diff --git a/sql_adapter.go b/sql_adapter.go new file mode 100644 index 0000000..b3fa8a1 --- /dev/null +++ b/sql_adapter.go @@ -0,0 +1,57 @@ +package ksql + +import ( + "context" + "database/sql" +) + +// SQLAdapter adapts the sql.DB type to be compatible with the `DBAdapter` interface +type SQLAdapter struct { + *sql.DB +} + +var _ DBAdapter = SQLAdapter{} + +// ExecContext implements the DBAdapter interface +func (s SQLAdapter) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { + return s.DB.ExecContext(ctx, query, args...) +} + +// QueryContext implements the DBAdapter interface +func (s SQLAdapter) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) { + return s.DB.QueryContext(ctx, query, args...) +} + +// BeginTx implements the Tx interface +func (s SQLAdapter) BeginTx(ctx context.Context) (Tx, error) { + tx, err := s.DB.BeginTx(ctx, nil) + return SQLTx{Tx: tx}, err +} + +// SQLTx is used to implement the DBAdapter interface and implements +// the Tx interface +type SQLTx struct { + *sql.Tx +} + +// ExecContext implements the Tx interface +func (s SQLTx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { + return s.Tx.ExecContext(ctx, query, args...) +} + +// QueryContext implements the Tx interface +func (s SQLTx) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) { + return s.Tx.QueryContext(ctx, query, args...) +} + +// Rollback implements the Tx interface +func (s SQLTx) Rollback(ctx context.Context) error { + return s.Tx.Rollback() +} + +// Commit implements the Tx interface +func (s SQLTx) Commit(ctx context.Context) error { + return s.Tx.Commit() +} + +var _ Tx = SQLTx{}