diff --git a/contracts.go b/contracts.go index 95aa456..0723f96 100644 --- a/contracts.go +++ b/contracts.go @@ -2,14 +2,20 @@ package kissorm import ( "context" + "fmt" ) +// EntityNotFound ... +var EntityNotFoundErr error = fmt.Errorf("kissorm: the query returned no results") + // ORMProvider describes the public behavior of this ORM type ORMProvider interface { - Find(ctx context.Context, item interface{}, query string, params ...interface{}) error - Insert(ctx context.Context, items ...interface{}) error + Insert(ctx context.Context, records ...interface{}) error Delete(ctx context.Context, ids ...interface{}) error - Update(ctx context.Context, items ...interface{}) error + Update(ctx context.Context, records ...interface{}) error + + Query(ctx context.Context, records interface{}, query string, params ...interface{}) error + QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error QueryChunks(ctx context.Context, parser ChunkParser) error } @@ -23,11 +29,11 @@ type ChunkParser struct { Chunk interface{} // Must be a pointer to a slice of structs // The closure that will be called right after - // filling the Chunk with ChunkSize items + // filling the Chunk with ChunkSize records // // Each chunk consecutively parsed will overwrite the // same slice, so don't keep references to it, if you // need some data to be preserved after all chunks are - // processed copy the items by value. + // processed copy the records by value. ForEachChunk func() error } diff --git a/kiss_orm.go b/kiss_orm.go index 57fa2f0..f2ab40e 100644 --- a/kiss_orm.go +++ b/kiss_orm.go @@ -45,20 +45,68 @@ func (c Client) ChangeTable(ctx context.Context, tableName string) ORMProvider { } } -// Find one instance from the database, the input struct -// must be passed by reference and the query should -// return only one result. -func (c Client) Find( +// Query queries several rows from the database, +// the input should be a slice of structs passed +// by reference and it will be filled with all the results. +// +// Note: it is very important to make sure the query will +// return a small number of results, otherwise you risk +// of overloading the available memory. +func (c Client) Query( ctx context.Context, - item interface{}, + records interface{}, query string, params ...interface{}, ) error { + t := reflect.TypeOf(records) + if t.Kind() != reflect.Ptr { + return fmt.Errorf("kissorm: expected to receive a pointer to slice of structs, but got: %T", records) + } + t = t.Elem() + if t.Kind() != reflect.Slice { + return fmt.Errorf("kissorm: expected to receive a pointer to slice of structs, but got: %T", records) + } + if t.Elem().Kind() != reflect.Struct { + return fmt.Errorf("kissorm: expected to receive a pointer to slice of structs, but got: %T", records) + } + it := c.db.Raw(query, params...) if it.Error != nil { return it.Error } - it = it.Scan(item) + it = it.Scan(records) + return it.Error +} + +// QueryOne queries one instance from the database, +// the input struct must be passed by reference +// and the query should return only one result. +// +// QueryOne returns a EntityNotFoundErr if +// the query returns no results. +func (c Client) QueryOne( + ctx context.Context, + record interface{}, + query string, + params ...interface{}, +) error { + t := reflect.TypeOf(record) + if t.Kind() != reflect.Ptr { + return fmt.Errorf("kissorm: expected to receive a pointer to struct, but got: %T", record) + } + t = t.Elem() + if t.Kind() != reflect.Struct { + return fmt.Errorf("kissorm: expected to receive a pointer to struct, but got: %T", record) + } + + it := c.db.Raw(query, params...) + if it.Error != nil { + return it.Error + } + it = it.Scan(record) + if it.Error != nil && it.Error.Error() == "record not found" { + return EntityNotFoundErr + } return it.Error } @@ -141,14 +189,14 @@ func (c Client) QueryChunks( // the ID is automatically updated after insertion is completed. func (c Client) Insert( ctx context.Context, - items ...interface{}, + records ...interface{}, ) error { - if len(items) == 0 { + if len(records) == 0 { return nil } - for _, item := range items { - r := c.db.Table(c.tableName).Create(item) + for _, record := range records { + r := c.db.Table(c.tableName).Create(record) if r.Error != nil { return r.Error } @@ -177,15 +225,15 @@ func (c Client) Delete( // Partial updates are supported, i.e. it will ignore nil pointer attributes func (c Client) Update( ctx context.Context, - items ...interface{}, + records ...interface{}, ) error { - for _, item := range items { - m, err := StructToMap(item) + for _, record := range records { + m, err := StructToMap(record) if err != nil { return err } delete(m, "id") - r := c.db.Table(c.tableName).Model(item).Updates(m) + r := c.db.Table(c.tableName).Model(record).Updates(m) if r.Error != nil { return r.Error } diff --git a/kiss_orm_test.go b/kiss_orm_test.go index 29ec778..9517164 100644 --- a/kiss_orm_test.go +++ b/kiss_orm_test.go @@ -16,7 +16,7 @@ type User struct { CreatedAt time.Time `gorm:"created_at"` } -func TestFind(t *testing.T) { +func TestQuery(t *testing.T) { err := createTable() if err != nil { t.Fatal("could not create test table!") @@ -31,10 +31,10 @@ func TestFind(t *testing.T) { db: db, tableName: "users", } - u := User{} - err := c.Find(ctx, &u, `SELECT * FROM users WHERE id=1;`) - assert.NotEqual(t, nil, err) - assert.Equal(t, User{}, u) + var users []User + err := c.Query(ctx, &users, `SELECT * FROM users WHERE id=1;`) + assert.Equal(t, nil, err) + assert.Equal(t, []User{}, users) }) t.Run("should return a user correctly", func(t *testing.T) { @@ -50,12 +50,13 @@ func TestFind(t *testing.T) { db: db, tableName: "users", } - u := User{} - err = c.Find(ctx, &u, `SELECT * FROM users WHERE name=?;`, "Bia") + var users []User + err = c.Query(ctx, &users, `SELECT * FROM users WHERE name=?;`, "Bia") - assert.Equal(t, err, nil) - assert.Equal(t, "Bia", u.Name) - assert.NotEqual(t, 0, u.ID) + assert.Equal(t, nil, err) + assert.Equal(t, 1, len(users)) + assert.Equal(t, "Bia", users[0].Name) + assert.NotEqual(t, 0, users[0].ID) }) t.Run("should return multiple users correctly", func(t *testing.T) { @@ -75,10 +76,10 @@ func TestFind(t *testing.T) { db: db, tableName: "users", } - users := []User{} - err = c.Find(ctx, &users, `SELECT * FROM users WHERE name like ?;`, "% Garcia") + var users []User + err = c.Query(ctx, &users, `SELECT * FROM users WHERE name like ?;`, "% Garcia") - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) assert.Equal(t, 2, len(users)) assert.Equal(t, "João Garcia", users[0].Name) assert.NotEqual(t, 0, users[0].ID) @@ -87,6 +88,71 @@ func TestFind(t *testing.T) { }) } +func TestQueryOne(t *testing.T) { + err := createTable() + if err != nil { + t.Fatal("could not create test table!") + } + + t.Run("should return EntityNotFoundErr when there are no results", func(t *testing.T) { + db := connectDB(t) + defer db.Close() + + ctx := context.Background() + c := Client{ + db: db, + tableName: "users", + } + u := User{} + err := c.QueryOne(ctx, &u, `SELECT * FROM users WHERE id=1;`) + assert.Equal(t, EntityNotFoundErr, err) + }) + + t.Run("should return a user correctly", func(t *testing.T) { + db := connectDB(t) + defer db.Close() + + db.Create(&User{ + Name: "Bia", + }) + + ctx := context.Background() + c := Client{ + db: db, + tableName: "users", + } + u := User{} + err = c.QueryOne(ctx, &u, `SELECT * FROM users WHERE name=?;`, "Bia") + + assert.Equal(t, nil, err) + assert.Equal(t, "Bia", u.Name) + assert.NotEqual(t, 0, u.ID) + }) + + t.Run("should report error if input is no a pointer to struct", func(t *testing.T) { + db := connectDB(t) + defer db.Close() + + db.Create(&User{ + Name: "Andréa Sá", + }) + + db.Create(&User{ + Name: "Caio Sá", + }) + + ctx := context.Background() + c := Client{ + db: db, + tableName: "users", + } + users := []User{} + err = c.QueryOne(ctx, &users, `SELECT * FROM users WHERE name like ?;`, "% Sá") + + assert.NotEqual(t, nil, err) + }) +} + func TestInsert(t *testing.T) { err := createTable() if err != nil { @@ -104,7 +170,7 @@ func TestInsert(t *testing.T) { } err = c.Insert(ctx) - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) }) t.Run("should insert one user correctly", func(t *testing.T) { @@ -122,12 +188,12 @@ func TestInsert(t *testing.T) { } err := c.Insert(ctx, &u) - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) result := User{} it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID) it.Scan(&result) - assert.Equal(t, it.Error, nil) + assert.Equal(t, nil, it.Error) assert.Equal(t, u.Name, result.Name) assert.Equal(t, u.CreatedAt.Format(time.RFC3339), result.CreatedAt.Format(time.RFC3339)) }) @@ -150,7 +216,7 @@ func TestDelete(t *testing.T) { } err = c.Delete(ctx) - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) }) t.Run("should delete one id correctly", func(t *testing.T) { @@ -168,7 +234,7 @@ func TestDelete(t *testing.T) { } err := c.Insert(ctx, &u) - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) assert.NotEqual(t, 0, u.ID) result := User{} @@ -177,13 +243,13 @@ func TestDelete(t *testing.T) { assert.Equal(t, u.ID, result.ID) err = c.Delete(ctx, u.ID) - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) result = User{} it = c.db.Raw("SELECT * FROM users WHERE id=?", u.ID) it.Scan(&result) - assert.Equal(t, it.Error, nil) + assert.Equal(t, nil, it.Error) assert.Equal(t, uint(0), result.ID) assert.Equal(t, "", result.Name) }) @@ -209,18 +275,18 @@ func TestUpdate(t *testing.T) { Name: "Thay", } err := c.Insert(ctx, &u) - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) assert.NotEqual(t, 0, u.ID) // Empty update, should do nothing: err = c.Update(ctx) - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) result := User{} it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID) it.Scan(&result) it.Close() - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) assert.Equal(t, "Thay", result.Name) }) @@ -239,7 +305,7 @@ func TestUpdate(t *testing.T) { ID: 1, Name: "Thayane", }) - assert.NotEqual(t, err, nil) + assert.NotEqual(t, nil, err) }) }