Breaking change: Update SQLProvider interface so methods receive table info as argument

pull/2/head
Vinícius Garcia 2021-06-06 20:51:13 -03:00
parent 54f5b7b1eb
commit c0d7206dcc
10 changed files with 403 additions and 282 deletions

View File

@ -12,6 +12,8 @@ import (
"github.com/vingarcia/ksql"
)
var UsersTable = ksql.NewTable("users")
func BenchmarkInsert(b *testing.B) {
ctx := context.Background()
@ -20,7 +22,6 @@ func BenchmarkInsert(b *testing.B) {
ksqlDB, err := ksql.New(driver, connStr, ksql.Config{
MaxOpenConns: 1,
TableName: "users",
})
if err != nil {
b.FailNow()
@ -40,7 +41,7 @@ func BenchmarkInsert(b *testing.B) {
b.Run("insert-one", func(b *testing.B) {
for i := 0; i < b.N; i++ {
err := ksqlDB.Insert(ctx, &User{
err := ksqlDB.Insert(ctx, UsersTable, &User{
Name: strconv.Itoa(i),
Age: i,
})
@ -92,7 +93,6 @@ func BenchmarkQuery(b *testing.B) {
ksqlDB, err := ksql.New(driver, connStr, ksql.Config{
MaxOpenConns: 1,
TableName: "users",
})
if err != nil {
b.FailNow()

View File

@ -16,9 +16,9 @@ var ErrAbortIteration error = fmt.Errorf("ksql: abort iteration, should only be
// SQLProvider describes the public behavior of this ORM
type SQLProvider interface {
Insert(ctx context.Context, record interface{}) error
Update(ctx context.Context, record interface{}) error
Delete(ctx context.Context, idsOrRecords ...interface{}) error
Insert(ctx context.Context, table Table, record interface{}) error
Update(ctx context.Context, table Table, record interface{}) error
Delete(ctx context.Context, table Table, idsOrRecords ...interface{}) error
Query(ctx context.Context, records interface{}, query string, params ...interface{}) error
QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error
@ -28,6 +28,60 @@ type SQLProvider interface {
Transaction(ctx context.Context, fn func(SQLProvider) error) error
}
// Table describes the required information for inserting, updating and
// deleting entities from the database by ID using the 3 helper functions
// created for that purpose.
type Table struct {
// this name must be set in order to use the Insert, Delete and Update helper
// functions. If you only intend to make queries or to use the Exec function
// it is safe to leave this field unset.
name string
// IDColumns defaults to []string{"id"} if unset
idColumns []string
}
// NewTable returns a Table instance that stores
// the tablename and the names of columns used as ID,
// if no column name is passed it defaults to using
// the `"id"` column.
//
// This Table is required only for using the helper methods:
//
// - Insert
// - Update
// - Delete
//
// Passing multiple ID columns will be interpreted
// as a single composite key, if you want
// to use the helper functions with different
// keys you'll need to create multiple Table instances
// for the same database table, each with a different
// set of ID columns, but this is usually not necessary.
func NewTable(tableName string, ids ...string) Table {
if len(ids) == 0 {
ids = []string{"id"}
}
return Table{
name: tableName,
idColumns: ids,
}
}
func (t Table) insertMethodFor(dialect dialect) insertMethod {
if len(t.idColumns) == 1 {
return dialect.InsertMethod()
}
insertMethod := dialect.InsertMethod()
if insertMethod == insertWithLastInsertID {
return insertWithNoIDRetrieval
}
return insertMethod
}
// ChunkParser stores the arguments of the QueryChunks function
type ChunkParser struct {
// The Query and Params are used together to build a query with

View File

@ -33,11 +33,14 @@ type Address struct {
City string `json:"city"`
}
// UsersTable informs ksql the name of the table and that it can
// use the default value for the primary key column name: "id"
var UsersTable = ksql.NewTable("users")
func main() {
ctx := context.Background()
db, err := ksql.New("sqlite3", "/tmp/hello.sqlite", ksql.Config{
MaxOpenConns: 1,
TableName: "users",
})
if err != nil {
panic(err.Error())
@ -62,14 +65,14 @@ func main() {
State: "MG",
},
}
err = db.Insert(ctx, &alison)
err = db.Insert(ctx, UsersTable, &alison)
if err != nil {
panic(err.Error())
}
fmt.Println("Alison ID:", alison.ID)
// Inserting inline:
err = db.Insert(ctx, &User{
err = db.Insert(ctx, UsersTable, &User{
Name: "Cristina",
Age: 27,
Address: Address{
@ -81,7 +84,7 @@ func main() {
}
// Deleting Alison:
err = db.Delete(ctx, alison.ID)
err = db.Delete(ctx, UsersTable, alison.ID)
if err != nil {
panic(err.Error())
}
@ -96,12 +99,12 @@ func main() {
// Updating all fields from Cristina:
cris.Name = "Cris"
err = db.Update(ctx, cris)
err = db.Update(ctx, UsersTable, cris)
// Changing the age of Cristina but not touching any other fields:
// Partial update technique 1:
err = db.Update(ctx, struct {
err = db.Update(ctx, UsersTable, struct {
ID int `ksql:"id"`
Age int `ksql:"age"`
}{ID: cris.ID, Age: 28})
@ -110,7 +113,7 @@ func main() {
}
// Partial update technique 2:
err = db.Update(ctx, PartialUpdateUser{
err = db.Update(ctx, UsersTable, PartialUpdateUser{
ID: cris.ID,
Age: nullable.Int(28),
})
@ -142,7 +145,7 @@ func main() {
return err
}
err = db.Update(ctx, PartialUpdateUser{
err = db.Update(ctx, UsersTable, PartialUpdateUser{
ID: cris2.ID,
Age: nullable.Int(29),
})

View File

@ -8,11 +8,8 @@ import (
"github.com/vingarcia/ksql/nullable"
)
// Service ...
type Service struct {
usersTable ksql.SQLProvider
streamChunkSize int
}
// UsersTable informs ksql that the ID column is named "id"
var UsersTable = ksql.NewTable("users", "id")
// UserEntity represents a domain user,
// the pointer fields represent optional fields that
@ -41,17 +38,23 @@ type Address struct {
Country string `json:"country"`
}
// Service ...
type Service struct {
db ksql.SQLProvider
streamChunkSize int
}
// NewUserService ...
func NewUserService(usersTable ksql.SQLProvider) Service {
func NewUserService(db ksql.SQLProvider) Service {
return Service{
usersTable: usersTable,
db: db,
streamChunkSize: 100,
}
}
// CreateUser ...
func (s Service) CreateUser(ctx context.Context, u UserEntity) error {
return s.usersTable.Insert(ctx, &u)
return s.db.Insert(ctx, UsersTable, &u)
}
// UpdateUserScore update the user score adding scoreChange with the current
@ -60,12 +63,12 @@ func (s Service) UpdateUserScore(ctx context.Context, uID int, scoreChange int)
var scoreRow struct {
Score int `ksql:"score"`
}
err := s.usersTable.QueryOne(ctx, &scoreRow, "SELECT score FROM users WHERE id = ?", uID)
err := s.db.QueryOne(ctx, &scoreRow, "SELECT score FROM users WHERE id = ?", uID)
if err != nil {
return err
}
return s.usersTable.Update(ctx, &UserEntity{
return s.db.Update(ctx, UsersTable, &UserEntity{
ID: uID,
Score: nullable.Int(scoreRow.Score + scoreChange),
})
@ -76,12 +79,12 @@ func (s Service) ListUsers(ctx context.Context, offset, limit int) (total int, u
var countRow struct {
Count int `ksql:"count"`
}
err = s.usersTable.QueryOne(ctx, &countRow, "SELECT count(*) as count FROM users")
err = s.db.QueryOne(ctx, &countRow, "SELECT count(*) as count FROM users")
if err != nil {
return 0, nil, err
}
return countRow.Count, users, s.usersTable.Query(ctx, &users, "SELECT * FROM users OFFSET ? LIMIT ?", offset, limit)
return countRow.Count, users, s.db.Query(ctx, &users, "SELECT * FROM users OFFSET ? LIMIT ?", offset, limit)
}
// StreamAllUsers sends all users from the database to an external client
@ -91,7 +94,7 @@ func (s Service) ListUsers(ctx context.Context, offset, limit int) (total int, u
// function only when the ammount of data loaded might exceed the available memory and/or
// when you can't put an upper limit on the number of values returned.
func (s Service) StreamAllUsers(ctx context.Context, sendUser func(u UserEntity) error) error {
return s.usersTable.QueryChunks(ctx, ksql.ChunkParser{
return s.db.QueryChunks(ctx, ksql.ChunkParser{
Query: "SELECT * FROM users",
Params: []interface{}{},
ChunkSize: s.streamChunkSize,
@ -110,5 +113,5 @@ func (s Service) StreamAllUsers(ctx context.Context, sendUser func(u UserEntity)
// DeleteUser deletes a user by its ID
func (s Service) DeleteUser(ctx context.Context, uID int) error {
return s.usersTable.Delete(ctx, uID)
return s.db.Delete(ctx, UsersTable, uID)
}

View File

@ -17,16 +17,16 @@ func TestCreateUser(t *testing.T) {
controller := gomock.NewController(t)
defer controller.Finish()
usersTableMock := NewMockSQLProvider(controller)
mockDB := NewMockSQLProvider(controller)
s := Service{
usersTable: usersTableMock,
db: mockDB,
streamChunkSize: 100,
}
var users []interface{}
usersTableMock.EXPECT().Insert(gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, records ...interface{}) error {
mockDB.EXPECT().Insert(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, table ksql.Table, records ...interface{}) error {
users = append(users, records...)
return nil
})
@ -43,16 +43,16 @@ func TestCreateUser(t *testing.T) {
controller := gomock.NewController(t)
defer controller.Finish()
usersTableMock := NewMockSQLProvider(controller)
mockDB := NewMockSQLProvider(controller)
s := Service{
usersTable: usersTableMock,
db: mockDB,
streamChunkSize: 100,
}
var users []map[string]interface{}
usersTableMock.EXPECT().Insert(gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, records ...interface{}) error {
mockDB.EXPECT().Insert(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, table ksql.Table, records ...interface{}) error {
for _, record := range records {
// The StructToMap function will convert a struct with `ksql` tags
// into a map using the ksql attr names as keys.
@ -83,16 +83,16 @@ func TestUpdateUserScore(t *testing.T) {
controller := gomock.NewController(t)
defer controller.Finish()
usersTableMock := NewMockSQLProvider(controller)
mockDB := NewMockSQLProvider(controller)
s := Service{
usersTable: usersTableMock,
db: mockDB,
streamChunkSize: 100,
}
var users []interface{}
gomock.InOrder(
usersTableMock.EXPECT().QueryOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
mockDB.EXPECT().QueryOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, result interface{}, query string, params ...interface{}) error {
// This function will use reflection to fill the
// struct fields with the values from the map
@ -103,8 +103,8 @@ func TestUpdateUserScore(t *testing.T) {
"score": 42,
})
}),
usersTableMock.EXPECT().Update(gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, records ...interface{}) error {
mockDB.EXPECT().Update(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, table ksql.Table, records ...interface{}) error {
users = append(users, records...)
return nil
}),
@ -127,15 +127,15 @@ func TestListUsers(t *testing.T) {
controller := gomock.NewController(t)
defer controller.Finish()
usersTableMock := NewMockSQLProvider(controller)
mockDB := NewMockSQLProvider(controller)
s := Service{
usersTable: usersTableMock,
db: mockDB,
streamChunkSize: 100,
}
gomock.InOrder(
usersTableMock.EXPECT().QueryOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
mockDB.EXPECT().QueryOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, result interface{}, query string, params ...interface{}) error {
// This function will use reflection to fill the
// struct fields with the values from the map
@ -146,7 +146,7 @@ func TestListUsers(t *testing.T) {
"count": 420,
})
}),
usersTableMock.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
mockDB.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, results interface{}, query string, params ...interface{}) error {
return structs.FillSliceWith(results, []map[string]interface{}{
{
@ -189,14 +189,14 @@ func TestStreamAllUsers(t *testing.T) {
controller := gomock.NewController(t)
defer controller.Finish()
usersTableMock := NewMockSQLProvider(controller)
mockDB := NewMockSQLProvider(controller)
s := Service{
usersTable: usersTableMock,
db: mockDB,
streamChunkSize: 2,
}
usersTableMock.EXPECT().QueryChunks(gomock.Any(), gomock.Any()).
mockDB.EXPECT().QueryChunks(gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, parser ksql.ChunkParser) error {
fn, ok := parser.ForEachChunk.(func(users []UserEntity) error)
require.True(t, ok)
@ -263,16 +263,16 @@ func TestDeleteUser(t *testing.T) {
controller := gomock.NewController(t)
defer controller.Finish()
usersTableMock := NewMockSQLProvider(controller)
mockDB := NewMockSQLProvider(controller)
s := Service{
usersTable: usersTableMock,
db: mockDB,
streamChunkSize: 100,
}
var ids []interface{}
usersTableMock.EXPECT().Delete(gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, idArgs ...interface{}) error {
mockDB.EXPECT().Delete(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, table ksql.Table, idArgs ...interface{}) error {
ids = append(ids, idArgs...)
return nil
})

View File

@ -36,9 +36,9 @@ func (m *MockSQLProvider) EXPECT() *MockSQLProviderMockRecorder {
}
// Delete mocks base method.
func (m *MockSQLProvider) Delete(ctx context.Context, idsOrRecords ...interface{}) error {
func (m *MockSQLProvider) Delete(ctx context.Context, table ksql.Table, idsOrRecords ...interface{}) error {
m.ctrl.T.Helper()
varargs := []interface{}{ctx}
varargs := []interface{}{ctx, table}
for _, a := range idsOrRecords {
varargs = append(varargs, a)
}
@ -48,9 +48,9 @@ func (m *MockSQLProvider) Delete(ctx context.Context, idsOrRecords ...interface{
}
// Delete indicates an expected call of Delete.
func (mr *MockSQLProviderMockRecorder) Delete(ctx interface{}, idsOrRecords ...interface{}) *gomock.Call {
func (mr *MockSQLProviderMockRecorder) Delete(ctx, table interface{}, idsOrRecords ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{ctx}, idsOrRecords...)
varargs := append([]interface{}{ctx, table}, idsOrRecords...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSQLProvider)(nil).Delete), varargs...)
}
@ -74,17 +74,17 @@ func (mr *MockSQLProviderMockRecorder) Exec(ctx, query interface{}, params ...in
}
// Insert mocks base method.
func (m *MockSQLProvider) Insert(ctx context.Context, record interface{}) error {
func (m *MockSQLProvider) Insert(ctx context.Context, table ksql.Table, record interface{}) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Insert", ctx, record)
ret := m.ctrl.Call(m, "Insert", ctx, table, record)
ret0, _ := ret[0].(error)
return ret0
}
// Insert indicates an expected call of Insert.
func (mr *MockSQLProviderMockRecorder) Insert(ctx, record interface{}) *gomock.Call {
func (mr *MockSQLProviderMockRecorder) Insert(ctx, table, record interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockSQLProvider)(nil).Insert), ctx, record)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockSQLProvider)(nil).Insert), ctx, table, record)
}
// Query mocks base method.
@ -154,15 +154,15 @@ func (mr *MockSQLProviderMockRecorder) Transaction(ctx, fn interface{}) *gomock.
}
// Update mocks base method.
func (m *MockSQLProvider) Update(ctx context.Context, record interface{}) error {
func (m *MockSQLProvider) Update(ctx context.Context, table ksql.Table, record interface{}) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Update", ctx, record)
ret := m.ctrl.Call(m, "Update", ctx, table, record)
ret0, _ := ret[0].(error)
return ret0
}
// Update indicates an expected call of Update.
func (mr *MockSQLProviderMockRecorder) Update(ctx, record interface{}) *gomock.Call {
func (mr *MockSQLProviderMockRecorder) Update(ctx, table, record interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockSQLProvider)(nil).Update), ctx, record)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockSQLProvider)(nil).Update), ctx, table, record)
}

3
go.sum
View File

@ -35,6 +35,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
@ -45,8 +46,10 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e h1:aZzprAO9/8oim3qStq3wc1Xuxx4QmAGriC4VU4ojemQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898 h1:/atklqdjdhuosWIl6AIbOeHJjicWYPqR9bpxqxYG2pA=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=

111
ksql.go
View File

@ -24,16 +24,9 @@ func init() {
// interfacing with the "database/sql" package implementing
// the KissSQL interface `SQLProvider`.
type DB struct {
driver string
dialect dialect
tableName string
db sqlProvider
// Most dbs have a single primary key,
// But in future ksql should work with compound keys as well
idCols []string
insertMethod insertMethod
driver string
dialect dialect
db sqlProvider
}
type sqlProvider interface {
@ -46,14 +39,6 @@ type sqlProvider interface {
type Config struct {
// MaxOpenCons defaults to 1 if not set
MaxOpenConns int
// TableName must be set in order to use the Insert, Delete and Update helper
// functions. If you only intend to make queries or to use the Exec function
// it is safe to leave this field unset.
TableName string
// IDColumns defaults to []string{"id"} if unset
IDColumns []string
}
// New instantiates a new KissSQL client
@ -81,23 +66,10 @@ func New(
return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver)
}
if len(config.IDColumns) == 0 {
config.IDColumns = []string{"id"}
}
insertMethod := dialect.InsertMethod()
if len(config.IDColumns) > 1 && insertMethod == insertWithLastInsertID {
insertMethod = insertWithNoIDRetrieval
}
return DB{
dialect: dialect,
driver: dbDriver,
db: db,
tableName: config.TableName,
idCols: config.IDColumns,
insertMethod: insertMethod,
dialect: dialect,
driver: dbDriver,
db: db,
}, nil
}
@ -133,8 +105,16 @@ func (c DB) Query(
slice = slice.Slice(0, 0)
}
if strings.ToUpper(getFirstToken(query)) == "FROM" {
selectPrefix, err := buildSelectQuery(c.dialect, structType, selectQueryCache[c.dialect.DriverName()])
info := structs.GetTagInfo(structType)
firstToken := strings.ToUpper(getFirstToken(query))
if info.IsNestedStruct && firstToken == "SELECT" {
// This error check is necessary, since if we can't build the select part of the query this feature won't work.
return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query")
}
if firstToken == "FROM" {
selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
if err != nil {
return err
}
@ -206,8 +186,16 @@ func (c DB) QueryOne(
return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record)
}
if strings.ToUpper(getFirstToken(query)) == "FROM" {
selectPrefix, err := buildSelectQuery(c.dialect, t, selectQueryCache[c.dialect.DriverName()])
info := structs.GetTagInfo(t)
firstToken := strings.ToUpper(getFirstToken(query))
if info.IsNestedStruct && firstToken == "SELECT" {
// This error check is necessary, since if we can't build the select part of the query this feature won't work.
return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query")
}
if firstToken == "FROM" {
selectPrefix, err := buildSelectQuery(c.dialect, t, info, selectQueryCache[c.dialect.DriverName()])
if err != nil {
return err
}
@ -268,8 +256,16 @@ func (c DB) QueryChunks(
return err
}
if strings.ToUpper(getFirstToken(parser.Query)) == "FROM" {
selectPrefix, err := buildSelectQuery(c.dialect, structType, selectQueryCache[c.dialect.DriverName()])
info := structs.GetTagInfo(structType)
firstToken := strings.ToUpper(getFirstToken(parser.Query))
if info.IsNestedStruct && firstToken == "SELECT" {
// This error check is necessary, since if we can't build the select part of the query this feature won't work.
return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query")
}
if firstToken == "FROM" {
selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
if err != nil {
return err
}
@ -347,22 +343,19 @@ func (c DB) QueryChunks(
// the ID is automatically updated after insertion is completed.
func (c DB) Insert(
ctx context.Context,
table Table,
record interface{},
) error {
if c.tableName == "" {
return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Insert method")
}
query, params, scanValues, err := buildInsertQuery(c.dialect, c.tableName, record, c.idCols...)
query, params, scanValues, err := buildInsertQuery(c.dialect, table.name, record, table.idColumns...)
if err != nil {
return err
}
switch c.insertMethod {
switch table.insertMethodFor(c.dialect) {
case insertWithReturning, insertWithOutput:
err = c.insertReturningIDs(ctx, record, query, params, scanValues, c.idCols)
err = c.insertReturningIDs(ctx, record, query, params, scanValues, table.idColumns)
case insertWithLastInsertID:
err = c.insertWithLastInsertID(ctx, record, query, params, c.idCols[0])
err = c.insertWithLastInsertID(ctx, record, query, params, table.idColumns[0])
case insertWithNoIDRetrieval:
err = c.insertWithNoIDRetrieval(ctx, record, query, params)
default:
@ -471,27 +464,24 @@ func assertStructPtr(t reflect.Type) error {
// Delete deletes one or more instances from the database by id
func (c DB) Delete(
ctx context.Context,
table Table,
ids ...interface{},
) error {
if c.tableName == "" {
return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Delete method")
}
if len(ids) == 0 {
return nil
}
idMaps, err := normalizeIDsAsMaps(c.idCols, ids)
idMaps, err := normalizeIDsAsMaps(table.idColumns, ids)
if err != nil {
return err
}
var query string
var params []interface{}
if len(c.idCols) == 1 {
query, params = buildSingleKeyDeleteQuery(c.dialect, c.tableName, c.idCols[0], idMaps)
if len(table.idColumns) == 1 {
query, params = buildSingleKeyDeleteQuery(c.dialect, table.name, table.idColumns[0], idMaps)
} else {
query, params = buildCompositeKeyDeleteQuery(c.dialect, c.tableName, c.idCols, idMaps)
query, params = buildCompositeKeyDeleteQuery(c.dialect, table.name, table.idColumns, idMaps)
}
_, err = c.db.ExecContext(ctx, query, params...)
@ -543,13 +533,10 @@ func normalizeIDsAsMaps(idNames []string, ids []interface{}) ([]map[string]inter
// Partial updates are supported, i.e. it will ignore nil pointer attributes
func (c DB) Update(
ctx context.Context,
table Table,
record interface{},
) error {
if c.tableName == "" {
return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Update method")
}
query, params, err := buildUpdateQuery(c.dialect, c.tableName, record, c.idCols...)
query, params, err := buildUpdateQuery(c.dialect, table.name, record, table.idColumns...)
if err != nil {
return err
}
@ -964,13 +951,13 @@ func getFirstToken(s string) string {
func buildSelectQuery(
dialect dialect,
structType reflect.Type,
info structs.StructInfo,
selectQueryCache map[reflect.Type]string,
) (query string, err error) {
if selectQuery, found := selectQueryCache[structType]; found {
return selectQuery, nil
}
info := structs.GetTagInfo(structType)
if info.IsNestedStruct {
query, err = buildSelectQueryForNestedStructs(dialect, structType, info)
if err != nil {

View File

@ -25,6 +25,8 @@ type User struct {
Address Address `ksql:"address,json"`
}
var UsersTable = NewTable("users")
type Address struct {
Street string `json:"street"`
Number string `json:"number"`
@ -40,6 +42,8 @@ type Post struct {
Title string `ksql:"title"`
}
var PostsTable = NewTable("posts")
func TestQuery(t *testing.T) {
for driver := range supportedDialects {
t.Run(driver, func(t *testing.T) {
@ -69,7 +73,7 @@ func TestQuery(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
var users []User
err := c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`)
assert.Equal(t, nil, err)
@ -89,7 +93,7 @@ func TestQuery(t *testing.T) {
assert.Equal(t, nil, err)
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
var users []User
err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia")
@ -111,7 +115,7 @@ func TestQuery(t *testing.T) {
assert.Equal(t, nil, err)
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
var users []User
err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia")
@ -154,7 +158,7 @@ func TestQuery(t *testing.T) {
assert.Equal(t, nil, err)
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
var rows []struct {
User User `tablename:"u"`
Post Post `tablename:"p"`
@ -193,7 +197,7 @@ func TestQuery(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
var users []*User
err := c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE id=1;`)
assert.Equal(t, nil, err)
@ -213,7 +217,7 @@ func TestQuery(t *testing.T) {
assert.Equal(t, nil, err)
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
var users []*User
err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia")
@ -235,7 +239,7 @@ func TestQuery(t *testing.T) {
assert.Equal(t, nil, err)
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
var users []*User
err = c.Query(ctx, &users, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0), "% Garcia")
@ -278,7 +282,7 @@ func TestQuery(t *testing.T) {
assert.Equal(t, nil, err)
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
var rows []*struct {
User User `tablename:"u"`
Post Post `tablename:"p"`
@ -325,7 +329,7 @@ func TestQuery(t *testing.T) {
assert.Equal(t, nil, err)
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
err = c.Query(ctx, &User{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá")
assert.NotEqual(t, nil, err)
@ -345,56 +349,72 @@ func TestQuery(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
var users []User
err = c.Query(ctx, &users, `SELECT * FROM not a valid query`)
err := c.Query(ctx, &users, `SELECT * FROM not a valid query`)
assert.NotEqual(t, nil, err)
})
})
t.Run("should report error for nested structs with invalid types", func(t *testing.T) {
t.Run("int", func(t *testing.T) {
t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
var rows []struct {
Foo int `tablename:"foo"`
User User `tablename:"users"`
Post Post `tablename:"posts"`
}
err := c.Query(ctx, &rows, fmt.Sprint(
`FROM users u JOIN posts p ON p.user_id = u.id`,
` WHERE u.name like `, c.dialect.Placeholder(0),
` ORDER BY u.id, p.id`,
), "% Ribeiro")
err := c.Query(ctx, &rows, `SELECT * FROM users u JOIN posts p ON u.id = p.user_id`)
assert.NotEqual(t, nil, err)
msg := err.Error()
for _, str := range []string{"foo", "int"} {
assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg))
}
assert.Equal(t, true, strings.Contains(err.Error(), "nested struct"), "unexpected error msg: "+err.Error())
assert.Equal(t, true, strings.Contains(err.Error(), "feature"), "unexpected error msg: "+err.Error())
})
t.Run("*struct", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
t.Run("should report error for nested structs with invalid types", func(t *testing.T) {
t.Run("int", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
var rows []struct {
Foo *User `tablename:"foo"`
}
err := c.Query(ctx, &rows, fmt.Sprint(
`FROM users u JOIN posts p ON p.user_id = u.id`,
` WHERE u.name like `, c.dialect.Placeholder(0),
` ORDER BY u.id, p.id`,
), "% Ribeiro")
ctx := context.Background()
c := newTestDB(db, driver)
var rows []struct {
Foo int `tablename:"foo"`
}
err := c.Query(ctx, &rows, fmt.Sprint(
`FROM users u JOIN posts p ON p.user_id = u.id`,
` WHERE u.name like `, c.dialect.Placeholder(0),
` ORDER BY u.id, p.id`,
), "% Ribeiro")
assert.NotEqual(t, nil, err)
msg := err.Error()
for _, str := range []string{"foo", "*ksql.User"} {
assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg))
}
assert.NotEqual(t, nil, err)
msg := err.Error()
for _, str := range []string{"foo", "int"} {
assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg))
}
})
t.Run("*struct", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver)
var rows []struct {
Foo *User `tablename:"foo"`
}
err := c.Query(ctx, &rows, fmt.Sprint(
`FROM users u JOIN posts p ON p.user_id = u.id`,
` WHERE u.name like `, c.dialect.Placeholder(0),
` ORDER BY u.id, p.id`,
), "% Ribeiro")
assert.NotEqual(t, nil, err)
msg := err.Error()
for _, str := range []string{"foo", "*ksql.User"} {
assert.Equal(t, true, strings.Contains(msg, str), fmt.Sprintf("missing expected substr '%s' in error message: '%s'", str, msg))
}
})
})
})
})
@ -429,7 +449,7 @@ func TestQueryOne(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
u := User{}
err := c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE id=1;`)
assert.Equal(t, ErrRecordNotFound, err)
@ -443,7 +463,7 @@ func TestQueryOne(t *testing.T) {
assert.Equal(t, nil, err)
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
u := User{}
err = c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE name=`+c.dialect.Placeholder(0), "Bia")
@ -466,7 +486,7 @@ func TestQueryOne(t *testing.T) {
assert.Equal(t, nil, err)
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
var u User
err = c.QueryOne(ctx, &u, variation.queryPrefix+`FROM users WHERE name like `+c.dialect.Placeholder(0)+` ORDER BY id ASC`, "% Sá")
@ -496,7 +516,7 @@ func TestQueryOne(t *testing.T) {
assert.Equal(t, nil, err)
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
var row struct {
User User `tablename:"u"`
Post Post `tablename:"p"`
@ -526,7 +546,7 @@ func TestQueryOne(t *testing.T) {
assert.Equal(t, nil, err)
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
err = c.QueryOne(ctx, &[]User{}, `SELECT * FROM users WHERE name like `+c.dialect.Placeholder(0), "% Sá")
assert.NotEqual(t, nil, err)
@ -540,11 +560,27 @@ func TestQueryOne(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
var user User
err := c.QueryOne(ctx, &user, `SELECT * FROM not a valid query`)
assert.NotEqual(t, nil, err)
})
t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver)
var row struct {
User User `tablename:"users"`
Post Post `tablename:"posts"`
}
err := c.QueryOne(ctx, &row, `SELECT * FROM users u JOIN posts p ON u.id = p.user_id LIMIT 1`)
assert.NotEqual(t, nil, err)
assert.Equal(t, true, strings.Contains(err.Error(), "nested struct"), "unexpected error msg: "+err.Error())
assert.Equal(t, true, strings.Contains(err.Error(), "feature"), "unexpected error msg: "+err.Error())
})
})
}
}
@ -563,7 +599,7 @@ func TestInsert(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
u := User{
Name: "Fernanda",
@ -572,7 +608,7 @@ func TestInsert(t *testing.T) {
},
}
err := c.Insert(ctx, &u)
err := c.Insert(ctx, UsersTable, &u)
assert.Equal(t, nil, err)
assert.NotEqual(t, 0, u.ID)
@ -593,11 +629,11 @@ func TestInsert(t *testing.T) {
defer db.Close()
ctx := context.Background()
// Using columns "id" and "name" as IDs:
c, err := New(driver, connectionString[driver], Config{
TableName: "users",
IDColumns: []string{"id", "name"},
})
table := NewTable("users", "id", "name")
c, err := New(driver, connectionString[driver], Config{})
assert.Equal(t, nil, err)
u := User{
@ -609,7 +645,7 @@ func TestInsert(t *testing.T) {
},
}
err = c.Insert(ctx, &u)
err = c.Insert(ctx, table, &u)
assert.Equal(t, nil, err)
assert.Equal(t, uint(0), u.ID)
@ -633,15 +669,15 @@ func TestInsert(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
err = c.Insert(ctx, "foo")
err = c.Insert(ctx, UsersTable, "foo")
assert.NotEqual(t, nil, err)
err = c.Insert(ctx, nullable.String("foo"))
err = c.Insert(ctx, UsersTable, nullable.String("foo"))
assert.NotEqual(t, nil, err)
err = c.Insert(ctx, map[string]interface{}{
err = c.Insert(ctx, UsersTable, map[string]interface{}{
"name": "foo",
"age": 12,
})
@ -651,11 +687,11 @@ func TestInsert(t *testing.T) {
&User{Name: "foo", Age: 22},
&User{Name: "bar", Age: 32},
}
err = c.Insert(ctx, ifUserForgetToExpandList)
err = c.Insert(ctx, UsersTable, ifUserForgetToExpandList)
assert.NotEqual(t, nil, err)
// We might want to support this in the future, but not for now:
err = c.Insert(ctx, User{Name: "not a ptr to user", Age: 42})
err = c.Insert(ctx, UsersTable, User{Name: "not a ptr to user", Age: 42})
assert.NotEqual(t, nil, err)
})
@ -664,12 +700,12 @@ func TestInsert(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
// This is an invalid value:
c.insertMethod = insertMethod(42)
c.dialect = brokenDialect{}
err = c.Insert(ctx, &User{Name: "foo"})
err = c.Insert(ctx, UsersTable, &User{Name: "foo"})
assert.NotEqual(t, nil, err)
})
})
@ -677,6 +713,24 @@ func TestInsert(t *testing.T) {
}
}
type brokenDialect struct{}
func (brokenDialect) InsertMethod() insertMethod {
return insertMethod(42)
}
func (brokenDialect) Escape(str string) string {
return str
}
func (brokenDialect) Placeholder(idx int) string {
return "?"
}
func (brokenDialect) DriverName() string {
return "fake-driver-name"
}
func TestDelete(t *testing.T) {
for driver := range supportedDialects {
t.Run(driver, func(t *testing.T) {
@ -690,13 +744,13 @@ func TestDelete(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
u := User{
Name: "Won't be deleted",
}
err := c.Insert(ctx, &u)
err := c.Insert(ctx, UsersTable, &u)
assert.Equal(t, nil, err)
assert.NotEqual(t, uint(0), u.ID)
@ -706,7 +760,7 @@ func TestDelete(t *testing.T) {
assert.Equal(t, u.ID, result.ID)
err = c.Delete(ctx)
err = c.Delete(ctx, UsersTable)
assert.Equal(t, nil, err)
result = User{}
@ -720,13 +774,13 @@ func TestDelete(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
u1 := User{
Name: "Fernanda",
}
err := c.Insert(ctx, &u1)
err := c.Insert(ctx, UsersTable, &u1)
assert.Equal(t, nil, err)
assert.NotEqual(t, uint(0), u1.ID)
@ -739,7 +793,7 @@ func TestDelete(t *testing.T) {
Name: "Won't be deleted",
}
err = c.Insert(ctx, &u2)
err = c.Insert(ctx, UsersTable, &u2)
assert.Equal(t, nil, err)
assert.NotEqual(t, uint(0), u2.ID)
@ -748,7 +802,7 @@ func TestDelete(t *testing.T) {
assert.Equal(t, nil, err)
assert.Equal(t, u2.ID, result.ID)
err = c.Delete(ctx, u1.ID)
err = c.Delete(ctx, UsersTable, u1.ID)
assert.Equal(t, nil, err)
result = User{}
@ -768,26 +822,26 @@ func TestDelete(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
u1 := User{
Name: "Fernanda",
}
err := c.Insert(ctx, &u1)
err := c.Insert(ctx, UsersTable, &u1)
assert.Equal(t, nil, err)
assert.NotEqual(t, uint(0), u1.ID)
u2 := User{
Name: "Juliano",
}
err = c.Insert(ctx, &u2)
err = c.Insert(ctx, UsersTable, &u2)
assert.Equal(t, nil, err)
assert.NotEqual(t, uint(0), u2.ID)
u3 := User{
Name: "This won't be deleted",
}
err = c.Insert(ctx, &u3)
err = c.Insert(ctx, UsersTable, &u3)
assert.Equal(t, nil, err)
assert.NotEqual(t, uint(0), u3.ID)
@ -806,7 +860,7 @@ func TestDelete(t *testing.T) {
assert.Equal(t, nil, err)
assert.Equal(t, u3.ID, result.ID)
err = c.Delete(ctx, u1.ID, u2.ID)
err = c.Delete(ctx, UsersTable, u1.ID, u2.ID)
assert.Equal(t, nil, err)
results := []User{}
@ -833,7 +887,7 @@ func TestUpdate(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
u := User{
Name: "Letícia",
@ -847,7 +901,7 @@ func TestUpdate(t *testing.T) {
assert.Equal(t, nil, err)
assert.NotEqual(t, uint(0), u.ID)
err = c.Update(ctx, User{
err = c.Update(ctx, UsersTable, User{
ID: u.ID,
Name: "Thayane",
})
@ -864,7 +918,7 @@ func TestUpdate(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
u := User{
Name: "Letícia",
@ -878,7 +932,7 @@ func TestUpdate(t *testing.T) {
assert.Equal(t, nil, err)
assert.NotEqual(t, uint(0), u.ID)
err = c.Update(ctx, User{
err = c.Update(ctx, UsersTable, User{
ID: u.ID,
Name: "Thayane",
})
@ -895,7 +949,7 @@ func TestUpdate(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
type partialUser struct {
ID uint `ksql:"id"`
@ -915,7 +969,7 @@ func TestUpdate(t *testing.T) {
assert.Equal(t, nil, err)
assert.NotEqual(t, uint(0), u.ID)
err = c.Update(ctx, partialUser{
err = c.Update(ctx, UsersTable, partialUser{
ID: u.ID,
// Should be updated because it is not null, just empty:
Name: "",
@ -936,7 +990,7 @@ func TestUpdate(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
type partialUser struct {
ID uint `ksql:"id"`
@ -957,7 +1011,7 @@ func TestUpdate(t *testing.T) {
assert.NotEqual(t, uint(0), u.ID)
// Should update all fields:
err = c.Update(ctx, partialUser{
err = c.Update(ctx, UsersTable, partialUser{
ID: u.ID,
Name: "Thay",
Age: nullable.Int(42),
@ -977,9 +1031,9 @@ func TestUpdate(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "non_existing_table")
c := newTestDB(db, driver)
err = c.Update(ctx, User{
err = c.Update(ctx, NewTable("non_existing_table"), User{
ID: 1,
Name: "Thayane",
})
@ -1017,9 +1071,9 @@ func TestQueryChunks(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
_ = c.Insert(ctx, &User{
_ = c.Insert(ctx, UsersTable, &User{
Name: "User1",
Address: Address{Country: "BR"},
})
@ -1057,10 +1111,10 @@ func TestQueryChunks(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
_ = c.Insert(ctx, &User{Name: "User1", Address: Address{Country: "US"}})
_ = c.Insert(ctx, &User{Name: "User2", Address: Address{Country: "BR"}})
_ = c.Insert(ctx, UsersTable, &User{Name: "User1", Address: Address{Country: "US"}})
_ = c.Insert(ctx, UsersTable, &User{Name: "User2", Address: Address{Country: "BR"}})
var lengths []int
var users []User
@ -1099,10 +1153,10 @@ func TestQueryChunks(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
_ = c.Insert(ctx, &User{Name: "User1", Address: Address{Country: "US"}})
_ = c.Insert(ctx, &User{Name: "User2", Address: Address{Country: "BR"}})
_ = c.Insert(ctx, UsersTable, &User{Name: "User1", Address: Address{Country: "US"}})
_ = c.Insert(ctx, UsersTable, &User{Name: "User2", Address: Address{Country: "BR"}})
var lengths []int
var users []User
@ -1141,11 +1195,11 @@ func TestQueryChunks(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
_ = c.Insert(ctx, &User{Name: "User1"})
_ = c.Insert(ctx, &User{Name: "User2"})
_ = c.Insert(ctx, &User{Name: "User3"})
_ = c.Insert(ctx, UsersTable, &User{Name: "User1"})
_ = c.Insert(ctx, UsersTable, &User{Name: "User2"})
_ = c.Insert(ctx, UsersTable, &User{Name: "User3"})
var lengths []int
var users []User
@ -1192,9 +1246,9 @@ func TestQueryChunks(t *testing.T) {
}
ctx := context.Background()
c := newTestDB(db, driver, "users")
_ = c.Insert(ctx, &joao)
_ = c.Insert(ctx, &thatiana)
c := newTestDB(db, driver)
_ = c.Insert(ctx, UsersTable, &joao)
_ = c.Insert(ctx, UsersTable, &thatiana)
_, err := db.Exec(fmt.Sprint(`INSERT INTO posts (user_id, title) VALUES (`, thatiana.ID, `, 'Thatiana Post1')`))
assert.Equal(t, nil, err)
@ -1254,11 +1308,11 @@ func TestQueryChunks(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
_ = c.Insert(ctx, &User{Name: "User1"})
_ = c.Insert(ctx, &User{Name: "User2"})
_ = c.Insert(ctx, &User{Name: "User3"})
_ = c.Insert(ctx, UsersTable, &User{Name: "User1"})
_ = c.Insert(ctx, UsersTable, &User{Name: "User2"})
_ = c.Insert(ctx, UsersTable, &User{Name: "User3"})
var lengths []int
var users []User
@ -1293,11 +1347,11 @@ func TestQueryChunks(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
_ = c.Insert(ctx, &User{Name: "User1"})
_ = c.Insert(ctx, &User{Name: "User2"})
_ = c.Insert(ctx, &User{Name: "User3"})
_ = c.Insert(ctx, UsersTable, &User{Name: "User1"})
_ = c.Insert(ctx, UsersTable, &User{Name: "User2"})
_ = c.Insert(ctx, UsersTable, &User{Name: "User3"})
returnVals := []error{nil, ErrAbortIteration}
var lengths []int
@ -1336,11 +1390,11 @@ func TestQueryChunks(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
_ = c.Insert(ctx, &User{Name: "User1"})
_ = c.Insert(ctx, &User{Name: "User2"})
_ = c.Insert(ctx, &User{Name: "User3"})
_ = c.Insert(ctx, UsersTable, &User{Name: "User1"})
_ = c.Insert(ctx, UsersTable, &User{Name: "User2"})
_ = c.Insert(ctx, UsersTable, &User{Name: "User3"})
var lengths []int
var users []User
@ -1375,11 +1429,11 @@ func TestQueryChunks(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
_ = c.Insert(ctx, &User{Name: "User1"})
_ = c.Insert(ctx, &User{Name: "User2"})
_ = c.Insert(ctx, &User{Name: "User3"})
_ = c.Insert(ctx, UsersTable, &User{Name: "User1"})
_ = c.Insert(ctx, UsersTable, &User{Name: "User2"})
_ = c.Insert(ctx, UsersTable, &User{Name: "User3"})
returnVals := []error{nil, errors.New("fake error msg")}
var lengths []int
@ -1413,7 +1467,7 @@ func TestQueryChunks(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
funcs := []interface{}{
nil,
@ -1458,7 +1512,7 @@ func TestQueryChunks(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
err := c.QueryChunks(ctx, ChunkParser{
Query: `SELECT * FROM not a valid query`,
Params: []interface{}{},
@ -1470,6 +1524,31 @@ func TestQueryChunks(t *testing.T) {
})
assert.NotEqual(t, nil, err)
})
t.Run("should report error if using nested struct and the query starts with SELECT", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver)
err := c.QueryChunks(ctx, ChunkParser{
Query: `SELECT * FROM users u JOIN posts p ON u.id = p.user_id`,
Params: []interface{}{},
ChunkSize: 2,
ForEachChunk: func(buffer []struct {
User User `tablename:"users"`
Post Post `tablename:"posts"`
}) error {
return nil
},
})
assert.NotEqual(t, nil, err)
assert.Equal(t, true, strings.Contains(err.Error(), "nested struct"), "unexpected error msg: "+err.Error())
assert.Equal(t, true, strings.Contains(err.Error(), "feature"), "unexpected error msg: "+err.Error())
})
})
}
})
@ -1489,10 +1568,10 @@ func TestTransaction(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
_ = c.Insert(ctx, &User{Name: "User1"})
_ = c.Insert(ctx, &User{Name: "User2"})
_ = c.Insert(ctx, UsersTable, &User{Name: "User1"})
_ = c.Insert(ctx, UsersTable, &User{Name: "User2"})
var users []User
err = c.Transaction(ctx, func(db SQLProvider) error {
@ -1516,17 +1595,17 @@ func TestTransaction(t *testing.T) {
defer db.Close()
ctx := context.Background()
c := newTestDB(db, driver, "users")
c := newTestDB(db, driver)
u1 := User{Name: "User1", Age: 42}
u2 := User{Name: "User2", Age: 42}
_ = c.Insert(ctx, &u1)
_ = c.Insert(ctx, &u2)
_ = c.Insert(ctx, UsersTable, &u1)
_ = c.Insert(ctx, UsersTable, &u2)
err = c.Transaction(ctx, func(db SQLProvider) error {
err = db.Insert(ctx, &User{Name: "User3"})
err = db.Insert(ctx, UsersTable, &User{Name: "User3"})
assert.Equal(t, nil, err)
err = db.Insert(ctx, &User{Name: "User4"})
err = db.Insert(ctx, UsersTable, &User{Name: "User4"})
assert.Equal(t, nil, err)
err = db.Exec(ctx, "UPDATE users SET age = 22")
assert.Equal(t, nil, err)
@ -1557,10 +1636,10 @@ func TestScanRows(t *testing.T) {
ctx := context.TODO()
db := connectDB(t, "sqlite3")
defer db.Close()
c := newTestDB(db, "sqlite3", "users")
_ = c.Insert(ctx, &User{Name: "User1", Age: 22})
_ = c.Insert(ctx, &User{Name: "User2", Age: 14})
_ = c.Insert(ctx, &User{Name: "User3", Age: 43})
c := newTestDB(db, "sqlite3")
_ = c.Insert(ctx, UsersTable, &User{Name: "User1", Age: 22})
_ = c.Insert(ctx, UsersTable, &User{Name: "User2", Age: 14})
_ = c.Insert(ctx, UsersTable, &User{Name: "User3", Age: 43})
rows, err := db.QueryContext(ctx, "select * from users where name='User2'")
assert.Equal(t, nil, err)
@ -1586,8 +1665,8 @@ func TestScanRows(t *testing.T) {
ctx := context.TODO()
db := connectDB(t, "sqlite3")
defer db.Close()
c := newTestDB(db, "sqlite3", "users")
_ = c.Insert(ctx, &User{Name: "User1", Age: 22})
c := newTestDB(db, "sqlite3")
_ = c.Insert(ctx, UsersTable, &User{Name: "User1", Age: 22})
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User1'")
assert.Equal(t, nil, err)
@ -1758,19 +1837,11 @@ func createTables(driver string) error {
return nil
}
func newTestDB(db *sql.DB, driver string, tableName string, ids ...string) DB {
if len(ids) == 0 {
ids = []string{"id"}
}
func newTestDB(db *sql.DB, driver string) DB {
return DB{
driver: driver,
dialect: supportedDialects[driver],
db: db,
tableName: tableName,
idCols: ids,
insertMethod: supportedDialects[driver].InsertMethod(),
driver: driver,
dialect: supportedDialects[driver],
db: db,
}
}

View File

@ -6,9 +6,9 @@ var _ SQLProvider = MockSQLProvider{}
// MockSQLProvider ...
type MockSQLProvider struct {
InsertFn func(ctx context.Context, record interface{}) error
UpdateFn func(ctx context.Context, record interface{}) error
DeleteFn func(ctx context.Context, ids ...interface{}) error
InsertFn func(ctx context.Context, table Table, record interface{}) error
UpdateFn func(ctx context.Context, table Table, record interface{}) error
DeleteFn func(ctx context.Context, table Table, ids ...interface{}) error
QueryFn func(ctx context.Context, records interface{}, query string, params ...interface{}) error
QueryOneFn func(ctx context.Context, record interface{}, query string, params ...interface{}) error
@ -19,18 +19,18 @@ type MockSQLProvider struct {
}
// Insert ...
func (m MockSQLProvider) Insert(ctx context.Context, record interface{}) error {
return m.InsertFn(ctx, record)
func (m MockSQLProvider) Insert(ctx context.Context, table Table, record interface{}) error {
return m.InsertFn(ctx, table, record)
}
// Update ...
func (m MockSQLProvider) Update(ctx context.Context, record interface{}) error {
return m.UpdateFn(ctx, record)
func (m MockSQLProvider) Update(ctx context.Context, table Table, record interface{}) error {
return m.UpdateFn(ctx, table, record)
}
// Delete ...
func (m MockSQLProvider) Delete(ctx context.Context, ids ...interface{}) error {
return m.DeleteFn(ctx, ids...)
func (m MockSQLProvider) Delete(ctx context.Context, table Table, ids ...interface{}) error {
return m.DeleteFn(ctx, table, ids...)
}
// Query ...