From c0d7206dccb36e562e3f58155e10ff2b10887391 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 6 Jun 2021 20:51:13 -0300 Subject: [PATCH] Breaking change: Update SQLProvider interface so methods receive table info as argument --- benchmark_test.go | 6 +- contracts.go | 60 ++- examples/crud/crud.go | 19 +- examples/example_service/example_service.go | 31 +- .../example_service/example_service_test.go | 48 +-- examples/example_service/mocks.go | 24 +- go.sum | 3 + ksql.go | 111 +++--- ksql_test.go | 365 +++++++++++------- mocks.go | 18 +- 10 files changed, 403 insertions(+), 282 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 8914511..a31128c 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -12,6 +12,8 @@ import ( "github.com/vingarcia/ksql" ) +var UsersTable = ksql.NewTable("users") + func BenchmarkInsert(b *testing.B) { ctx := context.Background() @@ -20,7 +22,6 @@ func BenchmarkInsert(b *testing.B) { ksqlDB, err := ksql.New(driver, connStr, ksql.Config{ MaxOpenConns: 1, - TableName: "users", }) if err != nil { b.FailNow() @@ -40,7 +41,7 @@ func BenchmarkInsert(b *testing.B) { b.Run("insert-one", func(b *testing.B) { for i := 0; i < b.N; i++ { - err := ksqlDB.Insert(ctx, &User{ + err := ksqlDB.Insert(ctx, UsersTable, &User{ Name: strconv.Itoa(i), Age: i, }) @@ -92,7 +93,6 @@ func BenchmarkQuery(b *testing.B) { ksqlDB, err := ksql.New(driver, connStr, ksql.Config{ MaxOpenConns: 1, - TableName: "users", }) if err != nil { b.FailNow() diff --git a/contracts.go b/contracts.go index fc52d92..26866de 100644 --- a/contracts.go +++ b/contracts.go @@ -16,9 +16,9 @@ var ErrAbortIteration error = fmt.Errorf("ksql: abort iteration, should only be // SQLProvider describes the public behavior of this ORM type SQLProvider interface { - Insert(ctx context.Context, record interface{}) error - Update(ctx context.Context, record interface{}) error - Delete(ctx context.Context, idsOrRecords ...interface{}) error + Insert(ctx context.Context, table Table, record interface{}) error + Update(ctx context.Context, table Table, record interface{}) error + Delete(ctx context.Context, table Table, idsOrRecords ...interface{}) error Query(ctx context.Context, records interface{}, query string, params ...interface{}) error QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error @@ -28,6 +28,60 @@ type SQLProvider interface { Transaction(ctx context.Context, fn func(SQLProvider) error) error } +// Table describes the required information for inserting, updating and +// deleting entities from the database by ID using the 3 helper functions +// created for that purpose. +type Table struct { + // this name must be set in order to use the Insert, Delete and Update helper + // functions. If you only intend to make queries or to use the Exec function + // it is safe to leave this field unset. + name string + + // IDColumns defaults to []string{"id"} if unset + idColumns []string +} + +// NewTable returns a Table instance that stores +// the tablename and the names of columns used as ID, +// if no column name is passed it defaults to using +// the `"id"` column. +// +// This Table is required only for using the helper methods: +// +// - Insert +// - Update +// - Delete +// +// Passing multiple ID columns will be interpreted +// as a single composite key, if you want +// to use the helper functions with different +// keys you'll need to create multiple Table instances +// for the same database table, each with a different +// set of ID columns, but this is usually not necessary. +func NewTable(tableName string, ids ...string) Table { + if len(ids) == 0 { + ids = []string{"id"} + } + + return Table{ + name: tableName, + idColumns: ids, + } +} + +func (t Table) insertMethodFor(dialect dialect) insertMethod { + if len(t.idColumns) == 1 { + return dialect.InsertMethod() + } + + insertMethod := dialect.InsertMethod() + if insertMethod == insertWithLastInsertID { + return insertWithNoIDRetrieval + } + + return insertMethod +} + // ChunkParser stores the arguments of the QueryChunks function type ChunkParser struct { // The Query and Params are used together to build a query with diff --git a/examples/crud/crud.go b/examples/crud/crud.go index 75fa363..20e78b0 100644 --- a/examples/crud/crud.go +++ b/examples/crud/crud.go @@ -33,11 +33,14 @@ type Address struct { City string `json:"city"` } +// UsersTable informs ksql the name of the table and that it can +// use the default value for the primary key column name: "id" +var UsersTable = ksql.NewTable("users") + func main() { ctx := context.Background() db, err := ksql.New("sqlite3", "/tmp/hello.sqlite", ksql.Config{ MaxOpenConns: 1, - TableName: "users", }) if err != nil { panic(err.Error()) @@ -62,14 +65,14 @@ func main() { State: "MG", }, } - err = db.Insert(ctx, &alison) + err = db.Insert(ctx, UsersTable, &alison) if err != nil { panic(err.Error()) } fmt.Println("Alison ID:", alison.ID) // Inserting inline: - err = db.Insert(ctx, &User{ + err = db.Insert(ctx, UsersTable, &User{ Name: "Cristina", Age: 27, Address: Address{ @@ -81,7 +84,7 @@ func main() { } // Deleting Alison: - err = db.Delete(ctx, alison.ID) + err = db.Delete(ctx, UsersTable, alison.ID) if err != nil { panic(err.Error()) } @@ -96,12 +99,12 @@ func main() { // Updating all fields from Cristina: cris.Name = "Cris" - err = db.Update(ctx, cris) + err = db.Update(ctx, UsersTable, cris) // Changing the age of Cristina but not touching any other fields: // Partial update technique 1: - err = db.Update(ctx, struct { + err = db.Update(ctx, UsersTable, struct { ID int `ksql:"id"` Age int `ksql:"age"` }{ID: cris.ID, Age: 28}) @@ -110,7 +113,7 @@ func main() { } // Partial update technique 2: - err = db.Update(ctx, PartialUpdateUser{ + err = db.Update(ctx, UsersTable, PartialUpdateUser{ ID: cris.ID, Age: nullable.Int(28), }) @@ -142,7 +145,7 @@ func main() { return err } - err = db.Update(ctx, PartialUpdateUser{ + err = db.Update(ctx, UsersTable, PartialUpdateUser{ ID: cris2.ID, Age: nullable.Int(29), }) diff --git a/examples/example_service/example_service.go b/examples/example_service/example_service.go index 85bd18f..94ca156 100644 --- a/examples/example_service/example_service.go +++ b/examples/example_service/example_service.go @@ -8,11 +8,8 @@ import ( "github.com/vingarcia/ksql/nullable" ) -// Service ... -type Service struct { - usersTable ksql.SQLProvider - streamChunkSize int -} +// UsersTable informs ksql that the ID column is named "id" +var UsersTable = ksql.NewTable("users", "id") // UserEntity represents a domain user, // the pointer fields represent optional fields that @@ -41,17 +38,23 @@ type Address struct { Country string `json:"country"` } +// Service ... +type Service struct { + db ksql.SQLProvider + streamChunkSize int +} + // NewUserService ... -func NewUserService(usersTable ksql.SQLProvider) Service { +func NewUserService(db ksql.SQLProvider) Service { return Service{ - usersTable: usersTable, + db: db, streamChunkSize: 100, } } // CreateUser ... func (s Service) CreateUser(ctx context.Context, u UserEntity) error { - return s.usersTable.Insert(ctx, &u) + return s.db.Insert(ctx, UsersTable, &u) } // UpdateUserScore update the user score adding scoreChange with the current @@ -60,12 +63,12 @@ func (s Service) UpdateUserScore(ctx context.Context, uID int, scoreChange int) var scoreRow struct { Score int `ksql:"score"` } - err := s.usersTable.QueryOne(ctx, &scoreRow, "SELECT score FROM users WHERE id = ?", uID) + err := s.db.QueryOne(ctx, &scoreRow, "SELECT score FROM users WHERE id = ?", uID) if err != nil { return err } - return s.usersTable.Update(ctx, &UserEntity{ + return s.db.Update(ctx, UsersTable, &UserEntity{ ID: uID, Score: nullable.Int(scoreRow.Score + scoreChange), }) @@ -76,12 +79,12 @@ func (s Service) ListUsers(ctx context.Context, offset, limit int) (total int, u var countRow struct { Count int `ksql:"count"` } - err = s.usersTable.QueryOne(ctx, &countRow, "SELECT count(*) as count FROM users") + err = s.db.QueryOne(ctx, &countRow, "SELECT count(*) as count FROM users") if err != nil { return 0, nil, err } - return countRow.Count, users, s.usersTable.Query(ctx, &users, "SELECT * FROM users OFFSET ? LIMIT ?", offset, limit) + return countRow.Count, users, s.db.Query(ctx, &users, "SELECT * FROM users OFFSET ? LIMIT ?", offset, limit) } // StreamAllUsers sends all users from the database to an external client @@ -91,7 +94,7 @@ func (s Service) ListUsers(ctx context.Context, offset, limit int) (total int, u // function only when the ammount of data loaded might exceed the available memory and/or // when you can't put an upper limit on the number of values returned. func (s Service) StreamAllUsers(ctx context.Context, sendUser func(u UserEntity) error) error { - return s.usersTable.QueryChunks(ctx, ksql.ChunkParser{ + return s.db.QueryChunks(ctx, ksql.ChunkParser{ Query: "SELECT * FROM users", Params: []interface{}{}, ChunkSize: s.streamChunkSize, @@ -110,5 +113,5 @@ func (s Service) StreamAllUsers(ctx context.Context, sendUser func(u UserEntity) // DeleteUser deletes a user by its ID func (s Service) DeleteUser(ctx context.Context, uID int) error { - return s.usersTable.Delete(ctx, uID) + return s.db.Delete(ctx, UsersTable, uID) } diff --git a/examples/example_service/example_service_test.go b/examples/example_service/example_service_test.go index ae142f5..33b2f2c 100644 --- a/examples/example_service/example_service_test.go +++ b/examples/example_service/example_service_test.go @@ -17,16 +17,16 @@ func TestCreateUser(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockSQLProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 100, } var users []interface{} - usersTableMock.EXPECT().Insert(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, records ...interface{}) error { + mockDB.EXPECT().Insert(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, table ksql.Table, records ...interface{}) error { users = append(users, records...) return nil }) @@ -43,16 +43,16 @@ func TestCreateUser(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockSQLProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 100, } var users []map[string]interface{} - usersTableMock.EXPECT().Insert(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, records ...interface{}) error { + mockDB.EXPECT().Insert(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, table ksql.Table, records ...interface{}) error { for _, record := range records { // The StructToMap function will convert a struct with `ksql` tags // into a map using the ksql attr names as keys. @@ -83,16 +83,16 @@ func TestUpdateUserScore(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockSQLProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 100, } var users []interface{} gomock.InOrder( - usersTableMock.EXPECT().QueryOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + mockDB.EXPECT().QueryOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, result interface{}, query string, params ...interface{}) error { // This function will use reflection to fill the // struct fields with the values from the map @@ -103,8 +103,8 @@ func TestUpdateUserScore(t *testing.T) { "score": 42, }) }), - usersTableMock.EXPECT().Update(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, records ...interface{}) error { + mockDB.EXPECT().Update(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, table ksql.Table, records ...interface{}) error { users = append(users, records...) return nil }), @@ -127,15 +127,15 @@ func TestListUsers(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockSQLProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 100, } gomock.InOrder( - usersTableMock.EXPECT().QueryOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + mockDB.EXPECT().QueryOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, result interface{}, query string, params ...interface{}) error { // This function will use reflection to fill the // struct fields with the values from the map @@ -146,7 +146,7 @@ func TestListUsers(t *testing.T) { "count": 420, }) }), - usersTableMock.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + mockDB.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, results interface{}, query string, params ...interface{}) error { return structs.FillSliceWith(results, []map[string]interface{}{ { @@ -189,14 +189,14 @@ func TestStreamAllUsers(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockSQLProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 2, } - usersTableMock.EXPECT().QueryChunks(gomock.Any(), gomock.Any()). + mockDB.EXPECT().QueryChunks(gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, parser ksql.ChunkParser) error { fn, ok := parser.ForEachChunk.(func(users []UserEntity) error) require.True(t, ok) @@ -263,16 +263,16 @@ func TestDeleteUser(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - usersTableMock := NewMockSQLProvider(controller) + mockDB := NewMockSQLProvider(controller) s := Service{ - usersTable: usersTableMock, + db: mockDB, streamChunkSize: 100, } var ids []interface{} - usersTableMock.EXPECT().Delete(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, idArgs ...interface{}) error { + mockDB.EXPECT().Delete(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, table ksql.Table, idArgs ...interface{}) error { ids = append(ids, idArgs...) return nil }) diff --git a/examples/example_service/mocks.go b/examples/example_service/mocks.go index 44b7baf..c170490 100644 --- a/examples/example_service/mocks.go +++ b/examples/example_service/mocks.go @@ -36,9 +36,9 @@ func (m *MockSQLProvider) EXPECT() *MockSQLProviderMockRecorder { } // Delete mocks base method. -func (m *MockSQLProvider) Delete(ctx context.Context, idsOrRecords ...interface{}) error { +func (m *MockSQLProvider) Delete(ctx context.Context, table ksql.Table, idsOrRecords ...interface{}) error { m.ctrl.T.Helper() - varargs := []interface{}{ctx} + varargs := []interface{}{ctx, table} for _, a := range idsOrRecords { varargs = append(varargs, a) } @@ -48,9 +48,9 @@ func (m *MockSQLProvider) Delete(ctx context.Context, idsOrRecords ...interface{ } // Delete indicates an expected call of Delete. -func (mr *MockSQLProviderMockRecorder) Delete(ctx interface{}, idsOrRecords ...interface{}) *gomock.Call { +func (mr *MockSQLProviderMockRecorder) Delete(ctx, table interface{}, idsOrRecords ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx}, idsOrRecords...) + varargs := append([]interface{}{ctx, table}, idsOrRecords...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSQLProvider)(nil).Delete), varargs...) } @@ -74,17 +74,17 @@ func (mr *MockSQLProviderMockRecorder) Exec(ctx, query interface{}, params ...in } // Insert mocks base method. -func (m *MockSQLProvider) Insert(ctx context.Context, record interface{}) error { +func (m *MockSQLProvider) Insert(ctx context.Context, table ksql.Table, record interface{}) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Insert", ctx, record) + ret := m.ctrl.Call(m, "Insert", ctx, table, record) ret0, _ := ret[0].(error) return ret0 } // Insert indicates an expected call of Insert. -func (mr *MockSQLProviderMockRecorder) Insert(ctx, record interface{}) *gomock.Call { +func (mr *MockSQLProviderMockRecorder) Insert(ctx, table, record interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockSQLProvider)(nil).Insert), ctx, record) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockSQLProvider)(nil).Insert), ctx, table, record) } // Query mocks base method. @@ -154,15 +154,15 @@ func (mr *MockSQLProviderMockRecorder) Transaction(ctx, fn interface{}) *gomock. } // Update mocks base method. -func (m *MockSQLProvider) Update(ctx context.Context, record interface{}) error { +func (m *MockSQLProvider) Update(ctx context.Context, table ksql.Table, record interface{}) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Update", ctx, record) + ret := m.ctrl.Call(m, "Update", ctx, table, record) ret0, _ := ret[0].(error) return ret0 } // Update indicates an expected call of Update. -func (mr *MockSQLProviderMockRecorder) Update(ctx, record interface{}) *gomock.Call { +func (mr *MockSQLProviderMockRecorder) Update(ctx, table, record interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockSQLProvider)(nil).Update), ctx, record) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockSQLProvider)(nil).Update), ctx, table, record) } diff --git a/go.sum b/go.sum index b80fabb..67d781f 100644 --- a/go.sum +++ b/go.sum @@ -35,6 +35,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= @@ -45,8 +46,10 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e h1:aZzprAO9/8oim3qStq3wc1Xuxx4QmAGriC4VU4ojemQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898 h1:/atklqdjdhuosWIl6AIbOeHJjicWYPqR9bpxqxYG2pA= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= diff --git a/ksql.go b/ksql.go index 5ada2ef..85e6e2a 100644 --- a/ksql.go +++ b/ksql.go @@ -24,16 +24,9 @@ func init() { // interfacing with the "database/sql" package implementing // the KissSQL interface `SQLProvider`. type DB struct { - driver string - dialect dialect - tableName string - db sqlProvider - - // Most dbs have a single primary key, - // But in future ksql should work with compound keys as well - idCols []string - - insertMethod insertMethod + driver string + dialect dialect + db sqlProvider } type sqlProvider interface { @@ -46,14 +39,6 @@ type sqlProvider interface { type Config struct { // MaxOpenCons defaults to 1 if not set MaxOpenConns int - - // TableName must be set in order to use the Insert, Delete and Update helper - // functions. If you only intend to make queries or to use the Exec function - // it is safe to leave this field unset. - TableName string - - // IDColumns defaults to []string{"id"} if unset - IDColumns []string } // New instantiates a new KissSQL client @@ -81,23 +66,10 @@ func New( return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver) } - if len(config.IDColumns) == 0 { - config.IDColumns = []string{"id"} - } - - insertMethod := dialect.InsertMethod() - if len(config.IDColumns) > 1 && insertMethod == insertWithLastInsertID { - insertMethod = insertWithNoIDRetrieval - } - return DB{ - dialect: dialect, - driver: dbDriver, - db: db, - tableName: config.TableName, - - idCols: config.IDColumns, - insertMethod: insertMethod, + dialect: dialect, + driver: dbDriver, + db: db, }, nil } @@ -133,8 +105,16 @@ func (c DB) Query( slice = slice.Slice(0, 0) } - if strings.ToUpper(getFirstToken(query)) == "FROM" { - selectPrefix, err := buildSelectQuery(c.dialect, structType, selectQueryCache[c.dialect.DriverName()]) + info := structs.GetTagInfo(structType) + + firstToken := strings.ToUpper(getFirstToken(query)) + if info.IsNestedStruct && firstToken == "SELECT" { + // This error check is necessary, since if we can't build the select part of the query this feature won't work. + return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query") + } + + if firstToken == "FROM" { + selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()]) if err != nil { return err } @@ -206,8 +186,16 @@ func (c DB) QueryOne( return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record) } - if strings.ToUpper(getFirstToken(query)) == "FROM" { - selectPrefix, err := buildSelectQuery(c.dialect, t, selectQueryCache[c.dialect.DriverName()]) + info := structs.GetTagInfo(t) + + firstToken := strings.ToUpper(getFirstToken(query)) + if info.IsNestedStruct && firstToken == "SELECT" { + // This error check is necessary, since if we can't build the select part of the query this feature won't work. + return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query") + } + + if firstToken == "FROM" { + selectPrefix, err := buildSelectQuery(c.dialect, t, info, selectQueryCache[c.dialect.DriverName()]) if err != nil { return err } @@ -268,8 +256,16 @@ func (c DB) QueryChunks( return err } - if strings.ToUpper(getFirstToken(parser.Query)) == "FROM" { - selectPrefix, err := buildSelectQuery(c.dialect, structType, selectQueryCache[c.dialect.DriverName()]) + info := structs.GetTagInfo(structType) + + firstToken := strings.ToUpper(getFirstToken(parser.Query)) + if info.IsNestedStruct && firstToken == "SELECT" { + // This error check is necessary, since if we can't build the select part of the query this feature won't work. + return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query") + } + + if firstToken == "FROM" { + selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()]) if err != nil { return err } @@ -347,22 +343,19 @@ func (c DB) QueryChunks( // the ID is automatically updated after insertion is completed. func (c DB) Insert( ctx context.Context, + table Table, record interface{}, ) error { - if c.tableName == "" { - return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Insert method") - } - - query, params, scanValues, err := buildInsertQuery(c.dialect, c.tableName, record, c.idCols...) + query, params, scanValues, err := buildInsertQuery(c.dialect, table.name, record, table.idColumns...) if err != nil { return err } - switch c.insertMethod { + switch table.insertMethodFor(c.dialect) { case insertWithReturning, insertWithOutput: - err = c.insertReturningIDs(ctx, record, query, params, scanValues, c.idCols) + err = c.insertReturningIDs(ctx, record, query, params, scanValues, table.idColumns) case insertWithLastInsertID: - err = c.insertWithLastInsertID(ctx, record, query, params, c.idCols[0]) + err = c.insertWithLastInsertID(ctx, record, query, params, table.idColumns[0]) case insertWithNoIDRetrieval: err = c.insertWithNoIDRetrieval(ctx, record, query, params) default: @@ -471,27 +464,24 @@ func assertStructPtr(t reflect.Type) error { // Delete deletes one or more instances from the database by id func (c DB) Delete( ctx context.Context, + table Table, ids ...interface{}, ) error { - if c.tableName == "" { - return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Delete method") - } - if len(ids) == 0 { return nil } - idMaps, err := normalizeIDsAsMaps(c.idCols, ids) + idMaps, err := normalizeIDsAsMaps(table.idColumns, ids) if err != nil { return err } var query string var params []interface{} - if len(c.idCols) == 1 { - query, params = buildSingleKeyDeleteQuery(c.dialect, c.tableName, c.idCols[0], idMaps) + if len(table.idColumns) == 1 { + query, params = buildSingleKeyDeleteQuery(c.dialect, table.name, table.idColumns[0], idMaps) } else { - query, params = buildCompositeKeyDeleteQuery(c.dialect, c.tableName, c.idCols, idMaps) + query, params = buildCompositeKeyDeleteQuery(c.dialect, table.name, table.idColumns, idMaps) } _, err = c.db.ExecContext(ctx, query, params...) @@ -543,13 +533,10 @@ func normalizeIDsAsMaps(idNames []string, ids []interface{}) ([]map[string]inter // Partial updates are supported, i.e. it will ignore nil pointer attributes func (c DB) Update( ctx context.Context, + table Table, record interface{}, ) error { - if c.tableName == "" { - return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Update method") - } - - query, params, err := buildUpdateQuery(c.dialect, c.tableName, record, c.idCols...) + query, params, err := buildUpdateQuery(c.dialect, table.name, record, table.idColumns...) if err != nil { return err } @@ -964,13 +951,13 @@ func getFirstToken(s string) string { func buildSelectQuery( dialect dialect, structType reflect.Type, + info structs.StructInfo, selectQueryCache map[reflect.Type]string, ) (query string, err error) { if selectQuery, found := selectQueryCache[structType]; found { return selectQuery, nil } - info := structs.GetTagInfo(structType) if info.IsNestedStruct { query, err = buildSelectQueryForNestedStructs(dialect, structType, info) if err != nil { diff --git a/ksql_test.go b/ksql_test.go index 95d4461..3eea91a 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -25,6 +25,8 @@ type User struct { Address Address `ksql:"address,json"` } +var UsersTable = NewTable("users") + type Address struct { Street string `json:"street"` Number string `json:"number"` @@ -40,6 +42,8 @@ type Post struct { Title string `ksql:"title"` } +var PostsTable = NewTable("posts") + func TestQuery(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { @@ -69,7 +73,7 @@ func TestQuery(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var users []User err := c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`) assert.Equal(t, nil, err) @@ -89,7 +93,7 @@ func TestQuery(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var users []User err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") @@ -111,7 +115,7 @@ func TestQuery(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var users []User err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia") @@ -154,7 +158,7 @@ func TestQuery(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var rows []struct { User User `tablename:"u"` Post Post `tablename:"p"` @@ -193,7 +197,7 @@ func TestQuery(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var users []*User err := c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`) assert.Equal(t, nil, err) @@ -213,7 +217,7 @@ func TestQuery(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var users []*User err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") @@ -235,7 +239,7 @@ func TestQuery(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var users []*User err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia") @@ -278,7 +282,7 @@ func TestQuery(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var rows []*struct { User User `tablename:"u"` Post Post `tablename:"p"` @@ -325,7 +329,7 @@ func TestQuery(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) err = c.Query(ctx, &User{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá") assert.NotEqual(t, nil, err) @@ -345,56 +349,72 @@ func TestQuery(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var users []User - err = c.Query(ctx, &users, `SELECT * FROM not a valid query`) + err := c.Query(ctx, &users, `SELECT * FROM not a valid query`) assert.NotEqual(t, nil, err) }) - }) - t.Run("should report error for nested structs with invalid types", func(t *testing.T) { - t.Run("int", func(t *testing.T) { + t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) { db := connectDB(t, driver) defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var rows []struct { - Foo int `tablename:"foo"` + User User `tablename:"users"` + Post Post `tablename:"posts"` } - err := c.Query(ctx, &rows, fmt.Sprint( - `FROM users u JOIN posts p ON p.user_id = u.id`, - ` WHERE u.name like `, c.dialect.Placeholder(0), - ` ORDER BY u.id, p.id`, - ), "% Ribeiro") - + err := c.Query(ctx, &rows, `SELECT * FROM users u JOIN posts p ON u.id = p.user_id`) assert.NotEqual(t, nil, err) - msg := err.Error() - for _, str := range []string{"foo", "int"} { - assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg)) - } + assert.Equal(t, true, strings.Contains(err.Error(), "nested struct"), "unexpected error msg: "+err.Error()) + assert.Equal(t, true, strings.Contains(err.Error(), "feature"), "unexpected error msg: "+err.Error()) }) - t.Run("*struct", func(t *testing.T) { - db := connectDB(t, driver) - defer db.Close() + t.Run("should report error for nested structs with invalid types", func(t *testing.T) { + t.Run("int", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() - ctx := context.Background() - c := newTestDB(db, driver, "users") - var rows []struct { - Foo *User `tablename:"foo"` - } - err := c.Query(ctx, &rows, fmt.Sprint( - `FROM users u JOIN posts p ON p.user_id = u.id`, - ` WHERE u.name like `, c.dialect.Placeholder(0), - ` ORDER BY u.id, p.id`, - ), "% Ribeiro") + ctx := context.Background() + c := newTestDB(db, driver) + var rows []struct { + Foo int `tablename:"foo"` + } + err := c.Query(ctx, &rows, fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), "% Ribeiro") - assert.NotEqual(t, nil, err) - msg := err.Error() - for _, str := range []string{"foo", "*ksql.User"} { - assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg)) - } + assert.NotEqual(t, nil, err) + msg := err.Error() + for _, str := range []string{"foo", "int"} { + assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg)) + } + }) + + t.Run("*struct", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + var rows []struct { + Foo *User `tablename:"foo"` + } + err := c.Query(ctx, &rows, fmt.Sprint( + `FROM users u JOIN posts p ON p.user_id = u.id`, + ` WHERE u.name like `, c.dialect.Placeholder(0), + ` ORDER BY u.id, p.id`, + ), "% Ribeiro") + + assert.NotEqual(t, nil, err) + msg := err.Error() + for _, str := range []string{"foo", "*ksql.User"} { + assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg)) + } + }) }) }) }) @@ -429,7 +449,7 @@ func TestQueryOne(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u := User{} err := c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE id=1;`) assert.Equal(t, ErrRecordNotFound, err) @@ -443,7 +463,7 @@ func TestQueryOne(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u := User{} err = c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia") @@ -466,7 +486,7 @@ func TestQueryOne(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var u User err = c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0)+` ORDER BY id ASC`, "% Sá") @@ -496,7 +516,7 @@ func TestQueryOne(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var row struct { User User `tablename:"u"` Post Post `tablename:"p"` @@ -526,7 +546,7 @@ func TestQueryOne(t *testing.T) { assert.Equal(t, nil, err) ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) err = c.QueryOne(ctx, &[]User{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá") assert.NotEqual(t, nil, err) @@ -540,11 +560,27 @@ func TestQueryOne(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) var user User err := c.QueryOne(ctx, &user, `SELECT * FROM not a valid query`) assert.NotEqual(t, nil, err) }) + + t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + var row struct { + User User `tablename:"users"` + Post Post `tablename:"posts"` + } + err := c.QueryOne(ctx, &row, `SELECT * FROM users u JOIN posts p ON u.id = p.user_id LIMIT 1`) + assert.NotEqual(t, nil, err) + assert.Equal(t, true, strings.Contains(err.Error(), "nested struct"), "unexpected error msg: "+err.Error()) + assert.Equal(t, true, strings.Contains(err.Error(), "feature"), "unexpected error msg: "+err.Error()) + }) }) } } @@ -563,7 +599,7 @@ func TestInsert(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u := User{ Name: "Fernanda", @@ -572,7 +608,7 @@ func TestInsert(t *testing.T) { }, } - err := c.Insert(ctx, &u) + err := c.Insert(ctx, UsersTable, &u) assert.Equal(t, nil, err) assert.NotEqual(t, 0, u.ID) @@ -593,11 +629,11 @@ func TestInsert(t *testing.T) { defer db.Close() ctx := context.Background() + // Using columns "id" and "name" as IDs: - c, err := New(driver, connectionString[driver], Config{ - TableName: "users", - IDColumns: []string{"id", "name"}, - }) + table := NewTable("users", "id", "name") + + c, err := New(driver, connectionString[driver], Config{}) assert.Equal(t, nil, err) u := User{ @@ -609,7 +645,7 @@ func TestInsert(t *testing.T) { }, } - err = c.Insert(ctx, &u) + err = c.Insert(ctx, table, &u) assert.Equal(t, nil, err) assert.Equal(t, uint(0), u.ID) @@ -633,15 +669,15 @@ func TestInsert(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - err = c.Insert(ctx, "foo") + err = c.Insert(ctx, UsersTable, "foo") assert.NotEqual(t, nil, err) - err = c.Insert(ctx, nullable.String("foo")) + err = c.Insert(ctx, UsersTable, nullable.String("foo")) assert.NotEqual(t, nil, err) - err = c.Insert(ctx, map[string]interface{}{ + err = c.Insert(ctx, UsersTable, map[string]interface{}{ "name": "foo", "age": 12, }) @@ -651,11 +687,11 @@ func TestInsert(t *testing.T) { &User{Name: "foo", Age: 22}, &User{Name: "bar", Age: 32}, } - err = c.Insert(ctx, ifUserForgetToExpandList) + err = c.Insert(ctx, UsersTable, ifUserForgetToExpandList) assert.NotEqual(t, nil, err) // We might want to support this in the future, but not for now: - err = c.Insert(ctx, User{Name: "not a ptr to user", Age: 42}) + err = c.Insert(ctx, UsersTable, User{Name: "not a ptr to user", Age: 42}) assert.NotEqual(t, nil, err) }) @@ -664,12 +700,12 @@ func TestInsert(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) // This is an invalid value: - c.insertMethod = insertMethod(42) + c.dialect = brokenDialect{} - err = c.Insert(ctx, &User{Name: "foo"}) + err = c.Insert(ctx, UsersTable, &User{Name: "foo"}) assert.NotEqual(t, nil, err) }) }) @@ -677,6 +713,24 @@ func TestInsert(t *testing.T) { } } +type brokenDialect struct{} + +func (brokenDialect) InsertMethod() insertMethod { + return insertMethod(42) +} + +func (brokenDialect) Escape(str string) string { + return str +} + +func (brokenDialect) Placeholder(idx int) string { + return "?" +} + +func (brokenDialect) DriverName() string { + return "fake-driver-name" +} + func TestDelete(t *testing.T) { for driver := range supportedDialects { t.Run(driver, func(t *testing.T) { @@ -690,13 +744,13 @@ func TestDelete(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u := User{ Name: "Won't be deleted", } - err := c.Insert(ctx, &u) + err := c.Insert(ctx, UsersTable, &u) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) @@ -706,7 +760,7 @@ func TestDelete(t *testing.T) { assert.Equal(t, u.ID, result.ID) - err = c.Delete(ctx) + err = c.Delete(ctx, UsersTable) assert.Equal(t, nil, err) result = User{} @@ -720,13 +774,13 @@ func TestDelete(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u1 := User{ Name: "Fernanda", } - err := c.Insert(ctx, &u1) + err := c.Insert(ctx, UsersTable, &u1) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u1.ID) @@ -739,7 +793,7 @@ func TestDelete(t *testing.T) { Name: "Won't be deleted", } - err = c.Insert(ctx, &u2) + err = c.Insert(ctx, UsersTable, &u2) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u2.ID) @@ -748,7 +802,7 @@ func TestDelete(t *testing.T) { assert.Equal(t, nil, err) assert.Equal(t, u2.ID, result.ID) - err = c.Delete(ctx, u1.ID) + err = c.Delete(ctx, UsersTable, u1.ID) assert.Equal(t, nil, err) result = User{} @@ -768,26 +822,26 @@ func TestDelete(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u1 := User{ Name: "Fernanda", } - err := c.Insert(ctx, &u1) + err := c.Insert(ctx, UsersTable, &u1) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u1.ID) u2 := User{ Name: "Juliano", } - err = c.Insert(ctx, &u2) + err = c.Insert(ctx, UsersTable, &u2) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u2.ID) u3 := User{ Name: "This won't be deleted", } - err = c.Insert(ctx, &u3) + err = c.Insert(ctx, UsersTable, &u3) assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u3.ID) @@ -806,7 +860,7 @@ func TestDelete(t *testing.T) { assert.Equal(t, nil, err) assert.Equal(t, u3.ID, result.ID) - err = c.Delete(ctx, u1.ID, u2.ID) + err = c.Delete(ctx, UsersTable, u1.ID, u2.ID) assert.Equal(t, nil, err) results := []User{} @@ -833,7 +887,7 @@ func TestUpdate(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u := User{ Name: "Letícia", @@ -847,7 +901,7 @@ func TestUpdate(t *testing.T) { assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) - err = c.Update(ctx, User{ + err = c.Update(ctx, UsersTable, User{ ID: u.ID, Name: "Thayane", }) @@ -864,7 +918,7 @@ func TestUpdate(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u := User{ Name: "Letícia", @@ -878,7 +932,7 @@ func TestUpdate(t *testing.T) { assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) - err = c.Update(ctx, User{ + err = c.Update(ctx, UsersTable, User{ ID: u.ID, Name: "Thayane", }) @@ -895,7 +949,7 @@ func TestUpdate(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) type partialUser struct { ID uint `ksql:"id"` @@ -915,7 +969,7 @@ func TestUpdate(t *testing.T) { assert.Equal(t, nil, err) assert.NotEqual(t, uint(0), u.ID) - err = c.Update(ctx, partialUser{ + err = c.Update(ctx, UsersTable, partialUser{ ID: u.ID, // Should be updated because it is not null, just empty: Name: "", @@ -936,7 +990,7 @@ func TestUpdate(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) type partialUser struct { ID uint `ksql:"id"` @@ -957,7 +1011,7 @@ func TestUpdate(t *testing.T) { assert.NotEqual(t, uint(0), u.ID) // Should update all fields: - err = c.Update(ctx, partialUser{ + err = c.Update(ctx, UsersTable, partialUser{ ID: u.ID, Name: "Thay", Age: nullable.Int(42), @@ -977,9 +1031,9 @@ func TestUpdate(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "non_existing_table") + c := newTestDB(db, driver) - err = c.Update(ctx, User{ + err = c.Update(ctx, NewTable("non_existing_table"), User{ ID: 1, Name: "Thayane", }) @@ -1017,9 +1071,9 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{ + _ = c.Insert(ctx, UsersTable, &User{ Name: "User1", Address: Address{Country: "BR"}, }) @@ -1057,10 +1111,10 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{Name: "User1", Address: Address{Country: "US"}}) - _ = c.Insert(ctx, &User{Name: "User2", Address: Address{Country: "BR"}}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User1", Address: Address{Country: "US"}}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2", Address: Address{Country: "BR"}}) var lengths []int var users []User @@ -1099,10 +1153,10 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{Name: "User1", Address: Address{Country: "US"}}) - _ = c.Insert(ctx, &User{Name: "User2", Address: Address{Country: "BR"}}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User1", Address: Address{Country: "US"}}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2", Address: Address{Country: "BR"}}) var lengths []int var users []User @@ -1141,11 +1195,11 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3"}) var lengths []int var users []User @@ -1192,9 +1246,9 @@ func TestQueryChunks(t *testing.T) { } ctx := context.Background() - c := newTestDB(db, driver, "users") - _ = c.Insert(ctx, &joao) - _ = c.Insert(ctx, &thatiana) + c := newTestDB(db, driver) + _ = c.Insert(ctx, UsersTable, &joao) + _ = c.Insert(ctx, UsersTable, &thatiana) _, err := db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, thatiana.ID, `, 'Thatiana Post1')`)) assert.Equal(t, nil, err) @@ -1254,11 +1308,11 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3"}) var lengths []int var users []User @@ -1293,11 +1347,11 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3"}) returnVals := []error{nil, ErrAbortIteration} var lengths []int @@ -1336,11 +1390,11 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3"}) var lengths []int var users []User @@ -1375,11 +1429,11 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) - _ = c.Insert(ctx, &User{Name: "User3"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3"}) returnVals := []error{nil, errors.New("fake error msg")} var lengths []int @@ -1413,7 +1467,7 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) funcs := []interface{}{ nil, @@ -1458,7 +1512,7 @@ func TestQueryChunks(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) err := c.QueryChunks(ctx, ChunkParser{ Query: `SELECT * FROM not a valid query`, Params: []interface{}{}, @@ -1470,6 +1524,31 @@ func TestQueryChunks(t *testing.T) { }) assert.NotEqual(t, nil, err) }) + + t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) { + db := connectDB(t, driver) + defer db.Close() + + ctx := context.Background() + c := newTestDB(db, driver) + + err := c.QueryChunks(ctx, ChunkParser{ + Query: `SELECT * FROM users u JOIN posts p ON u.id = p.user_id`, + Params: []interface{}{}, + + ChunkSize: 2, + ForEachChunk: func(buffer []struct { + User User `tablename:"users"` + Post Post `tablename:"posts"` + }) error { + return nil + }, + }) + + assert.NotEqual(t, nil, err) + assert.Equal(t, true, strings.Contains(err.Error(), "nested struct"), "unexpected error msg: "+err.Error()) + assert.Equal(t, true, strings.Contains(err.Error(), "feature"), "unexpected error msg: "+err.Error()) + }) }) } }) @@ -1489,10 +1568,10 @@ func TestTransaction(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) - _ = c.Insert(ctx, &User{Name: "User1"}) - _ = c.Insert(ctx, &User{Name: "User2"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User1"}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2"}) var users []User err = c.Transaction(ctx, func(db SQLProvider) error { @@ -1516,17 +1595,17 @@ func TestTransaction(t *testing.T) { defer db.Close() ctx := context.Background() - c := newTestDB(db, driver, "users") + c := newTestDB(db, driver) u1 := User{Name: "User1", Age: 42} u2 := User{Name: "User2", Age: 42} - _ = c.Insert(ctx, &u1) - _ = c.Insert(ctx, &u2) + _ = c.Insert(ctx, UsersTable, &u1) + _ = c.Insert(ctx, UsersTable, &u2) err = c.Transaction(ctx, func(db SQLProvider) error { - err = db.Insert(ctx, &User{Name: "User3"}) + err = db.Insert(ctx, UsersTable, &User{Name: "User3"}) assert.Equal(t, nil, err) - err = db.Insert(ctx, &User{Name: "User4"}) + err = db.Insert(ctx, UsersTable, &User{Name: "User4"}) assert.Equal(t, nil, err) err = db.Exec(ctx, "UPDATE users SET age = 22") assert.Equal(t, nil, err) @@ -1557,10 +1636,10 @@ func TestScanRows(t *testing.T) { ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() - c := newTestDB(db, "sqlite3", "users") - _ = c.Insert(ctx, &User{Name: "User1", Age: 22}) - _ = c.Insert(ctx, &User{Name: "User2", Age: 14}) - _ = c.Insert(ctx, &User{Name: "User3", Age: 43}) + c := newTestDB(db, "sqlite3") + _ = c.Insert(ctx, UsersTable, &User{Name: "User1", Age: 22}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User2", Age: 14}) + _ = c.Insert(ctx, UsersTable, &User{Name: "User3", Age: 43}) rows, err := db.QueryContext(ctx, "select * from users where name='User2'") assert.Equal(t, nil, err) @@ -1586,8 +1665,8 @@ func TestScanRows(t *testing.T) { ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() - c := newTestDB(db, "sqlite3", "users") - _ = c.Insert(ctx, &User{Name: "User1", Age: 22}) + c := newTestDB(db, "sqlite3") + _ = c.Insert(ctx, UsersTable, &User{Name: "User1", Age: 22}) rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User1'") assert.Equal(t, nil, err) @@ -1758,19 +1837,11 @@ func createTables(driver string) error { return nil } -func newTestDB(db *sql.DB, driver string, tableName string, ids ...string) DB { - if len(ids) == 0 { - ids = []string{"id"} - } - +func newTestDB(db *sql.DB, driver string) DB { return DB{ - driver: driver, - dialect: supportedDialects[driver], - db: db, - tableName: tableName, - - idCols: ids, - insertMethod: supportedDialects[driver].InsertMethod(), + driver: driver, + dialect: supportedDialects[driver], + db: db, } } diff --git a/mocks.go b/mocks.go index eab1770..8fc7321 100644 --- a/mocks.go +++ b/mocks.go @@ -6,9 +6,9 @@ var _ SQLProvider = MockSQLProvider{} // MockSQLProvider ... type MockSQLProvider struct { - InsertFn func(ctx context.Context, record interface{}) error - UpdateFn func(ctx context.Context, record interface{}) error - DeleteFn func(ctx context.Context, ids ...interface{}) error + InsertFn func(ctx context.Context, table Table, record interface{}) error + UpdateFn func(ctx context.Context, table Table, record interface{}) error + DeleteFn func(ctx context.Context, table Table, ids ...interface{}) error QueryFn func(ctx context.Context, records interface{}, query string, params ...interface{}) error QueryOneFn func(ctx context.Context, record interface{}, query string, params ...interface{}) error @@ -19,18 +19,18 @@ type MockSQLProvider struct { } // Insert ... -func (m MockSQLProvider) Insert(ctx context.Context, record interface{}) error { - return m.InsertFn(ctx, record) +func (m MockSQLProvider) Insert(ctx context.Context, table Table, record interface{}) error { + return m.InsertFn(ctx, table, record) } // Update ... -func (m MockSQLProvider) Update(ctx context.Context, record interface{}) error { - return m.UpdateFn(ctx, record) +func (m MockSQLProvider) Update(ctx context.Context, table Table, record interface{}) error { + return m.UpdateFn(ctx, table, record) } // Delete ... -func (m MockSQLProvider) Delete(ctx context.Context, ids ...interface{}) error { - return m.DeleteFn(ctx, ids...) +func (m MockSQLProvider) Delete(ctx context.Context, table Table, ids ...interface{}) error { + return m.DeleteFn(ctx, table, ids...) } // Query ...