Fix Update() to work with postgres dialect

pull/2/head
Vinícius Garcia 2020-12-29 23:51:31 -03:00
parent de8f4e56d7
commit f782fabb37
2 changed files with 194 additions and 180 deletions

View File

@ -363,7 +363,7 @@ func (c Client) Update(
records ...interface{},
) error {
for _, record := range records {
query, params, err := buildUpdateQuery(c.tableName, record, "id")
query, params, err := buildUpdateQuery(c.dialect, c.tableName, record, "id")
if err != nil {
return err
}
@ -424,6 +424,7 @@ func buildInsertQuery(
}
func buildUpdateQuery(
dialect dialect,
tableName string,
record interface{},
idFieldNames ...string,
@ -433,14 +434,19 @@ func buildUpdateQuery(
return "", nil, err
}
numAttrs := len(recordMap)
numIDs := len(idFieldNames)
args = make([]interface{}, numAttrs+numIDs)
whereArgs := args[numAttrs-len(idFieldNames):]
args = make([]interface{}, numAttrs)
numNonIDArgs := numAttrs - len(idFieldNames)
whereArgs := args[numNonIDArgs:]
var whereQuery []string
whereQuery := make([]string, len(idFieldNames))
for i, fieldName := range idFieldNames {
whereArgs[i] = recordMap[fieldName]
whereQuery = append(whereQuery, fmt.Sprintf("`%s` = ?", fieldName))
whereQuery[i] = fmt.Sprintf(
"%s = %s",
dialect.Escape(fieldName),
dialect.Placeholder(i+numNonIDArgs),
)
delete(recordMap, fieldName)
}
@ -452,12 +458,16 @@ func buildUpdateQuery(
var setQuery []string
for i, k := range keys {
args[i] = recordMap[k]
setQuery = append(setQuery, fmt.Sprintf("`%s` = ?", k))
setQuery = append(setQuery, fmt.Sprintf(
"%s = %s",
dialect.Escape(k),
dialect.Placeholder(i),
))
}
query = fmt.Sprintf(
"UPDATE `%s` SET %s WHERE %s",
tableName,
"UPDATE %s SET %s WHERE %s",
dialect.Escape(tableName),
strings.Join(setQuery, ", "),
strings.Join(whereQuery, ", "),
)

View File

@ -364,178 +364,182 @@ func TestDelete(t *testing.T) {
}
func TestUpdate(t *testing.T) {
err := createTable("sqlite3")
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
for _, driver := range []string{"sqlite3", "postgres"} {
t.Run(driver, func(t *testing.T) {
err := createTable(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("should ignore empty lists of ids", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestClient(db, driver, "users")
u := User{
Name: "Thay",
}
err := c.Insert(ctx, &u)
assert.Equal(t, nil, err)
assert.NotEqual(t, uint(0), u.ID)
// Empty update, should do nothing:
err = c.Update(ctx)
assert.Equal(t, nil, err)
result := User{}
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
it.Scan(&result)
it.Close()
assert.Equal(t, nil, err)
assert.Equal(t, "Thay", result.Name)
})
t.Run("should update one user correctly", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestClient(db, driver, "users")
u := User{
Name: "Letícia",
}
r := c.db.Table(c.tableName).Create(&u)
assert.Equal(t, nil, r.Error)
assert.NotEqual(t, uint(0), u.ID)
err = c.Update(ctx, User{
ID: u.ID,
Name: "Thayane",
})
assert.Equal(t, nil, err)
var result User
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
it.Scan(&result)
assert.Equal(t, nil, it.Error)
assert.Equal(t, "Thayane", result.Name)
})
t.Run("should update one user correctly", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestClient(db, driver, "users")
u := User{
Name: "Letícia",
}
r := c.db.Table(c.tableName).Create(&u)
assert.Equal(t, nil, r.Error)
assert.NotEqual(t, uint(0), u.ID)
err = c.Update(ctx, User{
ID: u.ID,
Name: "Thayane",
})
assert.Equal(t, nil, err)
var result User
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
it.Scan(&result)
assert.Equal(t, nil, it.Error)
assert.Equal(t, "Thayane", result.Name)
})
t.Run("should ignore null pointers on partial updates", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestClient(db, driver, "users")
type partialUser struct {
ID uint `gorm:"id"`
Name string `gorm:"name"`
Age *int `gorm:"age"`
}
u := partialUser{
Name: "Letícia",
Age: nullable.Int(22),
}
r := c.db.Table(c.tableName).Create(&u)
assert.Equal(t, nil, r.Error)
assert.NotEqual(t, uint(0), u.ID)
err = c.Update(ctx, partialUser{
ID: u.ID,
// Should be updated because it is not null, just empty:
Name: "",
// Should not be updated because it is null:
Age: nil,
})
assert.Equal(t, nil, err)
var result User
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
it.Scan(&result)
assert.Equal(t, nil, it.Error)
assert.Equal(t, "", result.Name)
assert.Equal(t, 22, result.Age)
})
t.Run("should update valid pointers on partial updates", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestClient(db, driver, "users")
type partialUser struct {
ID uint `gorm:"id"`
Name string `gorm:"name"`
Age *int `gorm:"age"`
}
u := partialUser{
Name: "Letícia",
Age: nullable.Int(22),
}
r := c.db.Table(c.tableName).Create(&u)
assert.Equal(t, nil, r.Error)
assert.NotEqual(t, uint(0), u.ID)
// Should update all fields:
err = c.Update(ctx, partialUser{
ID: u.ID,
Name: "Thay",
Age: nullable.Int(42),
})
assert.Equal(t, nil, err)
var result User
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
it.Scan(&result)
assert.Equal(t, nil, it.Error)
assert.Equal(t, "Thay", result.Name)
assert.Equal(t, 42, result.Age)
})
t.Run("should report database errors correctly", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestClient(db, driver, "non_existing_table")
err = c.Update(ctx, User{
ID: 1,
Name: "Thayane",
})
assert.NotEqual(t, nil, err)
})
})
}
t.Run("should ignore empty lists of ids", func(t *testing.T) {
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := newTestClient(db, "sqlite3", "users")
u := User{
Name: "Thay",
}
err := c.Insert(ctx, &u)
assert.Equal(t, nil, err)
assert.NotEqual(t, uint(0), u.ID)
// Empty update, should do nothing:
err = c.Update(ctx)
assert.Equal(t, nil, err)
result := User{}
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
it.Scan(&result)
it.Close()
assert.Equal(t, nil, err)
assert.Equal(t, "Thay", result.Name)
})
t.Run("should update one user correctly", func(t *testing.T) {
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := newTestClient(db, "sqlite3", "users")
u := User{
Name: "Letícia",
}
r := c.db.Table(c.tableName).Create(&u)
assert.Equal(t, nil, r.Error)
assert.NotEqual(t, uint(0), u.ID)
err = c.Update(ctx, User{
ID: u.ID,
Name: "Thayane",
})
assert.Equal(t, nil, err)
var result User
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
it.Scan(&result)
assert.Equal(t, nil, it.Error)
assert.Equal(t, "Thayane", result.Name)
})
t.Run("should update one user correctly", func(t *testing.T) {
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := newTestClient(db, "sqlite3", "users")
u := User{
Name: "Letícia",
}
r := c.db.Table(c.tableName).Create(&u)
assert.Equal(t, nil, r.Error)
assert.NotEqual(t, uint(0), u.ID)
err = c.Update(ctx, User{
ID: u.ID,
Name: "Thayane",
})
assert.Equal(t, nil, err)
var result User
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
it.Scan(&result)
assert.Equal(t, nil, it.Error)
assert.Equal(t, "Thayane", result.Name)
})
t.Run("should ignore null pointers on partial updates", func(t *testing.T) {
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := newTestClient(db, "sqlite3", "users")
type partialUser struct {
ID uint `gorm:"id"`
Name string `gorm:"name"`
Age *int `gorm:"age"`
}
u := partialUser{
Name: "Letícia",
Age: nullable.Int(22),
}
r := c.db.Table(c.tableName).Create(&u)
assert.Equal(t, nil, r.Error)
assert.NotEqual(t, uint(0), u.ID)
err = c.Update(ctx, partialUser{
ID: u.ID,
// Should be updated because it is not null, just empty:
Name: "",
// Should not be updated because it is null:
Age: nil,
})
assert.Equal(t, nil, err)
var result User
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
it.Scan(&result)
assert.Equal(t, nil, it.Error)
assert.Equal(t, "", result.Name)
assert.Equal(t, 22, result.Age)
})
t.Run("should update valid pointers on partial updates", func(t *testing.T) {
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := newTestClient(db, "sqlite3", "users")
type partialUser struct {
ID uint `gorm:"id"`
Name string `gorm:"name"`
Age *int `gorm:"age"`
}
u := partialUser{
Name: "Letícia",
Age: nullable.Int(22),
}
r := c.db.Table(c.tableName).Create(&u)
assert.Equal(t, nil, r.Error)
assert.NotEqual(t, uint(0), u.ID)
// Should update all fields:
err = c.Update(ctx, partialUser{
ID: u.ID,
Name: "Thay",
Age: nullable.Int(42),
})
assert.Equal(t, nil, err)
var result User
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
it.Scan(&result)
assert.Equal(t, nil, it.Error)
assert.Equal(t, "Thay", result.Name)
assert.Equal(t, 42, result.Age)
})
t.Run("should report database errors correctly", func(t *testing.T) {
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := newTestClient(db, "sqlite3", "non_existing_table")
err = c.Update(ctx, User{
ID: 1,
Name: "Thayane",
})
assert.NotEqual(t, nil, err)
})
}
func TestStructToMap(t *testing.T) {