diff --git a/benchmarks/sqlcgen/db.go b/benchmarks/sqlcgen/db.go index 7f34027..af38f8d 100644 --- a/benchmarks/sqlcgen/db.go +++ b/benchmarks/sqlcgen/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.14.0 +// sqlc v1.26.0 package sqlcgen diff --git a/benchmarks/sqlcgen/models.go b/benchmarks/sqlcgen/models.go index 4c83f97..f108b31 100644 --- a/benchmarks/sqlcgen/models.go +++ b/benchmarks/sqlcgen/models.go @@ -1,11 +1,9 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.14.0 +// sqlc v1.26.0 package sqlcgen -import () - type User struct { ID int32 Name string diff --git a/benchmarks/sqlcgen/queries.sql.go b/benchmarks/sqlcgen/queries.sql.go index aba4a14..50dc904 100644 --- a/benchmarks/sqlcgen/queries.sql.go +++ b/benchmarks/sqlcgen/queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.14.0 +// sqlc v1.26.0 // source: queries.sql package sqlcgen diff --git a/ksql.go b/ksql.go index 5ab6ce4..db48385 100644 --- a/ksql.go +++ b/ksql.go @@ -517,22 +517,38 @@ func (c DB) insertWithLastInsertID( ) error { result, err := c.db.ExecContext(ctx, query, params...) if err != nil { - return err + return fmt.Errorf("error running insert query: %w", err) } id, err := result.LastInsertId() if err != nil { - return err + return fmt.Errorf("error fetching LastInsertId: %w", err) } vID := reflect.ValueOf(id) - fieldAddr := v.Elem().Field(info.ByName(idName).Index).Addr() - fieldType := fieldAddr.Type().Elem() + fieldValue := v.Elem().Field(info.ByName(idName).Index) + fieldType := fieldValue.Type() - switch fieldType.Kind() { + baseFieldKind := fieldType.Kind() + leafFieldKind := baseFieldKind + if baseFieldKind == reflect.Pointer { + leafFieldKind = fieldType.Elem().Kind() + } + + switch leafFieldKind { case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64: - fieldAddr.Elem().Set(vID.Convert(fieldType)) + if baseFieldKind == reflect.Pointer { + // If fieldValue is nil allocate memory for it: + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldType.Elem())) + } + + fieldValue.Elem().Set(vID.Convert(fieldType.Elem())) + return nil + } + + fieldValue.Set(vID.Convert(fieldType)) return nil case reflect.String: @@ -540,12 +556,6 @@ func (c DB) insertWithLastInsertID( // we cannot retrieve it, so we just return: return nil - case reflect.Pointer: - if fieldType.Elem().Kind() == reflect.String { - return nil - } - - fallthrough default: return fmt.Errorf( "error scanning field `%s` cannot assign last insert id of type int64 into field of type %v", diff --git a/test_adapters.go b/test_adapters.go index 52f6452..f4f82bc 100644 --- a/test_adapters.go +++ b/test_adapters.go @@ -841,6 +841,66 @@ func InsertTest( tt.AssertEqual(t, result.Address, u.Address) }) + t.Run("should insert one user correctly when the ID is a pointer to int allocating memory when nil", func(t *testing.T) { + c := newTestDB(db, dialect) + + type ptrUser struct { + ID *uint `ksql:"id"` + Name string `ksql:"name"` + Address address `ksql:"address,json"` + Age int `ksql:"age"` + } + u := ptrUser{ + Name: "Paulo", + Address: address{ + Country: "Brazil", + }, + } + + err := c.Insert(ctx, usersTable, &u) + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, u.ID, nil) + + result := user{} + err = getUserByID(db, dialect, &result, *u.ID) + tt.AssertNoErr(t, err) + + tt.AssertEqual(t, result.Name, u.Name) + tt.AssertEqual(t, result.Address, u.Address) + }) + + t.Run("should insert one user correctly when the ID is a pointer to int reusing existing int", func(t *testing.T) { + c := newTestDB(db, dialect) + + type ptrUser struct { + ID *uint `ksql:"id"` + Name string `ksql:"name"` + Address address `ksql:"address,json"` + Age int `ksql:"age"` + } + var id uint = 0 + u := ptrUser{ + ID: &id, + Name: "Paulo", + Address: address{ + Country: "Brazil", + }, + } + + err := c.Insert(ctx, usersTable, &u) + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, u.ID, nil) + tt.AssertEqual(t, id, *u.ID) + + result := user{} + err = getUserByID(db, dialect, &result, *u.ID) + tt.AssertNoErr(t, err) + + tt.AssertEqual(t, result.ID, id) + tt.AssertEqual(t, result.Name, u.Name) + tt.AssertEqual(t, result.Address, u.Address) + }) + t.Run("should insert ignoring the ID with multiple ids", func(t *testing.T) { if dialect.InsertMethod() != sqldialect.InsertWithLastInsertID { return