Split the Find function into the Query & QueryOne functions

pull/2/head
Vinícius Garcia 2020-10-13 21:39:39 -03:00
parent a3bf34146d
commit e8bd504703
3 changed files with 163 additions and 43 deletions

View File

@ -2,14 +2,20 @@ package kissorm
import ( import (
"context" "context"
"fmt"
) )
// EntityNotFound ...
var EntityNotFoundErr error = fmt.Errorf("kissorm: the query returned no results")
// ORMProvider describes the public behavior of this ORM // ORMProvider describes the public behavior of this ORM
type ORMProvider interface { type ORMProvider interface {
Find(ctx context.Context, item interface{}, query string, params ...interface{}) error Insert(ctx context.Context, records ...interface{}) error
Insert(ctx context.Context, items ...interface{}) error
Delete(ctx context.Context, ids ...interface{}) error Delete(ctx context.Context, ids ...interface{}) error
Update(ctx context.Context, items ...interface{}) error Update(ctx context.Context, records ...interface{}) error
Query(ctx context.Context, records interface{}, query string, params ...interface{}) error
QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error
QueryChunks(ctx context.Context, parser ChunkParser) error QueryChunks(ctx context.Context, parser ChunkParser) error
} }
@ -23,11 +29,11 @@ type ChunkParser struct {
Chunk interface{} // Must be a pointer to a slice of structs Chunk interface{} // Must be a pointer to a slice of structs
// The closure that will be called right after // The closure that will be called right after
// filling the Chunk with ChunkSize items // filling the Chunk with ChunkSize records
// //
// Each chunk consecutively parsed will overwrite the // Each chunk consecutively parsed will overwrite the
// same slice, so don't keep references to it, if you // same slice, so don't keep references to it, if you
// need some data to be preserved after all chunks are // need some data to be preserved after all chunks are
// processed copy the items by value. // processed copy the records by value.
ForEachChunk func() error ForEachChunk func() error
} }

View File

@ -45,20 +45,68 @@ func (c Client) ChangeTable(ctx context.Context, tableName string) ORMProvider {
} }
} }
// Find one instance from the database, the input struct // Query queries several rows from the database,
// must be passed by reference and the query should // the input should be a slice of structs passed
// return only one result. // by reference and it will be filled with all the results.
func (c Client) Find( //
// Note: it is very important to make sure the query will
// return a small number of results, otherwise you risk
// of overloading the available memory.
func (c Client) Query(
ctx context.Context, ctx context.Context,
item interface{}, records interface{},
query string, query string,
params ...interface{}, params ...interface{},
) error { ) error {
t := reflect.TypeOf(records)
if t.Kind() != reflect.Ptr {
return fmt.Errorf("kissorm: expected to receive a pointer to slice of structs, but got: %T", records)
}
t = t.Elem()
if t.Kind() != reflect.Slice {
return fmt.Errorf("kissorm: expected to receive a pointer to slice of structs, but got: %T", records)
}
if t.Elem().Kind() != reflect.Struct {
return fmt.Errorf("kissorm: expected to receive a pointer to slice of structs, but got: %T", records)
}
it := c.db.Raw(query, params...) it := c.db.Raw(query, params...)
if it.Error != nil { if it.Error != nil {
return it.Error return it.Error
} }
it = it.Scan(item) it = it.Scan(records)
return it.Error
}
// QueryOne queries one instance from the database,
// the input struct must be passed by reference
// and the query should return only one result.
//
// QueryOne returns a EntityNotFoundErr if
// the query returns no results.
func (c Client) QueryOne(
ctx context.Context,
record interface{},
query string,
params ...interface{},
) error {
t := reflect.TypeOf(record)
if t.Kind() != reflect.Ptr {
return fmt.Errorf("kissorm: expected to receive a pointer to struct, but got: %T", record)
}
t = t.Elem()
if t.Kind() != reflect.Struct {
return fmt.Errorf("kissorm: expected to receive a pointer to struct, but got: %T", record)
}
it := c.db.Raw(query, params...)
if it.Error != nil {
return it.Error
}
it = it.Scan(record)
if it.Error != nil && it.Error.Error() == "record not found" {
return EntityNotFoundErr
}
return it.Error return it.Error
} }
@ -141,14 +189,14 @@ func (c Client) QueryChunks(
// the ID is automatically updated after insertion is completed. // the ID is automatically updated after insertion is completed.
func (c Client) Insert( func (c Client) Insert(
ctx context.Context, ctx context.Context,
items ...interface{}, records ...interface{},
) error { ) error {
if len(items) == 0 { if len(records) == 0 {
return nil return nil
} }
for _, item := range items { for _, record := range records {
r := c.db.Table(c.tableName).Create(item) r := c.db.Table(c.tableName).Create(record)
if r.Error != nil { if r.Error != nil {
return r.Error return r.Error
} }
@ -177,15 +225,15 @@ func (c Client) Delete(
// Partial updates are supported, i.e. it will ignore nil pointer attributes // Partial updates are supported, i.e. it will ignore nil pointer attributes
func (c Client) Update( func (c Client) Update(
ctx context.Context, ctx context.Context,
items ...interface{}, records ...interface{},
) error { ) error {
for _, item := range items { for _, record := range records {
m, err := StructToMap(item) m, err := StructToMap(record)
if err != nil { if err != nil {
return err return err
} }
delete(m, "id") delete(m, "id")
r := c.db.Table(c.tableName).Model(item).Updates(m) r := c.db.Table(c.tableName).Model(record).Updates(m)
if r.Error != nil { if r.Error != nil {
return r.Error return r.Error
} }

View File

@ -16,7 +16,7 @@ type User struct {
CreatedAt time.Time `gorm:"created_at"` CreatedAt time.Time `gorm:"created_at"`
} }
func TestFind(t *testing.T) { func TestQuery(t *testing.T) {
err := createTable() err := createTable()
if err != nil { if err != nil {
t.Fatal("could not create test table!") t.Fatal("could not create test table!")
@ -31,10 +31,10 @@ func TestFind(t *testing.T) {
db: db, db: db,
tableName: "users", tableName: "users",
} }
u := User{} var users []User
err := c.Find(ctx, &u, `SELECT * FROM users WHERE id=1;`) err := c.Query(ctx, &users, `SELECT * FROM users WHERE id=1;`)
assert.NotEqual(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, User{}, u) assert.Equal(t, []User{}, users)
}) })
t.Run("should return a user correctly", func(t *testing.T) { t.Run("should return a user correctly", func(t *testing.T) {
@ -50,12 +50,13 @@ func TestFind(t *testing.T) {
db: db, db: db,
tableName: "users", tableName: "users",
} }
u := User{} var users []User
err = c.Find(ctx, &u, `SELECT * FROM users WHERE name=?;`, "Bia") err = c.Query(ctx, &users, `SELECT * FROM users WHERE name=?;`, "Bia")
assert.Equal(t, err, nil) assert.Equal(t, nil, err)
assert.Equal(t, "Bia", u.Name) assert.Equal(t, 1, len(users))
assert.NotEqual(t, 0, u.ID) assert.Equal(t, "Bia", users[0].Name)
assert.NotEqual(t, 0, users[0].ID)
}) })
t.Run("should return multiple users correctly", func(t *testing.T) { t.Run("should return multiple users correctly", func(t *testing.T) {
@ -75,10 +76,10 @@ func TestFind(t *testing.T) {
db: db, db: db,
tableName: "users", tableName: "users",
} }
users := []User{} var users []User
err = c.Find(ctx, &users, `SELECT * FROM users WHERE name like ?;`, "% Garcia") err = c.Query(ctx, &users, `SELECT * FROM users WHERE name like ?;`, "% Garcia")
assert.Equal(t, err, nil) assert.Equal(t, nil, err)
assert.Equal(t, 2, len(users)) assert.Equal(t, 2, len(users))
assert.Equal(t, "João Garcia", users[0].Name) assert.Equal(t, "João Garcia", users[0].Name)
assert.NotEqual(t, 0, users[0].ID) assert.NotEqual(t, 0, users[0].ID)
@ -87,6 +88,71 @@ func TestFind(t *testing.T) {
}) })
} }
func TestQueryOne(t *testing.T) {
err := createTable()
if err != nil {
t.Fatal("could not create test table!")
}
t.Run("should return EntityNotFoundErr when there are no results", func(t *testing.T) {
db := connectDB(t)
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
u := User{}
err := c.QueryOne(ctx, &u, `SELECT * FROM users WHERE id=1;`)
assert.Equal(t, EntityNotFoundErr, err)
})
t.Run("should return a user correctly", func(t *testing.T) {
db := connectDB(t)
defer db.Close()
db.Create(&User{
Name: "Bia",
})
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
u := User{}
err = c.QueryOne(ctx, &u, `SELECT * FROM users WHERE name=?;`, "Bia")
assert.Equal(t, nil, err)
assert.Equal(t, "Bia", u.Name)
assert.NotEqual(t, 0, u.ID)
})
t.Run("should report error if input is no a pointer to struct", func(t *testing.T) {
db := connectDB(t)
defer db.Close()
db.Create(&User{
Name: "Andréa Sá",
})
db.Create(&User{
Name: "Caio Sá",
})
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
users := []User{}
err = c.QueryOne(ctx, &users, `SELECT * FROM users WHERE name like ?;`, "% Sá")
assert.NotEqual(t, nil, err)
})
}
func TestInsert(t *testing.T) { func TestInsert(t *testing.T) {
err := createTable() err := createTable()
if err != nil { if err != nil {
@ -104,7 +170,7 @@ func TestInsert(t *testing.T) {
} }
err = c.Insert(ctx) err = c.Insert(ctx)
assert.Equal(t, err, nil) assert.Equal(t, nil, err)
}) })
t.Run("should insert one user correctly", func(t *testing.T) { t.Run("should insert one user correctly", func(t *testing.T) {
@ -122,12 +188,12 @@ func TestInsert(t *testing.T) {
} }
err := c.Insert(ctx, &u) err := c.Insert(ctx, &u)
assert.Equal(t, err, nil) assert.Equal(t, nil, err)
result := User{} result := User{}
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID) it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
it.Scan(&result) it.Scan(&result)
assert.Equal(t, it.Error, nil) assert.Equal(t, nil, it.Error)
assert.Equal(t, u.Name, result.Name) assert.Equal(t, u.Name, result.Name)
assert.Equal(t, u.CreatedAt.Format(time.RFC3339), result.CreatedAt.Format(time.RFC3339)) assert.Equal(t, u.CreatedAt.Format(time.RFC3339), result.CreatedAt.Format(time.RFC3339))
}) })
@ -150,7 +216,7 @@ func TestDelete(t *testing.T) {
} }
err = c.Delete(ctx) err = c.Delete(ctx)
assert.Equal(t, err, nil) assert.Equal(t, nil, err)
}) })
t.Run("should delete one id correctly", func(t *testing.T) { t.Run("should delete one id correctly", func(t *testing.T) {
@ -168,7 +234,7 @@ func TestDelete(t *testing.T) {
} }
err := c.Insert(ctx, &u) err := c.Insert(ctx, &u)
assert.Equal(t, err, nil) assert.Equal(t, nil, err)
assert.NotEqual(t, 0, u.ID) assert.NotEqual(t, 0, u.ID)
result := User{} result := User{}
@ -177,13 +243,13 @@ func TestDelete(t *testing.T) {
assert.Equal(t, u.ID, result.ID) assert.Equal(t, u.ID, result.ID)
err = c.Delete(ctx, u.ID) err = c.Delete(ctx, u.ID)
assert.Equal(t, err, nil) assert.Equal(t, nil, err)
result = User{} result = User{}
it = c.db.Raw("SELECT * FROM users WHERE id=?", u.ID) it = c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
it.Scan(&result) it.Scan(&result)
assert.Equal(t, it.Error, nil) assert.Equal(t, nil, it.Error)
assert.Equal(t, uint(0), result.ID) assert.Equal(t, uint(0), result.ID)
assert.Equal(t, "", result.Name) assert.Equal(t, "", result.Name)
}) })
@ -209,18 +275,18 @@ func TestUpdate(t *testing.T) {
Name: "Thay", Name: "Thay",
} }
err := c.Insert(ctx, &u) err := c.Insert(ctx, &u)
assert.Equal(t, err, nil) assert.Equal(t, nil, err)
assert.NotEqual(t, 0, u.ID) assert.NotEqual(t, 0, u.ID)
// Empty update, should do nothing: // Empty update, should do nothing:
err = c.Update(ctx) err = c.Update(ctx)
assert.Equal(t, err, nil) assert.Equal(t, nil, err)
result := User{} result := User{}
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID) it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
it.Scan(&result) it.Scan(&result)
it.Close() it.Close()
assert.Equal(t, err, nil) assert.Equal(t, nil, err)
assert.Equal(t, "Thay", result.Name) assert.Equal(t, "Thay", result.Name)
}) })
@ -239,7 +305,7 @@ func TestUpdate(t *testing.T) {
ID: 1, ID: 1,
Name: "Thayane", Name: "Thayane",
}) })
assert.NotEqual(t, err, nil) assert.NotEqual(t, nil, err)
}) })
} }