Fix issue on kmysql Insert where an *int ID would not work

pull/55/head
Vinícius Garcia 2024-11-03 22:33:00 -03:00
parent 661630db8d
commit 3f0e9b9a3e
5 changed files with 85 additions and 17 deletions

View File

@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.14.0 // sqlc v1.26.0
package sqlcgen package sqlcgen

View File

@ -1,11 +1,9 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.14.0 // sqlc v1.26.0
package sqlcgen package sqlcgen
import ()
type User struct { type User struct {
ID int32 ID int32
Name string Name string

View File

@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.14.0 // sqlc v1.26.0
// source: queries.sql // source: queries.sql
package sqlcgen package sqlcgen

34
ksql.go
View File

@ -517,22 +517,38 @@ func (c DB) insertWithLastInsertID(
) error { ) error {
result, err := c.db.ExecContext(ctx, query, params...) result, err := c.db.ExecContext(ctx, query, params...)
if err != nil { if err != nil {
return err return fmt.Errorf("error running insert query: %w", err)
} }
id, err := result.LastInsertId() id, err := result.LastInsertId()
if err != nil { if err != nil {
return err return fmt.Errorf("error fetching LastInsertId: %w", err)
} }
vID := reflect.ValueOf(id) vID := reflect.ValueOf(id)
fieldAddr := v.Elem().Field(info.ByName(idName).Index).Addr() fieldValue := v.Elem().Field(info.ByName(idName).Index)
fieldType := fieldAddr.Type().Elem() fieldType := fieldValue.Type()
switch fieldType.Kind() { baseFieldKind := fieldType.Kind()
leafFieldKind := baseFieldKind
if baseFieldKind == reflect.Pointer {
leafFieldKind = fieldType.Elem().Kind()
}
switch leafFieldKind {
case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64: case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64:
fieldAddr.Elem().Set(vID.Convert(fieldType)) if baseFieldKind == reflect.Pointer {
// If fieldValue is nil allocate memory for it:
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldType.Elem()))
}
fieldValue.Elem().Set(vID.Convert(fieldType.Elem()))
return nil
}
fieldValue.Set(vID.Convert(fieldType))
return nil return nil
case reflect.String: case reflect.String:
@ -540,12 +556,6 @@ func (c DB) insertWithLastInsertID(
// we cannot retrieve it, so we just return: // we cannot retrieve it, so we just return:
return nil return nil
case reflect.Pointer:
if fieldType.Elem().Kind() == reflect.String {
return nil
}
fallthrough
default: default:
return fmt.Errorf( return fmt.Errorf(
"error scanning field `%s` cannot assign last insert id of type int64 into field of type %v", "error scanning field `%s` cannot assign last insert id of type int64 into field of type %v",

View File

@ -841,6 +841,66 @@ func InsertTest(
tt.AssertEqual(t, result.Address, u.Address) tt.AssertEqual(t, result.Address, u.Address)
}) })
t.Run("should insert one user correctly when the ID is a pointer to int allocating memory when nil", func(t *testing.T) {
c := newTestDB(db, dialect)
type ptrUser struct {
ID *uint `ksql:"id"`
Name string `ksql:"name"`
Address address `ksql:"address,json"`
Age int `ksql:"age"`
}
u := ptrUser{
Name: "Paulo",
Address: address{
Country: "Brazil",
},
}
err := c.Insert(ctx, usersTable, &u)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, nil)
result := user{}
err = getUserByID(db, dialect, &result, *u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, result.Name, u.Name)
tt.AssertEqual(t, result.Address, u.Address)
})
t.Run("should insert one user correctly when the ID is a pointer to int reusing existing int", func(t *testing.T) {
c := newTestDB(db, dialect)
type ptrUser struct {
ID *uint `ksql:"id"`
Name string `ksql:"name"`
Address address `ksql:"address,json"`
Age int `ksql:"age"`
}
var id uint = 0
u := ptrUser{
ID: &id,
Name: "Paulo",
Address: address{
Country: "Brazil",
},
}
err := c.Insert(ctx, usersTable, &u)
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, u.ID, nil)
tt.AssertEqual(t, id, *u.ID)
result := user{}
err = getUserByID(db, dialect, &result, *u.ID)
tt.AssertNoErr(t, err)
tt.AssertEqual(t, result.ID, id)
tt.AssertEqual(t, result.Name, u.Name)
tt.AssertEqual(t, result.Address, u.Address)
})
t.Run("should insert ignoring the ID with multiple ids", func(t *testing.T) { t.Run("should insert ignoring the ID with multiple ids", func(t *testing.T) {
if dialect.InsertMethod() != sqldialect.InsertWithLastInsertID { if dialect.InsertMethod() != sqldialect.InsertWithLastInsertID {
return return