From 3a90b03a37d88584e7c53d1cfa7139424ae3ed1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sat, 8 May 2021 11:56:57 -0300 Subject: [PATCH 01/40] Refactor dialect.go so its easier to add new dialects --- dialect.go | 40 +++++++++++++++++++++++++++++++++++----- docker-compose.yml | 8 ++++++++ ksql.go | 24 ++++-------------------- ksql_test.go | 16 ++++++++-------- 4 files changed, 55 insertions(+), 33 deletions(-) diff --git a/dialect.go b/dialect.go index 5522fb8..6640c34 100644 --- a/dialect.go +++ b/dialect.go @@ -2,13 +2,32 @@ package ksql import "strconv" +type insertMethod int + +const ( + insertWithReturning insertMethod = iota + insertWithLastInsertID + insertWithNoIDRetrieval +) + +var supportedDialects = map[string]dialect{ + "postgres": &postgresDialect{}, + "sqlite3": &sqlite3Dialect{}, + // "mysql": &mysqlDialect{}, +} + type dialect interface { + InsertMethod() insertMethod Escape(str string) string Placeholder(idx int) string } type postgresDialect struct{} +func (postgresDialect) InsertMethod() insertMethod { + return insertWithReturning +} + func (postgresDialect) Escape(str string) string { return `"` + str + `"` } @@ -19,6 +38,10 @@ func (postgresDialect) Placeholder(idx int) string { type sqlite3Dialect struct{} +func (sqlite3Dialect) InsertMethod() insertMethod { + return insertWithLastInsertID +} + func (sqlite3Dialect) Escape(str string) string { return "`" + str + "`" } @@ -27,9 +50,16 @@ func (sqlite3Dialect) Placeholder(idx int) string { return "?" } -func getDriverDialect(driver string) dialect { - return map[string]dialect{ - "postgres": &postgresDialect{}, - "sqlite3": &sqlite3Dialect{}, - }[driver] +type mysqlDialect struct{} + +func (mysqlDialect) InsertMethod() insertMethod { + return insertWithLastInsertID +} + +func (mysqlDialect) Escape(str string) string { + return "`" + str + "`" +} + +func (mysqlDialect) Placeholder(idx int) string { + return "?" } diff --git a/docker-compose.yml b/docker-compose.yml index 878d61d..6d0932c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,3 +15,11 @@ services: environment: - POSTGRES_USER=postgres - POSTGRES_PASSWORD=postgres + + mysql: + image: mysql + restart: always + ports: + - "127.0.0.1:3306:3306" + environment: + MYSQL_ROOT_PASSWORD: mysql diff --git a/ksql.go b/ksql.go index e426514..0386b68 100644 --- a/ksql.go +++ b/ksql.go @@ -32,14 +32,6 @@ type sqlProvider interface { QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } -type insertMethod int - -const ( - insertWithReturning insertMethod = iota - insertWithLastInsertID - insertWithNoIDRetrieval -) - // Config describes the optional arguments accepted // by the ksql.New() function. type Config struct { @@ -75,7 +67,7 @@ func New( db.SetMaxOpenConns(config.MaxOpenConns) - dialect := getDriverDialect(dbDriver) + dialect := supportedDialects[dbDriver] if dialect == nil { return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver) } @@ -84,17 +76,9 @@ func New( config.IDColumns = []string{"id"} } - var insertMethod insertMethod - switch dbDriver { - case "sqlite3": - insertMethod = insertWithLastInsertID - if len(config.IDColumns) > 1 { - insertMethod = insertWithNoIDRetrieval - } - case "postgres": - insertMethod = insertWithReturning - default: - return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver) + insertMethod := dialect.InsertMethod() + if len(config.IDColumns) > 1 && insertMethod == insertWithLastInsertID { + insertMethod = insertWithNoIDRetrieval } return DB{ diff --git a/ksql_test.go b/ksql_test.go index 903acff..f0514b9 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -33,7 +33,7 @@ type Address struct { } func TestQuery(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { + for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { t.Run("using slice of structs", func(t *testing.T) { err := createTable(driver) @@ -222,7 +222,7 @@ func TestQuery(t *testing.T) { } func TestQueryOne(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { + for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { err := createTable(driver) if err != nil { @@ -318,7 +318,7 @@ func TestQueryOne(t *testing.T) { } func TestInsert(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { + for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { t.Run("using slice of structs", func(t *testing.T) { err := createTable(driver) @@ -446,7 +446,7 @@ func TestInsert(t *testing.T) { } func TestDelete(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { + for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { err := createTable(driver) if err != nil { @@ -589,7 +589,7 @@ func TestDelete(t *testing.T) { } func TestUpdate(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { + for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { err := createTable(driver) if err != nil { @@ -758,7 +758,7 @@ func TestUpdate(t *testing.T) { } func TestQueryChunks(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { + for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { t.Run("should query a single row correctly", func(t *testing.T) { err := createTable(driver) @@ -1156,7 +1156,7 @@ func TestQueryChunks(t *testing.T) { } func TestTransaction(t *testing.T) { - for _, driver := range []string{"sqlite3", "postgres"} { + for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { t.Run("should query a single row correctly", func(t *testing.T) { err := createTable(driver) @@ -1391,7 +1391,7 @@ func newTestDB(db *sql.DB, driver string, tableName string, ids ...string) DB { return DB{ driver: driver, - dialect: getDriverDialect(driver), + dialect: supportedDialects[driver], db: db, tableName: tableName, From bbad31ce4d3ef263b2d55415d03ed5728f6c858e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sat, 8 May 2021 12:43:11 -0300 Subject: [PATCH 02/40] Add support to the mysql driver --- dialect.go | 2 +- go.mod | 1 + ksql_test.go | 12 +++++++++++- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/dialect.go b/dialect.go index 6640c34..c46869d 100644 --- a/dialect.go +++ b/dialect.go @@ -13,7 +13,7 @@ const ( var supportedDialects = map[string]dialect{ "postgres": &postgresDialect{}, "sqlite3": &sqlite3Dialect{}, - // "mysql": &mysqlDialect{}, + "mysql": &mysqlDialect{}, } type dialect interface { diff --git a/go.mod b/go.mod index b218375..2b82c83 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.14 require ( github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018 + github.com/go-sql-driver/mysql v1.4.0 // indirect github.com/golang/mock v1.5.0 github.com/jmoiron/sqlx v1.2.0 github.com/lib/pq v1.1.1 diff --git a/ksql_test.go b/ksql_test.go index f0514b9..85a4691 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/ditointernet/go-assert" + _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" "github.com/vingarcia/ksql/nullable" @@ -353,7 +354,7 @@ func TestInsert(t *testing.T) { }) t.Run("should insert ignoring the ID for sqlite and multiple ids", func(t *testing.T) { - if driver != "sqlite3" { + if supportedDialects[driver].InsertMethod() != insertWithLastInsertID { return } @@ -1345,6 +1346,7 @@ func TestScanRows(t *testing.T) { var connectionString = map[string]string{ "postgres": "host=localhost port=5432 user=postgres password=postgres dbname=ksql sslmode=disable", "sqlite3": "/tmp/ksql.db", + "mysql": "root:mysql@(127.0.0.1:3306)/ksql?timeout=30s", } func createTable(driver string) error { @@ -1376,6 +1378,13 @@ func createTable(driver string) error { name VARCHAR(50), address jsonb )`) + case "mysql": + _, err = db.Exec(`CREATE TABLE users ( + id INT AUTO_INCREMENT PRIMARY KEY, + age INT, + name VARCHAR(50), + address JSON + )`) } if err != nil { return fmt.Errorf("failed to create new users table: %s", err.Error()) @@ -1399,6 +1408,7 @@ func newTestDB(db *sql.DB, driver string, tableName string, ids ...string) DB { insertMethod: map[string]insertMethod{ "sqlite3": insertWithLastInsertID, "postgres": insertWithReturning, + "mysql": insertWithLastInsertID, }[driver], } } From 2dd55131d5b84986f21fccea4bd40c500dfd81b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sat, 8 May 2021 13:02:01 -0300 Subject: [PATCH 03/40] Add README instructions on how to run the tests --- README.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/README.md b/README.md index ec351fd..35b66e0 100644 --- a/README.md +++ b/README.md @@ -277,6 +277,27 @@ PASS ok github.com/vingarcia/ksql 34.251s ``` +### Running the tests + +The tests run in real database instances so the easiest way to have +them working is to just start them using docker-compose: + +```bash +docker-compose up -d +``` + +And then for each of them you will need to run the command: + +```sql +CREATE DATABASE ksql; +``` + +After that you can just run the tests by using: + +```bash +make test +``` + ### TODO List - Implement support for nested objects with prefixed table names From 56aa77135c78c4553c848fc3711c1e741506a71d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 9 May 2021 22:33:03 -0300 Subject: [PATCH 04/40] Add support to the `sqlserver` driver =] --- README.md | 4 +- dialect.go | 39 +++++++++++++-- docker-compose.yml | 10 ++++ go.mod | 1 + go.sum | 6 +++ json.go | 9 +++- ksql.go | 117 ++++++++++++++++++++++++++++----------------- ksql_test.go | 52 ++++++++++++-------- 8 files changed, 167 insertions(+), 71 deletions(-) diff --git a/README.md b/README.md index 35b66e0..24d428d 100644 --- a/README.md +++ b/README.md @@ -14,10 +14,12 @@ The goals were: **Supported Drivers:** -Currently we only support 2 Drivers: +Currently we support 4 Drivers: - `"postgres"` - `"sqlite3"` +- `"mysql"` +- `"sqlserver"` ### Why KissSQL? diff --git a/dialect.go b/dialect.go index c46869d..d245da7 100644 --- a/dialect.go +++ b/dialect.go @@ -6,24 +6,31 @@ type insertMethod int const ( insertWithReturning insertMethod = iota + insertWithOutput insertWithLastInsertID insertWithNoIDRetrieval ) var supportedDialects = map[string]dialect{ - "postgres": &postgresDialect{}, - "sqlite3": &sqlite3Dialect{}, - "mysql": &mysqlDialect{}, + "postgres": &postgresDialect{}, + "sqlite3": &sqlite3Dialect{}, + "mysql": &mysqlDialect{}, + "sqlserver": &sqlserverDialect{}, } type dialect interface { InsertMethod() insertMethod Escape(str string) string Placeholder(idx int) string + DriverName() string } type postgresDialect struct{} +func (postgresDialect) DriverName() string { + return "postgres" +} + func (postgresDialect) InsertMethod() insertMethod { return insertWithReturning } @@ -38,6 +45,10 @@ func (postgresDialect) Placeholder(idx int) string { type sqlite3Dialect struct{} +func (sqlite3Dialect) DriverName() string { + return "sqlite3" +} + func (sqlite3Dialect) InsertMethod() insertMethod { return insertWithLastInsertID } @@ -52,6 +63,10 @@ func (sqlite3Dialect) Placeholder(idx int) string { type mysqlDialect struct{} +func (mysqlDialect) DriverName() string { + return "mysql" +} + func (mysqlDialect) InsertMethod() insertMethod { return insertWithLastInsertID } @@ -63,3 +78,21 @@ func (mysqlDialect) Escape(str string) string { func (mysqlDialect) Placeholder(idx int) string { return "?" } + +type sqlserverDialect struct{} + +func (sqlserverDialect) DriverName() string { + return "sqlserver" +} + +func (sqlserverDialect) InsertMethod() insertMethod { + return insertWithOutput +} + +func (sqlserverDialect) Escape(str string) string { + return `[` + str + `]` +} + +func (sqlserverDialect) Placeholder(idx int) string { + return "@p" + strconv.Itoa(idx+1) +} diff --git a/docker-compose.yml b/docker-compose.yml index 6d0932c..6e70b12 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -23,3 +23,13 @@ services: - "127.0.0.1:3306:3306" environment: MYSQL_ROOT_PASSWORD: mysql + + sqlserver: + image: microsoft/mssql-server-linux:2017-latest + restart: always + ports: + - "127.0.0.1:1433:1433" + - "127.0.0.1:1434:1434" + environment: + SA_PASSWORD: "Sqls3rv3r" + ACCEPT_EULA: "Y" diff --git a/go.mod b/go.mod index 2b82c83..f09a85a 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/vingarcia/ksql go 1.14 require ( + github.com/denisenkom/go-mssqldb v0.10.0 // indirect github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018 github.com/go-sql-driver/mysql v1.4.0 // indirect github.com/golang/mock v1.5.0 diff --git a/go.sum b/go.sum index 2e20171..b80fabb 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,14 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/denisenkom/go-mssqldb v0.10.0 h1:QykgLZBorFE95+gO3u9esLd0BmbvpWp0/waNNZfHBM8= +github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018 h1:QsFkVafcKOaZoAB4WcyUHdkPbwh+VYwZgYJb/rU6EIM= github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018/go.mod h1:5C3SWkut69TSdkerzRDxXMRM5x73PGWNcRLe/xKjXhs= github.com/go-sql-driver/mysql v1.4.0 h1:7LxgVwFb2hIQtMm87NdgAVfXjnt4OePseqT1tKx+opk= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/mock v1.5.0 h1:jlYHihg//f7RRwuPfptm04yp4s7O6Kw8EZiVYIGcH0g= github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -28,6 +32,8 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/tj/assert v0.0.3 h1:Df/BlaZ20mq6kuai7f5z2TvPFiwC3xaWJSDQNiIS3Rk= github.com/tj/assert v0.0.3/go.mod h1:Ne6X72Q+TB1AteidzQncjw9PabbMp4PBMZ1k+vd1Pvk= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= diff --git a/json.go b/json.go index aaf2ef9..03fdda1 100644 --- a/json.go +++ b/json.go @@ -11,7 +11,8 @@ import ( // input attributes to be convertible to and from JSON // before sending or receiving it from the database. type jsonSerializable struct { - Attr interface{} + DriverName string + Attr interface{} } // Scan Implements the Scanner interface in order to load @@ -40,5 +41,9 @@ func (j *jsonSerializable) Scan(value interface{}) error { // Value Implements the Valuer interface in order to save // this field as JSON on the database. func (j jsonSerializable) Value() (driver.Value, error) { - return json.Marshal(j.Attr) + b, err := json.Marshal(j.Attr) + if j.DriverName == "sqlserver" { + return string(b), err + } + return b, err } diff --git a/ksql.go b/ksql.go index 0386b68..86ba670 100644 --- a/ksql.go +++ b/ksql.go @@ -148,7 +148,7 @@ func (c DB) Query( elemPtr = elemPtr.Elem() } - err = scanRows(rows, elemPtr.Interface()) + err = scanRows(c.dialect, rows, elemPtr.Interface()) if err != nil { return err } @@ -202,7 +202,7 @@ func (c DB) QueryOne( return ErrRecordNotFound } - err = scanRows(rows, record) + err = scanRows(c.dialect, rows, record) if err != nil { return err } @@ -262,7 +262,7 @@ func (c DB) QueryChunks( chunk = reflect.Append(chunk, elemValue) } - err = scanRows(rows, chunk.Index(idx).Addr().Interface()) + err = scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface()) if err != nil { return err } @@ -320,14 +320,14 @@ func (c DB) Insert( return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Insert method") } - query, params, err := buildInsertQuery(c.dialect, c.tableName, record, c.idCols...) + query, params, scanValues, err := buildInsertQuery(c.dialect, c.tableName, record, c.idCols...) if err != nil { return err } switch c.insertMethod { - case insertWithReturning: - err = c.insertWithReturningID(ctx, record, query, params, c.idCols) + case insertWithReturning, insertWithOutput: + err = c.insertReturningIDs(ctx, record, query, params, scanValues, c.idCols) case insertWithLastInsertID: err = c.insertWithLastInsertID(ctx, record, query, params, c.idCols[0]) case insertWithNoIDRetrieval: @@ -341,19 +341,14 @@ func (c DB) Insert( return err } -func (c DB) insertWithReturningID( +func (c DB) insertReturningIDs( ctx context.Context, record interface{}, query string, params []interface{}, + scanValues []interface{}, idNames []string, ) error { - escapedIDNames := []string{} - for _, id := range idNames { - escapedIDNames = append(escapedIDNames, c.dialect.Escape(id)) - } - query += " RETURNING " + strings.Join(idNames, ", ") - rows, err := c.db.QueryContext(ctx, query, params...) if err != nil { return err @@ -369,21 +364,7 @@ func (c DB) insertWithReturningID( return err } - v := reflect.ValueOf(record) - t := v.Type() - if err = assertStructPtr(t); err != nil { - return errors.Wrap(err, "can't write id field") - } - info := structs.GetTagInfo(t.Elem()) - - var scanFields []interface{} - for _, id := range idNames { - scanFields = append( - scanFields, - v.Elem().Field(info.ByName(id).Index).Addr().Interface(), - ) - } - err = rows.Scan(scanFields...) + err = rows.Scan(scanValues...) if err != nil { return err } @@ -549,20 +530,25 @@ func buildInsertQuery( dialect dialect, tableName string, record interface{}, - idFieldNames ...string, -) (query string, params []interface{}, err error) { + idNames ...string, +) (query string, params []interface{}, scanValues []interface{}, err error) { + v := reflect.ValueOf(record) + t := v.Type() + if err = assertStructPtr(t); err != nil { + return "", nil, nil, fmt.Errorf( + "ksql: expected record to be a pointer to struct, but got: %T", + record, + ) + } + + info := structs.GetTagInfo(t.Elem()) + recordMap, err := structs.StructToMap(record) if err != nil { - return "", nil, err + return "", nil, nil, err } - t := reflect.TypeOf(record) - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - info := structs.GetTagInfo(t) - - for _, fieldName := range idFieldNames { + for _, fieldName := range idNames { // Remove any ID field that was not set: if reflect.ValueOf(recordMap[fieldName]).IsZero() { delete(recordMap, fieldName) @@ -580,7 +566,10 @@ func buildInsertQuery( recordValue := recordMap[col] params[i] = recordValue if info.ByName(col).SerializeAsJSON { - params[i] = jsonSerializable{Attr: recordValue} + params[i] = jsonSerializable{ + DriverName: dialect.DriverName(), + Attr: recordValue, + } } valuesQuery[i] = dialect.Placeholder(i) @@ -592,14 +581,48 @@ func buildInsertQuery( escapedColumnNames = append(escapedColumnNames, dialect.Escape(col)) } + var returningQuery, outputQuery string + switch dialect.InsertMethod() { + case insertWithReturning: + escapedIDNames := []string{} + for _, id := range idNames { + escapedIDNames = append(escapedIDNames, dialect.Escape(id)) + } + returningQuery = " RETURNING " + strings.Join(escapedIDNames, ", ") + + for _, id := range idNames { + scanValues = append( + scanValues, + v.Elem().Field(info.ByName(id).Index).Addr().Interface(), + ) + } + case insertWithOutput: + escapedIDNames := []string{} + for _, id := range idNames { + escapedIDNames = append(escapedIDNames, "INSERTED."+dialect.Escape(id)) + } + outputQuery = " OUTPUT " + strings.Join(escapedIDNames, ", ") + + for _, id := range idNames { + scanValues = append( + scanValues, + v.Elem().Field(info.ByName(id).Index).Addr().Interface(), + ) + } + } + + // Note that the outputQuery and the returningQuery depend + // on the selected driver, thus, they might be empty strings. query = fmt.Sprintf( - "INSERT INTO %s (%s) VALUES (%s)", + "INSERT INTO %s (%s)%s VALUES (%s)%s", dialect.Escape(tableName), strings.Join(escapedColumnNames, ", "), + outputQuery, strings.Join(valuesQuery, ", "), + returningQuery, ) - return query, params, nil + return query, params, scanValues, nil } func buildUpdateQuery( @@ -644,7 +667,10 @@ func buildUpdateQuery( for i, k := range keys { recordValue := recordMap[k] if info.ByName(k).SerializeAsJSON { - recordValue = jsonSerializable{Attr: recordValue} + recordValue = jsonSerializable{ + DriverName: dialect.DriverName(), + Attr: recordValue, + } } args[i] = recordValue setQuery = append(setQuery, fmt.Sprintf( @@ -753,7 +779,7 @@ func (nopScanner) Scan(value interface{}) error { return nil } -func scanRows(rows *sql.Rows, record interface{}) error { +func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error { names, err := rows.Columns() if err != nil { return err @@ -782,7 +808,10 @@ func scanRows(rows *sql.Rows, record interface{}) error { if fieldInfo.Valid { valueScanner = v.Field(fieldInfo.Index).Addr().Interface() if fieldInfo.SerializeAsJSON { - valueScanner = &jsonSerializable{Attr: valueScanner} + valueScanner = &jsonSerializable{ + DriverName: dialect.DriverName(), + Attr: valueScanner, + } } } diff --git a/ksql_test.go b/ksql_test.go index 85a4691..b8132d5 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -9,6 +9,7 @@ import ( "strings" "testing" + _ "github.com/denisenkom/go-mssqldb" "github.com/ditointernet/go-assert" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" @@ -1233,6 +1234,7 @@ func TestScanRows(t *testing.T) { t.Fatal("could not create test table!, reason:", err.Error()) } + dialect := supportedDialects["sqlite3"] ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() @@ -1248,7 +1250,7 @@ func TestScanRows(t *testing.T) { assert.Equal(t, true, rows.Next()) var u User - err = scanRows(rows, &u) + err = scanRows(dialect, rows, &u) assert.Equal(t, nil, err) assert.Equal(t, "User2", u.Name) @@ -1261,6 +1263,7 @@ func TestScanRows(t *testing.T) { t.Fatal("could not create test table!, reason:", err.Error()) } + dialect := supportedDialects["sqlite3"] ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() @@ -1280,7 +1283,7 @@ func TestScanRows(t *testing.T) { // Omitted for testing purposes: // Name string `ksql:"name"` } - err = scanRows(rows, &user) + err = scanRows(dialect, rows, &user) assert.Equal(t, nil, err) assert.Equal(t, 22, user.Age) @@ -1292,6 +1295,7 @@ func TestScanRows(t *testing.T) { t.Fatal("could not create test table!, reason:", err.Error()) } + dialect := supportedDialects["sqlite3"] ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() @@ -1302,7 +1306,7 @@ func TestScanRows(t *testing.T) { var u User err = rows.Close() assert.Equal(t, nil, err) - err = scanRows(rows, &u) + err = scanRows(dialect, rows, &u) assert.NotEqual(t, nil, err) }) @@ -1312,6 +1316,7 @@ func TestScanRows(t *testing.T) { t.Fatal("could not create test table!, reason:", err.Error()) } + dialect := supportedDialects["sqlite3"] ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() @@ -1320,7 +1325,7 @@ func TestScanRows(t *testing.T) { assert.Equal(t, nil, err) var u User - err = scanRows(rows, u) + err = scanRows(dialect, rows, u) assert.NotEqual(t, nil, err) }) @@ -1330,6 +1335,7 @@ func TestScanRows(t *testing.T) { t.Fatal("could not create test table!, reason:", err.Error()) } + dialect := supportedDialects["sqlite3"] ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() @@ -1338,15 +1344,16 @@ func TestScanRows(t *testing.T) { assert.Equal(t, nil, err) var u map[string]interface{} - err = scanRows(rows, &u) + err = scanRows(dialect, rows, &u) assert.NotEqual(t, nil, err) }) } var connectionString = map[string]string{ - "postgres": "host=localhost port=5432 user=postgres password=postgres dbname=ksql sslmode=disable", - "sqlite3": "/tmp/ksql.db", - "mysql": "root:mysql@(127.0.0.1:3306)/ksql?timeout=30s", + "postgres": "host=localhost port=5432 user=postgres password=postgres dbname=ksql sslmode=disable", + "sqlite3": "/tmp/ksql.db", + "mysql": "root:mysql@(127.0.0.1:3306)/ksql?timeout=30s", + "sqlserver": "sqlserver://sa:Sqls3rv3r@127.0.0.1:1433?databaseName=ksql", } func createTable(driver string) error { @@ -1385,6 +1392,13 @@ func createTable(driver string) error { name VARCHAR(50), address JSON )`) + case "sqlserver": + _, err = db.Exec(`CREATE TABLE users ( + id INT IDENTITY(1,1) PRIMARY KEY, + age INT, + name VARCHAR(50), + address NVARCHAR(4000) + )`) } if err != nil { return fmt.Errorf("failed to create new users table: %s", err.Error()) @@ -1404,12 +1418,8 @@ func newTestDB(db *sql.DB, driver string, tableName string, ids ...string) DB { db: db, tableName: tableName, - idCols: ids, - insertMethod: map[string]insertMethod{ - "sqlite3": insertWithLastInsertID, - "postgres": insertWithReturning, - "mysql": insertWithLastInsertID, - }[driver], + idCols: ids, + insertMethod: supportedDialects[driver].InsertMethod(), } } @@ -1482,17 +1492,17 @@ func getUserByID(dbi sqlProvider, dialect dialect, result *User, id uint) error return row.Err() } - var rawAddr []byte - err := row.Scan(&result.ID, &result.Name, &result.Age, &rawAddr) + value := jsonSerializable{ + DriverName: dialect.DriverName(), + Attr: &result.Address, + } + + err := row.Scan(&result.ID, &result.Name, &result.Age, &value) if err != nil { return err } - if rawAddr == nil { - return nil - } - - return json.Unmarshal(rawAddr, &result.Address) + return nil } func getUserByName(dbi sqlProvider, dialect dialect, result *User, name string) error { From cb84b02e2e198823674d8bbb7d38417ee98f10d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Mon, 10 May 2021 09:38:45 -0300 Subject: [PATCH 05/40] Improve README so the fact that the benchmark results are good is enphasized --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 24d428d..01e2b95 100644 --- a/README.md +++ b/README.md @@ -260,7 +260,7 @@ read the example tests available on our [example service](./examples/example_ser ### Benchmark Comparison -The benchmark is not bad, as far the code is in average as fast as sqlx: +The benchmark is very good, the code is, in practical terms, as fast as sqlx: ```bash $ make bench TIME=3s From d275555df527a26b2ca6eb3f204fe2fb0bf59dcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sat, 15 May 2021 10:53:12 -0300 Subject: [PATCH 06/40] Update TODO list --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 01e2b95..c2549ee 100644 --- a/README.md +++ b/README.md @@ -307,9 +307,13 @@ make test - Add tests for tables using composite keys - Add support for serializing structs as other formats such as YAML - Update structs.FillStructWith to work with `json` tagged attributes +- Make testing easier by exposing the connection strings in an .env file +- Make testing easier by automatically creating the ksql database +- Create a way for users to submit user defined dialects ### Optimization Oportunities - Test if using a pointer on the field info is faster or not - Consider passing the cached structInfo as argument for all the functions that use it, so that we don't need to get it more than once in the same call. +- Use a cache to store all queries after they are built From edecbf8191c88e70729e389b36e467c963198501 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 16 May 2021 17:05:21 -0300 Subject: [PATCH 07/40] Add feature of omiting the "SELECT" part of the query Now the 3 functions that allow you to write plain SQL queries also work if you omit the `SELECT ...` part of the query. If you do this the code will check and notice that the first token of the query is a "FROM" token and then automatically build the SELECT part of the query based on the tags of the struct. Everything is cached, so the impact on performance should be negligible. The affected functions are: - Query() - QueryOne() - QueryChunks() --- go.mod | 4 +- ksql.go | 69 +++ ksql_test.go | 1172 +++++++++++++++++++++++--------------------- structs/structs.go | 4 + 4 files changed, 687 insertions(+), 562 deletions(-) diff --git a/go.mod b/go.mod index f09a85a..3dac923 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,9 @@ module github.com/vingarcia/ksql go 1.14 require ( - github.com/denisenkom/go-mssqldb v0.10.0 // indirect + github.com/denisenkom/go-mssqldb v0.10.0 github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018 - github.com/go-sql-driver/mysql v1.4.0 // indirect + github.com/go-sql-driver/mysql v1.4.0 github.com/golang/mock v1.5.0 github.com/jmoiron/sqlx v1.2.0 github.com/lib/pq v1.1.1 diff --git a/ksql.go b/ksql.go index 86ba670..1c30a8b 100644 --- a/ksql.go +++ b/ksql.go @@ -6,11 +6,20 @@ import ( "fmt" "reflect" "strings" + "unicode" "github.com/pkg/errors" "github.com/vingarcia/ksql/structs" ) +var selectQueryCache = map[string]map[reflect.Type]string{} + +func init() { + for dname := range supportedDialects { + selectQueryCache[dname] = map[reflect.Type]string{} + } +} + // DB represents the ksql client responsible for // interfacing with the "database/sql" package implementing // the KissSQL interface `SQLProvider`. @@ -124,6 +133,14 @@ func (c DB) Query( slice = slice.Slice(0, 0) } + if strings.ToUpper(getFirstToken(query)) == "FROM" { + selectPrefix, err := buildSelectQuery(c.dialect, structType, selectQueryCache[c.dialect.DriverName()]) + if err != nil { + return err + } + query = selectPrefix + query + } + rows, err := c.db.QueryContext(ctx, query, params...) if err != nil { return fmt.Errorf("error running query: %s", err.Error()) @@ -189,6 +206,14 @@ func (c DB) QueryOne( return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record) } + if strings.ToUpper(getFirstToken(query)) == "FROM" { + selectPrefix, err := buildSelectQuery(c.dialect, t, selectQueryCache[c.dialect.DriverName()]) + if err != nil { + return err + } + query = selectPrefix + query + } + rows, err := c.db.QueryContext(ctx, query, params...) if err != nil { return err @@ -243,6 +268,14 @@ func (c DB) QueryChunks( return err } + if strings.ToUpper(getFirstToken(parser.Query)) == "FROM" { + selectPrefix, err := buildSelectQuery(c.dialect, structType, selectQueryCache[c.dialect.DriverName()]) + if err != nil { + return err + } + parser.Query = selectPrefix + parser.Query + } + rows, err := c.db.QueryContext(ctx, parser.Query, parser.Params...) if err != nil { return err @@ -869,3 +902,39 @@ func buildCompositeKeyDeleteQuery( strings.Join(values, ","), ), params } + +// We implemented this function instead of using +// a regex or strings.Fields because we wanted +// to preserve the performance of the package. +func getFirstToken(s string) string { + s = strings.TrimLeftFunc(s, unicode.IsSpace) + + var token strings.Builder + for _, c := range s { + if unicode.IsSpace(c) { + break + } + token.WriteRune(c) + } + return token.String() +} + +func buildSelectQuery( + dialect dialect, + structType reflect.Type, + selectQueryCache map[reflect.Type]string, +) (string, error) { + if selectQuery, found := selectQueryCache[structType]; found { + return selectQuery, nil + } + + info := structs.GetTagInfo(structType) + var fields []string + for _, field := range info.Fields() { + fields = append(fields, dialect.Escape(field.Name)) + } + + query := "SELECT " + strings.Join(fields, ", ") + " " + selectQueryCache[structType] = query + return query, nil +} diff --git a/ksql_test.go b/ksql_test.go index b8132d5..19a6cad 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -37,144 +37,162 @@ type Address struct { func TestQuery(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { - t.Run("using slice of structs", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } + variations := []struct { + desc string + queryPrefix string + }{ + { + desc: "with select *", + queryPrefix: "SELECT * ", + }, + { + desc: "building the SELECT part of the query internally", + queryPrefix: "", + }, + } + for _, variation := range variations { + t.Run(variation.desc, func(t *testing.T) { + t.Run("using slice of structs", func(t *testing.T) { + err := createTable(driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } - t.Run("should return 0 results correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + t.Run("should return 0 results correctly", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() - ctx := context.Background() - c := newTestDB(db, driver, "users") - var users []User - err := c.Query(ctx, &users, `SELECT * FROM users WHERE id=1;`) - assert.Equal(t, nil, err) - assert.Equal(t, []User(nil), users) + ctx := context.Background() + c := newTestDB(db, driver, "users") + var users []User + err := c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`) + assert.Equal(t, nil, err) + assert.Equal(t, 0, len(users)) - users = []User{} - err = c.Query(ctx, &users, `SELECT * FROM users WHERE id=1;`) - assert.Equal(t, nil, err) - assert.Equal(t, []User{}, users) + users = []User{} + err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`) + assert.Equal(t, nil, err) + assert.Equal(t, 0, len(users)) + }) + + t.Run("should return a user correctly", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + + ctx := context.Background() + c := newTestDB(db, driver, "users") + var users []User + err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") + + assert.Equal(t, nil, err) + assert.Equal(t, 1, len(users)) + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "Bia", users[0].Name) + assert.Equal(t, "BR", users[0].Address.Country) + }) + + t.Run("should return multiple users correctly", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + + _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + + ctx := context.Background() + c := newTestDB(db, driver, "users") + var users []User + err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia") + + assert.Equal(t, nil, err) + assert.Equal(t, 2, len(users)) + + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "João Garcia", users[0].Name) + assert.Equal(t, "US", users[0].Address.Country) + + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "Bia Garcia", users[1].Name) + assert.Equal(t, "BR", users[1].Address.Country) + }) + }) + + t.Run("using slice of pointers to structs", func(t *testing.T) { + err := createTable(driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + t.Run("should return 0 results correctly", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver, "users") + var users []*User + err := c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`) + assert.Equal(t, nil, err) + assert.Equal(t, 0, len(users)) + + users = []*User{} + err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`) + assert.Equal(t, nil, err) + assert.Equal(t, 0, len(users)) + }) + + t.Run("should return a user correctly", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + + ctx := context.Background() + c := newTestDB(db, driver, "users") + var users []*User + err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") + + assert.Equal(t, nil, err) + assert.Equal(t, 1, len(users)) + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "Bia", users[0].Name) + assert.Equal(t, "BR", users[0].Address.Country) + }) + + t.Run("should return multiple users correctly", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + + _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + + ctx := context.Background() + c := newTestDB(db, driver, "users") + var users []*User + err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia") + + assert.Equal(t, nil, err) + assert.Equal(t, 2, len(users)) + + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "João Garcia", users[0].Name) + assert.Equal(t, "US", users[0].Address.Country) + + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "Bia Garcia", users[1].Name) + assert.Equal(t, "BR", users[1].Address.Country) + }) + }) }) - - t.Run("should return a user correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) - - ctx := context.Background() - c := newTestDB(db, driver, "users") - var users []User - err = c.Query(ctx, &users, `SELECT * FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") - - assert.Equal(t, nil, err) - assert.Equal(t, 1, len(users)) - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "Bia", users[0].Name) - assert.Equal(t, "BR", users[0].Address.Country) - }) - - t.Run("should return multiple users correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`) - assert.Equal(t, nil, err) - - _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) - - ctx := context.Background() - c := newTestDB(db, driver, "users") - var users []User - err = c.Query(ctx, &users, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia") - - assert.Equal(t, nil, err) - assert.Equal(t, 2, len(users)) - - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "João Garcia", users[0].Name) - assert.Equal(t, "US", users[0].Address.Country) - - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "Bia Garcia", users[1].Name) - assert.Equal(t, "BR", users[1].Address.Country) - }) - }) - - t.Run("using slice of pointers to structs", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - t.Run("should return 0 results correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - var users []*User - err := c.Query(ctx, &users, `SELECT * FROM users WHERE id=1;`) - assert.Equal(t, nil, err) - assert.Equal(t, []*User(nil), users) - - users = []*User{} - err = c.Query(ctx, &users, `SELECT * FROM users WHERE id=1;`) - assert.Equal(t, nil, err) - assert.Equal(t, []*User{}, users) - }) - - t.Run("should return a user correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) - - ctx := context.Background() - c := newTestDB(db, driver, "users") - var users []*User - err = c.Query(ctx, &users, `SELECT * FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") - - assert.Equal(t, nil, err) - assert.Equal(t, 1, len(users)) - assert.Equal(t, "Bia", users[0].Name) - assert.NotEqual(t, uint(0), users[0].ID) - }) - - t.Run("should return multiple users correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`) - assert.Equal(t, nil, err) - - _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) - - ctx := context.Background() - c := newTestDB(db, driver, "users") - var users []*User - err = c.Query(ctx, &users, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia") - - assert.Equal(t, nil, err) - assert.Equal(t, 2, len(users)) - - assert.Equal(t, "João Garcia", users[0].Name) - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "US", users[0].Address.Country) - - assert.Equal(t, "Bia Garcia", users[1].Name) - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "BR", users[1].Address.Country) - }) - }) + } t.Run("testing error cases", func(t *testing.T) { err := createTable(driver) @@ -226,64 +244,81 @@ func TestQuery(t *testing.T) { func TestQueryOne(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) + variations := []struct { + desc string + queryPrefix string + }{ + { + desc: "with select *", + queryPrefix: "SELECT * ", + }, + { + desc: "building the SELECT part of the query internally", + queryPrefix: "", + }, } + for _, variation := range variations { + err := createTable(driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } - t.Run("should return RecordNotFoundErr when there are no results", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + t.Run(variation.desc, func(t *testing.T) { + t.Run("should return RecordNotFoundErr when there are no results", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() - ctx := context.Background() - c := newTestDB(db, driver, "users") - u := User{} - err := c.QueryOne(ctx, &u, `SELECT * FROM users WHERE id=1;`) - assert.Equal(t, ErrRecordNotFound, err) - }) + ctx := context.Background() + c := newTestDB(db, driver, "users") + u := User{} + err := c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE id=1;`) + assert.Equal(t, ErrRecordNotFound, err) + }) - t.Run("should return a user correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + t.Run("should return a user correctly", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) + _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) - ctx := context.Background() - c := newTestDB(db, driver, "users") - u := User{} - err = c.QueryOne(ctx, &u, `SELECT * FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") + ctx := context.Background() + c := newTestDB(db, driver, "users") + u := User{} + err = c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") - assert.Equal(t, nil, err) - assert.NotEqual(t, uint(0), u.ID) - assert.Equal(t, "Bia", u.Name) - assert.Equal(t, Address{ - Country: "BR", - }, u.Address) - }) + assert.Equal(t, nil, err) + assert.NotEqual(t, uint(0), u.ID) + assert.Equal(t, "Bia", u.Name) + assert.Equal(t, Address{ + Country: "BR", + }, u.Address) + }) - t.Run("should return only the first result on multiples matches", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + t.Run("should return only the first result on multiples matches", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Andréa Sá', 0, '{"country":"US"}')`) - assert.Equal(t, nil, err) + _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Andréa Sá', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) - _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Caio Sá', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) + _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Caio Sá', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) - ctx := context.Background() - c := newTestDB(db, driver, "users") + ctx := context.Background() + c := newTestDB(db, driver, "users") - var u User - err = c.QueryOne(ctx, &u, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0)+` ORDER BY id ASC`, "% Sá") - assert.Equal(t, nil, err) - assert.Equal(t, "Andréa Sá", u.Name) - assert.Equal(t, 0, u.Age) - assert.Equal(t, Address{ - Country: "US", - }, u.Address) - }) + var u User + err = c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0)+` ORDER BY id ASC`, "% Sá") + assert.Equal(t, nil, err) + assert.Equal(t, "Andréa Sá", u.Name) + assert.Equal(t, 0, u.Age) + assert.Equal(t, Address{ + Country: "US", + }, u.Address) + }) + }) + } t.Run("should report error if input is not a pointer to struct", func(t *testing.T) { db := connectDB(t, driver) @@ -312,7 +347,7 @@ func TestQueryOne(t *testing.T) { ctx := context.Background() c := newTestDB(db, driver, "users") var user User - err = c.QueryOne(ctx, &user, `SELECT * FROM not a valid query`) + err := c.QueryOne(ctx, &user, `SELECT * FROM not a valid query`) assert.NotEqual(t, nil, err) }) }) @@ -762,397 +797,414 @@ func TestUpdate(t *testing.T) { func TestQueryChunks(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { - t.Run("should query a single row correctly", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - - _ = c.Insert(ctx, &User{ - Name: "User1", - Address: Address{Country: "BR"}, - }) - - var length int - var u User - err = c.QueryChunks(ctx, ChunkParser{ - Query: `SELECT * FROM users WHERE name = ` + c.dialect.Placeholder(0), - Params: []interface{}{"User1"}, - - ChunkSize: 100, - ForEachChunk: func(users []User) error { - length = len(users) - if length > 0 { - u = users[0] + variations := []struct { + desc string + queryPrefix string + }{ + { + desc: "with select *", + queryPrefix: "SELECT * ", + }, + { + desc: "building the SELECT part of the query internally", + queryPrefix: "", + }, + } + for _, variation := range variations { + t.Run(variation.desc, func(t *testing.T) { + t.Run("should query a single row correctly", func(t *testing.T) { + err := createTable(driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) } - return nil - }, - }) - assert.Equal(t, nil, err) - assert.Equal(t, 1, length) - assert.NotEqual(t, uint(0), u.ID) - assert.Equal(t, "User1", u.Name) - assert.Equal(t, "BR", u.Address.Country) - }) + db := connectDB(t, driver) + defer db.Close() - t.Run("should query one chunk correctly", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } + ctx := context.Background() + c := newTestDB(db, driver, "users") - db := connectDB(t, driver) - defer db.Close() + _ = c.Insert(ctx, &User{ + Name: "User1", + Address: Address{Country: "BR"}, + }) - ctx := context.Background() - c := newTestDB(db, driver, "users") + var length int + var u User + err = c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `FROM users WHERE name = ` + c.dialect.Placeholder(0), + Params: []interface{}{"User1"}, - _ = c.Insert(ctx, &User{Name: "User1", Address: Address{Country: "US"}}) - _ = c.Insert(ctx, &User{Name: "User2", Address: Address{Country: "BR"}}) + ChunkSize: 100, + ForEachChunk: func(users []User) error { + length = len(users) + if length > 0 { + u = users[0] + } + return nil + }, + }) - var lengths []int - var users []User - err = c.QueryChunks(ctx, ChunkParser{ - Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, - Params: []interface{}{"User%"}, - - ChunkSize: 2, - ForEachChunk: func(buffer []User) error { - users = append(users, buffer...) - lengths = append(lengths, len(buffer)) - return nil - }, - }) - - assert.Equal(t, nil, err) - assert.Equal(t, 1, len(lengths)) - assert.Equal(t, 2, lengths[0]) - - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "User1", users[0].Name) - assert.Equal(t, "US", users[0].Address.Country) - - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "User2", users[1].Name) - assert.Equal(t, "BR", users[1].Address.Country) - }) - - t.Run("should query chunks of 1 correctly", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - - _ = c.Insert(ctx, &User{Name: "User1", Address: Address{Country: "US"}}) - _ = c.Insert(ctx, &User{Name: "User2", Address: Address{Country: "BR"}}) - - var lengths []int - var users []User - err = c.QueryChunks(ctx, ChunkParser{ - Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, - Params: []interface{}{"User%"}, - - ChunkSize: 1, - ForEachChunk: func(buffer []User) error { - lengths = append(lengths, len(buffer)) - users = append(users, buffer...) - return nil - }, - }) - - assert.Equal(t, nil, err) - assert.Equal(t, 2, len(users)) - assert.Equal(t, []int{1, 1}, lengths) - - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "User1", users[0].Name) - assert.Equal(t, "US", users[0].Address.Country) - - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "User2", users[1].Name) - assert.Equal(t, "BR", users[1].Address.Country) - }) - - t.Run("should load partially filled chunks correctly", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) - - var lengths []int - var users []User - err = c.QueryChunks(ctx, ChunkParser{ - Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, - Params: []interface{}{"User%"}, - - ChunkSize: 2, - ForEachChunk: func(buffer []User) error { - lengths = append(lengths, len(buffer)) - users = append(users, buffer...) - return nil - }, - }) - - assert.Equal(t, nil, err) - assert.Equal(t, 3, len(users)) - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "User1", users[0].Name) - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "User2", users[1].Name) - assert.NotEqual(t, uint(0), users[2].ID) - assert.Equal(t, "User3", users[2].Name) - assert.Equal(t, []int{2, 1}, lengths) - }) - - t.Run("should abort the first iteration when the callback returns an ErrAbortIteration", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) - - var lengths []int - var users []User - err = c.QueryChunks(ctx, ChunkParser{ - Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, - Params: []interface{}{"User%"}, - - ChunkSize: 2, - ForEachChunk: func(buffer []User) error { - lengths = append(lengths, len(buffer)) - users = append(users, buffer...) - return ErrAbortIteration - }, - }) - - assert.Equal(t, nil, err) - assert.Equal(t, 2, len(users)) - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "User1", users[0].Name) - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "User2", users[1].Name) - assert.Equal(t, []int{2}, lengths) - }) - - t.Run("should abort the last iteration when the callback returns an ErrAbortIteration", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) - - returnVals := []error{nil, ErrAbortIteration} - var lengths []int - var users []User - err = c.QueryChunks(ctx, ChunkParser{ - Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, - Params: []interface{}{"User%"}, - - ChunkSize: 2, - ForEachChunk: func(buffer []User) error { - lengths = append(lengths, len(buffer)) - users = append(users, buffer...) - - return shiftErrSlice(&returnVals) - }, - }) - - assert.Equal(t, nil, err) - assert.Equal(t, 3, len(users)) - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "User1", users[0].Name) - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "User2", users[1].Name) - assert.NotEqual(t, uint(0), users[2].ID) - assert.Equal(t, "User3", users[2].Name) - assert.Equal(t, []int{2, 1}, lengths) - }) - - t.Run("should return error if the callback returns an error in the first iteration", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) - - var lengths []int - var users []User - err = c.QueryChunks(ctx, ChunkParser{ - Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, - Params: []interface{}{"User%"}, - - ChunkSize: 2, - ForEachChunk: func(buffer []User) error { - lengths = append(lengths, len(buffer)) - users = append(users, buffer...) - return errors.New("fake error msg") - }, - }) - - assert.NotEqual(t, nil, err) - assert.Equal(t, 2, len(users)) - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "User1", users[0].Name) - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "User2", users[1].Name) - assert.Equal(t, []int{2}, lengths) - }) - - t.Run("should return error if the callback returns an error in the last iteration", func(t *testing.T) { - err := createTable(driver) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } - - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) - - returnVals := []error{nil, errors.New("fake error msg")} - var lengths []int - var users []User - err = c.QueryChunks(ctx, ChunkParser{ - Query: `select * from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, - Params: []interface{}{"User%"}, - - ChunkSize: 2, - ForEachChunk: func(buffer []User) error { - lengths = append(lengths, len(buffer)) - users = append(users, buffer...) - - return shiftErrSlice(&returnVals) - }, - }) - - assert.NotEqual(t, nil, err) - assert.Equal(t, 3, len(users)) - assert.NotEqual(t, uint(0), users[0].ID) - assert.Equal(t, "User1", users[0].Name) - assert.NotEqual(t, uint(0), users[1].ID) - assert.Equal(t, "User2", users[1].Name) - assert.NotEqual(t, uint(0), users[2].ID) - assert.Equal(t, "User3", users[2].Name) - assert.Equal(t, []int{2, 1}, lengths) - }) - - t.Run("should report error if the input function is invalid", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - ctx := context.Background() - c := newTestDB(db, driver, "users") - - funcs := []interface{}{ - nil, - "not a function", - func() error { - return nil - }, - func(extraInputValue []User, extra []User) error { - return nil - }, - func(invalidArgType string) error { - return nil - }, - func(missingReturnType []User) { - return - }, - func(users []User) string { - return "" - }, - func(extraReturnValue []User) ([]User, error) { - return nil, nil - }, - func(notSliceOfStructs []string) error { - return nil - }, - } - - for _, fn := range funcs { - err := c.QueryChunks(ctx, ChunkParser{ - Query: `SELECT * FROM users`, - Params: []interface{}{}, - - ChunkSize: 2, - ForEachChunk: fn, + assert.Equal(t, nil, err) + assert.Equal(t, 1, length) + assert.NotEqual(t, uint(0), u.ID) + assert.Equal(t, "User1", u.Name) + assert.Equal(t, "BR", u.Address.Country) }) - assert.NotEqual(t, nil, err) - } - }) - t.Run("should report error if the query is not valid", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + t.Run("should query one chunk correctly", func(t *testing.T) { + err := createTable(driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } - ctx := context.Background() - c := newTestDB(db, driver, "users") - err := c.QueryChunks(ctx, ChunkParser{ - Query: `SELECT * FROM not a valid query`, - Params: []interface{}{}, + db := connectDB(t, driver) + defer db.Close() - ChunkSize: 2, - ForEachChunk: func(buffer []User) error { - return nil - }, + ctx := context.Background() + c := newTestDB(db, driver, "users") + + _ = c.Insert(ctx, &User{Name: "User1", Address: Address{Country: "US"}}) + _ = c.Insert(ctx, &User{Name: "User2", Address: Address{Country: "BR"}}) + + var lengths []int + var users []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + ForEachChunk: func(buffer []User) error { + users = append(users, buffer...) + lengths = append(lengths, len(buffer)) + return nil + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 1, len(lengths)) + assert.Equal(t, 2, lengths[0]) + + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.Equal(t, "US", users[0].Address.Country) + + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.Equal(t, "BR", users[1].Address.Country) + }) + + t.Run("should query chunks of 1 correctly", func(t *testing.T) { + err := createTable(driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver, "users") + + _ = c.Insert(ctx, &User{Name: "User1", Address: Address{Country: "US"}}) + _ = c.Insert(ctx, &User{Name: "User2", Address: Address{Country: "BR"}}) + + var lengths []int + var users []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 1, + ForEachChunk: func(buffer []User) error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + return nil + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 2, len(users)) + assert.Equal(t, []int{1, 1}, lengths) + + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.Equal(t, "US", users[0].Address.Country) + + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.Equal(t, "BR", users[1].Address.Country) + }) + + t.Run("should load partially filled chunks correctly", func(t *testing.T) { + err := createTable(driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver, "users") + + _ = c.Insert(ctx, &User{Name: "User1"}) + _ = c.Insert(ctx, &User{Name: "User2"}) + _ = c.Insert(ctx, &User{Name: "User3"}) + + var lengths []int + var users []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + ForEachChunk: func(buffer []User) error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + return nil + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 3, len(users)) + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.NotEqual(t, uint(0), users[2].ID) + assert.Equal(t, "User3", users[2].Name) + assert.Equal(t, []int{2, 1}, lengths) + }) + + t.Run("should abort the first iteration when the callback returns an ErrAbortIteration", func(t *testing.T) { + err := createTable(driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver, "users") + + _ = c.Insert(ctx, &User{Name: "User1"}) + _ = c.Insert(ctx, &User{Name: "User2"}) + _ = c.Insert(ctx, &User{Name: "User3"}) + + var lengths []int + var users []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + ForEachChunk: func(buffer []User) error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + return ErrAbortIteration + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 2, len(users)) + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.Equal(t, []int{2}, lengths) + }) + + t.Run("should abort the last iteration when the callback returns an ErrAbortIteration", func(t *testing.T) { + err := createTable(driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver, "users") + + _ = c.Insert(ctx, &User{Name: "User1"}) + _ = c.Insert(ctx, &User{Name: "User2"}) + _ = c.Insert(ctx, &User{Name: "User3"}) + + returnVals := []error{nil, ErrAbortIteration} + var lengths []int + var users []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + ForEachChunk: func(buffer []User) error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + + return shiftErrSlice(&returnVals) + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 3, len(users)) + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.NotEqual(t, uint(0), users[2].ID) + assert.Equal(t, "User3", users[2].Name) + assert.Equal(t, []int{2, 1}, lengths) + }) + + t.Run("should return error if the callback returns an error in the first iteration", func(t *testing.T) { + err := createTable(driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver, "users") + + _ = c.Insert(ctx, &User{Name: "User1"}) + _ = c.Insert(ctx, &User{Name: "User2"}) + _ = c.Insert(ctx, &User{Name: "User3"}) + + var lengths []int + var users []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + ForEachChunk: func(buffer []User) error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + return errors.New("fake error msg") + }, + }) + + assert.NotEqual(t, nil, err) + assert.Equal(t, 2, len(users)) + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.Equal(t, []int{2}, lengths) + }) + + t.Run("should return error if the callback returns an error in the last iteration", func(t *testing.T) { + err := createTable(driver) + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver, "users") + + _ = c.Insert(ctx, &User{Name: "User1"}) + _ = c.Insert(ctx, &User{Name: "User2"}) + _ = c.Insert(ctx, &User{Name: "User3"}) + + returnVals := []error{nil, errors.New("fake error msg")} + var lengths []int + var users []User + err = c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `from users where name like ` + c.dialect.Placeholder(0) + ` order by name asc;`, + Params: []interface{}{"User%"}, + + ChunkSize: 2, + ForEachChunk: func(buffer []User) error { + lengths = append(lengths, len(buffer)) + users = append(users, buffer...) + + return shiftErrSlice(&returnVals) + }, + }) + + assert.NotEqual(t, nil, err) + assert.Equal(t, 3, len(users)) + assert.NotEqual(t, uint(0), users[0].ID) + assert.Equal(t, "User1", users[0].Name) + assert.NotEqual(t, uint(0), users[1].ID) + assert.Equal(t, "User2", users[1].Name) + assert.NotEqual(t, uint(0), users[2].ID) + assert.Equal(t, "User3", users[2].Name) + assert.Equal(t, []int{2, 1}, lengths) + }) + + t.Run("should report error if the input function is invalid", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver, "users") + + funcs := []interface{}{ + nil, + "not a function", + func() error { + return nil + }, + func(extraInputValue []User, extra []User) error { + return nil + }, + func(invalidArgType string) error { + return nil + }, + func(missingReturnType []User) { + return + }, + func(users []User) string { + return "" + }, + func(extraReturnValue []User) ([]User, error) { + return nil, nil + }, + func(notSliceOfStructs []string) error { + return nil + }, + } + + for _, fn := range funcs { + err := c.QueryChunks(ctx, ChunkParser{ + Query: variation.queryPrefix + `FROM users`, + Params: []interface{}{}, + + ChunkSize: 2, + ForEachChunk: fn, + }) + assert.NotEqual(t, nil, err) + } + }) + + t.Run("should report error if the query is not valid", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver, "users") + err := c.QueryChunks(ctx, ChunkParser{ + Query: `SELECT * FROM not a valid query`, + Params: []interface{}{}, + + ChunkSize: 2, + ForEachChunk: func(buffer []User) error { + return nil + }, + }) + assert.NotEqual(t, nil, err) + }) }) - assert.NotEqual(t, nil, err) - }) + } }) } } diff --git a/structs/structs.go b/structs/structs.go index 7c7941d..63d7d57 100644 --- a/structs/structs.go +++ b/structs/structs.go @@ -42,6 +42,10 @@ func (s structInfo) Add(field fieldInfo) { s.byName[field.Name] = &field } +func (s structInfo) Fields() map[int]*fieldInfo { + return s.byIndex +} + // This cache is kept as a pkg variable // because the total number of types on a program // should be finite. So keeping a single cache here From d8ca3cab8d390a6a1a9d033d8df8e879a96fd3dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Wed, 19 May 2021 23:44:03 -0300 Subject: [PATCH 08/40] Improve README intro --- README.md | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c2549ee..d7f3bb9 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,11 @@ # KissSQL -Welcome to the KissSQL project, the Keep It Stupid Simple sql package. +If the thing you hate the most when coding is having too much unnecessary +abstractions and the second thing you hate the most is having verbose +and repetitive code for routine tasks this library is probably for you. + +Welcome to the KissSQL project, the "Keep It Stupid Simple" sql client for Go. This package was created to be used by any developer efficiently and safely. The goals were: @@ -10,17 +14,20 @@ The goals were: - To be hard to make mistakes - To have a small API so it's easy to learn - To be easy to mock and test (very easy) -- To be above all readable. +- And above all to be readable. **Supported Drivers:** -Currently we support 4 Drivers: +Currently we support only the 4 most popular Golang database drivers: - `"postgres"` - `"sqlite3"` - `"mysql"` - `"sqlserver"` +If you need a new one included please open an issue or make +your own implementation and submit a Pull Request. + ### Why KissSQL? > Note: If you want numbers see our Benchmark section below From 0d3a75fe422e698248de457bdb1fd088c842ae62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 23 May 2021 11:28:16 -0300 Subject: [PATCH 09/40] Add feature of nesting structs so we can reuse existing structs --- ksql.go | 100 ++++++++++++++++++++++++++++---- ksql_test.go | 140 +++++++++++++++++++++++++++++++++++++-------- structs/structs.go | 81 ++++++++++++++++++-------- 3 files changed, 261 insertions(+), 60 deletions(-) diff --git a/ksql.go b/ksql.go index 1c30a8b..52674db 100644 --- a/ksql.go +++ b/ksql.go @@ -813,11 +813,6 @@ func (nopScanner) Scan(value interface{}) error { } func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error { - names, err := rows.Columns() - if err != nil { - return err - } - v := reflect.ValueOf(record) t := v.Type() if t.Kind() != reflect.Ptr { @@ -833,6 +828,53 @@ func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error { info := structs.GetTagInfo(t) + var scanArgs []interface{} + if info.IsNestedStruct { + // This version is positional meaning that it expect the arguments + // to follow an specific order. It's ok because we don't allow the + // user to type the "SELECT" part of the query for nested structs. + scanArgs = getScanArgsForNestedStructs(dialect, rows, t, v, info) + } else { + names, err := rows.Columns() + if err != nil { + return err + } + // Since this version uses the names of the columns it works + // with any order of attributes/columns. + scanArgs = getScanArgsFromNames(dialect, names, v, info) + } + + return rows.Scan(scanArgs...) +} + +func getScanArgsForNestedStructs(dialect dialect, rows *sql.Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) []interface{} { + scanArgs := []interface{}{} + for i := 0; i < v.NumField(); i++ { + // TODO(vingarcia00): Handle case where type is pointer + nestedStructInfo := structs.GetTagInfo(t.Field(i).Type) + nestedStructValue := v.Field(i) + for j := 0; j < nestedStructValue.NumField(); j++ { + fieldInfo := nestedStructInfo.ByIndex(j) + + valueScanner := nopScannerValue + if fieldInfo.Valid { + valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface() + if fieldInfo.SerializeAsJSON { + valueScanner = &jsonSerializable{ + DriverName: dialect.DriverName(), + Attr: valueScanner, + } + } + } + + scanArgs = append(scanArgs, valueScanner) + } + } + + return scanArgs +} + +func getScanArgsFromNames(dialect dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} { scanArgs := []interface{}{} for _, name := range names { fieldInfo := info.ByName(name) @@ -851,7 +893,7 @@ func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error { scanArgs = append(scanArgs, valueScanner) } - return rows.Scan(scanArgs...) + return scanArgs } func buildSingleKeyDeleteQuery( @@ -923,18 +965,54 @@ func buildSelectQuery( dialect dialect, structType reflect.Type, selectQueryCache map[reflect.Type]string, -) (string, error) { +) (query string, err error) { if selectQuery, found := selectQueryCache[structType]; found { return selectQuery, nil } info := structs.GetTagInfo(structType) - var fields []string - for _, field := range info.Fields() { - fields = append(fields, dialect.Escape(field.Name)) + if info.IsNestedStruct { + query, err = buildSelectQueryForNestedStructs(dialect, structType, info) + if err != nil { + return "", err + } + } else { + query = buildSelectQueryForPlainStructs(dialect, structType, info) } - query := "SELECT " + strings.Join(fields, ", ") + " " selectQueryCache[structType] = query return query, nil } + +func buildSelectQueryForPlainStructs( + dialect dialect, + structType reflect.Type, + info structs.StructInfo, +) string { + var fields []string + for i := 0; i < structType.NumField(); i++ { + fields = append(fields, dialect.Escape(info.ByIndex(i).Name)) + } + + return "SELECT " + strings.Join(fields, ", ") + " " +} + +func buildSelectQueryForNestedStructs( + dialect dialect, + structType reflect.Type, + info structs.StructInfo, +) (string, error) { + var fields []string + for i := 0; i < structType.NumField(); i++ { + nestedStructName := info.ByIndex(i).Name + nestedStructInfo := structs.GetTagInfo(structType.Field(i).Type) + for j := 0; j < structType.Field(i).Type.NumField(); j++ { + fields = append( + fields, + dialect.Escape(nestedStructName)+"."+dialect.Escape(nestedStructInfo.ByIndex(j).Name), + ) + } + } + + return "SELECT " + strings.Join(fields, ", ") + " ", nil +} diff --git a/ksql_test.go b/ksql_test.go index 19a6cad..b7dfe4e 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -34,6 +34,12 @@ type Address struct { Country string `json:"country"` } +type Post struct { + ID int `ksql:"id"` + UserID uint `ksql:"user_id"` + Title string `ksql:"title"` +} + func TestQuery(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { @@ -53,7 +59,7 @@ func TestQuery(t *testing.T) { for _, variation := range variations { t.Run(variation.desc, func(t *testing.T) { t.Run("using slice of structs", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -123,7 +129,7 @@ func TestQuery(t *testing.T) { }) t.Run("using slice of pointers to structs", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -190,12 +196,66 @@ func TestQuery(t *testing.T) { assert.Equal(t, "Bia Garcia", users[1].Name) assert.Equal(t, "BR", users[1].Address.Country) }) + + t.Run("should query joined tables correctly", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + // This test only makes sense with no query prefix + if variation.queryPrefix != "" { + return + } + + _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + var joaoID uint + db.QueryRow(`SELECT id FROM users WHERE name = 'João Ribeiro'`).Scan(&joaoID) + + _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia Ribeiro', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + var biaID uint + db.QueryRow(`SELECT id FROM users WHERE name = 'Bia Ribeiro'`).Scan(&biaID) + + _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, biaID, `, 'Bia Post1')`)) + assert.Equal(t, nil, err) + _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, biaID, `, 'Bia Post2')`)) + assert.Equal(t, nil, err) + _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joaoID, `, 'João Post1')`)) + assert.Equal(t, nil, err) + + ctx := context.Background() + c := newTestDB(db, driver, "users") + var rows []*struct { + User User `tablename:"u"` + Post Post `tablename:"p"` + } + err = c.Query(ctx, &rows, fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), "% Ribeiro") + + assert.Equal(t, nil, err) + assert.Equal(t, 3, len(rows)) + + assert.Equal(t, joaoID, rows[0].User.ID) + assert.Equal(t, "João Ribeiro", rows[0].User.Name) + assert.Equal(t, "João Post1", rows[0].Post.Title) + + assert.Equal(t, biaID, rows[1].User.ID) + assert.Equal(t, "Bia Ribeiro", rows[1].User.Name) + assert.Equal(t, "Bia Post1", rows[1].Post.Title) + + assert.Equal(t, biaID, rows[2].User.ID) + assert.Equal(t, "Bia Ribeiro", rows[2].User.Name) + assert.Equal(t, "Bia Post2", rows[2].Post.Title) + }) }) }) } t.Run("testing error cases", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -258,7 +318,7 @@ func TestQueryOne(t *testing.T) { }, } for _, variation := range variations { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -358,7 +418,7 @@ func TestInsert(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { t.Run("using slice of structs", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -428,7 +488,7 @@ func TestInsert(t *testing.T) { }) t.Run("testing error cases", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -485,7 +545,7 @@ func TestInsert(t *testing.T) { func TestDelete(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -628,7 +688,7 @@ func TestDelete(t *testing.T) { func TestUpdate(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -813,7 +873,7 @@ func TestQueryChunks(t *testing.T) { for _, variation := range variations { t.Run(variation.desc, func(t *testing.T) { t.Run("should query a single row correctly", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -853,7 +913,7 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should query one chunk correctly", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -895,7 +955,7 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should query chunks of 1 correctly", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -937,7 +997,7 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should load partially filled chunks correctly", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -978,7 +1038,7 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should abort the first iteration when the callback returns an ErrAbortIteration", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1017,7 +1077,7 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should abort the last iteration when the callback returns an ErrAbortIteration", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1060,7 +1120,7 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should return error if the callback returns an error in the first iteration", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1099,7 +1159,7 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should return error if the callback returns an error in the last iteration", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1213,7 +1273,7 @@ func TestTransaction(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { t.Run("should query a single row correctly", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1240,7 +1300,7 @@ func TestTransaction(t *testing.T) { }) t.Run("should rollback when there are errors", func(t *testing.T) { - err := createTable(driver) + err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1281,7 +1341,7 @@ func TestTransaction(t *testing.T) { func TestScanRows(t *testing.T) { t.Run("should scan users correctly", func(t *testing.T) { - err := createTable("sqlite3") + err := createTables("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1310,7 +1370,7 @@ func TestScanRows(t *testing.T) { }) t.Run("should ignore extra columns from query", func(t *testing.T) { - err := createTable("sqlite3") + err := createTables("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1342,7 +1402,7 @@ func TestScanRows(t *testing.T) { }) t.Run("should report error for closed rows", func(t *testing.T) { - err := createTable("sqlite3") + err := createTables("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1363,7 +1423,7 @@ func TestScanRows(t *testing.T) { }) t.Run("should report if record is not a pointer", func(t *testing.T) { - err := createTable("sqlite3") + err := createTables("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1382,7 +1442,7 @@ func TestScanRows(t *testing.T) { }) t.Run("should report if record is not a pointer to struct", func(t *testing.T) { - err := createTable("sqlite3") + err := createTables("sqlite3") if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } @@ -1408,7 +1468,7 @@ var connectionString = map[string]string{ "sqlserver": "sqlserver://sa:Sqls3rv3r@127.0.0.1:1433?databaseName=ksql", } -func createTable(driver string) error { +func createTables(driver string) error { connStr := connectionString[driver] if connStr == "" { return fmt.Errorf("unsupported driver: '%s'", driver) @@ -1456,6 +1516,38 @@ func createTable(driver string) error { return fmt.Errorf("failed to create new users table: %s", err.Error()) } + db.Exec(`DROP TABLE posts`) + + switch driver { + case "sqlite3": + _, err = db.Exec(`CREATE TABLE posts ( + id INTEGER PRIMARY KEY, + user_id INTEGER, + title TEXT + )`) + case "postgres": + _, err = db.Exec(`CREATE TABLE posts ( + id serial PRIMARY KEY, + user_id INT, + title VARCHAR(50) + )`) + case "mysql": + _, err = db.Exec(`CREATE TABLE posts ( + id INT AUTO_INCREMENT PRIMARY KEY, + user_id INT, + title VARCHAR(50) + )`) + case "sqlserver": + _, err = db.Exec(`CREATE TABLE posts ( + id INT IDENTITY(1,1) PRIMARY KEY, + user_id INT, + title VARCHAR(50) + )`) + } + if err != nil { + return fmt.Errorf("failed to create new users table: %s", err.Error()) + } + return nil } diff --git a/structs/structs.go b/structs/structs.go index 63d7d57..3106d62 100644 --- a/structs/structs.go +++ b/structs/structs.go @@ -8,49 +8,56 @@ import ( "github.com/pkg/errors" ) -type structInfo struct { - byIndex map[int]*fieldInfo - byName map[string]*fieldInfo +// StructInfo stores metainformation of the struct +// parser in order to help the ksql library to work +// efectively and efficiently with reflection. +type StructInfo struct { + IsNestedStruct bool + byIndex map[int]*FieldInfo + byName map[string]*FieldInfo } -type fieldInfo struct { +// FieldInfo contains reflection and tags +// information regarding a specific field +// of a struct. +type FieldInfo struct { Name string Index int Valid bool SerializeAsJSON bool } -func (s structInfo) ByIndex(idx int) *fieldInfo { +// ByIndex returns either the *FieldInfo of a valid +// empty struct with Valid set to false +func (s StructInfo) ByIndex(idx int) *FieldInfo { field, found := s.byIndex[idx] if !found { - return &fieldInfo{} + return &FieldInfo{} } return field } -func (s structInfo) ByName(name string) *fieldInfo { +// ByName returns either the *FieldInfo of a valid +// empty struct with Valid set to false +func (s StructInfo) ByName(name string) *FieldInfo { field, found := s.byName[name] if !found { - return &fieldInfo{} + return &FieldInfo{} } return field } -func (s structInfo) Add(field fieldInfo) { +func (s StructInfo) add(field FieldInfo) { field.Valid = true s.byIndex[field.Index] = &field s.byName[field.Name] = &field } -func (s structInfo) Fields() map[int]*fieldInfo { - return s.byIndex -} - // This cache is kept as a pkg variable // because the total number of types on a program // should be finite. So keeping a single cache here // works fine. -var tagInfoCache = map[reflect.Type]structInfo{} +var tagInfoCache = map[reflect.Type]StructInfo{} // GetTagInfo efficiently returns the type information // using a global private cache @@ -58,16 +65,17 @@ var tagInfoCache = map[reflect.Type]structInfo{} // In the future we might move this cache inside // a struct, but for now this accessor is the one // we are using -func GetTagInfo(key reflect.Type) structInfo { +func GetTagInfo(key reflect.Type) StructInfo { return getCachedTagInfo(tagInfoCache, key) } -func getCachedTagInfo(tagInfoCache map[reflect.Type]structInfo, key reflect.Type) structInfo { - info, found := tagInfoCache[key] - if !found { - info = getTagNames(key) - tagInfoCache[key] = info +func getCachedTagInfo(tagInfoCache map[reflect.Type]StructInfo, key reflect.Type) StructInfo { + if info, found := tagInfoCache[key]; found { + return info } + + info := getTagNames(key) + tagInfoCache[key] = info return info } @@ -291,10 +299,10 @@ func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error // // This should save several calls to `Field(i).Tag.Get("foo")` // which improves performance by a lot. -func getTagNames(t reflect.Type) structInfo { - info := structInfo{ - byIndex: map[int]*fieldInfo{}, - byName: map[string]*fieldInfo{}, +func getTagNames(t reflect.Type) StructInfo { + info := StructInfo{ + byIndex: map[int]*FieldInfo{}, + byName: map[string]*FieldInfo{}, } for i := 0; i < t.NumField(); i++ { name := t.Field(i).Tag.Get("ksql") @@ -309,13 +317,36 @@ func getTagNames(t reflect.Type) structInfo { serializeAsJSON = tags[1] == "json" } - info.Add(fieldInfo{ + info.add(FieldInfo{ Name: name, Index: i, SerializeAsJSON: serializeAsJSON, }) } + // If there were `ksql` tags present, then we are finished: + if len(info.byIndex) > 0 { + return info + } + + // If there are no `ksql` tags in the struct, lets assume + // it is a struct tagged with `tablename` for allowing JOINs + for i := 0; i < t.NumField(); i++ { + name := t.Field(i).Tag.Get("tablename") + if name == "" { + continue + } + + info.add(FieldInfo{ + Name: name, + Index: i, + }) + } + + if len(info.byIndex) > 0 { + info.IsNestedStruct = true + } + return info } From 6bd61346d99878395aa29f608082f4b00764b699 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 23 May 2021 11:32:23 -0300 Subject: [PATCH 10/40] Add more tests to the nested struct feature --- ksql_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/ksql_test.go b/ksql_test.go index b7dfe4e..be3aa30 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -126,6 +126,60 @@ func TestQuery(t *testing.T) { assert.Equal(t, "Bia Garcia", users[1].Name) assert.Equal(t, "BR", users[1].Address.Country) }) + + t.Run("should query joined tables correctly", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + // This test only makes sense with no query prefix + if variation.queryPrefix != "" { + return + } + + _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + var joaoID uint + db.QueryRow(`SELECT id FROM users WHERE name = 'João Ribeiro'`).Scan(&joaoID) + + _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia Ribeiro', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + var biaID uint + db.QueryRow(`SELECT id FROM users WHERE name = 'Bia Ribeiro'`).Scan(&biaID) + + _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, biaID, `, 'Bia Post1')`)) + assert.Equal(t, nil, err) + _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, biaID, `, 'Bia Post2')`)) + assert.Equal(t, nil, err) + _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joaoID, `, 'João Post1')`)) + assert.Equal(t, nil, err) + + ctx := context.Background() + c := newTestDB(db, driver, "users") + var rows []struct { + User User `tablename:"u"` + Post Post `tablename:"p"` + } + err = c.Query(ctx, &rows, fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), "% Ribeiro") + + assert.Equal(t, nil, err) + assert.Equal(t, 3, len(rows)) + + assert.Equal(t, joaoID, rows[0].User.ID) + assert.Equal(t, "João Ribeiro", rows[0].User.Name) + assert.Equal(t, "João Post1", rows[0].Post.Title) + + assert.Equal(t, biaID, rows[1].User.ID) + assert.Equal(t, "Bia Ribeiro", rows[1].User.Name) + assert.Equal(t, "Bia Post1", rows[1].Post.Title) + + assert.Equal(t, biaID, rows[2].User.ID) + assert.Equal(t, "Bia Ribeiro", rows[2].User.Name) + assert.Equal(t, "Bia Post2", rows[2].Post.Title) + }) }) t.Run("using slice of pointers to structs", func(t *testing.T) { From 9e4583c3f8cc99b331b444bcbbf6642023f930d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 23 May 2021 12:25:35 -0300 Subject: [PATCH 11/40] Add error check for preventing reflection panics in nested structs --- ksql.go | 10 +++++++++- ksql_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/ksql.go b/ksql.go index 52674db..5ada2ef 100644 --- a/ksql.go +++ b/ksql.go @@ -1005,7 +1005,15 @@ func buildSelectQueryForNestedStructs( var fields []string for i := 0; i < structType.NumField(); i++ { nestedStructName := info.ByIndex(i).Name - nestedStructInfo := structs.GetTagInfo(structType.Field(i).Type) + nestedStructType := structType.Field(i).Type + if nestedStructType.Kind() != reflect.Struct { + return "", fmt.Errorf( + "expected nested struct with `tablename:\"%s\"` to be a kind of Struct, but got %v", + nestedStructName, nestedStructType, + ) + } + + nestedStructInfo := structs.GetTagInfo(nestedStructType) for j := 0; j < structType.Field(i).Type.NumField(); j++ { fields = append( fields, diff --git a/ksql_test.go b/ksql_test.go index be3aa30..94fa122 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -351,6 +351,52 @@ func TestQuery(t *testing.T) { assert.NotEqual(t, nil, err) }) }) + + t.Run("should report error for nested structs with invalid types", func(t *testing.T) { + t.Run("int", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver, "users") + var rows []struct { + Foo int `tablename:"foo"` + } + err := c.Query(ctx, &rows, fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), "% Ribeiro") + + assert.NotEqual(t, nil, err) + msg := err.Error() + for _, str := range []string{"foo", "int"} { + assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg)) + } + }) + + t.Run("*struct", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver, "users") + var rows []struct { + Foo *User `tablename:"foo"` + } + err := c.Query(ctx, &rows, fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), "% Ribeiro") + + assert.NotEqual(t, nil, err) + msg := err.Error() + for _, str := range []string{"foo", "*ksql.User"} { + assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg)) + } + }) + }) }) } } From 4e201031b7851355cbdac7f86f6654e1ce1b68e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 23 May 2021 15:04:24 -0300 Subject: [PATCH 12/40] Update README to include the feature of generating the SELECT part of the query --- README.md | 93 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 90 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d7f3bb9..e18e15e 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ If the thing you hate the most when coding is having too much unnecessary abstractions and the second thing you hate the most is having verbose and repetitive code for routine tasks this library is probably for you. -Welcome to the KissSQL project, the "Keep It Stupid Simple" sql client for Go. +Welcome to the KissSQL project, the "Keep It Stupid Simple" SQL client for Go. This package was created to be used by any developer efficiently and safely. The goals were: @@ -88,6 +88,10 @@ type SQLProvider interface { This example is also available [here](./examples/crud/crud.go) if you want to compile it yourself. +Also we have a small feature for building the "SELECT" part of the query if +you rather not use `SELECT *` queries, you may skip to the +[Select Generator Feature](#Select-Generator-Feature) which is very clean too. + ```Go package main @@ -254,6 +258,90 @@ func main() { } ``` +### Query Chunks Feature + +It's very unsual for us to need to load a number of records from the +database that might be too big for fitting in memory, e.g. load all the +users and send them somewhere. But it might happen. + +For these cases it's best to load chunks of data at a time so +that we can work on a substantial amount of data at a time and never +overload our memory capacity. For this use case we have a specific +function called `QueryChunks`: + +```golang +err = db.QueryChunks(ctx, ksql.ChunkParser{ + Query: "SELECT * FROM users WHERE type = ?", + Params: []interface{}{usersType}, + ChunkSize: 100, + ForEachChunk: func(users []User) error { + err := sendUsersSomewhere(users) + if err != nil { + // This will abort the QueryChunks loop and return this error + return err + } + return nil + }, +}) +if err != nil { + panic(err.Error()) +} +``` + +It's signature is more complicated but the use-case is also +less common so it's as simple as it gets. + +### Select Generator Feature + +There are good reasons not to use `SELECT *` queries the most important +of them is that you might end up loading more information than you are actually +going to use putting more pressure in your database for no good reason. + +To prevent that `ksql` has a feature specifically for building the `SELECT` +part of the query for you using the tags from the input struct and using +it is very simple and it works with all the 3 Query\* functions: + +Querying a single user: + +```golang +var user User +err = db.QueryOne(ctx, &user, "FROM users WHERE id = ?", userID) +if err != nil { + panic(err.Error()) +} +``` + +Querying a page of users: + +```golang +var users []User +err = db.Query(ctx, &users, "FROM users WHERE type = ? ORDER BY id LIMIT ? OFFSET ?", "Cristina", limit, offset) +if err != nil { + panic(err.Error()) +} +``` + +Querying all the users, or any potentially big number of users, from the database (not usual, but supported): + +```golang +err = db.QueryChunks(ctx, ksql.ChunkParser{ + Query: "FROM users WHERE type = ?", + Params: []interface{}{usersType}, + ChunkSize: 100, + ForEachChunk: func(users []User) error { + err := sendUsersSomewhere(users) + if err != nil { + // This will abort the QueryChunks loop and return this error + return err + } + return nil + }, +}) +if err != nil { + panic(err.Error()) +} +``` + ### Testing Examples This library has a few helper functions for helping your tests: @@ -286,7 +374,7 @@ PASS ok github.com/vingarcia/ksql 34.251s ``` -### Running the tests +### Running the ksql tests (for contributors) The tests run in real database instances so the easiest way to have them working is to just start them using docker-compose: @@ -309,7 +397,6 @@ make test ### TODO List -- Implement support for nested objects with prefixed table names - Improve error messages - Add tests for tables using composite keys - Add support for serializing structs as other formats such as YAML From ac1f94a90b8fa6e500041a6e35d6dc326d7df27a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 23 May 2021 15:05:01 -0300 Subject: [PATCH 13/40] Add nested struct tests for TestQueryOne --- ksql_test.go | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/ksql_test.go b/ksql_test.go index 94fa122..6bc7aea 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -477,6 +477,41 @@ func TestQueryOne(t *testing.T) { Country: "US", }, u.Address) }) + + t.Run("should query joined tables correctly", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + // This test only makes sense with no query prefix + if variation.queryPrefix != "" { + return + } + + _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + var joaoID uint + db.QueryRow(`SELECT id FROM users WHERE name = 'João Ribeiro'`).Scan(&joaoID) + + _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joaoID, `, 'João Post1')`)) + assert.Equal(t, nil, err) + + ctx := context.Background() + c := newTestDB(db, driver, "users") + var row struct { + User User `tablename:"u"` + Post Post `tablename:"p"` + } + err = c.QueryOne(ctx, &row, fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), "% Ribeiro") + + assert.Equal(t, nil, err) + assert.Equal(t, joaoID, row.User.ID) + assert.Equal(t, "João Ribeiro", row.User.Name) + assert.Equal(t, "João Post1", row.Post.Title) + }) }) } From e5c7b44e38e0159d50bdeea14d37a8e5e097fd68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Thu, 3 Jun 2021 20:15:18 -0300 Subject: [PATCH 14/40] Add test to QueryChunks with SELECT generation --- ksql_test.go | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/ksql_test.go b/ksql_test.go index 6bc7aea..95d4461 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -1172,6 +1172,78 @@ func TestQueryChunks(t *testing.T) { assert.Equal(t, []int{2, 1}, lengths) }) + // xxx + t.Run("should query joined tables correctly", func(t *testing.T) { + // This test only makes sense with no query prefix + if variation.queryPrefix != "" { + return + } + + db := connectDB(t, driver) + defer db.Close() + + joao := User{ + Name: "Thiago Ribeiro", + Age: 24, + } + thatiana := User{ + Name: "Thatiana Ribeiro", + Age: 20, + } + + ctx := context.Background() + c := newTestDB(db, driver, "users") + _ = c.Insert(ctx, &joao) + _ = c.Insert(ctx, &thatiana) + + _, err := db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, thatiana.ID, `, 'Thatiana Post1')`)) + assert.Equal(t, nil, err) + _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, thatiana.ID, `, 'Thatiana Post2')`)) + assert.Equal(t, nil, err) + _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'Thiago Post1')`)) + assert.Equal(t, nil, err) + + var lengths []int + var users []User + var posts []Post + err = c.QueryChunks(ctx, ChunkParser{ + Query: fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), + Params: []interface{}{"% Ribeiro"}, + + ChunkSize: 2, + ForEachChunk: func(chunk []struct { + User User `tablename:"u"` + Post Post `tablename:"p"` + }) error { + lengths = append(lengths, len(chunk)) + for _, row := range chunk { + users = append(users, row.User) + posts = append(posts, row.Post) + } + return nil + }, + }) + + assert.Equal(t, nil, err) + assert.Equal(t, 3, len(posts)) + + assert.Equal(t, joao.ID, users[0].ID) + assert.Equal(t, "Thiago Ribeiro", users[0].Name) + assert.Equal(t, "Thiago Post1", posts[0].Title) + + assert.Equal(t, thatiana.ID, users[1].ID) + assert.Equal(t, "Thatiana Ribeiro", users[1].Name) + assert.Equal(t, "Thatiana Post1", posts[1].Title) + + assert.Equal(t, thatiana.ID, users[2].ID) + assert.Equal(t, "Thatiana Ribeiro", users[2].Name) + assert.Equal(t, "Thatiana Post2", posts[2].Title) + }) + t.Run("should abort the first iteration when the callback returns an ErrAbortIteration", func(t *testing.T) { err := createTables(driver) if err != nil { From 2ad920968a7276bd8ad920cdb1d5b040b4f1d6fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Thu, 3 Jun 2021 22:54:39 -0300 Subject: [PATCH 15/40] Update README to explain the composite structs feature --- README.md | 140 ++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 130 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index e18e15e..e99cc77 100644 --- a/README.md +++ b/README.md @@ -214,14 +214,14 @@ func main() { } // Listing first 10 users from the database - // (each time you run this example a new Cristina is created) + // (each time you run this example a new "Cristina" is created) // // Note: Using this function it is recommended to set a LIMIT, since - // not doing so can load too many users on your computer's memory or - // cause an Out Of Memory Kill. + // not doing so can load too many users on your computer's memory + // causing an "Out Of Memory Kill". // // If you need to query very big numbers of users we recommend using - // the `QueryChunks` function. + // the `QueryChunks` function instead. var users []User err = db.Query(ctx, &users, "SELECT * FROM users LIMIT 10") if err != nil { @@ -288,8 +288,10 @@ if err != nil { } ``` -It's signature is more complicated but the use-case is also -less common so it's as simple as it gets. +It's signature is more complicated than the other two Query\* methods, +thus, it is adivisible to always prefer using the other two when possible +reserving this one for the rare use-case where you are actually +loading big sections of the database into memory. ### Select Generator Feature @@ -298,8 +300,8 @@ of them is that you might end up loading more information than you are actually going to use putting more pressure in your database for no good reason. To prevent that `ksql` has a feature specifically for building the `SELECT` -part of the query for you using the tags from the input struct and using -it is very simple and it works with all the 3 Query\* functions: +part of the query using the tags from the input struct. +Using it is very simple and it works with all the 3 Query\* functions: Querying a single user: @@ -342,6 +344,124 @@ if err != nil { } ``` +The implementation of this feature is actually simple internally. +First we check if the query is starting with the word `FROM`, +if it is then we just get the `ksql` tags from the struct and +then use it for building the `SELECT` statement. + +The `SELECT` statement is then cached so we don't have to build it again +the next time in order to keep the library efficient even when +using this feature. + +### Select Generation with Joins + +So there is one use-case that was not covered by `ksql` so far: + +What if you want to JOIN multiple tables for which you already have +structs defined? Would you need to create a new struct to represent +the joined columns of the two tables? no, we actually have this covered as well. + +`ksql` has a special feature for allowing the reuse of existing +structs by using composition in an anonymous struct, and then +generating the `SELECT` part of the query accordingly: + +Querying a single joined row: + +```golang +var row struct{ + User User `tablename:"u"` // (here the tablename must match the aliased tablename in the query) + Post Post `tablename:"p"` // (if no alias is used you should use the actual name of the table) +} +err = db.QueryOne(ctx, &row, "FROM users as u JOIN posts as p ON u.id = p.user_id WHERE u.id = ?", userID) +if err != nil { + panic(err.Error()) +} +``` + +Querying a page of joined rows: + +```golang +var rows []struct{ + User User `tablename:"u"` + Post Post `tablename:"p"` +} +err = db.Query(ctx, &rows, + "FROM users as u JOIN posts as p ON u.id = p.user_id WHERE name = ? LIMIT ? OFFSET ?", + "Cristina", limit, offset, +) +if err != nil { + panic(err.Error()) +} +``` + +Querying all the users, or any potentially big number of users, from the database (not usual, but supported): + +```golang +err = db.QueryChunks(ctx, ksql.ChunkParser{ + Query: "FROM users as u JOIN posts as p ON u.id = p.user_id WHERE type = ?", + Params: []interface{}{usersType}, + ChunkSize: 100, + ForEachChunk: func(rows []struct{ + User User `tablename:"u"` + Post Post `tablename:"p"` + }) error { + err := sendRowsSomewhere(rows) + if err != nil { + // This will abort the QueryChunks loop and return this error + return err + } + return nil + }, +}) +if err != nil { + panic(err.Error()) +} +``` + +As advanced as this feature might seem we don't do any parsing of the query, +and all the work is done only once and then cached. + +What actually happens is that we use the "tablename" tag to build the `SELECT` +part of the query like this: + +- `SELECT u.id, u.name, u.age, p.id, p.title ` + +This is then cached, and when we need it again we concatenate it with the rest +of the query. + +This feature has two important limitations: + +1. It is not possible to use `tablename` tags together with normal `ksql` tags. + Doing so will cause the `tablename` tags to be ignored in favor of the `ksql` ones. +2. It is not possible to use it without omitting the `SELECT` part of the query. + While in normal queries we match the selected field with the attribute by name, + in queries joining multiple tables we can't use this strategy because + different tables might have columns with the same name, and we don't + really have access to the full name of these columns making for exemple + it impossible to differentiate between `u.id` and `p.id` except by the + order in which these fields were passed. Thus, it is necessary to + leave the job of generating the `SELECT` for the library when using + this technique with composite anonymous structs. + +Ok, but what if I don't want to use this feature? + +You are not forced to, and there are a few use-cases where you would prefer not to, e.g.: + +```golang +var rows []struct{ + UserName string `ksql:"name"` + PostTitle string `ksql:"title"` +} +err := db.Query(ctx, &rows, "SELECT u.name, p.title FROM users u JOIN posts p ON u.id = p.user_id LIMIT 10") +if err != nil { + panic(err.Error()) +} +``` + +Here, since we are only interested in a couple of columns it is far +simpler and more efficient for the database to only select the columns +that we actually care about like in the example above. + ### Testing Examples This library has a few helper functions for helping your tests: @@ -376,8 +496,8 @@ ok github.com/vingarcia/ksql 34.251s ### Running the ksql tests (for contributors) -The tests run in real database instances so the easiest way to have -them working is to just start them using docker-compose: +The tests run in dockerized database instances so the easiest way +to have them working is to just start them using docker-compose: ```bash docker-compose up -d From 3c57bcf1d717c054323bf8ddb00a43ce2a9980d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Thu, 3 Jun 2021 22:56:50 -0300 Subject: [PATCH 16/40] Minor improvement in README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e99cc77..ae19ad9 100644 --- a/README.md +++ b/README.md @@ -520,9 +520,9 @@ make test - Improve error messages - Add tests for tables using composite keys - Add support for serializing structs as other formats such as YAML -- Update structs.FillStructWith to work with `json` tagged attributes +- Update `structs.FillStructWith` to work with `json` tagged attributes - Make testing easier by exposing the connection strings in an .env file -- Make testing easier by automatically creating the ksql database +- Make testing easier by automatically creating the `ksql` database - Create a way for users to submit user defined dialects ### Optimization Oportunities From 208ce07d6ec8895249b1648a2d5634a67944d803 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Thu, 3 Jun 2021 23:00:19 -0300 Subject: [PATCH 17/40] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ae19ad9..79212a1 100644 --- a/README.md +++ b/README.md @@ -59,12 +59,12 @@ in order to save development time for your team, i.e.: - Less time spent learning (few methods to learn) - Less time spent testing (helper tools made to help you) -- less time spent debugging (simple apis are easier to debug) +- Less time spent debugging (simple apis are easier to debug) - and less time reading & understanding the code ### Kiss Interface -The current interface is as follows and we plan to keep +The current interface is as follows and we plan on keeping it with as little functions as possible, so don't expect many additions: ```go From 1e434b0b78e0f98d1edee7306a7255a4aa4b3816 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Thu, 3 Jun 2021 23:21:53 -0300 Subject: [PATCH 18/40] Fix code identation on README.md --- README.md | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 79212a1..a185c70 100644 --- a/README.md +++ b/README.md @@ -369,8 +369,8 @@ Querying a single joined row: ```golang var row struct{ - User User `tablename:"u"` // (here the tablename must match the aliased tablename in the query) - Post Post `tablename:"p"` // (if no alias is used you should use the actual name of the table) + User User `tablename:"u"` // (here the tablename must match the aliased tablename in the query) + Post Post `tablename:"p"` // (if no alias is used you should use the actual name of the table) } err = db.QueryOne(ctx, &row, "FROM users as u JOIN posts as p ON u.id = p.user_id WHERE u.id = ?", userID) if err != nil { @@ -382,12 +382,12 @@ Querying a page of joined rows: ```golang var rows []struct{ - User User `tablename:"u"` - Post Post `tablename:"p"` + User User `tablename:"u"` + Post Post `tablename:"p"` } err = db.Query(ctx, &rows, - "FROM users as u JOIN posts as p ON u.id = p.user_id WHERE name = ? LIMIT ? OFFSET ?", - "Cristina", limit, offset, + "FROM users as u JOIN posts as p ON u.id = p.user_id WHERE name = ? LIMIT ? OFFSET ?", + "Cristina", limit, offset, ) if err != nil { panic(err.Error()) @@ -402,9 +402,9 @@ err = db.QueryChunks(ctx, ksql.ChunkParser{ Params: []interface{}{usersType}, ChunkSize: 100, ForEachChunk: func(rows []struct{ - User User `tablename:"u"` - Post Post `tablename:"p"` - }) error { + User User `tablename:"u"` + Post Post `tablename:"p"` + }) error { err := sendRowsSomewhere(rows) if err != nil { // This will abort the QueryChunks loop and return this error @@ -449,12 +449,12 @@ You are not forced to, and there are a few use-cases where you would prefer not ```golang var rows []struct{ - UserName string `ksql:"name"` - PostTitle string `ksql:"title"` + UserName string `ksql:"name"` + PostTitle string `ksql:"title"` } err := db.Query(ctx, &rows, "SELECT u.name, p.title FROM users u JOIN posts p ON u.id = p.user_id LIMIT 10") if err != nil { - panic(err.Error()) + panic(err.Error()) } ``` From 936d3872217557a42b08b1b321246256db06dee8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 6 Jun 2021 10:10:01 -0300 Subject: [PATCH 19/40] Minor improvment on README --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index a185c70..bb5e40d 100644 --- a/README.md +++ b/README.md @@ -437,11 +437,11 @@ This feature has two important limitations: While in normal queries we match the selected field with the attribute by name, in queries joining multiple tables we can't use this strategy because different tables might have columns with the same name, and we don't - really have access to the full name of these columns making for exemple + really have access to the full name of these columns making, for example, it impossible to differentiate between `u.id` and `p.id` except by the - order in which these fields were passed. Thus, it is necessary to - leave the job of generating the `SELECT` for the library when using - this technique with composite anonymous structs. + order in which these fields were passed. Thus, it is necessary that + the library itself writes the `SELECT` part of the query when using + this technique so that we can control the order or the selected fields. Ok, but what if I don't want to use this feature? @@ -458,9 +458,9 @@ if err != nil { } ``` -Here, since we are only interested in a couple of columns it is far -simpler and more efficient for the database to only select the columns -that we actually care about like in the example above. +In the example above, since we are only interested in a couple of columns it +is far simpler and more efficient for the database to only select the columns +that we actually care about, so it's better not to use composite structs. ### Testing Examples From 54f5b7b1eb04a34d9247fa32d0bb4a71e2d8164f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 6 Jun 2021 12:15:38 -0300 Subject: [PATCH 20/40] Improve Makefile --- Makefile | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index 49beb88..7a17471 100644 --- a/Makefile +++ b/Makefile @@ -1,32 +1,36 @@ args= path=./... -GOPATH=$(shell go env GOPATH) +GOBIN=$(shell go env GOPATH)/bin TIME=1s test: setup - $(GOPATH)/bin/richgo test $(path) $(args) + $(GOBIN)/richgo test $(path) $(args) bench: go test -bench=. -benchtime=$(TIME) lint: setup - @$(GOPATH)/bin/golint -set_exit_status -min_confidence 0.9 $(path) $(args) + @$(GOBIN)/golint -set_exit_status -min_confidence 0.9 $(path) $(args) @go vet $(path) $(args) @echo "Golint & Go Vet found no problems on your code!" mock: setup - mockgen -package=exampleservice -source=contracts.go -destination=examples/example_service/mocks.go + $(GOBIN)/mockgen -package=exampleservice -source=contracts.go -destination=examples/example_service/mocks.go -setup: .make.setup -.make.setup: +setup: $(GOBIN)/richgo $(GOBIN)/golint $(GOBIN)/mockgen + +$(GOBIN)/richgo: go get github.com/kyoh86/richgo + +$(GOBIN)/golint: go get golang.org/x/lint + +$(GOBIN)/mockgen: @# (Gomock is used on examples/example_service) go get github.com/golang/mock/gomock go get github.com/golang/mock/mockgen - touch .make.setup # Running examples: exampleservice: mock From c0d7206dccb36e562e3f58155e10ff2b10887391 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 6 Jun 2021 20:51:13 -0300 Subject: [PATCH 21/40] Breaking change: Update SQLProvider interface so methods receive table info as argument --- benchmark_test.go | 6 +- contracts.go | 60 ++- examples/crud/crud.go | 19 +- examples/example_service/example_service.go | 31 +- .../example_service/example_service_test.go | 48 +-- examples/example_service/mocks.go | 24 +- go.sum | 3 + ksql.go | 111 +++--- ksql_test.go | 365 +++++++++++------- mocks.go | 18 +- 10 files changed, 403 insertions(+), 282 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 8914511..a31128c 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -12,6 +12,8 @@ import ( "github.com/vingarcia/ksql" ) +var UsersTable = ksql.NewTable("users") + func BenchmarkInsert(b *testing.B) { ctx := context.Background() @@ -20,7 +22,6 @@ func BenchmarkInsert(b *testing.B) { ksqlDB, err := ksql.New(driver, connStr, ksql.Config{ MaxOpenConns: 1, - TableName: "users", }) if err != nil { b.FailNow() @@ -40,7 +41,7 @@ func BenchmarkInsert(b *testing.B) { b.Run("insert-one", func(b *testing.B) { for i := 0; i < b.N; i++ { - err := ksqlDB.Insert(ctx, &User{ + err := ksqlDB.Insert(ctx, UsersTable, &User{ Name: strconv.Itoa(i), Age: i, }) @@ -92,7 +93,6 @@ func BenchmarkQuery(b *testing.B) { ksqlDB, err := ksql.New(driver, connStr, ksql.Config{ MaxOpenConns: 1, - TableName: "users", }) if err != nil { b.FailNow() diff --git a/contracts.go b/contracts.go index fc52d92..26866de 100644 --- a/contracts.go +++ b/contracts.go @@ -16,9 +16,9 @@ var ErrAbortIteration error = fmt.Errorf("ksql: abort iteration, should only be // SQLProvider describes the public behavior of this ORM type SQLProvider interface { - Insert(ctx context.Context, record interface{}) error - Update(ctx context.Context, record interface{}) error - Delete(ctx context.Context, idsOrRecords ...interface{}) error + Insert(ctx context.Context, table Table, record interface{}) error + Update(ctx context.Context, table Table, record interface{}) error + Delete(ctx context.Context, table Table, idsOrRecords ...interface{}) error Query(ctx context.Context, records interface{}, query string, params ...interface{}) error QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error @@ -28,6 +28,60 @@ type SQLProvider interface { Transaction(ctx context.Context, fn func(SQLProvider) error) error } +// Table describes the required information for inserting, updating and +// deleting entities from the database by ID using the 3 helper functions +// created for that purpose. +type Table struct { + // this name must be set in order to use the Insert, Delete and Update helper + // functions. If you only intend to make queries or to use the Exec function + // it is safe to leave this field unset. + name string + + // IDColumns defaults to []string{"id"} if unset + idColumns []string +} + +// NewTable returns a Table instance that stores +// the tablename and the names of columns used as ID, +// if no column name is passed it defaults to using +// the `"id"` column. +// +// This Table is required only for using the helper methods: +// +// - Insert +// - Update +// - Delete +// +// Passing multiple ID columns will be interpreted +// as a single composite key, if you want +// to use the helper functions with different +// keys you'll need to create multiple Table instances +// for the same database table, each with a different +// set of ID columns, but this is usually not necessary. +func NewTable(tableName string, ids ...string) Table { + if len(ids) == 0 { + ids = []string{"id"} + } + + return Table{ + name: tableName, + idColumns: ids, + } +} + +func (t Table) insertMethodFor(dialect dialect) insertMethod { + if len(t.idColumns) == 1 { + return dialect.InsertMethod() + } + + insertMethod := dialect.InsertMethod() + if insertMethod == insertWithLastInsertID { + return insertWithNoIDRetrieval + } + + return insertMethod +} + // ChunkParser stores the arguments of the QueryChunks function type ChunkParser struct { // The Query and Params are used together to build a query with diff --git a/examples/crud/crud.go b/examples/crud/crud.go index 75fa363..20e78b0 100644 --- a/examples/crud/crud.go +++ b/examples/crud/crud.go @@ -33,11 +33,14 @@ type Address struct { City string `json:"city"` } +// UsersTable informs ksql the name of the table and that it can +// use the default value for the primary key column name: "id" +var UsersTable = ksql.NewTable("users") + func main() { ctx := context.Background() db, err := ksql.New("sqlite3", "/tmp/hello.sqlite", ksql.Config{ MaxOpenConns: 1, - TableName: "users", }) if err != nil { panic(err.Error()) @@ -62,14 +65,14 @@ func main() { State: "MG", }, } - err = db.Insert(ctx, &alison) + err = db.Insert(ctx, UsersTable, &alison) if err != nil { panic(err.Error()) } fmt.Println("Alison ID:", alison.ID) // Inserting inline: - err = db.Insert(ctx, &User{ + err = db.Insert(ctx, UsersTable, &User{ Name: "Cristina", Age: 27, Address: Address{ @@ -81,7 +84,7 @@ func main() { } // Deleting Alison: - err = db.Delete(ctx, alison.ID) + err = db.Delete(ctx, UsersTable, alison.ID) if err != nil { panic(err.Error()) } @@ -96,12 +99,12 @@ func main() { // Updating all fields from Cristina: cris.Name = "Cris" - err = db.Update(ctx, cris) + err = db.Update(ctx, UsersTable, cris) // Changing the age of Cristina but not touching any other fields: // Partial update technique 1: - err = db.Update(ctx, struct { + err = db.Update(ctx, UsersTable, struct { ID int `ksql:"id"` Age int `ksql:"age"` }{ID: cris.ID, Age: 28}) @@ -110,7 +113,7 @@ func main() { } // Partial update technique 2: - err = db.Update(ctx, PartialUpdateUser{ + err = db.Update(ctx, UsersTable, PartialUpdateUser{ ID: cris.ID, Age: nullable.Int(28), }) @@ -142,7 +145,7 @@ func main() { return err } - err = db.Update(ctx, PartialUpdateUser{ + err = db.Update(ctx, UsersTable, PartialUpdateUser{ ID: cris2.ID, Age: nullable.Int(29), }) diff --git a/examples/example_service/example_service.go b/examples/example_service/example_service.go index 85bd18f..94ca156 100644 --- a/examples/example_service/example_service.go +++ b/examples/example_service/example_service.go @@ -8,11 +8,8 @@ import ( "github.com/vingarcia/ksql/nullable" ) -// Service ... -type Service struct { - usersTable ksql.SQLProvider - streamChunkSize int -} +// UsersTable informs ksql that the ID column is named "id" +var UsersTable = ksql.NewTable("users", "id") // UserEntity represents a domain user, // the pointer fields represent optional fields that @@ -41,17 +38,23 @@ type Address struct { Country string `json:"country"` } +// Service ... +type Service struct { + db ksql.SQLProvider + streamChunkSize int +} + // NewUserService ... -func NewUserService(usersTable ksql.SQLProvider) Service { +func NewUserService(db ksql.SQLProvider) Service { return Service{ - usersTable: usersTable, + db: db, streamChunkSize: 100, } } // CreateUser ... func (s Service) CreateUser(ctx context.Context, u UserEntity) error { - return s.usersTable.Insert(ctx, &u) + return s.db.Insert(ctx, UsersTable, &u) } // UpdateUserScore update the user score adding scoreChange with the current @@ -60,12 +63,12 @@ func (s Service) UpdateUserScore(ctx context.Context, uID int, scoreChange int) var scoreRow struct { Score int `ksql:"score"` } - err := s.usersTable.QueryOne(ctx, &scoreRow, "SELECT score FROM users WHERE id = ?", uID) + err := s.db.QueryOne(ctx, &scoreRow, "SELECT score FROM users WHERE id = ?", uID) if err != nil { return err } - return s.usersTable.Update(ctx, &UserEntity{ + return s.db.Update(ctx, UsersTable, &UserEntity{ ID: uID, Score: nullable.Int(scoreRow.Score + scoreChange), }) @@ -76,12 +79,12 @@ func (s Service) ListUsers(ctx context.Context, offset, limit int) (total int, u var countRow struct { Count int `ksql:"count"` } - err = s.usersTable.QueryOne(ctx, &countRow, "SELECT count(*) as count FROM users") + err = s.db.QueryOne(ctx, &countRow, "SELECT count(*) as count FROM users") if err != nil { return 0, nil, err } - return countRow.Count, users, s.usersTable.Query(ctx, &users, "SELECT * FROM users OFFSET ? LIMIT ?", offset, limit) + return countRow.Count, users, s.db.Query(ctx, &users, "SELECT * FROM users OFFSET ? LIMIT ?", offset, limit) } // StreamAllUsers sends all users from the database to an external client @@ -91,7 +94,7 @@ func (s Service) ListUsers(ctx context.Context, offset, limit int) (total int, u // function only when the ammount of data loaded might exceed the available memory and/or // when you can't put an upper limit on the number of values returned. func (s Service) StreamAllUsers(ctx context.Context, sendUser func(u UserEntity) error) error { - return s.usersTable.QueryChunks(ctx, ksql.ChunkParser{ + return s.db.QueryChunks(ctx, ksql.ChunkParser{ Query: "SELECT * FROM users", Params: []interface{}{}, ChunkSize: s.streamChunkSize, @@ -110,5 +113,5 @@ func (s Service) StreamAllUsers(ctx context.Context, sendUser func(u UserEntity) // DeleteUser deletes a user by its ID func (s Service) DeleteUser(ctx context.Context, uID int) error { - return s.usersTable.Delete(ctx, uID) + return s.db.Delete(ctx, UsersTable, uID) } diff --git a/examples/example_service/example_service_test.go b/examples/example_service/example_service_test.go index ae142f5..33b2f2c 100644 --- a/examples/example_service/example_service_test.go +++ b/examples/example_service/example_service_test.go @@ -17,16 +17,16 @@ func TestCreateUser(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockSQLProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 100, } var users []interface{} - usersTableMock.EXPECT().Insert(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, records ...interface{}) error { + mockDB.EXPECT().Insert(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, table ksql.Table, records ...interface{}) error { users = append(users, records...) return nil }) @@ -43,16 +43,16 @@ func TestCreateUser(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockSQLProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 100, } var users []map[string]interface{} - usersTableMock.EXPECT().Insert(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, records ...interface{}) error { + mockDB.EXPECT().Insert(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, table ksql.Table, records ...interface{}) error { for _, record := range records { // The StructToMap function will convert a struct with `ksql` tags // into a map using the ksql attr names as keys. @@ -83,16 +83,16 @@ func TestUpdateUserScore(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockSQLProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 100, } var users []interface{} gomock.InOrder( - usersTableMock.EXPECT().QueryOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + mockDB.EXPECT().QueryOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, result interface{}, query string, params ...interface{}) error { // This function will use reflection to fill the // struct fields with the values from the map @@ -103,8 +103,8 @@ func TestUpdateUserScore(t *testing.T) { "score": 42, }) }), - usersTableMock.EXPECT().Update(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, records ...interface{}) error { + mockDB.EXPECT().Update(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, table ksql.Table, records ...interface{}) error { users = append(users, records...) return nil }), @@ -127,15 +127,15 @@ func TestListUsers(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockSQLProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 100, } gomock.InOrder( - usersTableMock.EXPECT().QueryOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + mockDB.EXPECT().QueryOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, result interface{}, query string, params ...interface{}) error { // This function will use reflection to fill the // struct fields with the values from the map @@ -146,7 +146,7 @@ func TestListUsers(t *testing.T) { "count": 420, }) }), - usersTableMock.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + mockDB.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, results interface{}, query string, params ...interface{}) error { return structs.FillSliceWith(results, []map[string]interface{}{ { @@ -189,14 +189,14 @@ func TestStreamAllUsers(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockSQLProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 2, } - usersTableMock.EXPECT().QueryChunks(gomock.Any(), gomock.Any()). + mockDB.EXPECT().QueryChunks(gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, parser ksql.ChunkParser) error { fn, ok := parser.ForEachChunk.(func(users []UserEntity) error) require.True(t, ok) @@ -263,16 +263,16 @@ func TestDeleteUser(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockSQLProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 100, } var ids []interface{} - usersTableMock.EXPECT().Delete(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, idArgs ...interface{}) error { + mockDB.EXPECT().Delete(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, table ksql.Table, idArgs ...interface{}) error { ids = append(ids, idArgs...) return nil }) diff --git a/examples/example_service/mocks.go b/examples/example_service/mocks.go index 44b7baf..c170490 100644 --- a/examples/example_service/mocks.go +++ b/examples/example_service/mocks.go @@ -36,9 +36,9 @@ func (m *MockSQLProvider) EXPECT() *MockSQLProviderMockRecorder { } // Delete mocks base method. -func (m *MockSQLProvider) Delete(ctx context.Context, idsOrRecords ...interface{}) error { +func (m *MockSQLProvider) Delete(ctx context.Context, table ksql.Table, idsOrRecords ...interface{}) error { m.ctrl.T.Helper() - varargs := []interface{}{ctx} + varargs := []interface{}{ctx, table} for _, a := range idsOrRecords { varargs = append(varargs, a) } @@ -48,9 +48,9 @@ func (m *MockSQLProvider) Delete(ctx context.Context, idsOrRecords ...interface{ } // Delete indicates an expected call of Delete. -func (mr *MockSQLProviderMockRecorder) Delete(ctx interface{}, idsOrRecords ...interface{}) *gomock.Call { +func (mr *MockSQLProviderMockRecorder) Delete(ctx, table interface{}, idsOrRecords ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx}, idsOrRecords...) + varargs := append([]interface{}{ctx, table}, idsOrRecords...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSQLProvider)(nil).Delete), varargs...) } @@ -74,17 +74,17 @@ func (mr *MockSQLProviderMockRecorder) Exec(ctx, query interface{}, params ...in } // Insert mocks base method. -func (m *MockSQLProvider) Insert(ctx context.Context, record interface{}) error { +func (m *MockSQLProvider) Insert(ctx context.Context, table ksql.Table, record interface{}) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Insert", ctx, record) + ret := m.ctrl.Call(m, "Insert", ctx, table, record) ret0, _ := ret[0].(error) return ret0 } // Insert indicates an expected call of Insert. -func (mr *MockSQLProviderMockRecorder) Insert(ctx, record interface{}) *gomock.Call { +func (mr *MockSQLProviderMockRecorder) Insert(ctx, table, record interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockSQLProvider)(nil).Insert), ctx, record) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockSQLProvider)(nil).Insert), ctx, table, record) } // Query mocks base method. @@ -154,15 +154,15 @@ func (mr *MockSQLProviderMockRecorder) Transaction(ctx, fn interface{}) *gomock. } // Update mocks base method. -func (m *MockSQLProvider) Update(ctx context.Context, record interface{}) error { +func (m *MockSQLProvider) Update(ctx context.Context, table ksql.Table, record interface{}) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Update", ctx, record) + ret := m.ctrl.Call(m, "Update", ctx, table, record) ret0, _ := ret[0].(error) return ret0 } // Update indicates an expected call of Update. -func (mr *MockSQLProviderMockRecorder) Update(ctx, record interface{}) *gomock.Call { +func (mr *MockSQLProviderMockRecorder) Update(ctx, table, record interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockSQLProvider)(nil).Update), ctx, record) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockSQLProvider)(nil).Update), ctx, table, record) } diff --git a/go.sum b/go.sum index b80fabb..67d781f 100644 --- a/go.sum +++ b/go.sum @@ -35,6 +35,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= @@ -45,8 +46,10 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e h1:aZzprAO9/8oim3qStq3wc1Xuxx4QmAGriC4VU4ojemQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898 h1:/atklqdjdhuosWIl6AIbOeHJjicWYPqR9bpxqxYG2pA= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= diff --git a/ksql.go b/ksql.go index 5ada2ef..85e6e2a 100644 --- a/ksql.go +++ b/ksql.go @@ -24,16 +24,9 @@ func init() { // interfacing with the "database/sql" package implementing // the KissSQL interface `SQLProvider`. type DB struct { - driver string - dialect dialect - tableName string - db sqlProvider - - // Most dbs have a single primary key, - // But in future ksql should work with compound keys as well - idCols []string - - insertMethod insertMethod + driver string + dialect dialect + db sqlProvider } type sqlProvider interface { @@ -46,14 +39,6 @@ type sqlProvider interface { type Config struct { // MaxOpenCons defaults to 1 if not set MaxOpenConns int - - // TableName must be set in order to use the Insert, Delete and Update helper - // functions. If you only intend to make queries or to use the Exec function - // it is safe to leave this field unset. - TableName string - - // IDColumns defaults to []string{"id"} if unset - IDColumns []string } // New instantiates a new KissSQL client @@ -81,23 +66,10 @@ func New( return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver) } - if len(config.IDColumns) == 0 { - config.IDColumns = []string{"id"} - } - - insertMethod := dialect.InsertMethod() - if len(config.IDColumns) > 1 && insertMethod == insertWithLastInsertID { - insertMethod = insertWithNoIDRetrieval - } - return DB{ - dialect: dialect, - driver: dbDriver, - db: db, - tableName: config.TableName, - - idCols: config.IDColumns, - insertMethod: insertMethod, + dialect: dialect, + driver: dbDriver, + db: db, }, nil } @@ -133,8 +105,16 @@ func (c DB) Query( slice = slice.Slice(0, 0) } - if strings.ToUpper(getFirstToken(query)) == "FROM" { - selectPrefix, err := buildSelectQuery(c.dialect, structType, selectQueryCache[c.dialect.DriverName()]) + info := structs.GetTagInfo(structType) + + firstToken := strings.ToUpper(getFirstToken(query)) + if info.IsNestedStruct && firstToken == "SELECT" { + // This error check is necessary, since if we can't build the select part of the query this feature won't work. + return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query") + } + + if firstToken == "FROM" { + selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()]) if err != nil { return err } @@ -206,8 +186,16 @@ func (c DB) QueryOne( return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record) } - if strings.ToUpper(getFirstToken(query)) == "FROM" { - selectPrefix, err := buildSelectQuery(c.dialect, t, selectQueryCache[c.dialect.DriverName()]) + info := structs.GetTagInfo(t) + + firstToken := strings.ToUpper(getFirstToken(query)) + if info.IsNestedStruct && firstToken == "SELECT" { + // This error check is necessary, since if we can't build the select part of the query this feature won't work. + return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query") + } + + if firstToken == "FROM" { + selectPrefix, err := buildSelectQuery(c.dialect, t, info, selectQueryCache[c.dialect.DriverName()]) if err != nil { return err } @@ -268,8 +256,16 @@ func (c DB) QueryChunks( return err } - if strings.ToUpper(getFirstToken(parser.Query)) == "FROM" { - selectPrefix, err := buildSelectQuery(c.dialect, structType, selectQueryCache[c.dialect.DriverName()]) + info := structs.GetTagInfo(structType) + + firstToken := strings.ToUpper(getFirstToken(parser.Query)) + if info.IsNestedStruct && firstToken == "SELECT" { + // This error check is necessary, since if we can't build the select part of the query this feature won't work. + return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query") + } + + if firstToken == "FROM" { + selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()]) if err != nil { return err } @@ -347,22 +343,19 @@ func (c DB) QueryChunks( // the ID is automatically updated after insertion is completed. func (c DB) Insert( ctx context.Context, + table Table, record interface{}, ) error { - if c.tableName == "" { - return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Insert method") - } - - query, params, scanValues, err := buildInsertQuery(c.dialect, c.tableName, record, c.idCols...) + query, params, scanValues, err := buildInsertQuery(c.dialect, table.name, record, table.idColumns...) if err != nil { return err } - switch c.insertMethod { + switch table.insertMethodFor(c.dialect) { case insertWithReturning, insertWithOutput: - err = c.insertReturningIDs(ctx, record, query, params, scanValues, c.idCols) + err = c.insertReturningIDs(ctx, record, query, params, scanValues, table.idColumns) case insertWithLastInsertID: - err = c.insertWithLastInsertID(ctx, record, query, params, c.idCols[0]) + err = c.insertWithLastInsertID(ctx, record, query, params, table.idColumns[0]) case insertWithNoIDRetrieval: err = c.insertWithNoIDRetrieval(ctx, record, query, params) default: @@ -471,27 +464,24 @@ func assertStructPtr(t reflect.Type) error { // Delete deletes one or more instances from the database by id func (c DB) Delete( ctx context.Context, + table Table, ids ...interface{}, ) error { - if c.tableName == "" { - return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Delete method") - } - if len(ids) == 0 { return nil } - idMaps, err := normalizeIDsAsMaps(c.idCols, ids) + idMaps, err := normalizeIDsAsMaps(table.idColumns, ids) if err != nil { return err } var query string var params []interface{} - if len(c.idCols) == 1 { - query, params = buildSingleKeyDeleteQuery(c.dialect, c.tableName, c.idCols[0], idMaps) + if len(table.idColumns) == 1 { + query, params = buildSingleKeyDeleteQuery(c.dialect, table.name, table.idColumns[0], idMaps) } else { - query, params = buildCompositeKeyDeleteQuery(c.dialect, c.tableName, c.idCols, idMaps) + query, params = buildCompositeKeyDeleteQuery(c.dialect, table.name, table.idColumns, idMaps) } _, err = c.db.ExecContext(ctx, query, params...) @@ -543,13 +533,10 @@ func normalizeIDsAsMaps(idNames []string, ids []interface{}) ([]map[string]inter // Partial updates are supported, i.e. it will ignore nil pointer attributes func (c DB) Update( ctx context.Context, + table Table, record interface{}, ) error { - if c.tableName == "" { - return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Update method") - } - - query, params, err := buildUpdateQuery(c.dialect, c.tableName, record, c.idCols...) + query, params, err := buildUpdateQuery(c.dialect, table.name, record, table.idColumns...) if err != nil { return err } @@ -964,13 +951,13 @@ func getFirstToken(s string) string { func buildSelectQuery( dialect dialect, structType reflect.Type, + info structs.StructInfo, selectQueryCache map[reflect.Type]string, ) (query string, err error) { if selectQuery, found := selectQueryCache[structType]; found { return selectQuery, nil } - info := structs.GetTagInfo(structType) if info.IsNestedStruct { query, err = buildSelectQueryForNestedStructs(dialect, structType, info) if err != nil { diff --git a/ksql_test.go b/ksql_test.go index 95d4461..3eea91a 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -25,6 +25,8 @@ type User struct { Address Address `ksql:"address,json"` } +var UsersTable = NewTable("users") + type Address struct { Street string `json:"street"` Number string `json:"number"` @@ -40,6 +42,8 @@ type Post struct { Title string `ksql:"title"` } +var PostsTable = NewTable("posts") + func TestQuery(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { @@ -69,7 +73,7 @@ func TestQuery(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var users []User err := c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`) assert.Equal(t, nil, err) @@ -89,7 +93,7 @@ func TestQuery(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var users []User err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") @@ -111,7 +115,7 @@ func TestQuery(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var users []User err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia") @@ -154,7 +158,7 @@ func TestQuery(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var rows []struct { User User `tablename:"u"` Post Post `tablename:"p"` @@ -193,7 +197,7 @@ func TestQuery(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var users []*User err := c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`) assert.Equal(t, nil, err) @@ -213,7 +217,7 @@ func TestQuery(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var users []*User err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") @@ -235,7 +239,7 @@ func TestQuery(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var users []*User err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia") @@ -278,7 +282,7 @@ func TestQuery(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var rows []*struct { User User `tablename:"u"` Post Post `tablename:"p"` @@ -325,7 +329,7 @@ func TestQuery(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) err = c.Query(ctx, &User{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá") assert.NotEqual(t, nil, err) @@ -345,56 +349,72 @@ func TestQuery(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var users []User - err = c.Query(ctx, &users, `SELECT * FROM not a valid query`) + err := c.Query(ctx, &users, `SELECT * FROM not a valid query`) assert.NotEqual(t, nil, err) }) - }) - t.Run("should report error for nested structs with invalid types", func(t *testing.T) { - t.Run("int", func(t *testing.T) { + t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var rows []struct { - Foo int `tablename:"foo"` + User User `tablename:"users"` + Post Post `tablename:"posts"` } - err := c.Query(ctx, &rows, fmt.Sprint( - `FROM users u JOIN posts p ON p.user_id = u.id`, - ` WHERE u.name like `, c.dialect.Placeholder(0), - ` ORDER BY u.id, p.id`, - ), "% Ribeiro") - + err := c.Query(ctx, &rows, `SELECT * FROM users u JOIN posts p ON u.id = p.user_id`) assert.NotEqual(t, nil, err) - msg := err.Error() - for _, str := range []string{"foo", "int"} { - assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg)) - } + assert.Equal(t, true, strings.Contains(err.Error(), "nested struct"), "unexpected error msg: "+err.Error()) + assert.Equal(t, true, strings.Contains(err.Error(), "feature"), "unexpected error msg: "+err.Error()) }) - t.Run("*struct", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + t.Run("should report error for nested structs with invalid types", func(t *testing.T) { + t.Run("int", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() - ctx := context.Background() - c := newTestDB(db, driver, "users") - var rows []struct { - Foo *User `tablename:"foo"` - } - err := c.Query(ctx, &rows, fmt.Sprint( - `FROM users u JOIN posts p ON p.user_id = u.id`, - ` WHERE u.name like `, c.dialect.Placeholder(0), - ` ORDER BY u.id, p.id`, - ), "% Ribeiro") + ctx := context.Background() + c := newTestDB(db, driver) + var rows []struct { + Foo int `tablename:"foo"` + } + err := c.Query(ctx, &rows, fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), "% Ribeiro") - assert.NotEqual(t, nil, err) - msg := err.Error() - for _, str := range []string{"foo", "*ksql.User"} { - assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg)) - } + assert.NotEqual(t, nil, err) + msg := err.Error() + for _, str := range []string{"foo", "int"} { + assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg)) + } + }) + + t.Run("*struct", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + var rows []struct { + Foo *User `tablename:"foo"` + } + err := c.Query(ctx, &rows, fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), "% Ribeiro") + + assert.NotEqual(t, nil, err) + msg := err.Error() + for _, str := range []string{"foo", "*ksql.User"} { + assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg)) + } + }) }) }) }) @@ -429,7 +449,7 @@ func TestQueryOne(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u := User{} err := c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE id=1;`) assert.Equal(t, ErrRecordNotFound, err) @@ -443,7 +463,7 @@ func TestQueryOne(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u := User{} err = c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") @@ -466,7 +486,7 @@ func TestQueryOne(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var u User err = c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0)+` ORDER BY id ASC`, "% Sá") @@ -496,7 +516,7 @@ func TestQueryOne(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var row struct { User User `tablename:"u"` Post Post `tablename:"p"` @@ -526,7 +546,7 @@ func TestQueryOne(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) err = c.QueryOne(ctx, &[]User{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá") assert.NotEqual(t, nil, err) @@ -540,11 +560,27 @@ func TestQueryOne(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var user User err := c.QueryOne(ctx, &user, `SELECT * FROM not a valid query`) assert.NotEqual(t, nil, err) }) + + t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + var row struct { + User User `tablename:"users"` + Post Post `tablename:"posts"` + } + err := c.QueryOne(ctx, &row, `SELECT * FROM users u JOIN posts p ON u.id = p.user_id LIMIT 1`) + assert.NotEqual(t, nil, err) + assert.Equal(t, true, strings.Contains(err.Error(), "nested struct"), "unexpected error msg: "+err.Error()) + assert.Equal(t, true, strings.Contains(err.Error(), "feature"), "unexpected error msg: "+err.Error()) + }) }) } } @@ -563,7 +599,7 @@ func TestInsert(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u := User{ Name: "Fernanda", @@ -572,7 +608,7 @@ func TestInsert(t *testing.T) { }, } - err := c.Insert(ctx, &u) + err := c.Insert(ctx, UsersTable, &u) assert.Equal(t, nil, err) assert.NotEqual(t, 0, u.ID) @@ -593,11 +629,11 @@ func TestInsert(t *testing.T) { defer db.Close() ctx := context.Background() + // Using columns "id" and "name" as IDs: - c, err := New(driver, connectionString[driver], Config{ - TableName: "users", - IDColumns: []string{"id", "name"}, - }) + table := NewTable("users", "id", "name") + + c, err := New(driver, connectionString[driver], Config{}) assert.Equal(t, nil, err) u := User{ @@ -609,7 +645,7 @@ func TestInsert(t *testing.T) { }, } - err = c.Insert(ctx, &u) + err = c.Insert(ctx, table, &u) assert.Equal(t, nil, err) assert.Equal(t, uint(0), u.ID) @@ -633,15 +669,15 @@ func TestInsert(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - err = c.Insert(ctx, "foo") + err = c.Insert(ctx, UsersTable, "foo") assert.NotEqual(t, nil, err) - err = c.Insert(ctx, nullable.String("foo")) + err = c.Insert(ctx, UsersTable, nullable.String("foo")) assert.NotEqual(t, nil, err) - err = c.Insert(ctx, map[string]interface{}{ + err = c.Insert(ctx, UsersTable, map[string]interface{}{ "name": "foo", "age": 12, }) @@ -651,11 +687,11 @@ func TestInsert(t *testing.T) { &User{Name: "foo", Age: 22}, &User{Name: "bar", Age: 32}, } - err = c.Insert(ctx, ifUserForgetToExpandList) + err = c.Insert(ctx, UsersTable, ifUserForgetToExpandList) assert.NotEqual(t, nil, err) // We might want to support this in the future, but not for now: - err = c.Insert(ctx, User{Name: "not a ptr to user", Age: 42}) + err = c.Insert(ctx, UsersTable, User{Name: "not a ptr to user", Age: 42}) assert.NotEqual(t, nil, err) }) @@ -664,12 +700,12 @@ func TestInsert(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) // This is an invalid value: - c.insertMethod = insertMethod(42) + c.dialect = brokenDialect{} - err = c.Insert(ctx, &User{Name: "foo"}) + err = c.Insert(ctx, UsersTable, &User{Name: "foo"}) assert.NotEqual(t, nil, err) }) }) @@ -677,6 +713,24 @@ func TestInsert(t *testing.T) { } } +type brokenDialect struct{} + +func (brokenDialect) InsertMethod() insertMethod { + return insertMethod(42) +} + +func (brokenDialect) Escape(str string) string { + return str +} + +func (brokenDialect) Placeholder(idx int) string { + return "?" +} + +func (brokenDialect) DriverName() string { + return "fake-driver-name" +} + func TestDelete(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { @@ -690,13 +744,13 @@ func TestDelete(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u := User{ Name: "Won't be deleted", } - err := c.Insert(ctx, &u) + err := c.Insert(ctx, UsersTable, &u) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) @@ -706,7 +760,7 @@ func TestDelete(t *testing.T) { assert.Equal(t, u.ID, result.ID) - err = c.Delete(ctx) + err = c.Delete(ctx, UsersTable) assert.Equal(t, nil, err) result = User{} @@ -720,13 +774,13 @@ func TestDelete(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u1 := User{ Name: "Fernanda", } - err := c.Insert(ctx, &u1) + err := c.Insert(ctx, UsersTable, &u1) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u1.ID) @@ -739,7 +793,7 @@ func TestDelete(t *testing.T) { Name: "Won't be deleted", } - err = c.Insert(ctx, &u2) + err = c.Insert(ctx, UsersTable, &u2) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u2.ID) @@ -748,7 +802,7 @@ func TestDelete(t *testing.T) { assert.Equal(t, nil, err) assert.Equal(t, u2.ID, result.ID) - err = c.Delete(ctx, u1.ID) + err = c.Delete(ctx, UsersTable, u1.ID) assert.Equal(t, nil, err) result = User{} @@ -768,26 +822,26 @@ func TestDelete(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u1 := User{ Name: "Fernanda", } - err := c.Insert(ctx, &u1) + err := c.Insert(ctx, UsersTable, &u1) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u1.ID) u2 := User{ Name: "Juliano", } - err = c.Insert(ctx, &u2) + err = c.Insert(ctx, UsersTable, &u2) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u2.ID) u3 := User{ Name: "This won't be deleted", } - err = c.Insert(ctx, &u3) + err = c.Insert(ctx, UsersTable, &u3) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u3.ID) @@ -806,7 +860,7 @@ func TestDelete(t *testing.T) { assert.Equal(t, nil, err) assert.Equal(t, u3.ID, result.ID) - err = c.Delete(ctx, u1.ID, u2.ID) + err = c.Delete(ctx, UsersTable, u1.ID, u2.ID) assert.Equal(t, nil, err) results := []User{} @@ -833,7 +887,7 @@ func TestUpdate(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u := User{ Name: "Letícia", @@ -847,7 +901,7 @@ func TestUpdate(t *testing.T) { assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) - err = c.Update(ctx, User{ + err = c.Update(ctx, UsersTable, User{ ID: u.ID, Name: "Thayane", }) @@ -864,7 +918,7 @@ func TestUpdate(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u := User{ Name: "Letícia", @@ -878,7 +932,7 @@ func TestUpdate(t *testing.T) { assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) - err = c.Update(ctx, User{ + err = c.Update(ctx, UsersTable, User{ ID: u.ID, Name: "Thayane", }) @@ -895,7 +949,7 @@ func TestUpdate(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) type partialUser struct { ID uint `ksql:"id"` @@ -915,7 +969,7 @@ func TestUpdate(t *testing.T) { assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) - err = c.Update(ctx, partialUser{ + err = c.Update(ctx, UsersTable, partialUser{ ID: u.ID, // Should be updated because it is not null, just empty: Name: "", @@ -936,7 +990,7 @@ func TestUpdate(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) type partialUser struct { ID uint `ksql:"id"` @@ -957,7 +1011,7 @@ func TestUpdate(t *testing.T) { assert.NotEqual(t, uint(0), u.ID) // Should update all fields: - err = c.Update(ctx, partialUser{ + err = c.Update(ctx, UsersTable, partialUser{ ID: u.ID, Name: "Thay", Age: nullable.Int(42), @@ -977,9 +1031,9 @@ func TestUpdate(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "non_existing_table") + c := newTestDB(db, driver) - err = c.Update(ctx, User{ + err = c.Update(ctx, NewTable("non_existing_table"), User{ ID: 1, Name: "Thayane", }) @@ -1017,9 +1071,9 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{ + _ = c.Insert(ctx, UsersTable, &User{ Name: "User1", Address: Address{Country: "BR"}, }) @@ -1057,10 +1111,10 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{Name: "User1", Address: Address{Country: "US"}}) - _ = c.Insert(ctx, &User{Name: "User2", Address: Address{Country: "BR"}}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User1", Address: Address{Country: "US"}}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2", Address: Address{Country: "BR"}}) var lengths []int var users []User @@ -1099,10 +1153,10 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{Name: "User1", Address: Address{Country: "US"}}) - _ = c.Insert(ctx, &User{Name: "User2", Address: Address{Country: "BR"}}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User1", Address: Address{Country: "US"}}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2", Address: Address{Country: "BR"}}) var lengths []int var users []User @@ -1141,11 +1195,11 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3"}) var lengths []int var users []User @@ -1192,9 +1246,9 @@ func TestQueryChunks(t *testing.T) { } ctx := context.Background() - c := newTestDB(db, driver, "users") - _ = c.Insert(ctx, &joao) - _ = c.Insert(ctx, &thatiana) + c := newTestDB(db, driver) + _ = c.Insert(ctx, UsersTable, &joao) + _ = c.Insert(ctx, UsersTable, &thatiana) _, err := db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, thatiana.ID, `, 'Thatiana Post1')`)) assert.Equal(t, nil, err) @@ -1254,11 +1308,11 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3"}) var lengths []int var users []User @@ -1293,11 +1347,11 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3"}) returnVals := []error{nil, ErrAbortIteration} var lengths []int @@ -1336,11 +1390,11 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3"}) var lengths []int var users []User @@ -1375,11 +1429,11 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3"}) returnVals := []error{nil, errors.New("fake error msg")} var lengths []int @@ -1413,7 +1467,7 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) funcs := []interface{}{ nil, @@ -1458,7 +1512,7 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) err := c.QueryChunks(ctx, ChunkParser{ Query: `SELECT * FROM not a valid query`, Params: []interface{}{}, @@ -1470,6 +1524,31 @@ func TestQueryChunks(t *testing.T) { }) assert.NotEqual(t, nil, err) }) + + t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + + err := c.QueryChunks(ctx, ChunkParser{ + Query: `SELECT * FROM users u JOIN posts p ON u.id = p.user_id`, + Params: []interface{}{}, + + ChunkSize: 2, + ForEachChunk: func(buffer []struct { + User User `tablename:"users"` + Post Post `tablename:"posts"` + }) error { + return nil + }, + }) + + assert.NotEqual(t, nil, err) + assert.Equal(t, true, strings.Contains(err.Error(), "nested struct"), "unexpected error msg: "+err.Error()) + assert.Equal(t, true, strings.Contains(err.Error(), "feature"), "unexpected error msg: "+err.Error()) + }) }) } }) @@ -1489,10 +1568,10 @@ func TestTransaction(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) var users []User err = c.Transaction(ctx, func(db SQLProvider) error { @@ -1516,17 +1595,17 @@ func TestTransaction(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u1 := User{Name: "User1", Age: 42} u2 := User{Name: "User2", Age: 42} - _ = c.Insert(ctx, &u1) - _ = c.Insert(ctx, &u2) + _ = c.Insert(ctx, UsersTable, &u1) + _ = c.Insert(ctx, UsersTable, &u2) err = c.Transaction(ctx, func(db SQLProvider) error { - err = db.Insert(ctx, &User{Name: "User3"}) + err = db.Insert(ctx, UsersTable, &User{Name: "User3"}) assert.Equal(t, nil, err) - err = db.Insert(ctx, &User{Name: "User4"}) + err = db.Insert(ctx, UsersTable, &User{Name: "User4"}) assert.Equal(t, nil, err) err = db.Exec(ctx, "UPDATE users SET age = 22") assert.Equal(t, nil, err) @@ -1557,10 +1636,10 @@ func TestScanRows(t *testing.T) { ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() - c := newTestDB(db, "sqlite3", "users") - _ = c.Insert(ctx, &User{Name: "User1", Age: 22}) - _ = c.Insert(ctx, &User{Name: "User2", Age: 14}) - _ = c.Insert(ctx, &User{Name: "User3", Age: 43}) + c := newTestDB(db, "sqlite3") + _ = c.Insert(ctx, UsersTable, &User{Name: "User1", Age: 22}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2", Age: 14}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3", Age: 43}) rows, err := db.QueryContext(ctx, "select * from users where name='User2'") assert.Equal(t, nil, err) @@ -1586,8 +1665,8 @@ func TestScanRows(t *testing.T) { ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() - c := newTestDB(db, "sqlite3", "users") - _ = c.Insert(ctx, &User{Name: "User1", Age: 22}) + c := newTestDB(db, "sqlite3") + _ = c.Insert(ctx, UsersTable, &User{Name: "User1", Age: 22}) rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User1'") assert.Equal(t, nil, err) @@ -1758,19 +1837,11 @@ func createTables(driver string) error { return nil } -func newTestDB(db *sql.DB, driver string, tableName string, ids ...string) DB { - if len(ids) == 0 { - ids = []string{"id"} - } - +func newTestDB(db *sql.DB, driver string) DB { return DB{ - driver: driver, - dialect: supportedDialects[driver], - db: db, - tableName: tableName, - - idCols: ids, - insertMethod: supportedDialects[driver].InsertMethod(), + driver: driver, + dialect: supportedDialects[driver], + db: db, } } diff --git a/mocks.go b/mocks.go index eab1770..8fc7321 100644 --- a/mocks.go +++ b/mocks.go @@ -6,9 +6,9 @@ var _ SQLProvider = MockSQLProvider{} // MockSQLProvider ... type MockSQLProvider struct { - InsertFn func(ctx context.Context, record interface{}) error - UpdateFn func(ctx context.Context, record interface{}) error - DeleteFn func(ctx context.Context, ids ...interface{}) error + InsertFn func(ctx context.Context, table Table, record interface{}) error + UpdateFn func(ctx context.Context, table Table, record interface{}) error + DeleteFn func(ctx context.Context, table Table, ids ...interface{}) error QueryFn func(ctx context.Context, records interface{}, query string, params ...interface{}) error QueryOneFn func(ctx context.Context, record interface{}, query string, params ...interface{}) error @@ -19,18 +19,18 @@ type MockSQLProvider struct { } // Insert ... -func (m MockSQLProvider) Insert(ctx context.Context, record interface{}) error { - return m.InsertFn(ctx, record) +func (m MockSQLProvider) Insert(ctx context.Context, table Table, record interface{}) error { + return m.InsertFn(ctx, table, record) } // Update ... -func (m MockSQLProvider) Update(ctx context.Context, record interface{}) error { - return m.UpdateFn(ctx, record) +func (m MockSQLProvider) Update(ctx context.Context, table Table, record interface{}) error { + return m.UpdateFn(ctx, table, record) } // Delete ... -func (m MockSQLProvider) Delete(ctx context.Context, ids ...interface{}) error { - return m.DeleteFn(ctx, ids...) +func (m MockSQLProvider) Delete(ctx context.Context, table Table, ids ...interface{}) error { + return m.DeleteFn(ctx, table, ids...) } // Query ... From cc4e73dc6261800fd88397ae97c116df322c3735 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 6 Jun 2021 20:56:20 -0300 Subject: [PATCH 22/40] Update README to describe the new interface --- README.md | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index bb5e40d..9cb5b28 100644 --- a/README.md +++ b/README.md @@ -70,9 +70,9 @@ it with as little functions as possible, so don't expect many additions: ```go // SQLProvider describes the public behavior of this ORM type SQLProvider interface { - Insert(ctx context.Context, record interface{}) error - Update(ctx context.Context, record interface{}) error - Delete(ctx context.Context, idsOrRecords ...interface{}) error + Insert(ctx context.Context, table Table, record interface{}) error + Update(ctx context.Context, table Table, record interface{}) error + Delete(ctx context.Context, table Table, idsOrRecords ...interface{}) error Query(ctx context.Context, records interface{}, query string, params ...interface{}) error QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error @@ -128,11 +128,14 @@ type Address struct { City string `json:"city"` } +// UsersTable informs ksql the name of the table and that it can +// use the default value for the primary key column name: "id" +var UsersTable = ksql.NewTable("users") + func main() { ctx := context.Background() db, err := ksql.New("sqlite3", "/tmp/hello.sqlite", ksql.Config{ MaxOpenConns: 1, - TableName: "users", }) if err != nil { panic(err.Error()) @@ -157,14 +160,14 @@ func main() { State: "MG", }, } - err = db.Insert(ctx, &alison) + err = db.Insert(ctx, UsersTable, &alison) if err != nil { panic(err.Error()) } fmt.Println("Alison ID:", alison.ID) // Inserting inline: - err = db.Insert(ctx, &User{ + err = db.Insert(ctx, UsersTable, &User{ Name: "Cristina", Age: 27, Address: Address{ @@ -176,7 +179,7 @@ func main() { } // Deleting Alison: - err = db.Delete(ctx, alison.ID) + err = db.Delete(ctx, UsersTable, alison.ID) if err != nil { panic(err.Error()) } @@ -191,12 +194,12 @@ func main() { // Updating all fields from Cristina: cris.Name = "Cris" - err = db.Update(ctx, cris) + err = db.Update(ctx, UsersTable, cris) // Changing the age of Cristina but not touching any other fields: // Partial update technique 1: - err = db.Update(ctx, struct { + err = db.Update(ctx, UsersTable, struct { ID int `ksql:"id"` Age int `ksql:"age"` }{ID: cris.ID, Age: 28}) @@ -205,7 +208,7 @@ func main() { } // Partial update technique 2: - err = db.Update(ctx, PartialUpdateUser{ + err = db.Update(ctx, UsersTable, PartialUpdateUser{ ID: cris.ID, Age: nullable.Int(28), }) @@ -214,14 +217,14 @@ func main() { } // Listing first 10 users from the database - // (each time you run this example a new "Cristina" is created) + // (each time you run this example a new Cristina is created) // // Note: Using this function it is recommended to set a LIMIT, since - // not doing so can load too many users on your computer's memory - // causing an "Out Of Memory Kill". + // not doing so can load too many users on your computer's memory or + // cause an Out Of Memory Kill. // // If you need to query very big numbers of users we recommend using - // the `QueryChunks` function instead. + // the `QueryChunks` function. var users []User err = db.Query(ctx, &users, "SELECT * FROM users LIMIT 10") if err != nil { @@ -237,7 +240,7 @@ func main() { return err } - err = db.Update(ctx, PartialUpdateUser{ + err = db.Update(ctx, UsersTable, PartialUpdateUser{ ID: cris2.ID, Age: nullable.Int(29), }) @@ -531,3 +534,4 @@ make test - Consider passing the cached structInfo as argument for all the functions that use it, so that we don't need to get it more than once in the same call. - Use a cache to store all queries after they are built +- Preload the insert method for all dialects inside `ksql.NewTable()` From 34a07b75b563b3928b57401fb0145c491da9a918 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 6 Jun 2021 20:58:12 -0300 Subject: [PATCH 23/40] Run go mod tidy --- go.sum | 3 --- 1 file changed, 3 deletions(-) diff --git a/go.sum b/go.sum index 67d781f..b80fabb 100644 --- a/go.sum +++ b/go.sum @@ -35,7 +35,6 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= @@ -46,10 +45,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e h1:aZzprAO9/8oim3qStq3wc1Xuxx4QmAGriC4VU4ojemQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898 h1:/atklqdjdhuosWIl6AIbOeHJjicWYPqR9bpxqxYG2pA= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= From 75330a12c5dac0d314983009d71cbdcef201d04f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 6 Jun 2021 21:11:05 -0300 Subject: [PATCH 24/40] Update benchmark on README --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 9cb5b28..ef941d5 100644 --- a/README.md +++ b/README.md @@ -487,14 +487,14 @@ goos: linux goarch: amd64 pkg: github.com/vingarcia/ksql cpu: Intel(R) Core(TM) i5-3210M CPU @ 2.50GHz -BenchmarkInsert/ksql-setup/insert-one-4 4302 776648 ns/op -BenchmarkInsert/sqlx-setup/insert-one-4 4716 762358 ns/op -BenchmarkQuery/ksql-setup/single-row-4 12204 293858 ns/op -BenchmarkQuery/ksql-setup/multiple-rows-4 11145 323571 ns/op -BenchmarkQuery/sqlx-setup/single-row-4 12440 290937 ns/op -BenchmarkQuery/sqlx-setup/multiple-rows-4 10000 310314 ns/op +BenchmarkInsert/ksql-setup/insert-one-4 4970 727724 ns/op +BenchmarkInsert/sqlx-setup/insert-one-4 4842 703503 ns/op +BenchmarkQuery/ksql-setup/single-row-4 12692 282544 ns/op +BenchmarkQuery/ksql-setup/multiple-rows-4 10000 313662 ns/op +BenchmarkQuery/sqlx-setup/single-row-4 12328 291965 ns/op +BenchmarkQuery/sqlx-setup/multiple-rows-4 10000 301910 ns/op PASS -ok github.com/vingarcia/ksql 34.251s +ok github.com/vingarcia/ksql 39.995s ``` ### Running the ksql tests (for contributors) From 5b9b0dd00df562762aa09fac8635ad31755a2ea6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Thu, 10 Jun 2021 15:57:46 -0300 Subject: [PATCH 25/40] Add CallFunctionWithRows() for help testing calls to QueryChunks --- .../example_service/example_service_test.go | 25 +++++++-------- structs/structs.go | 2 +- test_helpers.go | 31 +++++++++++++++++++ 3 files changed, 43 insertions(+), 15 deletions(-) create mode 100644 test_helpers.go diff --git a/examples/example_service/example_service_test.go b/examples/example_service/example_service_test.go index 33b2f2c..1cb465c 100644 --- a/examples/example_service/example_service_test.go +++ b/examples/example_service/example_service_test.go @@ -5,7 +5,6 @@ import ( "testing" gomock "github.com/golang/mock/gomock" - "github.com/stretchr/testify/require" "github.com/tj/assert" "github.com/vingarcia/ksql" "github.com/vingarcia/ksql/nullable" @@ -198,19 +197,17 @@ func TestStreamAllUsers(t *testing.T) { mockDB.EXPECT().QueryChunks(gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, parser ksql.ChunkParser) error { - fn, ok := parser.ForEachChunk.(func(users []UserEntity) error) - require.True(t, ok) // Chunk 1: - err := fn([]UserEntity{ + err := ksql.CallFunctionWithRows(parser.ForEachChunk, []map[string]interface{}{ { - ID: 1, - Name: nullable.String("fake name"), - Age: nullable.Int(42), + "id": 1, + "name": "fake name", + "age": 42, }, { - ID: 2, - Name: nullable.String("another fake name"), - Age: nullable.Int(43), + "id": 2, + "name": "another fake name", + "age": 43, }, }) if err != nil { @@ -218,11 +215,11 @@ func TestStreamAllUsers(t *testing.T) { } // Chunk 2: - err = fn([]UserEntity{ + err = ksql.CallFunctionWithRows(parser.ForEachChunk, []map[string]interface{}{ { - ID: 3, - Name: nullable.String("yet another fake name"), - Age: nullable.Int(44), + "id": 3, + "name": "yet another fake name", + "age": 44, }, }) return err diff --git a/structs/structs.go b/structs/structs.go index 3106d62..7690555 100644 --- a/structs/structs.go +++ b/structs/structs.go @@ -262,7 +262,7 @@ func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error sliceType := sliceRef.Type() if sliceType.Kind() != reflect.Ptr { return fmt.Errorf( - "FillSliceWith: expected input to be a pointer to struct but got %v", + "FillSliceWith: expected input to be a pointer to a slice of structs but got %v", sliceType, ) } diff --git a/test_helpers.go b/test_helpers.go new file mode 100644 index 0000000..65fff68 --- /dev/null +++ b/test_helpers.go @@ -0,0 +1,31 @@ +package ksql + +import ( + "reflect" + + "github.com/vingarcia/ksql/structs" +) + +// CallFunctionWithRows was created for helping test the QueryChunks method +func CallFunctionWithRows(fn interface{}, rows []map[string]interface{}) error { + fnValue := reflect.ValueOf(fn) + chunkType, err := parseInputFunc(fn) + if err != nil { + return err + } + + chunk := reflect.MakeSlice(chunkType, 0, len(rows)) + + // Create a pointer to a slice (required by FillSliceWith) + chunkPtr := reflect.New(chunkType) + chunkPtr.Elem().Set(chunk) + + err = structs.FillSliceWith(chunkPtr.Interface(), rows) + if err != nil { + return err + } + + err, _ = fnValue.Call([]reflect.Value{chunkPtr.Elem()})[0].Interface().(error) + + return err +} From 20f49eb22b99f2d9640bfdb72b6b8b84d54d24c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Fri, 11 Jun 2021 12:25:24 -0300 Subject: [PATCH 26/40] Reorganize files so the test helpers are grouped in the same pkg --- README.md | 7 +- .../example_service/example_service_test.go | 4 +- ksql.go | 34 +---- structs/func_parser.go | 40 ++++++ structs/structs.go | 95 ------------- structs/testhelpers.go | 125 ++++++++++++++++++ test_helpers.go | 31 ----- 7 files changed, 172 insertions(+), 164 deletions(-) create mode 100644 structs/func_parser.go create mode 100644 structs/testhelpers.go delete mode 100644 test_helpers.go diff --git a/README.md b/README.md index ef941d5..05430d2 100644 --- a/README.md +++ b/README.md @@ -469,9 +469,10 @@ that we actually care about, so it's better not to use composite structs. This library has a few helper functions for helping your tests: -- `ksql.FillStructWith(struct interface{}, dbRow map[string]interface{}) error` -- `ksql.FillSliceWith(structSlice interface{}, dbRows []map[string]interface{}) error` -- `ksql.StructToMap(struct interface{}) (map[string]interface{}, error)` +- `structs.FillStructWith(struct interface{}, dbRow map[string]interface{}) error` +- `structs.FillSliceWith(structSlice interface{}, dbRows []map[string]interface{}) error` +- `structs.StructToMap(struct interface{}) (map[string]interface{}, error)` +- `structs.CallFunctionWithRows(fn interface{}, rows []map[string]interface{}) (map[string]interface{}, error)` If you want to see examples (we have examples for all the public functions) just read the example tests available on our [example service](./examples/example_service) diff --git a/examples/example_service/example_service_test.go b/examples/example_service/example_service_test.go index 1cb465c..983ee88 100644 --- a/examples/example_service/example_service_test.go +++ b/examples/example_service/example_service_test.go @@ -198,7 +198,7 @@ func TestStreamAllUsers(t *testing.T) { mockDB.EXPECT().QueryChunks(gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, parser ksql.ChunkParser) error { // Chunk 1: - err := ksql.CallFunctionWithRows(parser.ForEachChunk, []map[string]interface{}{ + err := structs.CallFunctionWithRows(parser.ForEachChunk, []map[string]interface{}{ { "id": 1, "name": "fake name", @@ -215,7 +215,7 @@ func TestStreamAllUsers(t *testing.T) { } // Chunk 2: - err = ksql.CallFunctionWithRows(parser.ForEachChunk, []map[string]interface{}{ + err = structs.CallFunctionWithRows(parser.ForEachChunk, []map[string]interface{}{ { "id": 3, "name": "yet another fake name", diff --git a/ksql.go b/ksql.go index 85e6e2a..0665449 100644 --- a/ksql.go +++ b/ksql.go @@ -244,7 +244,7 @@ func (c DB) QueryChunks( parser ChunkParser, ) error { fnValue := reflect.ValueOf(parser.ForEachChunk) - chunkType, err := parseInputFunc(parser.ForEachChunk) + chunkType, err := structs.ParseInputFunc(parser.ForEachChunk) if err != nil { return err } @@ -759,38 +759,6 @@ func (c DB) Transaction(ctx context.Context, fn func(SQLProvider) error) error { } } -var errType = reflect.TypeOf(new(error)).Elem() - -func parseInputFunc(fn interface{}) (reflect.Type, error) { - if fn == nil { - return nil, fmt.Errorf("the ForEachChunk attribute is required and cannot be nil") - } - - t := reflect.TypeOf(fn) - - if t.Kind() != reflect.Func { - return nil, fmt.Errorf("the ForEachChunk callback must be a function") - } - if t.NumIn() != 1 { - return nil, fmt.Errorf("the ForEachChunk callback must have 1 argument") - } - - if t.NumOut() != 1 { - return nil, fmt.Errorf("the ForEachChunk callback must have a single return value") - } - - if t.Out(0) != errType { - return nil, fmt.Errorf("the return value of the ForEachChunk callback must be of type error") - } - - argsType := t.In(0) - if argsType.Kind() != reflect.Slice { - return nil, fmt.Errorf("the argument of the ForEachChunk callback must a slice of structs") - } - - return argsType, nil -} - type nopScanner struct{} var nopScannerValue = reflect.ValueOf(&nopScanner{}).Interface() diff --git a/structs/func_parser.go b/structs/func_parser.go new file mode 100644 index 0000000..d68db69 --- /dev/null +++ b/structs/func_parser.go @@ -0,0 +1,40 @@ +package structs + +import ( + "fmt" + "reflect" +) + +var errType = reflect.TypeOf(new(error)).Elem() + +// ParseInputFunc is used exclusively for parsing +// the ForEachChunk function used on the QueryChunks method. +func ParseInputFunc(fn interface{}) (reflect.Type, error) { + if fn == nil { + return nil, fmt.Errorf("the ForEachChunk attribute is required and cannot be nil") + } + + t := reflect.TypeOf(fn) + + if t.Kind() != reflect.Func { + return nil, fmt.Errorf("the ForEachChunk callback must be a function") + } + if t.NumIn() != 1 { + return nil, fmt.Errorf("the ForEachChunk callback must have 1 argument") + } + + if t.NumOut() != 1 { + return nil, fmt.Errorf("the ForEachChunk callback must have a single return value") + } + + if t.Out(0) != errType { + return nil, fmt.Errorf("the return value of the ForEachChunk callback must be of type error") + } + + argsType := t.In(0) + if argsType.Kind() != reflect.Slice { + return nil, fmt.Errorf("the argument of the ForEachChunk callback must a slice of structs") + } + + return argsType, nil +} diff --git a/structs/structs.go b/structs/structs.go index 7690555..a77b3a0 100644 --- a/structs/structs.go +++ b/structs/structs.go @@ -4,8 +4,6 @@ import ( "fmt" "reflect" "strings" - - "github.com/pkg/errors" ) // StructInfo stores metainformation of the struct @@ -120,56 +118,6 @@ func StructToMap(obj interface{}) (map[string]interface{}, error) { return m, nil } -// FillStructWith is meant to be used on unit tests to mock -// the response from the database. -// -// The first argument is any struct you are passing to a ksql func, -// and the second is a map representing a database row you want -// to use to update this struct. -func FillStructWith(record interface{}, dbRow map[string]interface{}) error { - v := reflect.ValueOf(record) - t := v.Type() - - if t.Kind() != reflect.Ptr { - return fmt.Errorf( - "FillStructWith: expected input to be a pointer to struct but got %T", - record, - ) - } - - t = t.Elem() - v = v.Elem() - - if t.Kind() != reflect.Struct { - return fmt.Errorf( - "FillStructWith: expected input kind to be a struct but got %T", - record, - ) - } - - info := getCachedTagInfo(tagInfoCache, t) - for colName, rawSrc := range dbRow { - fieldInfo := info.ByName(colName) - if !fieldInfo.Valid { - // Ignore columns not tagged with `ksql:"..."` - continue - } - - src := NewPtrConverter(rawSrc) - dest := v.Field(fieldInfo.Index) - destType := t.Field(fieldInfo.Index).Type - - destValue, err := src.Convert(destType) - if err != nil { - return errors.Wrap(err, fmt.Sprintf("FillStructWith: error on field `%s`", colName)) - } - - dest.Set(destValue) - } - - return nil -} - // PtrConverter was created to make it easier // to handle conversion between ptr and non ptr types, e.g.: // @@ -251,49 +199,6 @@ func (p PtrConverter) Convert(destType reflect.Type) (reflect.Value, error) { return destValue, nil } -// FillSliceWith is meant to be used on unit tests to mock -// the response from the database. -// -// The first argument is any slice of structs you are passing to a ksql func, -// and the second is a slice of maps representing the database rows you want -// to use to update this struct. -func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error { - sliceRef := reflect.ValueOf(entities) - sliceType := sliceRef.Type() - if sliceType.Kind() != reflect.Ptr { - return fmt.Errorf( - "FillSliceWith: expected input to be a pointer to a slice of structs but got %v", - sliceType, - ) - } - - structType, isSliceOfPtrs, err := DecodeAsSliceOfStructs(sliceType.Elem()) - if err != nil { - return errors.Wrap(err, "FillSliceWith") - } - - slice := sliceRef.Elem() - for idx, row := range dbRows { - if slice.Len() <= idx { - var elemValue reflect.Value - elemValue = reflect.New(structType) - if !isSliceOfPtrs { - elemValue = elemValue.Elem() - } - slice = reflect.Append(slice, elemValue) - } - - err := FillStructWith(slice.Index(idx).Addr().Interface(), row) - if err != nil { - return errors.Wrap(err, "FillSliceWith") - } - } - - sliceRef.Elem().Set(slice) - - return nil -} - // This function collects only the names // that will be used from the input type. // diff --git a/structs/testhelpers.go b/structs/testhelpers.go new file mode 100644 index 0000000..c02c40d --- /dev/null +++ b/structs/testhelpers.go @@ -0,0 +1,125 @@ +package structs + +import ( + "fmt" + "reflect" + + "github.com/pkg/errors" +) + +// FillStructWith is meant to be used on unit tests to mock +// the response from the database. +// +// The first argument is any struct you are passing to a ksql func, +// and the second is a map representing a database row you want +// to use to update this struct. +func FillStructWith(record interface{}, dbRow map[string]interface{}) error { + v := reflect.ValueOf(record) + t := v.Type() + + if t.Kind() != reflect.Ptr { + return fmt.Errorf( + "FillStructWith: expected input to be a pointer to struct but got %T", + record, + ) + } + + t = t.Elem() + v = v.Elem() + + if t.Kind() != reflect.Struct { + return fmt.Errorf( + "FillStructWith: expected input kind to be a struct but got %T", + record, + ) + } + + info := GetTagInfo(t) + for colName, rawSrc := range dbRow { + fieldInfo := info.ByName(colName) + if !fieldInfo.Valid { + // Ignore columns not tagged with `ksql:"..."` + continue + } + + src := NewPtrConverter(rawSrc) + dest := v.Field(fieldInfo.Index) + destType := t.Field(fieldInfo.Index).Type + + destValue, err := src.Convert(destType) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("FillStructWith: error on field `%s`", colName)) + } + + dest.Set(destValue) + } + + return nil +} + +// FillSliceWith is meant to be used on unit tests to mock +// the response from the database. +// +// The first argument is any slice of structs you are passing to a ksql func, +// and the second is a slice of maps representing the database rows you want +// to use to update this struct. +func FillSliceWith(entities interface{}, dbRows []map[string]interface{}) error { + sliceRef := reflect.ValueOf(entities) + sliceType := sliceRef.Type() + if sliceType.Kind() != reflect.Ptr { + return fmt.Errorf( + "FillSliceWith: expected input to be a pointer to a slice of structs but got %v", + sliceType, + ) + } + + structType, isSliceOfPtrs, err := DecodeAsSliceOfStructs(sliceType.Elem()) + if err != nil { + return errors.Wrap(err, "FillSliceWith") + } + + slice := sliceRef.Elem() + for idx, row := range dbRows { + if slice.Len() <= idx { + var elemValue reflect.Value + elemValue = reflect.New(structType) + if !isSliceOfPtrs { + elemValue = elemValue.Elem() + } + slice = reflect.Append(slice, elemValue) + } + + err := FillStructWith(slice.Index(idx).Addr().Interface(), row) + if err != nil { + return errors.Wrap(err, "FillSliceWith") + } + } + + sliceRef.Elem().Set(slice) + + return nil +} + +// CallFunctionWithRows was created for helping test the QueryChunks method +func CallFunctionWithRows(fn interface{}, rows []map[string]interface{}) error { + fnValue := reflect.ValueOf(fn) + chunkType, err := ParseInputFunc(fn) + if err != nil { + return err + } + + chunk := reflect.MakeSlice(chunkType, 0, len(rows)) + + // Create a pointer to a slice (required by FillSliceWith) + chunkPtr := reflect.New(chunkType) + chunkPtr.Elem().Set(chunk) + + err = FillSliceWith(chunkPtr.Interface(), rows) + if err != nil { + return err + } + + err, _ = fnValue.Call([]reflect.Value{chunkPtr.Elem()})[0].Interface().(error) + + return err +} diff --git a/test_helpers.go b/test_helpers.go deleted file mode 100644 index 65fff68..0000000 --- a/test_helpers.go +++ /dev/null @@ -1,31 +0,0 @@ -package ksql - -import ( - "reflect" - - "github.com/vingarcia/ksql/structs" -) - -// CallFunctionWithRows was created for helping test the QueryChunks method -func CallFunctionWithRows(fn interface{}, rows []map[string]interface{}) error { - fnValue := reflect.ValueOf(fn) - chunkType, err := parseInputFunc(fn) - if err != nil { - return err - } - - chunk := reflect.MakeSlice(chunkType, 0, len(rows)) - - // Create a pointer to a slice (required by FillSliceWith) - chunkPtr := reflect.New(chunkType) - chunkPtr.Elem().Set(chunk) - - err = structs.FillSliceWith(chunkPtr.Interface(), rows) - if err != nil { - return err - } - - err, _ = fnValue.Call([]reflect.Value{chunkPtr.Elem()})[0].Interface().(error) - - return err -} From 5b351c8ba236cf5c9fe299b46159578bb4329489 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Fri, 11 Jun 2021 12:53:47 -0300 Subject: [PATCH 27/40] Rename package structs to kstructs so its unambiguous --- README.md | 12 ++--- .../example_service/example_service_test.go | 14 +++--- ksql.go | 44 +++++++++---------- {structs => kstructs}/func_parser.go | 2 +- {structs => kstructs}/structs.go | 2 +- {structs => kstructs}/structs_test.go | 2 +- {structs => kstructs}/testhelpers.go | 2 +- 7 files changed, 39 insertions(+), 39 deletions(-) rename {structs => kstructs}/func_parser.go (98%) rename {structs => kstructs}/structs.go (99%) rename {structs => kstructs}/structs_test.go (99%) rename {structs => kstructs}/testhelpers.go (99%) diff --git a/README.md b/README.md index 05430d2..0762c52 100644 --- a/README.md +++ b/README.md @@ -463,16 +463,16 @@ if err != nil { In the example above, since we are only interested in a couple of columns it is far simpler and more efficient for the database to only select the columns -that we actually care about, so it's better not to use composite structs. +that we actually care about, so it's better not to use composite kstructs. ### Testing Examples This library has a few helper functions for helping your tests: -- `structs.FillStructWith(struct interface{}, dbRow map[string]interface{}) error` -- `structs.FillSliceWith(structSlice interface{}, dbRows []map[string]interface{}) error` -- `structs.StructToMap(struct interface{}) (map[string]interface{}, error)` -- `structs.CallFunctionWithRows(fn interface{}, rows []map[string]interface{}) (map[string]interface{}, error)` +- `kstructs.FillStructWith(struct interface{}, dbRow map[string]interface{}) error` +- `kstructs.FillSliceWith(structSlice interface{}, dbRows []map[string]interface{}) error` +- `kstructs.StructToMap(struct interface{}) (map[string]interface{}, error)` +- `kstructs.CallFunctionWithRows(fn interface{}, rows []map[string]interface{}) (map[string]interface{}, error)` If you want to see examples (we have examples for all the public functions) just read the example tests available on our [example service](./examples/example_service) @@ -524,7 +524,7 @@ make test - Improve error messages - Add tests for tables using composite keys - Add support for serializing structs as other formats such as YAML -- Update `structs.FillStructWith` to work with `json` tagged attributes +- Update `kstructs.FillStructWith` to work with `json` tagged attributes - Make testing easier by exposing the connection strings in an .env file - Make testing easier by automatically creating the `ksql` database - Create a way for users to submit user defined dialects diff --git a/examples/example_service/example_service_test.go b/examples/example_service/example_service_test.go index 983ee88..007f1c0 100644 --- a/examples/example_service/example_service_test.go +++ b/examples/example_service/example_service_test.go @@ -8,7 +8,7 @@ import ( "github.com/tj/assert" "github.com/vingarcia/ksql" "github.com/vingarcia/ksql/nullable" - "github.com/vingarcia/ksql/structs" + "github.com/vingarcia/ksql/kstructs" ) func TestCreateUser(t *testing.T) { @@ -58,7 +58,7 @@ func TestCreateUser(t *testing.T) { // // If you are inserting an anonymous struct (not usual) this function // can make your tests shorter: - uMap, err := structs.StructToMap(record) + uMap, err := kstructs.StructToMap(record) if err != nil { return err } @@ -95,7 +95,7 @@ func TestUpdateUserScore(t *testing.T) { DoAndReturn(func(ctx context.Context, result interface{}, query string, params ...interface{}) error { // This function will use reflection to fill the // struct fields with the values from the map - return structs.FillStructWith(result, map[string]interface{}{ + return kstructs.FillStructWith(result, map[string]interface{}{ // Use int this map the keys you set on the ksql tags, e.g. `ksql:"score"` // Each of these fields represent the database rows returned // by the query. @@ -138,7 +138,7 @@ func TestListUsers(t *testing.T) { DoAndReturn(func(ctx context.Context, result interface{}, query string, params ...interface{}) error { // This function will use reflection to fill the // struct fields with the values from the map - return structs.FillStructWith(result, map[string]interface{}{ + return kstructs.FillStructWith(result, map[string]interface{}{ // Use int this map the keys you set on the ksql tags, e.g. `ksql:"score"` // Each of these fields represent the database rows returned // by the query. @@ -147,7 +147,7 @@ func TestListUsers(t *testing.T) { }), mockDB.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, results interface{}, query string, params ...interface{}) error { - return structs.FillSliceWith(results, []map[string]interface{}{ + return kstructs.FillSliceWith(results, []map[string]interface{}{ { "id": 1, "name": "fake name", @@ -198,7 +198,7 @@ func TestStreamAllUsers(t *testing.T) { mockDB.EXPECT().QueryChunks(gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, parser ksql.ChunkParser) error { // Chunk 1: - err := structs.CallFunctionWithRows(parser.ForEachChunk, []map[string]interface{}{ + err := kstructs.CallFunctionWithRows(parser.ForEachChunk, []map[string]interface{}{ { "id": 1, "name": "fake name", @@ -215,7 +215,7 @@ func TestStreamAllUsers(t *testing.T) { } // Chunk 2: - err = structs.CallFunctionWithRows(parser.ForEachChunk, []map[string]interface{}{ + err = kstructs.CallFunctionWithRows(parser.ForEachChunk, []map[string]interface{}{ { "id": 3, "name": "yet another fake name", diff --git a/ksql.go b/ksql.go index 0665449..1f9195f 100644 --- a/ksql.go +++ b/ksql.go @@ -9,7 +9,7 @@ import ( "unicode" "github.com/pkg/errors" - "github.com/vingarcia/ksql/structs" + "github.com/vingarcia/ksql/kstructs" ) var selectQueryCache = map[string]map[reflect.Type]string{} @@ -93,7 +93,7 @@ func (c DB) Query( } sliceType := slicePtrType.Elem() slice := slicePtr.Elem() - structType, isSliceOfPtrs, err := structs.DecodeAsSliceOfStructs(sliceType) + structType, isSliceOfPtrs, err := kstructs.DecodeAsSliceOfStructs(sliceType) if err != nil { return err } @@ -105,7 +105,7 @@ func (c DB) Query( slice = slice.Slice(0, 0) } - info := structs.GetTagInfo(structType) + info := kstructs.GetTagInfo(structType) firstToken := strings.ToUpper(getFirstToken(query)) if info.IsNestedStruct && firstToken == "SELECT" { @@ -186,7 +186,7 @@ func (c DB) QueryOne( return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record) } - info := structs.GetTagInfo(t) + info := kstructs.GetTagInfo(t) firstToken := strings.ToUpper(getFirstToken(query)) if info.IsNestedStruct && firstToken == "SELECT" { @@ -244,19 +244,19 @@ func (c DB) QueryChunks( parser ChunkParser, ) error { fnValue := reflect.ValueOf(parser.ForEachChunk) - chunkType, err := structs.ParseInputFunc(parser.ForEachChunk) + chunkType, err := kstructs.ParseInputFunc(parser.ForEachChunk) if err != nil { return err } chunk := reflect.MakeSlice(chunkType, 0, parser.ChunkSize) - structType, isSliceOfPtrs, err := structs.DecodeAsSliceOfStructs(chunkType) + structType, isSliceOfPtrs, err := kstructs.DecodeAsSliceOfStructs(chunkType) if err != nil { return err } - info := structs.GetTagInfo(structType) + info := kstructs.GetTagInfo(structType) firstToken := strings.ToUpper(getFirstToken(parser.Query)) if info.IsNestedStruct && firstToken == "SELECT" { @@ -416,7 +416,7 @@ func (c DB) insertWithLastInsertID( return errors.Wrap(err, "can't write to `"+idName+"` field") } - info := structs.GetTagInfo(t.Elem()) + info := kstructs.GetTagInfo(t.Elem()) id, err := result.LastInsertId() if err != nil { @@ -499,7 +499,7 @@ func normalizeIDsAsMaps(idNames []string, ids []interface{}) ([]map[string]inter t := reflect.TypeOf(ids[i]) switch t.Kind() { case reflect.Struct: - m, err := structs.StructToMap(ids[i]) + m, err := kstructs.StructToMap(ids[i]) if err != nil { return nil, errors.Wrapf(err, "could not get ID(s) from record on idx %d", i) } @@ -561,9 +561,9 @@ func buildInsertQuery( ) } - info := structs.GetTagInfo(t.Elem()) + info := kstructs.GetTagInfo(t.Elem()) - recordMap, err := structs.StructToMap(record) + recordMap, err := kstructs.StructToMap(record) if err != nil { return "", nil, nil, err } @@ -651,7 +651,7 @@ func buildUpdateQuery( record interface{}, idFieldNames ...string, ) (query string, args []interface{}, err error) { - recordMap, err := structs.StructToMap(record) + recordMap, err := kstructs.StructToMap(record) if err != nil { return "", nil, err } @@ -681,7 +681,7 @@ func buildUpdateQuery( if t.Kind() == reflect.Ptr { t = t.Elem() } - info := structs.GetTagInfo(t) + info := kstructs.GetTagInfo(t) var setQuery []string for i, k := range keys { @@ -781,13 +781,13 @@ func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error { return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record) } - info := structs.GetTagInfo(t) + info := kstructs.GetTagInfo(t) var scanArgs []interface{} if info.IsNestedStruct { // This version is positional meaning that it expect the arguments // to follow an specific order. It's ok because we don't allow the - // user to type the "SELECT" part of the query for nested structs. + // user to type the "SELECT" part of the query for nested kstructs. scanArgs = getScanArgsForNestedStructs(dialect, rows, t, v, info) } else { names, err := rows.Columns() @@ -802,11 +802,11 @@ func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error { return rows.Scan(scanArgs...) } -func getScanArgsForNestedStructs(dialect dialect, rows *sql.Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) []interface{} { +func getScanArgsForNestedStructs(dialect dialect, rows *sql.Rows, t reflect.Type, v reflect.Value, info kstructs.StructInfo) []interface{} { scanArgs := []interface{}{} for i := 0; i < v.NumField(); i++ { // TODO(vingarcia00): Handle case where type is pointer - nestedStructInfo := structs.GetTagInfo(t.Field(i).Type) + nestedStructInfo := kstructs.GetTagInfo(t.Field(i).Type) nestedStructValue := v.Field(i) for j := 0; j < nestedStructValue.NumField(); j++ { fieldInfo := nestedStructInfo.ByIndex(j) @@ -829,7 +829,7 @@ func getScanArgsForNestedStructs(dialect dialect, rows *sql.Rows, t reflect.Type return scanArgs } -func getScanArgsFromNames(dialect dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} { +func getScanArgsFromNames(dialect dialect, names []string, v reflect.Value, info kstructs.StructInfo) []interface{} { scanArgs := []interface{}{} for _, name := range names { fieldInfo := info.ByName(name) @@ -919,7 +919,7 @@ func getFirstToken(s string) string { func buildSelectQuery( dialect dialect, structType reflect.Type, - info structs.StructInfo, + info kstructs.StructInfo, selectQueryCache map[reflect.Type]string, ) (query string, err error) { if selectQuery, found := selectQueryCache[structType]; found { @@ -942,7 +942,7 @@ func buildSelectQuery( func buildSelectQueryForPlainStructs( dialect dialect, structType reflect.Type, - info structs.StructInfo, + info kstructs.StructInfo, ) string { var fields []string for i := 0; i < structType.NumField(); i++ { @@ -955,7 +955,7 @@ func buildSelectQueryForPlainStructs( func buildSelectQueryForNestedStructs( dialect dialect, structType reflect.Type, - info structs.StructInfo, + info kstructs.StructInfo, ) (string, error) { var fields []string for i := 0; i < structType.NumField(); i++ { @@ -968,7 +968,7 @@ func buildSelectQueryForNestedStructs( ) } - nestedStructInfo := structs.GetTagInfo(nestedStructType) + nestedStructInfo := kstructs.GetTagInfo(nestedStructType) for j := 0; j < structType.Field(i).Type.NumField(); j++ { fields = append( fields, diff --git a/structs/func_parser.go b/kstructs/func_parser.go similarity index 98% rename from structs/func_parser.go rename to kstructs/func_parser.go index d68db69..56550f2 100644 --- a/structs/func_parser.go +++ b/kstructs/func_parser.go @@ -1,4 +1,4 @@ -package structs +package kstructs import ( "fmt" diff --git a/structs/structs.go b/kstructs/structs.go similarity index 99% rename from structs/structs.go rename to kstructs/structs.go index a77b3a0..0232862 100644 --- a/structs/structs.go +++ b/kstructs/structs.go @@ -1,4 +1,4 @@ -package structs +package kstructs import ( "fmt" diff --git a/structs/structs_test.go b/kstructs/structs_test.go similarity index 99% rename from structs/structs_test.go rename to kstructs/structs_test.go index c11c3c1..3887be2 100644 --- a/structs/structs_test.go +++ b/kstructs/structs_test.go @@ -1,4 +1,4 @@ -package structs +package kstructs import ( "testing" diff --git a/structs/testhelpers.go b/kstructs/testhelpers.go similarity index 99% rename from structs/testhelpers.go rename to kstructs/testhelpers.go index c02c40d..caa7bb7 100644 --- a/structs/testhelpers.go +++ b/kstructs/testhelpers.go @@ -1,4 +1,4 @@ -package structs +package kstructs import ( "fmt" From 2a38ae39982e34de4400c8790c45885806505095 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 13 Jun 2021 15:11:17 -0300 Subject: [PATCH 28/40] Improve Update method to return ErrRecordNotFound if no rows were updated --- ksql.go | 18 ++++++++++++++++-- ksql_test.go | 14 ++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/ksql.go b/ksql.go index 1f9195f..b359b36 100644 --- a/ksql.go +++ b/ksql.go @@ -541,9 +541,23 @@ func (c DB) Update( return err } - _, err = c.db.ExecContext(ctx, query, params...) + result, err := c.db.ExecContext(ctx, query, params...) + if err != nil { + return err + } - return err + n, err := result.RowsAffected() + if err != nil { + return fmt.Errorf( + "unexpected error: unable to fetch how many rows were affected by the update: %s", + err, + ) + } + if n < 1 { + return ErrRecordNotFound + } + + return nil } func buildInsertQuery( diff --git a/ksql_test.go b/ksql_test.go index 3eea91a..eccec01 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -1026,6 +1026,20 @@ func TestUpdate(t *testing.T) { assert.Equal(t, 42, result.Age) }) + t.Run("should return ErrRecordNotFound when asked to update an inexistent user", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + + err = c.Update(ctx, UsersTable, User{ + ID: 4200, + Name: "Thayane", + }) + assert.Equal(t, ErrRecordNotFound, err) + }) + t.Run("should report database errors correctly", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() From 682f99b495ddfd51edefad976d92b2be6464c7e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Thu, 24 Jun 2021 21:43:21 -0300 Subject: [PATCH 29/40] Improve postgres container in docker-compose.yml Now the postgres container will create the ksql database automatically. --- docker-compose.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/docker-compose.yml b/docker-compose.yml index 6e70b12..715b450 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,6 +15,7 @@ services: environment: - POSTGRES_USER=postgres - POSTGRES_PASSWORD=postgres + - POSTGRES_DB=${DB_NAME:-ksql} mysql: image: mysql From b6e6667a3fe0f1285f4c39eb4fdae68811a2ee23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Mon, 28 Jun 2021 17:35:46 -0300 Subject: [PATCH 30/40] Improve the names of some public types *breaking change* --- README.md | 16 ++--- contracts.go | 6 +- examples/crud/crud.go | 2 +- examples/example_service/example_service.go | 4 +- .../example_service/example_service_test.go | 14 ++-- examples/example_service/mocks.go | 70 +++++++++---------- ksql.go | 23 ++++-- ksql_test.go | 10 +-- mocks.go | 24 +++---- 9 files changed, 92 insertions(+), 77 deletions(-) diff --git a/README.md b/README.md index 0762c52..ebdcf7a 100644 --- a/README.md +++ b/README.md @@ -68,8 +68,8 @@ The current interface is as follows and we plan on keeping it with as little functions as possible, so don't expect many additions: ```go -// SQLProvider describes the public behavior of this ORM -type SQLProvider interface { +// Provider describes the public behavior of this ORM +type Provider interface { Insert(ctx context.Context, table Table, record interface{}) error Update(ctx context.Context, table Table, record interface{}) error Delete(ctx context.Context, table Table, idsOrRecords ...interface{}) error @@ -79,7 +79,7 @@ type SQLProvider interface { QueryChunks(ctx context.Context, parser ChunkParser) error Exec(ctx context.Context, query string, params ...interface{}) error - Transaction(ctx context.Context, fn func(SQLProvider) error) error + Transaction(ctx context.Context, fn func(Provider) error) error } ``` @@ -232,7 +232,7 @@ func main() { } // Making transactions: - err = db.Transaction(ctx, func(db ksql.SQLProvider) error { + err = db.Transaction(ctx, func(db ksql.Provider) error { var cris2 User err = db.QueryOne(ctx, &cris2, "SELECT * FROM users WHERE id = ?", cris.ID) if err != nil { @@ -372,10 +372,10 @@ Querying a single joined row: ```golang var row struct{ - User User `tablename:"u"` // (here the tablename must match the aliased tablename in the query) - Post Post `tablename:"p"` // (if no alias is used you should use the actual name of the table) + User User `tablename:"u"` // (here the tablename must match the aliased tablename in the query) + Post Post `tablename:"posts"` // (if no alias is used you should use the actual name of the table) } -err = db.QueryOne(ctx, &row, "FROM users as u JOIN posts as p ON u.id = p.user_id WHERE u.id = ?", userID) +err = db.QueryOne(ctx, &row, "FROM users as u JOIN posts ON u.id = posts.user_id WHERE u.id = ?", userID) if err != nil { panic(err.Error()) } @@ -521,13 +521,13 @@ make test ### TODO List -- Improve error messages - Add tests for tables using composite keys - Add support for serializing structs as other formats such as YAML - Update `kstructs.FillStructWith` to work with `json` tagged attributes - Make testing easier by exposing the connection strings in an .env file - Make testing easier by automatically creating the `ksql` database - Create a way for users to submit user defined dialects +- Improve error messages ### Optimization Oportunities diff --git a/contracts.go b/contracts.go index 26866de..8a9866f 100644 --- a/contracts.go +++ b/contracts.go @@ -14,8 +14,8 @@ var ErrRecordNotFound error = errors.Wrap(sql.ErrNoRows, "ksql: the query return // ErrAbortIteration ... var ErrAbortIteration error = fmt.Errorf("ksql: abort iteration, should only be used inside QueryChunks function") -// SQLProvider describes the public behavior of this ORM -type SQLProvider interface { +// Provider describes the public behavior of this ORM +type Provider interface { Insert(ctx context.Context, table Table, record interface{}) error Update(ctx context.Context, table Table, record interface{}) error Delete(ctx context.Context, table Table, idsOrRecords ...interface{}) error @@ -25,7 +25,7 @@ type SQLProvider interface { QueryChunks(ctx context.Context, parser ChunkParser) error Exec(ctx context.Context, query string, params ...interface{}) error - Transaction(ctx context.Context, fn func(SQLProvider) error) error + Transaction(ctx context.Context, fn func(Provider) error) error } // Table describes the required information for inserting, updating and diff --git a/examples/crud/crud.go b/examples/crud/crud.go index 20e78b0..6ef5df5 100644 --- a/examples/crud/crud.go +++ b/examples/crud/crud.go @@ -137,7 +137,7 @@ func main() { } // Making transactions: - err = db.Transaction(ctx, func(db ksql.SQLProvider) error { + err = db.Transaction(ctx, func(db ksql.Provider) error { var cris2 User err = db.QueryOne(ctx, &cris2, "SELECT * FROM users WHERE id = ?", cris.ID) if err != nil { diff --git a/examples/example_service/example_service.go b/examples/example_service/example_service.go index 94ca156..bbee499 100644 --- a/examples/example_service/example_service.go +++ b/examples/example_service/example_service.go @@ -40,12 +40,12 @@ type Address struct { // Service ... type Service struct { - db ksql.SQLProvider + db ksql.Provider streamChunkSize int } // NewUserService ... -func NewUserService(db ksql.SQLProvider) Service { +func NewUserService(db ksql.Provider) Service { return Service{ db: db, streamChunkSize: 100, diff --git a/examples/example_service/example_service_test.go b/examples/example_service/example_service_test.go index 007f1c0..f19d799 100644 --- a/examples/example_service/example_service_test.go +++ b/examples/example_service/example_service_test.go @@ -7,8 +7,8 @@ import ( gomock "github.com/golang/mock/gomock" "github.com/tj/assert" "github.com/vingarcia/ksql" - "github.com/vingarcia/ksql/nullable" "github.com/vingarcia/ksql/kstructs" + "github.com/vingarcia/ksql/nullable" ) func TestCreateUser(t *testing.T) { @@ -16,7 +16,7 @@ func TestCreateUser(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - mockDB := NewMockSQLProvider(controller) + mockDB := NewMockProvider(controller) s := Service{ db: mockDB, @@ -42,7 +42,7 @@ func TestCreateUser(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - mockDB := NewMockSQLProvider(controller) + mockDB := NewMockProvider(controller) s := Service{ db: mockDB, @@ -82,7 +82,7 @@ func TestUpdateUserScore(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - mockDB := NewMockSQLProvider(controller) + mockDB := NewMockProvider(controller) s := Service{ db: mockDB, @@ -126,7 +126,7 @@ func TestListUsers(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - mockDB := NewMockSQLProvider(controller) + mockDB := NewMockProvider(controller) s := Service{ db: mockDB, @@ -188,7 +188,7 @@ func TestStreamAllUsers(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - mockDB := NewMockSQLProvider(controller) + mockDB := NewMockProvider(controller) s := Service{ db: mockDB, @@ -260,7 +260,7 @@ func TestDeleteUser(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - mockDB := NewMockSQLProvider(controller) + mockDB := NewMockProvider(controller) s := Service{ db: mockDB, diff --git a/examples/example_service/mocks.go b/examples/example_service/mocks.go index c170490..1307ab2 100644 --- a/examples/example_service/mocks.go +++ b/examples/example_service/mocks.go @@ -12,31 +12,31 @@ import ( ksql "github.com/vingarcia/ksql" ) -// MockSQLProvider is a mock of SQLProvider interface. -type MockSQLProvider struct { +// MockProvider is a mock of Provider interface. +type MockProvider struct { ctrl *gomock.Controller - recorder *MockSQLProviderMockRecorder + recorder *MockProviderMockRecorder } -// MockSQLProviderMockRecorder is the mock recorder for MockSQLProvider. -type MockSQLProviderMockRecorder struct { - mock *MockSQLProvider +// MockProviderMockRecorder is the mock recorder for MockProvider. +type MockProviderMockRecorder struct { + mock *MockProvider } -// NewMockSQLProvider creates a new mock instance. -func NewMockSQLProvider(ctrl *gomock.Controller) *MockSQLProvider { - mock := &MockSQLProvider{ctrl: ctrl} - mock.recorder = &MockSQLProviderMockRecorder{mock} +// NewMockProvider creates a new mock instance. +func NewMockProvider(ctrl *gomock.Controller) *MockProvider { + mock := &MockProvider{ctrl: ctrl} + mock.recorder = &MockProviderMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSQLProvider) EXPECT() *MockSQLProviderMockRecorder { +func (m *MockProvider) EXPECT() *MockProviderMockRecorder { return m.recorder } // Delete mocks base method. -func (m *MockSQLProvider) Delete(ctx context.Context, table ksql.Table, idsOrRecords ...interface{}) error { +func (m *MockProvider) Delete(ctx context.Context, table ksql.Table, idsOrRecords ...interface{}) error { m.ctrl.T.Helper() varargs := []interface{}{ctx, table} for _, a := range idsOrRecords { @@ -48,14 +48,14 @@ func (m *MockSQLProvider) Delete(ctx context.Context, table ksql.Table, idsOrRec } // Delete indicates an expected call of Delete. -func (mr *MockSQLProviderMockRecorder) Delete(ctx, table interface{}, idsOrRecords ...interface{}) *gomock.Call { +func (mr *MockProviderMockRecorder) Delete(ctx, table interface{}, idsOrRecords ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]interface{}{ctx, table}, idsOrRecords...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSQLProvider)(nil).Delete), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockProvider)(nil).Delete), varargs...) } // Exec mocks base method. -func (m *MockSQLProvider) Exec(ctx context.Context, query string, params ...interface{}) error { +func (m *MockProvider) Exec(ctx context.Context, query string, params ...interface{}) error { m.ctrl.T.Helper() varargs := []interface{}{ctx, query} for _, a := range params { @@ -67,14 +67,14 @@ func (m *MockSQLProvider) Exec(ctx context.Context, query string, params ...inte } // Exec indicates an expected call of Exec. -func (mr *MockSQLProviderMockRecorder) Exec(ctx, query interface{}, params ...interface{}) *gomock.Call { +func (mr *MockProviderMockRecorder) Exec(ctx, query interface{}, params ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]interface{}{ctx, query}, params...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLProvider)(nil).Exec), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockProvider)(nil).Exec), varargs...) } // Insert mocks base method. -func (m *MockSQLProvider) Insert(ctx context.Context, table ksql.Table, record interface{}) error { +func (m *MockProvider) Insert(ctx context.Context, table ksql.Table, record interface{}) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Insert", ctx, table, record) ret0, _ := ret[0].(error) @@ -82,13 +82,13 @@ func (m *MockSQLProvider) Insert(ctx context.Context, table ksql.Table, record i } // Insert indicates an expected call of Insert. -func (mr *MockSQLProviderMockRecorder) Insert(ctx, table, record interface{}) *gomock.Call { +func (mr *MockProviderMockRecorder) Insert(ctx, table, record interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockSQLProvider)(nil).Insert), ctx, table, record) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockProvider)(nil).Insert), ctx, table, record) } // Query mocks base method. -func (m *MockSQLProvider) Query(ctx context.Context, records interface{}, query string, params ...interface{}) error { +func (m *MockProvider) Query(ctx context.Context, records interface{}, query string, params ...interface{}) error { m.ctrl.T.Helper() varargs := []interface{}{ctx, records, query} for _, a := range params { @@ -100,14 +100,14 @@ func (m *MockSQLProvider) Query(ctx context.Context, records interface{}, query } // Query indicates an expected call of Query. -func (mr *MockSQLProviderMockRecorder) Query(ctx, records, query interface{}, params ...interface{}) *gomock.Call { +func (mr *MockProviderMockRecorder) Query(ctx, records, query interface{}, params ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]interface{}{ctx, records, query}, params...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockSQLProvider)(nil).Query), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockProvider)(nil).Query), varargs...) } // QueryChunks mocks base method. -func (m *MockSQLProvider) QueryChunks(ctx context.Context, parser ksql.ChunkParser) error { +func (m *MockProvider) QueryChunks(ctx context.Context, parser ksql.ChunkParser) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "QueryChunks", ctx, parser) ret0, _ := ret[0].(error) @@ -115,13 +115,13 @@ func (m *MockSQLProvider) QueryChunks(ctx context.Context, parser ksql.ChunkPars } // QueryChunks indicates an expected call of QueryChunks. -func (mr *MockSQLProviderMockRecorder) QueryChunks(ctx, parser interface{}) *gomock.Call { +func (mr *MockProviderMockRecorder) QueryChunks(ctx, parser interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryChunks", reflect.TypeOf((*MockSQLProvider)(nil).QueryChunks), ctx, parser) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryChunks", reflect.TypeOf((*MockProvider)(nil).QueryChunks), ctx, parser) } // QueryOne mocks base method. -func (m *MockSQLProvider) QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error { +func (m *MockProvider) QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error { m.ctrl.T.Helper() varargs := []interface{}{ctx, record, query} for _, a := range params { @@ -133,14 +133,14 @@ func (m *MockSQLProvider) QueryOne(ctx context.Context, record interface{}, quer } // QueryOne indicates an expected call of QueryOne. -func (mr *MockSQLProviderMockRecorder) QueryOne(ctx, record, query interface{}, params ...interface{}) *gomock.Call { +func (mr *MockProviderMockRecorder) QueryOne(ctx, record, query interface{}, params ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]interface{}{ctx, record, query}, params...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryOne", reflect.TypeOf((*MockSQLProvider)(nil).QueryOne), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryOne", reflect.TypeOf((*MockProvider)(nil).QueryOne), varargs...) } // Transaction mocks base method. -func (m *MockSQLProvider) Transaction(ctx context.Context, fn func(ksql.SQLProvider) error) error { +func (m *MockProvider) Transaction(ctx context.Context, fn func(ksql.Provider) error) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Transaction", ctx, fn) ret0, _ := ret[0].(error) @@ -148,13 +148,13 @@ func (m *MockSQLProvider) Transaction(ctx context.Context, fn func(ksql.SQLProvi } // Transaction indicates an expected call of Transaction. -func (mr *MockSQLProviderMockRecorder) Transaction(ctx, fn interface{}) *gomock.Call { +func (mr *MockProviderMockRecorder) Transaction(ctx, fn interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Transaction", reflect.TypeOf((*MockSQLProvider)(nil).Transaction), ctx, fn) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Transaction", reflect.TypeOf((*MockProvider)(nil).Transaction), ctx, fn) } // Update mocks base method. -func (m *MockSQLProvider) Update(ctx context.Context, table ksql.Table, record interface{}) error { +func (m *MockProvider) Update(ctx context.Context, table ksql.Table, record interface{}) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Update", ctx, table, record) ret0, _ := ret[0].(error) @@ -162,7 +162,7 @@ func (m *MockSQLProvider) Update(ctx context.Context, table ksql.Table, record i } // Update indicates an expected call of Update. -func (mr *MockSQLProviderMockRecorder) Update(ctx, table, record interface{}) *gomock.Call { +func (mr *MockProviderMockRecorder) Update(ctx, table, record interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockSQLProvider)(nil).Update), ctx, table, record) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockProvider)(nil).Update), ctx, table, record) } diff --git a/ksql.go b/ksql.go index b359b36..1b79e15 100644 --- a/ksql.go +++ b/ksql.go @@ -22,14 +22,19 @@ func init() { // DB represents the ksql client responsible for // interfacing with the "database/sql" package implementing -// the KissSQL interface `SQLProvider`. +// the KissSQL interface `ksql.Provider`. type DB struct { driver string dialect dialect - db sqlProvider + db DBAdapter } -type sqlProvider interface { +// DBAdapter is minimalistic interface to decouple our implementation +// from database/sql, i.e. if any struct implements the functions below +// with the exact same semantic as the sql package it will work with ksql. +// +// To create a new client using this adapter use ksql.NewWithDB() +type DBAdapter interface { ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } @@ -61,6 +66,16 @@ func New( db.SetMaxOpenConns(config.MaxOpenConns) + return newWithDB(db, dbDriver, connectionString) +} + +// NewWithDB allows the user to insert a custom implementation +// of the DBAdapter interface +func newWithDB( + db DBAdapter, + dbDriver string, + connectionString string, +) (DB, error) { dialect := supportedDialects[dbDriver] if dialect == nil { return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver) @@ -731,7 +746,7 @@ func (c DB) Exec(ctx context.Context, query string, params ...interface{}) error } // Transaction just runs an SQL command on the database returning no rows. -func (c DB) Transaction(ctx context.Context, fn func(SQLProvider) error) error { +func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error { switch db := c.db.(type) { case *sql.Tx: return fn(c) diff --git a/ksql_test.go b/ksql_test.go index eccec01..45fbc15 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -1588,7 +1588,7 @@ func TestTransaction(t *testing.T) { _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) var users []User - err = c.Transaction(ctx, func(db SQLProvider) error { + err = c.Transaction(ctx, func(db Provider) error { db.Query(ctx, &users, "SELECT * FROM users ORDER BY id ASC") return nil }) @@ -1616,7 +1616,7 @@ func TestTransaction(t *testing.T) { _ = c.Insert(ctx, UsersTable, &u1) _ = c.Insert(ctx, UsersTable, &u2) - err = c.Transaction(ctx, func(db SQLProvider) error { + err = c.Transaction(ctx, func(db Provider) error { err = db.Insert(ctx, UsersTable, &User{Name: "User3"}) assert.Equal(t, nil, err) err = db.Insert(ctx, UsersTable, &User{Name: "User4"}) @@ -1878,7 +1878,7 @@ func shiftErrSlice(errs *[]error) error { return err } -func getUsersByID(dbi sqlProvider, dialect dialect, resultsPtr *[]User, ids ...uint) error { +func getUsersByID(dbi DBAdapter, dialect dialect, resultsPtr *[]User, ids ...uint) error { db := dbi.(*sql.DB) placeholders := make([]string, len(ids)) @@ -1920,7 +1920,7 @@ func getUsersByID(dbi sqlProvider, dialect dialect, resultsPtr *[]User, ids ...u return nil } -func getUserByID(dbi sqlProvider, dialect dialect, result *User, id uint) error { +func getUserByID(dbi DBAdapter, dialect dialect, result *User, id uint) error { db := dbi.(*sql.DB) row := db.QueryRow(`SELECT id, name, age, address FROM users WHERE id=`+dialect.Placeholder(0), id) @@ -1941,7 +1941,7 @@ func getUserByID(dbi sqlProvider, dialect dialect, result *User, id uint) error return nil } -func getUserByName(dbi sqlProvider, dialect dialect, result *User, name string) error { +func getUserByName(dbi DBAdapter, dialect dialect, result *User, name string) error { db := dbi.(*sql.DB) row := db.QueryRow(`SELECT id, name, age, address FROM users WHERE name=`+dialect.Placeholder(0), name) diff --git a/mocks.go b/mocks.go index 8fc7321..97911cd 100644 --- a/mocks.go +++ b/mocks.go @@ -2,10 +2,10 @@ package ksql import "context" -var _ SQLProvider = MockSQLProvider{} +var _ Provider = Mock{} -// MockSQLProvider ... -type MockSQLProvider struct { +// Mock ... +type Mock struct { InsertFn func(ctx context.Context, table Table, record interface{}) error UpdateFn func(ctx context.Context, table Table, record interface{}) error DeleteFn func(ctx context.Context, table Table, ids ...interface{}) error @@ -15,45 +15,45 @@ type MockSQLProvider struct { QueryChunksFn func(ctx context.Context, parser ChunkParser) error ExecFn func(ctx context.Context, query string, params ...interface{}) error - TransactionFn func(ctx context.Context, fn func(db SQLProvider) error) error + TransactionFn func(ctx context.Context, fn func(db Provider) error) error } // Insert ... -func (m MockSQLProvider) Insert(ctx context.Context, table Table, record interface{}) error { +func (m Mock) Insert(ctx context.Context, table Table, record interface{}) error { return m.InsertFn(ctx, table, record) } // Update ... -func (m MockSQLProvider) Update(ctx context.Context, table Table, record interface{}) error { +func (m Mock) Update(ctx context.Context, table Table, record interface{}) error { return m.UpdateFn(ctx, table, record) } // Delete ... -func (m MockSQLProvider) Delete(ctx context.Context, table Table, ids ...interface{}) error { +func (m Mock) Delete(ctx context.Context, table Table, ids ...interface{}) error { return m.DeleteFn(ctx, table, ids...) } // Query ... -func (m MockSQLProvider) Query(ctx context.Context, records interface{}, query string, params ...interface{}) error { +func (m Mock) Query(ctx context.Context, records interface{}, query string, params ...interface{}) error { return m.QueryFn(ctx, records, query, params...) } // QueryOne ... -func (m MockSQLProvider) QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error { +func (m Mock) QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error { return m.QueryOneFn(ctx, record, query, params...) } // QueryChunks ... -func (m MockSQLProvider) QueryChunks(ctx context.Context, parser ChunkParser) error { +func (m Mock) QueryChunks(ctx context.Context, parser ChunkParser) error { return m.QueryChunksFn(ctx, parser) } // Exec ... -func (m MockSQLProvider) Exec(ctx context.Context, query string, params ...interface{}) error { +func (m Mock) Exec(ctx context.Context, query string, params ...interface{}) error { return m.ExecFn(ctx, query, params...) } // Transaction ... -func (m MockSQLProvider) Transaction(ctx context.Context, fn func(db SQLProvider) error) error { +func (m Mock) Transaction(ctx context.Context, fn func(db Provider) error) error { return m.TransactionFn(ctx, fn) } From f420553e0b02ebd286c44c77b79aa356caabb3cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Thu, 15 Jul 2021 23:14:20 -0300 Subject: [PATCH 31/40] Remove risk of panic on Insert() This panic used to happen if the user configured a table to use a specific ID column then tried to insert to the database with a struct that did not have that column. --- ksql.go | 7 +++- ksql_test.go | 91 +++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 96 insertions(+), 2 deletions(-) diff --git a/ksql.go b/ksql.go index 1b79e15..0e87e26 100644 --- a/ksql.go +++ b/ksql.go @@ -598,8 +598,13 @@ func buildInsertQuery( } for _, fieldName := range idNames { + field, found := recordMap[fieldName] + if !found { + continue + } + // Remove any ID field that was not set: - if reflect.ValueOf(recordMap[fieldName]).IsZero() { + if reflect.ValueOf(field).IsZero() { delete(recordMap, fieldName) } } diff --git a/ksql_test.go b/ksql_test.go index 45fbc15..f57ba46 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -588,7 +588,7 @@ func TestQueryOne(t *testing.T) { func TestInsert(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { - t.Run("using slice of structs", func(t *testing.T) { + t.Run("success cases", func(t *testing.T) { err := createTables(driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) @@ -656,6 +656,41 @@ func TestInsert(t *testing.T) { assert.Equal(t, u.Age, result.Age) assert.Equal(t, u.Address, result.Address) }) + + t.Run("should work with anonymous structs", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + err = c.Insert(ctx, UsersTable, &struct { + ID int `ksql:"id"` + Name string `ksql:"name"` + Address map[string]interface{} `ksql:"address,json"` + }{Name: "fake-name", Address: map[string]interface{}{"city": "bar"}}) + assert.Equal(t, nil, err) + }) + + t.Run("should work with preset IDs", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + + usersByName := NewTable("users", "name") + + err = c.Insert(ctx, usersByName, &struct { + Name string `ksql:"name"` + Age int `ksql:"age"` + }{Name: "Preset Name", Age: 5455}) + assert.Equal(t, nil, err) + + var inserted User + err := getUserByName(db, c.dialect, &inserted, "Preset Name") + assert.Equal(t, nil, err) + assert.Equal(t, 5455, inserted.Age) + }) }) t.Run("testing error cases", func(t *testing.T) { @@ -708,6 +743,60 @@ func TestInsert(t *testing.T) { err = c.Insert(ctx, UsersTable, &User{Name: "foo"}) assert.NotEqual(t, nil, err) }) + + t.Run("should not panic if a column doesn't exist in the database", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + + err = c.Insert(ctx, UsersTable, &struct { + ID string `ksql:"id"` + NonExistingColumn int `ksql:"non_existing"` + Name string `ksql:"name"` + }{NonExistingColumn: 42, Name: "fake-name"}) + assert.NotEqual(t, nil, err) + msg := err.Error() + assert.Equal(t, true, strings.Contains(msg, "column")) + assert.Equal(t, true, strings.Contains(msg, "non_existing")) + }) + + t.Run("should not panic if the ID column doesn't exist in the database", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + + brokenTable := NewTable("users", "non_existing_id") + + _ = c.Insert(ctx, brokenTable, &struct { + ID string `ksql:"non_existing_id"` + Age int `ksql:"age"` + Name string `ksql:"name"` + }{Age: 42, Name: "fake-name"}) + }) + + t.Run("should not panic if the ID column is missing in the struct", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + + err = c.Insert(ctx, UsersTable, &struct { + Age int `ksql:"age"` + Name string `ksql:"name"` + }{Age: 42, Name: "Inserted With no ID"}) + assert.Equal(t, nil, err) + + var u User + err = getUserByName(db, c.dialect, &u, "Inserted With no ID") + assert.Equal(t, nil, err) + assert.NotEqual(t, uint(0), u.ID) + assert.Equal(t, 42, u.Age) + }) }) }) } From e73db4a216b57c39669c1c99aa770fe7fa0199cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sat, 31 Jul 2021 18:54:57 -0300 Subject: [PATCH 32/40] Abstract the DBAdapter so that we can support other sql adapters This was done for a few different reasons: 1. This allows us to work on top of the pgx client in the future 2. This would allow our users to implement their own DBAdapters to use with our tool. 3. This gives the users the option of using advanced configs of any sql client they want to use and just feed us with it after the configuration is done, which means we will not have to worry about supporting a growing number of configurations as we try to add support to more drivers or if we get issues asking for more advanced config options. --- README.md | 6 +++++- adapters.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ ksql.go | 48 ++++++++++++++++++++++++++++++++++++++---------- ksql_test.go | 18 +++++++++--------- 4 files changed, 99 insertions(+), 20 deletions(-) create mode 100644 adapters.go diff --git a/README.md b/README.md index ebdcf7a..ca1f2ad 100644 --- a/README.md +++ b/README.md @@ -523,11 +523,15 @@ make test - Add tests for tables using composite keys - Add support for serializing structs as other formats such as YAML -- Update `kstructs.FillStructWith` to work with `json` tagged attributes +- Update `kstructs.FillStructWith` to work with `ksql:"..,json"` tagged attributes - Make testing easier by exposing the connection strings in an .env file - Make testing easier by automatically creating the `ksql` database - Create a way for users to submit user defined dialects - Improve error messages +- Add support for the update function to work with maps for partial updates +- Add support for the insert function to work with maps +- Add support for a `ksql.Array(params ...interface{})` for allowing queries like this: + `db.Query(ctx, &user, "SELECT * FROM user WHERE id in (?)", ksql.Array(1,2,3))` ### Optimization Oportunities diff --git a/adapters.go b/adapters.go new file mode 100644 index 0000000..2e041a2 --- /dev/null +++ b/adapters.go @@ -0,0 +1,47 @@ +package ksql + +import ( + "context" + "database/sql" +) + +// SQLAdapter adapts the sql.DB type to be compatible with the `DBAdapter` interface +type SQLAdapter struct { + *sql.DB +} + +var _ DBAdapter = SQLAdapter{} + +// ExecContext implements the DBAdapter interface +func (s SQLAdapter) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { + return s.DB.ExecContext(ctx, query, args...) +} + +// QueryContext implements the DBAdapter interface +func (s SQLAdapter) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) { + return s.DB.QueryContext(ctx, query, args...) +} + +// SQLTx is used to implement the DBAdapter interface and implements +// the Tx interface +type SQLTx struct { + *sql.Tx +} + +// ExecContext implements the Tx interface +func (s SQLTx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { + return s.Tx.ExecContext(ctx, query, args...) +} + +// QueryContext implements the Tx interface +func (s SQLTx) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) { + return s.Tx.QueryContext(ctx, query, args...) +} + +var _ Tx = SQLTx{} + +// BeginTx implements the Tx interface +func (s SQLAdapter) BeginTx(ctx context.Context) (Tx, error) { + tx, err := s.DB.BeginTx(ctx, nil) + return SQLTx{Tx: tx}, err +} diff --git a/ksql.go b/ksql.go index 0e87e26..1477d25 100644 --- a/ksql.go +++ b/ksql.go @@ -35,8 +35,36 @@ type DB struct { // // To create a new client using this adapter use ksql.NewWithDB() type DBAdapter interface { - ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) - QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) +} + +// TxBeginner needs to be implemented by the DBAdapter in order to make it possible +// to use the `ksql.Transaction()` function. +type TxBeginner interface { + BeginTx(ctx context.Context) (Tx, error) +} + +// Result stores information about the result of an Exec query +type Result = sql.Result + +// Rows represents the results from a call to Query() +type Rows interface { + Scan(...interface{}) error + Close() error + Next() bool + Err() error + Columns() ([]string, error) +} + +var _ Rows = &sql.Rows{} + +// Tx represents a transaction and is expected to be returned by the DBAdapter.BeginTx function +type Tx interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) + Rollback() error + Commit() error } // Config describes the optional arguments accepted @@ -66,7 +94,7 @@ func New( db.SetMaxOpenConns(config.MaxOpenConns) - return newWithDB(db, dbDriver, connectionString) + return newWithDB(SQLAdapter{db}, dbDriver, connectionString) } // NewWithDB allows the user to insert a custom implementation @@ -752,11 +780,11 @@ func (c DB) Exec(ctx context.Context, query string, params ...interface{}) error // Transaction just runs an SQL command on the database returning no rows. func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error { - switch db := c.db.(type) { - case *sql.Tx: + switch txBeginner := c.db.(type) { + case Tx: return fn(c) - case *sql.DB: - tx, err := db.BeginTx(ctx, nil) + case TxBeginner: + tx, err := txBeginner.BeginTx(ctx) if err != nil { return err } @@ -789,7 +817,7 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error { return tx.Commit() default: - return fmt.Errorf("unexpected error on ksql: db attribute has an invalid type") + return fmt.Errorf("can't start transaction: The DBAdapter doesn't implement the TxBegginner interface") } } @@ -801,7 +829,7 @@ func (nopScanner) Scan(value interface{}) error { return nil } -func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error { +func scanRows(dialect dialect, rows Rows, record interface{}) error { v := reflect.ValueOf(record) t := v.Type() if t.Kind() != reflect.Ptr { @@ -836,7 +864,7 @@ func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error { return rows.Scan(scanArgs...) } -func getScanArgsForNestedStructs(dialect dialect, rows *sql.Rows, t reflect.Type, v reflect.Value, info kstructs.StructInfo) []interface{} { +func getScanArgsForNestedStructs(dialect dialect, rows Rows, t reflect.Type, v reflect.Value, info kstructs.StructInfo) []interface{} { scanArgs := []interface{}{} for i := 0; i < v.NumField(); i++ { // TODO(vingarcia00): Handle case where type is pointer diff --git a/ksql_test.go b/ksql_test.go index f57ba46..d940800 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -687,7 +687,7 @@ func TestInsert(t *testing.T) { assert.Equal(t, nil, err) var inserted User - err := getUserByName(db, c.dialect, &inserted, "Preset Name") + err := getUserByName(SQLAdapter{db}, c.dialect, &inserted, "Preset Name") assert.Equal(t, nil, err) assert.Equal(t, 5455, inserted.Age) }) @@ -792,7 +792,7 @@ func TestInsert(t *testing.T) { assert.Equal(t, nil, err) var u User - err = getUserByName(db, c.dialect, &u, "Inserted With no ID") + err = getUserByName(SQLAdapter{db}, c.dialect, &u, "Inserted With no ID") assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) assert.Equal(t, 42, u.Age) @@ -1944,7 +1944,7 @@ func newTestDB(db *sql.DB, driver string) DB { return DB{ driver: driver, dialect: supportedDialects[driver], - db: db, + db: SQLAdapter{db}, } } @@ -1968,7 +1968,7 @@ func shiftErrSlice(errs *[]error) error { } func getUsersByID(dbi DBAdapter, dialect dialect, resultsPtr *[]User, ids ...uint) error { - db := dbi.(*sql.DB) + db := dbi.(SQLAdapter) placeholders := make([]string, len(ids)) params := make([]interface{}, len(ids)) @@ -1978,7 +1978,7 @@ func getUsersByID(dbi DBAdapter, dialect dialect, resultsPtr *[]User, ids ...uin } results := []User{} - rows, err := db.Query( + rows, err := db.DB.Query( fmt.Sprintf( "SELECT id, name, age FROM users WHERE id IN (%s)", strings.Join(placeholders, ", "), @@ -2010,9 +2010,9 @@ func getUsersByID(dbi DBAdapter, dialect dialect, resultsPtr *[]User, ids ...uin } func getUserByID(dbi DBAdapter, dialect dialect, result *User, id uint) error { - db := dbi.(*sql.DB) + db := dbi.(SQLAdapter) - row := db.QueryRow(`SELECT id, name, age, address FROM users WHERE id=`+dialect.Placeholder(0), id) + row := db.DB.QueryRow(`SELECT id, name, age, address FROM users WHERE id=`+dialect.Placeholder(0), id) if row.Err() != nil { return row.Err() } @@ -2031,9 +2031,9 @@ func getUserByID(dbi DBAdapter, dialect dialect, result *User, id uint) error { } func getUserByName(dbi DBAdapter, dialect dialect, result *User, name string) error { - db := dbi.(*sql.DB) + db := dbi.(SQLAdapter) - row := db.QueryRow(`SELECT id, name, age, address FROM users WHERE name=`+dialect.Placeholder(0), name) + row := db.DB.QueryRow(`SELECT id, name, age, address FROM users WHERE name=`+dialect.Placeholder(0), name) if row.Err() != nil { return row.Err() } From c1a44c8e569ff3a4fc1532e81e1f6d8b5c997430 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sat, 31 Jul 2021 19:09:05 -0300 Subject: [PATCH 33/40] Update benchmark on README --- Makefile | 2 ++ README.md | 16 +++++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index 7a17471..5b72354 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,8 @@ test: setup bench: go test -bench=. -benchtime=$(TIME) + @echo "Benchmark executed at: $$(date --iso)" + @echo "Benchmark executed on commit: $$(git rev-parse HEAD)" lint: setup @$(GOBIN)/golint -set_exit_status -min_confidence 0.9 $(path) $(args) diff --git a/README.md b/README.md index ca1f2ad..c716f9b 100644 --- a/README.md +++ b/README.md @@ -488,14 +488,16 @@ goos: linux goarch: amd64 pkg: github.com/vingarcia/ksql cpu: Intel(R) Core(TM) i5-3210M CPU @ 2.50GHz -BenchmarkInsert/ksql-setup/insert-one-4 4970 727724 ns/op -BenchmarkInsert/sqlx-setup/insert-one-4 4842 703503 ns/op -BenchmarkQuery/ksql-setup/single-row-4 12692 282544 ns/op -BenchmarkQuery/ksql-setup/multiple-rows-4 10000 313662 ns/op -BenchmarkQuery/sqlx-setup/single-row-4 12328 291965 ns/op -BenchmarkQuery/sqlx-setup/multiple-rows-4 10000 301910 ns/op +BenchmarkInsert/ksql-setup/insert-one-4 4442 862525 ns/op +BenchmarkInsert/sqlx-setup/insert-one-4 4269 854837 ns/op +BenchmarkQuery/ksql-setup/single-row-4 10000 325756 ns/op +BenchmarkQuery/ksql-setup/multiple-rows-4 10000 339198 ns/op +BenchmarkQuery/sqlx-setup/single-row-4 11764 305418 ns/op +BenchmarkQuery/sqlx-setup/multiple-rows-4 9534 322344 ns/op PASS -ok github.com/vingarcia/ksql 39.995s +ok github.com/vingarcia/ksql 46.143s +Benchmark executed at: 2021-07-31 +Benchmark executed on commit: e73db4a216b57c39669c1c99aa770fe7fa0199cc ``` ### Running the ksql tests (for contributors) From 657ed7414bf9e84a6ea8344d35044be00348aea2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 1 Aug 2021 10:23:10 -0300 Subject: [PATCH 34/40] Add first version of the pgx adapter --- go.mod | 5 +- go.sum | 161 ++++++++++++++++++++++++++++++++-- ksql.go | 48 +++++++--- pgx_adapter.go | 106 ++++++++++++++++++++++ adapters.go => sql_adapter.go | 22 +++-- 5 files changed, 318 insertions(+), 24 deletions(-) create mode 100644 pgx_adapter.go rename adapters.go => sql_adapter.go (84%) diff --git a/go.mod b/go.mod index 3dac923..4fa2edf 100644 --- a/go.mod +++ b/go.mod @@ -7,11 +7,12 @@ require ( github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018 github.com/go-sql-driver/mysql v1.4.0 github.com/golang/mock v1.5.0 + github.com/jackc/pgconn v1.10.0 // indirect + github.com/jackc/pgx/v4 v4.13.0 github.com/jmoiron/sqlx v1.2.0 - github.com/lib/pq v1.1.1 + github.com/lib/pq v1.10.2 github.com/mattn/go-sqlite3 v1.14.6 github.com/pkg/errors v0.9.1 - github.com/stretchr/testify v1.6.1 github.com/tj/assert v0.0.3 google.golang.org/appengine v1.6.7 // indirect ) diff --git a/go.sum b/go.sum index b80fabb..af807c8 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,11 @@ +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= +github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= +github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -5,53 +13,196 @@ github.com/denisenkom/go-mssqldb v0.10.0 h1:QykgLZBorFE95+gO3u9esLd0BmbvpWp0/waN github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018 h1:QsFkVafcKOaZoAB4WcyUHdkPbwh+VYwZgYJb/rU6EIM= github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018/go.mod h1:5C3SWkut69TSdkerzRDxXMRM5x73PGWNcRLe/xKjXhs= +github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= +github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-sql-driver/mysql v1.4.0 h1:7LxgVwFb2hIQtMm87NdgAVfXjnt4OePseqT1tKx+opk= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= +github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/mock v1.5.0 h1:jlYHihg//f7RRwuPfptm04yp4s7O6Kw8EZiVYIGcH0g= github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= +github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= +github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgconn v1.10.0 h1:4EYhlDVEMsJ30nNj0mmgwIUXoq7e9sMJrVC2ED6QlCU= +github.com/jackc/pgconn v1.10.0/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= +github.com/jackc/pgtype v1.8.1 h1:9k0IXtdJXHJbyAWQgbWr1lU+MEhPXZz6RIXxfR5oxXs= +github.com/jackc/pgtype v1.8.1/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= +github.com/jackc/pgx/v4 v4.13.0 h1:JCjhT5vmhMAf/YwBHLvrBn4OGdIQBiFG6ym8Zmdx570= +github.com/jackc/pgx/v4 v4.13.0/go.mod h1:9P4X524sErlaxj0XSGZk7s+LD0eOyu1ZDUrrpznYDF0= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.3 h1:JnPg/5Q9xVJGfjsO5CPUOjnJps1JaRUm8I9FXVCFK94= +github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jmoiron/sqlx v1.2.0 h1:41Ip0zITnmWNR/vHV+S4m+VoUivnWY5E4OJfLZjCJMA= github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4= -github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tj/assert v0.0.3 h1:Df/BlaZ20mq6kuai7f5z2TvPFiwC3xaWJSDQNiIS3Rk= github.com/tj/assert v0.0.3/go.mod h1:Ne6X72Q+TB1AteidzQncjw9PabbMp4PBMZ1k+vd1Pvk= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c h1:grhR+C34yXImVGp7EzNk+DTIk+323eIUWOmEevy6bDo= gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/ksql.go b/ksql.go index 1477d25..2f0bba9 100644 --- a/ksql.go +++ b/ksql.go @@ -8,6 +8,7 @@ import ( "strings" "unicode" + "github.com/jackc/pgx/v4/pgxpool" "github.com/pkg/errors" "github.com/vingarcia/ksql/kstructs" ) @@ -46,7 +47,10 @@ type TxBeginner interface { } // Result stores information about the result of an Exec query -type Result = sql.Result +type Result interface { + LastInsertId() (int64, error) + RowsAffected() (int64, error) +} // Rows represents the results from a call to Query() type Rows interface { @@ -57,14 +61,12 @@ type Rows interface { Columns() ([]string, error) } -var _ Rows = &sql.Rows{} - // Tx represents a transaction and is expected to be returned by the DBAdapter.BeginTx function type Tx interface { ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) - Rollback() error - Commit() error + Rollback(ctx context.Context) error + Commit(ctx context.Context) error } // Config describes the optional arguments accepted @@ -94,12 +96,36 @@ func New( db.SetMaxOpenConns(config.MaxOpenConns) - return newWithDB(SQLAdapter{db}, dbDriver, connectionString) + return NewWithAdapter(SQLAdapter{db}, dbDriver, connectionString) } -// NewWithDB allows the user to insert a custom implementation +// NewWithPGX instantiates a new KissSQL client using the pgx +// library in the backend +// +// Configurations such as max open connections can be passed through +// the URL using the pgxpool `Config.ConnString()` or building the URL manually. +// +// More info at: https://pkg.go.dev/github.com/jackc/pgx/v4/pgxpool#Config +func NewWithPGX( + ctx context.Context, + dbDriver string, + connectionString string, +) (db DB, err error) { + pool, err := pgxpool.Connect(ctx, connectionString) + if err != nil { + return DB{}, err + } + if err = pool.Ping(ctx); err != nil { + return DB{}, err + } + + db, err = NewWithAdapter(PGXAdapter{pool}, dbDriver, connectionString) + return db, err +} + +// NewWithAdapter allows the user to insert a custom implementation // of the DBAdapter interface -func newWithDB( +func NewWithAdapter( db DBAdapter, dbDriver string, connectionString string, @@ -790,7 +816,7 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error { } defer func() { if r := recover(); r != nil { - rollbackErr := tx.Rollback() + rollbackErr := tx.Rollback(ctx) if rollbackErr != nil { r = errors.Wrap(rollbackErr, fmt.Sprintf("unable to rollback after panic with value: %v", r), @@ -805,7 +831,7 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error { err = fn(ormCopy) if err != nil { - rollbackErr := tx.Rollback() + rollbackErr := tx.Rollback(ctx) if rollbackErr != nil { err = errors.Wrap(rollbackErr, fmt.Sprintf("unable to rollback after error: %s", err.Error()), @@ -814,7 +840,7 @@ func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error { return err } - return tx.Commit() + return tx.Commit(ctx) default: return fmt.Errorf("can't start transaction: The DBAdapter doesn't implement the TxBegginner interface") diff --git a/pgx_adapter.go b/pgx_adapter.go new file mode 100644 index 0000000..944dff7 --- /dev/null +++ b/pgx_adapter.go @@ -0,0 +1,106 @@ +package ksql + +import ( + "context" + "fmt" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgxpool" +) + +// PGXAdapter adapts the sql.DB type to be compatible with the `DBAdapter` interface +type PGXAdapter struct { + db *pgxpool.Pool +} + +var _ DBAdapter = PGXAdapter{} + +// ExecContext implements the DBAdapter interface +func (p PGXAdapter) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { + result, err := p.db.Exec(ctx, query, args...) + return PGXResult{result}, err +} + +// QueryContext implements the DBAdapter interface +func (p PGXAdapter) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) { + rows, err := p.db.Query(ctx, query, args...) + return PGXRows{rows}, err +} + +// BeginTx implements the Tx interface +func (p PGXAdapter) BeginTx(ctx context.Context) (Tx, error) { + tx, err := p.db.Begin(ctx) + return PGXTx{tx}, err +} + +// PGXResult is used to implement the DBAdapter interface and implements +// the Result interface +type PGXResult struct { + tag pgconn.CommandTag +} + +// RowsAffected implements the Result interface +func (p PGXResult) RowsAffected() (int64, error) { + return p.tag.RowsAffected(), nil +} + +// LastInsertId implements the Result interface +func (p PGXResult) LastInsertId() (int64, error) { + return 0, fmt.Errorf( + "LastInsertId is not implemented in the pgx adapter, use the `RETURNING` statement instead", + ) +} + +// PGXTx is used to implement the DBAdapter interface and implements +// the Tx interface +type PGXTx struct { + tx pgx.Tx +} + +// ExecContext implements the Tx interface +func (p PGXTx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { + result, err := p.tx.Exec(ctx, query, args...) + return PGXResult{result}, err +} + +// QueryContext implements the Tx interface +func (p PGXTx) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) { + rows, err := p.tx.Query(ctx, query, args...) + return PGXRows{rows}, err +} + +// Rollback implements the Tx interface +func (p PGXTx) Rollback(ctx context.Context) error { + return p.tx.Rollback(ctx) +} + +// Commit implements the Tx interface +func (p PGXTx) Commit(ctx context.Context) error { + return p.tx.Commit(ctx) +} + +var _ Tx = PGXTx{} + +// PGXRows implements the Rows interface and is used to help +// the PGXAdapter to implement the DBAdapter interface. +type PGXRows struct { + pgx.Rows +} + +var _ Rows = PGXRows{} + +// Columns implements the Rows interface +func (p PGXRows) Columns() ([]string, error) { + var names []string + for _, desc := range p.Rows.FieldDescriptions() { + names = append(names, string(desc.Name)) + } + return names, nil +} + +// Close implements the Rows interface +func (p PGXRows) Close() error { + p.Rows.Close() + return nil +} diff --git a/adapters.go b/sql_adapter.go similarity index 84% rename from adapters.go rename to sql_adapter.go index 2e041a2..b3fa8a1 100644 --- a/adapters.go +++ b/sql_adapter.go @@ -22,6 +22,12 @@ func (s SQLAdapter) QueryContext(ctx context.Context, query string, args ...inte return s.DB.QueryContext(ctx, query, args...) } +// BeginTx implements the Tx interface +func (s SQLAdapter) BeginTx(ctx context.Context) (Tx, error) { + tx, err := s.DB.BeginTx(ctx, nil) + return SQLTx{Tx: tx}, err +} + // SQLTx is used to implement the DBAdapter interface and implements // the Tx interface type SQLTx struct { @@ -38,10 +44,14 @@ func (s SQLTx) QueryContext(ctx context.Context, query string, args ...interface return s.Tx.QueryContext(ctx, query, args...) } -var _ Tx = SQLTx{} - -// BeginTx implements the Tx interface -func (s SQLAdapter) BeginTx(ctx context.Context) (Tx, error) { - tx, err := s.DB.BeginTx(ctx, nil) - return SQLTx{Tx: tx}, err +// Rollback implements the Tx interface +func (s SQLTx) Rollback(ctx context.Context) error { + return s.Tx.Rollback() } + +// Commit implements the Tx interface +func (s SQLTx) Commit(ctx context.Context) error { + return s.Tx.Commit() +} + +var _ Tx = SQLTx{} From 5d6f1d7d37829045d7218a1b8f629f4c0fce9345 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 1 Aug 2021 11:51:40 -0300 Subject: [PATCH 35/40] Improve NewWithPGX() constructor --- ksql.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/ksql.go b/ksql.go index 2f0bba9..049fec4 100644 --- a/ksql.go +++ b/ksql.go @@ -101,17 +101,20 @@ func New( // NewWithPGX instantiates a new KissSQL client using the pgx // library in the backend -// -// Configurations such as max open connections can be passed through -// the URL using the pgxpool `Config.ConnString()` or building the URL manually. -// -// More info at: https://pkg.go.dev/github.com/jackc/pgx/v4/pgxpool#Config func NewWithPGX( ctx context.Context, dbDriver string, connectionString string, + config Config, ) (db DB, err error) { - pool, err := pgxpool.Connect(ctx, connectionString) + pgxConf, err := pgxpool.ParseConfig(connectionString) + if err != nil { + return DB{}, err + } + + pgxConf.MaxConns = int32(config.MaxOpenConns) + + pool, err := pgxpool.ConnectConfig(ctx, pgxConf) if err != nil { return DB{}, err } From 87f57f665fc4bb18f3cfb543398bb96f665f68af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 1 Aug 2021 11:52:04 -0300 Subject: [PATCH 36/40] Add code for benchmarking the PGX adapter --- benchmark_test.go | 68 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/benchmark_test.go b/benchmark_test.go index a31128c..0b52dda 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -52,6 +52,32 @@ func BenchmarkInsert(b *testing.B) { }) }) + pgxDB, err := ksql.NewWithPGX(ctx, driver, connStr, ksql.Config{ + MaxOpenConns: 1, + }) + if err != nil { + b.Fatalf("error creating pgx client: %s", err) + } + + b.Run("pgx-adapter-setup", func(b *testing.B) { + err := recreateTable(connStr) + if err != nil { + b.Fatalf("error creating table: %s", err.Error()) + } + + b.Run("insert-one", func(b *testing.B) { + for i := 0; i < b.N; i++ { + err := pgxDB.Insert(ctx, UsersTable, &User{ + Name: strconv.Itoa(i), + Age: i, + }) + if err != nil { + b.Fatalf("insert error: %s", err.Error()) + } + } + }) + }) + sqlxDB, err := sqlx.Open(driver, connStr) sqlxDB.SetMaxOpenConns(1) @@ -139,6 +165,48 @@ func BenchmarkQuery(b *testing.B) { }) }) + pgxDB, err := ksql.NewWithPGX(ctx, driver, connStr, ksql.Config{ + MaxOpenConns: 1, + }) + if err != nil { + b.Fatalf("error creating pgx client: %s", err) + } + + b.Run("pgx-adapter-setup", func(b *testing.B) { + err := recreateTable(connStr) + if err != nil { + b.Fatalf("error creating table: %s", err.Error()) + } + + err = insertUsers(connStr, 100) + if err != nil { + b.Fatalf("error inserting users: %s", err.Error()) + } + + b.Run("single-row", func(b *testing.B) { + for i := 0; i < b.N; i++ { + var user User + err := pgxDB.QueryOne(ctx, &user, `SELECT * FROM users OFFSET $1 LIMIT 1`, i%100) + if err != nil { + b.Fatalf("query error: %s", err.Error()) + } + } + }) + + b.Run("multiple-rows", func(b *testing.B) { + for i := 0; i < b.N; i++ { + var users []User + err := pgxDB.Query(ctx, &users, `SELECT * FROM users OFFSET $1 LIMIT 10`, i%90) + if err != nil { + b.Fatalf("query error: %s", err.Error()) + } + if len(users) < 10 { + b.Fatalf("expected 10 scanned users, but got: %d", len(users)) + } + } + }) + }) + sqlxDB, err := sqlx.Open(driver, connStr) sqlxDB.SetMaxOpenConns(1) From ec749af84c86172b149c8b9e0f64dd8d79a52d9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 1 Aug 2021 11:57:03 -0300 Subject: [PATCH 37/40] Update benchmarks by running the PGx adapter benchmarks --- README.md | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index c716f9b..00bb99a 100644 --- a/README.md +++ b/README.md @@ -488,16 +488,19 @@ goos: linux goarch: amd64 pkg: github.com/vingarcia/ksql cpu: Intel(R) Core(TM) i5-3210M CPU @ 2.50GHz -BenchmarkInsert/ksql-setup/insert-one-4 4442 862525 ns/op -BenchmarkInsert/sqlx-setup/insert-one-4 4269 854837 ns/op -BenchmarkQuery/ksql-setup/single-row-4 10000 325756 ns/op -BenchmarkQuery/ksql-setup/multiple-rows-4 10000 339198 ns/op -BenchmarkQuery/sqlx-setup/single-row-4 11764 305418 ns/op -BenchmarkQuery/sqlx-setup/multiple-rows-4 9534 322344 ns/op +BenchmarkInsert/ksql-setup/insert-one-4 4170 880911 ns/op +BenchmarkInsert/pgx-adapter-setup/insert-one-4 5780 669992 ns/op +BenchmarkInsert/sqlx-setup/insert-one-4 4389 825303 ns/op +BenchmarkQuery/ksql-setup/single-row-4 1710 306996 ns/op +BenchmarkQuery/ksql-setup/multiple-rows-4 0000 345091 ns/op +BenchmarkQuery/pgx-adapter-setup/single-row-4 5237 140154 ns/op +BenchmarkQuery/pgx-adapter-setup/multiple-rows-4 22051 164306 ns/op +BenchmarkQuery/sqlx-setup/single-row-4 11955 311654 ns/op +BenchmarkQuery/sqlx-setup/multiple-rows-4 10000 323079 ns/op PASS -ok github.com/vingarcia/ksql 46.143s -Benchmark executed at: 2021-07-31 -Benchmark executed on commit: e73db4a216b57c39669c1c99aa770fe7fa0199cc +ok github.com/vingarcia/ksql 55.231s +Benchmark executed at: 2021-08-01 +Benchmark executed on commit: 87f57f665fc4bb18f3cfb543398bb96f665f68af ``` ### Running the ksql tests (for contributors) From 5c2b9816966ffe93f78212b442b8f19b304bd0ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 1 Aug 2021 14:42:39 -0300 Subject: [PATCH 38/40] Add tests to the pgx adapter --- ksql_test.go | 726 ++++++++++++++++++++++++++++----------------------- 1 file changed, 397 insertions(+), 329 deletions(-) diff --git a/ksql_test.go b/ksql_test.go index d940800..18df87c 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -6,12 +6,14 @@ import ( "encoding/json" "errors" "fmt" + "io" "strings" "testing" _ "github.com/denisenkom/go-mssqldb" "github.com/ditointernet/go-assert" _ "github.com/go-sql-driver/mysql" + "github.com/jackc/pgx/v4/pgxpool" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" "github.com/vingarcia/ksql/nullable" @@ -44,9 +46,37 @@ type Post struct { var PostsTable = NewTable("posts") +type testConfig struct { + driver string + adapterName string +} + +var supportedConfigs = []testConfig{ + { + driver: "sqlite3", + adapterName: "sql", + }, + { + driver: "postgres", + adapterName: "sql", + }, + { + driver: "mysql", + adapterName: "sql", + }, + { + driver: "sqlserver", + adapterName: "sql", + }, + { + driver: "postgres", + adapterName: "pgx", + }, +} + func TestQuery(t *testing.T) { - for driver := range supportedDialects { - t.Run(driver, func(t *testing.T) { + for _, config := range supportedConfigs { + t.Run(config.driver, func(t *testing.T) { variations := []struct { desc string queryPrefix string @@ -63,17 +93,17 @@ func TestQuery(t *testing.T) { for _, variation := range variations { t.Run(variation.desc, func(t *testing.T) { t.Run("using slice of structs", func(t *testing.T) { - err := createTables(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should return 0 results correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) var users []User err := c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`) assert.Equal(t, nil, err) @@ -86,14 +116,14 @@ func TestQuery(t *testing.T) { }) t.Run("should return a user correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) + _, err := db.ExecContext(context.TODO(), `INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) var users []User err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") @@ -105,17 +135,17 @@ func TestQuery(t *testing.T) { }) t.Run("should return multiple users correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`) + _, err := db.ExecContext(context.TODO(), `INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`) assert.Equal(t, nil, err) - _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`) + _, err = db.ExecContext(context.TODO(), `INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`) assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) var users []User err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia") @@ -132,33 +162,34 @@ func TestQuery(t *testing.T) { }) t.Run("should query joined tables correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() // This test only makes sense with no query prefix if variation.queryPrefix != "" { return } - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`) + _, err := db.ExecContext(context.TODO(), `INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + var joao User + getUserByName(db, config.driver, &joao, "João Ribeiro") assert.Equal(t, nil, err) - var joaoID uint - db.QueryRow(`SELECT id FROM users WHERE name = 'João Ribeiro'`).Scan(&joaoID) - _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia Ribeiro', 0, '{"country":"BR"}')`) + _, err = db.ExecContext(context.TODO(), `INSERT INTO users (name, age, address) VALUES ('Bia Ribeiro', 0, '{"country":"BR"}')`) assert.Equal(t, nil, err) - var biaID uint - db.QueryRow(`SELECT id FROM users WHERE name = 'Bia Ribeiro'`).Scan(&biaID) + var bia User + getUserByName(db, config.driver, &bia, "Bia Ribeiro") - _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, biaID, `, 'Bia Post1')`)) + _, err = db.ExecContext(context.TODO(), fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, bia.ID, `, 'Bia Post1')`)) assert.Equal(t, nil, err) - _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, biaID, `, 'Bia Post2')`)) + _, err = db.ExecContext(context.TODO(), fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, bia.ID, `, 'Bia Post2')`)) assert.Equal(t, nil, err) - _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joaoID, `, 'João Post1')`)) + _, err = db.ExecContext(context.TODO(), fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'João Post1')`)) assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) var rows []struct { User User `tablename:"u"` Post Post `tablename:"p"` @@ -172,32 +203,32 @@ func TestQuery(t *testing.T) { assert.Equal(t, nil, err) assert.Equal(t, 3, len(rows)) - assert.Equal(t, joaoID, rows[0].User.ID) + assert.Equal(t, joao.ID, rows[0].User.ID) assert.Equal(t, "João Ribeiro", rows[0].User.Name) assert.Equal(t, "João Post1", rows[0].Post.Title) - assert.Equal(t, biaID, rows[1].User.ID) + assert.Equal(t, bia.ID, rows[1].User.ID) assert.Equal(t, "Bia Ribeiro", rows[1].User.Name) assert.Equal(t, "Bia Post1", rows[1].Post.Title) - assert.Equal(t, biaID, rows[2].User.ID) + assert.Equal(t, bia.ID, rows[2].User.ID) assert.Equal(t, "Bia Ribeiro", rows[2].User.Name) assert.Equal(t, "Bia Post2", rows[2].Post.Title) }) }) t.Run("using slice of pointers to structs", func(t *testing.T) { - err := createTables(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should return 0 results correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) var users []*User err := c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`) assert.Equal(t, nil, err) @@ -210,14 +241,15 @@ func TestQuery(t *testing.T) { }) t.Run("should return a user correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + + c := newTestDB(db, config.driver) var users []*User err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") @@ -229,17 +261,18 @@ func TestQuery(t *testing.T) { }) t.Run("should return multiple users correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`) - assert.Equal(t, nil, err) - - _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('João Garcia', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + + _, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia Garcia', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + + c := newTestDB(db, config.driver) var users []*User err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia") @@ -256,33 +289,34 @@ func TestQuery(t *testing.T) { }) t.Run("should query joined tables correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - // This test only makes sense with no query prefix if variation.queryPrefix != "" { return } - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`) - assert.Equal(t, nil, err) - var joaoID uint - db.QueryRow(`SELECT id FROM users WHERE name = 'João Ribeiro'`).Scan(&joaoID) - - _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia Ribeiro', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) - var biaID uint - db.QueryRow(`SELECT id FROM users WHERE name = 'Bia Ribeiro'`).Scan(&biaID) - - _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, biaID, `, 'Bia Post1')`)) - assert.Equal(t, nil, err) - _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, biaID, `, 'Bia Post2')`)) - assert.Equal(t, nil, err) - _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joaoID, `, 'João Post1')`)) - assert.Equal(t, nil, err) + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + var joao User + getUserByName(db, config.driver, &joao, "João Ribeiro") + + _, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia Ribeiro', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + var bia User + getUserByName(db, config.driver, &bia, "Bia Ribeiro") + + _, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, bia.ID, `, 'Bia Post1')`)) + assert.Equal(t, nil, err) + _, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, bia.ID, `, 'Bia Post2')`)) + assert.Equal(t, nil, err) + _, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'João Post1')`)) + assert.Equal(t, nil, err) + + c := newTestDB(db, config.driver) var rows []*struct { User User `tablename:"u"` Post Post `tablename:"p"` @@ -296,15 +330,15 @@ func TestQuery(t *testing.T) { assert.Equal(t, nil, err) assert.Equal(t, 3, len(rows)) - assert.Equal(t, joaoID, rows[0].User.ID) + assert.Equal(t, joao.ID, rows[0].User.ID) assert.Equal(t, "João Ribeiro", rows[0].User.Name) assert.Equal(t, "João Post1", rows[0].Post.Title) - assert.Equal(t, biaID, rows[1].User.ID) + assert.Equal(t, bia.ID, rows[1].User.ID) assert.Equal(t, "Bia Ribeiro", rows[1].User.Name) assert.Equal(t, "Bia Post1", rows[1].Post.Title) - assert.Equal(t, biaID, rows[2].User.ID) + assert.Equal(t, bia.ID, rows[2].User.ID) assert.Equal(t, "Bia Ribeiro", rows[2].User.Name) assert.Equal(t, "Bia Post2", rows[2].Post.Title) }) @@ -313,23 +347,24 @@ func TestQuery(t *testing.T) { } t.Run("testing error cases", func(t *testing.T) { - err := createTables(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should report error if input is not a pointer to a slice of structs", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age) VALUES ('Andréa Sá', 0)`) - assert.Equal(t, nil, err) - - _, err = db.Exec(`INSERT INTO users (name, age) VALUES ('Caio Sá', 0)`) - assert.Equal(t, nil, err) + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Andréa Sá', 0)`) + assert.Equal(t, nil, err) + + _, err = db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Caio Sá', 0)`) + assert.Equal(t, nil, err) + + c := newTestDB(db, config.driver) err = c.Query(ctx, &User{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá") assert.NotEqual(t, nil, err) @@ -345,22 +380,22 @@ func TestQuery(t *testing.T) { }) t.Run("should report error if the query is not valid", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) var users []User err := c.Query(ctx, &users, `SELECT * FROM not a valid query`) assert.NotEqual(t, nil, err) }) t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) var rows []struct { User User `tablename:"users"` Post Post `tablename:"posts"` @@ -373,11 +408,11 @@ func TestQuery(t *testing.T) { t.Run("should report error for nested structs with invalid types", func(t *testing.T) { t.Run("int", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) var rows []struct { Foo int `tablename:"foo"` } @@ -395,11 +430,11 @@ func TestQuery(t *testing.T) { }) t.Run("*struct", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) var rows []struct { Foo *User `tablename:"foo"` } @@ -422,8 +457,8 @@ func TestQuery(t *testing.T) { } func TestQueryOne(t *testing.T) { - for driver := range supportedDialects { - t.Run(driver, func(t *testing.T) { + for _, config := range supportedConfigs { + t.Run(config.driver, func(t *testing.T) { variations := []struct { desc string queryPrefix string @@ -438,32 +473,33 @@ func TestQueryOne(t *testing.T) { }, } for _, variation := range variations { - err := createTables(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run(variation.desc, func(t *testing.T) { t.Run("should return RecordNotFoundErr when there are no results", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) u := User{} err := c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE id=1;`) assert.Equal(t, ErrRecordNotFound, err) }) t.Run("should return a user correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Bia', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + + c := newTestDB(db, config.driver) u := User{} err = c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") @@ -476,17 +512,18 @@ func TestQueryOne(t *testing.T) { }) t.Run("should return only the first result on multiples matches", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Andréa Sá', 0, '{"country":"US"}')`) - assert.Equal(t, nil, err) - - _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Caio Sá', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Andréa Sá', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + + _, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Caio Sá', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + + c := newTestDB(db, config.driver) var u User err = c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0)+` ORDER BY id ASC`, "% Sá") @@ -499,24 +536,25 @@ func TestQueryOne(t *testing.T) { }) t.Run("should query joined tables correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - // This test only makes sense with no query prefix if variation.queryPrefix != "" { return } - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`) - assert.Equal(t, nil, err) - var joaoID uint - db.QueryRow(`SELECT id FROM users WHERE name = 'João Ribeiro'`).Scan(&joaoID) - - _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joaoID, `, 'João Post1')`)) - assert.Equal(t, nil, err) + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('João Ribeiro', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + var joao User + getUserByName(db, config.driver, &joao, "João Ribeiro") + + _, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'João Post1')`)) + assert.Equal(t, nil, err) + + c := newTestDB(db, config.driver) var row struct { User User `tablename:"u"` Post Post `tablename:"p"` @@ -528,7 +566,7 @@ func TestQueryOne(t *testing.T) { ), "% Ribeiro") assert.Equal(t, nil, err) - assert.Equal(t, joaoID, row.User.ID) + assert.Equal(t, joao.ID, row.User.ID) assert.Equal(t, "João Ribeiro", row.User.Name) assert.Equal(t, "João Post1", row.Post.Title) }) @@ -536,17 +574,18 @@ func TestQueryOne(t *testing.T) { } t.Run("should report error if input is not a pointer to struct", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() - - _, err := db.Exec(`INSERT INTO users (name, age, address) VALUES ('Andréa Sá', 0, '{"country":"US"}')`) - assert.Equal(t, nil, err) - - _, err = db.Exec(`INSERT INTO users (name, age, address) VALUES ('Caio Sá', 0, '{"country":"BR"}')`) - assert.Equal(t, nil, err) + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Andréa Sá', 0, '{"country":"US"}')`) + assert.Equal(t, nil, err) + + _, err = db.ExecContext(ctx, `INSERT INTO users (name, age, address) VALUES ('Caio Sá', 0, '{"country":"BR"}')`) + assert.Equal(t, nil, err) + + c := newTestDB(db, config.driver) err = c.QueryOne(ctx, &[]User{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá") assert.NotEqual(t, nil, err) @@ -556,22 +595,22 @@ func TestQueryOne(t *testing.T) { }) t.Run("should report error if the query is not valid", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) var user User err := c.QueryOne(ctx, &user, `SELECT * FROM not a valid query`) assert.NotEqual(t, nil, err) }) t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) var row struct { User User `tablename:"users"` Post Post `tablename:"posts"` @@ -586,20 +625,20 @@ func TestQueryOne(t *testing.T) { } func TestInsert(t *testing.T) { - for driver := range supportedDialects { - t.Run(driver, func(t *testing.T) { + for _, config := range supportedConfigs { + t.Run(config.driver, func(t *testing.T) { t.Run("success cases", func(t *testing.T) { - err := createTables(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should insert one user correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) u := User{ Name: "Fernanda", @@ -621,19 +660,16 @@ func TestInsert(t *testing.T) { }) t.Run("should insert ignoring the ID for sqlite and multiple ids", func(t *testing.T) { - if supportedDialects[driver].InsertMethod() != insertWithLastInsertID { + if supportedDialects[config.driver].InsertMethod() != insertWithLastInsertID { return } - db := connectDB(t, driver) - defer db.Close() - ctx := context.Background() // Using columns "id" and "name" as IDs: table := NewTable("users", "id", "name") - c, err := New(driver, connectionString[driver], Config{}) + c, err := New(config.driver, connectionString[config.driver], Config{}) assert.Equal(t, nil, err) u := User{ @@ -650,7 +686,7 @@ func TestInsert(t *testing.T) { assert.Equal(t, uint(0), u.ID) result := User{} - err = getUserByName(c.db, c.dialect, &result, "No ID returned") + err = getUserByName(c.db, config.driver, &result, "No ID returned") assert.Equal(t, nil, err) assert.Equal(t, u.Age, result.Age) @@ -658,11 +694,11 @@ func TestInsert(t *testing.T) { }) t.Run("should work with anonymous structs", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) err = c.Insert(ctx, UsersTable, &struct { ID int `ksql:"id"` Name string `ksql:"name"` @@ -672,11 +708,11 @@ func TestInsert(t *testing.T) { }) t.Run("should work with preset IDs", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) usersByName := NewTable("users", "name") @@ -687,24 +723,24 @@ func TestInsert(t *testing.T) { assert.Equal(t, nil, err) var inserted User - err := getUserByName(SQLAdapter{db}, c.dialect, &inserted, "Preset Name") + err := getUserByName(db, config.driver, &inserted, "Preset Name") assert.Equal(t, nil, err) assert.Equal(t, 5455, inserted.Age) }) }) t.Run("testing error cases", func(t *testing.T) { - err := createTables(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should report error for invalid input types", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) err = c.Insert(ctx, UsersTable, "foo") assert.NotEqual(t, nil, err) @@ -731,11 +767,11 @@ func TestInsert(t *testing.T) { }) t.Run("should report error if for some reason the insertMethod is invalid", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) // This is an invalid value: c.dialect = brokenDialect{} @@ -745,11 +781,11 @@ func TestInsert(t *testing.T) { }) t.Run("should not panic if a column doesn't exist in the database", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) err = c.Insert(ctx, UsersTable, &struct { ID string `ksql:"id"` @@ -763,11 +799,11 @@ func TestInsert(t *testing.T) { }) t.Run("should not panic if the ID column doesn't exist in the database", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) brokenTable := NewTable("users", "non_existing_id") @@ -779,11 +815,11 @@ func TestInsert(t *testing.T) { }) t.Run("should not panic if the ID column is missing in the struct", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) err = c.Insert(ctx, UsersTable, &struct { Age int `ksql:"age"` @@ -792,7 +828,7 @@ func TestInsert(t *testing.T) { assert.Equal(t, nil, err) var u User - err = getUserByName(SQLAdapter{db}, c.dialect, &u, "Inserted With no ID") + err = getUserByName(db, config.driver, &u, "Inserted With no ID") assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) assert.Equal(t, 42, u.Age) @@ -821,19 +857,19 @@ func (brokenDialect) DriverName() string { } func TestDelete(t *testing.T) { - for driver := range supportedDialects { - t.Run(driver, func(t *testing.T) { - err := createTables(driver) + for _, config := range supportedConfigs { + t.Run(config.driver, func(t *testing.T) { + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should ignore empty lists of ids", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) u := User{ Name: "Won't be deleted", @@ -859,11 +895,11 @@ func TestDelete(t *testing.T) { }) t.Run("should delete one id correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) u1 := User{ Name: "Fernanda", @@ -907,11 +943,11 @@ func TestDelete(t *testing.T) { }) t.Run("should delete multiple ids correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) u1 := User{ Name: "Fernanda", @@ -964,29 +1000,27 @@ func TestDelete(t *testing.T) { } func TestUpdate(t *testing.T) { - for driver := range supportedDialects { - t.Run(driver, func(t *testing.T) { - err := createTables(driver) + for _, config := range supportedConfigs { + t.Run(config.driver, func(t *testing.T) { + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } t.Run("should update one user correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) u := User{ Name: "Letícia", } - _, err := db.Exec(`INSERT INTO users (name, age) VALUES ('Letícia', 0)`) + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 0)`) assert.Equal(t, nil, err) - row := db.QueryRow(`SELECT id FROM users WHERE name = 'Letícia'`) - assert.Equal(t, nil, row.Err()) - err = row.Scan(&u.ID) + err = getUserByName(db, config.driver, &u, "Letícia") assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) @@ -1003,21 +1037,19 @@ func TestUpdate(t *testing.T) { }) t.Run("should update one user correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) u := User{ Name: "Letícia", } - _, err := db.Exec(`INSERT INTO users (name, age) VALUES ('Letícia', 0)`) + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 0)`) assert.Equal(t, nil, err) - row := db.QueryRow(`SELECT id FROM users WHERE name = 'Letícia'`) - assert.Equal(t, nil, row.Err()) - err = row.Scan(&u.ID) + err = getUserByName(db, config.driver, &u, "Letícia") assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) @@ -1034,27 +1066,23 @@ func TestUpdate(t *testing.T) { }) t.Run("should ignore null pointers on partial updates", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) type partialUser struct { ID uint `ksql:"id"` Name string `ksql:"name"` Age *int `ksql:"age"` } - u := partialUser{ - Name: "Letícia", - Age: nullable.Int(22), - } - _, err := db.Exec(`INSERT INTO users (name, age) VALUES ('Letícia', 22)`) + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 22)`) assert.Equal(t, nil, err) - row := db.QueryRow(`SELECT id FROM users WHERE name = 'Letícia'`) - assert.Equal(t, nil, row.Err()) - err = row.Scan(&u.ID) + var u User + err = getUserByName(db, config.driver, &u, "Letícia") assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) @@ -1075,27 +1103,23 @@ func TestUpdate(t *testing.T) { }) t.Run("should update valid pointers on partial updates", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) type partialUser struct { ID uint `ksql:"id"` Name string `ksql:"name"` Age *int `ksql:"age"` } - u := partialUser{ - Name: "Letícia", - Age: nullable.Int(22), - } - _, err := db.Exec(`INSERT INTO users (name, age) VALUES ('Letícia', 22)`) + + _, err := db.ExecContext(ctx, `INSERT INTO users (name, age) VALUES ('Letícia', 22)`) assert.Equal(t, nil, err) - row := db.QueryRow(`SELECT id FROM users WHERE name = 'Letícia'`) - assert.Equal(t, nil, row.Err()) - err = row.Scan(&u.ID) + var u User + err = getUserByName(db, config.driver, &u, "Letícia") assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) @@ -1116,11 +1140,11 @@ func TestUpdate(t *testing.T) { }) t.Run("should return ErrRecordNotFound when asked to update an inexistent user", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) err = c.Update(ctx, UsersTable, User{ ID: 4200, @@ -1130,11 +1154,11 @@ func TestUpdate(t *testing.T) { }) t.Run("should report database errors correctly", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) err = c.Update(ctx, NewTable("non_existing_table"), User{ ID: 1, @@ -1147,8 +1171,8 @@ func TestUpdate(t *testing.T) { } func TestQueryChunks(t *testing.T) { - for driver := range supportedDialects { - t.Run(driver, func(t *testing.T) { + for _, config := range supportedConfigs { + t.Run(config.driver, func(t *testing.T) { variations := []struct { desc string queryPrefix string @@ -1165,16 +1189,16 @@ func TestQueryChunks(t *testing.T) { for _, variation := range variations { t.Run(variation.desc, func(t *testing.T) { t.Run("should query a single row correctly", func(t *testing.T) { - err := createTables(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) _ = c.Insert(ctx, UsersTable, &User{ Name: "User1", @@ -1205,16 +1229,16 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should query one chunk correctly", func(t *testing.T) { - err := createTables(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) _ = c.Insert(ctx, UsersTable, &User{Name: "User1", Address: Address{Country: "US"}}) _ = c.Insert(ctx, UsersTable, &User{Name: "User2", Address: Address{Country: "BR"}}) @@ -1247,16 +1271,16 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should query chunks of 1 correctly", func(t *testing.T) { - err := createTables(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) _ = c.Insert(ctx, UsersTable, &User{Name: "User1", Address: Address{Country: "US"}}) _ = c.Insert(ctx, UsersTable, &User{Name: "User2", Address: Address{Country: "BR"}}) @@ -1289,16 +1313,16 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should load partially filled chunks correctly", func(t *testing.T) { - err := createTables(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) @@ -1336,8 +1360,8 @@ func TestQueryChunks(t *testing.T) { return } - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() joao := User{ Name: "Thiago Ribeiro", @@ -1349,15 +1373,15 @@ func TestQueryChunks(t *testing.T) { } ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) _ = c.Insert(ctx, UsersTable, &joao) _ = c.Insert(ctx, UsersTable, &thatiana) - _, err := db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, thatiana.ID, `, 'Thatiana Post1')`)) + _, err := db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, thatiana.ID, `, 'Thatiana Post1')`)) assert.Equal(t, nil, err) - _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, thatiana.ID, `, 'Thatiana Post2')`)) + _, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, thatiana.ID, `, 'Thatiana Post2')`)) assert.Equal(t, nil, err) - _, err = db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'Thiago Post1')`)) + _, err = db.ExecContext(ctx, fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, joao.ID, `, 'Thiago Post1')`)) assert.Equal(t, nil, err) var lengths []int @@ -1402,16 +1426,16 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should abort the first iteration when the callback returns an ErrAbortIteration", func(t *testing.T) { - err := createTables(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) @@ -1441,16 +1465,16 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should abort the last iteration when the callback returns an ErrAbortIteration", func(t *testing.T) { - err := createTables(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) @@ -1484,16 +1508,16 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should return error if the callback returns an error in the first iteration", func(t *testing.T) { - err := createTables(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) @@ -1523,16 +1547,16 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should return error if the callback returns an error in the last iteration", func(t *testing.T) { - err := createTables(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) @@ -1566,11 +1590,11 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should report error if the input function is invalid", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) funcs := []interface{}{ nil, @@ -1611,11 +1635,11 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should report error if the query is not valid", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) err := c.QueryChunks(ctx, ChunkParser{ Query: `SELECT * FROM not a valid query`, Params: []interface{}{}, @@ -1629,11 +1653,11 @@ func TestQueryChunks(t *testing.T) { }) t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) err := c.QueryChunks(ctx, ChunkParser{ Query: `SELECT * FROM users u JOIN posts p ON u.id = p.user_id`, @@ -1659,19 +1683,19 @@ func TestQueryChunks(t *testing.T) { } func TestTransaction(t *testing.T) { - for driver := range supportedDialects { - t.Run(driver, func(t *testing.T) { + for _, config := range supportedConfigs { + t.Run(config.driver, func(t *testing.T) { t.Run("should query a single row correctly", func(t *testing.T) { - err := createTables(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) @@ -1689,16 +1713,16 @@ func TestTransaction(t *testing.T) { }) t.Run("should rollback when there are errors", func(t *testing.T) { - err := createTables(driver) + err := createTables(config.driver) if err != nil { t.Fatal("could not create test table!, reason:", err.Error()) } - db := connectDB(t, driver) - defer db.Close() + db, closer := connectDB(t, config) + defer closer.Close() ctx := context.Background() - c := newTestDB(db, driver) + c := newTestDB(db, config.driver) u1 := User{Name: "User1", Age: 42} u2 := User{Name: "User2", Age: 42} @@ -1737,8 +1761,11 @@ func TestScanRows(t *testing.T) { dialect := supportedDialects["sqlite3"] ctx := context.TODO() - db := connectDB(t, "sqlite3") - defer db.Close() + db, closer := connectDB(t, testConfig{ + driver: "sqlite3", + adapterName: "sql", + }) + defer closer.Close() c := newTestDB(db, "sqlite3") _ = c.Insert(ctx, UsersTable, &User{Name: "User1", Age: 22}) _ = c.Insert(ctx, UsersTable, &User{Name: "User2", Age: 14}) @@ -1766,8 +1793,11 @@ func TestScanRows(t *testing.T) { dialect := supportedDialects["sqlite3"] ctx := context.TODO() - db := connectDB(t, "sqlite3") - defer db.Close() + db, closer := connectDB(t, testConfig{ + driver: "sqlite3", + adapterName: "sql", + }) + defer closer.Close() c := newTestDB(db, "sqlite3") _ = c.Insert(ctx, UsersTable, &User{Name: "User1", Age: 22}) @@ -1798,8 +1828,11 @@ func TestScanRows(t *testing.T) { dialect := supportedDialects["sqlite3"] ctx := context.TODO() - db := connectDB(t, "sqlite3") - defer db.Close() + db, closer := connectDB(t, testConfig{ + driver: "sqlite3", + adapterName: "sql", + }) + defer closer.Close() rows, err := db.QueryContext(ctx, "select * from users where name='User2'") assert.Equal(t, nil, err) @@ -1819,8 +1852,11 @@ func TestScanRows(t *testing.T) { dialect := supportedDialects["sqlite3"] ctx := context.TODO() - db := connectDB(t, "sqlite3") - defer db.Close() + db, closer := connectDB(t, testConfig{ + driver: "sqlite3", + adapterName: "sql", + }) + defer closer.Close() rows, err := db.QueryContext(ctx, "select * from users where name='User2'") assert.Equal(t, nil, err) @@ -1838,8 +1874,11 @@ func TestScanRows(t *testing.T) { dialect := supportedDialects["sqlite3"] ctx := context.TODO() - db := connectDB(t, "sqlite3") - defer db.Close() + db, closer := connectDB(t, testConfig{ + driver: "sqlite3", + adapterName: "sql", + }) + defer closer.Close() rows, err := db.QueryContext(ctx, "select * from users where name='User2'") assert.Equal(t, nil, err) @@ -1940,25 +1979,41 @@ func createTables(driver string) error { return nil } -func newTestDB(db *sql.DB, driver string) DB { +func newTestDB(db DBAdapter, driver string) DB { return DB{ driver: driver, dialect: supportedDialects[driver], - db: SQLAdapter{db}, + db: db, } } -func connectDB(t *testing.T, driver string) *sql.DB { - connStr := connectionString[driver] +type NopCloser struct{} + +func (NopCloser) Close() error { return nil } + +func connectDB(t *testing.T, config testConfig) (DBAdapter, io.Closer) { + connStr := connectionString[config.driver] if connStr == "" { - panic(fmt.Sprintf("unsupported driver: '%s'", driver)) + panic(fmt.Sprintf("unsupported driver: '%s'", config.driver)) } - db, err := sql.Open(driver, connStr) - if err != nil { - t.Fatal(err.Error()) + switch config.adapterName { + case "sql": + db, err := sql.Open(config.driver, connStr) + if err != nil { + t.Fatal(err.Error()) + } + return SQLAdapter{db}, db + case "pgx": + pool, err := pgxpool.Connect(context.TODO(), connStr) + if err != nil { + t.Fatal(err.Error()) + } + return PGXAdapter{pool}, NopCloser{} } - return db + + t.Fatalf("unsupported adapter: %s", config.adapterName) + return nil, nil } func shiftErrSlice(errs *[]error) error { @@ -1967,9 +2022,7 @@ func shiftErrSlice(errs *[]error) error { return err } -func getUsersByID(dbi DBAdapter, dialect dialect, resultsPtr *[]User, ids ...uint) error { - db := dbi.(SQLAdapter) - +func getUsersByID(db DBAdapter, dialect dialect, resultsPtr *[]User, ids ...uint) error { placeholders := make([]string, len(ids)) params := make([]interface{}, len(ids)) for i := range ids { @@ -1978,7 +2031,8 @@ func getUsersByID(dbi DBAdapter, dialect dialect, resultsPtr *[]User, ids ...uin } results := []User{} - rows, err := db.DB.Query( + rows, err := db.QueryContext( + context.TODO(), fmt.Sprintf( "SELECT id, name, age FROM users WHERE id IN (%s)", strings.Join(placeholders, ", "), @@ -2009,12 +2063,18 @@ func getUsersByID(dbi DBAdapter, dialect dialect, resultsPtr *[]User, ids ...uin return nil } -func getUserByID(dbi DBAdapter, dialect dialect, result *User, id uint) error { - db := dbi.(SQLAdapter) +func getUserByID(db DBAdapter, dialect dialect, result *User, id uint) error { + rows, err := db.QueryContext(context.TODO(), `SELECT id, name, age, address FROM users WHERE id=`+dialect.Placeholder(0), id) + if err != nil { + return err + } + defer rows.Close() - row := db.DB.QueryRow(`SELECT id, name, age, address FROM users WHERE id=`+dialect.Placeholder(0), id) - if row.Err() != nil { - return row.Err() + if rows.Next() == false { + if rows.Err() != nil { + return rows.Err() + } + return sql.ErrNoRows } value := jsonSerializable{ @@ -2022,7 +2082,7 @@ func getUserByID(dbi DBAdapter, dialect dialect, result *User, id uint) error { Attr: &result.Address, } - err := row.Scan(&result.ID, &result.Name, &result.Age, &value) + err = rows.Scan(&result.ID, &result.Name, &result.Age, &value) if err != nil { return err } @@ -2030,16 +2090,24 @@ func getUserByID(dbi DBAdapter, dialect dialect, result *User, id uint) error { return nil } -func getUserByName(dbi DBAdapter, dialect dialect, result *User, name string) error { - db := dbi.(SQLAdapter) +func getUserByName(db DBAdapter, driver string, result *User, name string) error { + dialect := supportedDialects[driver] - row := db.DB.QueryRow(`SELECT id, name, age, address FROM users WHERE name=`+dialect.Placeholder(0), name) - if row.Err() != nil { - return row.Err() + rows, err := db.QueryContext(context.TODO(), `SELECT id, name, age, address FROM users WHERE name=`+dialect.Placeholder(0), name) + if err != nil { + return err + } + defer rows.Close() + + if rows.Next() == false { + if rows.Err() != nil { + return rows.Err() + } + return sql.ErrNoRows } var rawAddr []byte - err := row.Scan(&result.ID, &result.Name, &result.Age, &rawAddr) + err = rows.Scan(&result.ID, &result.Name, &result.Age, &rawAddr) if err != nil { return err } From 37298e2c243f1ec66e88dd92ed7c4542f7820b4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 1 Aug 2021 15:04:47 -0300 Subject: [PATCH 39/40] Simplify NewWithPGX() function --- benchmark_test.go | 4 ++-- ksql.go | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 0b52dda..605ee7a 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -52,7 +52,7 @@ func BenchmarkInsert(b *testing.B) { }) }) - pgxDB, err := ksql.NewWithPGX(ctx, driver, connStr, ksql.Config{ + pgxDB, err := ksql.NewWithPGX(ctx, connStr, ksql.Config{ MaxOpenConns: 1, }) if err != nil { @@ -165,7 +165,7 @@ func BenchmarkQuery(b *testing.B) { }) }) - pgxDB, err := ksql.NewWithPGX(ctx, driver, connStr, ksql.Config{ + pgxDB, err := ksql.NewWithPGX(ctx, connStr, ksql.Config{ MaxOpenConns: 1, }) if err != nil { diff --git a/ksql.go b/ksql.go index 049fec4..eafeefe 100644 --- a/ksql.go +++ b/ksql.go @@ -103,7 +103,6 @@ func New( // library in the backend func NewWithPGX( ctx context.Context, - dbDriver string, connectionString string, config Config, ) (db DB, err error) { @@ -122,7 +121,7 @@ func NewWithPGX( return DB{}, err } - db, err = NewWithAdapter(PGXAdapter{pool}, dbDriver, connectionString) + db, err = NewWithAdapter(PGXAdapter{pool}, "postgres", connectionString) return db, err } From 6e3e558407f4b6ce906a9f0ac541bef8c35f4bed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 1 Aug 2021 15:19:03 -0300 Subject: [PATCH 40/40] Update benchmark on README --- README.md | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 00bb99a..2fe2de1 100644 --- a/README.md +++ b/README.md @@ -482,25 +482,25 @@ read the example tests available on our [example service](./examples/example_ser The benchmark is very good, the code is, in practical terms, as fast as sqlx: ```bash -$ make bench TIME=3s -go test -bench=. -benchtime=3s +$ make bench TIME=5s +go test -bench=. -benchtime=5s goos: linux goarch: amd64 pkg: github.com/vingarcia/ksql cpu: Intel(R) Core(TM) i5-3210M CPU @ 2.50GHz -BenchmarkInsert/ksql-setup/insert-one-4 4170 880911 ns/op -BenchmarkInsert/pgx-adapter-setup/insert-one-4 5780 669992 ns/op -BenchmarkInsert/sqlx-setup/insert-one-4 4389 825303 ns/op -BenchmarkQuery/ksql-setup/single-row-4 1710 306996 ns/op -BenchmarkQuery/ksql-setup/multiple-rows-4 0000 345091 ns/op -BenchmarkQuery/pgx-adapter-setup/single-row-4 5237 140154 ns/op -BenchmarkQuery/pgx-adapter-setup/multiple-rows-4 22051 164306 ns/op -BenchmarkQuery/sqlx-setup/single-row-4 11955 311654 ns/op -BenchmarkQuery/sqlx-setup/multiple-rows-4 10000 323079 ns/op +BenchmarkInsert/ksql-setup/insert-one-4 5293 960859 ns/op +BenchmarkInsert/pgx-adapter-setup/insert-one-4 7982 736973 ns/op +BenchmarkInsert/sqlx-setup/insert-one-4 6854 857824 ns/op +BenchmarkQuery/ksql-setup/single-row-4 12596 407116 ns/op +BenchmarkQuery/ksql-setup/multiple-rows-4 15883 391135 ns/op +BenchmarkQuery/pgx-adapter-setup/single-row-4 34008 165604 ns/op +BenchmarkQuery/pgx-adapter-setup/multiple-rows-4 22579 280673 ns/op +BenchmarkQuery/sqlx-setup/single-row-4 10000 512741 ns/op +BenchmarkQuery/sqlx-setup/multiple-rows-4 10779 596377 ns/op PASS -ok github.com/vingarcia/ksql 55.231s +ok github.com/vingarcia/ksql 94.951s Benchmark executed at: 2021-08-01 -Benchmark executed on commit: 87f57f665fc4bb18f3cfb543398bb96f665f68af +Benchmark executed on commit: 37298e2c243f1ec66e88dd92ed7c4542f7820b4f ``` ### Running the ksql tests (for contributors)