diff --git a/internal/structs/structs.go b/internal/structs/structs.go index 90c88c5..5e74f49 100644 --- a/internal/structs/structs.go +++ b/internal/structs/structs.go @@ -150,11 +150,13 @@ func StructToMap(obj interface{}) (map[string]interface{}, error) { field := v.Field(i) ft := field.Type() if ft.Kind() == reflect.Ptr { - if field.IsNil() && !fieldInfo.Modifier.Nullable { - continue + if !field.IsNil() { + field = field.Elem() + } else { + if !fieldInfo.Modifier.Nullable { + continue + } } - - field = field.Elem() } m[fieldInfo.ColumnName] = field.Interface() diff --git a/ksql.go b/ksql.go index 808991a..3249759 100644 --- a/ksql.go +++ b/ksql.go @@ -13,7 +13,6 @@ import ( "github.com/vingarcia/ksql/internal/modifiers" "github.com/vingarcia/ksql/internal/structs" "github.com/vingarcia/ksql/ksqlmodifiers" - "github.com/vingarcia/ksql/ksqltest" ) var selectQueryCache = initializeQueryCache() @@ -631,7 +630,7 @@ func normalizeIDsAsMap(idNames []string, idOrMap interface{}) (idMap map[string] switch t.Kind() { case reflect.Struct: - idMap, err = ksqltest.StructToMap(idOrMap) + idMap, err = structs.StructToMap(idOrMap) if err != nil { return nil, fmt.Errorf("could not get ID(s) from input record: %w", err) } @@ -724,7 +723,7 @@ func buildInsertQuery( info structs.StructInfo, record interface{}, ) (query string, params []interface{}, scanValues []interface{}, err error) { - recordMap, err := ksqltest.StructToMap(record) + recordMap, err := structs.StructToMap(record) if err != nil { return "", nil, nil, err } diff --git a/test_adapters.go b/test_adapters.go index c801c96..89ecb43 100644 --- a/test_adapters.go +++ b/test_adapters.go @@ -3070,6 +3070,142 @@ func ModifiersTest( tt.AssertEqual(t, taggedUser.Name, "Marta Ribeiro") }) }) + + t.Run("nullable modifier", func(t *testing.T) { + t.Run("should prevent null fields from being ignored during insertions", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + c := newTestDB(db, driver) + + var taggedUser struct { + ID uint `ksql:"id"` + NullableField *string `ksql:"nullable_field,nullable"` + } + + var untaggedUser struct { + ID uint `ksql:"id"` + NullableField *string `ksql:"nullable_field"` + } + + err := c.Insert(ctx, usersTable, &taggedUser) + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, taggedUser.ID, 0) + + err = c.QueryOne(ctx, &taggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), taggedUser.ID) + tt.AssertNoErr(t, err) + + err = c.Insert(ctx, usersTable, &untaggedUser) + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, taggedUser.ID, 0) + + err = c.QueryOne(ctx, &untaggedUser, `FROM users WHERE id = `+c.dialect.Placeholder(0), untaggedUser.ID) + tt.AssertNoErr(t, err) + + tt.AssertEqual(t, taggedUser.NullableField == nil, true) + tt.AssertEqual(t, untaggedUser.NullableField, nullable.String("not_null")) + }) + + t.Run("should prevent null fields from being ignored during updates", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + c := newTestDB(db, driver) + + type userWithNoTags struct { + ID uint `ksql:"id"` + Name string `ksql:"name"` + NullableField *string `ksql:"nullable_field"` + } + untaggedUser := userWithNoTags{ + Name: "Laurinha Ribeiro", + NullableField: nullable.String("fakeValue"), + } + err := c.Insert(ctx, usersTable, &untaggedUser) + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, untaggedUser.ID, 0) + + type taggedUser struct { + ID uint `ksql:"id"` + Name string `ksql:"name"` + NullableField *string `ksql:"nullable_field,nullable"` + } + u := taggedUser{ + ID: untaggedUser.ID, + Name: "Laura Ribeiro", + NullableField: nil, + } + err = c.Patch(ctx, usersTable, u) + tt.AssertNoErr(t, err) + + var untaggedUser2 userWithNoTags + err = c.QueryOne(ctx, &untaggedUser2, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID) + tt.AssertNoErr(t, err) + tt.AssertEqual(t, untaggedUser2.Name, "Laura Ribeiro") + tt.AssertEqual(t, untaggedUser2.NullableField == nil, true) + }) + + t.Run("should not alter the value on queries", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + c := newTestDB(db, driver) + + type userWithNoTags struct { + ID uint `ksql:"id"` + Name *string `ksql:"name"` + } + untaggedUser := userWithNoTags{ + Name: nullable.String("Marta Ribeiro"), + } + err := c.Insert(ctx, usersTable, &untaggedUser) + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, untaggedUser.ID, 0) + + var taggedUser struct { + ID uint `ksql:"id"` + Name *string `ksql:"name,nullable"` + } + err = c.QueryOne(ctx, &taggedUser, "FROM users WHERE id = "+c.dialect.Placeholder(0), untaggedUser.ID) + tt.AssertNoErr(t, err) + tt.AssertEqual(t, taggedUser.ID, untaggedUser.ID) + tt.AssertEqual(t, taggedUser.Name, nullable.String("Marta Ribeiro")) + }) + + t.Run("should cause no effect if used on a non pointer field", func(t *testing.T) { + db, closer := newDBAdapter(t) + defer closer.Close() + + c := newTestDB(db, driver) + + type user struct { + ID uint `ksql:"id"` + Name string `ksql:"name,nullable"` + Age int `ksql:"age,nullable"` + } + u1 := user{ + Name: "Marta Ribeiro", + } + err := c.Insert(ctx, usersTable, &u1) + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, u1.ID, 0) + + err = c.Patch(ctx, usersTable, &struct { + ID uint `ksql:"id"` + Age int `ksql:"age,nullable"` + }{ + ID: u1.ID, + Age: 42, + }) + + var u2 user + err = c.QueryOne(ctx, &u2, "FROM users WHERE id = "+c.dialect.Placeholder(0), u1.ID) + tt.AssertNoErr(t, err) + tt.AssertEqual(t, u2.ID, u1.ID) + tt.AssertEqual(t, u2.Name, "Marta Ribeiro") + tt.AssertEqual(t, u2.Age, 42) + }) + }) }) } @@ -3317,7 +3453,8 @@ func createTables(driver string, connStr string) error { name TEXT, address BLOB, created_at DATETIME, - updated_at DATETIME + updated_at DATETIME, + nullable_field TEXT DEFAULT "not_null" )`) case "postgres": _, err = db.Exec(`CREATE TABLE users ( @@ -3326,7 +3463,8 @@ func createTables(driver string, connStr string) error { name VARCHAR(50), address jsonb, created_at TIMESTAMP, - updated_at TIMESTAMP + updated_at TIMESTAMP, + nullable_field VARCHAR(50) DEFAULT 'not_null' )`) case "mysql": _, err = db.Exec(`CREATE TABLE users ( @@ -3335,7 +3473,8 @@ func createTables(driver string, connStr string) error { name VARCHAR(50), address JSON, created_at DATETIME, - updated_at DATETIME + updated_at DATETIME, + nullable_field VARCHAR(50) DEFAULT "not_null" )`) case "sqlserver": _, err = db.Exec(`CREATE TABLE users ( @@ -3344,7 +3483,8 @@ func createTables(driver string, connStr string) error { name VARCHAR(50), address NVARCHAR(4000), created_at DATETIME, - updated_at DATETIME + updated_at DATETIME, + nullable_field VARCHAR(50) DEFAULT 'not_null' )`) } if err != nil {