mirror of https://github.com/VinGarcia/ksql.git
Fix Update() to work with postgres dialect
parent
de8f4e56d7
commit
f782fabb37
28
kiss_orm.go
28
kiss_orm.go
|
@ -363,7 +363,7 @@ func (c Client) Update(
|
|||
records ...interface{},
|
||||
) error {
|
||||
for _, record := range records {
|
||||
query, params, err := buildUpdateQuery(c.tableName, record, "id")
|
||||
query, params, err := buildUpdateQuery(c.dialect, c.tableName, record, "id")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -424,6 +424,7 @@ func buildInsertQuery(
|
|||
}
|
||||
|
||||
func buildUpdateQuery(
|
||||
dialect dialect,
|
||||
tableName string,
|
||||
record interface{},
|
||||
idFieldNames ...string,
|
||||
|
@ -433,14 +434,19 @@ func buildUpdateQuery(
|
|||
return "", nil, err
|
||||
}
|
||||
numAttrs := len(recordMap)
|
||||
numIDs := len(idFieldNames)
|
||||
args = make([]interface{}, numAttrs+numIDs)
|
||||
whereArgs := args[numAttrs-len(idFieldNames):]
|
||||
args = make([]interface{}, numAttrs)
|
||||
numNonIDArgs := numAttrs - len(idFieldNames)
|
||||
whereArgs := args[numNonIDArgs:]
|
||||
|
||||
var whereQuery []string
|
||||
whereQuery := make([]string, len(idFieldNames))
|
||||
for i, fieldName := range idFieldNames {
|
||||
whereArgs[i] = recordMap[fieldName]
|
||||
whereQuery = append(whereQuery, fmt.Sprintf("`%s` = ?", fieldName))
|
||||
whereQuery[i] = fmt.Sprintf(
|
||||
"%s = %s",
|
||||
dialect.Escape(fieldName),
|
||||
dialect.Placeholder(i+numNonIDArgs),
|
||||
)
|
||||
|
||||
delete(recordMap, fieldName)
|
||||
}
|
||||
|
||||
|
@ -452,12 +458,16 @@ func buildUpdateQuery(
|
|||
var setQuery []string
|
||||
for i, k := range keys {
|
||||
args[i] = recordMap[k]
|
||||
setQuery = append(setQuery, fmt.Sprintf("`%s` = ?", k))
|
||||
setQuery = append(setQuery, fmt.Sprintf(
|
||||
"%s = %s",
|
||||
dialect.Escape(k),
|
||||
dialect.Placeholder(i),
|
||||
))
|
||||
}
|
||||
|
||||
query = fmt.Sprintf(
|
||||
"UPDATE `%s` SET %s WHERE %s",
|
||||
tableName,
|
||||
"UPDATE %s SET %s WHERE %s",
|
||||
dialect.Escape(tableName),
|
||||
strings.Join(setQuery, ", "),
|
||||
strings.Join(whereQuery, ", "),
|
||||
)
|
||||
|
|
346
kiss_orm_test.go
346
kiss_orm_test.go
|
@ -364,178 +364,182 @@ func TestDelete(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
err := createTable("sqlite3")
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
for _, driver := range []string{"sqlite3", "postgres"} {
|
||||
t.Run(driver, func(t *testing.T) {
|
||||
err := createTable(driver)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
|
||||
t.Run("should ignore empty lists of ids", func(t *testing.T) {
|
||||
db := connectDB(t, driver)
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestClient(db, driver, "users")
|
||||
|
||||
u := User{
|
||||
Name: "Thay",
|
||||
}
|
||||
err := c.Insert(ctx, &u)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, uint(0), u.ID)
|
||||
|
||||
// Empty update, should do nothing:
|
||||
err = c.Update(ctx)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
result := User{}
|
||||
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
|
||||
it.Scan(&result)
|
||||
it.Close()
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
assert.Equal(t, "Thay", result.Name)
|
||||
})
|
||||
|
||||
t.Run("should update one user correctly", func(t *testing.T) {
|
||||
db := connectDB(t, driver)
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestClient(db, driver, "users")
|
||||
|
||||
u := User{
|
||||
Name: "Letícia",
|
||||
}
|
||||
r := c.db.Table(c.tableName).Create(&u)
|
||||
assert.Equal(t, nil, r.Error)
|
||||
assert.NotEqual(t, uint(0), u.ID)
|
||||
|
||||
err = c.Update(ctx, User{
|
||||
ID: u.ID,
|
||||
Name: "Thayane",
|
||||
})
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
var result User
|
||||
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
|
||||
it.Scan(&result)
|
||||
assert.Equal(t, nil, it.Error)
|
||||
assert.Equal(t, "Thayane", result.Name)
|
||||
})
|
||||
|
||||
t.Run("should update one user correctly", func(t *testing.T) {
|
||||
db := connectDB(t, driver)
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestClient(db, driver, "users")
|
||||
|
||||
u := User{
|
||||
Name: "Letícia",
|
||||
}
|
||||
r := c.db.Table(c.tableName).Create(&u)
|
||||
assert.Equal(t, nil, r.Error)
|
||||
assert.NotEqual(t, uint(0), u.ID)
|
||||
|
||||
err = c.Update(ctx, User{
|
||||
ID: u.ID,
|
||||
Name: "Thayane",
|
||||
})
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
var result User
|
||||
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
|
||||
it.Scan(&result)
|
||||
assert.Equal(t, nil, it.Error)
|
||||
assert.Equal(t, "Thayane", result.Name)
|
||||
})
|
||||
|
||||
t.Run("should ignore null pointers on partial updates", func(t *testing.T) {
|
||||
db := connectDB(t, driver)
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestClient(db, driver, "users")
|
||||
|
||||
type partialUser struct {
|
||||
ID uint `gorm:"id"`
|
||||
Name string `gorm:"name"`
|
||||
Age *int `gorm:"age"`
|
||||
}
|
||||
u := partialUser{
|
||||
Name: "Letícia",
|
||||
Age: nullable.Int(22),
|
||||
}
|
||||
r := c.db.Table(c.tableName).Create(&u)
|
||||
assert.Equal(t, nil, r.Error)
|
||||
assert.NotEqual(t, uint(0), u.ID)
|
||||
|
||||
err = c.Update(ctx, partialUser{
|
||||
ID: u.ID,
|
||||
// Should be updated because it is not null, just empty:
|
||||
Name: "",
|
||||
// Should not be updated because it is null:
|
||||
Age: nil,
|
||||
})
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
var result User
|
||||
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
|
||||
it.Scan(&result)
|
||||
assert.Equal(t, nil, it.Error)
|
||||
assert.Equal(t, "", result.Name)
|
||||
assert.Equal(t, 22, result.Age)
|
||||
})
|
||||
|
||||
t.Run("should update valid pointers on partial updates", func(t *testing.T) {
|
||||
db := connectDB(t, driver)
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestClient(db, driver, "users")
|
||||
|
||||
type partialUser struct {
|
||||
ID uint `gorm:"id"`
|
||||
Name string `gorm:"name"`
|
||||
Age *int `gorm:"age"`
|
||||
}
|
||||
u := partialUser{
|
||||
Name: "Letícia",
|
||||
Age: nullable.Int(22),
|
||||
}
|
||||
r := c.db.Table(c.tableName).Create(&u)
|
||||
assert.Equal(t, nil, r.Error)
|
||||
assert.NotEqual(t, uint(0), u.ID)
|
||||
|
||||
// Should update all fields:
|
||||
err = c.Update(ctx, partialUser{
|
||||
ID: u.ID,
|
||||
Name: "Thay",
|
||||
Age: nullable.Int(42),
|
||||
})
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
var result User
|
||||
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
|
||||
it.Scan(&result)
|
||||
assert.Equal(t, nil, it.Error)
|
||||
assert.Equal(t, "Thay", result.Name)
|
||||
assert.Equal(t, 42, result.Age)
|
||||
})
|
||||
|
||||
t.Run("should report database errors correctly", func(t *testing.T) {
|
||||
db := connectDB(t, driver)
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestClient(db, driver, "non_existing_table")
|
||||
|
||||
err = c.Update(ctx, User{
|
||||
ID: 1,
|
||||
Name: "Thayane",
|
||||
})
|
||||
assert.NotEqual(t, nil, err)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("should ignore empty lists of ids", func(t *testing.T) {
|
||||
db := connectDB(t, "sqlite3")
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestClient(db, "sqlite3", "users")
|
||||
|
||||
u := User{
|
||||
Name: "Thay",
|
||||
}
|
||||
err := c.Insert(ctx, &u)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, uint(0), u.ID)
|
||||
|
||||
// Empty update, should do nothing:
|
||||
err = c.Update(ctx)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
result := User{}
|
||||
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
|
||||
it.Scan(&result)
|
||||
it.Close()
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
assert.Equal(t, "Thay", result.Name)
|
||||
})
|
||||
|
||||
t.Run("should update one user correctly", func(t *testing.T) {
|
||||
db := connectDB(t, "sqlite3")
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestClient(db, "sqlite3", "users")
|
||||
|
||||
u := User{
|
||||
Name: "Letícia",
|
||||
}
|
||||
r := c.db.Table(c.tableName).Create(&u)
|
||||
assert.Equal(t, nil, r.Error)
|
||||
assert.NotEqual(t, uint(0), u.ID)
|
||||
|
||||
err = c.Update(ctx, User{
|
||||
ID: u.ID,
|
||||
Name: "Thayane",
|
||||
})
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
var result User
|
||||
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
|
||||
it.Scan(&result)
|
||||
assert.Equal(t, nil, it.Error)
|
||||
assert.Equal(t, "Thayane", result.Name)
|
||||
})
|
||||
|
||||
t.Run("should update one user correctly", func(t *testing.T) {
|
||||
db := connectDB(t, "sqlite3")
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestClient(db, "sqlite3", "users")
|
||||
|
||||
u := User{
|
||||
Name: "Letícia",
|
||||
}
|
||||
r := c.db.Table(c.tableName).Create(&u)
|
||||
assert.Equal(t, nil, r.Error)
|
||||
assert.NotEqual(t, uint(0), u.ID)
|
||||
|
||||
err = c.Update(ctx, User{
|
||||
ID: u.ID,
|
||||
Name: "Thayane",
|
||||
})
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
var result User
|
||||
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
|
||||
it.Scan(&result)
|
||||
assert.Equal(t, nil, it.Error)
|
||||
assert.Equal(t, "Thayane", result.Name)
|
||||
})
|
||||
|
||||
t.Run("should ignore null pointers on partial updates", func(t *testing.T) {
|
||||
db := connectDB(t, "sqlite3")
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestClient(db, "sqlite3", "users")
|
||||
|
||||
type partialUser struct {
|
||||
ID uint `gorm:"id"`
|
||||
Name string `gorm:"name"`
|
||||
Age *int `gorm:"age"`
|
||||
}
|
||||
u := partialUser{
|
||||
Name: "Letícia",
|
||||
Age: nullable.Int(22),
|
||||
}
|
||||
r := c.db.Table(c.tableName).Create(&u)
|
||||
assert.Equal(t, nil, r.Error)
|
||||
assert.NotEqual(t, uint(0), u.ID)
|
||||
|
||||
err = c.Update(ctx, partialUser{
|
||||
ID: u.ID,
|
||||
// Should be updated because it is not null, just empty:
|
||||
Name: "",
|
||||
// Should not be updated because it is null:
|
||||
Age: nil,
|
||||
})
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
var result User
|
||||
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
|
||||
it.Scan(&result)
|
||||
assert.Equal(t, nil, it.Error)
|
||||
assert.Equal(t, "", result.Name)
|
||||
assert.Equal(t, 22, result.Age)
|
||||
})
|
||||
|
||||
t.Run("should update valid pointers on partial updates", func(t *testing.T) {
|
||||
db := connectDB(t, "sqlite3")
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestClient(db, "sqlite3", "users")
|
||||
|
||||
type partialUser struct {
|
||||
ID uint `gorm:"id"`
|
||||
Name string `gorm:"name"`
|
||||
Age *int `gorm:"age"`
|
||||
}
|
||||
u := partialUser{
|
||||
Name: "Letícia",
|
||||
Age: nullable.Int(22),
|
||||
}
|
||||
r := c.db.Table(c.tableName).Create(&u)
|
||||
assert.Equal(t, nil, r.Error)
|
||||
assert.NotEqual(t, uint(0), u.ID)
|
||||
|
||||
// Should update all fields:
|
||||
err = c.Update(ctx, partialUser{
|
||||
ID: u.ID,
|
||||
Name: "Thay",
|
||||
Age: nullable.Int(42),
|
||||
})
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
var result User
|
||||
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
|
||||
it.Scan(&result)
|
||||
assert.Equal(t, nil, it.Error)
|
||||
assert.Equal(t, "Thay", result.Name)
|
||||
assert.Equal(t, 42, result.Age)
|
||||
})
|
||||
|
||||
t.Run("should report database errors correctly", func(t *testing.T) {
|
||||
db := connectDB(t, "sqlite3")
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestClient(db, "sqlite3", "non_existing_table")
|
||||
|
||||
err = c.Update(ctx, User{
|
||||
ID: 1,
|
||||
Name: "Thayane",
|
||||
})
|
||||
assert.NotEqual(t, nil, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructToMap(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue