Add tests to the nullable Modifier

pull/32/head
Vinícius Garcia 2022-12-15 19:48:31 -03:00
parent 4598800f87
commit 88167361c1
3 changed files with 152 additions and 11 deletions

View File

@ -150,11 +150,13 @@ func StructToMap(obj interface{}) (map[string]interface{}, error) {
field := v.Field(i) field := v.Field(i)
ft := field.Type() ft := field.Type()
if ft.Kind() == reflect.Ptr { if ft.Kind() == reflect.Ptr {
if field.IsNil() && !fieldInfo.Modifier.Nullable { if !field.IsNil() {
continue field = field.Elem()
} else {
if !fieldInfo.Modifier.Nullable {
continue
}
} }
field = field.Elem()
} }
m[fieldInfo.ColumnName] = field.Interface() m[fieldInfo.ColumnName] = field.Interface()

View File

@ -13,7 +13,6 @@ import (
"github.com/vingarcia/ksql/internal/modifiers" "github.com/vingarcia/ksql/internal/modifiers"
"github.com/vingarcia/ksql/internal/structs" "github.com/vingarcia/ksql/internal/structs"
"github.com/vingarcia/ksql/ksqlmodifiers" "github.com/vingarcia/ksql/ksqlmodifiers"
"github.com/vingarcia/ksql/ksqltest"
) )
var selectQueryCache = initializeQueryCache() var selectQueryCache = initializeQueryCache()
@ -631,7 +630,7 @@ func normalizeIDsAsMap(idNames []string, idOrMap interface{}) (idMap map[string]
switch t.Kind() { switch t.Kind() {
case reflect.Struct: case reflect.Struct:
idMap, err = ksqltest.StructToMap(idOrMap) idMap, err = structs.StructToMap(idOrMap)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not get ID(s) from input record: %w", err) return nil, fmt.Errorf("could not get ID(s) from input record: %w", err)
} }
@ -724,7 +723,7 @@ func buildInsertQuery(
info structs.StructInfo, info structs.StructInfo,
record interface{}, record interface{},
) (query string, params []interface{}, scanValues []interface{}, err error) { ) (query string, params []interface{}, scanValues []interface{}, err error) {
recordMap, err := ksqltest.StructToMap(record) recordMap, err := structs.StructToMap(record)
if err != nil { if err != nil {
return "", nil, nil, err return "", nil, nil, err
} }

View File

@ -3070,6 +3070,142 @@ func ModifiersTest(
tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro") tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro")
}) })
}) })
t.Run("nullable modifier", func(t *testing.T) {
t.Run("should prevent null fields from being ignored during insertions", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
var taggedUser struct {
ID uint `ksql:"id"`
NullableField *string `ksql:"nullable_field,nullable"`
}
var untaggedUser struct {
ID uint `ksql:"id"`
NullableField *string `ksql:"nullable_field"`
}
err := c.Insert(ctx, usersTable, &taggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, taggedUser.ID, 0)
err = c.QueryOne(ctx, &taggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), taggedUser.ID)
tt.AssertNoErr(t, err)
err = c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, taggedUser.ID, 0)
err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), untaggedUser.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, taggedUser.NullableField == nil, true)
tt.AssertEqual(t, untaggedUser.NullableField, nullable.String("not_null"))
})
t.Run("should prevent null fields from being ignored during updates", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
NullableField *string `ksql:"nullable_field"`
}
untaggedUser := userWithNoTags{
Name: "Laurinha Ribeiro",
NullableField: nullable.String("fakeValue"),
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
type taggedUser struct {
ID uint `ksql:"id"`
Name string `ksql:"name"`
NullableField *string `ksql:"nullable_field,nullable"`
}
u := taggedUser{
ID: untaggedUser.ID,
Name: "Laura Ribeiro",
NullableField: nil,
}
err = c.Patch(ctx, usersTable, u)
tt.AssertNoErr(t, err)
var untaggedUser2 userWithNoTags
err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, untaggedUser2.Name, "Laura Ribeiro")
tt.AssertEqual(t, untaggedUser2.NullableField == nil, true)
})
t.Run("should not alter the value on queries", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
type userWithNoTags struct {
ID uint `ksql:"id"`
Name *string `ksql:"name"`
}
untaggedUser := userWithNoTags{
Name: nullable.String("Marta Ribeiro"),
}
err := c.Insert(ctx, usersTable, &untaggedUser)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, untaggedUser.ID, 0)
var taggedUser struct {
ID uint `ksql:"id"`
Name *string `ksql:"name,nullable"`
}
err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID)
tt.AssertEqual(t, taggedUser.Name, nullable.String("Marta Ribeiro"))
})
t.Run("should cause no effect if used on a non pointer field", func(t *testing.T) {
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
type user struct {
ID uint `ksql:"id"`
Name string `ksql:"name,nullable"`
Age int `ksql:"age,nullable"`
}
u1 := user{
Name: "Marta Ribeiro",
}
err := c.Insert(ctx, usersTable, &u1)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u1.ID, 0)
err = c.Patch(ctx, usersTable, &struct {
ID uint `ksql:"id"`
Age int `ksql:"age,nullable"`
}{
ID: u1.ID,
Age: 42,
})
var u2 user
err = c.QueryOne(ctx, &u2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u1.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, u2.ID, u1.ID)
tt.AssertEqual(t, u2.Name, "Marta Ribeiro")
tt.AssertEqual(t, u2.Age, 42)
})
})
}) })
} }
@ -3317,7 +3453,8 @@ func createTables(driver string, connStr string) error {
name TEXT, name TEXT,
address BLOB, address BLOB,
created_at DATETIME, created_at DATETIME,
updated_at DATETIME updated_at DATETIME,
nullable_field TEXT DEFAULT "not_null"
)`) )`)
case "postgres": case "postgres":
_, err = db.Exec(`CREATE TABLE users ( _, err = db.Exec(`CREATE TABLE users (
@ -3326,7 +3463,8 @@ func createTables(driver string, connStr string) error {
name VARCHAR(50), name VARCHAR(50),
address jsonb, address jsonb,
created_at TIMESTAMP, created_at TIMESTAMP,
updated_at TIMESTAMP updated_at TIMESTAMP,
nullable_field VARCHAR(50) DEFAULT 'not_null'
)`) )`)
case "mysql": case "mysql":
_, err = db.Exec(`CREATE TABLE users ( _, err = db.Exec(`CREATE TABLE users (
@ -3335,7 +3473,8 @@ func createTables(driver string, connStr string) error {
name VARCHAR(50), name VARCHAR(50),
address JSON, address JSON,
created_at DATETIME, created_at DATETIME,
updated_at DATETIME updated_at DATETIME,
nullable_field VARCHAR(50) DEFAULT "not_null"
)`) )`)
case "sqlserver": case "sqlserver":
_, err = db.Exec(`CREATE TABLE users ( _, err = db.Exec(`CREATE TABLE users (
@ -3344,7 +3483,8 @@ func createTables(driver string, connStr string) error {
name VARCHAR(50), name VARCHAR(50),
address NVARCHAR(4000), address NVARCHAR(4000),
created_at DATETIME, created_at DATETIME,
updated_at DATETIME updated_at DATETIME,
nullable_field VARCHAR(50) DEFAULT 'not_null'
)`) )`)
} }
if err != nil { if err != nil {