mirror of https://github.com/VinGarcia/ksql.git
Test Insert when the ID retrieval is not supported
parent
c768876908
commit
933ded26f4
|
@ -311,6 +311,35 @@ func TestInsert(t *testing.T) {
|
|||
|
||||
assert.Equal(t, u.Name, result.Name)
|
||||
})
|
||||
|
||||
t.Run("should insert ignoring the ID for sqlite & multiple ids", func(t *testing.T) {
|
||||
if driver != "sqlite3" {
|
||||
return
|
||||
}
|
||||
|
||||
db := connectDB(t, driver)
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
// Using columns "id" and "name" as IDs:
|
||||
c, err := New(driver, connectionString[driver], 1, "users", "id", "name")
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
u := User{
|
||||
Name: "No ID returned",
|
||||
Age: 3434, // Random number to avoid false positives on this test
|
||||
}
|
||||
|
||||
err = c.Insert(ctx, &u)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, uint(0), u.ID)
|
||||
|
||||
result := User{}
|
||||
err = getUserByName(c.db, c.dialect, &result, "No ID returned")
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
assert.Equal(t, u.Age, result.Age)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("testing error cases", func(t *testing.T) {
|
||||
|
@ -1318,14 +1347,18 @@ func createTable(driver string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func newTestDB(db *sql.DB, driver string, tableName string) DB {
|
||||
func newTestDB(db *sql.DB, driver string, tableName string, ids ...string) DB {
|
||||
if len(ids) == 0 {
|
||||
ids = []string{"id"}
|
||||
}
|
||||
|
||||
return DB{
|
||||
driver: driver,
|
||||
dialect: getDriverDialect(driver),
|
||||
db: db,
|
||||
tableName: tableName,
|
||||
|
||||
idCols: []string{"id"},
|
||||
idCols: ids,
|
||||
insertMethod: map[string]insertMethod{
|
||||
"sqlite3": insertWithLastInsertID,
|
||||
"postgres": insertWithReturning,
|
||||
|
@ -1404,3 +1437,14 @@ func getUserByID(dbi sqlProvider, dialect dialect, result *User, id uint) error
|
|||
err := row.Scan(&result.ID, &result.Name, &result.Age)
|
||||
return err
|
||||
}
|
||||
|
||||
func getUserByName(dbi sqlProvider, dialect dialect, result *User, name string) error {
|
||||
db := dbi.(*sql.DB)
|
||||
|
||||
row := db.QueryRow(`SELECT id, name, age FROM users WHERE name=`+dialect.Placeholder(0), name)
|
||||
if row.Err() != nil {
|
||||
return row.Err()
|
||||
}
|
||||
err := row.Scan(&result.ID, &result.Name, &result.Age)
|
||||
return err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue