mirror of https://github.com/VinGarcia/ksql.git
Fix issue on kmysql Insert where an *int ID would not work
parent
661630db8d
commit
3f0e9b9a3e
|
@ -1,6 +1,6 @@
|
|||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.14.0
|
||||
// sqlc v1.26.0
|
||||
|
||||
package sqlcgen
|
||||
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.14.0
|
||||
// sqlc v1.26.0
|
||||
|
||||
package sqlcgen
|
||||
|
||||
import ()
|
||||
|
||||
type User struct {
|
||||
ID int32
|
||||
Name string
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.14.0
|
||||
// sqlc v1.26.0
|
||||
// source: queries.sql
|
||||
|
||||
package sqlcgen
|
||||
|
|
34
ksql.go
34
ksql.go
|
@ -517,22 +517,38 @@ func (c DB) insertWithLastInsertID(
|
|||
) error {
|
||||
result, err := c.db.ExecContext(ctx, query, params...)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("error running insert query: %w", err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("error fetching LastInsertId: %w", err)
|
||||
}
|
||||
|
||||
vID := reflect.ValueOf(id)
|
||||
|
||||
fieldAddr := v.Elem().Field(info.ByName(idName).Index).Addr()
|
||||
fieldType := fieldAddr.Type().Elem()
|
||||
fieldValue := v.Elem().Field(info.ByName(idName).Index)
|
||||
fieldType := fieldValue.Type()
|
||||
|
||||
switch fieldType.Kind() {
|
||||
baseFieldKind := fieldType.Kind()
|
||||
leafFieldKind := baseFieldKind
|
||||
if baseFieldKind == reflect.Pointer {
|
||||
leafFieldKind = fieldType.Elem().Kind()
|
||||
}
|
||||
|
||||
switch leafFieldKind {
|
||||
case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64:
|
||||
fieldAddr.Elem().Set(vID.Convert(fieldType))
|
||||
if baseFieldKind == reflect.Pointer {
|
||||
// If fieldValue is nil allocate memory for it:
|
||||
if fieldValue.IsNil() {
|
||||
fieldValue.Set(reflect.New(fieldType.Elem()))
|
||||
}
|
||||
|
||||
fieldValue.Elem().Set(vID.Convert(fieldType.Elem()))
|
||||
return nil
|
||||
}
|
||||
|
||||
fieldValue.Set(vID.Convert(fieldType))
|
||||
return nil
|
||||
|
||||
case reflect.String:
|
||||
|
@ -540,12 +556,6 @@ func (c DB) insertWithLastInsertID(
|
|||
// we cannot retrieve it, so we just return:
|
||||
return nil
|
||||
|
||||
case reflect.Pointer:
|
||||
if fieldType.Elem().Kind() == reflect.String {
|
||||
return nil
|
||||
}
|
||||
|
||||
fallthrough
|
||||
default:
|
||||
return fmt.Errorf(
|
||||
"error scanning field `%s` cannot assign last insert id of type int64 into field of type %v",
|
||||
|
|
|
@ -841,6 +841,66 @@ func InsertTest(
|
|||
tt.AssertEqual(t, result.Address, u.Address)
|
||||
})
|
||||
|
||||
t.Run("should insert one user correctly when the ID is a pointer to int allocating memory when nil", func(t *testing.T) {
|
||||
c := newTestDB(db, dialect)
|
||||
|
||||
type ptrUser struct {
|
||||
ID *uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
Address address `ksql:"address,json"`
|
||||
Age int `ksql:"age"`
|
||||
}
|
||||
u := ptrUser{
|
||||
Name: "Paulo",
|
||||
Address: address{
|
||||
Country: "Brazil",
|
||||
},
|
||||
}
|
||||
|
||||
err := c.Insert(ctx, usersTable, &u)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, u.ID, nil)
|
||||
|
||||
result := user{}
|
||||
err = getUserByID(db, dialect, &result, *u.ID)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
tt.AssertEqual(t, result.Name, u.Name)
|
||||
tt.AssertEqual(t, result.Address, u.Address)
|
||||
})
|
||||
|
||||
t.Run("should insert one user correctly when the ID is a pointer to int reusing existing int", func(t *testing.T) {
|
||||
c := newTestDB(db, dialect)
|
||||
|
||||
type ptrUser struct {
|
||||
ID *uint `ksql:"id"`
|
||||
Name string `ksql:"name"`
|
||||
Address address `ksql:"address,json"`
|
||||
Age int `ksql:"age"`
|
||||
}
|
||||
var id uint = 0
|
||||
u := ptrUser{
|
||||
ID: &id,
|
||||
Name: "Paulo",
|
||||
Address: address{
|
||||
Country: "Brazil",
|
||||
},
|
||||
}
|
||||
|
||||
err := c.Insert(ctx, usersTable, &u)
|
||||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, u.ID, nil)
|
||||
tt.AssertEqual(t, id, *u.ID)
|
||||
|
||||
result := user{}
|
||||
err = getUserByID(db, dialect, &result, *u.ID)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
tt.AssertEqual(t, result.ID, id)
|
||||
tt.AssertEqual(t, result.Name, u.Name)
|
||||
tt.AssertEqual(t, result.Address, u.Address)
|
||||
})
|
||||
|
||||
t.Run("should insert ignoring the ID with multiple ids", func(t *testing.T) {
|
||||
if dialect.InsertMethod() != sqldialect.InsertWithLastInsertID {
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue